use chrono::NaiveDate;
use datasynth_core::utils::seeded_rng;
use rand::RngExt;
use rand_chacha::ChaCha8Rng;
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use datasynth_core::models::{SchemeDetectionStatus, SchemeType};
use super::schemes::{
FraudScheme, GradualEmbezzlementScheme, RevenueManipulationScheme, SchemeAction, SchemeContext,
SchemeStatus, VendorKickbackScheme,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SchemeAdvancerConfig {
pub embezzlement_probability: f64,
pub revenue_manipulation_probability: f64,
pub kickback_probability: f64,
pub max_concurrent_schemes: usize,
pub allow_repeat_perpetrators: bool,
pub seed: u64,
}
impl Default for SchemeAdvancerConfig {
fn default() -> Self {
Self {
embezzlement_probability: 0.02,
revenue_manipulation_probability: 0.01,
kickback_probability: 0.01,
max_concurrent_schemes: 5,
allow_repeat_perpetrators: false,
seed: 42,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletedScheme {
pub scheme_id: Uuid,
pub scheme_type: SchemeType,
pub perpetrator_id: String,
pub start_date: Option<NaiveDate>,
pub end_date: NaiveDate,
pub final_status: SchemeStatus,
pub detection_status: SchemeDetectionStatus,
pub total_impact: Decimal,
pub stages_completed: u32,
pub transaction_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiStageAnomalyLabel {
pub anomaly_id: String,
pub scheme_id: Uuid,
pub scheme_type: SchemeType,
pub stage_number: u32,
pub stage_name: String,
pub total_stages: u32,
pub perpetrator_id: String,
pub scheme_detected: bool,
}
pub struct SchemeAdvancer {
config: SchemeAdvancerConfig,
rng: ChaCha8Rng,
active_schemes: Vec<Box<dyn FraudScheme>>,
completed_schemes: Vec<CompletedScheme>,
active_perpetrators: Vec<String>,
active_vendors: Vec<String>,
labels: Vec<MultiStageAnomalyLabel>,
}
impl SchemeAdvancer {
pub fn new(config: SchemeAdvancerConfig) -> Self {
let rng = seeded_rng(config.seed, 0);
Self {
config,
rng,
active_schemes: Vec::new(),
completed_schemes: Vec::new(),
active_perpetrators: Vec::new(),
active_vendors: Vec::new(),
labels: Vec::new(),
}
}
pub fn maybe_start_scheme(&mut self, context: &SchemeContext) -> Option<Uuid> {
if self.active_schemes.len() >= self.config.max_concurrent_schemes {
return None;
}
let available_users: Vec<_> = if self.config.allow_repeat_perpetrators {
context.available_users.clone()
} else {
context
.available_users
.iter()
.filter(|u| !self.active_perpetrators.contains(u))
.cloned()
.collect()
};
if available_users.is_empty() {
return None;
}
let r = self.rng.random::<f64>();
let total_prob = self.config.embezzlement_probability
+ self.config.revenue_manipulation_probability
+ self.config.kickback_probability;
if total_prob == 0.0 || r > total_prob {
return None;
}
let normalized_r = r / total_prob;
let embezzlement_threshold = self.config.embezzlement_probability / total_prob;
let revenue_threshold =
embezzlement_threshold + self.config.revenue_manipulation_probability / total_prob;
let user_idx = self.rng.random_range(0..available_users.len());
let perpetrator = available_users[user_idx].clone();
let scheme: Box<dyn FraudScheme> = if normalized_r < embezzlement_threshold {
let scheme = GradualEmbezzlementScheme::new(&perpetrator)
.with_accounts(context.available_accounts.clone());
Box::new(scheme)
} else if normalized_r < revenue_threshold {
let scheme = RevenueManipulationScheme::new(&perpetrator);
Box::new(scheme)
} else {
if context.available_counterparties.is_empty() {
return None;
}
let available_vendors: Vec<_> = context
.available_counterparties
.iter()
.filter(|v| !self.active_vendors.contains(v))
.cloned()
.collect();
if available_vendors.is_empty() {
return None;
}
let vendor_idx = self.rng.random_range(0..available_vendors.len());
let vendor = available_vendors[vendor_idx].clone();
let inflation = 0.10 + self.rng.random::<f64>() * 0.15; let scheme =
VendorKickbackScheme::new(&perpetrator, &vendor).with_inflation_percent(inflation);
self.active_vendors.push(vendor);
Box::new(scheme)
};
let scheme_id = scheme.scheme_id();
self.active_perpetrators.push(perpetrator);
self.active_schemes.push(scheme);
Some(scheme_id)
}
pub fn advance_all(&mut self, context: &SchemeContext) -> Vec<SchemeAction> {
let mut all_actions = Vec::new();
let mut schemes_to_complete = Vec::new();
for (idx, scheme) in self.active_schemes.iter_mut().enumerate() {
let mut scheme_rng = seeded_rng(self.config.seed, scheme.scheme_id().as_u128() as u64);
let actions = scheme.advance(context, &mut scheme_rng);
all_actions.extend(actions);
if matches!(
scheme.status(),
SchemeStatus::Completed | SchemeStatus::Terminated | SchemeStatus::Detected
) {
schemes_to_complete.push(idx);
}
}
for idx in schemes_to_complete.into_iter().rev() {
let scheme = self.active_schemes.remove(idx);
let completed = CompletedScheme {
scheme_id: scheme.scheme_id(),
scheme_type: scheme.scheme_type(),
perpetrator_id: scheme.perpetrator_id().to_string(),
start_date: scheme.start_date(),
end_date: context.current_date,
final_status: scheme.status(),
detection_status: scheme.detection_status(),
total_impact: scheme.total_impact(),
stages_completed: scheme.current_stage_number(),
transaction_count: scheme.transaction_refs().len(),
};
self.active_perpetrators
.retain(|p| p != scheme.perpetrator_id());
self.completed_schemes.push(completed);
}
all_actions
}
pub fn record_label(&mut self, anomaly_id: impl Into<String>, action: &SchemeAction) {
if let Some(scheme) = self
.active_schemes
.iter()
.find(|s| s.scheme_id() == action.scheme_id)
{
let label = MultiStageAnomalyLabel {
anomaly_id: anomaly_id.into(),
scheme_id: scheme.scheme_id(),
scheme_type: scheme.scheme_type(),
stage_number: action.stage,
stage_name: scheme.current_stage().name.clone(),
total_stages: scheme.stages().len() as u32,
perpetrator_id: scheme.perpetrator_id().to_string(),
scheme_detected: scheme.detection_status() != SchemeDetectionStatus::Undetected,
};
self.labels.push(label);
}
}
pub fn get_labels(&self) -> &[MultiStageAnomalyLabel] {
&self.labels
}
pub fn get_completed_schemes(&self) -> &[CompletedScheme] {
&self.completed_schemes
}
pub fn active_scheme_count(&self) -> usize {
self.active_schemes.len()
}
pub fn completed_scheme_count(&self) -> usize {
self.completed_schemes.len()
}
pub fn active_schemes_summary(&self) -> Vec<(Uuid, SchemeType, SchemeStatus)> {
self.active_schemes
.iter()
.map(|s| (s.scheme_id(), s.scheme_type(), s.status()))
.collect()
}
pub fn get_scheme(&self, scheme_id: Uuid) -> Option<&dyn FraudScheme> {
self.active_schemes
.iter()
.find(|s| s.scheme_id() == scheme_id)
.map(std::convert::AsRef::as_ref)
}
pub fn reset(&mut self) {
self.active_schemes.clear();
self.completed_schemes.clear();
self.active_perpetrators.clear();
self.active_vendors.clear();
self.labels.clear();
self.rng = seeded_rng(self.config.seed, 0);
}
pub fn get_statistics(&self) -> SchemeStatistics {
let total_impact: Decimal = self
.completed_schemes
.iter()
.map(|s| s.total_impact)
.sum::<Decimal>()
+ self
.active_schemes
.iter()
.map(|s| s.total_impact())
.sum::<Decimal>();
let detected_count = self
.completed_schemes
.iter()
.filter(|s| s.detection_status != SchemeDetectionStatus::Undetected)
.count();
let by_type = |t: SchemeType| {
self.completed_schemes
.iter()
.filter(|s| s.scheme_type == t)
.count()
+ self
.active_schemes
.iter()
.filter(|s| s.scheme_type() == t)
.count()
};
SchemeStatistics {
total_schemes: self.active_schemes.len() + self.completed_schemes.len(),
active_schemes: self.active_schemes.len(),
completed_schemes: self.completed_schemes.len(),
detected_schemes: detected_count,
total_impact,
embezzlement_count: by_type(SchemeType::GradualEmbezzlement),
revenue_manipulation_count: by_type(SchemeType::RevenueManipulation),
kickback_count: by_type(SchemeType::VendorKickback),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SchemeStatistics {
pub total_schemes: usize,
pub active_schemes: usize,
pub completed_schemes: usize,
pub detected_schemes: usize,
pub total_impact: Decimal,
pub embezzlement_count: usize,
pub revenue_manipulation_count: usize,
pub kickback_count: usize,
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_scheme_advancer_creation() {
let advancer = SchemeAdvancer::new(SchemeAdvancerConfig::default());
assert_eq!(advancer.active_scheme_count(), 0);
assert_eq!(advancer.completed_scheme_count(), 0);
}
#[test]
fn test_scheme_advancer_start_scheme() {
let mut advancer = SchemeAdvancer::new(SchemeAdvancerConfig {
embezzlement_probability: 1.0, ..Default::default()
});
let context = SchemeContext::new(NaiveDate::from_ymd_opt(2024, 1, 15).unwrap(), "1000")
.with_users(vec!["USER001".to_string(), "USER002".to_string()])
.with_accounts(vec!["5000".to_string()]);
let scheme_id = advancer.maybe_start_scheme(&context);
assert!(scheme_id.is_some());
assert_eq!(advancer.active_scheme_count(), 1);
}
#[test]
fn test_scheme_advancer_max_concurrent() {
let mut advancer = SchemeAdvancer::new(SchemeAdvancerConfig {
embezzlement_probability: 1.0,
max_concurrent_schemes: 2,
..Default::default()
});
let context = SchemeContext::new(NaiveDate::from_ymd_opt(2024, 1, 15).unwrap(), "1000")
.with_users(vec![
"USER001".to_string(),
"USER002".to_string(),
"USER003".to_string(),
])
.with_accounts(vec!["5000".to_string()]);
advancer.maybe_start_scheme(&context);
advancer.maybe_start_scheme(&context);
let third = advancer.maybe_start_scheme(&context);
assert_eq!(advancer.active_scheme_count(), 2);
assert!(third.is_none()); }
#[test]
fn test_scheme_advancer_advance_all() {
let mut advancer = SchemeAdvancer::new(SchemeAdvancerConfig {
embezzlement_probability: 1.0,
..Default::default()
});
let context = SchemeContext::new(NaiveDate::from_ymd_opt(2024, 1, 15).unwrap(), "1000")
.with_users(vec!["USER001".to_string()])
.with_accounts(vec!["5000".to_string()]);
advancer.maybe_start_scheme(&context);
for day in 0..30 {
let date = NaiveDate::from_ymd_opt(2024, 1, 15).unwrap() + chrono::Duration::days(day);
let mut ctx = context.clone();
ctx.current_date = date;
let _actions = advancer.advance_all(&ctx);
}
assert_eq!(advancer.active_scheme_count(), 1);
}
#[test]
fn test_scheme_advancer_statistics() {
let mut advancer = SchemeAdvancer::new(SchemeAdvancerConfig {
embezzlement_probability: 1.0,
..Default::default()
});
let context = SchemeContext::new(NaiveDate::from_ymd_opt(2024, 1, 15).unwrap(), "1000")
.with_users(vec!["USER001".to_string()])
.with_accounts(vec!["5000".to_string()]);
advancer.maybe_start_scheme(&context);
let stats = advancer.get_statistics();
assert_eq!(stats.total_schemes, 1);
assert_eq!(stats.active_schemes, 1);
assert_eq!(stats.embezzlement_count, 1);
}
}