1use axum::{
11 extract::{Path, Query, State},
12 http::StatusCode,
13 response::Json,
14};
15use mockforge_core::security::{
16 emit_security_event,
17 risk_assessment::{
18 Impact, Likelihood, Risk, RiskAssessmentEngine, RiskCategory, RiskLevel, TreatmentOption,
19 TreatmentStatus,
20 },
21 EventActor, EventOutcome, EventTarget, SecurityEvent, SecurityEventType,
22};
23use serde::{Deserialize, Serialize};
24use std::collections::HashMap;
25use std::sync::Arc;
26use tokio::sync::RwLock;
27use tracing::{error, info};
28
29use crate::handlers::auth_helpers::{extract_user_id_with_fallback, OptionalAuthClaims};
30
31#[derive(Clone)]
33pub struct RiskAssessmentState {
34 pub engine: Arc<RwLock<RiskAssessmentEngine>>,
36}
37
38#[derive(Debug, Deserialize)]
40pub struct CreateRiskRequest {
41 pub title: String,
43 pub description: String,
45 pub category: RiskCategory,
47 pub subcategory: Option<String>,
49 pub likelihood: Likelihood,
51 pub impact: Impact,
53 pub threat: Option<String>,
55 pub vulnerability: Option<String>,
57 pub asset: Option<String>,
59 pub existing_controls: Option<Vec<String>>,
61 pub compliance_requirements: Option<Vec<String>>,
63}
64
65#[derive(Debug, Deserialize)]
67pub struct UpdateRiskAssessmentRequest {
68 pub likelihood: Option<Likelihood>,
70 pub impact: Option<Impact>,
72}
73
74#[derive(Debug, Deserialize)]
76pub struct UpdateTreatmentPlanRequest {
77 pub treatment_option: TreatmentOption,
79 pub treatment_plan: Vec<String>,
81 pub treatment_owner: Option<String>,
83 pub treatment_deadline: Option<chrono::DateTime<chrono::Utc>>,
85}
86
87#[derive(Debug, Deserialize)]
89pub struct SetResidualRiskRequest {
90 pub residual_likelihood: Likelihood,
92 pub residual_impact: Impact,
94}
95
96#[derive(Debug, Serialize)]
98pub struct RiskListResponse {
99 pub risks: Vec<Risk>,
101 pub summary: mockforge_core::security::risk_assessment::RiskSummary,
103}
104
105pub async fn create_risk(
109 State(state): State<RiskAssessmentState>,
110 claims: OptionalAuthClaims,
111 Json(request): Json<CreateRiskRequest>,
112) -> Result<Json<serde_json::Value>, StatusCode> {
113 let created_by = extract_user_id_with_fallback(&claims);
115
116 let engine = state.engine.write().await;
117 let risk = engine
118 .create_risk(
119 request.title,
120 request.description,
121 request.category,
122 request.likelihood,
123 request.impact,
124 created_by,
125 )
126 .await
127 .map_err(|e| {
128 error!("Failed to create risk: {}", e);
129 StatusCode::INTERNAL_SERVER_ERROR
130 })?;
131
132 let mut risk = risk;
134 let mut needs_update = false;
135
136 if let Some(subcategory) = request.subcategory {
137 risk.subcategory = Some(subcategory);
138 needs_update = true;
139 }
140 if let Some(threat) = request.threat {
141 risk.threat = Some(threat);
142 needs_update = true;
143 }
144 if let Some(vulnerability) = request.vulnerability {
145 risk.vulnerability = Some(vulnerability);
146 needs_update = true;
147 }
148 if let Some(asset) = request.asset {
149 risk.asset = Some(asset);
150 needs_update = true;
151 }
152 if let Some(controls) = request.existing_controls {
153 risk.existing_controls = controls;
154 needs_update = true;
155 }
156 if let Some(requirements) = request.compliance_requirements {
157 risk.compliance_requirements = requirements;
158 needs_update = true;
159 }
160
161 if needs_update {
162 let risk_id = risk.risk_id.clone();
163 engine.update_risk(&risk_id, risk.clone()).await.map_err(|e| {
164 error!("Failed to update risk with optional fields: {}", e);
165 StatusCode::INTERNAL_SERVER_ERROR
166 })?;
167 }
168
169 info!("Risk created: {}", risk.risk_id);
170
171 let event = SecurityEvent::new(SecurityEventType::ConfigChanged, None, None)
173 .with_actor(EventActor {
174 user_id: Some(created_by.to_string()),
175 username: None,
176 ip_address: None,
177 user_agent: None,
178 })
179 .with_target(EventTarget {
180 resource_type: Some("risk".to_string()),
181 resource_id: Some(risk.risk_id.clone()),
182 method: None,
183 })
184 .with_outcome(EventOutcome {
185 success: true,
186 reason: Some("Risk created".to_string()),
187 });
188 emit_security_event(event).await;
189
190 Ok(Json(serde_json::to_value(&risk).map_err(|e| {
191 error!("Failed to serialize risk: {}", e);
192 StatusCode::INTERNAL_SERVER_ERROR
193 })?))
194}
195
196pub async fn get_risk(
200 State(state): State<RiskAssessmentState>,
201 Path(risk_id): Path<String>,
202) -> Result<Json<serde_json::Value>, StatusCode> {
203 let engine = state.engine.read().await;
204 let risk = engine
205 .get_risk(&risk_id)
206 .await
207 .map_err(|e| {
208 error!("Failed to get risk: {}", e);
209 StatusCode::INTERNAL_SERVER_ERROR
210 })?
211 .ok_or_else(|| {
212 error!("Risk not found: {}", risk_id);
213 StatusCode::NOT_FOUND
214 })?;
215
216 Ok(Json(serde_json::to_value(&risk).map_err(|e| {
217 error!("Failed to serialize risk: {}", e);
218 StatusCode::INTERNAL_SERVER_ERROR
219 })?))
220}
221
222pub async fn list_risks(
226 State(state): State<RiskAssessmentState>,
227 Query(params): Query<HashMap<String, String>>,
228) -> Result<Json<RiskListResponse>, StatusCode> {
229 let engine = state.engine.read().await;
230
231 let risks = if let Some(level_str) = params.get("level") {
232 let level = match level_str.as_str() {
233 "critical" => RiskLevel::Critical,
234 "high" => RiskLevel::High,
235 "medium" => RiskLevel::Medium,
236 "low" => RiskLevel::Low,
237 _ => return Err(StatusCode::BAD_REQUEST),
238 };
239 engine.get_risks_by_level(level).await.map_err(|e| {
240 error!("Failed to get risks by level: {}", e);
241 StatusCode::INTERNAL_SERVER_ERROR
242 })?
243 } else if let Some(category_str) = params.get("category") {
244 let category = match category_str.as_str() {
245 "technical" => RiskCategory::Technical,
246 "operational" => RiskCategory::Operational,
247 "compliance" => RiskCategory::Compliance,
248 "business" => RiskCategory::Business,
249 _ => return Err(StatusCode::BAD_REQUEST),
250 };
251 engine.get_risks_by_category(category).await.map_err(|e| {
252 error!("Failed to get risks by category: {}", e);
253 StatusCode::INTERNAL_SERVER_ERROR
254 })?
255 } else if let Some(status_str) = params.get("treatment_status") {
256 let status = match status_str.as_str() {
257 "not_started" => TreatmentStatus::NotStarted,
258 "in_progress" => TreatmentStatus::InProgress,
259 "completed" => TreatmentStatus::Completed,
260 "on_hold" => TreatmentStatus::OnHold,
261 _ => return Err(StatusCode::BAD_REQUEST),
262 };
263 engine.get_risks_by_treatment_status(status).await.map_err(|e| {
264 error!("Failed to get risks by treatment status: {}", e);
265 StatusCode::INTERNAL_SERVER_ERROR
266 })?
267 } else {
268 engine.get_all_risks().await.map_err(|e| {
269 error!("Failed to get all risks: {}", e);
270 StatusCode::INTERNAL_SERVER_ERROR
271 })?
272 };
273
274 let summary = engine.get_risk_summary().await.map_err(|e| {
275 error!("Failed to get risk summary: {}", e);
276 StatusCode::INTERNAL_SERVER_ERROR
277 })?;
278
279 Ok(Json(RiskListResponse { risks, summary }))
280}
281
282pub async fn update_risk_assessment(
286 State(state): State<RiskAssessmentState>,
287 Path(risk_id): Path<String>,
288 claims: OptionalAuthClaims,
289 Json(request): Json<UpdateRiskAssessmentRequest>,
290) -> Result<Json<serde_json::Value>, StatusCode> {
291 let updated_by = extract_user_id_with_fallback(&claims);
293
294 let engine = state.engine.write().await;
295 engine
296 .update_risk_assessment(&risk_id, request.likelihood, request.impact)
297 .await
298 .map_err(|e| {
299 error!("Failed to update risk assessment: {}", e);
300 StatusCode::BAD_REQUEST
301 })?;
302
303 info!("Risk assessment updated: {}", risk_id);
304
305 let event = SecurityEvent::new(SecurityEventType::ConfigChanged, None, None)
307 .with_actor(EventActor {
308 user_id: Some(updated_by.to_string()),
309 username: None,
310 ip_address: None,
311 user_agent: None,
312 })
313 .with_target(EventTarget {
314 resource_type: Some("risk".to_string()),
315 resource_id: Some(risk_id.clone()),
316 method: None,
317 })
318 .with_outcome(EventOutcome {
319 success: true,
320 reason: Some("Risk assessment updated".to_string()),
321 });
322 emit_security_event(event).await;
323
324 Ok(Json(serde_json::json!({
325 "risk_id": risk_id,
326 "status": "updated"
327 })))
328}
329
330pub async fn update_treatment_plan(
334 State(state): State<RiskAssessmentState>,
335 Path(risk_id): Path<String>,
336 claims: OptionalAuthClaims,
337 Json(request): Json<UpdateTreatmentPlanRequest>,
338) -> Result<Json<serde_json::Value>, StatusCode> {
339 let updated_by = extract_user_id_with_fallback(&claims);
341
342 let engine = state.engine.write().await;
343 engine
344 .update_treatment_plan(
345 &risk_id,
346 request.treatment_option,
347 request.treatment_plan,
348 request.treatment_owner,
349 request.treatment_deadline,
350 )
351 .await
352 .map_err(|e| {
353 error!("Failed to update treatment plan: {}", e);
354 StatusCode::BAD_REQUEST
355 })?;
356
357 info!("Treatment plan updated: {}", risk_id);
358
359 let event = SecurityEvent::new(SecurityEventType::ConfigChanged, None, None)
361 .with_actor(EventActor {
362 user_id: Some(updated_by.to_string()),
363 username: None,
364 ip_address: None,
365 user_agent: None,
366 })
367 .with_target(EventTarget {
368 resource_type: Some("risk".to_string()),
369 resource_id: Some(risk_id.clone()),
370 method: None,
371 })
372 .with_outcome(EventOutcome {
373 success: true,
374 reason: Some("Treatment plan updated".to_string()),
375 });
376 emit_security_event(event).await;
377
378 Ok(Json(serde_json::json!({
379 "risk_id": risk_id,
380 "status": "updated"
381 })))
382}
383
384pub async fn update_treatment_status(
388 State(state): State<RiskAssessmentState>,
389 Path(risk_id): Path<String>,
390 claims: OptionalAuthClaims,
391 Json(request): Json<serde_json::Value>,
392) -> Result<Json<serde_json::Value>, StatusCode> {
393 let _updated_by = extract_user_id_with_fallback(&claims);
395
396 let status_str =
397 request.get("status").and_then(|v| v.as_str()).ok_or(StatusCode::BAD_REQUEST)?;
398
399 let status = match status_str {
400 "not_started" => TreatmentStatus::NotStarted,
401 "in_progress" => TreatmentStatus::InProgress,
402 "completed" => TreatmentStatus::Completed,
403 "on_hold" => TreatmentStatus::OnHold,
404 _ => return Err(StatusCode::BAD_REQUEST),
405 };
406
407 let engine = state.engine.write().await;
408 engine.update_treatment_status(&risk_id, status).await.map_err(|e| {
409 error!("Failed to update treatment status: {}", e);
410 StatusCode::BAD_REQUEST
411 })?;
412
413 info!("Treatment status updated: {}", risk_id);
414
415 Ok(Json(serde_json::json!({
416 "risk_id": risk_id,
417 "status": "updated"
418 })))
419}
420
421pub async fn set_residual_risk(
425 State(state): State<RiskAssessmentState>,
426 Path(risk_id): Path<String>,
427 claims: OptionalAuthClaims,
428 Json(request): Json<SetResidualRiskRequest>,
429) -> Result<Json<serde_json::Value>, StatusCode> {
430 let _updated_by = extract_user_id_with_fallback(&claims);
432
433 let engine = state.engine.write().await;
434 engine
435 .set_residual_risk(&risk_id, request.residual_likelihood, request.residual_impact)
436 .await
437 .map_err(|e| {
438 error!("Failed to set residual risk: {}", e);
439 StatusCode::BAD_REQUEST
440 })?;
441
442 info!("Residual risk set: {}", risk_id);
443
444 Ok(Json(serde_json::json!({
445 "risk_id": risk_id,
446 "status": "updated"
447 })))
448}
449
450pub async fn review_risk(
454 State(state): State<RiskAssessmentState>,
455 Path(risk_id): Path<String>,
456 claims: OptionalAuthClaims,
457) -> Result<Json<serde_json::Value>, StatusCode> {
458 let reviewed_by = extract_user_id_with_fallback(&claims);
460
461 let engine = state.engine.write().await;
462 engine.review_risk(&risk_id, reviewed_by).await.map_err(|e| {
463 error!("Failed to review risk: {}", e);
464 StatusCode::BAD_REQUEST
465 })?;
466
467 info!("Risk reviewed: {}", risk_id);
468
469 Ok(Json(serde_json::json!({
470 "risk_id": risk_id,
471 "status": "reviewed"
472 })))
473}
474
475pub async fn get_risks_due_for_review(
479 State(state): State<RiskAssessmentState>,
480) -> Result<Json<serde_json::Value>, StatusCode> {
481 let engine = state.engine.read().await;
482 let risks = engine.get_risks_due_for_review().await.map_err(|e| {
483 error!("Failed to get risks due for review: {}", e);
484 StatusCode::INTERNAL_SERVER_ERROR
485 })?;
486
487 Ok(Json(serde_json::to_value(&risks).map_err(|e| {
488 error!("Failed to serialize risks: {}", e);
489 StatusCode::INTERNAL_SERVER_ERROR
490 })?))
491}
492
493pub async fn get_risk_summary(
497 State(state): State<RiskAssessmentState>,
498) -> Result<Json<serde_json::Value>, StatusCode> {
499 let engine = state.engine.read().await;
500 let summary = engine.get_risk_summary().await.map_err(|e| {
501 error!("Failed to get risk summary: {}", e);
502 StatusCode::INTERNAL_SERVER_ERROR
503 })?;
504
505 Ok(Json(serde_json::to_value(&summary).map_err(|e| {
506 error!("Failed to serialize summary: {}", e);
507 StatusCode::INTERNAL_SERVER_ERROR
508 })?))
509}
510
511pub fn risk_assessment_router(state: RiskAssessmentState) -> axum::Router {
513 use axum::routing::{get, patch, post, put};
514
515 axum::Router::new()
516 .route("/risks", get(list_risks))
517 .route("/risks", post(create_risk))
518 .route("/risks/{risk_id}", get(get_risk))
519 .route("/risks/{risk_id}/assessment", put(update_risk_assessment))
520 .route("/risks/{risk_id}/treatment", put(update_treatment_plan))
521 .route("/risks/{risk_id}/treatment/status", patch(update_treatment_status))
522 .route("/risks/{risk_id}/residual", put(set_residual_risk))
523 .route("/risks/{risk_id}/review", post(review_risk))
524 .route("/risks/due-for-review", get(get_risks_due_for_review))
525 .route("/risks/summary", get(get_risk_summary))
526 .with_state(state)
527}