mockforge_http/handlers/
risk_simulation.rs1use axum::{
7 extract::State,
8 http::StatusCode,
9 response::Json,
10};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14
15use crate::auth::risk_engine::{RiskEngine, RiskEngineConfig};
16
17#[derive(Clone)]
19pub struct RiskSimulationState {
20 pub risk_engine: Arc<RiskEngine>,
22}
23
24#[derive(Debug, Deserialize)]
26pub struct SetRiskScoreRequest {
27 pub user_id: String,
29 pub risk_score: f64,
31}
32
33#[derive(Debug, Deserialize)]
35pub struct SetRiskFactorsRequest {
36 pub user_id: String,
38 pub risk_factors: HashMap<String, f64>,
40}
41
42#[derive(Debug, Deserialize)]
44pub struct TriggerMfaRequest {
45 pub user_id: String,
47 pub mfa_type: String,
49}
50
51#[derive(Debug, Deserialize)]
53pub struct BlockUserRequest {
54 pub user_id: String,
56 pub reason: String,
58}
59
60pub async fn set_risk_score(
62 State(state): State<RiskSimulationState>,
63 Json(request): Json<SetRiskScoreRequest>,
64) -> Result<Json<serde_json::Value>, StatusCode> {
65 state
66 .risk_engine
67 .set_simulated_risk(request.user_id.clone(), Some(request.risk_score))
68 .await;
69
70 Ok(Json(serde_json::json!({
71 "success": true,
72 "user_id": request.user_id,
73 "risk_score": request.risk_score
74 })))
75}
76
77pub async fn set_risk_factors(
79 State(state): State<RiskSimulationState>,
80 Json(request): Json<SetRiskFactorsRequest>,
81) -> Result<Json<serde_json::Value>, StatusCode> {
82 state
83 .risk_engine
84 .set_simulated_factors(request.user_id.clone(), request.risk_factors.clone())
85 .await;
86
87 Ok(Json(serde_json::json!({
88 "success": true,
89 "user_id": request.user_id,
90 "risk_factors": request.risk_factors
91 })))
92}
93
94pub async fn clear_risk(
96 State(state): State<RiskSimulationState>,
97 axum::extract::Path(user_id): axum::extract::Path<String>,
98) -> Result<Json<serde_json::Value>, StatusCode> {
99 state.risk_engine.clear_simulated_risk(&user_id).await;
100
101 Ok(Json(serde_json::json!({
102 "success": true,
103 "user_id": user_id,
104 "message": "Simulated risk cleared"
105 })))
106}
107
108pub async fn trigger_mfa(
110 State(state): State<RiskSimulationState>,
111 Json(request): Json<TriggerMfaRequest>,
112) -> Result<Json<serde_json::Value>, StatusCode> {
113 state
115 .risk_engine
116 .set_simulated_risk(request.user_id.clone(), Some(0.8))
117 .await;
118
119 Ok(Json(serde_json::json!({
120 "success": true,
121 "user_id": request.user_id,
122 "mfa_type": request.mfa_type,
123 "message": "MFA prompt triggered"
124 })))
125}
126
127pub async fn block_user(
129 State(state): State<RiskSimulationState>,
130 Json(request): Json<BlockUserRequest>,
131) -> Result<Json<serde_json::Value>, StatusCode> {
132 state
134 .risk_engine
135 .set_simulated_risk(request.user_id.clone(), Some(0.95))
136 .await;
137
138 Ok(Json(serde_json::json!({
139 "success": true,
140 "user_id": request.user_id,
141 "reason": request.reason,
142 "message": "User login blocked"
143 })))
144}
145
146pub async fn get_risk_assessment(
148 State(state): State<RiskSimulationState>,
149 axum::extract::Path(user_id): axum::extract::Path<String>,
150) -> Json<serde_json::Value> {
151 let risk_factors = HashMap::new();
152 let assessment = state
153 .risk_engine
154 .assess_risk(&user_id, &risk_factors)
155 .await;
156
157 Json(serde_json::json!({
158 "user_id": user_id,
159 "risk_score": assessment.risk_score,
160 "risk_factors": assessment.risk_factors,
161 "recommended_action": assessment.recommended_action
162 }))
163}
164
165pub fn risk_simulation_router(state: RiskSimulationState) -> axum::Router {
167 use axum::routing::{get, post, delete};
168
169 axum::Router::new()
170 .route("/risk/simulate", post(set_risk_score))
171 .route("/risk/factors", post(set_risk_factors))
172 .route("/risk/clear/:user_id", delete(clear_risk))
173 .route("/risk/trigger-mfa", post(trigger_mfa))
174 .route("/risk/block", post(block_user))
175 .route("/risk/assessment/:user_id", get(get_risk_assessment))
176 .with_state(state)
177}
178