use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RiskAssessment {
pub risk_score: f64,
pub risk_factors: Vec<RiskFactor>,
pub recommended_action: RiskAction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RiskFactor {
pub name: String,
pub weight: f64,
pub value: f64,
pub contribution: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum RiskAction {
Allow,
DeviceChallenge,
RequireMfa,
Block,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RiskEngineConfig {
pub mfa_threshold: f64,
pub device_challenge_threshold: f64,
pub blocked_login_threshold: f64,
pub risk_factors: HashMap<String, f64>,
pub risk_rules: Vec<RiskRule>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RiskRule {
pub condition: String,
pub action: RiskAction,
}
impl Default for RiskEngineConfig {
fn default() -> Self {
let mut risk_factors = HashMap::new();
risk_factors.insert("new_device".to_string(), 0.3);
risk_factors.insert("unusual_location".to_string(), 0.4);
risk_factors.insert("suspicious_activity".to_string(), 0.5);
let risk_rules = vec![
RiskRule {
condition: "risk_score > 0.9".to_string(),
action: RiskAction::Block,
},
RiskRule {
condition: "risk_score > 0.7".to_string(),
action: RiskAction::RequireMfa,
},
RiskRule {
condition: "risk_score > 0.5".to_string(),
action: RiskAction::DeviceChallenge,
},
];
Self {
mfa_threshold: 0.7,
device_challenge_threshold: 0.5,
blocked_login_threshold: 0.9,
risk_factors,
risk_rules,
}
}
}
#[derive(Debug, Clone)]
pub struct RiskEngine {
pub config: RiskEngineConfig,
pub simulated_risks: Arc<RwLock<HashMap<String, Option<f64>>>>,
pub simulated_factors: Arc<RwLock<HashMap<String, HashMap<String, f64>>>>,
}
impl RiskEngine {
pub fn new(config: RiskEngineConfig) -> Self {
Self {
config,
simulated_risks: Arc::new(RwLock::new(HashMap::new())),
simulated_factors: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn assess_risk(
&self,
user_id: &str,
risk_factors: &HashMap<String, f64>,
) -> RiskAssessment {
let simulated_risk = {
let risks = self.simulated_risks.read().await;
risks.get(user_id).copied().flatten()
};
if let Some(risk_score) = simulated_risk {
return self.create_assessment_from_score(risk_score);
}
let factors_to_use = {
let simulated = self.simulated_factors.read().await;
if let Some(simulated_factors) = simulated.get(user_id) {
simulated_factors.clone()
} else {
risk_factors.clone()
}
};
let mut risk_factors_vec = Vec::new();
let mut total_score = 0.0;
for (name, value) in factors_to_use {
let weight = self.config.risk_factors.get(&name).copied().unwrap_or(0.0);
let contribution = weight * value;
total_score += contribution;
risk_factors_vec.push(RiskFactor {
name: name.clone(),
weight,
value,
contribution,
});
}
let risk_score = total_score.clamp(0.0, 1.0);
let recommended_action = self.determine_action(risk_score);
RiskAssessment {
risk_score,
risk_factors: risk_factors_vec,
recommended_action,
}
}
fn create_assessment_from_score(&self, risk_score: f64) -> RiskAssessment {
let recommended_action = self.determine_action(risk_score);
RiskAssessment {
risk_score,
risk_factors: vec![],
recommended_action,
}
}
fn determine_action(&self, risk_score: f64) -> RiskAction {
for rule in &self.config.risk_rules {
if self.evaluate_condition(&rule.condition, risk_score) {
return rule.action.clone();
}
}
if risk_score >= self.config.blocked_login_threshold {
RiskAction::Block
} else if risk_score >= self.config.mfa_threshold {
RiskAction::RequireMfa
} else if risk_score >= self.config.device_challenge_threshold {
RiskAction::DeviceChallenge
} else {
RiskAction::Allow
}
}
fn evaluate_condition(&self, condition: &str, risk_score: f64) -> bool {
if condition.contains(">=") {
let parts: Vec<&str> = condition.split(">=").collect();
if parts.len() == 2 {
if let Ok(threshold) = parts[1].trim().parse::<f64>() {
return risk_score >= threshold;
}
}
} else if condition.contains("<=") {
let parts: Vec<&str> = condition.split("<=").collect();
if parts.len() == 2 {
if let Ok(threshold) = parts[1].trim().parse::<f64>() {
return risk_score <= threshold;
}
}
} else if condition.contains("==") {
let parts: Vec<&str> = condition.split("==").collect();
if parts.len() == 2 {
if let Ok(threshold) = parts[1].trim().parse::<f64>() {
return (risk_score - threshold).abs() < 0.001;
}
}
} else if condition.contains('>') {
let parts: Vec<&str> = condition.split('>').collect();
if parts.len() == 2 {
if let Ok(threshold) = parts[1].trim().parse::<f64>() {
return risk_score > threshold;
}
}
} else if condition.contains('<') {
let parts: Vec<&str> = condition.split('<').collect();
if parts.len() == 2 {
if let Ok(threshold) = parts[1].trim().parse::<f64>() {
return risk_score < threshold;
}
}
}
false
}
pub async fn set_simulated_risk(&self, user_id: String, risk_score: Option<f64>) {
let mut risks = self.simulated_risks.write().await;
if let Some(score) = risk_score {
risks.insert(user_id, Some(score));
} else {
risks.remove(&user_id);
}
}
pub async fn set_simulated_factors(&self, user_id: String, factors: HashMap<String, f64>) {
let mut simulated = self.simulated_factors.write().await;
simulated.insert(user_id, factors);
}
pub async fn clear_simulated_risk(&self, user_id: &str) {
let mut risks = self.simulated_risks.write().await;
risks.remove(user_id);
let mut factors = self.simulated_factors.write().await;
factors.remove(user_id);
}
}
impl Default for RiskEngine {
fn default() -> Self {
Self::new(RiskEngineConfig::default())
}
}