mockforge_http/auth/
risk_engine.rs

1//! Risk assessment engine for authentication
2//!
3//! This module provides risk-based authentication challenges including
4//! MFA prompts, device challenges, and blocked login simulation.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10
11/// Risk assessment result
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct RiskAssessment {
14    /// Overall risk score (0.0 - 1.0)
15    pub risk_score: f64,
16    /// Risk factors contributing to the score
17    pub risk_factors: Vec<RiskFactor>,
18    /// Recommended action based on risk
19    pub recommended_action: RiskAction,
20}
21
22/// Individual risk factor
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct RiskFactor {
25    /// Factor name
26    pub name: String,
27    /// Factor weight
28    pub weight: f64,
29    /// Factor value (0.0 - 1.0)
30    pub value: f64,
31    /// Contribution to overall risk score
32    pub contribution: f64,
33}
34
35/// Risk-based action
36#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
37pub enum RiskAction {
38    /// Allow normal authentication
39    Allow,
40    /// Require device challenge
41    DeviceChallenge,
42    /// Require MFA
43    RequireMfa,
44    /// Block login
45    Block,
46}
47
48/// Risk engine configuration
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct RiskEngineConfig {
51    /// MFA threshold (0.0 - 1.0)
52    pub mfa_threshold: f64,
53    /// Device challenge threshold (0.0 - 1.0)
54    pub device_challenge_threshold: f64,
55    /// Blocked login threshold (0.0 - 1.0)
56    pub blocked_login_threshold: f64,
57    /// Risk factor weights
58    pub risk_factors: HashMap<String, f64>,
59    /// Risk rules (conditions -> actions)
60    pub risk_rules: Vec<RiskRule>,
61}
62
63/// Risk rule
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct RiskRule {
66    /// Condition (e.g., "risk_score > 0.9")
67    pub condition: String,
68    /// Action to take
69    pub action: RiskAction,
70}
71
72impl Default for RiskEngineConfig {
73    fn default() -> Self {
74        let mut risk_factors = HashMap::new();
75        risk_factors.insert("new_device".to_string(), 0.3);
76        risk_factors.insert("unusual_location".to_string(), 0.4);
77        risk_factors.insert("suspicious_activity".to_string(), 0.5);
78
79        let mut risk_rules = Vec::new();
80        risk_rules.push(RiskRule {
81            condition: "risk_score > 0.9".to_string(),
82            action: RiskAction::Block,
83        });
84        risk_rules.push(RiskRule {
85            condition: "risk_score > 0.7".to_string(),
86            action: RiskAction::RequireMfa,
87        });
88        risk_rules.push(RiskRule {
89            condition: "risk_score > 0.5".to_string(),
90            action: RiskAction::DeviceChallenge,
91        });
92
93        Self {
94            mfa_threshold: 0.7,
95            device_challenge_threshold: 0.5,
96            blocked_login_threshold: 0.9,
97            risk_factors,
98            risk_rules,
99        }
100    }
101}
102
103/// Risk engine state
104#[derive(Debug, Clone)]
105pub struct RiskEngine {
106    /// Configuration
107    pub config: RiskEngineConfig,
108    /// Simulated risk scores (user_id -> risk_score override)
109    pub simulated_risks: Arc<RwLock<HashMap<String, Option<f64>>>>,
110    /// Simulated risk factors (user_id -> risk_factors override)
111    pub simulated_factors: Arc<RwLock<HashMap<String, HashMap<String, f64>>>>,
112}
113
114impl RiskEngine {
115    /// Create new risk engine
116    pub fn new(config: RiskEngineConfig) -> Self {
117        Self {
118            config,
119            simulated_risks: Arc::new(RwLock::new(HashMap::new())),
120            simulated_factors: Arc::new(RwLock::new(HashMap::new())),
121        }
122    }
123
124    /// Assess risk for an authentication request
125    pub async fn assess_risk(
126        &self,
127        user_id: &str,
128        risk_factors: &HashMap<String, f64>,
129    ) -> RiskAssessment {
130        // Check for simulated risk score override
131        let simulated_risk = {
132            let risks = self.simulated_risks.read().await;
133            risks.get(user_id).copied().flatten()
134        };
135
136        if let Some(risk_score) = simulated_risk {
137            return self.create_assessment_from_score(risk_score);
138        }
139
140        // Check for simulated risk factors override
141        let factors_to_use = {
142            let simulated = self.simulated_factors.read().await;
143            if let Some(simulated_factors) = simulated.get(user_id) {
144                simulated_factors.clone()
145            } else {
146                risk_factors.clone()
147            }
148        };
149
150        // Calculate risk score from factors
151        let mut risk_factors_vec = Vec::new();
152        let mut total_score = 0.0;
153
154        for (name, value) in factors_to_use {
155            let weight = self.config.risk_factors.get(&name).copied().unwrap_or(0.0);
156            let contribution = weight * value;
157            total_score += contribution;
158
159            risk_factors_vec.push(RiskFactor {
160                name: name.clone(),
161                weight,
162                value,
163                contribution,
164            });
165        }
166
167        // Clamp score to 0.0 - 1.0
168        let risk_score = total_score.min(1.0).max(0.0);
169
170        // Determine recommended action
171        let recommended_action = self.determine_action(risk_score);
172
173        RiskAssessment {
174            risk_score,
175            risk_factors: risk_factors_vec,
176            recommended_action,
177        }
178    }
179
180    /// Create assessment from a risk score (for simulation)
181    fn create_assessment_from_score(&self, risk_score: f64) -> RiskAssessment {
182        let recommended_action = self.determine_action(risk_score);
183
184        RiskAssessment {
185            risk_score,
186            risk_factors: vec![],
187            recommended_action,
188        }
189    }
190
191    /// Determine action based on risk score
192    fn determine_action(&self, risk_score: f64) -> RiskAction {
193        // Check risk rules first
194        for rule in &self.config.risk_rules {
195            if self.evaluate_condition(&rule.condition, risk_score) {
196                return rule.action.clone();
197            }
198        }
199
200        // Fallback to threshold-based logic
201        if risk_score >= self.config.blocked_login_threshold {
202            RiskAction::Block
203        } else if risk_score >= self.config.mfa_threshold {
204            RiskAction::RequireMfa
205        } else if risk_score >= self.config.device_challenge_threshold {
206            RiskAction::DeviceChallenge
207        } else {
208            RiskAction::Allow
209        }
210    }
211
212    /// Evaluate a risk condition
213    fn evaluate_condition(&self, condition: &str, risk_score: f64) -> bool {
214        // Simple condition evaluation
215        // In production, use a proper expression evaluator
216        if condition.contains(">") {
217            let parts: Vec<&str> = condition.split('>').collect();
218            if parts.len() == 2 {
219                if let Ok(threshold) = parts[1].trim().parse::<f64>() {
220                    return risk_score > threshold;
221                }
222            }
223        } else if condition.contains("<") {
224            let parts: Vec<&str> = condition.split('<').collect();
225            if parts.len() == 2 {
226                if let Ok(threshold) = parts[1].trim().parse::<f64>() {
227                    return risk_score < threshold;
228                }
229            }
230        } else if condition.contains(">=") {
231            let parts: Vec<&str> = condition.split(">=").collect();
232            if parts.len() == 2 {
233                if let Ok(threshold) = parts[1].trim().parse::<f64>() {
234                    return risk_score >= threshold;
235                }
236            }
237        } else if condition.contains("<=") {
238            let parts: Vec<&str> = condition.split("<=").collect();
239            if parts.len() == 2 {
240                if let Ok(threshold) = parts[1].trim().parse::<f64>() {
241                    return risk_score <= threshold;
242                }
243            }
244        } else if condition.contains("==") {
245            let parts: Vec<&str> = condition.split("==").collect();
246            if parts.len() == 2 {
247                if let Ok(threshold) = parts[1].trim().parse::<f64>() {
248                    return (risk_score - threshold).abs() < 0.001;
249                }
250            }
251        }
252
253        false
254    }
255
256    /// Set simulated risk score for a user
257    pub async fn set_simulated_risk(&self, user_id: String, risk_score: Option<f64>) {
258        let mut risks = self.simulated_risks.write().await;
259        if let Some(score) = risk_score {
260            risks.insert(user_id, Some(score));
261        } else {
262            risks.remove(&user_id);
263        }
264    }
265
266    /// Set simulated risk factors for a user
267    pub async fn set_simulated_factors(&self, user_id: String, factors: HashMap<String, f64>) {
268        let mut simulated = self.simulated_factors.write().await;
269        simulated.insert(user_id, factors);
270    }
271
272    /// Clear simulated risk for a user
273    pub async fn clear_simulated_risk(&self, user_id: &str) {
274        let mut risks = self.simulated_risks.write().await;
275        risks.remove(user_id);
276        let mut factors = self.simulated_factors.write().await;
277        factors.remove(user_id);
278    }
279}
280
281impl Default for RiskEngine {
282    fn default() -> Self {
283        Self::new(RiskEngineConfig::default())
284    }
285}