1use chrono::{NaiveDateTime, Utc};
15use rust_decimal::Decimal;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use uuid::Uuid;
19
20use datasynth_core::models::{
21 AnomalyCausalReason, AnomalyType, ErrorType, FraudType, InjectionStrategy, JournalEntry,
22 LabeledAnomaly, RelationalAnomalyType, StatisticalAnomalyType,
23};
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct CounterfactualPair {
28 pub pair_id: String,
30
31 pub original: JournalEntry,
33
34 pub modified: JournalEntry,
36
37 pub anomaly_label: LabeledAnomaly,
39
40 pub change_description: String,
42
43 pub injection_strategy: InjectionStrategy,
45
46 pub generated_at: NaiveDateTime,
48
49 pub metadata: HashMap<String, String>,
51}
52
53impl CounterfactualPair {
54 pub fn new(
56 original: JournalEntry,
57 modified: JournalEntry,
58 anomaly_label: LabeledAnomaly,
59 injection_strategy: InjectionStrategy,
60 ) -> Self {
61 let pair_id = Uuid::new_v4().to_string();
62 let change_description = injection_strategy.description();
63
64 Self {
65 pair_id,
66 original,
67 modified,
68 anomaly_label,
69 change_description,
70 injection_strategy,
71 generated_at: Utc::now().naive_utc(),
72 metadata: HashMap::new(),
73 }
74 }
75
76 pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
78 self.metadata.insert(key.to_string(), value.to_string());
79 self
80 }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub enum CounterfactualSpec {
86 ScaleAmount {
88 factor: f64,
90 },
91
92 AddAmount {
94 delta: Decimal,
96 },
97
98 SetAmount {
100 target: Decimal,
102 },
103
104 ShiftDate {
106 days: i32,
108 },
109
110 ChangePeriod {
112 target_period: u8,
114 },
115
116 ReclassifyAccount {
118 new_account: String,
120 },
121
122 AddLineItem {
124 account: String,
126 amount: Decimal,
128 is_debit: bool,
130 },
131
132 RemoveLineItem {
134 line_index: usize,
136 },
137
138 SplitTransaction {
140 split_count: u32,
142 },
143
144 CreateRoundTrip {
146 intermediaries: Vec<String>,
148 },
149
150 SelfApprove,
152
153 InjectFraud {
155 fraud_type: FraudType,
157 },
158
159 Custom {
161 name: String,
163 params: HashMap<String, String>,
165 },
166}
167
168impl CounterfactualSpec {
169 pub fn to_anomaly_type(&self) -> AnomalyType {
171 match self {
172 CounterfactualSpec::ScaleAmount { factor } if *factor > 2.0 => {
173 AnomalyType::Fraud(FraudType::RevenueManipulation)
174 }
175 CounterfactualSpec::ScaleAmount { .. } => {
176 AnomalyType::Statistical(StatisticalAnomalyType::UnusuallyHighAmount)
177 }
178 CounterfactualSpec::AddAmount { .. } => {
179 AnomalyType::Statistical(StatisticalAnomalyType::UnusuallyHighAmount)
180 }
181 CounterfactualSpec::SetAmount { .. } => {
182 AnomalyType::Statistical(StatisticalAnomalyType::UnusuallyHighAmount)
183 }
184 CounterfactualSpec::ShiftDate { .. } => AnomalyType::Fraud(FraudType::TimingAnomaly),
185 CounterfactualSpec::ChangePeriod { .. } => AnomalyType::Fraud(FraudType::TimingAnomaly),
186 CounterfactualSpec::ReclassifyAccount { .. } => {
187 AnomalyType::Error(ErrorType::MisclassifiedAccount)
188 }
189 CounterfactualSpec::AddLineItem { .. } => {
190 AnomalyType::Fraud(FraudType::FictitiousEntry)
191 }
192 CounterfactualSpec::RemoveLineItem { .. } => {
193 AnomalyType::Error(ErrorType::MissingField)
194 }
195 CounterfactualSpec::SplitTransaction { .. } => {
196 AnomalyType::Fraud(FraudType::SplitTransaction)
197 }
198 CounterfactualSpec::CreateRoundTrip { .. } => {
199 AnomalyType::Relational(RelationalAnomalyType::CircularTransaction)
200 }
201 CounterfactualSpec::SelfApprove => AnomalyType::Fraud(FraudType::SelfApproval),
202 CounterfactualSpec::InjectFraud { fraud_type } => AnomalyType::Fraud(*fraud_type),
203 CounterfactualSpec::Custom { .. } => AnomalyType::Custom("custom".to_string()),
204 }
205 }
206
207 pub fn description(&self) -> String {
209 match self {
210 CounterfactualSpec::ScaleAmount { factor } => {
211 format!("Scale amount by {:.2}x", factor)
212 }
213 CounterfactualSpec::AddAmount { delta } => {
214 format!("Add {} to amount", delta)
215 }
216 CounterfactualSpec::SetAmount { target } => {
217 format!("Set amount to {}", target)
218 }
219 CounterfactualSpec::ShiftDate { days } => {
220 if *days < 0 {
221 format!("Backdate by {} days", days.abs())
222 } else {
223 format!("Forward-date by {} days", days)
224 }
225 }
226 CounterfactualSpec::ChangePeriod { target_period } => {
227 format!("Change to period {}", target_period)
228 }
229 CounterfactualSpec::ReclassifyAccount { new_account } => {
230 format!("Reclassify to account {}", new_account)
231 }
232 CounterfactualSpec::AddLineItem {
233 account,
234 amount,
235 is_debit,
236 } => {
237 format!(
238 "Add {} line for {} to account {}",
239 if *is_debit { "debit" } else { "credit" },
240 amount,
241 account
242 )
243 }
244 CounterfactualSpec::RemoveLineItem { line_index } => {
245 format!("Remove line item {}", line_index)
246 }
247 CounterfactualSpec::SplitTransaction { split_count } => {
248 format!("Split into {} transactions", split_count)
249 }
250 CounterfactualSpec::CreateRoundTrip { intermediaries } => {
251 format!(
252 "Create round-trip through {} entities",
253 intermediaries.len()
254 )
255 }
256 CounterfactualSpec::SelfApprove => "Apply self-approval".to_string(),
257 CounterfactualSpec::InjectFraud { fraud_type } => {
258 format!("Inject {:?} fraud", fraud_type)
259 }
260 CounterfactualSpec::Custom { name, .. } => {
261 format!("Apply custom transformation: {}", name)
262 }
263 }
264 }
265}
266
267pub struct CounterfactualGenerator {
269 seed: u64,
271 counter: u64,
273}
274
275impl CounterfactualGenerator {
276 pub fn new(seed: u64) -> Self {
278 Self { seed, counter: 0 }
279 }
280
281 pub fn generate(
283 &mut self,
284 original: &JournalEntry,
285 spec: &CounterfactualSpec,
286 ) -> CounterfactualPair {
287 self.counter += 1;
288
289 let mut modified = original.clone();
291
292 let injection_strategy = self.apply_spec(&mut modified, spec, original);
294
295 let anomaly_label =
297 self.create_anomaly_label(&modified, spec, &injection_strategy, original);
298
299 if let AnomalyType::Fraud(fraud_type) = spec.to_anomaly_type() {
301 modified.header.is_fraud = true;
302 modified.header.fraud_type = Some(fraud_type);
303 }
304
305 CounterfactualPair::new(
306 original.clone(),
307 modified,
308 anomaly_label,
309 injection_strategy,
310 )
311 }
312
313 pub fn generate_batch(
315 &mut self,
316 original: &JournalEntry,
317 specs: &[CounterfactualSpec],
318 ) -> Vec<CounterfactualPair> {
319 specs
320 .iter()
321 .map(|spec| self.generate(original, spec))
322 .collect()
323 }
324
325 fn apply_spec(
327 &self,
328 entry: &mut JournalEntry,
329 spec: &CounterfactualSpec,
330 original: &JournalEntry,
331 ) -> InjectionStrategy {
332 match spec {
333 CounterfactualSpec::ScaleAmount { factor } => {
334 let original_total = original.total_debit();
335 for line in &mut entry.lines {
336 if line.debit_amount > Decimal::ZERO {
337 let new_amount = Decimal::from_f64_retain(
338 line.debit_amount.to_f64().unwrap_or(0.0) * factor,
339 )
340 .unwrap_or(line.debit_amount);
341 line.debit_amount = new_amount;
342 line.local_amount = new_amount;
343 }
344 if line.credit_amount > Decimal::ZERO {
345 let new_amount = Decimal::from_f64_retain(
346 line.credit_amount.to_f64().unwrap_or(0.0) * factor,
347 )
348 .unwrap_or(line.credit_amount);
349 line.credit_amount = new_amount;
350 line.local_amount = -new_amount;
351 }
352 }
353 InjectionStrategy::AmountManipulation {
354 original: original_total,
355 factor: *factor,
356 }
357 }
358 CounterfactualSpec::AddAmount { delta } => {
359 if !entry.lines.is_empty() {
361 let original_amount = entry.lines[0].debit_amount;
362 if entry.lines[0].debit_amount > Decimal::ZERO {
363 entry.lines[0].debit_amount += delta;
364 entry.lines[0].local_amount += delta;
365 }
366 for line in entry.lines.iter_mut().skip(1) {
368 if line.credit_amount > Decimal::ZERO {
369 line.credit_amount += delta;
370 line.local_amount -= delta;
371 break;
372 }
373 }
374 InjectionStrategy::AmountManipulation {
375 original: original_amount,
376 factor: (original_amount + delta).to_f64().unwrap_or(1.0)
377 / original_amount.to_f64().unwrap_or(1.0),
378 }
379 } else {
380 InjectionStrategy::Custom {
381 name: "AddAmount".to_string(),
382 parameters: HashMap::new(),
383 }
384 }
385 }
386 CounterfactualSpec::SetAmount { target } => {
387 let original_total = original.total_debit();
388 if !entry.lines.is_empty() {
389 if entry.lines[0].debit_amount > Decimal::ZERO {
391 entry.lines[0].debit_amount = *target;
392 entry.lines[0].local_amount = *target;
393 }
394 for line in entry.lines.iter_mut().skip(1) {
396 if line.credit_amount > Decimal::ZERO {
397 line.credit_amount = *target;
398 line.local_amount = -*target;
399 break;
400 }
401 }
402 }
403 InjectionStrategy::AmountManipulation {
404 original: original_total,
405 factor: target.to_f64().unwrap_or(1.0) / original_total.to_f64().unwrap_or(1.0),
406 }
407 }
408 CounterfactualSpec::ShiftDate { days } => {
409 let original_date = entry.header.posting_date;
410 entry.header.posting_date = if *days >= 0 {
411 entry.header.posting_date + chrono::Duration::days(*days as i64)
412 } else {
413 entry.header.posting_date - chrono::Duration::days(days.abs() as i64)
414 };
415 InjectionStrategy::DateShift {
416 days_shifted: *days,
417 original_date,
418 }
419 }
420 CounterfactualSpec::ChangePeriod { target_period } => {
421 entry.header.fiscal_period = *target_period;
422 InjectionStrategy::TimingManipulation {
423 timing_type: "PeriodChange".to_string(),
424 original_time: None,
425 }
426 }
427 CounterfactualSpec::ReclassifyAccount { new_account } => {
428 let old_account = if !entry.lines.is_empty() {
429 let old = entry.lines[0].gl_account.clone();
430 entry.lines[0].gl_account = new_account.clone();
431 entry.lines[0].account_code = new_account.clone();
432 old
433 } else {
434 String::new()
435 };
436 InjectionStrategy::AccountMisclassification {
437 correct_account: old_account,
438 incorrect_account: new_account.clone(),
439 }
440 }
441 CounterfactualSpec::SelfApprove => {
442 let user_id = entry.header.created_by.clone();
443 entry.header.sod_violation = true;
444 InjectionStrategy::SelfApproval { user_id }
445 }
446 CounterfactualSpec::SplitTransaction { split_count } => {
447 let original_amount = original.total_debit();
448 InjectionStrategy::SplitTransaction {
449 original_amount,
450 split_count: *split_count,
451 split_doc_ids: vec![entry.header.document_id.to_string()],
452 }
453 }
454 CounterfactualSpec::CreateRoundTrip { intermediaries } => {
455 InjectionStrategy::CircularFlow {
456 entity_chain: intermediaries.clone(),
457 }
458 }
459 _ => InjectionStrategy::Custom {
460 name: spec.description(),
461 parameters: HashMap::new(),
462 },
463 }
464 }
465
466 fn create_anomaly_label(
468 &self,
469 modified: &JournalEntry,
470 spec: &CounterfactualSpec,
471 strategy: &InjectionStrategy,
472 original: &JournalEntry,
473 ) -> LabeledAnomaly {
474 let anomaly_id = format!("CF-{}-{}", self.seed, self.counter);
475 let anomaly_type = spec.to_anomaly_type();
476
477 LabeledAnomaly {
478 anomaly_id,
479 anomaly_type: anomaly_type.clone(),
480 document_id: modified.header.document_id.to_string(),
481 document_type: "JournalEntry".to_string(),
482 company_code: modified.header.company_code.clone(),
483 anomaly_date: modified.header.posting_date,
484 detection_timestamp: Utc::now().naive_utc(),
485 confidence: 1.0, severity: anomaly_type.severity(),
487 description: spec.description(),
488 related_entities: vec![original.header.document_id.to_string()],
489 monetary_impact: Some(modified.total_debit()),
490 metadata: HashMap::new(),
491 is_injected: true,
492 injection_strategy: Some(strategy.description()),
493 cluster_id: None,
494 original_document_hash: Some(format!("{:x}", hash_entry(original))),
495 causal_reason: Some(AnomalyCausalReason::MLTrainingBalance {
496 target_class: "counterfactual".to_string(),
497 }),
498 structured_strategy: Some(strategy.clone()),
499 parent_anomaly_id: None,
500 child_anomaly_ids: vec![],
501 scenario_id: None,
502 run_id: None,
503 generation_seed: Some(self.seed),
504 }
505 }
506}
507
508fn hash_entry(entry: &JournalEntry) -> u64 {
510 use std::collections::hash_map::DefaultHasher;
511 use std::hash::{Hash, Hasher};
512
513 let mut hasher = DefaultHasher::new();
514 entry.header.document_id.hash(&mut hasher);
515 entry.header.company_code.hash(&mut hasher);
516 entry.header.posting_date.hash(&mut hasher);
517 entry.lines.len().hash(&mut hasher);
518 hasher.finish()
519}
520
521#[derive(Debug, Clone, Serialize, Deserialize)]
523pub struct CounterfactualConfig {
524 pub seed: u64,
526 pub variants_per_original: usize,
528 pub specifications: Vec<CounterfactualSpec>,
530 pub include_originals: bool,
532}
533
534impl Default for CounterfactualConfig {
535 fn default() -> Self {
536 Self {
537 seed: 42,
538 variants_per_original: 3,
539 specifications: vec![
540 CounterfactualSpec::ScaleAmount { factor: 1.5 },
541 CounterfactualSpec::ScaleAmount { factor: 2.0 },
542 CounterfactualSpec::ScaleAmount { factor: 0.5 },
543 CounterfactualSpec::ShiftDate { days: -7 },
544 CounterfactualSpec::ShiftDate { days: 30 },
545 CounterfactualSpec::SelfApprove,
546 ],
547 include_originals: true,
548 }
549 }
550}
551
552use rust_decimal::prelude::*;
554
555#[cfg(test)]
556mod tests {
557 use super::*;
558 use chrono::NaiveDate;
559 use datasynth_core::models::{JournalEntryHeader, JournalEntryLine};
560
561 fn create_test_entry() -> JournalEntry {
562 let header = JournalEntryHeader::new(
563 "TEST".to_string(),
564 NaiveDate::from_ymd_opt(2024, 6, 15).unwrap(),
565 );
566 let mut entry = JournalEntry::new(header);
567
568 entry.add_line(JournalEntryLine::debit(
569 entry.header.document_id,
570 1,
571 "1100".to_string(),
572 Decimal::new(10000, 2), ));
574 entry.add_line(JournalEntryLine::credit(
575 entry.header.document_id,
576 2,
577 "2000".to_string(),
578 Decimal::new(10000, 2), ));
580
581 entry
582 }
583
584 #[test]
585 fn test_counterfactual_generator_scale_amount() {
586 let mut generator = CounterfactualGenerator::new(42);
587 let original = create_test_entry();
588 let spec = CounterfactualSpec::ScaleAmount { factor: 2.0 };
589
590 let pair = generator.generate(&original, &spec);
591
592 assert_eq!(pair.original.total_debit(), Decimal::new(10000, 2));
593 assert_eq!(pair.modified.total_debit(), Decimal::new(20000, 2));
594 assert!(!pair.modified.header.is_fraud);
596 }
597
598 #[test]
599 fn test_counterfactual_generator_shift_date() {
600 let mut generator = CounterfactualGenerator::new(42);
601 let original = create_test_entry();
602 let spec = CounterfactualSpec::ShiftDate { days: -7 };
603
604 let pair = generator.generate(&original, &spec);
605
606 let expected_date = NaiveDate::from_ymd_opt(2024, 6, 8).unwrap();
607 assert_eq!(pair.modified.header.posting_date, expected_date);
608 }
609
610 #[test]
611 fn test_counterfactual_spec_to_anomaly_type() {
612 let spec = CounterfactualSpec::SelfApprove;
613 let anomaly_type = spec.to_anomaly_type();
614
615 assert!(matches!(
617 anomaly_type,
618 AnomalyType::Fraud(FraudType::SelfApproval)
619 ));
620 }
621
622 #[test]
623 fn test_counterfactual_batch_generation() {
624 let mut generator = CounterfactualGenerator::new(42);
625 let original = create_test_entry();
626 let specs = vec![
627 CounterfactualSpec::ScaleAmount { factor: 1.5 },
628 CounterfactualSpec::ShiftDate { days: -3 },
629 CounterfactualSpec::SelfApprove,
630 ];
631
632 let pairs = generator.generate_batch(&original, &specs);
633
634 assert_eq!(pairs.len(), 3);
635 assert!(!pairs[0].modified.header.is_fraud); assert!(pairs[1].modified.header.is_fraud); assert!(pairs[2].modified.header.is_fraud); }
641}