datasynth_generators/anomaly/
scheme_advancer.rs1use chrono::NaiveDate;
7use datasynth_core::utils::seeded_rng;
8use rand::Rng;
9use rand_chacha::ChaCha8Rng;
10use rust_decimal::Decimal;
11use serde::{Deserialize, Serialize};
12use uuid::Uuid;
13
14use datasynth_core::models::{SchemeDetectionStatus, SchemeType};
15
16use super::schemes::{
17 FraudScheme, GradualEmbezzlementScheme, RevenueManipulationScheme, SchemeAction, SchemeContext,
18 SchemeStatus, VendorKickbackScheme,
19};
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct SchemeAdvancerConfig {
24 pub embezzlement_probability: f64,
26 pub revenue_manipulation_probability: f64,
28 pub kickback_probability: f64,
30 pub max_concurrent_schemes: usize,
32 pub allow_repeat_perpetrators: bool,
34 pub seed: u64,
36}
37
38impl Default for SchemeAdvancerConfig {
39 fn default() -> Self {
40 Self {
41 embezzlement_probability: 0.02,
42 revenue_manipulation_probability: 0.01,
43 kickback_probability: 0.01,
44 max_concurrent_schemes: 5,
45 allow_repeat_perpetrators: false,
46 seed: 42,
47 }
48 }
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct CompletedScheme {
54 pub scheme_id: Uuid,
56 pub scheme_type: SchemeType,
58 pub perpetrator_id: String,
60 pub start_date: Option<NaiveDate>,
62 pub end_date: NaiveDate,
64 pub final_status: SchemeStatus,
66 pub detection_status: SchemeDetectionStatus,
68 pub total_impact: Decimal,
70 pub stages_completed: u32,
72 pub transaction_count: usize,
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct MultiStageAnomalyLabel {
79 pub anomaly_id: String,
81 pub scheme_id: Uuid,
83 pub scheme_type: SchemeType,
85 pub stage_number: u32,
87 pub stage_name: String,
89 pub total_stages: u32,
91 pub perpetrator_id: String,
93 pub scheme_detected: bool,
95}
96
97pub struct SchemeAdvancer {
99 config: SchemeAdvancerConfig,
100 rng: ChaCha8Rng,
101 active_schemes: Vec<Box<dyn FraudScheme>>,
103 completed_schemes: Vec<CompletedScheme>,
105 active_perpetrators: Vec<String>,
107 active_vendors: Vec<String>,
109 labels: Vec<MultiStageAnomalyLabel>,
111}
112
113impl SchemeAdvancer {
114 pub fn new(config: SchemeAdvancerConfig) -> Self {
116 let rng = seeded_rng(config.seed, 0);
117 Self {
118 config,
119 rng,
120 active_schemes: Vec::new(),
121 completed_schemes: Vec::new(),
122 active_perpetrators: Vec::new(),
123 active_vendors: Vec::new(),
124 labels: Vec::new(),
125 }
126 }
127
128 pub fn maybe_start_scheme(&mut self, context: &SchemeContext) -> Option<Uuid> {
130 if self.active_schemes.len() >= self.config.max_concurrent_schemes {
132 return None;
133 }
134
135 let available_users: Vec<_> = if self.config.allow_repeat_perpetrators {
137 context.available_users.clone()
138 } else {
139 context
140 .available_users
141 .iter()
142 .filter(|u| !self.active_perpetrators.contains(u))
143 .cloned()
144 .collect()
145 };
146
147 if available_users.is_empty() {
148 return None;
149 }
150
151 let r = self.rng.gen::<f64>();
153 let total_prob = self.config.embezzlement_probability
154 + self.config.revenue_manipulation_probability
155 + self.config.kickback_probability;
156
157 if r > total_prob {
158 return None;
159 }
160
161 let normalized_r = r / total_prob;
162 let embezzlement_threshold = self.config.embezzlement_probability / total_prob;
163 let revenue_threshold =
164 embezzlement_threshold + self.config.revenue_manipulation_probability / total_prob;
165
166 let user_idx = self.rng.gen_range(0..available_users.len());
167 let perpetrator = available_users[user_idx].clone();
168
169 let scheme: Box<dyn FraudScheme> = if normalized_r < embezzlement_threshold {
170 let scheme = GradualEmbezzlementScheme::new(&perpetrator)
172 .with_accounts(context.available_accounts.clone());
173 Box::new(scheme)
174 } else if normalized_r < revenue_threshold {
175 let scheme = RevenueManipulationScheme::new(&perpetrator);
177 Box::new(scheme)
178 } else {
179 if context.available_counterparties.is_empty() {
181 return None;
182 }
183
184 let available_vendors: Vec<_> = context
185 .available_counterparties
186 .iter()
187 .filter(|v| !self.active_vendors.contains(v))
188 .cloned()
189 .collect();
190
191 if available_vendors.is_empty() {
192 return None;
193 }
194
195 let vendor_idx = self.rng.gen_range(0..available_vendors.len());
196 let vendor = available_vendors[vendor_idx].clone();
197
198 let inflation = 0.10 + self.rng.gen::<f64>() * 0.15; let scheme =
200 VendorKickbackScheme::new(&perpetrator, &vendor).with_inflation_percent(inflation);
201
202 self.active_vendors.push(vendor);
203 Box::new(scheme)
204 };
205
206 let scheme_id = scheme.scheme_id();
207 self.active_perpetrators.push(perpetrator);
208 self.active_schemes.push(scheme);
209
210 Some(scheme_id)
211 }
212
213 pub fn advance_all(&mut self, context: &SchemeContext) -> Vec<SchemeAction> {
215 let mut all_actions = Vec::new();
216 let mut schemes_to_complete = Vec::new();
217
218 for (idx, scheme) in self.active_schemes.iter_mut().enumerate() {
219 let mut scheme_rng = seeded_rng(self.config.seed, scheme.scheme_id().as_u128() as u64);
221
222 let actions = scheme.advance(context, &mut scheme_rng);
223 all_actions.extend(actions);
224
225 if matches!(
227 scheme.status(),
228 SchemeStatus::Completed | SchemeStatus::Terminated | SchemeStatus::Detected
229 ) {
230 schemes_to_complete.push(idx);
231 }
232 }
233
234 for idx in schemes_to_complete.into_iter().rev() {
236 let scheme = self.active_schemes.remove(idx);
237 let completed = CompletedScheme {
238 scheme_id: scheme.scheme_id(),
239 scheme_type: scheme.scheme_type(),
240 perpetrator_id: scheme.perpetrator_id().to_string(),
241 start_date: scheme.start_date(),
242 end_date: context.current_date,
243 final_status: scheme.status(),
244 detection_status: scheme.detection_status(),
245 total_impact: scheme.total_impact(),
246 stages_completed: scheme.current_stage_number(),
247 transaction_count: scheme.transaction_refs().len(),
248 };
249
250 self.active_perpetrators
252 .retain(|p| p != scheme.perpetrator_id());
253
254 self.completed_schemes.push(completed);
255 }
256
257 all_actions
258 }
259
260 pub fn record_label(&mut self, anomaly_id: impl Into<String>, action: &SchemeAction) {
262 if let Some(scheme) = self
263 .active_schemes
264 .iter()
265 .find(|s| s.scheme_id() == action.scheme_id)
266 {
267 let label = MultiStageAnomalyLabel {
268 anomaly_id: anomaly_id.into(),
269 scheme_id: scheme.scheme_id(),
270 scheme_type: scheme.scheme_type(),
271 stage_number: action.stage,
272 stage_name: scheme.current_stage().name.clone(),
273 total_stages: scheme.stages().len() as u32,
274 perpetrator_id: scheme.perpetrator_id().to_string(),
275 scheme_detected: scheme.detection_status() != SchemeDetectionStatus::Undetected,
276 };
277 self.labels.push(label);
278 }
279 }
280
281 pub fn get_labels(&self) -> &[MultiStageAnomalyLabel] {
283 &self.labels
284 }
285
286 pub fn get_completed_schemes(&self) -> &[CompletedScheme] {
288 &self.completed_schemes
289 }
290
291 pub fn active_scheme_count(&self) -> usize {
293 self.active_schemes.len()
294 }
295
296 pub fn completed_scheme_count(&self) -> usize {
298 self.completed_schemes.len()
299 }
300
301 pub fn active_schemes_summary(&self) -> Vec<(Uuid, SchemeType, SchemeStatus)> {
303 self.active_schemes
304 .iter()
305 .map(|s| (s.scheme_id(), s.scheme_type(), s.status()))
306 .collect()
307 }
308
309 pub fn get_scheme(&self, scheme_id: Uuid) -> Option<&dyn FraudScheme> {
311 self.active_schemes
312 .iter()
313 .find(|s| s.scheme_id() == scheme_id)
314 .map(|s| s.as_ref())
315 }
316
317 pub fn reset(&mut self) {
319 self.active_schemes.clear();
320 self.completed_schemes.clear();
321 self.active_perpetrators.clear();
322 self.active_vendors.clear();
323 self.labels.clear();
324 self.rng = seeded_rng(self.config.seed, 0);
325 }
326
327 pub fn get_statistics(&self) -> SchemeStatistics {
329 let total_impact: Decimal = self
330 .completed_schemes
331 .iter()
332 .map(|s| s.total_impact)
333 .sum::<Decimal>()
334 + self
335 .active_schemes
336 .iter()
337 .map(|s| s.total_impact())
338 .sum::<Decimal>();
339
340 let detected_count = self
341 .completed_schemes
342 .iter()
343 .filter(|s| s.detection_status != SchemeDetectionStatus::Undetected)
344 .count();
345
346 let by_type = |t: SchemeType| {
347 self.completed_schemes
348 .iter()
349 .filter(|s| s.scheme_type == t)
350 .count()
351 + self
352 .active_schemes
353 .iter()
354 .filter(|s| s.scheme_type() == t)
355 .count()
356 };
357
358 SchemeStatistics {
359 total_schemes: self.active_schemes.len() + self.completed_schemes.len(),
360 active_schemes: self.active_schemes.len(),
361 completed_schemes: self.completed_schemes.len(),
362 detected_schemes: detected_count,
363 total_impact,
364 embezzlement_count: by_type(SchemeType::GradualEmbezzlement),
365 revenue_manipulation_count: by_type(SchemeType::RevenueManipulation),
366 kickback_count: by_type(SchemeType::VendorKickback),
367 }
368 }
369}
370
371#[derive(Debug, Clone, Serialize, Deserialize)]
373pub struct SchemeStatistics {
374 pub total_schemes: usize,
376 pub active_schemes: usize,
378 pub completed_schemes: usize,
380 pub detected_schemes: usize,
382 pub total_impact: Decimal,
384 pub embezzlement_count: usize,
386 pub revenue_manipulation_count: usize,
388 pub kickback_count: usize,
390}
391
392#[cfg(test)]
393#[allow(clippy::unwrap_used)]
394mod tests {
395 use super::*;
396
397 #[test]
398 fn test_scheme_advancer_creation() {
399 let advancer = SchemeAdvancer::new(SchemeAdvancerConfig::default());
400 assert_eq!(advancer.active_scheme_count(), 0);
401 assert_eq!(advancer.completed_scheme_count(), 0);
402 }
403
404 #[test]
405 fn test_scheme_advancer_start_scheme() {
406 let mut advancer = SchemeAdvancer::new(SchemeAdvancerConfig {
407 embezzlement_probability: 1.0, ..Default::default()
409 });
410
411 let context = SchemeContext::new(NaiveDate::from_ymd_opt(2024, 1, 15).unwrap(), "1000")
412 .with_users(vec!["USER001".to_string(), "USER002".to_string()])
413 .with_accounts(vec!["5000".to_string()]);
414
415 let scheme_id = advancer.maybe_start_scheme(&context);
416 assert!(scheme_id.is_some());
417 assert_eq!(advancer.active_scheme_count(), 1);
418 }
419
420 #[test]
421 fn test_scheme_advancer_max_concurrent() {
422 let mut advancer = SchemeAdvancer::new(SchemeAdvancerConfig {
423 embezzlement_probability: 1.0,
424 max_concurrent_schemes: 2,
425 ..Default::default()
426 });
427
428 let context = SchemeContext::new(NaiveDate::from_ymd_opt(2024, 1, 15).unwrap(), "1000")
429 .with_users(vec![
430 "USER001".to_string(),
431 "USER002".to_string(),
432 "USER003".to_string(),
433 ])
434 .with_accounts(vec!["5000".to_string()]);
435
436 advancer.maybe_start_scheme(&context);
438 advancer.maybe_start_scheme(&context);
439 let third = advancer.maybe_start_scheme(&context);
440
441 assert_eq!(advancer.active_scheme_count(), 2);
442 assert!(third.is_none()); }
444
445 #[test]
446 fn test_scheme_advancer_advance_all() {
447 let mut advancer = SchemeAdvancer::new(SchemeAdvancerConfig {
448 embezzlement_probability: 1.0,
449 ..Default::default()
450 });
451
452 let context = SchemeContext::new(NaiveDate::from_ymd_opt(2024, 1, 15).unwrap(), "1000")
453 .with_users(vec!["USER001".to_string()])
454 .with_accounts(vec!["5000".to_string()]);
455
456 advancer.maybe_start_scheme(&context);
457
458 for day in 0..30 {
460 let date = NaiveDate::from_ymd_opt(2024, 1, 15).unwrap() + chrono::Duration::days(day);
461 let mut ctx = context.clone();
462 ctx.current_date = date;
463
464 let _actions = advancer.advance_all(&ctx);
465 }
466
467 assert_eq!(advancer.active_scheme_count(), 1);
468 }
469
470 #[test]
471 fn test_scheme_advancer_statistics() {
472 let mut advancer = SchemeAdvancer::new(SchemeAdvancerConfig {
473 embezzlement_probability: 1.0,
474 ..Default::default()
475 });
476
477 let context = SchemeContext::new(NaiveDate::from_ymd_opt(2024, 1, 15).unwrap(), "1000")
478 .with_users(vec!["USER001".to_string()])
479 .with_accounts(vec!["5000".to_string()]);
480
481 advancer.maybe_start_scheme(&context);
482
483 let stats = advancer.get_statistics();
484 assert_eq!(stats.total_schemes, 1);
485 assert_eq!(stats.active_schemes, 1);
486 assert_eq!(stats.embezzlement_count, 1);
487 }
488}