Skip to main content

datasynth_generators/counterfactual/
mod.rs

1//! Counterfactual generation for what-if scenarios and paired examples.
2//!
3//! This module provides:
4//! - Paired normal/anomaly example generation for ML training
5//! - Controllable anomaly injection with specific parameters
6//! - What-if scenario generation for testing and analysis
7//!
8//! Counterfactual generation is essential for:
9//! - Training robust anomaly detection models
10//! - Understanding the impact of specific changes
11//! - Testing detection system sensitivity
12//! - Generating balanced ML datasets
13
14use chrono::{NaiveDateTime, Utc};
15use rust_decimal::Decimal;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use uuid::Uuid;
19
20use datasynth_core::models::{
21    AnomalyCausalReason, AnomalyType, ErrorType, FraudType, InjectionStrategy, JournalEntry,
22    LabeledAnomaly, RelationalAnomalyType, StatisticalAnomalyType,
23};
24
25/// A counterfactual pair containing both the original and modified versions.
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct CounterfactualPair {
28    /// Unique identifier for this pair.
29    pub pair_id: String,
30
31    /// The original (normal) journal entry.
32    pub original: JournalEntry,
33
34    /// The modified (anomalous) journal entry.
35    pub modified: JournalEntry,
36
37    /// The anomaly label for the modified entry.
38    pub anomaly_label: LabeledAnomaly,
39
40    /// Description of what changed.
41    pub change_description: String,
42
43    /// The injection strategy applied.
44    pub injection_strategy: InjectionStrategy,
45
46    /// Timestamp when the pair was generated.
47    pub generated_at: NaiveDateTime,
48
49    /// Additional metadata.
50    pub metadata: HashMap<String, String>,
51}
52
53impl CounterfactualPair {
54    /// Create a new counterfactual pair.
55    pub fn new(
56        original: JournalEntry,
57        modified: JournalEntry,
58        anomaly_label: LabeledAnomaly,
59        injection_strategy: InjectionStrategy,
60    ) -> Self {
61        let pair_id = Uuid::new_v4().to_string();
62        let change_description = injection_strategy.description();
63
64        Self {
65            pair_id,
66            original,
67            modified,
68            anomaly_label,
69            change_description,
70            injection_strategy,
71            generated_at: Utc::now().naive_utc(),
72            metadata: HashMap::new(),
73        }
74    }
75
76    /// Add metadata to the pair.
77    pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
78        self.metadata.insert(key.to_string(), value.to_string());
79        self
80    }
81}
82
83/// Specification for a counterfactual modification.
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub enum CounterfactualSpec {
86    /// Multiply amount by a factor.
87    ScaleAmount {
88        /// Multiplication factor.
89        factor: f64,
90    },
91
92    /// Add a fixed amount.
93    AddAmount {
94        /// Amount to add (can be negative).
95        delta: Decimal,
96    },
97
98    /// Set amount to a specific value.
99    SetAmount {
100        /// Target amount.
101        target: Decimal,
102    },
103
104    /// Shift the posting date.
105    ShiftDate {
106        /// Days to shift (negative = earlier).
107        days: i32,
108    },
109
110    /// Change the fiscal period.
111    ChangePeriod {
112        /// Target fiscal period.
113        target_period: u8,
114    },
115
116    /// Change the account classification.
117    ReclassifyAccount {
118        /// New account number.
119        new_account: String,
120    },
121
122    /// Add a line item.
123    AddLineItem {
124        /// Account for the new line.
125        account: String,
126        /// Amount for the new line.
127        amount: Decimal,
128        /// Is debit (true) or credit (false).
129        is_debit: bool,
130    },
131
132    /// Remove a line item by index.
133    RemoveLineItem {
134        /// Index of line to remove.
135        line_index: usize,
136    },
137
138    /// Split into multiple transactions.
139    SplitTransaction {
140        /// Number of splits.
141        split_count: u32,
142    },
143
144    /// Create a round-tripping pattern.
145    CreateRoundTrip {
146        /// Intermediate entities.
147        intermediaries: Vec<String>,
148    },
149
150    /// Mark as self-approved.
151    SelfApprove,
152
153    /// Inject a specific fraud type.
154    InjectFraud {
155        /// The fraud type to inject.
156        fraud_type: FraudType,
157    },
158
159    /// Apply a custom transformation.
160    Custom {
161        /// Transformation name.
162        name: String,
163        /// Parameters.
164        params: HashMap<String, String>,
165    },
166}
167
168impl CounterfactualSpec {
169    /// Get the anomaly type this spec would produce.
170    pub fn to_anomaly_type(&self) -> AnomalyType {
171        match self {
172            CounterfactualSpec::ScaleAmount { factor } if *factor > 2.0 => {
173                AnomalyType::Fraud(FraudType::RevenueManipulation)
174            }
175            CounterfactualSpec::ScaleAmount { .. } => {
176                AnomalyType::Statistical(StatisticalAnomalyType::UnusuallyHighAmount)
177            }
178            CounterfactualSpec::AddAmount { .. } => {
179                AnomalyType::Statistical(StatisticalAnomalyType::UnusuallyHighAmount)
180            }
181            CounterfactualSpec::SetAmount { .. } => {
182                AnomalyType::Statistical(StatisticalAnomalyType::UnusuallyHighAmount)
183            }
184            CounterfactualSpec::ShiftDate { .. } => AnomalyType::Fraud(FraudType::TimingAnomaly),
185            CounterfactualSpec::ChangePeriod { .. } => AnomalyType::Fraud(FraudType::TimingAnomaly),
186            CounterfactualSpec::ReclassifyAccount { .. } => {
187                AnomalyType::Error(ErrorType::MisclassifiedAccount)
188            }
189            CounterfactualSpec::AddLineItem { .. } => {
190                AnomalyType::Fraud(FraudType::FictitiousEntry)
191            }
192            CounterfactualSpec::RemoveLineItem { .. } => {
193                AnomalyType::Error(ErrorType::MissingField)
194            }
195            CounterfactualSpec::SplitTransaction { .. } => {
196                AnomalyType::Fraud(FraudType::SplitTransaction)
197            }
198            CounterfactualSpec::CreateRoundTrip { .. } => {
199                AnomalyType::Relational(RelationalAnomalyType::CircularTransaction)
200            }
201            CounterfactualSpec::SelfApprove => AnomalyType::Fraud(FraudType::SelfApproval),
202            CounterfactualSpec::InjectFraud { fraud_type } => AnomalyType::Fraud(*fraud_type),
203            CounterfactualSpec::Custom { .. } => AnomalyType::Custom("custom".to_string()),
204        }
205    }
206
207    /// Get a description of this specification.
208    pub fn description(&self) -> String {
209        match self {
210            CounterfactualSpec::ScaleAmount { factor } => {
211                format!("Scale amount by {:.2}x", factor)
212            }
213            CounterfactualSpec::AddAmount { delta } => {
214                format!("Add {} to amount", delta)
215            }
216            CounterfactualSpec::SetAmount { target } => {
217                format!("Set amount to {}", target)
218            }
219            CounterfactualSpec::ShiftDate { days } => {
220                if *days < 0 {
221                    format!("Backdate by {} days", days.abs())
222                } else {
223                    format!("Forward-date by {} days", days)
224                }
225            }
226            CounterfactualSpec::ChangePeriod { target_period } => {
227                format!("Change to period {}", target_period)
228            }
229            CounterfactualSpec::ReclassifyAccount { new_account } => {
230                format!("Reclassify to account {}", new_account)
231            }
232            CounterfactualSpec::AddLineItem {
233                account,
234                amount,
235                is_debit,
236            } => {
237                format!(
238                    "Add {} line for {} to account {}",
239                    if *is_debit { "debit" } else { "credit" },
240                    amount,
241                    account
242                )
243            }
244            CounterfactualSpec::RemoveLineItem { line_index } => {
245                format!("Remove line item {}", line_index)
246            }
247            CounterfactualSpec::SplitTransaction { split_count } => {
248                format!("Split into {} transactions", split_count)
249            }
250            CounterfactualSpec::CreateRoundTrip { intermediaries } => {
251                format!(
252                    "Create round-trip through {} entities",
253                    intermediaries.len()
254                )
255            }
256            CounterfactualSpec::SelfApprove => "Apply self-approval".to_string(),
257            CounterfactualSpec::InjectFraud { fraud_type } => {
258                format!("Inject {:?} fraud", fraud_type)
259            }
260            CounterfactualSpec::Custom { name, .. } => {
261                format!("Apply custom transformation: {}", name)
262            }
263        }
264    }
265}
266
267/// Generator for counterfactual pairs.
268pub struct CounterfactualGenerator {
269    /// Seed for reproducibility.
270    seed: u64,
271    /// Counter for generating unique IDs.
272    counter: u64,
273}
274
275impl CounterfactualGenerator {
276    /// Create a new counterfactual generator.
277    pub fn new(seed: u64) -> Self {
278        Self { seed, counter: 0 }
279    }
280
281    /// Generate a counterfactual pair by applying a specification to an entry.
282    pub fn generate(
283        &mut self,
284        original: &JournalEntry,
285        spec: &CounterfactualSpec,
286    ) -> CounterfactualPair {
287        self.counter += 1;
288
289        // Clone the original to create modified version
290        let mut modified = original.clone();
291
292        // Apply the specification to create the modified entry
293        let injection_strategy = self.apply_spec(&mut modified, spec, original);
294
295        // Create the anomaly label
296        let anomaly_label =
297            self.create_anomaly_label(&modified, spec, &injection_strategy, original);
298
299        // Mark the modified entry as fraudulent if the anomaly type is fraud
300        if let AnomalyType::Fraud(fraud_type) = spec.to_anomaly_type() {
301            modified.header.is_fraud = true;
302            modified.header.fraud_type = Some(fraud_type);
303        }
304
305        CounterfactualPair::new(
306            original.clone(),
307            modified,
308            anomaly_label,
309            injection_strategy,
310        )
311    }
312
313    /// Generate multiple counterfactual pairs from a single original.
314    pub fn generate_batch(
315        &mut self,
316        original: &JournalEntry,
317        specs: &[CounterfactualSpec],
318    ) -> Vec<CounterfactualPair> {
319        specs
320            .iter()
321            .map(|spec| self.generate(original, spec))
322            .collect()
323    }
324
325    /// Apply a specification to a journal entry.
326    fn apply_spec(
327        &self,
328        entry: &mut JournalEntry,
329        spec: &CounterfactualSpec,
330        original: &JournalEntry,
331    ) -> InjectionStrategy {
332        match spec {
333            CounterfactualSpec::ScaleAmount { factor } => {
334                let original_total = original.total_debit();
335                for line in &mut entry.lines {
336                    if line.debit_amount > Decimal::ZERO {
337                        let new_amount = Decimal::from_f64_retain(
338                            line.debit_amount.to_f64().unwrap_or(0.0) * factor,
339                        )
340                        .unwrap_or(line.debit_amount);
341                        line.debit_amount = new_amount;
342                        line.local_amount = new_amount;
343                    }
344                    if line.credit_amount > Decimal::ZERO {
345                        let new_amount = Decimal::from_f64_retain(
346                            line.credit_amount.to_f64().unwrap_or(0.0) * factor,
347                        )
348                        .unwrap_or(line.credit_amount);
349                        line.credit_amount = new_amount;
350                        line.local_amount = -new_amount;
351                    }
352                }
353                InjectionStrategy::AmountManipulation {
354                    original: original_total,
355                    factor: *factor,
356                }
357            }
358            CounterfactualSpec::AddAmount { delta } => {
359                // Add delta to first debit line and first credit line to keep balanced
360                if !entry.lines.is_empty() {
361                    let original_amount = entry.lines[0].debit_amount;
362                    if entry.lines[0].debit_amount > Decimal::ZERO {
363                        entry.lines[0].debit_amount += delta;
364                        entry.lines[0].local_amount += delta;
365                    }
366                    // Find first credit line and add to it
367                    for line in entry.lines.iter_mut().skip(1) {
368                        if line.credit_amount > Decimal::ZERO {
369                            line.credit_amount += delta;
370                            line.local_amount -= delta;
371                            break;
372                        }
373                    }
374                    InjectionStrategy::AmountManipulation {
375                        original: original_amount,
376                        factor: (original_amount + delta).to_f64().unwrap_or(1.0)
377                            / original_amount.to_f64().unwrap_or(1.0),
378                    }
379                } else {
380                    InjectionStrategy::Custom {
381                        name: "AddAmount".to_string(),
382                        parameters: HashMap::new(),
383                    }
384                }
385            }
386            CounterfactualSpec::SetAmount { target } => {
387                let original_total = original.total_debit();
388                if !entry.lines.is_empty() {
389                    // Set first debit line
390                    if entry.lines[0].debit_amount > Decimal::ZERO {
391                        entry.lines[0].debit_amount = *target;
392                        entry.lines[0].local_amount = *target;
393                    }
394                    // Find first credit line and set it
395                    for line in entry.lines.iter_mut().skip(1) {
396                        if line.credit_amount > Decimal::ZERO {
397                            line.credit_amount = *target;
398                            line.local_amount = -*target;
399                            break;
400                        }
401                    }
402                }
403                InjectionStrategy::AmountManipulation {
404                    original: original_total,
405                    factor: target.to_f64().unwrap_or(1.0) / original_total.to_f64().unwrap_or(1.0),
406                }
407            }
408            CounterfactualSpec::ShiftDate { days } => {
409                let original_date = entry.header.posting_date;
410                entry.header.posting_date = if *days >= 0 {
411                    entry.header.posting_date + chrono::Duration::days(*days as i64)
412                } else {
413                    entry.header.posting_date - chrono::Duration::days(days.abs() as i64)
414                };
415                InjectionStrategy::DateShift {
416                    days_shifted: *days,
417                    original_date,
418                }
419            }
420            CounterfactualSpec::ChangePeriod { target_period } => {
421                entry.header.fiscal_period = *target_period;
422                InjectionStrategy::TimingManipulation {
423                    timing_type: "PeriodChange".to_string(),
424                    original_time: None,
425                }
426            }
427            CounterfactualSpec::ReclassifyAccount { new_account } => {
428                let old_account = if !entry.lines.is_empty() {
429                    let old = entry.lines[0].gl_account.clone();
430                    entry.lines[0].gl_account = new_account.clone();
431                    entry.lines[0].account_code = new_account.clone();
432                    old
433                } else {
434                    String::new()
435                };
436                InjectionStrategy::AccountMisclassification {
437                    correct_account: old_account,
438                    incorrect_account: new_account.clone(),
439                }
440            }
441            CounterfactualSpec::SelfApprove => {
442                let user_id = entry.header.created_by.clone();
443                entry.header.sod_violation = true;
444                InjectionStrategy::SelfApproval { user_id }
445            }
446            CounterfactualSpec::SplitTransaction { split_count } => {
447                let original_amount = original.total_debit();
448                InjectionStrategy::SplitTransaction {
449                    original_amount,
450                    split_count: *split_count,
451                    split_doc_ids: vec![entry.header.document_id.to_string()],
452                }
453            }
454            CounterfactualSpec::CreateRoundTrip { intermediaries } => {
455                InjectionStrategy::CircularFlow {
456                    entity_chain: intermediaries.clone(),
457                }
458            }
459            _ => InjectionStrategy::Custom {
460                name: spec.description(),
461                parameters: HashMap::new(),
462            },
463        }
464    }
465
466    /// Create an anomaly label for the modified entry.
467    fn create_anomaly_label(
468        &self,
469        modified: &JournalEntry,
470        spec: &CounterfactualSpec,
471        strategy: &InjectionStrategy,
472        original: &JournalEntry,
473    ) -> LabeledAnomaly {
474        let anomaly_id = format!("CF-{}-{}", self.seed, self.counter);
475        let anomaly_type = spec.to_anomaly_type();
476
477        LabeledAnomaly {
478            anomaly_id,
479            anomaly_type: anomaly_type.clone(),
480            document_id: modified.header.document_id.to_string(),
481            document_type: "JournalEntry".to_string(),
482            company_code: modified.header.company_code.clone(),
483            anomaly_date: modified.header.posting_date,
484            detection_timestamp: Utc::now().naive_utc(),
485            confidence: 1.0, // Counterfactuals are known anomalies
486            severity: anomaly_type.severity(),
487            description: spec.description(),
488            related_entities: vec![original.header.document_id.to_string()],
489            monetary_impact: Some(modified.total_debit()),
490            metadata: HashMap::new(),
491            is_injected: true,
492            injection_strategy: Some(strategy.description()),
493            cluster_id: None,
494            original_document_hash: Some(format!("{:x}", hash_entry(original))),
495            causal_reason: Some(AnomalyCausalReason::MLTrainingBalance {
496                target_class: "counterfactual".to_string(),
497            }),
498            structured_strategy: Some(strategy.clone()),
499            parent_anomaly_id: None,
500            child_anomaly_ids: vec![],
501            scenario_id: None,
502            run_id: None,
503            generation_seed: Some(self.seed),
504        }
505    }
506}
507
508/// Simple hash function for journal entries (for provenance tracking).
509fn hash_entry(entry: &JournalEntry) -> u64 {
510    use std::collections::hash_map::DefaultHasher;
511    use std::hash::{Hash, Hasher};
512
513    let mut hasher = DefaultHasher::new();
514    entry.header.document_id.hash(&mut hasher);
515    entry.header.company_code.hash(&mut hasher);
516    entry.header.posting_date.hash(&mut hasher);
517    entry.lines.len().hash(&mut hasher);
518    hasher.finish()
519}
520
521/// Configuration for batch counterfactual generation.
522#[derive(Debug, Clone, Serialize, Deserialize)]
523pub struct CounterfactualConfig {
524    /// Seed for reproducibility.
525    pub seed: u64,
526    /// Number of counterfactual variants per original.
527    pub variants_per_original: usize,
528    /// Specifications to apply (randomly selected).
529    pub specifications: Vec<CounterfactualSpec>,
530    /// Whether to include the original in output.
531    pub include_originals: bool,
532}
533
534impl Default for CounterfactualConfig {
535    fn default() -> Self {
536        Self {
537            seed: 42,
538            variants_per_original: 3,
539            specifications: vec![
540                CounterfactualSpec::ScaleAmount { factor: 1.5 },
541                CounterfactualSpec::ScaleAmount { factor: 2.0 },
542                CounterfactualSpec::ScaleAmount { factor: 0.5 },
543                CounterfactualSpec::ShiftDate { days: -7 },
544                CounterfactualSpec::ShiftDate { days: 30 },
545                CounterfactualSpec::SelfApprove,
546            ],
547            include_originals: true,
548        }
549    }
550}
551
552// Re-export Decimal for use in specs
553use rust_decimal::prelude::*;
554
555#[cfg(test)]
556mod tests {
557    use super::*;
558    use chrono::NaiveDate;
559    use datasynth_core::models::{JournalEntryHeader, JournalEntryLine};
560
561    fn create_test_entry() -> JournalEntry {
562        let header = JournalEntryHeader::new(
563            "TEST".to_string(),
564            NaiveDate::from_ymd_opt(2024, 6, 15).unwrap(),
565        );
566        let mut entry = JournalEntry::new(header);
567
568        entry.add_line(JournalEntryLine::debit(
569            entry.header.document_id,
570            1,
571            "1100".to_string(),
572            Decimal::new(10000, 2), // 100.00
573        ));
574        entry.add_line(JournalEntryLine::credit(
575            entry.header.document_id,
576            2,
577            "2000".to_string(),
578            Decimal::new(10000, 2), // 100.00
579        ));
580
581        entry
582    }
583
584    #[test]
585    fn test_counterfactual_generator_scale_amount() {
586        let mut generator = CounterfactualGenerator::new(42);
587        let original = create_test_entry();
588        let spec = CounterfactualSpec::ScaleAmount { factor: 2.0 };
589
590        let pair = generator.generate(&original, &spec);
591
592        assert_eq!(pair.original.total_debit(), Decimal::new(10000, 2));
593        assert_eq!(pair.modified.total_debit(), Decimal::new(20000, 2));
594        // ScaleAmount with factor <= 2.0 is statistical anomaly, not fraud
595        assert!(!pair.modified.header.is_fraud);
596    }
597
598    #[test]
599    fn test_counterfactual_generator_shift_date() {
600        let mut generator = CounterfactualGenerator::new(42);
601        let original = create_test_entry();
602        let spec = CounterfactualSpec::ShiftDate { days: -7 };
603
604        let pair = generator.generate(&original, &spec);
605
606        let expected_date = NaiveDate::from_ymd_opt(2024, 6, 8).unwrap();
607        assert_eq!(pair.modified.header.posting_date, expected_date);
608    }
609
610    #[test]
611    fn test_counterfactual_spec_to_anomaly_type() {
612        let spec = CounterfactualSpec::SelfApprove;
613        let anomaly_type = spec.to_anomaly_type();
614
615        // SelfApprove is classified as Fraud (FraudType::SelfApproval)
616        assert!(matches!(
617            anomaly_type,
618            AnomalyType::Fraud(FraudType::SelfApproval)
619        ));
620    }
621
622    #[test]
623    fn test_counterfactual_batch_generation() {
624        let mut generator = CounterfactualGenerator::new(42);
625        let original = create_test_entry();
626        let specs = vec![
627            CounterfactualSpec::ScaleAmount { factor: 1.5 },
628            CounterfactualSpec::ShiftDate { days: -3 },
629            CounterfactualSpec::SelfApprove,
630        ];
631
632        let pairs = generator.generate_batch(&original, &specs);
633
634        assert_eq!(pairs.len(), 3);
635        // Only fraud types (ShiftDate, SelfApprove) set is_fraud = true
636        // ScaleAmount with factor <= 2.0 is statistical, not fraud
637        assert!(!pairs[0].modified.header.is_fraud); // ScaleAmount -> Statistical
638        assert!(pairs[1].modified.header.is_fraud); // ShiftDate -> TimingAnomaly (Fraud)
639        assert!(pairs[2].modified.header.is_fraud); // SelfApprove -> SelfApproval (Fraud)
640    }
641}