mockforge_http/handlers/
risk_simulation.rs1use axum::{extract::State, http::StatusCode, response::Json};
11use serde::Deserialize;
12use std::collections::HashMap;
13use std::sync::Arc;
14
15use crate::auth::risk_engine::RiskEngine;
16
17#[derive(Clone)]
19pub struct RiskSimulationState {
20 pub risk_engine: Arc<RiskEngine>,
22}
23
24#[derive(Debug, Deserialize)]
26pub struct SetRiskScoreRequest {
27 pub user_id: String,
29 pub risk_score: f64,
31}
32
33#[derive(Debug, Deserialize)]
35pub struct SetRiskFactorsRequest {
36 pub user_id: String,
38 pub risk_factors: HashMap<String, f64>,
40}
41
42#[derive(Debug, Deserialize)]
44pub struct TriggerMfaRequest {
45 pub user_id: String,
47 pub mfa_type: String,
49}
50
51#[derive(Debug, Deserialize)]
53pub struct BlockUserRequest {
54 pub user_id: String,
56 pub reason: String,
58}
59
60pub async fn set_risk_score(
62 State(state): State<RiskSimulationState>,
63 Json(request): Json<SetRiskScoreRequest>,
64) -> Result<Json<serde_json::Value>, StatusCode> {
65 state
66 .risk_engine
67 .set_simulated_risk(request.user_id.clone(), Some(request.risk_score))
68 .await;
69
70 Ok(Json(serde_json::json!({
71 "success": true,
72 "user_id": request.user_id,
73 "risk_score": request.risk_score
74 })))
75}
76
77pub async fn set_risk_factors(
79 State(state): State<RiskSimulationState>,
80 Json(request): Json<SetRiskFactorsRequest>,
81) -> Result<Json<serde_json::Value>, StatusCode> {
82 state
83 .risk_engine
84 .set_simulated_factors(request.user_id.clone(), request.risk_factors.clone())
85 .await;
86
87 Ok(Json(serde_json::json!({
88 "success": true,
89 "user_id": request.user_id,
90 "risk_factors": request.risk_factors
91 })))
92}
93
94pub async fn clear_risk(
96 State(state): State<RiskSimulationState>,
97 axum::extract::Path(user_id): axum::extract::Path<String>,
98) -> Result<Json<serde_json::Value>, StatusCode> {
99 state.risk_engine.clear_simulated_risk(&user_id).await;
100
101 Ok(Json(serde_json::json!({
102 "success": true,
103 "user_id": user_id,
104 "message": "Simulated risk cleared"
105 })))
106}
107
108pub async fn trigger_mfa(
110 State(state): State<RiskSimulationState>,
111 Json(request): Json<TriggerMfaRequest>,
112) -> Result<Json<serde_json::Value>, StatusCode> {
113 state.risk_engine.set_simulated_risk(request.user_id.clone(), Some(0.8)).await;
115
116 Ok(Json(serde_json::json!({
117 "success": true,
118 "user_id": request.user_id,
119 "mfa_type": request.mfa_type,
120 "message": "MFA prompt triggered"
121 })))
122}
123
124pub async fn block_user(
126 State(state): State<RiskSimulationState>,
127 Json(request): Json<BlockUserRequest>,
128) -> Result<Json<serde_json::Value>, StatusCode> {
129 state.risk_engine.set_simulated_risk(request.user_id.clone(), Some(0.95)).await;
131
132 Ok(Json(serde_json::json!({
133 "success": true,
134 "user_id": request.user_id,
135 "reason": request.reason,
136 "message": "User login blocked"
137 })))
138}
139
140pub async fn get_risk_assessment(
142 State(state): State<RiskSimulationState>,
143 axum::extract::Path(user_id): axum::extract::Path<String>,
144) -> Json<serde_json::Value> {
145 let risk_factors = HashMap::new();
146 let assessment = state.risk_engine.assess_risk(&user_id, &risk_factors).await;
147
148 Json(serde_json::json!({
149 "user_id": user_id,
150 "risk_score": assessment.risk_score,
151 "risk_factors": assessment.risk_factors,
152 "recommended_action": assessment.recommended_action
153 }))
154}
155
156pub fn risk_simulation_router(state: RiskSimulationState) -> axum::Router {
158 use axum::routing::{delete, get, post};
159
160 axum::Router::new()
161 .route("/risk/simulate", post(set_risk_score))
162 .route("/risk/factors", post(set_risk_factors))
163 .route("/risk/clear/{user_id}", delete(clear_risk))
164 .route("/risk/trigger-mfa", post(trigger_mfa))
165 .route("/risk/block", post(block_user))
166 .route("/risk/assessment/{user_id}", get(get_risk_assessment))
167 .with_state(state)
168}