1use crate::error::Result;
4use crate::models::{RiskAction, RiskAssessment, RiskFactor, RiskLevel, User};
5use std::collections::HashMap;
6use std::net::IpAddr;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10pub struct RiskEngine {
11 config: RiskConfig,
12 user_behavior: Arc<RwLock<HashMap<String, UserBehaviorProfile>>>,
13}
14
15#[derive(Clone)]
16pub struct RiskConfig {
17 pub mfa_threshold: u8,
18 pub block_threshold: u8,
19 pub geo_velocity_enabled: bool,
20 pub max_travel_speed_kmh: f64,
21}
22
23#[derive(Clone)]
24struct UserBehaviorProfile {
25 usual_locations: Vec<Location>,
26 usual_devices: Vec<String>,
27 usual_login_times: Vec<chrono::NaiveTime>,
28 last_location: Option<Location>,
29 last_login: Option<chrono::DateTime<chrono::Utc>>,
30 successful_logins: u64,
31 failed_logins: u64,
32}
33
34#[derive(Clone)]
35struct Location {
36 latitude: f64,
37 longitude: f64,
38 city: Option<String>,
39 country: Option<String>,
40}
41
42impl RiskEngine {
43 pub fn new(config: RiskConfig) -> Self {
44 Self {
45 config,
46 user_behavior: Arc::new(RwLock::new(HashMap::new())),
47 }
48 }
49
50 pub async fn assess_risk(
51 &self,
52 user: &User,
53 ip_address: Option<IpAddr>,
54 device_id: Option<&str>,
55 _user_agent: Option<&str>,
56 ) -> Result<RiskAssessment> {
57 let mut factors = Vec::new();
58 let mut total_score = 0u8;
59
60 if user.status != crate::models::UserStatus::Active {
62 factors.push(RiskFactor {
63 name: "Account Status".to_string(),
64 score: 100,
65 reason: "Account is not active".to_string(),
66 });
67 total_score = 100;
68 }
69
70 if user.failed_login_attempts > 0 {
72 let score = (user.failed_login_attempts * 10).min(50) as u8;
73 factors.push(RiskFactor {
74 name: "Failed Attempts".to_string(),
75 score,
76 reason: format!("{} recent failed login attempts", user.failed_login_attempts),
77 });
78 total_score = total_score.saturating_add(score);
79 }
80
81 let account_age = chrono::Utc::now() - user.created_at;
83 if account_age < chrono::Duration::days(1) {
84 factors.push(RiskFactor {
85 name: "New Account".to_string(),
86 score: 30,
87 reason: "Account created less than 24 hours ago".to_string(),
88 });
89 total_score = total_score.saturating_add(30);
90 }
91
92 if let Some(ip) = ip_address {
94 if let Some(location_score) = self.check_location_risk(user, &ip).await {
95 factors.push(location_score.clone());
96 total_score = total_score.saturating_add(location_score.score);
97 }
98 }
99
100 if let Some(device) = device_id {
102 if let Some(device_score) = self.check_device_risk(user, device).await {
103 factors.push(device_score.clone());
104 total_score = total_score.saturating_add(device_score.score);
105 }
106 }
107
108 if let Some(time_score) = self.check_time_risk(user).await {
110 factors.push(time_score.clone());
111 total_score = total_score.saturating_add(time_score.score);
112 }
113
114 if self.config.geo_velocity_enabled {
116 if let (Some(_ip), Some(velocity_score)) = (ip_address, self.check_geo_velocity(user, &ip_address.unwrap()).await) {
117 factors.push(velocity_score.clone());
118 total_score = total_score.saturating_add(velocity_score.score);
119 }
120 }
121
122 let level = match total_score {
124 0..=30 => RiskLevel::Low,
125 31..=60 => RiskLevel::Medium,
126 61..=85 => RiskLevel::High,
127 _ => RiskLevel::Critical,
128 };
129
130 let recommended_action = if total_score >= self.config.block_threshold {
131 RiskAction::Deny
132 } else if total_score >= self.config.mfa_threshold {
133 RiskAction::RequireMfa
134 } else if total_score >= 40 {
135 RiskAction::Challenge
136 } else {
137 RiskAction::Allow
138 };
139
140 Ok(RiskAssessment {
141 score: total_score,
142 level,
143 factors,
144 recommended_action,
145 })
146 }
147
148 async fn check_location_risk(&self, user: &User, ip: &IpAddr) -> Option<RiskFactor> {
149 let location = self.get_location_from_ip(ip)?;
151
152 let behavior = self.user_behavior.read().await;
153 let profile = behavior.get(&user.id.to_string())?;
154
155 let is_usual = profile.usual_locations.iter().any(|loc| {
156 self.distance_km(loc, &location) < 100.0
157 });
158
159 if !is_usual {
160 Some(RiskFactor {
161 name: "Unusual Location".to_string(),
162 score: 25,
163 reason: format!("Login from unfamiliar location: {:?}", location.city),
164 })
165 } else {
166 None
167 }
168 }
169
170 async fn check_device_risk(&self, user: &User, device_id: &str) -> Option<RiskFactor> {
171 let behavior = self.user_behavior.read().await;
172 let profile = behavior.get(&user.id.to_string())?;
173
174 if !profile.usual_devices.contains(&device_id.to_string()) {
175 Some(RiskFactor {
176 name: "Unknown Device".to_string(),
177 score: 20,
178 reason: "Login from unrecognized device".to_string(),
179 })
180 } else {
181 None
182 }
183 }
184
185 async fn check_time_risk(&self, user: &User) -> Option<RiskFactor> {
186 let current_time = chrono::Utc::now().time();
187
188 let behavior = self.user_behavior.read().await;
189 let profile = behavior.get(&user.id.to_string())?;
190
191 let is_usual_time = profile.usual_login_times.iter().any(|usual| {
193 let diff = if current_time >= *usual {
194 (current_time - *usual).num_hours()
195 } else {
196 (*usual - current_time).num_hours()
197 };
198 diff <= 2
199 });
200
201 if !is_usual_time && !profile.usual_login_times.is_empty() {
202 Some(RiskFactor {
203 name: "Unusual Time".to_string(),
204 score: 15,
205 reason: "Login at unusual time of day".to_string(),
206 })
207 } else {
208 None
209 }
210 }
211
212 async fn check_geo_velocity(&self, user: &User, ip: &IpAddr) -> Option<RiskFactor> {
213 let current_location = self.get_location_from_ip(ip)?;
214
215 let behavior = self.user_behavior.read().await;
216 let profile = behavior.get(&user.id.to_string())?;
217
218 let last_location = profile.last_location.as_ref()?;
219 let last_login = profile.last_login?;
220
221 let distance = self.distance_km(last_location, ¤t_location);
222 let time_diff = (chrono::Utc::now() - last_login).num_hours() as f64;
223
224 if time_diff > 0.0 {
225 let velocity = distance / time_diff;
226
227 if velocity > self.config.max_travel_speed_kmh {
228 return Some(RiskFactor {
229 name: "Impossible Travel".to_string(),
230 score: 40,
231 reason: format!(
232 "Travel speed of {:.0} km/h exceeds maximum",
233 velocity
234 ),
235 });
236 }
237 }
238
239 None
240 }
241
242 fn get_location_from_ip(&self, _ip: &IpAddr) -> Option<Location> {
243 Some(Location {
245 latitude: -23.5505,
246 longitude: -46.6333,
247 city: Some("São Paulo".to_string()),
248 country: Some("Brazil".to_string()),
249 })
250 }
251
252 fn distance_km(&self, loc1: &Location, loc2: &Location) -> f64 {
253 let r = 6371.0; let lat1 = loc1.latitude.to_radians();
257 let lat2 = loc2.latitude.to_radians();
258 let delta_lat = (loc2.latitude - loc1.latitude).to_radians();
259 let delta_lon = (loc2.longitude - loc1.longitude).to_radians();
260
261 let a = (delta_lat / 2.0).sin().powi(2)
262 + lat1.cos() * lat2.cos() * (delta_lon / 2.0).sin().powi(2);
263 let c = 2.0 * a.sqrt().atan2((1.0 - a).sqrt());
264
265 r * c
266 }
267
268 pub async fn update_behavior_profile(
269 &self,
270 user_id: &uuid::Uuid,
271 ip_address: Option<IpAddr>,
272 device_id: Option<String>,
273 success: bool,
274 ) {
275 let mut behavior = self.user_behavior.write().await;
276 let profile = behavior.entry(user_id.to_string()).or_insert_with(|| UserBehaviorProfile {
277 usual_locations: Vec::new(),
278 usual_devices: Vec::new(),
279 usual_login_times: Vec::new(),
280 last_location: None,
281 last_login: None,
282 successful_logins: 0,
283 failed_logins: 0,
284 });
285
286 if success {
287 profile.successful_logins += 1;
288
289 if let Some(ip) = ip_address {
290 if let Some(location) = self.get_location_from_ip(&ip) {
291 profile.last_location = Some(location.clone());
292
293 if !profile.usual_locations.iter().any(|l| self.distance_km(l, &location) < 50.0) {
295 profile.usual_locations.push(location);
296 }
297 }
298 }
299
300 if let Some(device) = device_id {
301 if !profile.usual_devices.contains(&device) {
302 profile.usual_devices.push(device);
303 }
304 }
305
306 let current_time = chrono::Utc::now().time();
307 profile.usual_login_times.push(current_time);
308 profile.last_login = Some(chrono::Utc::now());
309 } else {
310 profile.failed_logins += 1;
311 }
312 }
313}