Skip to main content

datasynth_eval/
diff_engine.rs

1//! Diff engine for comparing baseline vs counterfactual output directories.
2
3use crate::scenario_diff::*;
4use std::collections::{HashMap, HashSet};
5use std::path::Path;
6use thiserror::Error;
7
8/// Errors from the diff engine.
9#[derive(Debug, Error)]
10pub enum DiffError {
11    #[error("IO error: {0}")]
12    Io(#[from] std::io::Error),
13    #[error("CSV parse error: {0}")]
14    CsvParse(String),
15    #[error("mismatched schemas: baseline has {baseline} columns, counterfactual has {counterfactual} for file {file}")]
16    MismatchedSchemas {
17        file: String,
18        baseline: usize,
19        counterfactual: usize,
20    },
21}
22
23/// Diff format options.
24#[derive(Debug, Clone, PartialEq)]
25pub enum DiffFormat {
26    Summary,
27    RecordLevel,
28    Aggregate,
29}
30
31/// Configuration for diff computation.
32#[derive(Debug, Clone)]
33pub struct DiffConfig {
34    pub formats: Vec<DiffFormat>,
35    /// Files to compare (empty = all CSV files found in baseline directory).
36    pub scope: Vec<String>,
37    pub max_sample_changes: usize,
38}
39
40impl Default for DiffConfig {
41    fn default() -> Self {
42        Self {
43            formats: vec![DiffFormat::Summary, DiffFormat::Aggregate],
44            scope: vec![],
45            max_sample_changes: 1000,
46        }
47    }
48}
49
50/// Engine for computing diffs between baseline and counterfactual outputs.
51pub struct DiffEngine;
52
53impl DiffEngine {
54    /// Compute a diff between baseline and counterfactual directories.
55    pub fn compute(
56        baseline_path: &Path,
57        counterfactual_path: &Path,
58        config: &DiffConfig,
59    ) -> Result<ScenarioDiff, DiffError> {
60        let summary = if config.formats.contains(&DiffFormat::Summary) {
61            Some(Self::compute_summary(baseline_path, counterfactual_path)?)
62        } else {
63            None
64        };
65
66        let record_level = if config.formats.contains(&DiffFormat::RecordLevel) {
67            Some(Self::compute_record_level(
68                baseline_path,
69                counterfactual_path,
70                &config.scope,
71                config.max_sample_changes,
72            )?)
73        } else {
74            None
75        };
76
77        let aggregate = if config.formats.contains(&DiffFormat::Aggregate) {
78            Some(Self::compute_aggregate(baseline_path, counterfactual_path)?)
79        } else {
80            None
81        };
82
83        Ok(ScenarioDiff {
84            summary,
85            record_level,
86            aggregate,
87            intervention_trace: None, // populated separately by causal engine
88        })
89    }
90
91    /// Compute impact summary from the two directories.
92    fn compute_summary(
93        baseline_path: &Path,
94        counterfactual_path: &Path,
95    ) -> Result<ImpactSummary, DiffError> {
96        let mut kpi_impacts = Vec::new();
97
98        // Compare journal_entries.csv if present
99        let baseline_je = baseline_path.join("journal_entries.csv");
100        let counter_je = counterfactual_path.join("journal_entries.csv");
101
102        if baseline_je.exists() && counter_je.exists() {
103            let baseline_stats = Self::csv_stats(&baseline_je)?;
104            let counter_stats = Self::csv_stats(&counter_je)?;
105
106            // Record count KPI
107            let b_count = baseline_stats.record_count as f64;
108            let c_count = counter_stats.record_count as f64;
109            kpi_impacts.push(Self::make_kpi("total_transactions", b_count, c_count));
110
111            // Total amount KPI (sum of first numeric column after ID)
112            if let (Some(b_sum), Some(c_sum)) =
113                (baseline_stats.numeric_sum, counter_stats.numeric_sum)
114            {
115                kpi_impacts.push(Self::make_kpi("total_amount", b_sum, c_sum));
116            }
117        }
118
119        // Compare anomaly_labels.csv if present
120        let baseline_al = baseline_path.join("anomaly_labels.csv");
121        let counter_al = counterfactual_path.join("anomaly_labels.csv");
122        let anomaly_impact = if baseline_al.exists() && counter_al.exists() {
123            let b_stats = Self::csv_stats(&baseline_al)?;
124            let c_stats = Self::csv_stats(&counter_al)?;
125            let b_count = b_stats.record_count;
126            let c_count = c_stats.record_count;
127            let rate_change = if b_count > 0 {
128                ((c_count as f64 - b_count as f64) / b_count as f64) * 100.0
129            } else if c_count > 0 {
130                100.0
131            } else {
132                0.0
133            };
134
135            // Parse anomaly types from both files
136            let b_types = Self::extract_anomaly_types(&baseline_al)?;
137            let c_types = Self::extract_anomaly_types(&counter_al)?;
138
139            let new_types: Vec<String> = c_types.difference(&b_types).cloned().collect();
140            let removed_types: Vec<String> = b_types.difference(&c_types).cloned().collect();
141
142            Some(AnomalyImpact {
143                baseline_count: b_count,
144                counterfactual_count: c_count,
145                new_types,
146                removed_types,
147                rate_change_pct: rate_change,
148            })
149        } else {
150            None
151        };
152
153        // Compute financial statement impacts if trial_balance.csv exists
154        let financial_statement_impacts =
155            Self::compute_financial_impacts(baseline_path, counterfactual_path)?;
156
157        Ok(ImpactSummary {
158            scenario_name: String::new(),
159            generation_timestamp: chrono::Utc::now().to_rfc3339(),
160            interventions_applied: 0,
161            kpi_impacts,
162            financial_statement_impacts,
163            anomaly_impact,
164            control_impact: None,
165        })
166    }
167
168    /// Compute record-level diffs for CSV files.
169    fn compute_record_level(
170        baseline_path: &Path,
171        counterfactual_path: &Path,
172        scope: &[String],
173        max_samples: usize,
174    ) -> Result<Vec<RecordLevelDiff>, DiffError> {
175        let files = if scope.is_empty() {
176            Self::find_csv_files(baseline_path)?
177        } else {
178            scope.to_vec()
179        };
180
181        let mut diffs = Vec::new();
182        for file in &files {
183            let b_path = baseline_path.join(file);
184            let c_path = counterfactual_path.join(file);
185
186            if !b_path.exists() || !c_path.exists() {
187                continue;
188            }
189
190            let diff = Self::diff_csv_file(&b_path, &c_path, file, max_samples)?;
191            diffs.push(diff);
192        }
193        Ok(diffs)
194    }
195
196    /// Compute aggregate comparison.
197    fn compute_aggregate(
198        baseline_path: &Path,
199        counterfactual_path: &Path,
200    ) -> Result<AggregateComparison, DiffError> {
201        let files = Self::find_csv_files(baseline_path)?;
202        let mut metrics = Vec::new();
203
204        for file in &files {
205            let b_path = baseline_path.join(file);
206            let c_path = counterfactual_path.join(file);
207
208            if !c_path.exists() {
209                continue;
210            }
211
212            let b_stats = Self::csv_stats(&b_path)?;
213            let c_stats = Self::csv_stats(&c_path)?;
214
215            let b_count = b_stats.record_count as f64;
216            let c_count = c_stats.record_count as f64;
217            let change_pct = if b_count > 0.0 {
218                ((c_count - b_count) / b_count) * 100.0
219            } else {
220                0.0
221            };
222
223            metrics.push(MetricComparison {
224                metric_name: format!("{}_record_count", file.trim_end_matches(".csv")),
225                baseline: b_count,
226                counterfactual: c_count,
227                change_pct,
228            });
229        }
230
231        Ok(AggregateComparison {
232            metrics,
233            period_comparisons: vec![],
234        })
235    }
236
237    /// Create a KpiImpact from baseline and counterfactual values.
238    fn make_kpi(name: &str, baseline: f64, counterfactual: f64) -> KpiImpact {
239        let abs = counterfactual - baseline;
240        let pct = if baseline.abs() > f64::EPSILON {
241            (abs / baseline) * 100.0
242        } else {
243            0.0
244        };
245        let direction = if abs > f64::EPSILON {
246            ChangeDirection::Increase
247        } else if abs < -f64::EPSILON {
248            ChangeDirection::Decrease
249        } else {
250            ChangeDirection::Unchanged
251        };
252        KpiImpact {
253            kpi_name: name.to_string(),
254            baseline_value: baseline,
255            counterfactual_value: counterfactual,
256            absolute_change: abs,
257            percent_change: pct,
258            direction,
259        }
260    }
261
262    /// Compute basic CSV statistics (record count, column count, first numeric column sum).
263    fn csv_stats(path: &Path) -> Result<CsvStats, DiffError> {
264        let content = std::fs::read_to_string(path)?;
265        let mut lines = content.lines();
266        let header = lines.next().unwrap_or("");
267        let col_count = header.split(',').count();
268
269        let mut record_count = 0;
270        let mut numeric_sum: Option<f64> = None;
271
272        for line in lines {
273            if line.trim().is_empty() {
274                continue;
275            }
276            record_count += 1;
277            // Try to find a numeric column to sum (skip first column as ID)
278            let fields: Vec<&str> = line.split(',').collect();
279            for field in fields.iter().skip(1) {
280                let trimmed = field.trim().trim_matches('"');
281                if let Ok(val) = trimmed.parse::<f64>() {
282                    *numeric_sum.get_or_insert(0.0) += val;
283                    break;
284                }
285            }
286        }
287
288        Ok(CsvStats {
289            record_count,
290            _col_count: col_count,
291            numeric_sum,
292        })
293    }
294
295    /// Find all CSV files in a directory, sorted by name.
296    fn find_csv_files(dir: &Path) -> Result<Vec<String>, DiffError> {
297        let mut files = Vec::new();
298        if dir.is_dir() {
299            for entry in std::fs::read_dir(dir)? {
300                let entry = entry?;
301                let path = entry.path();
302                if path.extension().and_then(|e| e.to_str()) == Some("csv") {
303                    if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
304                        files.push(name.to_string());
305                    }
306                }
307            }
308        }
309        files.sort();
310        Ok(files)
311    }
312
313    /// Diff a single CSV file between baseline and counterfactual directories.
314    fn diff_csv_file(
315        baseline: &Path,
316        counterfactual: &Path,
317        file_name: &str,
318        max_samples: usize,
319    ) -> Result<RecordLevelDiff, DiffError> {
320        let b_content = std::fs::read_to_string(baseline)?;
321        let c_content = std::fs::read_to_string(counterfactual)?;
322
323        let b_records = Self::parse_csv_records(&b_content);
324        let c_records = Self::parse_csv_records(&c_content);
325
326        let b_ids: HashSet<&str> = b_records.keys().copied().collect();
327        let c_ids: HashSet<&str> = c_records.keys().copied().collect();
328
329        let added: Vec<&str> = c_ids.difference(&b_ids).copied().collect();
330        let removed: Vec<&str> = b_ids.difference(&c_ids).copied().collect();
331        let common: Vec<&str> = b_ids.intersection(&c_ids).copied().collect();
332
333        let mut modified_count = 0;
334        let mut unchanged_count = 0;
335        let mut sample_changes = Vec::new();
336
337        // Get header for field names
338        let header: Vec<&str> = b_content.lines().next().unwrap_or("").split(',').collect();
339
340        for id in &common {
341            let b_line = b_records[id];
342            let c_line = c_records[id];
343            if b_line == c_line {
344                unchanged_count += 1;
345            } else {
346                modified_count += 1;
347                if sample_changes.len() < max_samples {
348                    let b_fields: Vec<&str> = b_line.split(',').collect();
349                    let c_fields: Vec<&str> = c_line.split(',').collect();
350                    let mut field_changes = Vec::new();
351                    for (i, (bf, cf)) in b_fields.iter().zip(c_fields.iter()).enumerate() {
352                        if bf != cf {
353                            field_changes.push(FieldChange {
354                                field_name: header.get(i).unwrap_or(&"unknown").to_string(),
355                                baseline_value: bf.to_string(),
356                                counterfactual_value: cf.to_string(),
357                            });
358                        }
359                    }
360                    sample_changes.push(RecordChange {
361                        record_id: id.to_string(),
362                        change_type: RecordChangeType::Modified,
363                        field_changes,
364                    });
365                }
366            }
367        }
368
369        // Add samples for added records
370        for id in added
371            .iter()
372            .take(max_samples.saturating_sub(sample_changes.len()))
373        {
374            sample_changes.push(RecordChange {
375                record_id: id.to_string(),
376                change_type: RecordChangeType::Added,
377                field_changes: vec![],
378            });
379        }
380
381        // Add samples for removed records
382        for id in removed
383            .iter()
384            .take(max_samples.saturating_sub(sample_changes.len()))
385        {
386            sample_changes.push(RecordChange {
387                record_id: id.to_string(),
388                change_type: RecordChangeType::Removed,
389                field_changes: vec![],
390            });
391        }
392
393        Ok(RecordLevelDiff {
394            file_name: file_name.to_string(),
395            records_added: added.len(),
396            records_removed: removed.len(),
397            records_modified: modified_count,
398            records_unchanged: unchanged_count,
399            sample_changes,
400        })
401    }
402
403    /// Extract unique anomaly type values from an anomaly_labels CSV.
404    /// Looks for a column named "anomaly_type" or "type" in the header.
405    fn extract_anomaly_types(path: &Path) -> Result<HashSet<String>, DiffError> {
406        let content = std::fs::read_to_string(path)?;
407        let mut lines = content.lines();
408        let header = lines.next().unwrap_or("");
409        let columns: Vec<&str> = header.split(',').collect();
410
411        // Find the type column index
412        let type_col = columns
413            .iter()
414            .position(|c| {
415                let trimmed = c.trim().trim_matches('"').to_lowercase();
416                trimmed == "anomaly_type" || trimmed == "type"
417            })
418            .unwrap_or(1); // Default to second column if not found
419
420        let mut types = HashSet::new();
421        for line in lines {
422            if line.trim().is_empty() {
423                continue;
424            }
425            let fields: Vec<&str> = line.split(',').collect();
426            if let Some(field) = fields.get(type_col) {
427                let val = field.trim().trim_matches('"').to_string();
428                if !val.is_empty() {
429                    types.insert(val);
430                }
431            }
432        }
433        Ok(types)
434    }
435
436    /// Compute financial statement impacts by comparing trial_balance.csv
437    /// or balance_sheet.csv between baseline and counterfactual.
438    fn compute_financial_impacts(
439        baseline_path: &Path,
440        counterfactual_path: &Path,
441    ) -> Result<Option<FinancialStatementImpact>, DiffError> {
442        // Try trial_balance.csv first, then balance_sheet.csv
443        let file_candidates = ["trial_balance.csv", "balance_sheet.csv"];
444        let mut b_file = None;
445        let mut c_file = None;
446
447        for candidate in &file_candidates {
448            let bp = baseline_path.join(candidate);
449            let cp = counterfactual_path.join(candidate);
450            if bp.exists() && cp.exists() {
451                b_file = Some(bp);
452                c_file = Some(cp);
453                break;
454            }
455        }
456
457        let (b_path, c_path) = match (b_file, c_file) {
458            (Some(b), Some(c)) => (b, c),
459            _ => return Ok(None),
460        };
461
462        let b_items = Self::parse_financial_line_items(&b_path)?;
463        let c_items = Self::parse_financial_line_items(&c_path)?;
464
465        let pct_change = |key: &str| -> f64 {
466            let b_val = b_items.get(key).copied().unwrap_or(0.0);
467            let c_val = c_items.get(key).copied().unwrap_or(0.0);
468            if b_val.abs() > f64::EPSILON {
469                ((c_val - b_val) / b_val) * 100.0
470            } else {
471                0.0
472            }
473        };
474
475        // Collect top changed line items
476        let mut line_item_impacts: Vec<LineItemImpact> = b_items
477            .keys()
478            .chain(c_items.keys())
479            .collect::<HashSet<_>>()
480            .into_iter()
481            .filter_map(|key| {
482                let b_val = b_items.get(key).copied().unwrap_or(0.0);
483                let c_val = c_items.get(key).copied().unwrap_or(0.0);
484                let change = if b_val.abs() > f64::EPSILON {
485                    ((c_val - b_val) / b_val) * 100.0
486                } else {
487                    0.0
488                };
489                if change.abs() > f64::EPSILON {
490                    Some(LineItemImpact {
491                        line_item: key.clone(),
492                        baseline: b_val,
493                        counterfactual: c_val,
494                        change_pct: change,
495                    })
496                } else {
497                    None
498                }
499            })
500            .collect();
501
502        // Sort by absolute change percentage, descending
503        line_item_impacts.sort_by(|a, b| {
504            b.change_pct
505                .abs()
506                .partial_cmp(&a.change_pct.abs())
507                .unwrap_or(std::cmp::Ordering::Equal)
508        });
509        line_item_impacts.truncate(10);
510
511        Ok(Some(FinancialStatementImpact {
512            revenue_change_pct: pct_change("revenue"),
513            cogs_change_pct: pct_change("cogs"),
514            margin_change_pct: pct_change("gross_margin"),
515            net_income_change_pct: pct_change("net_income"),
516            total_assets_change_pct: pct_change("total_assets"),
517            total_liabilities_change_pct: pct_change("total_liabilities"),
518            cash_flow_change_pct: pct_change("cash_flow"),
519            top_changed_line_items: line_item_impacts,
520        }))
521    }
522
523    /// Parse a financial CSV into a map of line item name → value.
524    /// Expects columns like: account/line_item, amount/balance/value.
525    fn parse_financial_line_items(path: &Path) -> Result<HashMap<String, f64>, DiffError> {
526        let content = std::fs::read_to_string(path)?;
527        let mut lines = content.lines();
528        let header = lines.next().unwrap_or("");
529        let columns: Vec<&str> = header.split(',').collect();
530
531        // Find name and value column indices
532        let name_col = columns
533            .iter()
534            .position(|c| {
535                let t = c.trim().trim_matches('"').to_lowercase();
536                t == "account" || t == "line_item" || t == "item" || t == "name"
537            })
538            .unwrap_or(0);
539
540        let value_col = columns
541            .iter()
542            .position(|c| {
543                let t = c.trim().trim_matches('"').to_lowercase();
544                t == "amount" || t == "balance" || t == "value" || t == "total"
545            })
546            .unwrap_or(1);
547
548        let mut items = HashMap::new();
549        for line in lines {
550            if line.trim().is_empty() {
551                continue;
552            }
553            let fields: Vec<&str> = line.split(',').collect();
554            if let (Some(name), Some(val_str)) = (fields.get(name_col), fields.get(value_col)) {
555                let name = name.trim().trim_matches('"').to_lowercase();
556                let val = val_str
557                    .trim()
558                    .trim_matches('"')
559                    .parse::<f64>()
560                    .unwrap_or(0.0);
561                items.insert(name, val);
562            }
563        }
564        Ok(items)
565    }
566
567    /// Parse CSV content into a map of (first-column value) -> (full line).
568    fn parse_csv_records(content: &str) -> HashMap<&str, &str> {
569        let mut records = HashMap::new();
570        for (i, line) in content.lines().enumerate() {
571            if i == 0 || line.trim().is_empty() {
572                continue; // skip header
573            }
574            let id = line.split(',').next().unwrap_or("");
575            records.insert(id, line);
576        }
577        records
578    }
579}
580
581/// Internal statistics for a CSV file.
582struct CsvStats {
583    record_count: usize,
584    _col_count: usize,
585    numeric_sum: Option<f64>,
586}
587
588#[cfg(test)]
589#[allow(clippy::unwrap_used)]
590mod tests {
591    use super::*;
592    use std::fs;
593    use tempfile::TempDir;
594
595    fn write_csv(dir: &Path, name: &str, content: &str) {
596        fs::write(dir.join(name), content).unwrap();
597    }
598
599    #[test]
600    fn test_diff_identical_dirs() {
601        let baseline = TempDir::new().unwrap();
602        let counter = TempDir::new().unwrap();
603
604        let csv = "id,amount,desc\n1,100.0,test\n2,200.0,test2\n";
605        write_csv(baseline.path(), "data.csv", csv);
606        write_csv(counter.path(), "data.csv", csv);
607
608        let config = DiffConfig {
609            formats: vec![
610                DiffFormat::Summary,
611                DiffFormat::RecordLevel,
612                DiffFormat::Aggregate,
613            ],
614            ..Default::default()
615        };
616
617        let diff = DiffEngine::compute(baseline.path(), counter.path(), &config).unwrap();
618
619        // Record level should show no changes
620        let records = diff.record_level.unwrap();
621        assert_eq!(records.len(), 1);
622        assert_eq!(records[0].records_modified, 0);
623        assert_eq!(records[0].records_added, 0);
624        assert_eq!(records[0].records_removed, 0);
625        assert_eq!(records[0].records_unchanged, 2);
626    }
627
628    #[test]
629    fn test_diff_record_added() {
630        let baseline = TempDir::new().unwrap();
631        let counter = TempDir::new().unwrap();
632
633        write_csv(baseline.path(), "data.csv", "id,amount\n1,100.0\n");
634        write_csv(counter.path(), "data.csv", "id,amount\n1,100.0\n2,200.0\n");
635
636        let config = DiffConfig {
637            formats: vec![DiffFormat::RecordLevel],
638            ..Default::default()
639        };
640
641        let diff = DiffEngine::compute(baseline.path(), counter.path(), &config).unwrap();
642        let records = diff.record_level.unwrap();
643        assert_eq!(records[0].records_added, 1);
644        assert_eq!(records[0].records_unchanged, 1);
645    }
646
647    #[test]
648    fn test_diff_field_changed() {
649        let baseline = TempDir::new().unwrap();
650        let counter = TempDir::new().unwrap();
651
652        write_csv(baseline.path(), "data.csv", "id,amount\n1,100.0\n2,200.0\n");
653        write_csv(counter.path(), "data.csv", "id,amount\n1,150.0\n2,200.0\n");
654
655        let config = DiffConfig {
656            formats: vec![DiffFormat::RecordLevel],
657            ..Default::default()
658        };
659
660        let diff = DiffEngine::compute(baseline.path(), counter.path(), &config).unwrap();
661        let records = diff.record_level.unwrap();
662        assert_eq!(records[0].records_modified, 1);
663        assert_eq!(records[0].records_unchanged, 1);
664        assert_eq!(records[0].sample_changes.len(), 1);
665        assert_eq!(
666            records[0].sample_changes[0].field_changes[0].field_name,
667            "amount"
668        );
669    }
670
671    #[test]
672    fn test_diff_summary_kpis() {
673        let baseline = TempDir::new().unwrap();
674        let counter = TempDir::new().unwrap();
675
676        write_csv(
677            baseline.path(),
678            "journal_entries.csv",
679            "id,amount\n1,100.0\n2,200.0\n",
680        );
681        write_csv(
682            counter.path(),
683            "journal_entries.csv",
684            "id,amount\n1,150.0\n2,200.0\n3,50.0\n",
685        );
686
687        let config = DiffConfig {
688            formats: vec![DiffFormat::Summary],
689            ..Default::default()
690        };
691
692        let diff = DiffEngine::compute(baseline.path(), counter.path(), &config).unwrap();
693        let summary = diff.summary.unwrap();
694        assert_eq!(summary.kpi_impacts.len(), 2); // transaction count + total_amount
695
696        let tx_kpi = summary
697            .kpi_impacts
698            .iter()
699            .find(|k| k.kpi_name == "total_transactions")
700            .unwrap();
701        assert_eq!(tx_kpi.baseline_value, 2.0);
702        assert_eq!(tx_kpi.counterfactual_value, 3.0);
703        assert_eq!(tx_kpi.direction, ChangeDirection::Increase);
704    }
705
706    #[test]
707    fn test_diff_anomaly_types_new_and_removed() {
708        let baseline = TempDir::new().unwrap();
709        let counter = TempDir::new().unwrap();
710
711        write_csv(
712            baseline.path(),
713            "anomaly_labels.csv",
714            "id,anomaly_type,severity\n1,FictitiousTransaction,high\n2,DuplicateEntry,medium\n",
715        );
716        write_csv(
717            counter.path(),
718            "anomaly_labels.csv",
719            "id,anomaly_type,severity\n1,DuplicateEntry,medium\n2,SplitTransaction,high\n3,BenfordViolation,low\n",
720        );
721
722        let config = DiffConfig {
723            formats: vec![DiffFormat::Summary],
724            ..Default::default()
725        };
726
727        let diff = DiffEngine::compute(baseline.path(), counter.path(), &config).unwrap();
728        let summary = diff.summary.unwrap();
729        let anomaly = summary.anomaly_impact.unwrap();
730
731        assert_eq!(anomaly.baseline_count, 2);
732        assert_eq!(anomaly.counterfactual_count, 3);
733        assert!(anomaly.new_types.contains(&"SplitTransaction".to_string()));
734        assert!(anomaly.new_types.contains(&"BenfordViolation".to_string()));
735        assert!(anomaly
736            .removed_types
737            .contains(&"FictitiousTransaction".to_string()));
738        assert!(!anomaly.new_types.contains(&"DuplicateEntry".to_string()));
739    }
740
741    #[test]
742    fn test_diff_financial_statement_impacts() {
743        let baseline = TempDir::new().unwrap();
744        let counter = TempDir::new().unwrap();
745
746        write_csv(
747            baseline.path(),
748            "trial_balance.csv",
749            "account,amount\nrevenue,1000000.0\ncogs,600000.0\ntotal_assets,5000000.0\n",
750        );
751        write_csv(
752            counter.path(),
753            "trial_balance.csv",
754            "account,amount\nrevenue,850000.0\ncogs,550000.0\ntotal_assets,4800000.0\n",
755        );
756
757        let config = DiffConfig {
758            formats: vec![DiffFormat::Summary],
759            ..Default::default()
760        };
761
762        let diff = DiffEngine::compute(baseline.path(), counter.path(), &config).unwrap();
763        let summary = diff.summary.unwrap();
764        let fi = summary.financial_statement_impacts.unwrap();
765
766        assert!(fi.revenue_change_pct < 0.0); // Revenue decreased
767        assert!(fi.total_assets_change_pct < 0.0); // Assets decreased
768        assert!(!fi.top_changed_line_items.is_empty());
769    }
770
771    #[test]
772    fn test_diff_aggregate() {
773        let baseline = TempDir::new().unwrap();
774        let counter = TempDir::new().unwrap();
775
776        write_csv(baseline.path(), "data.csv", "id,val\n1,10\n2,20\n");
777        write_csv(counter.path(), "data.csv", "id,val\n1,10\n2,20\n3,30\n");
778
779        let config = DiffConfig {
780            formats: vec![DiffFormat::Aggregate],
781            ..Default::default()
782        };
783
784        let diff = DiffEngine::compute(baseline.path(), counter.path(), &config).unwrap();
785        let agg = diff.aggregate.unwrap();
786        assert_eq!(agg.metrics.len(), 1);
787        assert_eq!(agg.metrics[0].baseline, 2.0);
788        assert_eq!(agg.metrics[0].counterfactual, 3.0);
789    }
790}