1use axum::{
7 extract::{Path, Query, State},
8 http::StatusCode,
9 response::Json,
10};
11use mockforge_core::security::{
12 emit_security_event,
13 risk_assessment::{
14 Impact, Likelihood, Risk, RiskAssessmentEngine, RiskCategory, RiskLevel, TreatmentOption,
15 TreatmentStatus,
16 },
17 EventActor, EventOutcome, EventTarget, SecurityEvent, SecurityEventType,
18};
19use serde::{Deserialize, Serialize};
20use std::collections::HashMap;
21use std::sync::Arc;
22use tokio::sync::RwLock;
23use tracing::{error, info};
24use uuid::Uuid;
25
26use crate::handlers::auth_helpers::{extract_user_id_with_fallback, OptionalAuthClaims};
27
28#[derive(Clone)]
30pub struct RiskAssessmentState {
31 pub engine: Arc<RwLock<RiskAssessmentEngine>>,
33}
34
35#[derive(Debug, Deserialize)]
37pub struct CreateRiskRequest {
38 pub title: String,
40 pub description: String,
42 pub category: RiskCategory,
44 pub subcategory: Option<String>,
46 pub likelihood: Likelihood,
48 pub impact: Impact,
50 pub threat: Option<String>,
52 pub vulnerability: Option<String>,
54 pub asset: Option<String>,
56 pub existing_controls: Option<Vec<String>>,
58 pub compliance_requirements: Option<Vec<String>>,
60}
61
62#[derive(Debug, Deserialize)]
64pub struct UpdateRiskAssessmentRequest {
65 pub likelihood: Option<Likelihood>,
67 pub impact: Option<Impact>,
69}
70
71#[derive(Debug, Deserialize)]
73pub struct UpdateTreatmentPlanRequest {
74 pub treatment_option: TreatmentOption,
76 pub treatment_plan: Vec<String>,
78 pub treatment_owner: Option<String>,
80 pub treatment_deadline: Option<chrono::DateTime<chrono::Utc>>,
82}
83
84#[derive(Debug, Deserialize)]
86pub struct SetResidualRiskRequest {
87 pub residual_likelihood: Likelihood,
89 pub residual_impact: Impact,
91}
92
93#[derive(Debug, Serialize)]
95pub struct RiskListResponse {
96 pub risks: Vec<Risk>,
98 pub summary: mockforge_core::security::risk_assessment::RiskSummary,
100}
101
102pub async fn create_risk(
106 State(state): State<RiskAssessmentState>,
107 claims: OptionalAuthClaims,
108 Json(request): Json<CreateRiskRequest>,
109) -> Result<Json<serde_json::Value>, StatusCode> {
110 let created_by = extract_user_id_with_fallback(&claims);
112
113 let engine = state.engine.write().await;
114 let mut risk = engine
115 .create_risk(
116 request.title,
117 request.description,
118 request.category,
119 request.likelihood,
120 request.impact,
121 created_by,
122 )
123 .await
124 .map_err(|e| {
125 error!("Failed to create risk: {}", e);
126 StatusCode::INTERNAL_SERVER_ERROR
127 })?;
128
129 if let Some(subcategory) = request.subcategory {
131 }
134 if let Some(threat) = request.threat {
135 }
137 if let Some(vulnerability) = request.vulnerability {
138 }
140 if let Some(asset) = request.asset {
141 }
143 if let Some(controls) = request.existing_controls {
144 }
146 if let Some(requirements) = request.compliance_requirements {
147 }
149
150 info!("Risk created: {}", risk.risk_id);
151
152 let event = SecurityEvent::new(SecurityEventType::ConfigChanged, None, None)
154 .with_actor(EventActor {
155 user_id: Some(created_by.to_string()),
156 username: None,
157 ip_address: None,
158 user_agent: None,
159 })
160 .with_target(EventTarget {
161 resource_type: Some("risk".to_string()),
162 resource_id: Some(risk.risk_id.clone()),
163 method: None,
164 })
165 .with_outcome(EventOutcome {
166 success: true,
167 reason: Some("Risk created".to_string()),
168 });
169 emit_security_event(event).await;
170
171 Ok(Json(serde_json::to_value(&risk).unwrap()))
172}
173
174pub async fn get_risk(
178 State(state): State<RiskAssessmentState>,
179 Path(risk_id): Path<String>,
180) -> Result<Json<serde_json::Value>, StatusCode> {
181 let engine = state.engine.read().await;
182 let risk = engine
183 .get_risk(&risk_id)
184 .await
185 .map_err(|e| {
186 error!("Failed to get risk: {}", e);
187 StatusCode::INTERNAL_SERVER_ERROR
188 })?
189 .ok_or_else(|| {
190 error!("Risk not found: {}", risk_id);
191 StatusCode::NOT_FOUND
192 })?;
193
194 Ok(Json(serde_json::to_value(&risk).unwrap()))
195}
196
197pub async fn list_risks(
201 State(state): State<RiskAssessmentState>,
202 Query(params): Query<HashMap<String, String>>,
203) -> Result<Json<RiskListResponse>, StatusCode> {
204 let engine = state.engine.read().await;
205
206 let risks = if let Some(level_str) = params.get("level") {
207 let level = match level_str.as_str() {
208 "critical" => RiskLevel::Critical,
209 "high" => RiskLevel::High,
210 "medium" => RiskLevel::Medium,
211 "low" => RiskLevel::Low,
212 _ => return Err(StatusCode::BAD_REQUEST),
213 };
214 engine.get_risks_by_level(level).await.map_err(|e| {
215 error!("Failed to get risks by level: {}", e);
216 StatusCode::INTERNAL_SERVER_ERROR
217 })?
218 } else if let Some(category_str) = params.get("category") {
219 let category = match category_str.as_str() {
220 "technical" => RiskCategory::Technical,
221 "operational" => RiskCategory::Operational,
222 "compliance" => RiskCategory::Compliance,
223 "business" => RiskCategory::Business,
224 _ => return Err(StatusCode::BAD_REQUEST),
225 };
226 engine.get_risks_by_category(category).await.map_err(|e| {
227 error!("Failed to get risks by category: {}", e);
228 StatusCode::INTERNAL_SERVER_ERROR
229 })?
230 } else if let Some(status_str) = params.get("treatment_status") {
231 let status = match status_str.as_str() {
232 "not_started" => TreatmentStatus::NotStarted,
233 "in_progress" => TreatmentStatus::InProgress,
234 "completed" => TreatmentStatus::Completed,
235 "on_hold" => TreatmentStatus::OnHold,
236 _ => return Err(StatusCode::BAD_REQUEST),
237 };
238 engine.get_risks_by_treatment_status(status).await.map_err(|e| {
239 error!("Failed to get risks by treatment status: {}", e);
240 StatusCode::INTERNAL_SERVER_ERROR
241 })?
242 } else {
243 engine.get_all_risks().await.map_err(|e| {
244 error!("Failed to get all risks: {}", e);
245 StatusCode::INTERNAL_SERVER_ERROR
246 })?
247 };
248
249 let summary = engine.get_risk_summary().await.map_err(|e| {
250 error!("Failed to get risk summary: {}", e);
251 StatusCode::INTERNAL_SERVER_ERROR
252 })?;
253
254 Ok(Json(RiskListResponse { risks, summary }))
255}
256
257pub async fn update_risk_assessment(
261 State(state): State<RiskAssessmentState>,
262 Path(risk_id): Path<String>,
263 claims: OptionalAuthClaims,
264 Json(request): Json<UpdateRiskAssessmentRequest>,
265) -> Result<Json<serde_json::Value>, StatusCode> {
266 let updated_by = extract_user_id_with_fallback(&claims);
268
269 let engine = state.engine.write().await;
270 engine
271 .update_risk_assessment(&risk_id, request.likelihood, request.impact)
272 .await
273 .map_err(|e| {
274 error!("Failed to update risk assessment: {}", e);
275 StatusCode::BAD_REQUEST
276 })?;
277
278 info!("Risk assessment updated: {}", risk_id);
279
280 let event = SecurityEvent::new(SecurityEventType::ConfigChanged, None, None)
282 .with_actor(EventActor {
283 user_id: Some(updated_by.to_string()),
284 username: None,
285 ip_address: None,
286 user_agent: None,
287 })
288 .with_target(EventTarget {
289 resource_type: Some("risk".to_string()),
290 resource_id: Some(risk_id.clone()),
291 method: None,
292 })
293 .with_outcome(EventOutcome {
294 success: true,
295 reason: Some("Risk assessment updated".to_string()),
296 });
297 emit_security_event(event).await;
298
299 Ok(Json(serde_json::json!({
300 "risk_id": risk_id,
301 "status": "updated"
302 })))
303}
304
305pub async fn update_treatment_plan(
309 State(state): State<RiskAssessmentState>,
310 Path(risk_id): Path<String>,
311 claims: OptionalAuthClaims,
312 Json(request): Json<UpdateTreatmentPlanRequest>,
313) -> Result<Json<serde_json::Value>, StatusCode> {
314 let updated_by = extract_user_id_with_fallback(&claims);
316
317 let engine = state.engine.write().await;
318 engine
319 .update_treatment_plan(
320 &risk_id,
321 request.treatment_option,
322 request.treatment_plan,
323 request.treatment_owner,
324 request.treatment_deadline,
325 )
326 .await
327 .map_err(|e| {
328 error!("Failed to update treatment plan: {}", e);
329 StatusCode::BAD_REQUEST
330 })?;
331
332 info!("Treatment plan updated: {}", risk_id);
333
334 let event = SecurityEvent::new(SecurityEventType::ConfigChanged, None, None)
336 .with_actor(EventActor {
337 user_id: Some(updated_by.to_string()),
338 username: None,
339 ip_address: None,
340 user_agent: None,
341 })
342 .with_target(EventTarget {
343 resource_type: Some("risk".to_string()),
344 resource_id: Some(risk_id.clone()),
345 method: None,
346 })
347 .with_outcome(EventOutcome {
348 success: true,
349 reason: Some("Treatment plan updated".to_string()),
350 });
351 emit_security_event(event).await;
352
353 Ok(Json(serde_json::json!({
354 "risk_id": risk_id,
355 "status": "updated"
356 })))
357}
358
359pub async fn update_treatment_status(
363 State(state): State<RiskAssessmentState>,
364 Path(risk_id): Path<String>,
365 claims: OptionalAuthClaims,
366 Json(request): Json<serde_json::Value>,
367) -> Result<Json<serde_json::Value>, StatusCode> {
368 let _updated_by = extract_user_id_with_fallback(&claims);
370
371 let status_str = request
372 .get("status")
373 .and_then(|v| v.as_str())
374 .ok_or_else(|| StatusCode::BAD_REQUEST)?;
375
376 let status = match status_str {
377 "not_started" => TreatmentStatus::NotStarted,
378 "in_progress" => TreatmentStatus::InProgress,
379 "completed" => TreatmentStatus::Completed,
380 "on_hold" => TreatmentStatus::OnHold,
381 _ => return Err(StatusCode::BAD_REQUEST),
382 };
383
384 let engine = state.engine.write().await;
385 engine.update_treatment_status(&risk_id, status).await.map_err(|e| {
386 error!("Failed to update treatment status: {}", e);
387 StatusCode::BAD_REQUEST
388 })?;
389
390 info!("Treatment status updated: {}", risk_id);
391
392 Ok(Json(serde_json::json!({
393 "risk_id": risk_id,
394 "status": "updated"
395 })))
396}
397
398pub async fn set_residual_risk(
402 State(state): State<RiskAssessmentState>,
403 Path(risk_id): Path<String>,
404 claims: OptionalAuthClaims,
405 Json(request): Json<SetResidualRiskRequest>,
406) -> Result<Json<serde_json::Value>, StatusCode> {
407 let _updated_by = extract_user_id_with_fallback(&claims);
409
410 let engine = state.engine.write().await;
411 engine
412 .set_residual_risk(&risk_id, request.residual_likelihood, request.residual_impact)
413 .await
414 .map_err(|e| {
415 error!("Failed to set residual risk: {}", e);
416 StatusCode::BAD_REQUEST
417 })?;
418
419 info!("Residual risk set: {}", risk_id);
420
421 Ok(Json(serde_json::json!({
422 "risk_id": risk_id,
423 "status": "updated"
424 })))
425}
426
427pub async fn review_risk(
431 State(state): State<RiskAssessmentState>,
432 Path(risk_id): Path<String>,
433 claims: OptionalAuthClaims,
434) -> Result<Json<serde_json::Value>, StatusCode> {
435 let reviewed_by = extract_user_id_with_fallback(&claims);
437
438 let engine = state.engine.write().await;
439 engine.review_risk(&risk_id, reviewed_by).await.map_err(|e| {
440 error!("Failed to review risk: {}", e);
441 StatusCode::BAD_REQUEST
442 })?;
443
444 info!("Risk reviewed: {}", risk_id);
445
446 Ok(Json(serde_json::json!({
447 "risk_id": risk_id,
448 "status": "reviewed"
449 })))
450}
451
452pub async fn get_risks_due_for_review(
456 State(state): State<RiskAssessmentState>,
457) -> Result<Json<serde_json::Value>, StatusCode> {
458 let engine = state.engine.read().await;
459 let risks = engine.get_risks_due_for_review().await.map_err(|e| {
460 error!("Failed to get risks due for review: {}", e);
461 StatusCode::INTERNAL_SERVER_ERROR
462 })?;
463
464 Ok(Json(serde_json::to_value(&risks).unwrap()))
465}
466
467pub async fn get_risk_summary(
471 State(state): State<RiskAssessmentState>,
472) -> Result<Json<serde_json::Value>, StatusCode> {
473 let engine = state.engine.read().await;
474 let summary = engine.get_risk_summary().await.map_err(|e| {
475 error!("Failed to get risk summary: {}", e);
476 StatusCode::INTERNAL_SERVER_ERROR
477 })?;
478
479 Ok(Json(serde_json::to_value(&summary).unwrap()))
480}
481
482pub fn risk_assessment_router(state: RiskAssessmentState) -> axum::Router {
484 use axum::routing::{get, patch, post, put};
485
486 axum::Router::new()
487 .route("/risks", get(list_risks))
488 .route("/risks", post(create_risk))
489 .route("/risks/{risk_id}", get(get_risk))
490 .route("/risks/{risk_id}/assessment", put(update_risk_assessment))
491 .route("/risks/{risk_id}/treatment", put(update_treatment_plan))
492 .route("/risks/{risk_id}/treatment/status", patch(update_treatment_status))
493 .route("/risks/{risk_id}/residual", put(set_residual_risk))
494 .route("/risks/{risk_id}/review", post(review_risk))
495 .route("/risks/due-for-review", get(get_risks_due_for_review))
496 .route("/risks/summary", get(get_risk_summary))
497 .with_state(state)
498}