use chrono::{NaiveDateTime, Utc};
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use datasynth_core::uuid_factory::{DeterministicUuidFactory, GeneratorType};
use datasynth_core::models::{
AnomalyCausalReason, AnomalyType, ErrorType, FraudType, InjectionStrategy, JournalEntry,
JournalEntryLine, LabeledAnomaly, RelationalAnomalyType, StatisticalAnomalyType,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CounterfactualPair {
pub pair_id: String,
pub original: JournalEntry,
pub modified: JournalEntry,
pub anomaly_label: LabeledAnomaly,
pub change_description: String,
pub injection_strategy: InjectionStrategy,
pub generated_at: NaiveDateTime,
pub metadata: HashMap<String, String>,
}
impl CounterfactualPair {
pub fn new(
original: JournalEntry,
modified: JournalEntry,
anomaly_label: LabeledAnomaly,
injection_strategy: InjectionStrategy,
uuid_factory: &DeterministicUuidFactory,
) -> Self {
let pair_id = uuid_factory.next().to_string();
let change_description = injection_strategy.description();
Self {
pair_id,
original,
modified,
anomaly_label,
change_description,
injection_strategy,
generated_at: Utc::now().naive_utc(),
metadata: HashMap::new(),
}
}
pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
self.metadata.insert(key.to_string(), value.to_string());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CounterfactualSpec {
ScaleAmount {
factor: f64,
},
AddAmount {
delta: Decimal,
},
SetAmount {
target: Decimal,
},
ShiftDate {
days: i32,
},
ChangePeriod {
target_period: u8,
},
ReclassifyAccount {
new_account: String,
},
AddLineItem {
account: String,
amount: Decimal,
is_debit: bool,
},
RemoveLineItem {
line_index: usize,
},
SplitTransaction {
split_count: u32,
},
CreateRoundTrip {
intermediaries: Vec<String>,
},
SelfApprove,
InjectFraud {
fraud_type: FraudType,
},
Custom {
name: String,
params: HashMap<String, String>,
},
}
impl CounterfactualSpec {
pub fn to_anomaly_type(&self) -> AnomalyType {
match self {
CounterfactualSpec::ScaleAmount { factor } if *factor > 2.0 => {
AnomalyType::Fraud(FraudType::RevenueManipulation)
}
CounterfactualSpec::ScaleAmount { .. } => {
AnomalyType::Statistical(StatisticalAnomalyType::UnusuallyHighAmount)
}
CounterfactualSpec::AddAmount { .. } => {
AnomalyType::Statistical(StatisticalAnomalyType::UnusuallyHighAmount)
}
CounterfactualSpec::SetAmount { .. } => {
AnomalyType::Statistical(StatisticalAnomalyType::UnusuallyHighAmount)
}
CounterfactualSpec::ShiftDate { .. } => AnomalyType::Fraud(FraudType::TimingAnomaly),
CounterfactualSpec::ChangePeriod { .. } => AnomalyType::Fraud(FraudType::TimingAnomaly),
CounterfactualSpec::ReclassifyAccount { .. } => {
AnomalyType::Error(ErrorType::MisclassifiedAccount)
}
CounterfactualSpec::AddLineItem { .. } => {
AnomalyType::Fraud(FraudType::FictitiousEntry)
}
CounterfactualSpec::RemoveLineItem { .. } => {
AnomalyType::Error(ErrorType::MissingField)
}
CounterfactualSpec::SplitTransaction { .. } => {
AnomalyType::Fraud(FraudType::SplitTransaction)
}
CounterfactualSpec::CreateRoundTrip { .. } => {
AnomalyType::Relational(RelationalAnomalyType::CircularTransaction)
}
CounterfactualSpec::SelfApprove => AnomalyType::Fraud(FraudType::SelfApproval),
CounterfactualSpec::InjectFraud { fraud_type } => AnomalyType::Fraud(*fraud_type),
CounterfactualSpec::Custom { .. } => AnomalyType::Custom("custom".to_string()),
}
}
pub fn description(&self) -> String {
match self {
CounterfactualSpec::ScaleAmount { factor } => {
format!("Scale amount by {factor:.2}x")
}
CounterfactualSpec::AddAmount { delta } => {
format!("Add {delta} to amount")
}
CounterfactualSpec::SetAmount { target } => {
format!("Set amount to {target}")
}
CounterfactualSpec::ShiftDate { days } => {
if *days < 0 {
format!("Backdate by {} days", days.abs())
} else {
format!("Forward-date by {days} days")
}
}
CounterfactualSpec::ChangePeriod { target_period } => {
format!("Change to period {target_period}")
}
CounterfactualSpec::ReclassifyAccount { new_account } => {
format!("Reclassify to account {new_account}")
}
CounterfactualSpec::AddLineItem {
account,
amount,
is_debit,
} => {
format!(
"Add {} line for {} to account {}",
if *is_debit { "debit" } else { "credit" },
amount,
account
)
}
CounterfactualSpec::RemoveLineItem { line_index } => {
format!("Remove line item {line_index}")
}
CounterfactualSpec::SplitTransaction { split_count } => {
format!("Split into {split_count} transactions")
}
CounterfactualSpec::CreateRoundTrip { intermediaries } => {
format!(
"Create round-trip through {} entities",
intermediaries.len()
)
}
CounterfactualSpec::SelfApprove => "Apply self-approval".to_string(),
CounterfactualSpec::InjectFraud { fraud_type } => {
format!("Inject {fraud_type:?} fraud")
}
CounterfactualSpec::Custom { name, .. } => {
format!("Apply custom transformation: {name}")
}
}
}
}
pub struct CounterfactualGenerator {
seed: u64,
counter: u64,
uuid_factory: DeterministicUuidFactory,
}
impl CounterfactualGenerator {
pub fn new(seed: u64) -> Self {
Self {
seed,
counter: 0,
uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::Anomaly),
}
}
pub fn generate(
&mut self,
original: &JournalEntry,
spec: &CounterfactualSpec,
) -> CounterfactualPair {
self.counter += 1;
let mut modified = original.clone();
let injection_strategy = self.apply_spec(&mut modified, spec, original);
let anomaly_label =
self.create_anomaly_label(&modified, spec, &injection_strategy, original);
if let AnomalyType::Fraud(fraud_type) = spec.to_anomaly_type() {
modified.header.is_fraud = true;
modified.header.fraud_type = Some(fraud_type);
}
CounterfactualPair::new(
original.clone(),
modified,
anomaly_label,
injection_strategy,
&self.uuid_factory,
)
}
pub fn generate_batch(
&mut self,
original: &JournalEntry,
specs: &[CounterfactualSpec],
) -> Vec<CounterfactualPair> {
specs
.iter()
.map(|spec| self.generate(original, spec))
.collect()
}
fn apply_spec(
&self,
entry: &mut JournalEntry,
spec: &CounterfactualSpec,
original: &JournalEntry,
) -> InjectionStrategy {
match spec {
CounterfactualSpec::ScaleAmount { factor } => {
let original_total = original.total_debit();
for line in &mut entry.lines {
if line.debit_amount > Decimal::ZERO {
let new_amount = Decimal::from_f64_retain(
line.debit_amount.to_f64().unwrap_or(0.0) * factor,
)
.unwrap_or(line.debit_amount);
line.debit_amount = new_amount;
line.local_amount = new_amount;
}
if line.credit_amount > Decimal::ZERO {
let new_amount = Decimal::from_f64_retain(
line.credit_amount.to_f64().unwrap_or(0.0) * factor,
)
.unwrap_or(line.credit_amount);
line.credit_amount = new_amount;
line.local_amount = -new_amount;
}
}
InjectionStrategy::AmountManipulation {
original: original_total,
factor: *factor,
}
}
CounterfactualSpec::AddAmount { delta } => {
if !entry.lines.is_empty() {
let original_amount = entry.lines[0].debit_amount;
if entry.lines[0].debit_amount > Decimal::ZERO {
entry.lines[0].debit_amount += delta;
entry.lines[0].local_amount += delta;
}
for line in entry.lines.iter_mut().skip(1) {
if line.credit_amount > Decimal::ZERO {
line.credit_amount += delta;
line.local_amount -= delta;
break;
}
}
InjectionStrategy::AmountManipulation {
original: original_amount,
factor: (original_amount + delta).to_f64().unwrap_or(1.0)
/ original_amount.to_f64().unwrap_or(1.0),
}
} else {
InjectionStrategy::Custom {
name: "AddAmount".to_string(),
parameters: HashMap::new(),
}
}
}
CounterfactualSpec::SetAmount { target } => {
let original_total = original.total_debit();
if !entry.lines.is_empty() {
if entry.lines[0].debit_amount > Decimal::ZERO {
entry.lines[0].debit_amount = *target;
entry.lines[0].local_amount = *target;
}
for line in entry.lines.iter_mut().skip(1) {
if line.credit_amount > Decimal::ZERO {
line.credit_amount = *target;
line.local_amount = -*target;
break;
}
}
}
InjectionStrategy::AmountManipulation {
original: original_total,
factor: target.to_f64().unwrap_or(1.0) / original_total.to_f64().unwrap_or(1.0),
}
}
CounterfactualSpec::ShiftDate { days } => {
let original_date = entry.header.posting_date;
entry.header.posting_date = if *days >= 0 {
entry.header.posting_date + chrono::Duration::days(*days as i64)
} else {
entry.header.posting_date - chrono::Duration::days(days.abs() as i64)
};
InjectionStrategy::DateShift {
days_shifted: *days,
original_date,
}
}
CounterfactualSpec::ChangePeriod { target_period } => {
entry.header.fiscal_period = *target_period;
InjectionStrategy::TimingManipulation {
timing_type: "PeriodChange".to_string(),
original_time: None,
}
}
CounterfactualSpec::ReclassifyAccount { new_account } => {
let old_account = if !entry.lines.is_empty() {
let old = entry.lines[0].gl_account.clone();
entry.lines[0].gl_account = new_account.clone();
entry.lines[0].account_code = new_account.clone();
old
} else {
String::new()
};
InjectionStrategy::AccountMisclassification {
correct_account: old_account,
incorrect_account: new_account.clone(),
}
}
CounterfactualSpec::SelfApprove => {
let user_id = entry.header.created_by.clone();
entry.header.sod_violation = true;
InjectionStrategy::SelfApproval { user_id }
}
CounterfactualSpec::SplitTransaction { split_count } => {
let original_amount = original.total_debit();
let count = (*split_count).max(1);
let divisor = Decimal::from_f64_retain(count as f64).unwrap_or(Decimal::ONE);
let mut new_lines: Vec<JournalEntryLine> = Vec::new();
let mut line_number: u32 = 1;
for orig_line in &original.lines {
for _ in 0..count {
let mut split_line = orig_line.clone();
split_line.line_number = line_number;
if split_line.debit_amount > Decimal::ZERO {
let split_amt = split_line.debit_amount / divisor;
split_line.debit_amount = split_amt;
split_line.local_amount = split_amt;
}
if split_line.credit_amount > Decimal::ZERO {
let split_amt = split_line.credit_amount / divisor;
split_line.credit_amount = split_amt;
split_line.local_amount = -split_amt;
}
new_lines.push(split_line);
line_number += 1;
}
}
entry.lines = new_lines.into();
InjectionStrategy::SplitTransaction {
original_amount,
split_count: *split_count,
split_doc_ids: vec![entry.header.document_id.to_string()],
}
}
CounterfactualSpec::CreateRoundTrip { intermediaries } => {
InjectionStrategy::CircularFlow {
entity_chain: intermediaries.clone(),
}
}
CounterfactualSpec::AddLineItem {
account,
amount,
is_debit,
} => {
let next_line_number =
entry.lines.iter().map(|l| l.line_number).max().unwrap_or(0) + 1;
let new_line = if *is_debit {
JournalEntryLine::debit(
entry.header.document_id,
next_line_number,
account.clone(),
*amount,
)
} else {
JournalEntryLine::credit(
entry.header.document_id,
next_line_number,
account.clone(),
*amount,
)
};
entry.lines.push(new_line);
InjectionStrategy::Custom {
name: "AddLineItem".to_string(),
parameters: HashMap::from([
("account".to_string(), account.clone()),
("amount".to_string(), amount.to_string()),
("is_debit".to_string(), is_debit.to_string()),
]),
}
}
CounterfactualSpec::RemoveLineItem { line_index } => {
let removed_account = if *line_index < entry.lines.len() {
let removed = entry.lines.remove(*line_index);
removed.gl_account
} else {
String::from("(index out of bounds)")
};
InjectionStrategy::Custom {
name: "RemoveLineItem".to_string(),
parameters: HashMap::from([
("line_index".to_string(), line_index.to_string()),
("removed_account".to_string(), removed_account),
]),
}
}
_ => InjectionStrategy::Custom {
name: spec.description(),
parameters: HashMap::new(),
},
}
}
fn create_anomaly_label(
&self,
modified: &JournalEntry,
spec: &CounterfactualSpec,
strategy: &InjectionStrategy,
original: &JournalEntry,
) -> LabeledAnomaly {
let anomaly_id = format!("CF-{}-{}", self.seed, self.counter);
let anomaly_type = spec.to_anomaly_type();
LabeledAnomaly {
anomaly_id,
anomaly_type: anomaly_type.clone(),
document_id: modified.header.document_id.to_string(),
document_type: "JournalEntry".to_string(),
company_code: modified.header.company_code.clone(),
anomaly_date: modified.header.posting_date,
detection_timestamp: Utc::now().naive_utc(),
confidence: 1.0, severity: anomaly_type.severity(),
description: spec.description(),
related_entities: vec![original.header.document_id.to_string()],
monetary_impact: Some(modified.total_debit()),
metadata: HashMap::new(),
is_injected: true,
injection_strategy: Some(strategy.description()),
cluster_id: None,
original_document_hash: Some(format!("{:x}", hash_entry(original))),
causal_reason: Some(AnomalyCausalReason::MLTrainingBalance {
target_class: "counterfactual".to_string(),
}),
structured_strategy: Some(strategy.clone()),
parent_anomaly_id: None,
child_anomaly_ids: vec![],
scenario_id: None,
run_id: None,
generation_seed: Some(self.seed),
}
}
}
fn hash_entry(entry: &JournalEntry) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
entry.header.document_id.hash(&mut hasher);
entry.header.company_code.hash(&mut hasher);
entry.header.posting_date.hash(&mut hasher);
entry.lines.len().hash(&mut hasher);
hasher.finish()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CounterfactualConfig {
pub seed: u64,
pub variants_per_original: usize,
pub specifications: Vec<CounterfactualSpec>,
pub include_originals: bool,
}
impl Default for CounterfactualConfig {
fn default() -> Self {
Self {
seed: 42,
variants_per_original: 3,
specifications: vec![
CounterfactualSpec::ScaleAmount { factor: 1.5 },
CounterfactualSpec::ScaleAmount { factor: 2.0 },
CounterfactualSpec::ScaleAmount { factor: 0.5 },
CounterfactualSpec::ShiftDate { days: -7 },
CounterfactualSpec::ShiftDate { days: 30 },
CounterfactualSpec::SelfApprove,
],
include_originals: true,
}
}
}
use rust_decimal::prelude::*;
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use chrono::NaiveDate;
use datasynth_core::models::{JournalEntryHeader, JournalEntryLine};
fn create_test_entry() -> JournalEntry {
let header = JournalEntryHeader::new(
"TEST".to_string(),
NaiveDate::from_ymd_opt(2024, 6, 15).unwrap(),
);
let mut entry = JournalEntry::new(header);
entry.add_line(JournalEntryLine::debit(
entry.header.document_id,
1,
"1100".to_string(),
Decimal::new(10000, 2), ));
entry.add_line(JournalEntryLine::credit(
entry.header.document_id,
2,
"2000".to_string(),
Decimal::new(10000, 2), ));
entry
}
#[test]
fn test_counterfactual_generator_scale_amount() {
let mut generator = CounterfactualGenerator::new(42);
let original = create_test_entry();
let spec = CounterfactualSpec::ScaleAmount { factor: 2.0 };
let pair = generator.generate(&original, &spec);
assert_eq!(pair.original.total_debit(), Decimal::new(10000, 2));
assert_eq!(pair.modified.total_debit(), Decimal::new(20000, 2));
assert!(!pair.modified.header.is_fraud);
}
#[test]
fn test_counterfactual_generator_shift_date() {
let mut generator = CounterfactualGenerator::new(42);
let original = create_test_entry();
let spec = CounterfactualSpec::ShiftDate { days: -7 };
let pair = generator.generate(&original, &spec);
let expected_date = NaiveDate::from_ymd_opt(2024, 6, 8).unwrap();
assert_eq!(pair.modified.header.posting_date, expected_date);
}
#[test]
fn test_counterfactual_spec_to_anomaly_type() {
let spec = CounterfactualSpec::SelfApprove;
let anomaly_type = spec.to_anomaly_type();
assert!(matches!(
anomaly_type,
AnomalyType::Fraud(FraudType::SelfApproval)
));
}
#[test]
fn test_counterfactual_batch_generation() {
let mut generator = CounterfactualGenerator::new(42);
let original = create_test_entry();
let specs = vec![
CounterfactualSpec::ScaleAmount { factor: 1.5 },
CounterfactualSpec::ShiftDate { days: -3 },
CounterfactualSpec::SelfApprove,
];
let pairs = generator.generate_batch(&original, &specs);
assert_eq!(pairs.len(), 3);
assert!(!pairs[0].modified.header.is_fraud); assert!(pairs[1].modified.header.is_fraud); assert!(pairs[2].modified.header.is_fraud); }
}