Skip to main content

datasynth_generators/balance/
balance_tracker.rs

1//! Running balance tracker.
2//!
3//! Maintains real-time account balances as journal entries are processed,
4//! with continuous validation of balance sheet integrity.
5
6use chrono::{Datelike, NaiveDate};
7use rust_decimal::Decimal;
8use rust_decimal_macros::dec;
9use std::collections::HashMap;
10use tracing::debug;
11
12use datasynth_core::distributions::behavioral_priors::TbAnchorPrior;
13use datasynth_core::models::balance::{
14    AccountBalance, AccountPeriodActivity, AccountType, BalanceSnapshot,
15};
16use datasynth_core::models::JournalEntry;
17use datasynth_core::FrameworkAccounts;
18
19/// Configuration for the balance tracker.
20#[derive(Debug, Clone)]
21pub struct BalanceTrackerConfig {
22    /// Whether to validate balance sheet equation after each entry.
23    pub validate_on_each_entry: bool,
24    /// Whether to track balance history.
25    pub track_history: bool,
26    /// Tolerance for balance sheet validation (for rounding).
27    pub balance_tolerance: Decimal,
28    /// Whether to fail on validation errors.
29    pub fail_on_validation_error: bool,
30}
31
32impl Default for BalanceTrackerConfig {
33    fn default() -> Self {
34        Self {
35            validate_on_each_entry: true,
36            track_history: true,
37            balance_tolerance: dec!(0.01),
38            fail_on_validation_error: false,
39        }
40    }
41}
42
43/// Tracks running balances for all accounts across companies.
44pub struct RunningBalanceTracker {
45    config: BalanceTrackerConfig,
46    /// Balances by company code -> account code -> balance.
47    balances: HashMap<String, HashMap<String, AccountBalance>>,
48    /// Account type registry for determining debit/credit behavior.
49    account_types: HashMap<String, AccountType>,
50    /// Framework-aware account classification.
51    framework_accounts: FrameworkAccounts,
52    /// Balance history by company code.
53    history: HashMap<String, Vec<BalanceHistoryEntry>>,
54    /// Validation errors encountered.
55    validation_errors: Vec<ValidationError>,
56    /// Statistics.
57    stats: TrackerStatistics,
58    /// Default currency for new account balances and snapshots.
59    currency: String,
60    /// SP4.1 — optional TB anchor prior.  When `Some`, `account_drift()` and
61    /// `drift_correction_needed()` compare running balances against targets.
62    tb_anchor: Option<TbAnchorPrior>,
63}
64
65/// Entry in balance history.
66#[derive(Debug, Clone)]
67pub struct BalanceHistoryEntry {
68    pub date: NaiveDate,
69    pub entry_id: String,
70    pub account_code: String,
71    pub previous_balance: Decimal,
72    pub change: Decimal,
73    pub new_balance: Decimal,
74}
75
76/// Validation error details.
77#[derive(Debug, Clone)]
78pub struct ValidationError {
79    pub date: NaiveDate,
80    pub company_code: String,
81    pub entry_id: Option<String>,
82    pub error_type: ValidationErrorType,
83    pub message: String,
84    pub details: HashMap<String, Decimal>,
85}
86
87/// Types of validation errors.
88#[derive(Debug, Clone, PartialEq, Eq)]
89pub enum ValidationErrorType {
90    /// Entry debits don't equal credits.
91    UnbalancedEntry,
92    /// Balance sheet equation violated.
93    BalanceSheetImbalance,
94    /// Account has negative balance where not allowed.
95    NegativeBalance,
96    /// Unknown account code.
97    UnknownAccount,
98    /// Entry applied out of chronological order.
99    OutOfOrder,
100}
101
102/// Statistics about tracked entries.
103#[derive(Debug, Clone, Default)]
104pub struct TrackerStatistics {
105    pub entries_processed: u64,
106    pub lines_processed: u64,
107    pub total_debits: Decimal,
108    pub total_credits: Decimal,
109    pub companies_tracked: usize,
110    pub accounts_tracked: usize,
111    pub validation_errors: usize,
112}
113
114impl RunningBalanceTracker {
115    /// Creates a new balance tracker with the specified currency and accounting framework.
116    pub fn new_with_currency_and_framework(
117        config: BalanceTrackerConfig,
118        currency: String,
119        framework: &str,
120    ) -> Self {
121        Self {
122            config,
123            balances: HashMap::new(),
124            account_types: HashMap::new(),
125            framework_accounts: FrameworkAccounts::for_framework(framework),
126            history: HashMap::new(),
127            validation_errors: Vec::new(),
128            stats: TrackerStatistics::default(),
129            currency,
130            tb_anchor: None,
131        }
132    }
133
134    /// Creates a new balance tracker with the specified currency (defaults to US GAAP).
135    pub fn new_with_currency(config: BalanceTrackerConfig, currency: String) -> Self {
136        Self::new_with_currency_and_framework(config, currency, "us_gaap")
137    }
138
139    /// Creates a new balance tracker (defaults to USD and US GAAP).
140    pub fn new(config: BalanceTrackerConfig) -> Self {
141        Self::new_with_currency(config, "USD".to_string())
142    }
143
144    /// Creates a new balance tracker for a specific accounting framework (defaults to USD).
145    pub fn new_with_framework(config: BalanceTrackerConfig, framework: &str) -> Self {
146        Self::new_with_currency_and_framework(config, "USD".to_string(), framework)
147    }
148
149    /// Creates a tracker with default configuration (US GAAP).
150    pub fn with_defaults() -> Self {
151        Self::new(BalanceTrackerConfig::default())
152    }
153
154    /// Registers an account type for balance tracking.
155    pub fn register_account_type(&mut self, account_code: &str, account_type: AccountType) {
156        self.account_types
157            .insert(account_code.to_string(), account_type);
158    }
159
160    /// Registers multiple account types.
161    pub fn register_account_types(&mut self, types: &[(String, AccountType)]) {
162        for (code, account_type) in types {
163            self.account_types.insert(code.clone(), *account_type);
164        }
165    }
166
167    /// Registers account types from a chart of accounts prefix pattern.
168    pub fn register_from_chart_prefixes(&mut self, prefixes: &[(&str, AccountType)]) {
169        for (prefix, account_type) in prefixes {
170            self.account_types.insert(prefix.to_string(), *account_type);
171        }
172    }
173
174    /// Initializes balances from opening balance snapshot.
175    pub fn initialize_from_snapshot(&mut self, snapshot: &BalanceSnapshot) {
176        let company_balances = self
177            .balances
178            .entry(snapshot.company_code.clone())
179            .or_default();
180
181        for (account_code, balance) in &snapshot.balances {
182            company_balances.insert(account_code.clone(), balance.clone());
183        }
184
185        self.stats.companies_tracked = self.balances.len();
186        self.stats.accounts_tracked = self
187            .balances
188            .values()
189            .map(std::collections::HashMap::len)
190            .sum();
191    }
192
193    /// Applies a journal entry to the running balances.
194    pub fn apply_entry(&mut self, entry: &JournalEntry) -> Result<(), ValidationError> {
195        // Validate entry is balanced first
196        if !entry.is_balanced() {
197            let error = ValidationError {
198                date: entry.posting_date(),
199                company_code: entry.company_code().to_string(),
200                entry_id: Some(entry.document_number().clone()),
201                error_type: ValidationErrorType::UnbalancedEntry,
202                message: format!(
203                    "Entry {} is unbalanced: debits={}, credits={}",
204                    entry.document_number(),
205                    entry.total_debit(),
206                    entry.total_credit()
207                ),
208                details: {
209                    let mut d = HashMap::new();
210                    d.insert("total_debit".to_string(), entry.total_debit());
211                    d.insert("total_credit".to_string(), entry.total_credit());
212                    d
213                },
214            };
215
216            if self.config.fail_on_validation_error {
217                return Err(error);
218            }
219            self.validation_errors.push(error);
220        }
221
222        // Extract data we need before mutably borrowing balances
223        let company_code = entry.company_code().to_string();
224        let document_number = entry.document_number().clone();
225        let posting_date = entry.posting_date();
226        let track_history = self.config.track_history;
227
228        // Pre-compute account types for all lines
229        let line_data: Vec<_> = entry
230            .lines
231            .iter()
232            .map(|line| {
233                let account_type = self.determine_account_type(&line.account_code);
234                (line.clone(), account_type)
235            })
236            .collect();
237
238        // Get or create company balances
239        let company_balances = self.balances.entry(company_code.clone()).or_default();
240
241        // History entries to add
242        let mut history_entries = Vec::new();
243
244        // Apply each line
245        for (line, account_type) in &line_data {
246            // Get or create account balance
247            let balance = company_balances
248                .entry(line.account_code.clone())
249                .or_insert_with(|| {
250                    AccountBalance::new(
251                        company_code.clone(),
252                        line.account_code.clone(),
253                        *account_type,
254                        self.currency.clone(),
255                        posting_date.year(),
256                        posting_date.month(),
257                    )
258                });
259
260            let previous_balance = balance.closing_balance;
261
262            // Apply debit or credit
263            if line.debit_amount > Decimal::ZERO {
264                balance.apply_debit(line.debit_amount);
265            }
266            if line.credit_amount > Decimal::ZERO {
267                balance.apply_credit(line.credit_amount);
268            }
269
270            let new_balance = balance.closing_balance;
271
272            // Record history if configured
273            if track_history {
274                let change = line.debit_amount - line.credit_amount;
275                history_entries.push(BalanceHistoryEntry {
276                    date: posting_date,
277                    entry_id: document_number.clone(),
278                    account_code: line.account_code.clone(),
279                    previous_balance,
280                    change,
281                    new_balance,
282                });
283            }
284        }
285
286        // Add history entries after releasing the balances borrow
287        if !history_entries.is_empty() {
288            let hist = self.history.entry(company_code.clone()).or_default();
289            hist.extend(history_entries);
290        }
291
292        // Update statistics
293        self.stats.entries_processed += 1;
294        self.stats.lines_processed += entry.lines.len() as u64;
295        self.stats.total_debits += entry.total_debit();
296        self.stats.total_credits += entry.total_credit();
297        self.stats.companies_tracked = self.balances.len();
298        self.stats.accounts_tracked = self
299            .balances
300            .values()
301            .map(std::collections::HashMap::len)
302            .sum();
303
304        // Validate balance sheet if configured
305        if self.config.validate_on_each_entry {
306            self.validate_balance_sheet(
307                entry.company_code(),
308                entry.posting_date(),
309                Some(&entry.document_number()),
310            )?;
311        }
312
313        Ok(())
314    }
315
316    /// Applies a batch of entries.
317    pub fn apply_entries(&mut self, entries: &[JournalEntry]) -> Vec<ValidationError> {
318        debug!(
319            entry_count = entries.len(),
320            companies_tracked = self.stats.companies_tracked,
321            accounts_tracked = self.stats.accounts_tracked,
322            "Applying entries to balance tracker"
323        );
324
325        let mut errors = Vec::new();
326
327        for entry in entries {
328            if let Err(error) = self.apply_entry(entry) {
329                errors.push(error);
330            }
331        }
332
333        errors
334    }
335
336    /// Determines account type from code prefix.
337    ///
338    /// Checks explicitly registered types first, then falls back to the
339    /// framework-aware classifier from [`FrameworkAccounts`].
340    fn determine_account_type(&self, account_code: &str) -> AccountType {
341        // Check registered types first (exact match or prefix)
342        for (registered_code, account_type) in &self.account_types {
343            if account_code.starts_with(registered_code) {
344                return *account_type;
345            }
346        }
347
348        // Use framework-aware classification
349        self.framework_accounts.classify_account_type(account_code)
350    }
351
352    /// Validates the balance sheet equation for a company.
353    pub fn validate_balance_sheet(
354        &mut self,
355        company_code: &str,
356        date: NaiveDate,
357        entry_id: Option<&str>,
358    ) -> Result<(), ValidationError> {
359        let Some(company_balances) = self.balances.get(company_code) else {
360            return Ok(()); // No balances to validate
361        };
362
363        let mut total_assets = Decimal::ZERO;
364        let mut total_liabilities = Decimal::ZERO;
365        let mut total_equity = Decimal::ZERO;
366        let mut total_revenue = Decimal::ZERO;
367        let mut total_expenses = Decimal::ZERO;
368
369        for (account_code, balance) in company_balances {
370            let account_type = self.determine_account_type(account_code);
371            match account_type {
372                AccountType::Asset => total_assets += balance.closing_balance,
373                AccountType::ContraAsset => total_assets -= balance.closing_balance.abs(),
374                AccountType::Liability => total_liabilities += balance.closing_balance.abs(),
375                AccountType::ContraLiability => total_liabilities -= balance.closing_balance.abs(),
376                AccountType::Equity => total_equity += balance.closing_balance.abs(),
377                AccountType::ContraEquity => total_equity -= balance.closing_balance.abs(),
378                AccountType::Revenue => total_revenue += balance.closing_balance.abs(),
379                AccountType::Expense => total_expenses += balance.closing_balance.abs(),
380            }
381        }
382
383        // Net income = Revenue - Expenses
384        let net_income = total_revenue - total_expenses;
385
386        // Balance sheet equation: Assets = Liabilities + Equity + Net Income
387        let left_side = total_assets;
388        let right_side = total_liabilities + total_equity + net_income;
389        let difference = (left_side - right_side).abs();
390
391        if difference > self.config.balance_tolerance {
392            let error = ValidationError {
393                date,
394                company_code: company_code.to_string(),
395                entry_id: entry_id.map(String::from),
396                error_type: ValidationErrorType::BalanceSheetImbalance,
397                message: format!(
398                    "Balance sheet imbalance: Assets ({left_side}) != L + E + NI ({right_side}), diff = {difference}"
399                ),
400                details: {
401                    let mut d = HashMap::new();
402                    d.insert("total_assets".to_string(), total_assets);
403                    d.insert("total_liabilities".to_string(), total_liabilities);
404                    d.insert("total_equity".to_string(), total_equity);
405                    d.insert("net_income".to_string(), net_income);
406                    d.insert("difference".to_string(), difference);
407                    d
408                },
409            };
410
411            self.stats.validation_errors += 1;
412
413            if self.config.fail_on_validation_error {
414                return Err(error);
415            }
416            self.validation_errors.push(error);
417        }
418
419        Ok(())
420    }
421
422    /// Gets the current snapshot for a company.
423    pub fn get_snapshot(
424        &self,
425        company_code: &str,
426        as_of_date: NaiveDate,
427    ) -> Option<BalanceSnapshot> {
428        use chrono::Datelike;
429        let currency = self.currency.clone();
430        self.balances.get(company_code).map(|balances| {
431            let mut snapshot = BalanceSnapshot::new(
432                format!("SNAP-{company_code}-{as_of_date}"),
433                company_code.to_string(),
434                as_of_date,
435                as_of_date.year(),
436                as_of_date.month(),
437                currency,
438            );
439            for (account, balance) in balances {
440                snapshot.balances.insert(account.clone(), balance.clone());
441            }
442            snapshot.recalculate_totals();
443            snapshot
444        })
445    }
446
447    /// Gets snapshots for all companies.
448    pub fn get_all_snapshots(&self, as_of_date: NaiveDate) -> Vec<BalanceSnapshot> {
449        use chrono::Datelike;
450        self.balances
451            .iter()
452            .map(|(company_code, balances)| {
453                let mut snapshot = BalanceSnapshot::new(
454                    format!("SNAP-{company_code}-{as_of_date}"),
455                    company_code.clone(),
456                    as_of_date,
457                    as_of_date.year(),
458                    as_of_date.month(),
459                    self.currency.clone(),
460                );
461                for (account, balance) in balances {
462                    snapshot.balances.insert(account.clone(), balance.clone());
463                }
464                snapshot.recalculate_totals();
465                snapshot
466            })
467            .collect()
468    }
469
470    /// Gets balance changes for a period.
471    pub fn get_balance_changes(
472        &self,
473        company_code: &str,
474        from_date: NaiveDate,
475        to_date: NaiveDate,
476    ) -> Vec<AccountPeriodActivity> {
477        let Some(history) = self.history.get(company_code) else {
478            return Vec::new();
479        };
480
481        let mut changes_by_account: HashMap<String, AccountPeriodActivity> = HashMap::new();
482
483        for entry in history
484            .iter()
485            .filter(|e| e.date >= from_date && e.date <= to_date)
486        {
487            let change = changes_by_account
488                .entry(entry.account_code.clone())
489                .or_insert_with(|| AccountPeriodActivity {
490                    account_code: entry.account_code.clone(),
491                    period_start: from_date,
492                    period_end: to_date,
493                    opening_balance: Decimal::ZERO,
494                    closing_balance: Decimal::ZERO,
495                    total_debits: Decimal::ZERO,
496                    total_credits: Decimal::ZERO,
497                    net_change: Decimal::ZERO,
498                    transaction_count: 0,
499                });
500
501            if entry.change > Decimal::ZERO {
502                change.total_debits += entry.change;
503            } else {
504                change.total_credits += entry.change.abs();
505            }
506            change.net_change += entry.change;
507            change.transaction_count += 1;
508        }
509
510        // Update opening/closing balances
511        if let Some(company_balances) = self.balances.get(company_code) {
512            for change in changes_by_account.values_mut() {
513                if let Some(balance) = company_balances.get(&change.account_code) {
514                    change.closing_balance = balance.closing_balance;
515                    change.opening_balance = change.closing_balance - change.net_change;
516                }
517            }
518        }
519
520        changes_by_account.into_values().collect()
521    }
522
523    /// Gets balance for a specific account.
524    pub fn get_account_balance(
525        &self,
526        company_code: &str,
527        account_code: &str,
528    ) -> Option<&AccountBalance> {
529        self.balances
530            .get(company_code)
531            .and_then(|b| b.get(account_code))
532    }
533
534    // ---- SP4.1 — target-aware drift methods --------------------------------
535
536    /// SP4.1 — Attach a TB anchor prior to this tracker.  When set, enables
537    /// `account_drift()` and `drift_correction_needed()`.  Does not change
538    /// the existing balance-tracking behaviour — purely additive.
539    pub fn with_tb_anchor(mut self, anchor: TbAnchorPrior) -> Self {
540        self.tb_anchor = Some(anchor);
541        self
542    }
543
544    /// SP4.1 — Set the TB anchor prior on this tracker (mutable version).
545    pub fn set_tb_anchor(&mut self, anchor: TbAnchorPrior) {
546        self.tb_anchor = Some(anchor);
547    }
548
549    /// SP4.1 — Returns the per-account drift (current closing balance − target
550    /// closing balance) for every account that appears in the TB anchor prior.
551    ///
552    /// Positive drift means the synthetic account balance is higher than the
553    /// target; negative means it is lower.
554    ///
555    /// Returns an empty `Vec` when no TB anchor is loaded or no company has
556    /// been tracked yet.
557    pub fn account_drift(&self, company_code: &str) -> Vec<(String, f64)> {
558        let Some(anchor) = &self.tb_anchor else {
559            return Vec::new();
560        };
561        let company_balances = match self.balances.get(company_code) {
562            Some(b) => b,
563            None => return Vec::new(),
564        };
565
566        anchor
567            .per_account
568            .iter()
569            .map(|(account, target)| {
570                use rust_decimal::prelude::ToPrimitive;
571                let current = company_balances
572                    .get(account)
573                    .map(|b| b.closing_balance.to_f64().unwrap_or(0.0))
574                    .unwrap_or(0.0);
575                let drift = current - target.closing_balance;
576                (account.clone(), drift)
577            })
578            .collect()
579    }
580
581    /// SP4.1 / SP5.1 — Returns `true` when the TB anchor is loaded and any
582    /// single account's absolute drift exceeds `2 × closing_stdev` (or a
583    /// fallback of 2% of `|closing_balance|` when stdev is zero), OR the
584    /// aggregate absolute drift across all tracked accounts exceeds 0.5% of
585    /// `total_assets`.
586    ///
587    /// Thresholds were tuned to 2σ / 0.5% in SP5.1 (previously 3σ / 1%)
588    /// so that the drift-correction pass fires on realistic synthetic runs
589    /// where balances are shaped by priors but not pinned to the corpus median.
590    ///
591    /// Returns `false` when no TB anchor is loaded (backwards-compatible
592    /// behaviour — caller does not emit drift-correction entries).
593    pub fn drift_correction_needed(&self, company_code: &str) -> bool {
594        let Some(anchor) = &self.tb_anchor else {
595            return false;
596        };
597        if !anchor.has_data() {
598            return false;
599        }
600        let total_assets = anchor.total_assets.abs().max(1.0);
601
602        let drifts = self.account_drift(company_code);
603        if drifts.is_empty() {
604            return false;
605        }
606
607        // SP5.1 — Check per-account threshold: 2σ or 2% of |closing_balance|
608        // (previously 3σ or 5% of total_assets for the stdev=0 case).
609        for (account, drift) in &drifts {
610            if let Some(target) = anchor.per_account.get(account) {
611                let threshold = if target.closing_stdev > 1e-9 {
612                    2.0 * target.closing_stdev
613                } else {
614                    // No cross-client stdev — use 2% of |closing_balance| so
615                    // single-client targets fire on meaningful deviations without
616                    // requiring an unreachably-large per-account swing.
617                    (target.closing_balance.abs() * 0.02).max(1.0)
618                };
619                if drift.abs() > threshold {
620                    return true;
621                }
622            }
623        }
624
625        // SP5.1 — Aggregate threshold lowered from 1% → 0.5% of total_assets.
626        let aggregate_drift: f64 = drifts.iter().map(|(_, d)| d.abs()).sum();
627        if aggregate_drift > 0.005 * total_assets {
628            return true;
629        }
630
631        false
632    }
633
634    /// SP4.1 W8.1 — Build a balanced drift-correction JE that nudges the most-drifted
635    /// accounts back toward their TB anchor targets.
636    ///
637    /// The emitted JE:
638    /// - Has `document_type = "SA"` and `source = Adjustment` (period-end style).
639    /// - Includes at most 8 account lines (the worst drifters), plus one balancing line
640    ///   posted to suspense account "9999" when the selected lines don't net to zero.
641    /// - Always satisfies `total_debit == total_credit` (mandatory for `apply_entry`).
642    ///
643    /// Returns `None` when no TB anchor is loaded, when no drifts exceed the noise
644    /// floor, or when the resulting JE would have fewer than 2 lines.
645    ///
646    /// SP5.1: The net / balancing amount is now computed in Decimal (not f64) so
647    /// that f64→Decimal precision loss cannot produce an "unbalanced" JE.
648    pub fn build_drift_correction_je<R: rand::RngExt>(
649        &self,
650        company_code: &str,
651        posting_date: NaiveDate,
652        rng: &mut R,
653    ) -> Option<datasynth_core::models::JournalEntry> {
654        use datasynth_core::models::{
655            JournalEntry, JournalEntryHeader, JournalEntryLine, TransactionSource,
656        };
657        use rust_decimal::prelude::FromPrimitive;
658
659        // Only include accounts whose absolute drift exceeds 1% of target or $1.
660        let anchor = self.tb_anchor.as_ref()?;
661        let mut drifts: Vec<(String, f64)> = self
662            .account_drift(company_code)
663            .into_iter()
664            .filter(|(account, drift)| {
665                let threshold = anchor
666                    .per_account
667                    .get(account)
668                    .map(|t| (t.closing_balance.abs() * 0.01).max(1.0))
669                    .unwrap_or(1.0);
670                drift.abs() > threshold
671            })
672            .collect();
673
674        if drifts.is_empty() {
675            tracing::debug!(
676                target: "datasynth_generators::balance_tracker",
677                company = %company_code,
678                "W8.1 drift-correction: all drifts below noise floor — returning None"
679            );
680            return None;
681        }
682
683        // Take the top-8 worst drifters to keep the JE manageable.
684        drifts.sort_by(|a, b| {
685            b.1.abs()
686                .partial_cmp(&a.1.abs())
687                .unwrap_or(std::cmp::Ordering::Equal)
688        });
689        drifts.truncate(8);
690
691        let document_id = uuid::Uuid::now_v7();
692        let mut header = JournalEntryHeader::with_deterministic_id(
693            company_code.to_string(),
694            posting_date,
695            document_id,
696        );
697        header.source = TransactionSource::Adjustment;
698        header.document_type = "SA".to_string();
699        header.reference = Some(format!(
700            "DRIFT-CORR-{:08}",
701            rng.random_range(0u32..u32::MAX)
702        ));
703        header.header_text = Some("W8.1 Trial Balance Drift Correction".to_string());
704
705        let mut entry = JournalEntry::new(header);
706        let mut line_num = 1u32;
707
708        // SP5.1 — accumulate Decimal debit and credit totals as we add lines so
709        // the balancing suspense line can be computed from exact Decimal arithmetic
710        // rather than from the f64 `net`, avoiding f64→Decimal precision skew.
711        let mut decimal_debits = Decimal::ZERO;
712        let mut decimal_credits = Decimal::ZERO;
713
714        // For each drifted account: if drift > 0 (over-target) → credit; if < 0 (under-target) → debit.
715        for (account_number, drift) in &drifts {
716            let amount = match Decimal::from_f64(drift.abs()) {
717                Some(a) if a > Decimal::ZERO => a,
718                _ => continue,
719            };
720            let line = if *drift > 0.0 {
721                decimal_credits += amount;
722                JournalEntryLine::credit(document_id, line_num, account_number.clone(), amount)
723            } else {
724                decimal_debits += amount;
725                JournalEntryLine::debit(document_id, line_num, account_number.clone(), amount)
726            };
727            entry.add_line(line);
728            line_num += 1;
729        }
730
731        // SP5.1 — Compute the balancing amount in Decimal (not f64) to ensure
732        // `entry.is_balanced()` passes even for very large or fractional amounts.
733        let decimal_net = decimal_debits - decimal_credits;
734        if decimal_net.abs() > dec!(0.005) {
735            // decimal_net > 0 means debits exceed credits → credit the suspense.
736            // decimal_net < 0 means credits exceed debits → debit the suspense.
737            let balancing_line = if decimal_net > Decimal::ZERO {
738                JournalEntryLine::credit(
739                    document_id,
740                    line_num,
741                    "9999".to_string(),
742                    decimal_net.abs(),
743                )
744            } else {
745                JournalEntryLine::debit(
746                    document_id,
747                    line_num,
748                    "9999".to_string(),
749                    decimal_net.abs(),
750                )
751            };
752            entry.add_line(balancing_line);
753        }
754
755        if entry.lines.len() < 2 {
756            tracing::debug!(
757                target: "datasynth_generators::balance_tracker",
758                company = %company_code,
759                lines = entry.lines.len(),
760                "W8.1 drift-correction: too few lines — returning None"
761            );
762            return None;
763        }
764
765        if !entry.is_balanced() {
766            tracing::warn!(
767                target: "datasynth_generators::balance_tracker",
768                company = %company_code,
769                debit = %entry.total_debit(),
770                credit = %entry.total_credit(),
771                diff = %(entry.total_debit() - entry.total_credit()),
772                "W8.1 drift-correction: JE is unbalanced — returning None (should not happen with Decimal net)"
773            );
774            return None;
775        }
776
777        Some(entry)
778    }
779
780    /// Gets all validation errors.
781    pub fn get_validation_errors(&self) -> &[ValidationError] {
782        &self.validation_errors
783    }
784
785    /// Clears validation errors.
786    pub fn clear_validation_errors(&mut self) {
787        self.validation_errors.clear();
788        self.stats.validation_errors = 0;
789    }
790
791    /// Gets tracker statistics.
792    pub fn get_statistics(&self) -> &TrackerStatistics {
793        &self.stats
794    }
795
796    /// Rolls forward balances to a new period.
797    pub fn roll_forward(&mut self, _new_period_start: NaiveDate) {
798        for company_balances in self.balances.values_mut() {
799            for balance in company_balances.values_mut() {
800                balance.roll_forward();
801            }
802        }
803    }
804
805    /// Exports balances to a simple format.
806    pub fn export_balances(&self, company_code: &str) -> Vec<(String, Decimal)> {
807        self.balances
808            .get(company_code)
809            .map(|balances| {
810                balances
811                    .iter()
812                    .map(|(code, balance)| (code.clone(), balance.closing_balance))
813                    .collect()
814            })
815            .unwrap_or_default()
816    }
817}
818
819#[cfg(test)]
820mod tests {
821    use super::*;
822    use datasynth_core::models::{JournalEntry, JournalEntryLine};
823
824    fn create_test_entry(
825        company: &str,
826        account1: &str,
827        account2: &str,
828        amount: Decimal,
829    ) -> JournalEntry {
830        let mut entry = JournalEntry::new_simple(
831            "TEST001".to_string(),
832            company.to_string(),
833            NaiveDate::from_ymd_opt(2024, 1, 15).unwrap(),
834            "Test entry".to_string(),
835        );
836
837        entry.add_line(JournalEntryLine {
838            line_number: 1,
839            gl_account: account1.to_string(),
840            account_code: account1.to_string(),
841            debit_amount: amount,
842            ..Default::default()
843        });
844
845        entry.add_line(JournalEntryLine {
846            line_number: 2,
847            gl_account: account2.to_string(),
848            account_code: account2.to_string(),
849            credit_amount: amount,
850            ..Default::default()
851        });
852
853        entry
854    }
855
856    #[test]
857    fn test_apply_balanced_entry() {
858        let mut tracker = RunningBalanceTracker::with_defaults();
859        tracker.register_account_type("1100", AccountType::Asset);
860        tracker.register_account_type("4000", AccountType::Revenue);
861
862        let entry = create_test_entry("1000", "1100", "4000", dec!(1000));
863        let result = tracker.apply_entry(&entry);
864
865        assert!(result.is_ok());
866        assert_eq!(tracker.stats.entries_processed, 1);
867        assert_eq!(tracker.stats.lines_processed, 2);
868    }
869
870    #[test]
871    fn test_balance_accumulation() {
872        let mut tracker = RunningBalanceTracker::with_defaults();
873        tracker.config.validate_on_each_entry = false;
874
875        let entry1 = create_test_entry("1000", "1100", "4000", dec!(1000));
876        let entry2 = create_test_entry("1000", "1100", "4000", dec!(500));
877
878        tracker.apply_entry(&entry1).unwrap();
879        tracker.apply_entry(&entry2).unwrap();
880
881        let balance = tracker.get_account_balance("1000", "1100").unwrap();
882        assert_eq!(balance.closing_balance, dec!(1500));
883    }
884
885    #[test]
886    fn test_get_snapshot() {
887        let mut tracker = RunningBalanceTracker::with_defaults();
888        tracker.config.validate_on_each_entry = false;
889
890        let entry = create_test_entry("1000", "1100", "2000", dec!(1000));
891        tracker.apply_entry(&entry).unwrap();
892
893        let snapshot = tracker
894            .get_snapshot("1000", NaiveDate::from_ymd_opt(2024, 1, 31).unwrap())
895            .unwrap();
896
897        assert_eq!(snapshot.balances.len(), 2);
898    }
899
900    #[test]
901    fn test_determine_account_type_from_prefix() {
902        let tracker = RunningBalanceTracker::with_defaults();
903
904        assert_eq!(tracker.determine_account_type("1000"), AccountType::Asset);
905        assert_eq!(
906            tracker.determine_account_type("2000"),
907            AccountType::Liability
908        );
909        assert_eq!(tracker.determine_account_type("3000"), AccountType::Equity);
910        assert_eq!(tracker.determine_account_type("4000"), AccountType::Revenue);
911        assert_eq!(tracker.determine_account_type("5000"), AccountType::Expense);
912    }
913
914    #[test]
915    fn test_determine_account_type_french_gaap() {
916        let tracker = RunningBalanceTracker::new_with_framework(
917            BalanceTrackerConfig::default(),
918            "french_gaap",
919        );
920
921        // PCG class 2 = Fixed Assets (Asset)
922        assert_eq!(tracker.determine_account_type("210000"), AccountType::Asset);
923        // PCG class 1 subclass 0-4 = Equity
924        assert_eq!(
925            tracker.determine_account_type("101000"),
926            AccountType::Equity
927        );
928        // PCG class 4 subclass 0 = Suppliers (Liability)
929        assert_eq!(
930            tracker.determine_account_type("401000"),
931            AccountType::Liability
932        );
933        // PCG class 6 = Expenses
934        assert_eq!(
935            tracker.determine_account_type("603000"),
936            AccountType::Expense
937        );
938        // PCG class 7 = Revenue
939        assert_eq!(
940            tracker.determine_account_type("701000"),
941            AccountType::Revenue
942        );
943    }
944
945    #[test]
946    fn test_determine_account_type_german_gaap() {
947        let tracker = RunningBalanceTracker::new_with_framework(
948            BalanceTrackerConfig::default(),
949            "german_gaap",
950        );
951
952        // SKR04 class 0 = Fixed Assets (Asset)
953        assert_eq!(tracker.determine_account_type("0200"), AccountType::Asset);
954        // SKR04 class 2 = Equity
955        assert_eq!(tracker.determine_account_type("2000"), AccountType::Equity);
956        // SKR04 class 3 = Liabilities
957        assert_eq!(
958            tracker.determine_account_type("3300"),
959            AccountType::Liability
960        );
961        // SKR04 class 4 = Revenue
962        assert_eq!(tracker.determine_account_type("4000"), AccountType::Revenue);
963        // SKR04 class 5 = COGS (Expense)
964        assert_eq!(tracker.determine_account_type("5000"), AccountType::Expense);
965    }
966}