1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct RiskAssessment {
14 pub risk_score: f64,
16 pub risk_factors: Vec<RiskFactor>,
18 pub recommended_action: RiskAction,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct RiskFactor {
25 pub name: String,
27 pub weight: f64,
29 pub value: f64,
31 pub contribution: f64,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
37pub enum RiskAction {
38 Allow,
40 DeviceChallenge,
42 RequireMfa,
44 Block,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct RiskEngineConfig {
51 pub mfa_threshold: f64,
53 pub device_challenge_threshold: f64,
55 pub blocked_login_threshold: f64,
57 pub risk_factors: HashMap<String, f64>,
59 pub risk_rules: Vec<RiskRule>,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct RiskRule {
66 pub condition: String,
68 pub action: RiskAction,
70}
71
72impl Default for RiskEngineConfig {
73 fn default() -> Self {
74 let mut risk_factors = HashMap::new();
75 risk_factors.insert("new_device".to_string(), 0.3);
76 risk_factors.insert("unusual_location".to_string(), 0.4);
77 risk_factors.insert("suspicious_activity".to_string(), 0.5);
78
79 let risk_rules = vec![
80 RiskRule {
81 condition: "risk_score > 0.9".to_string(),
82 action: RiskAction::Block,
83 },
84 RiskRule {
85 condition: "risk_score > 0.7".to_string(),
86 action: RiskAction::RequireMfa,
87 },
88 RiskRule {
89 condition: "risk_score > 0.5".to_string(),
90 action: RiskAction::DeviceChallenge,
91 },
92 ];
93
94 Self {
95 mfa_threshold: 0.7,
96 device_challenge_threshold: 0.5,
97 blocked_login_threshold: 0.9,
98 risk_factors,
99 risk_rules,
100 }
101 }
102}
103
104#[derive(Debug, Clone)]
106pub struct RiskEngine {
107 pub config: RiskEngineConfig,
109 pub simulated_risks: Arc<RwLock<HashMap<String, Option<f64>>>>,
111 pub simulated_factors: Arc<RwLock<HashMap<String, HashMap<String, f64>>>>,
113}
114
115impl RiskEngine {
116 pub fn new(config: RiskEngineConfig) -> Self {
118 Self {
119 config,
120 simulated_risks: Arc::new(RwLock::new(HashMap::new())),
121 simulated_factors: Arc::new(RwLock::new(HashMap::new())),
122 }
123 }
124
125 pub async fn assess_risk(
127 &self,
128 user_id: &str,
129 risk_factors: &HashMap<String, f64>,
130 ) -> RiskAssessment {
131 let simulated_risk = {
133 let risks = self.simulated_risks.read().await;
134 risks.get(user_id).copied().flatten()
135 };
136
137 if let Some(risk_score) = simulated_risk {
138 return self.create_assessment_from_score(risk_score);
139 }
140
141 let factors_to_use = {
143 let simulated = self.simulated_factors.read().await;
144 if let Some(simulated_factors) = simulated.get(user_id) {
145 simulated_factors.clone()
146 } else {
147 risk_factors.clone()
148 }
149 };
150
151 let mut risk_factors_vec = Vec::new();
153 let mut total_score = 0.0;
154
155 for (name, value) in factors_to_use {
156 let weight = self.config.risk_factors.get(&name).copied().unwrap_or(0.0);
157 let contribution = weight * value;
158 total_score += contribution;
159
160 risk_factors_vec.push(RiskFactor {
161 name: name.clone(),
162 weight,
163 value,
164 contribution,
165 });
166 }
167
168 let risk_score = total_score.clamp(0.0, 1.0);
170
171 let recommended_action = self.determine_action(risk_score);
173
174 RiskAssessment {
175 risk_score,
176 risk_factors: risk_factors_vec,
177 recommended_action,
178 }
179 }
180
181 fn create_assessment_from_score(&self, risk_score: f64) -> RiskAssessment {
183 let recommended_action = self.determine_action(risk_score);
184
185 RiskAssessment {
186 risk_score,
187 risk_factors: vec![],
188 recommended_action,
189 }
190 }
191
192 fn determine_action(&self, risk_score: f64) -> RiskAction {
194 for rule in &self.config.risk_rules {
196 if self.evaluate_condition(&rule.condition, risk_score) {
197 return rule.action.clone();
198 }
199 }
200
201 if risk_score >= self.config.blocked_login_threshold {
203 RiskAction::Block
204 } else if risk_score >= self.config.mfa_threshold {
205 RiskAction::RequireMfa
206 } else if risk_score >= self.config.device_challenge_threshold {
207 RiskAction::DeviceChallenge
208 } else {
209 RiskAction::Allow
210 }
211 }
212
213 fn evaluate_condition(&self, condition: &str, risk_score: f64) -> bool {
215 if condition.contains(">") {
218 let parts: Vec<&str> = condition.split('>').collect();
219 if parts.len() == 2 {
220 if let Ok(threshold) = parts[1].trim().parse::<f64>() {
221 return risk_score > threshold;
222 }
223 }
224 } else if condition.contains("<") {
225 let parts: Vec<&str> = condition.split('<').collect();
226 if parts.len() == 2 {
227 if let Ok(threshold) = parts[1].trim().parse::<f64>() {
228 return risk_score < threshold;
229 }
230 }
231 } else if condition.contains(">=") {
232 let parts: Vec<&str> = condition.split(">=").collect();
233 if parts.len() == 2 {
234 if let Ok(threshold) = parts[1].trim().parse::<f64>() {
235 return risk_score >= threshold;
236 }
237 }
238 } else if condition.contains("<=") {
239 let parts: Vec<&str> = condition.split("<=").collect();
240 if parts.len() == 2 {
241 if let Ok(threshold) = parts[1].trim().parse::<f64>() {
242 return risk_score <= threshold;
243 }
244 }
245 } else if condition.contains("==") {
246 let parts: Vec<&str> = condition.split("==").collect();
247 if parts.len() == 2 {
248 if let Ok(threshold) = parts[1].trim().parse::<f64>() {
249 return (risk_score - threshold).abs() < 0.001;
250 }
251 }
252 }
253
254 false
255 }
256
257 pub async fn set_simulated_risk(&self, user_id: String, risk_score: Option<f64>) {
259 let mut risks = self.simulated_risks.write().await;
260 if let Some(score) = risk_score {
261 risks.insert(user_id, Some(score));
262 } else {
263 risks.remove(&user_id);
264 }
265 }
266
267 pub async fn set_simulated_factors(&self, user_id: String, factors: HashMap<String, f64>) {
269 let mut simulated = self.simulated_factors.write().await;
270 simulated.insert(user_id, factors);
271 }
272
273 pub async fn clear_simulated_risk(&self, user_id: &str) {
275 let mut risks = self.simulated_risks.write().await;
276 risks.remove(user_id);
277 let mut factors = self.simulated_factors.write().await;
278 factors.remove(user_id);
279 }
280}
281
282impl Default for RiskEngine {
283 fn default() -> Self {
284 Self::new(RiskEngineConfig::default())
285 }
286}