1use axum::{
7 extract::{Path, Query, State},
8 http::StatusCode,
9 response::Json,
10};
11use mockforge_core::security::{
12 emit_security_event,
13 risk_assessment::{
14 Impact, Likelihood, Risk, RiskAssessmentEngine, RiskCategory, RiskLevel, TreatmentOption,
15 TreatmentStatus,
16 },
17 EventActor, EventOutcome, EventTarget, SecurityEvent, SecurityEventType,
18};
19use serde::{Deserialize, Serialize};
20use std::collections::HashMap;
21use std::sync::Arc;
22use tokio::sync::RwLock;
23use tracing::{error, info};
24
25use crate::handlers::auth_helpers::{extract_user_id_with_fallback, OptionalAuthClaims};
26
27#[derive(Clone)]
29pub struct RiskAssessmentState {
30 pub engine: Arc<RwLock<RiskAssessmentEngine>>,
32}
33
34#[derive(Debug, Deserialize)]
36pub struct CreateRiskRequest {
37 pub title: String,
39 pub description: String,
41 pub category: RiskCategory,
43 pub subcategory: Option<String>,
45 pub likelihood: Likelihood,
47 pub impact: Impact,
49 pub threat: Option<String>,
51 pub vulnerability: Option<String>,
53 pub asset: Option<String>,
55 pub existing_controls: Option<Vec<String>>,
57 pub compliance_requirements: Option<Vec<String>>,
59}
60
61#[derive(Debug, Deserialize)]
63pub struct UpdateRiskAssessmentRequest {
64 pub likelihood: Option<Likelihood>,
66 pub impact: Option<Impact>,
68}
69
70#[derive(Debug, Deserialize)]
72pub struct UpdateTreatmentPlanRequest {
73 pub treatment_option: TreatmentOption,
75 pub treatment_plan: Vec<String>,
77 pub treatment_owner: Option<String>,
79 pub treatment_deadline: Option<chrono::DateTime<chrono::Utc>>,
81}
82
83#[derive(Debug, Deserialize)]
85pub struct SetResidualRiskRequest {
86 pub residual_likelihood: Likelihood,
88 pub residual_impact: Impact,
90}
91
92#[derive(Debug, Serialize)]
94pub struct RiskListResponse {
95 pub risks: Vec<Risk>,
97 pub summary: mockforge_core::security::risk_assessment::RiskSummary,
99}
100
101pub async fn create_risk(
105 State(state): State<RiskAssessmentState>,
106 claims: OptionalAuthClaims,
107 Json(request): Json<CreateRiskRequest>,
108) -> Result<Json<serde_json::Value>, StatusCode> {
109 let created_by = extract_user_id_with_fallback(&claims);
111
112 let engine = state.engine.write().await;
113 let risk = engine
114 .create_risk(
115 request.title,
116 request.description,
117 request.category,
118 request.likelihood,
119 request.impact,
120 created_by,
121 )
122 .await
123 .map_err(|e| {
124 error!("Failed to create risk: {}", e);
125 StatusCode::INTERNAL_SERVER_ERROR
126 })?;
127
128 if let Some(subcategory) = request.subcategory {
130 }
133 if let Some(threat) = request.threat {
134 }
136 if let Some(vulnerability) = request.vulnerability {
137 }
139 if let Some(asset) = request.asset {
140 }
142 if let Some(controls) = request.existing_controls {
143 }
145 if let Some(requirements) = request.compliance_requirements {
146 }
148
149 info!("Risk created: {}", risk.risk_id);
150
151 let event = SecurityEvent::new(SecurityEventType::ConfigChanged, None, None)
153 .with_actor(EventActor {
154 user_id: Some(created_by.to_string()),
155 username: None,
156 ip_address: None,
157 user_agent: None,
158 })
159 .with_target(EventTarget {
160 resource_type: Some("risk".to_string()),
161 resource_id: Some(risk.risk_id.clone()),
162 method: None,
163 })
164 .with_outcome(EventOutcome {
165 success: true,
166 reason: Some("Risk created".to_string()),
167 });
168 emit_security_event(event).await;
169
170 Ok(Json(serde_json::to_value(&risk).unwrap()))
171}
172
173pub async fn get_risk(
177 State(state): State<RiskAssessmentState>,
178 Path(risk_id): Path<String>,
179) -> Result<Json<serde_json::Value>, StatusCode> {
180 let engine = state.engine.read().await;
181 let risk = engine
182 .get_risk(&risk_id)
183 .await
184 .map_err(|e| {
185 error!("Failed to get risk: {}", e);
186 StatusCode::INTERNAL_SERVER_ERROR
187 })?
188 .ok_or_else(|| {
189 error!("Risk not found: {}", risk_id);
190 StatusCode::NOT_FOUND
191 })?;
192
193 Ok(Json(serde_json::to_value(&risk).unwrap()))
194}
195
196pub async fn list_risks(
200 State(state): State<RiskAssessmentState>,
201 Query(params): Query<HashMap<String, String>>,
202) -> Result<Json<RiskListResponse>, StatusCode> {
203 let engine = state.engine.read().await;
204
205 let risks = if let Some(level_str) = params.get("level") {
206 let level = match level_str.as_str() {
207 "critical" => RiskLevel::Critical,
208 "high" => RiskLevel::High,
209 "medium" => RiskLevel::Medium,
210 "low" => RiskLevel::Low,
211 _ => return Err(StatusCode::BAD_REQUEST),
212 };
213 engine.get_risks_by_level(level).await.map_err(|e| {
214 error!("Failed to get risks by level: {}", e);
215 StatusCode::INTERNAL_SERVER_ERROR
216 })?
217 } else if let Some(category_str) = params.get("category") {
218 let category = match category_str.as_str() {
219 "technical" => RiskCategory::Technical,
220 "operational" => RiskCategory::Operational,
221 "compliance" => RiskCategory::Compliance,
222 "business" => RiskCategory::Business,
223 _ => return Err(StatusCode::BAD_REQUEST),
224 };
225 engine.get_risks_by_category(category).await.map_err(|e| {
226 error!("Failed to get risks by category: {}", e);
227 StatusCode::INTERNAL_SERVER_ERROR
228 })?
229 } else if let Some(status_str) = params.get("treatment_status") {
230 let status = match status_str.as_str() {
231 "not_started" => TreatmentStatus::NotStarted,
232 "in_progress" => TreatmentStatus::InProgress,
233 "completed" => TreatmentStatus::Completed,
234 "on_hold" => TreatmentStatus::OnHold,
235 _ => return Err(StatusCode::BAD_REQUEST),
236 };
237 engine.get_risks_by_treatment_status(status).await.map_err(|e| {
238 error!("Failed to get risks by treatment status: {}", e);
239 StatusCode::INTERNAL_SERVER_ERROR
240 })?
241 } else {
242 engine.get_all_risks().await.map_err(|e| {
243 error!("Failed to get all risks: {}", e);
244 StatusCode::INTERNAL_SERVER_ERROR
245 })?
246 };
247
248 let summary = engine.get_risk_summary().await.map_err(|e| {
249 error!("Failed to get risk summary: {}", e);
250 StatusCode::INTERNAL_SERVER_ERROR
251 })?;
252
253 Ok(Json(RiskListResponse { risks, summary }))
254}
255
256pub async fn update_risk_assessment(
260 State(state): State<RiskAssessmentState>,
261 Path(risk_id): Path<String>,
262 claims: OptionalAuthClaims,
263 Json(request): Json<UpdateRiskAssessmentRequest>,
264) -> Result<Json<serde_json::Value>, StatusCode> {
265 let updated_by = extract_user_id_with_fallback(&claims);
267
268 let engine = state.engine.write().await;
269 engine
270 .update_risk_assessment(&risk_id, request.likelihood, request.impact)
271 .await
272 .map_err(|e| {
273 error!("Failed to update risk assessment: {}", e);
274 StatusCode::BAD_REQUEST
275 })?;
276
277 info!("Risk assessment updated: {}", risk_id);
278
279 let event = SecurityEvent::new(SecurityEventType::ConfigChanged, None, None)
281 .with_actor(EventActor {
282 user_id: Some(updated_by.to_string()),
283 username: None,
284 ip_address: None,
285 user_agent: None,
286 })
287 .with_target(EventTarget {
288 resource_type: Some("risk".to_string()),
289 resource_id: Some(risk_id.clone()),
290 method: None,
291 })
292 .with_outcome(EventOutcome {
293 success: true,
294 reason: Some("Risk assessment updated".to_string()),
295 });
296 emit_security_event(event).await;
297
298 Ok(Json(serde_json::json!({
299 "risk_id": risk_id,
300 "status": "updated"
301 })))
302}
303
304pub async fn update_treatment_plan(
308 State(state): State<RiskAssessmentState>,
309 Path(risk_id): Path<String>,
310 claims: OptionalAuthClaims,
311 Json(request): Json<UpdateTreatmentPlanRequest>,
312) -> Result<Json<serde_json::Value>, StatusCode> {
313 let updated_by = extract_user_id_with_fallback(&claims);
315
316 let engine = state.engine.write().await;
317 engine
318 .update_treatment_plan(
319 &risk_id,
320 request.treatment_option,
321 request.treatment_plan,
322 request.treatment_owner,
323 request.treatment_deadline,
324 )
325 .await
326 .map_err(|e| {
327 error!("Failed to update treatment plan: {}", e);
328 StatusCode::BAD_REQUEST
329 })?;
330
331 info!("Treatment plan updated: {}", risk_id);
332
333 let event = SecurityEvent::new(SecurityEventType::ConfigChanged, None, None)
335 .with_actor(EventActor {
336 user_id: Some(updated_by.to_string()),
337 username: None,
338 ip_address: None,
339 user_agent: None,
340 })
341 .with_target(EventTarget {
342 resource_type: Some("risk".to_string()),
343 resource_id: Some(risk_id.clone()),
344 method: None,
345 })
346 .with_outcome(EventOutcome {
347 success: true,
348 reason: Some("Treatment plan updated".to_string()),
349 });
350 emit_security_event(event).await;
351
352 Ok(Json(serde_json::json!({
353 "risk_id": risk_id,
354 "status": "updated"
355 })))
356}
357
358pub async fn update_treatment_status(
362 State(state): State<RiskAssessmentState>,
363 Path(risk_id): Path<String>,
364 claims: OptionalAuthClaims,
365 Json(request): Json<serde_json::Value>,
366) -> Result<Json<serde_json::Value>, StatusCode> {
367 let _updated_by = extract_user_id_with_fallback(&claims);
369
370 let status_str =
371 request.get("status").and_then(|v| v.as_str()).ok_or(StatusCode::BAD_REQUEST)?;
372
373 let status = match status_str {
374 "not_started" => TreatmentStatus::NotStarted,
375 "in_progress" => TreatmentStatus::InProgress,
376 "completed" => TreatmentStatus::Completed,
377 "on_hold" => TreatmentStatus::OnHold,
378 _ => return Err(StatusCode::BAD_REQUEST),
379 };
380
381 let engine = state.engine.write().await;
382 engine.update_treatment_status(&risk_id, status).await.map_err(|e| {
383 error!("Failed to update treatment status: {}", e);
384 StatusCode::BAD_REQUEST
385 })?;
386
387 info!("Treatment status updated: {}", risk_id);
388
389 Ok(Json(serde_json::json!({
390 "risk_id": risk_id,
391 "status": "updated"
392 })))
393}
394
395pub async fn set_residual_risk(
399 State(state): State<RiskAssessmentState>,
400 Path(risk_id): Path<String>,
401 claims: OptionalAuthClaims,
402 Json(request): Json<SetResidualRiskRequest>,
403) -> Result<Json<serde_json::Value>, StatusCode> {
404 let _updated_by = extract_user_id_with_fallback(&claims);
406
407 let engine = state.engine.write().await;
408 engine
409 .set_residual_risk(&risk_id, request.residual_likelihood, request.residual_impact)
410 .await
411 .map_err(|e| {
412 error!("Failed to set residual risk: {}", e);
413 StatusCode::BAD_REQUEST
414 })?;
415
416 info!("Residual risk set: {}", risk_id);
417
418 Ok(Json(serde_json::json!({
419 "risk_id": risk_id,
420 "status": "updated"
421 })))
422}
423
424pub async fn review_risk(
428 State(state): State<RiskAssessmentState>,
429 Path(risk_id): Path<String>,
430 claims: OptionalAuthClaims,
431) -> Result<Json<serde_json::Value>, StatusCode> {
432 let reviewed_by = extract_user_id_with_fallback(&claims);
434
435 let engine = state.engine.write().await;
436 engine.review_risk(&risk_id, reviewed_by).await.map_err(|e| {
437 error!("Failed to review risk: {}", e);
438 StatusCode::BAD_REQUEST
439 })?;
440
441 info!("Risk reviewed: {}", risk_id);
442
443 Ok(Json(serde_json::json!({
444 "risk_id": risk_id,
445 "status": "reviewed"
446 })))
447}
448
449pub async fn get_risks_due_for_review(
453 State(state): State<RiskAssessmentState>,
454) -> Result<Json<serde_json::Value>, StatusCode> {
455 let engine = state.engine.read().await;
456 let risks = engine.get_risks_due_for_review().await.map_err(|e| {
457 error!("Failed to get risks due for review: {}", e);
458 StatusCode::INTERNAL_SERVER_ERROR
459 })?;
460
461 Ok(Json(serde_json::to_value(&risks).unwrap()))
462}
463
464pub async fn get_risk_summary(
468 State(state): State<RiskAssessmentState>,
469) -> Result<Json<serde_json::Value>, StatusCode> {
470 let engine = state.engine.read().await;
471 let summary = engine.get_risk_summary().await.map_err(|e| {
472 error!("Failed to get risk summary: {}", e);
473 StatusCode::INTERNAL_SERVER_ERROR
474 })?;
475
476 Ok(Json(serde_json::to_value(&summary).unwrap()))
477}
478
479pub fn risk_assessment_router(state: RiskAssessmentState) -> axum::Router {
481 use axum::routing::{get, patch, post, put};
482
483 axum::Router::new()
484 .route("/risks", get(list_risks))
485 .route("/risks", post(create_risk))
486 .route("/risks/{risk_id}", get(get_risk))
487 .route("/risks/{risk_id}/assessment", put(update_risk_assessment))
488 .route("/risks/{risk_id}/treatment", put(update_treatment_plan))
489 .route("/risks/{risk_id}/treatment/status", patch(update_treatment_status))
490 .route("/risks/{risk_id}/residual", put(set_residual_risk))
491 .route("/risks/{risk_id}/review", post(review_risk))
492 .route("/risks/due-for-review", get(get_risks_due_for_review))
493 .route("/risks/summary", get(get_risk_summary))
494 .with_state(state)
495}