use axum::{extract::State, http::StatusCode, response::Json};
use serde::Deserialize;
use std::collections::HashMap;
use std::sync::Arc;
use crate::auth::risk_engine::RiskEngine;
#[derive(Clone)]
pub struct RiskSimulationState {
pub risk_engine: Arc<RiskEngine>,
}
#[derive(Debug, Deserialize)]
pub struct SetRiskScoreRequest {
pub user_id: String,
pub risk_score: f64,
}
#[derive(Debug, Deserialize)]
pub struct SetRiskFactorsRequest {
pub user_id: String,
pub risk_factors: HashMap<String, f64>,
}
#[derive(Debug, Deserialize)]
pub struct TriggerMfaRequest {
pub user_id: String,
pub mfa_type: String,
}
#[derive(Debug, Deserialize)]
pub struct BlockUserRequest {
pub user_id: String,
pub reason: String,
}
pub async fn set_risk_score(
State(state): State<RiskSimulationState>,
Json(request): Json<SetRiskScoreRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
state
.risk_engine
.set_simulated_risk(request.user_id.clone(), Some(request.risk_score))
.await;
Ok(Json(serde_json::json!({
"success": true,
"user_id": request.user_id,
"risk_score": request.risk_score
})))
}
pub async fn set_risk_factors(
State(state): State<RiskSimulationState>,
Json(request): Json<SetRiskFactorsRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
state
.risk_engine
.set_simulated_factors(request.user_id.clone(), request.risk_factors.clone())
.await;
Ok(Json(serde_json::json!({
"success": true,
"user_id": request.user_id,
"risk_factors": request.risk_factors
})))
}
pub async fn clear_risk(
State(state): State<RiskSimulationState>,
axum::extract::Path(user_id): axum::extract::Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
state.risk_engine.clear_simulated_risk(&user_id).await;
Ok(Json(serde_json::json!({
"success": true,
"user_id": user_id,
"message": "Simulated risk cleared"
})))
}
pub async fn trigger_mfa(
State(state): State<RiskSimulationState>,
Json(request): Json<TriggerMfaRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
state.risk_engine.set_simulated_risk(request.user_id.clone(), Some(0.8)).await;
Ok(Json(serde_json::json!({
"success": true,
"user_id": request.user_id,
"mfa_type": request.mfa_type,
"message": "MFA prompt triggered"
})))
}
pub async fn block_user(
State(state): State<RiskSimulationState>,
Json(request): Json<BlockUserRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
state.risk_engine.set_simulated_risk(request.user_id.clone(), Some(0.95)).await;
Ok(Json(serde_json::json!({
"success": true,
"user_id": request.user_id,
"reason": request.reason,
"message": "User login blocked"
})))
}
pub async fn get_risk_assessment(
State(state): State<RiskSimulationState>,
axum::extract::Path(user_id): axum::extract::Path<String>,
) -> Json<serde_json::Value> {
let risk_factors = HashMap::new();
let assessment = state.risk_engine.assess_risk(&user_id, &risk_factors).await;
Json(serde_json::json!({
"user_id": user_id,
"risk_score": assessment.risk_score,
"risk_factors": assessment.risk_factors,
"recommended_action": assessment.recommended_action
}))
}
pub fn risk_simulation_router(state: RiskSimulationState) -> axum::Router {
use axum::routing::{delete, get, post};
axum::Router::new()
.route("/risk/simulate", post(set_risk_score))
.route("/risk/factors", post(set_risk_factors))
.route("/risk/clear/{user_id}", delete(clear_risk))
.route("/risk/trigger-mfa", post(trigger_mfa))
.route("/risk/block", post(block_user))
.route("/risk/assessment/{user_id}", get(get_risk_assessment))
.with_state(state)
}