1use axum::{
7 extract::{Extension, Path, Query, State},
8 http::StatusCode,
9 response::Json,
10};
11use mockforge_core::security::{
12 emit_security_event,
13 risk_assessment::{
14 Impact, Likelihood, Risk, RiskAssessmentEngine, RiskCategory, RiskLevel, TreatmentOption,
15 TreatmentStatus,
16 },
17 EventActor, EventOutcome, EventTarget, SecurityEvent, SecurityEventType,
18};
19use serde::{Deserialize, Serialize};
20use std::collections::HashMap;
21use std::sync::Arc;
22use tokio::sync::RwLock;
23use tracing::{error, info};
24use uuid::Uuid;
25
26use crate::auth::types::AuthClaims;
27use crate::handlers::auth_helpers::extract_user_id_with_fallback;
28
29#[derive(Clone)]
31pub struct RiskAssessmentState {
32 pub engine: Arc<RwLock<RiskAssessmentEngine>>,
34}
35
36#[derive(Debug, Deserialize)]
38pub struct CreateRiskRequest {
39 pub title: String,
41 pub description: String,
43 pub category: RiskCategory,
45 pub subcategory: Option<String>,
47 pub likelihood: Likelihood,
49 pub impact: Impact,
51 pub threat: Option<String>,
53 pub vulnerability: Option<String>,
55 pub asset: Option<String>,
57 pub existing_controls: Option<Vec<String>>,
59 pub compliance_requirements: Option<Vec<String>>,
61}
62
63#[derive(Debug, Deserialize)]
65pub struct UpdateRiskAssessmentRequest {
66 pub likelihood: Option<Likelihood>,
68 pub impact: Option<Impact>,
70}
71
72#[derive(Debug, Deserialize)]
74pub struct UpdateTreatmentPlanRequest {
75 pub treatment_option: TreatmentOption,
77 pub treatment_plan: Vec<String>,
79 pub treatment_owner: Option<String>,
81 pub treatment_deadline: Option<chrono::DateTime<chrono::Utc>>,
83}
84
85#[derive(Debug, Deserialize)]
87pub struct SetResidualRiskRequest {
88 pub residual_likelihood: Likelihood,
90 pub residual_impact: Impact,
92}
93
94#[derive(Debug, Serialize)]
96pub struct RiskListResponse {
97 pub risks: Vec<Risk>,
99 pub summary: mockforge_core::security::risk_assessment::RiskSummary,
101}
102
103pub async fn create_risk(
107 State(state): State<RiskAssessmentState>,
108 Json(request): Json<CreateRiskRequest>,
109 claims: Option<Extension<AuthClaims>>,
110) -> Result<Json<serde_json::Value>, StatusCode> {
111 let created_by = extract_user_id_with_fallback(claims);
113
114 let engine = state.engine.write().await;
115 let mut risk = engine
116 .create_risk(
117 request.title,
118 request.description,
119 request.category,
120 request.likelihood,
121 request.impact,
122 created_by,
123 )
124 .await
125 .map_err(|e| {
126 error!("Failed to create risk: {}", e);
127 StatusCode::INTERNAL_SERVER_ERROR
128 })?;
129
130 if let Some(subcategory) = request.subcategory {
132 }
135 if let Some(threat) = request.threat {
136 }
138 if let Some(vulnerability) = request.vulnerability {
139 }
141 if let Some(asset) = request.asset {
142 }
144 if let Some(controls) = request.existing_controls {
145 }
147 if let Some(requirements) = request.compliance_requirements {
148 }
150
151 info!("Risk created: {}", risk.risk_id);
152
153 let event = SecurityEvent::new(SecurityEventType::ConfigChanged, None, None)
155 .with_actor(EventActor {
156 user_id: Some(created_by.to_string()),
157 username: None,
158 ip_address: None,
159 user_agent: None,
160 })
161 .with_target(EventTarget {
162 resource_type: Some("risk".to_string()),
163 resource_id: Some(risk.risk_id.clone()),
164 method: None,
165 })
166 .with_outcome(EventOutcome {
167 success: true,
168 reason: Some("Risk created".to_string()),
169 });
170 emit_security_event(event).await;
171
172 Ok(Json(serde_json::to_value(&risk).unwrap()))
173}
174
175pub async fn get_risk(
179 State(state): State<RiskAssessmentState>,
180 Path(risk_id): Path<String>,
181) -> Result<Json<serde_json::Value>, StatusCode> {
182 let engine = state.engine.read().await;
183 let risk = engine
184 .get_risk(&risk_id)
185 .await
186 .map_err(|e| {
187 error!("Failed to get risk: {}", e);
188 StatusCode::INTERNAL_SERVER_ERROR
189 })?
190 .ok_or_else(|| {
191 error!("Risk not found: {}", risk_id);
192 StatusCode::NOT_FOUND
193 })?;
194
195 Ok(Json(serde_json::to_value(&risk).unwrap()))
196}
197
198pub async fn list_risks(
202 State(state): State<RiskAssessmentState>,
203 Query(params): Query<HashMap<String, String>>,
204) -> Result<Json<RiskListResponse>, StatusCode> {
205 let engine = state.engine.read().await;
206
207 let risks = if let Some(level_str) = params.get("level") {
208 let level = match level_str.as_str() {
209 "critical" => RiskLevel::Critical,
210 "high" => RiskLevel::High,
211 "medium" => RiskLevel::Medium,
212 "low" => RiskLevel::Low,
213 _ => return Err(StatusCode::BAD_REQUEST),
214 };
215 engine
216 .get_risks_by_level(level)
217 .await
218 .map_err(|e| {
219 error!("Failed to get risks by level: {}", e);
220 StatusCode::INTERNAL_SERVER_ERROR
221 })?
222 } else if let Some(category_str) = params.get("category") {
223 let category = match category_str.as_str() {
224 "technical" => RiskCategory::Technical,
225 "operational" => RiskCategory::Operational,
226 "compliance" => RiskCategory::Compliance,
227 "business" => RiskCategory::Business,
228 _ => return Err(StatusCode::BAD_REQUEST),
229 };
230 engine
231 .get_risks_by_category(category)
232 .await
233 .map_err(|e| {
234 error!("Failed to get risks by category: {}", e);
235 StatusCode::INTERNAL_SERVER_ERROR
236 })?
237 } else if let Some(status_str) = params.get("treatment_status") {
238 let status = match status_str.as_str() {
239 "not_started" => TreatmentStatus::NotStarted,
240 "in_progress" => TreatmentStatus::InProgress,
241 "completed" => TreatmentStatus::Completed,
242 "on_hold" => TreatmentStatus::OnHold,
243 _ => return Err(StatusCode::BAD_REQUEST),
244 };
245 engine
246 .get_risks_by_treatment_status(status)
247 .await
248 .map_err(|e| {
249 error!("Failed to get risks by treatment status: {}", e);
250 StatusCode::INTERNAL_SERVER_ERROR
251 })?
252 } else {
253 engine
254 .get_all_risks()
255 .await
256 .map_err(|e| {
257 error!("Failed to get all risks: {}", e);
258 StatusCode::INTERNAL_SERVER_ERROR
259 })?
260 };
261
262 let summary = engine
263 .get_risk_summary()
264 .await
265 .map_err(|e| {
266 error!("Failed to get risk summary: {}", e);
267 StatusCode::INTERNAL_SERVER_ERROR
268 })?;
269
270 Ok(Json(RiskListResponse { risks, summary }))
271}
272
273pub async fn update_risk_assessment(
277 State(state): State<RiskAssessmentState>,
278 Path(risk_id): Path<String>,
279 Json(request): Json<UpdateRiskAssessmentRequest>,
280 claims: Option<Extension<AuthClaims>>,
281) -> Result<Json<serde_json::Value>, StatusCode> {
282 let updated_by = extract_user_id_with_fallback(claims);
284
285 let engine = state.engine.write().await;
286 engine
287 .update_risk_assessment(&risk_id, request.likelihood, request.impact)
288 .await
289 .map_err(|e| {
290 error!("Failed to update risk assessment: {}", e);
291 StatusCode::BAD_REQUEST
292 })?;
293
294 info!("Risk assessment updated: {}", risk_id);
295
296 let event = SecurityEvent::new(SecurityEventType::ConfigChanged, None, None)
298 .with_actor(EventActor {
299 user_id: Some(updated_by.to_string()),
300 username: None,
301 ip_address: None,
302 user_agent: None,
303 })
304 .with_target(EventTarget {
305 resource_type: Some("risk".to_string()),
306 resource_id: Some(risk_id.clone()),
307 method: None,
308 })
309 .with_outcome(EventOutcome {
310 success: true,
311 reason: Some("Risk assessment updated".to_string()),
312 });
313 emit_security_event(event).await;
314
315 Ok(Json(serde_json::json!({
316 "risk_id": risk_id,
317 "status": "updated"
318 })))
319}
320
321pub async fn update_treatment_plan(
325 State(state): State<RiskAssessmentState>,
326 Path(risk_id): Path<String>,
327 Json(request): Json<UpdateTreatmentPlanRequest>,
328 claims: Option<Extension<AuthClaims>>,
329) -> Result<Json<serde_json::Value>, StatusCode> {
330 let updated_by = extract_user_id_with_fallback(claims);
332
333 let engine = state.engine.write().await;
334 engine
335 .update_treatment_plan(
336 &risk_id,
337 request.treatment_option,
338 request.treatment_plan,
339 request.treatment_owner,
340 request.treatment_deadline,
341 )
342 .await
343 .map_err(|e| {
344 error!("Failed to update treatment plan: {}", e);
345 StatusCode::BAD_REQUEST
346 })?;
347
348 info!("Treatment plan updated: {}", risk_id);
349
350 let event = SecurityEvent::new(SecurityEventType::ConfigChanged, None, None)
352 .with_actor(EventActor {
353 user_id: Some(updated_by.to_string()),
354 username: None,
355 ip_address: None,
356 user_agent: None,
357 })
358 .with_target(EventTarget {
359 resource_type: Some("risk".to_string()),
360 resource_id: Some(risk_id.clone()),
361 method: None,
362 })
363 .with_outcome(EventOutcome {
364 success: true,
365 reason: Some("Treatment plan updated".to_string()),
366 });
367 emit_security_event(event).await;
368
369 Ok(Json(serde_json::json!({
370 "risk_id": risk_id,
371 "status": "updated"
372 })))
373}
374
375pub async fn update_treatment_status(
379 State(state): State<RiskAssessmentState>,
380 Path(risk_id): Path<String>,
381 Json(request): Json<serde_json::Value>,
382 claims: Option<Extension<AuthClaims>>,
383) -> Result<Json<serde_json::Value>, StatusCode> {
384 let _updated_by = extract_user_id_with_fallback(claims);
386
387 let status_str = request
388 .get("status")
389 .and_then(|v| v.as_str())
390 .ok_or_else(|| StatusCode::BAD_REQUEST)?;
391
392 let status = match status_str {
393 "not_started" => TreatmentStatus::NotStarted,
394 "in_progress" => TreatmentStatus::InProgress,
395 "completed" => TreatmentStatus::Completed,
396 "on_hold" => TreatmentStatus::OnHold,
397 _ => return Err(StatusCode::BAD_REQUEST),
398 };
399
400 let engine = state.engine.write().await;
401 engine
402 .update_treatment_status(&risk_id, status)
403 .await
404 .map_err(|e| {
405 error!("Failed to update treatment status: {}", e);
406 StatusCode::BAD_REQUEST
407 })?;
408
409 info!("Treatment status updated: {}", risk_id);
410
411 Ok(Json(serde_json::json!({
412 "risk_id": risk_id,
413 "status": "updated"
414 })))
415}
416
417pub async fn set_residual_risk(
421 State(state): State<RiskAssessmentState>,
422 Path(risk_id): Path<String>,
423 Json(request): Json<SetResidualRiskRequest>,
424 claims: Option<Extension<AuthClaims>>,
425) -> Result<Json<serde_json::Value>, StatusCode> {
426 let _updated_by = extract_user_id_with_fallback(claims);
428
429 let engine = state.engine.write().await;
430 engine
431 .set_residual_risk(&risk_id, request.residual_likelihood, request.residual_impact)
432 .await
433 .map_err(|e| {
434 error!("Failed to set residual risk: {}", e);
435 StatusCode::BAD_REQUEST
436 })?;
437
438 info!("Residual risk set: {}", risk_id);
439
440 Ok(Json(serde_json::json!({
441 "risk_id": risk_id,
442 "status": "updated"
443 })))
444}
445
446pub async fn review_risk(
450 State(state): State<RiskAssessmentState>,
451 Path(risk_id): Path<String>,
452 claims: Option<Extension<AuthClaims>>,
453) -> Result<Json<serde_json::Value>, StatusCode> {
454 let reviewed_by = extract_user_id_with_fallback(claims);
456
457 let engine = state.engine.write().await;
458 engine
459 .review_risk(&risk_id, reviewed_by)
460 .await
461 .map_err(|e| {
462 error!("Failed to review risk: {}", e);
463 StatusCode::BAD_REQUEST
464 })?;
465
466 info!("Risk reviewed: {}", risk_id);
467
468 Ok(Json(serde_json::json!({
469 "risk_id": risk_id,
470 "status": "reviewed"
471 })))
472}
473
474pub async fn get_risks_due_for_review(
478 State(state): State<RiskAssessmentState>,
479) -> Result<Json<serde_json::Value>, StatusCode> {
480 let engine = state.engine.read().await;
481 let risks = engine
482 .get_risks_due_for_review()
483 .await
484 .map_err(|e| {
485 error!("Failed to get risks due for review: {}", e);
486 StatusCode::INTERNAL_SERVER_ERROR
487 })?;
488
489 Ok(Json(serde_json::to_value(&risks).unwrap()))
490}
491
492pub async fn get_risk_summary(
496 State(state): State<RiskAssessmentState>,
497) -> Result<Json<serde_json::Value>, StatusCode> {
498 let engine = state.engine.read().await;
499 let summary = engine
500 .get_risk_summary()
501 .await
502 .map_err(|e| {
503 error!("Failed to get risk summary: {}", e);
504 StatusCode::INTERNAL_SERVER_ERROR
505 })?;
506
507 Ok(Json(serde_json::to_value(&summary).unwrap()))
508}
509
510pub fn risk_assessment_router(state: RiskAssessmentState) -> axum::Router {
512 use axum::routing::{get, post, put, patch};
513
514 axum::Router::new()
515 .route("/risks", get(list_risks))
516 .route("/risks", post(create_risk))
517 .route("/risks/{risk_id}", get(get_risk))
518 .route("/risks/{risk_id}/assessment", put(update_risk_assessment))
519 .route("/risks/{risk_id}/treatment", put(update_treatment_plan))
520 .route("/risks/{risk_id}/treatment/status", patch(update_treatment_status))
521 .route("/risks/{risk_id}/residual", put(set_residual_risk))
522 .route("/risks/{risk_id}/review", post(review_risk))
523 .route("/risks/due-for-review", get(get_risks_due_for_review))
524 .route("/risks/summary", get(get_risk_summary))
525 .with_state(state)
526}