1use axum::{
7 extract::{Path, Query, State},
8 http::StatusCode,
9 response::Json,
10};
11use mockforge_core::security::{
12 emit_security_event,
13 risk_assessment::{
14 Impact, Likelihood, Risk, RiskAssessmentEngine, RiskCategory, RiskLevel, TreatmentOption,
15 TreatmentStatus,
16 },
17 EventActor, EventOutcome, EventTarget, SecurityEvent, SecurityEventType,
18};
19use serde::{Deserialize, Serialize};
20use std::collections::HashMap;
21use std::sync::Arc;
22use tokio::sync::RwLock;
23use tracing::{error, info};
24
25use crate::handlers::auth_helpers::{extract_user_id_with_fallback, OptionalAuthClaims};
26
27#[derive(Clone)]
29pub struct RiskAssessmentState {
30 pub engine: Arc<RwLock<RiskAssessmentEngine>>,
32}
33
34#[derive(Debug, Deserialize)]
36pub struct CreateRiskRequest {
37 pub title: String,
39 pub description: String,
41 pub category: RiskCategory,
43 pub subcategory: Option<String>,
45 pub likelihood: Likelihood,
47 pub impact: Impact,
49 pub threat: Option<String>,
51 pub vulnerability: Option<String>,
53 pub asset: Option<String>,
55 pub existing_controls: Option<Vec<String>>,
57 pub compliance_requirements: Option<Vec<String>>,
59}
60
61#[derive(Debug, Deserialize)]
63pub struct UpdateRiskAssessmentRequest {
64 pub likelihood: Option<Likelihood>,
66 pub impact: Option<Impact>,
68}
69
70#[derive(Debug, Deserialize)]
72pub struct UpdateTreatmentPlanRequest {
73 pub treatment_option: TreatmentOption,
75 pub treatment_plan: Vec<String>,
77 pub treatment_owner: Option<String>,
79 pub treatment_deadline: Option<chrono::DateTime<chrono::Utc>>,
81}
82
83#[derive(Debug, Deserialize)]
85pub struct SetResidualRiskRequest {
86 pub residual_likelihood: Likelihood,
88 pub residual_impact: Impact,
90}
91
92#[derive(Debug, Serialize)]
94pub struct RiskListResponse {
95 pub risks: Vec<Risk>,
97 pub summary: mockforge_core::security::risk_assessment::RiskSummary,
99}
100
101pub async fn create_risk(
105 State(state): State<RiskAssessmentState>,
106 claims: OptionalAuthClaims,
107 Json(request): Json<CreateRiskRequest>,
108) -> Result<Json<serde_json::Value>, StatusCode> {
109 let created_by = extract_user_id_with_fallback(&claims);
111
112 let engine = state.engine.write().await;
113 let risk = engine
114 .create_risk(
115 request.title,
116 request.description,
117 request.category,
118 request.likelihood,
119 request.impact,
120 created_by,
121 )
122 .await
123 .map_err(|e| {
124 error!("Failed to create risk: {}", e);
125 StatusCode::INTERNAL_SERVER_ERROR
126 })?;
127
128 let mut risk = risk;
130 let mut needs_update = false;
131
132 if let Some(subcategory) = request.subcategory {
133 risk.subcategory = Some(subcategory);
134 needs_update = true;
135 }
136 if let Some(threat) = request.threat {
137 risk.threat = Some(threat);
138 needs_update = true;
139 }
140 if let Some(vulnerability) = request.vulnerability {
141 risk.vulnerability = Some(vulnerability);
142 needs_update = true;
143 }
144 if let Some(asset) = request.asset {
145 risk.asset = Some(asset);
146 needs_update = true;
147 }
148 if let Some(controls) = request.existing_controls {
149 risk.existing_controls = controls;
150 needs_update = true;
151 }
152 if let Some(requirements) = request.compliance_requirements {
153 risk.compliance_requirements = requirements;
154 needs_update = true;
155 }
156
157 if needs_update {
158 let risk_id = risk.risk_id.clone();
159 engine.update_risk(&risk_id, risk.clone()).await.map_err(|e| {
160 error!("Failed to update risk with optional fields: {}", e);
161 StatusCode::INTERNAL_SERVER_ERROR
162 })?;
163 }
164
165 info!("Risk created: {}", risk.risk_id);
166
167 let event = SecurityEvent::new(SecurityEventType::ConfigChanged, None, None)
169 .with_actor(EventActor {
170 user_id: Some(created_by.to_string()),
171 username: None,
172 ip_address: None,
173 user_agent: None,
174 })
175 .with_target(EventTarget {
176 resource_type: Some("risk".to_string()),
177 resource_id: Some(risk.risk_id.clone()),
178 method: None,
179 })
180 .with_outcome(EventOutcome {
181 success: true,
182 reason: Some("Risk created".to_string()),
183 });
184 emit_security_event(event).await;
185
186 Ok(Json(serde_json::to_value(&risk).map_err(|e| {
187 error!("Failed to serialize risk: {}", e);
188 StatusCode::INTERNAL_SERVER_ERROR
189 })?))
190}
191
192pub async fn get_risk(
196 State(state): State<RiskAssessmentState>,
197 Path(risk_id): Path<String>,
198) -> Result<Json<serde_json::Value>, StatusCode> {
199 let engine = state.engine.read().await;
200 let risk = engine
201 .get_risk(&risk_id)
202 .await
203 .map_err(|e| {
204 error!("Failed to get risk: {}", e);
205 StatusCode::INTERNAL_SERVER_ERROR
206 })?
207 .ok_or_else(|| {
208 error!("Risk not found: {}", risk_id);
209 StatusCode::NOT_FOUND
210 })?;
211
212 Ok(Json(serde_json::to_value(&risk).map_err(|e| {
213 error!("Failed to serialize risk: {}", e);
214 StatusCode::INTERNAL_SERVER_ERROR
215 })?))
216}
217
218pub async fn list_risks(
222 State(state): State<RiskAssessmentState>,
223 Query(params): Query<HashMap<String, String>>,
224) -> Result<Json<RiskListResponse>, StatusCode> {
225 let engine = state.engine.read().await;
226
227 let risks = if let Some(level_str) = params.get("level") {
228 let level = match level_str.as_str() {
229 "critical" => RiskLevel::Critical,
230 "high" => RiskLevel::High,
231 "medium" => RiskLevel::Medium,
232 "low" => RiskLevel::Low,
233 _ => return Err(StatusCode::BAD_REQUEST),
234 };
235 engine.get_risks_by_level(level).await.map_err(|e| {
236 error!("Failed to get risks by level: {}", e);
237 StatusCode::INTERNAL_SERVER_ERROR
238 })?
239 } else if let Some(category_str) = params.get("category") {
240 let category = match category_str.as_str() {
241 "technical" => RiskCategory::Technical,
242 "operational" => RiskCategory::Operational,
243 "compliance" => RiskCategory::Compliance,
244 "business" => RiskCategory::Business,
245 _ => return Err(StatusCode::BAD_REQUEST),
246 };
247 engine.get_risks_by_category(category).await.map_err(|e| {
248 error!("Failed to get risks by category: {}", e);
249 StatusCode::INTERNAL_SERVER_ERROR
250 })?
251 } else if let Some(status_str) = params.get("treatment_status") {
252 let status = match status_str.as_str() {
253 "not_started" => TreatmentStatus::NotStarted,
254 "in_progress" => TreatmentStatus::InProgress,
255 "completed" => TreatmentStatus::Completed,
256 "on_hold" => TreatmentStatus::OnHold,
257 _ => return Err(StatusCode::BAD_REQUEST),
258 };
259 engine.get_risks_by_treatment_status(status).await.map_err(|e| {
260 error!("Failed to get risks by treatment status: {}", e);
261 StatusCode::INTERNAL_SERVER_ERROR
262 })?
263 } else {
264 engine.get_all_risks().await.map_err(|e| {
265 error!("Failed to get all risks: {}", e);
266 StatusCode::INTERNAL_SERVER_ERROR
267 })?
268 };
269
270 let summary = engine.get_risk_summary().await.map_err(|e| {
271 error!("Failed to get risk summary: {}", e);
272 StatusCode::INTERNAL_SERVER_ERROR
273 })?;
274
275 Ok(Json(RiskListResponse { risks, summary }))
276}
277
278pub async fn update_risk_assessment(
282 State(state): State<RiskAssessmentState>,
283 Path(risk_id): Path<String>,
284 claims: OptionalAuthClaims,
285 Json(request): Json<UpdateRiskAssessmentRequest>,
286) -> Result<Json<serde_json::Value>, StatusCode> {
287 let updated_by = extract_user_id_with_fallback(&claims);
289
290 let engine = state.engine.write().await;
291 engine
292 .update_risk_assessment(&risk_id, request.likelihood, request.impact)
293 .await
294 .map_err(|e| {
295 error!("Failed to update risk assessment: {}", e);
296 StatusCode::BAD_REQUEST
297 })?;
298
299 info!("Risk assessment updated: {}", risk_id);
300
301 let event = SecurityEvent::new(SecurityEventType::ConfigChanged, None, None)
303 .with_actor(EventActor {
304 user_id: Some(updated_by.to_string()),
305 username: None,
306 ip_address: None,
307 user_agent: None,
308 })
309 .with_target(EventTarget {
310 resource_type: Some("risk".to_string()),
311 resource_id: Some(risk_id.clone()),
312 method: None,
313 })
314 .with_outcome(EventOutcome {
315 success: true,
316 reason: Some("Risk assessment updated".to_string()),
317 });
318 emit_security_event(event).await;
319
320 Ok(Json(serde_json::json!({
321 "risk_id": risk_id,
322 "status": "updated"
323 })))
324}
325
326pub async fn update_treatment_plan(
330 State(state): State<RiskAssessmentState>,
331 Path(risk_id): Path<String>,
332 claims: OptionalAuthClaims,
333 Json(request): Json<UpdateTreatmentPlanRequest>,
334) -> Result<Json<serde_json::Value>, StatusCode> {
335 let updated_by = extract_user_id_with_fallback(&claims);
337
338 let engine = state.engine.write().await;
339 engine
340 .update_treatment_plan(
341 &risk_id,
342 request.treatment_option,
343 request.treatment_plan,
344 request.treatment_owner,
345 request.treatment_deadline,
346 )
347 .await
348 .map_err(|e| {
349 error!("Failed to update treatment plan: {}", e);
350 StatusCode::BAD_REQUEST
351 })?;
352
353 info!("Treatment plan updated: {}", risk_id);
354
355 let event = SecurityEvent::new(SecurityEventType::ConfigChanged, None, None)
357 .with_actor(EventActor {
358 user_id: Some(updated_by.to_string()),
359 username: None,
360 ip_address: None,
361 user_agent: None,
362 })
363 .with_target(EventTarget {
364 resource_type: Some("risk".to_string()),
365 resource_id: Some(risk_id.clone()),
366 method: None,
367 })
368 .with_outcome(EventOutcome {
369 success: true,
370 reason: Some("Treatment plan updated".to_string()),
371 });
372 emit_security_event(event).await;
373
374 Ok(Json(serde_json::json!({
375 "risk_id": risk_id,
376 "status": "updated"
377 })))
378}
379
380pub async fn update_treatment_status(
384 State(state): State<RiskAssessmentState>,
385 Path(risk_id): Path<String>,
386 claims: OptionalAuthClaims,
387 Json(request): Json<serde_json::Value>,
388) -> Result<Json<serde_json::Value>, StatusCode> {
389 let _updated_by = extract_user_id_with_fallback(&claims);
391
392 let status_str =
393 request.get("status").and_then(|v| v.as_str()).ok_or(StatusCode::BAD_REQUEST)?;
394
395 let status = match status_str {
396 "not_started" => TreatmentStatus::NotStarted,
397 "in_progress" => TreatmentStatus::InProgress,
398 "completed" => TreatmentStatus::Completed,
399 "on_hold" => TreatmentStatus::OnHold,
400 _ => return Err(StatusCode::BAD_REQUEST),
401 };
402
403 let engine = state.engine.write().await;
404 engine.update_treatment_status(&risk_id, status).await.map_err(|e| {
405 error!("Failed to update treatment status: {}", e);
406 StatusCode::BAD_REQUEST
407 })?;
408
409 info!("Treatment status updated: {}", risk_id);
410
411 Ok(Json(serde_json::json!({
412 "risk_id": risk_id,
413 "status": "updated"
414 })))
415}
416
417pub async fn set_residual_risk(
421 State(state): State<RiskAssessmentState>,
422 Path(risk_id): Path<String>,
423 claims: OptionalAuthClaims,
424 Json(request): Json<SetResidualRiskRequest>,
425) -> Result<Json<serde_json::Value>, StatusCode> {
426 let _updated_by = extract_user_id_with_fallback(&claims);
428
429 let engine = state.engine.write().await;
430 engine
431 .set_residual_risk(&risk_id, request.residual_likelihood, request.residual_impact)
432 .await
433 .map_err(|e| {
434 error!("Failed to set residual risk: {}", e);
435 StatusCode::BAD_REQUEST
436 })?;
437
438 info!("Residual risk set: {}", risk_id);
439
440 Ok(Json(serde_json::json!({
441 "risk_id": risk_id,
442 "status": "updated"
443 })))
444}
445
446pub async fn review_risk(
450 State(state): State<RiskAssessmentState>,
451 Path(risk_id): Path<String>,
452 claims: OptionalAuthClaims,
453) -> Result<Json<serde_json::Value>, StatusCode> {
454 let reviewed_by = extract_user_id_with_fallback(&claims);
456
457 let engine = state.engine.write().await;
458 engine.review_risk(&risk_id, reviewed_by).await.map_err(|e| {
459 error!("Failed to review risk: {}", e);
460 StatusCode::BAD_REQUEST
461 })?;
462
463 info!("Risk reviewed: {}", risk_id);
464
465 Ok(Json(serde_json::json!({
466 "risk_id": risk_id,
467 "status": "reviewed"
468 })))
469}
470
471pub async fn get_risks_due_for_review(
475 State(state): State<RiskAssessmentState>,
476) -> Result<Json<serde_json::Value>, StatusCode> {
477 let engine = state.engine.read().await;
478 let risks = engine.get_risks_due_for_review().await.map_err(|e| {
479 error!("Failed to get risks due for review: {}", e);
480 StatusCode::INTERNAL_SERVER_ERROR
481 })?;
482
483 Ok(Json(serde_json::to_value(&risks).map_err(|e| {
484 error!("Failed to serialize risks: {}", e);
485 StatusCode::INTERNAL_SERVER_ERROR
486 })?))
487}
488
489pub async fn get_risk_summary(
493 State(state): State<RiskAssessmentState>,
494) -> Result<Json<serde_json::Value>, StatusCode> {
495 let engine = state.engine.read().await;
496 let summary = engine.get_risk_summary().await.map_err(|e| {
497 error!("Failed to get risk summary: {}", e);
498 StatusCode::INTERNAL_SERVER_ERROR
499 })?;
500
501 Ok(Json(serde_json::to_value(&summary).map_err(|e| {
502 error!("Failed to serialize summary: {}", e);
503 StatusCode::INTERNAL_SERVER_ERROR
504 })?))
505}
506
507pub fn risk_assessment_router(state: RiskAssessmentState) -> axum::Router {
509 use axum::routing::{get, patch, post, put};
510
511 axum::Router::new()
512 .route("/risks", get(list_risks))
513 .route("/risks", post(create_risk))
514 .route("/risks/{risk_id}", get(get_risk))
515 .route("/risks/{risk_id}/assessment", put(update_risk_assessment))
516 .route("/risks/{risk_id}/treatment", put(update_treatment_plan))
517 .route("/risks/{risk_id}/treatment/status", patch(update_treatment_status))
518 .route("/risks/{risk_id}/residual", put(set_residual_risk))
519 .route("/risks/{risk_id}/review", post(review_risk))
520 .route("/risks/due-for-review", get(get_risks_due_for_review))
521 .route("/risks/summary", get(get_risk_summary))
522 .with_state(state)
523}