1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct RiskAssessment {
14 pub risk_score: f64,
16 pub risk_factors: Vec<RiskFactor>,
18 pub recommended_action: RiskAction,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct RiskFactor {
25 pub name: String,
27 pub weight: f64,
29 pub value: f64,
31 pub contribution: f64,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
37pub enum RiskAction {
38 Allow,
40 DeviceChallenge,
42 RequireMfa,
44 Block,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct RiskEngineConfig {
51 pub mfa_threshold: f64,
53 pub device_challenge_threshold: f64,
55 pub blocked_login_threshold: f64,
57 pub risk_factors: HashMap<String, f64>,
59 pub risk_rules: Vec<RiskRule>,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct RiskRule {
66 pub condition: String,
68 pub action: RiskAction,
70}
71
72impl Default for RiskEngineConfig {
73 fn default() -> Self {
74 let mut risk_factors = HashMap::new();
75 risk_factors.insert("new_device".to_string(), 0.3);
76 risk_factors.insert("unusual_location".to_string(), 0.4);
77 risk_factors.insert("suspicious_activity".to_string(), 0.5);
78
79 let mut risk_rules = Vec::new();
80 risk_rules.push(RiskRule {
81 condition: "risk_score > 0.9".to_string(),
82 action: RiskAction::Block,
83 });
84 risk_rules.push(RiskRule {
85 condition: "risk_score > 0.7".to_string(),
86 action: RiskAction::RequireMfa,
87 });
88 risk_rules.push(RiskRule {
89 condition: "risk_score > 0.5".to_string(),
90 action: RiskAction::DeviceChallenge,
91 });
92
93 Self {
94 mfa_threshold: 0.7,
95 device_challenge_threshold: 0.5,
96 blocked_login_threshold: 0.9,
97 risk_factors,
98 risk_rules,
99 }
100 }
101}
102
103#[derive(Debug, Clone)]
105pub struct RiskEngine {
106 pub config: RiskEngineConfig,
108 pub simulated_risks: Arc<RwLock<HashMap<String, Option<f64>>>>,
110 pub simulated_factors: Arc<RwLock<HashMap<String, HashMap<String, f64>>>>,
112}
113
114impl RiskEngine {
115 pub fn new(config: RiskEngineConfig) -> Self {
117 Self {
118 config,
119 simulated_risks: Arc::new(RwLock::new(HashMap::new())),
120 simulated_factors: Arc::new(RwLock::new(HashMap::new())),
121 }
122 }
123
124 pub async fn assess_risk(
126 &self,
127 user_id: &str,
128 risk_factors: &HashMap<String, f64>,
129 ) -> RiskAssessment {
130 let simulated_risk = {
132 let risks = self.simulated_risks.read().await;
133 risks.get(user_id).copied().flatten()
134 };
135
136 if let Some(risk_score) = simulated_risk {
137 return self.create_assessment_from_score(risk_score);
138 }
139
140 let factors_to_use = {
142 let simulated = self.simulated_factors.read().await;
143 if let Some(simulated_factors) = simulated.get(user_id) {
144 simulated_factors.clone()
145 } else {
146 risk_factors.clone()
147 }
148 };
149
150 let mut risk_factors_vec = Vec::new();
152 let mut total_score = 0.0;
153
154 for (name, value) in factors_to_use {
155 let weight = self.config.risk_factors.get(&name).copied().unwrap_or(0.0);
156 let contribution = weight * value;
157 total_score += contribution;
158
159 risk_factors_vec.push(RiskFactor {
160 name: name.clone(),
161 weight,
162 value,
163 contribution,
164 });
165 }
166
167 let risk_score = total_score.min(1.0).max(0.0);
169
170 let recommended_action = self.determine_action(risk_score);
172
173 RiskAssessment {
174 risk_score,
175 risk_factors: risk_factors_vec,
176 recommended_action,
177 }
178 }
179
180 fn create_assessment_from_score(&self, risk_score: f64) -> RiskAssessment {
182 let recommended_action = self.determine_action(risk_score);
183
184 RiskAssessment {
185 risk_score,
186 risk_factors: vec![],
187 recommended_action,
188 }
189 }
190
191 fn determine_action(&self, risk_score: f64) -> RiskAction {
193 for rule in &self.config.risk_rules {
195 if self.evaluate_condition(&rule.condition, risk_score) {
196 return rule.action.clone();
197 }
198 }
199
200 if risk_score >= self.config.blocked_login_threshold {
202 RiskAction::Block
203 } else if risk_score >= self.config.mfa_threshold {
204 RiskAction::RequireMfa
205 } else if risk_score >= self.config.device_challenge_threshold {
206 RiskAction::DeviceChallenge
207 } else {
208 RiskAction::Allow
209 }
210 }
211
212 fn evaluate_condition(&self, condition: &str, risk_score: f64) -> bool {
214 if condition.contains(">") {
217 let parts: Vec<&str> = condition.split('>').collect();
218 if parts.len() == 2 {
219 if let Ok(threshold) = parts[1].trim().parse::<f64>() {
220 return risk_score > threshold;
221 }
222 }
223 } else if condition.contains("<") {
224 let parts: Vec<&str> = condition.split('<').collect();
225 if parts.len() == 2 {
226 if let Ok(threshold) = parts[1].trim().parse::<f64>() {
227 return risk_score < threshold;
228 }
229 }
230 } else if condition.contains(">=") {
231 let parts: Vec<&str> = condition.split(">=").collect();
232 if parts.len() == 2 {
233 if let Ok(threshold) = parts[1].trim().parse::<f64>() {
234 return risk_score >= threshold;
235 }
236 }
237 } else if condition.contains("<=") {
238 let parts: Vec<&str> = condition.split("<=").collect();
239 if parts.len() == 2 {
240 if let Ok(threshold) = parts[1].trim().parse::<f64>() {
241 return risk_score <= threshold;
242 }
243 }
244 } else if condition.contains("==") {
245 let parts: Vec<&str> = condition.split("==").collect();
246 if parts.len() == 2 {
247 if let Ok(threshold) = parts[1].trim().parse::<f64>() {
248 return (risk_score - threshold).abs() < 0.001;
249 }
250 }
251 }
252
253 false
254 }
255
256 pub async fn set_simulated_risk(&self, user_id: String, risk_score: Option<f64>) {
258 let mut risks = self.simulated_risks.write().await;
259 if let Some(score) = risk_score {
260 risks.insert(user_id, Some(score));
261 } else {
262 risks.remove(&user_id);
263 }
264 }
265
266 pub async fn set_simulated_factors(&self, user_id: String, factors: HashMap<String, f64>) {
268 let mut simulated = self.simulated_factors.write().await;
269 simulated.insert(user_id, factors);
270 }
271
272 pub async fn clear_simulated_risk(&self, user_id: &str) {
274 let mut risks = self.simulated_risks.write().await;
275 risks.remove(user_id);
276 let mut factors = self.simulated_factors.write().await;
277 factors.remove(user_id);
278 }
279}
280
281impl Default for RiskEngine {
282 fn default() -> Self {
283 Self::new(RiskEngineConfig::default())
284 }
285}