use axum::{
extract::{Path, Query, State},
http::StatusCode,
response::Json,
};
use mockforge_core::security::{
emit_security_event,
risk_assessment::{
Impact, Likelihood, Risk, RiskAssessmentEngine, RiskCategory, RiskLevel, TreatmentOption,
TreatmentStatus,
},
EventActor, EventOutcome, EventTarget, SecurityEvent, SecurityEventType,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{error, info};
use crate::handlers::auth_helpers::{extract_user_id_with_fallback, OptionalAuthClaims};
#[derive(Clone)]
pub struct RiskAssessmentState {
pub engine: Arc<RwLock<RiskAssessmentEngine>>,
}
#[derive(Debug, Deserialize)]
pub struct CreateRiskRequest {
pub title: String,
pub description: String,
pub category: RiskCategory,
pub subcategory: Option<String>,
pub likelihood: Likelihood,
pub impact: Impact,
pub threat: Option<String>,
pub vulnerability: Option<String>,
pub asset: Option<String>,
pub existing_controls: Option<Vec<String>>,
pub compliance_requirements: Option<Vec<String>>,
}
#[derive(Debug, Deserialize)]
pub struct UpdateRiskAssessmentRequest {
pub likelihood: Option<Likelihood>,
pub impact: Option<Impact>,
}
#[derive(Debug, Deserialize)]
pub struct UpdateTreatmentPlanRequest {
pub treatment_option: TreatmentOption,
pub treatment_plan: Vec<String>,
pub treatment_owner: Option<String>,
pub treatment_deadline: Option<chrono::DateTime<chrono::Utc>>,
}
#[derive(Debug, Deserialize)]
pub struct SetResidualRiskRequest {
pub residual_likelihood: Likelihood,
pub residual_impact: Impact,
}
#[derive(Debug, Serialize)]
pub struct RiskListResponse {
pub risks: Vec<Risk>,
pub summary: mockforge_core::security::risk_assessment::RiskSummary,
}
pub async fn create_risk(
State(state): State<RiskAssessmentState>,
claims: OptionalAuthClaims,
Json(request): Json<CreateRiskRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let created_by = extract_user_id_with_fallback(&claims);
let engine = state.engine.write().await;
let risk = engine
.create_risk(
request.title,
request.description,
request.category,
request.likelihood,
request.impact,
created_by,
)
.await
.map_err(|e| {
error!("Failed to create risk: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
let mut risk = risk;
let mut needs_update = false;
if let Some(subcategory) = request.subcategory {
risk.subcategory = Some(subcategory);
needs_update = true;
}
if let Some(threat) = request.threat {
risk.threat = Some(threat);
needs_update = true;
}
if let Some(vulnerability) = request.vulnerability {
risk.vulnerability = Some(vulnerability);
needs_update = true;
}
if let Some(asset) = request.asset {
risk.asset = Some(asset);
needs_update = true;
}
if let Some(controls) = request.existing_controls {
risk.existing_controls = controls;
needs_update = true;
}
if let Some(requirements) = request.compliance_requirements {
risk.compliance_requirements = requirements;
needs_update = true;
}
if needs_update {
let risk_id = risk.risk_id.clone();
engine.update_risk(&risk_id, risk.clone()).await.map_err(|e| {
error!("Failed to update risk with optional fields: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
}
info!("Risk created: {}", risk.risk_id);
let event = SecurityEvent::new(SecurityEventType::ConfigChanged, None, None)
.with_actor(EventActor {
user_id: Some(created_by.to_string()),
username: None,
ip_address: None,
user_agent: None,
})
.with_target(EventTarget {
resource_type: Some("risk".to_string()),
resource_id: Some(risk.risk_id.clone()),
method: None,
})
.with_outcome(EventOutcome {
success: true,
reason: Some("Risk created".to_string()),
});
emit_security_event(event).await;
Ok(Json(serde_json::to_value(&risk).map_err(|e| {
error!("Failed to serialize risk: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?))
}
pub async fn get_risk(
State(state): State<RiskAssessmentState>,
Path(risk_id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let engine = state.engine.read().await;
let risk = engine
.get_risk(&risk_id)
.await
.map_err(|e| {
error!("Failed to get risk: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?
.ok_or_else(|| {
error!("Risk not found: {}", risk_id);
StatusCode::NOT_FOUND
})?;
Ok(Json(serde_json::to_value(&risk).map_err(|e| {
error!("Failed to serialize risk: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?))
}
pub async fn list_risks(
State(state): State<RiskAssessmentState>,
Query(params): Query<HashMap<String, String>>,
) -> Result<Json<RiskListResponse>, StatusCode> {
let engine = state.engine.read().await;
let risks = if let Some(level_str) = params.get("level") {
let level = match level_str.as_str() {
"critical" => RiskLevel::Critical,
"high" => RiskLevel::High,
"medium" => RiskLevel::Medium,
"low" => RiskLevel::Low,
_ => return Err(StatusCode::BAD_REQUEST),
};
engine.get_risks_by_level(level).await.map_err(|e| {
error!("Failed to get risks by level: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?
} else if let Some(category_str) = params.get("category") {
let category = match category_str.as_str() {
"technical" => RiskCategory::Technical,
"operational" => RiskCategory::Operational,
"compliance" => RiskCategory::Compliance,
"business" => RiskCategory::Business,
_ => return Err(StatusCode::BAD_REQUEST),
};
engine.get_risks_by_category(category).await.map_err(|e| {
error!("Failed to get risks by category: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?
} else if let Some(status_str) = params.get("treatment_status") {
let status = match status_str.as_str() {
"not_started" => TreatmentStatus::NotStarted,
"in_progress" => TreatmentStatus::InProgress,
"completed" => TreatmentStatus::Completed,
"on_hold" => TreatmentStatus::OnHold,
_ => return Err(StatusCode::BAD_REQUEST),
};
engine.get_risks_by_treatment_status(status).await.map_err(|e| {
error!("Failed to get risks by treatment status: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?
} else {
engine.get_all_risks().await.map_err(|e| {
error!("Failed to get all risks: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?
};
let summary = engine.get_risk_summary().await.map_err(|e| {
error!("Failed to get risk summary: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(Json(RiskListResponse { risks, summary }))
}
pub async fn update_risk_assessment(
State(state): State<RiskAssessmentState>,
Path(risk_id): Path<String>,
claims: OptionalAuthClaims,
Json(request): Json<UpdateRiskAssessmentRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let updated_by = extract_user_id_with_fallback(&claims);
let engine = state.engine.write().await;
engine
.update_risk_assessment(&risk_id, request.likelihood, request.impact)
.await
.map_err(|e| {
error!("Failed to update risk assessment: {}", e);
StatusCode::BAD_REQUEST
})?;
info!("Risk assessment updated: {}", risk_id);
let event = SecurityEvent::new(SecurityEventType::ConfigChanged, None, None)
.with_actor(EventActor {
user_id: Some(updated_by.to_string()),
username: None,
ip_address: None,
user_agent: None,
})
.with_target(EventTarget {
resource_type: Some("risk".to_string()),
resource_id: Some(risk_id.clone()),
method: None,
})
.with_outcome(EventOutcome {
success: true,
reason: Some("Risk assessment updated".to_string()),
});
emit_security_event(event).await;
Ok(Json(serde_json::json!({
"risk_id": risk_id,
"status": "updated"
})))
}
pub async fn update_treatment_plan(
State(state): State<RiskAssessmentState>,
Path(risk_id): Path<String>,
claims: OptionalAuthClaims,
Json(request): Json<UpdateTreatmentPlanRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let updated_by = extract_user_id_with_fallback(&claims);
let engine = state.engine.write().await;
engine
.update_treatment_plan(
&risk_id,
request.treatment_option,
request.treatment_plan,
request.treatment_owner,
request.treatment_deadline,
)
.await
.map_err(|e| {
error!("Failed to update treatment plan: {}", e);
StatusCode::BAD_REQUEST
})?;
info!("Treatment plan updated: {}", risk_id);
let event = SecurityEvent::new(SecurityEventType::ConfigChanged, None, None)
.with_actor(EventActor {
user_id: Some(updated_by.to_string()),
username: None,
ip_address: None,
user_agent: None,
})
.with_target(EventTarget {
resource_type: Some("risk".to_string()),
resource_id: Some(risk_id.clone()),
method: None,
})
.with_outcome(EventOutcome {
success: true,
reason: Some("Treatment plan updated".to_string()),
});
emit_security_event(event).await;
Ok(Json(serde_json::json!({
"risk_id": risk_id,
"status": "updated"
})))
}
pub async fn update_treatment_status(
State(state): State<RiskAssessmentState>,
Path(risk_id): Path<String>,
claims: OptionalAuthClaims,
Json(request): Json<serde_json::Value>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let _updated_by = extract_user_id_with_fallback(&claims);
let status_str =
request.get("status").and_then(|v| v.as_str()).ok_or(StatusCode::BAD_REQUEST)?;
let status = match status_str {
"not_started" => TreatmentStatus::NotStarted,
"in_progress" => TreatmentStatus::InProgress,
"completed" => TreatmentStatus::Completed,
"on_hold" => TreatmentStatus::OnHold,
_ => return Err(StatusCode::BAD_REQUEST),
};
let engine = state.engine.write().await;
engine.update_treatment_status(&risk_id, status).await.map_err(|e| {
error!("Failed to update treatment status: {}", e);
StatusCode::BAD_REQUEST
})?;
info!("Treatment status updated: {}", risk_id);
Ok(Json(serde_json::json!({
"risk_id": risk_id,
"status": "updated"
})))
}
pub async fn set_residual_risk(
State(state): State<RiskAssessmentState>,
Path(risk_id): Path<String>,
claims: OptionalAuthClaims,
Json(request): Json<SetResidualRiskRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let _updated_by = extract_user_id_with_fallback(&claims);
let engine = state.engine.write().await;
engine
.set_residual_risk(&risk_id, request.residual_likelihood, request.residual_impact)
.await
.map_err(|e| {
error!("Failed to set residual risk: {}", e);
StatusCode::BAD_REQUEST
})?;
info!("Residual risk set: {}", risk_id);
Ok(Json(serde_json::json!({
"risk_id": risk_id,
"status": "updated"
})))
}
pub async fn review_risk(
State(state): State<RiskAssessmentState>,
Path(risk_id): Path<String>,
claims: OptionalAuthClaims,
) -> Result<Json<serde_json::Value>, StatusCode> {
let reviewed_by = extract_user_id_with_fallback(&claims);
let engine = state.engine.write().await;
engine.review_risk(&risk_id, reviewed_by).await.map_err(|e| {
error!("Failed to review risk: {}", e);
StatusCode::BAD_REQUEST
})?;
info!("Risk reviewed: {}", risk_id);
Ok(Json(serde_json::json!({
"risk_id": risk_id,
"status": "reviewed"
})))
}
pub async fn get_risks_due_for_review(
State(state): State<RiskAssessmentState>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let engine = state.engine.read().await;
let risks = engine.get_risks_due_for_review().await.map_err(|e| {
error!("Failed to get risks due for review: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(Json(serde_json::to_value(&risks).map_err(|e| {
error!("Failed to serialize risks: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?))
}
pub async fn get_risk_summary(
State(state): State<RiskAssessmentState>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let engine = state.engine.read().await;
let summary = engine.get_risk_summary().await.map_err(|e| {
error!("Failed to get risk summary: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(Json(serde_json::to_value(&summary).map_err(|e| {
error!("Failed to serialize summary: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?))
}
pub fn risk_assessment_router(state: RiskAssessmentState) -> axum::Router {
use axum::routing::{get, patch, post, put};
axum::Router::new()
.route("/risks", get(list_risks))
.route("/risks", post(create_risk))
.route("/risks/{risk_id}", get(get_risk))
.route("/risks/{risk_id}/assessment", put(update_risk_assessment))
.route("/risks/{risk_id}/treatment", put(update_treatment_plan))
.route("/risks/{risk_id}/treatment/status", patch(update_treatment_status))
.route("/risks/{risk_id}/residual", put(set_residual_risk))
.route("/risks/{risk_id}/review", post(review_risk))
.route("/risks/due-for-review", get(get_risks_due_for_review))
.route("/risks/summary", get(get_risk_summary))
.with_state(state)
}