Skip to main content

datasynth_generators/balance/
balance_tracker.rs

1//! Running balance tracker.
2//!
3//! Maintains real-time account balances as journal entries are processed,
4//! with continuous validation of balance sheet integrity.
5
6use chrono::{Datelike, NaiveDate};
7use rust_decimal::Decimal;
8use rust_decimal_macros::dec;
9use std::collections::HashMap;
10use tracing::debug;
11
12use datasynth_core::models::balance::{
13    AccountBalance, AccountPeriodActivity, AccountType, BalanceSnapshot,
14};
15use datasynth_core::models::JournalEntry;
16use datasynth_core::FrameworkAccounts;
17
18/// Configuration for the balance tracker.
19#[derive(Debug, Clone)]
20pub struct BalanceTrackerConfig {
21    /// Whether to validate balance sheet equation after each entry.
22    pub validate_on_each_entry: bool,
23    /// Whether to track balance history.
24    pub track_history: bool,
25    /// Tolerance for balance sheet validation (for rounding).
26    pub balance_tolerance: Decimal,
27    /// Whether to fail on validation errors.
28    pub fail_on_validation_error: bool,
29}
30
31impl Default for BalanceTrackerConfig {
32    fn default() -> Self {
33        Self {
34            validate_on_each_entry: true,
35            track_history: true,
36            balance_tolerance: dec!(0.01),
37            fail_on_validation_error: false,
38        }
39    }
40}
41
42/// Tracks running balances for all accounts across companies.
43pub struct RunningBalanceTracker {
44    config: BalanceTrackerConfig,
45    /// Balances by company code -> account code -> balance.
46    balances: HashMap<String, HashMap<String, AccountBalance>>,
47    /// Account type registry for determining debit/credit behavior.
48    account_types: HashMap<String, AccountType>,
49    /// Framework-aware account classification.
50    framework_accounts: FrameworkAccounts,
51    /// Balance history by company code.
52    history: HashMap<String, Vec<BalanceHistoryEntry>>,
53    /// Validation errors encountered.
54    validation_errors: Vec<ValidationError>,
55    /// Statistics.
56    stats: TrackerStatistics,
57    /// Default currency for new account balances and snapshots.
58    currency: String,
59}
60
61/// Entry in balance history.
62#[derive(Debug, Clone)]
63pub struct BalanceHistoryEntry {
64    pub date: NaiveDate,
65    pub entry_id: String,
66    pub account_code: String,
67    pub previous_balance: Decimal,
68    pub change: Decimal,
69    pub new_balance: Decimal,
70}
71
72/// Validation error details.
73#[derive(Debug, Clone)]
74pub struct ValidationError {
75    pub date: NaiveDate,
76    pub company_code: String,
77    pub entry_id: Option<String>,
78    pub error_type: ValidationErrorType,
79    pub message: String,
80    pub details: HashMap<String, Decimal>,
81}
82
83/// Types of validation errors.
84#[derive(Debug, Clone, PartialEq, Eq)]
85pub enum ValidationErrorType {
86    /// Entry debits don't equal credits.
87    UnbalancedEntry,
88    /// Balance sheet equation violated.
89    BalanceSheetImbalance,
90    /// Account has negative balance where not allowed.
91    NegativeBalance,
92    /// Unknown account code.
93    UnknownAccount,
94    /// Entry applied out of chronological order.
95    OutOfOrder,
96}
97
98/// Statistics about tracked entries.
99#[derive(Debug, Clone, Default)]
100pub struct TrackerStatistics {
101    pub entries_processed: u64,
102    pub lines_processed: u64,
103    pub total_debits: Decimal,
104    pub total_credits: Decimal,
105    pub companies_tracked: usize,
106    pub accounts_tracked: usize,
107    pub validation_errors: usize,
108}
109
110impl RunningBalanceTracker {
111    /// Creates a new balance tracker with the specified currency and accounting framework.
112    pub fn new_with_currency_and_framework(
113        config: BalanceTrackerConfig,
114        currency: String,
115        framework: &str,
116    ) -> Self {
117        Self {
118            config,
119            balances: HashMap::new(),
120            account_types: HashMap::new(),
121            framework_accounts: FrameworkAccounts::for_framework(framework),
122            history: HashMap::new(),
123            validation_errors: Vec::new(),
124            stats: TrackerStatistics::default(),
125            currency,
126        }
127    }
128
129    /// Creates a new balance tracker with the specified currency (defaults to US GAAP).
130    pub fn new_with_currency(config: BalanceTrackerConfig, currency: String) -> Self {
131        Self::new_with_currency_and_framework(config, currency, "us_gaap")
132    }
133
134    /// Creates a new balance tracker (defaults to USD and US GAAP).
135    pub fn new(config: BalanceTrackerConfig) -> Self {
136        Self::new_with_currency(config, "USD".to_string())
137    }
138
139    /// Creates a new balance tracker for a specific accounting framework (defaults to USD).
140    pub fn new_with_framework(config: BalanceTrackerConfig, framework: &str) -> Self {
141        Self::new_with_currency_and_framework(config, "USD".to_string(), framework)
142    }
143
144    /// Creates a tracker with default configuration (US GAAP).
145    pub fn with_defaults() -> Self {
146        Self::new(BalanceTrackerConfig::default())
147    }
148
149    /// Registers an account type for balance tracking.
150    pub fn register_account_type(&mut self, account_code: &str, account_type: AccountType) {
151        self.account_types
152            .insert(account_code.to_string(), account_type);
153    }
154
155    /// Registers multiple account types.
156    pub fn register_account_types(&mut self, types: &[(String, AccountType)]) {
157        for (code, account_type) in types {
158            self.account_types.insert(code.clone(), *account_type);
159        }
160    }
161
162    /// Registers account types from a chart of accounts prefix pattern.
163    pub fn register_from_chart_prefixes(&mut self, prefixes: &[(&str, AccountType)]) {
164        for (prefix, account_type) in prefixes {
165            self.account_types.insert(prefix.to_string(), *account_type);
166        }
167    }
168
169    /// Initializes balances from opening balance snapshot.
170    pub fn initialize_from_snapshot(&mut self, snapshot: &BalanceSnapshot) {
171        let company_balances = self
172            .balances
173            .entry(snapshot.company_code.clone())
174            .or_default();
175
176        for (account_code, balance) in &snapshot.balances {
177            company_balances.insert(account_code.clone(), balance.clone());
178        }
179
180        self.stats.companies_tracked = self.balances.len();
181        self.stats.accounts_tracked = self
182            .balances
183            .values()
184            .map(std::collections::HashMap::len)
185            .sum();
186    }
187
188    /// Applies a journal entry to the running balances.
189    pub fn apply_entry(&mut self, entry: &JournalEntry) -> Result<(), ValidationError> {
190        // Validate entry is balanced first
191        if !entry.is_balanced() {
192            let error = ValidationError {
193                date: entry.posting_date(),
194                company_code: entry.company_code().to_string(),
195                entry_id: Some(entry.document_number().clone()),
196                error_type: ValidationErrorType::UnbalancedEntry,
197                message: format!(
198                    "Entry {} is unbalanced: debits={}, credits={}",
199                    entry.document_number(),
200                    entry.total_debit(),
201                    entry.total_credit()
202                ),
203                details: {
204                    let mut d = HashMap::new();
205                    d.insert("total_debit".to_string(), entry.total_debit());
206                    d.insert("total_credit".to_string(), entry.total_credit());
207                    d
208                },
209            };
210
211            if self.config.fail_on_validation_error {
212                return Err(error);
213            }
214            self.validation_errors.push(error);
215        }
216
217        // Extract data we need before mutably borrowing balances
218        let company_code = entry.company_code().to_string();
219        let document_number = entry.document_number().clone();
220        let posting_date = entry.posting_date();
221        let track_history = self.config.track_history;
222
223        // Pre-compute account types for all lines
224        let line_data: Vec<_> = entry
225            .lines
226            .iter()
227            .map(|line| {
228                let account_type = self.determine_account_type(&line.account_code);
229                (line.clone(), account_type)
230            })
231            .collect();
232
233        // Get or create company balances
234        let company_balances = self.balances.entry(company_code.clone()).or_default();
235
236        // History entries to add
237        let mut history_entries = Vec::new();
238
239        // Apply each line
240        for (line, account_type) in &line_data {
241            // Get or create account balance
242            let balance = company_balances
243                .entry(line.account_code.clone())
244                .or_insert_with(|| {
245                    AccountBalance::new(
246                        company_code.clone(),
247                        line.account_code.clone(),
248                        *account_type,
249                        self.currency.clone(),
250                        posting_date.year(),
251                        posting_date.month(),
252                    )
253                });
254
255            let previous_balance = balance.closing_balance;
256
257            // Apply debit or credit
258            if line.debit_amount > Decimal::ZERO {
259                balance.apply_debit(line.debit_amount);
260            }
261            if line.credit_amount > Decimal::ZERO {
262                balance.apply_credit(line.credit_amount);
263            }
264
265            let new_balance = balance.closing_balance;
266
267            // Record history if configured
268            if track_history {
269                let change = line.debit_amount - line.credit_amount;
270                history_entries.push(BalanceHistoryEntry {
271                    date: posting_date,
272                    entry_id: document_number.clone(),
273                    account_code: line.account_code.clone(),
274                    previous_balance,
275                    change,
276                    new_balance,
277                });
278            }
279        }
280
281        // Add history entries after releasing the balances borrow
282        if !history_entries.is_empty() {
283            let hist = self.history.entry(company_code.clone()).or_default();
284            hist.extend(history_entries);
285        }
286
287        // Update statistics
288        self.stats.entries_processed += 1;
289        self.stats.lines_processed += entry.lines.len() as u64;
290        self.stats.total_debits += entry.total_debit();
291        self.stats.total_credits += entry.total_credit();
292        self.stats.companies_tracked = self.balances.len();
293        self.stats.accounts_tracked = self
294            .balances
295            .values()
296            .map(std::collections::HashMap::len)
297            .sum();
298
299        // Validate balance sheet if configured
300        if self.config.validate_on_each_entry {
301            self.validate_balance_sheet(
302                entry.company_code(),
303                entry.posting_date(),
304                Some(&entry.document_number()),
305            )?;
306        }
307
308        Ok(())
309    }
310
311    /// Applies a batch of entries.
312    pub fn apply_entries(&mut self, entries: &[JournalEntry]) -> Vec<ValidationError> {
313        debug!(
314            entry_count = entries.len(),
315            companies_tracked = self.stats.companies_tracked,
316            accounts_tracked = self.stats.accounts_tracked,
317            "Applying entries to balance tracker"
318        );
319
320        let mut errors = Vec::new();
321
322        for entry in entries {
323            if let Err(error) = self.apply_entry(entry) {
324                errors.push(error);
325            }
326        }
327
328        errors
329    }
330
331    /// Determines account type from code prefix.
332    ///
333    /// Checks explicitly registered types first, then falls back to the
334    /// framework-aware classifier from [`FrameworkAccounts`].
335    fn determine_account_type(&self, account_code: &str) -> AccountType {
336        // Check registered types first (exact match or prefix)
337        for (registered_code, account_type) in &self.account_types {
338            if account_code.starts_with(registered_code) {
339                return *account_type;
340            }
341        }
342
343        // Use framework-aware classification
344        self.framework_accounts.classify_account_type(account_code)
345    }
346
347    /// Validates the balance sheet equation for a company.
348    pub fn validate_balance_sheet(
349        &mut self,
350        company_code: &str,
351        date: NaiveDate,
352        entry_id: Option<&str>,
353    ) -> Result<(), ValidationError> {
354        let Some(company_balances) = self.balances.get(company_code) else {
355            return Ok(()); // No balances to validate
356        };
357
358        let mut total_assets = Decimal::ZERO;
359        let mut total_liabilities = Decimal::ZERO;
360        let mut total_equity = Decimal::ZERO;
361        let mut total_revenue = Decimal::ZERO;
362        let mut total_expenses = Decimal::ZERO;
363
364        for (account_code, balance) in company_balances {
365            let account_type = self.determine_account_type(account_code);
366            match account_type {
367                AccountType::Asset => total_assets += balance.closing_balance,
368                AccountType::ContraAsset => total_assets -= balance.closing_balance.abs(),
369                AccountType::Liability => total_liabilities += balance.closing_balance.abs(),
370                AccountType::ContraLiability => total_liabilities -= balance.closing_balance.abs(),
371                AccountType::Equity => total_equity += balance.closing_balance.abs(),
372                AccountType::ContraEquity => total_equity -= balance.closing_balance.abs(),
373                AccountType::Revenue => total_revenue += balance.closing_balance.abs(),
374                AccountType::Expense => total_expenses += balance.closing_balance.abs(),
375            }
376        }
377
378        // Net income = Revenue - Expenses
379        let net_income = total_revenue - total_expenses;
380
381        // Balance sheet equation: Assets = Liabilities + Equity + Net Income
382        let left_side = total_assets;
383        let right_side = total_liabilities + total_equity + net_income;
384        let difference = (left_side - right_side).abs();
385
386        if difference > self.config.balance_tolerance {
387            let error = ValidationError {
388                date,
389                company_code: company_code.to_string(),
390                entry_id: entry_id.map(String::from),
391                error_type: ValidationErrorType::BalanceSheetImbalance,
392                message: format!(
393                    "Balance sheet imbalance: Assets ({left_side}) != L + E + NI ({right_side}), diff = {difference}"
394                ),
395                details: {
396                    let mut d = HashMap::new();
397                    d.insert("total_assets".to_string(), total_assets);
398                    d.insert("total_liabilities".to_string(), total_liabilities);
399                    d.insert("total_equity".to_string(), total_equity);
400                    d.insert("net_income".to_string(), net_income);
401                    d.insert("difference".to_string(), difference);
402                    d
403                },
404            };
405
406            self.stats.validation_errors += 1;
407
408            if self.config.fail_on_validation_error {
409                return Err(error);
410            }
411            self.validation_errors.push(error);
412        }
413
414        Ok(())
415    }
416
417    /// Gets the current snapshot for a company.
418    pub fn get_snapshot(
419        &self,
420        company_code: &str,
421        as_of_date: NaiveDate,
422    ) -> Option<BalanceSnapshot> {
423        use chrono::Datelike;
424        let currency = self.currency.clone();
425        self.balances.get(company_code).map(|balances| {
426            let mut snapshot = BalanceSnapshot::new(
427                format!("SNAP-{company_code}-{as_of_date}"),
428                company_code.to_string(),
429                as_of_date,
430                as_of_date.year(),
431                as_of_date.month(),
432                currency,
433            );
434            for (account, balance) in balances {
435                snapshot.balances.insert(account.clone(), balance.clone());
436            }
437            snapshot.recalculate_totals();
438            snapshot
439        })
440    }
441
442    /// Gets snapshots for all companies.
443    pub fn get_all_snapshots(&self, as_of_date: NaiveDate) -> Vec<BalanceSnapshot> {
444        use chrono::Datelike;
445        self.balances
446            .iter()
447            .map(|(company_code, balances)| {
448                let mut snapshot = BalanceSnapshot::new(
449                    format!("SNAP-{company_code}-{as_of_date}"),
450                    company_code.clone(),
451                    as_of_date,
452                    as_of_date.year(),
453                    as_of_date.month(),
454                    self.currency.clone(),
455                );
456                for (account, balance) in balances {
457                    snapshot.balances.insert(account.clone(), balance.clone());
458                }
459                snapshot.recalculate_totals();
460                snapshot
461            })
462            .collect()
463    }
464
465    /// Gets balance changes for a period.
466    pub fn get_balance_changes(
467        &self,
468        company_code: &str,
469        from_date: NaiveDate,
470        to_date: NaiveDate,
471    ) -> Vec<AccountPeriodActivity> {
472        let Some(history) = self.history.get(company_code) else {
473            return Vec::new();
474        };
475
476        let mut changes_by_account: HashMap<String, AccountPeriodActivity> = HashMap::new();
477
478        for entry in history
479            .iter()
480            .filter(|e| e.date >= from_date && e.date <= to_date)
481        {
482            let change = changes_by_account
483                .entry(entry.account_code.clone())
484                .or_insert_with(|| AccountPeriodActivity {
485                    account_code: entry.account_code.clone(),
486                    period_start: from_date,
487                    period_end: to_date,
488                    opening_balance: Decimal::ZERO,
489                    closing_balance: Decimal::ZERO,
490                    total_debits: Decimal::ZERO,
491                    total_credits: Decimal::ZERO,
492                    net_change: Decimal::ZERO,
493                    transaction_count: 0,
494                });
495
496            if entry.change > Decimal::ZERO {
497                change.total_debits += entry.change;
498            } else {
499                change.total_credits += entry.change.abs();
500            }
501            change.net_change += entry.change;
502            change.transaction_count += 1;
503        }
504
505        // Update opening/closing balances
506        if let Some(company_balances) = self.balances.get(company_code) {
507            for change in changes_by_account.values_mut() {
508                if let Some(balance) = company_balances.get(&change.account_code) {
509                    change.closing_balance = balance.closing_balance;
510                    change.opening_balance = change.closing_balance - change.net_change;
511                }
512            }
513        }
514
515        changes_by_account.into_values().collect()
516    }
517
518    /// Gets balance for a specific account.
519    pub fn get_account_balance(
520        &self,
521        company_code: &str,
522        account_code: &str,
523    ) -> Option<&AccountBalance> {
524        self.balances
525            .get(company_code)
526            .and_then(|b| b.get(account_code))
527    }
528
529    /// Gets all validation errors.
530    pub fn get_validation_errors(&self) -> &[ValidationError] {
531        &self.validation_errors
532    }
533
534    /// Clears validation errors.
535    pub fn clear_validation_errors(&mut self) {
536        self.validation_errors.clear();
537        self.stats.validation_errors = 0;
538    }
539
540    /// Gets tracker statistics.
541    pub fn get_statistics(&self) -> &TrackerStatistics {
542        &self.stats
543    }
544
545    /// Rolls forward balances to a new period.
546    pub fn roll_forward(&mut self, _new_period_start: NaiveDate) {
547        for company_balances in self.balances.values_mut() {
548            for balance in company_balances.values_mut() {
549                balance.roll_forward();
550            }
551        }
552    }
553
554    /// Exports balances to a simple format.
555    pub fn export_balances(&self, company_code: &str) -> Vec<(String, Decimal)> {
556        self.balances
557            .get(company_code)
558            .map(|balances| {
559                balances
560                    .iter()
561                    .map(|(code, balance)| (code.clone(), balance.closing_balance))
562                    .collect()
563            })
564            .unwrap_or_default()
565    }
566}
567
568#[cfg(test)]
569#[allow(clippy::unwrap_used)]
570mod tests {
571    use super::*;
572    use datasynth_core::models::{JournalEntry, JournalEntryLine};
573
574    fn create_test_entry(
575        company: &str,
576        account1: &str,
577        account2: &str,
578        amount: Decimal,
579    ) -> JournalEntry {
580        let mut entry = JournalEntry::new_simple(
581            "TEST001".to_string(),
582            company.to_string(),
583            NaiveDate::from_ymd_opt(2024, 1, 15).unwrap(),
584            "Test entry".to_string(),
585        );
586
587        entry.add_line(JournalEntryLine {
588            line_number: 1,
589            gl_account: account1.to_string(),
590            account_code: account1.to_string(),
591            debit_amount: amount,
592            ..Default::default()
593        });
594
595        entry.add_line(JournalEntryLine {
596            line_number: 2,
597            gl_account: account2.to_string(),
598            account_code: account2.to_string(),
599            credit_amount: amount,
600            ..Default::default()
601        });
602
603        entry
604    }
605
606    #[test]
607    fn test_apply_balanced_entry() {
608        let mut tracker = RunningBalanceTracker::with_defaults();
609        tracker.register_account_type("1100", AccountType::Asset);
610        tracker.register_account_type("4000", AccountType::Revenue);
611
612        let entry = create_test_entry("1000", "1100", "4000", dec!(1000));
613        let result = tracker.apply_entry(&entry);
614
615        assert!(result.is_ok());
616        assert_eq!(tracker.stats.entries_processed, 1);
617        assert_eq!(tracker.stats.lines_processed, 2);
618    }
619
620    #[test]
621    fn test_balance_accumulation() {
622        let mut tracker = RunningBalanceTracker::with_defaults();
623        tracker.config.validate_on_each_entry = false;
624
625        let entry1 = create_test_entry("1000", "1100", "4000", dec!(1000));
626        let entry2 = create_test_entry("1000", "1100", "4000", dec!(500));
627
628        tracker.apply_entry(&entry1).unwrap();
629        tracker.apply_entry(&entry2).unwrap();
630
631        let balance = tracker.get_account_balance("1000", "1100").unwrap();
632        assert_eq!(balance.closing_balance, dec!(1500));
633    }
634
635    #[test]
636    fn test_get_snapshot() {
637        let mut tracker = RunningBalanceTracker::with_defaults();
638        tracker.config.validate_on_each_entry = false;
639
640        let entry = create_test_entry("1000", "1100", "2000", dec!(1000));
641        tracker.apply_entry(&entry).unwrap();
642
643        let snapshot = tracker
644            .get_snapshot("1000", NaiveDate::from_ymd_opt(2024, 1, 31).unwrap())
645            .unwrap();
646
647        assert_eq!(snapshot.balances.len(), 2);
648    }
649
650    #[test]
651    fn test_determine_account_type_from_prefix() {
652        let tracker = RunningBalanceTracker::with_defaults();
653
654        assert_eq!(tracker.determine_account_type("1000"), AccountType::Asset);
655        assert_eq!(
656            tracker.determine_account_type("2000"),
657            AccountType::Liability
658        );
659        assert_eq!(tracker.determine_account_type("3000"), AccountType::Equity);
660        assert_eq!(tracker.determine_account_type("4000"), AccountType::Revenue);
661        assert_eq!(tracker.determine_account_type("5000"), AccountType::Expense);
662    }
663
664    #[test]
665    fn test_determine_account_type_french_gaap() {
666        let tracker = RunningBalanceTracker::new_with_framework(
667            BalanceTrackerConfig::default(),
668            "french_gaap",
669        );
670
671        // PCG class 2 = Fixed Assets (Asset)
672        assert_eq!(tracker.determine_account_type("210000"), AccountType::Asset);
673        // PCG class 1 subclass 0-4 = Equity
674        assert_eq!(
675            tracker.determine_account_type("101000"),
676            AccountType::Equity
677        );
678        // PCG class 4 subclass 0 = Suppliers (Liability)
679        assert_eq!(
680            tracker.determine_account_type("401000"),
681            AccountType::Liability
682        );
683        // PCG class 6 = Expenses
684        assert_eq!(
685            tracker.determine_account_type("603000"),
686            AccountType::Expense
687        );
688        // PCG class 7 = Revenue
689        assert_eq!(
690            tracker.determine_account_type("701000"),
691            AccountType::Revenue
692        );
693    }
694
695    #[test]
696    fn test_determine_account_type_german_gaap() {
697        let tracker = RunningBalanceTracker::new_with_framework(
698            BalanceTrackerConfig::default(),
699            "german_gaap",
700        );
701
702        // SKR04 class 0 = Fixed Assets (Asset)
703        assert_eq!(tracker.determine_account_type("0200"), AccountType::Asset);
704        // SKR04 class 2 = Equity
705        assert_eq!(tracker.determine_account_type("2000"), AccountType::Equity);
706        // SKR04 class 3 = Liabilities
707        assert_eq!(
708            tracker.determine_account_type("3300"),
709            AccountType::Liability
710        );
711        // SKR04 class 4 = Revenue
712        assert_eq!(tracker.determine_account_type("4000"), AccountType::Revenue);
713        // SKR04 class 5 = COGS (Expense)
714        assert_eq!(tracker.determine_account_type("5000"), AccountType::Expense);
715    }
716}