Skip to main content

mockforge_http/auth/
risk_engine.rs

1//! Risk assessment engine for authentication
2//!
3//! This module provides risk-based authentication challenges including
4//! MFA prompts, device challenges, and blocked login simulation.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10
11/// Risk assessment result
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct RiskAssessment {
14    /// Overall risk score (0.0 - 1.0)
15    pub risk_score: f64,
16    /// Risk factors contributing to the score
17    pub risk_factors: Vec<RiskFactor>,
18    /// Recommended action based on risk
19    pub recommended_action: RiskAction,
20}
21
22/// Individual risk factor
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct RiskFactor {
25    /// Factor name
26    pub name: String,
27    /// Factor weight
28    pub weight: f64,
29    /// Factor value (0.0 - 1.0)
30    pub value: f64,
31    /// Contribution to overall risk score
32    pub contribution: f64,
33}
34
35/// Risk-based action
36#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
37pub enum RiskAction {
38    /// Allow normal authentication
39    Allow,
40    /// Require device challenge
41    DeviceChallenge,
42    /// Require MFA
43    RequireMfa,
44    /// Block login
45    Block,
46}
47
48/// Risk engine configuration
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct RiskEngineConfig {
51    /// MFA threshold (0.0 - 1.0)
52    pub mfa_threshold: f64,
53    /// Device challenge threshold (0.0 - 1.0)
54    pub device_challenge_threshold: f64,
55    /// Blocked login threshold (0.0 - 1.0)
56    pub blocked_login_threshold: f64,
57    /// Risk factor weights
58    pub risk_factors: HashMap<String, f64>,
59    /// Risk rules (conditions -> actions)
60    pub risk_rules: Vec<RiskRule>,
61}
62
63/// Risk rule
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct RiskRule {
66    /// Condition (e.g., "risk_score > 0.9")
67    pub condition: String,
68    /// Action to take
69    pub action: RiskAction,
70}
71
72impl Default for RiskEngineConfig {
73    fn default() -> Self {
74        let mut risk_factors = HashMap::new();
75        risk_factors.insert("new_device".to_string(), 0.3);
76        risk_factors.insert("unusual_location".to_string(), 0.4);
77        risk_factors.insert("suspicious_activity".to_string(), 0.5);
78
79        let risk_rules = vec![
80            RiskRule {
81                condition: "risk_score > 0.9".to_string(),
82                action: RiskAction::Block,
83            },
84            RiskRule {
85                condition: "risk_score > 0.7".to_string(),
86                action: RiskAction::RequireMfa,
87            },
88            RiskRule {
89                condition: "risk_score > 0.5".to_string(),
90                action: RiskAction::DeviceChallenge,
91            },
92        ];
93
94        Self {
95            mfa_threshold: 0.7,
96            device_challenge_threshold: 0.5,
97            blocked_login_threshold: 0.9,
98            risk_factors,
99            risk_rules,
100        }
101    }
102}
103
104/// Risk engine state
105#[derive(Debug, Clone)]
106pub struct RiskEngine {
107    /// Configuration
108    pub config: RiskEngineConfig,
109    /// Simulated risk scores (user_id -> risk_score override)
110    pub simulated_risks: Arc<RwLock<HashMap<String, Option<f64>>>>,
111    /// Simulated risk factors (user_id -> risk_factors override)
112    pub simulated_factors: Arc<RwLock<HashMap<String, HashMap<String, f64>>>>,
113}
114
115impl RiskEngine {
116    /// Create new risk engine
117    pub fn new(config: RiskEngineConfig) -> Self {
118        Self {
119            config,
120            simulated_risks: Arc::new(RwLock::new(HashMap::new())),
121            simulated_factors: Arc::new(RwLock::new(HashMap::new())),
122        }
123    }
124
125    /// Assess risk for an authentication request
126    pub async fn assess_risk(
127        &self,
128        user_id: &str,
129        risk_factors: &HashMap<String, f64>,
130    ) -> RiskAssessment {
131        // Check for simulated risk score override
132        let simulated_risk = {
133            let risks = self.simulated_risks.read().await;
134            risks.get(user_id).copied().flatten()
135        };
136
137        if let Some(risk_score) = simulated_risk {
138            return self.create_assessment_from_score(risk_score);
139        }
140
141        // Check for simulated risk factors override
142        let factors_to_use = {
143            let simulated = self.simulated_factors.read().await;
144            if let Some(simulated_factors) = simulated.get(user_id) {
145                simulated_factors.clone()
146            } else {
147                risk_factors.clone()
148            }
149        };
150
151        // Calculate risk score from factors
152        let mut risk_factors_vec = Vec::new();
153        let mut total_score = 0.0;
154
155        for (name, value) in factors_to_use {
156            let weight = self.config.risk_factors.get(&name).copied().unwrap_or(0.0);
157            let contribution = weight * value;
158            total_score += contribution;
159
160            risk_factors_vec.push(RiskFactor {
161                name: name.clone(),
162                weight,
163                value,
164                contribution,
165            });
166        }
167
168        // Clamp score to 0.0 - 1.0
169        let risk_score = total_score.clamp(0.0, 1.0);
170
171        // Determine recommended action
172        let recommended_action = self.determine_action(risk_score);
173
174        RiskAssessment {
175            risk_score,
176            risk_factors: risk_factors_vec,
177            recommended_action,
178        }
179    }
180
181    /// Create assessment from a risk score (for simulation)
182    fn create_assessment_from_score(&self, risk_score: f64) -> RiskAssessment {
183        let recommended_action = self.determine_action(risk_score);
184
185        RiskAssessment {
186            risk_score,
187            risk_factors: vec![],
188            recommended_action,
189        }
190    }
191
192    /// Determine action based on risk score
193    fn determine_action(&self, risk_score: f64) -> RiskAction {
194        // Check risk rules first
195        for rule in &self.config.risk_rules {
196            if self.evaluate_condition(&rule.condition, risk_score) {
197                return rule.action.clone();
198            }
199        }
200
201        // Fallback to threshold-based logic
202        if risk_score >= self.config.blocked_login_threshold {
203            RiskAction::Block
204        } else if risk_score >= self.config.mfa_threshold {
205            RiskAction::RequireMfa
206        } else if risk_score >= self.config.device_challenge_threshold {
207            RiskAction::DeviceChallenge
208        } else {
209            RiskAction::Allow
210        }
211    }
212
213    /// Evaluate a risk condition
214    fn evaluate_condition(&self, condition: &str, risk_score: f64) -> bool {
215        // Check multi-character operators before single-character ones
216        // to avoid `>=` being incorrectly matched by `>`.
217        if condition.contains(">=") {
218            let parts: Vec<&str> = condition.split(">=").collect();
219            if parts.len() == 2 {
220                if let Ok(threshold) = parts[1].trim().parse::<f64>() {
221                    return risk_score >= threshold;
222                }
223            }
224        } else if condition.contains("<=") {
225            let parts: Vec<&str> = condition.split("<=").collect();
226            if parts.len() == 2 {
227                if let Ok(threshold) = parts[1].trim().parse::<f64>() {
228                    return risk_score <= threshold;
229                }
230            }
231        } else if condition.contains("==") {
232            let parts: Vec<&str> = condition.split("==").collect();
233            if parts.len() == 2 {
234                if let Ok(threshold) = parts[1].trim().parse::<f64>() {
235                    return (risk_score - threshold).abs() < 0.001;
236                }
237            }
238        } else if condition.contains('>') {
239            let parts: Vec<&str> = condition.split('>').collect();
240            if parts.len() == 2 {
241                if let Ok(threshold) = parts[1].trim().parse::<f64>() {
242                    return risk_score > threshold;
243                }
244            }
245        } else if condition.contains('<') {
246            let parts: Vec<&str> = condition.split('<').collect();
247            if parts.len() == 2 {
248                if let Ok(threshold) = parts[1].trim().parse::<f64>() {
249                    return risk_score < threshold;
250                }
251            }
252        }
253
254        false
255    }
256
257    /// Set simulated risk score for a user
258    pub async fn set_simulated_risk(&self, user_id: String, risk_score: Option<f64>) {
259        let mut risks = self.simulated_risks.write().await;
260        if let Some(score) = risk_score {
261            risks.insert(user_id, Some(score));
262        } else {
263            risks.remove(&user_id);
264        }
265    }
266
267    /// Set simulated risk factors for a user
268    pub async fn set_simulated_factors(&self, user_id: String, factors: HashMap<String, f64>) {
269        let mut simulated = self.simulated_factors.write().await;
270        simulated.insert(user_id, factors);
271    }
272
273    /// Clear simulated risk for a user
274    pub async fn clear_simulated_risk(&self, user_id: &str) {
275        let mut risks = self.simulated_risks.write().await;
276        risks.remove(user_id);
277        let mut factors = self.simulated_factors.write().await;
278        factors.remove(user_id);
279    }
280}
281
282impl Default for RiskEngine {
283    fn default() -> Self {
284        Self::new(RiskEngineConfig::default())
285    }
286}