Skip to main content

datasynth_generators/balance/
balance_tracker.rs

1//! Running balance tracker.
2//!
3//! Maintains real-time account balances as journal entries are processed,
4//! with continuous validation of balance sheet integrity.
5
6use chrono::{Datelike, NaiveDate};
7use rust_decimal::Decimal;
8use rust_decimal_macros::dec;
9use std::collections::HashMap;
10use tracing::debug;
11
12use datasynth_core::models::balance::{
13    AccountBalance, AccountPeriodActivity, AccountType, BalanceSnapshot,
14};
15use datasynth_core::models::JournalEntry;
16
17/// Configuration for the balance tracker.
18#[derive(Debug, Clone)]
19pub struct BalanceTrackerConfig {
20    /// Whether to validate balance sheet equation after each entry.
21    pub validate_on_each_entry: bool,
22    /// Whether to track balance history.
23    pub track_history: bool,
24    /// Tolerance for balance sheet validation (for rounding).
25    pub balance_tolerance: Decimal,
26    /// Whether to fail on validation errors.
27    pub fail_on_validation_error: bool,
28}
29
30impl Default for BalanceTrackerConfig {
31    fn default() -> Self {
32        Self {
33            validate_on_each_entry: true,
34            track_history: true,
35            balance_tolerance: dec!(0.01),
36            fail_on_validation_error: false,
37        }
38    }
39}
40
41/// Tracks running balances for all accounts across companies.
42pub struct RunningBalanceTracker {
43    config: BalanceTrackerConfig,
44    /// Balances by company code -> account code -> balance.
45    balances: HashMap<String, HashMap<String, AccountBalance>>,
46    /// Account type registry for determining debit/credit behavior.
47    account_types: HashMap<String, AccountType>,
48    /// Balance history by company code.
49    history: HashMap<String, Vec<BalanceHistoryEntry>>,
50    /// Validation errors encountered.
51    validation_errors: Vec<ValidationError>,
52    /// Statistics.
53    stats: TrackerStatistics,
54    /// Default currency for new account balances and snapshots.
55    currency: String,
56}
57
58/// Entry in balance history.
59#[derive(Debug, Clone)]
60pub struct BalanceHistoryEntry {
61    pub date: NaiveDate,
62    pub entry_id: String,
63    pub account_code: String,
64    pub previous_balance: Decimal,
65    pub change: Decimal,
66    pub new_balance: Decimal,
67}
68
69/// Validation error details.
70#[derive(Debug, Clone)]
71pub struct ValidationError {
72    pub date: NaiveDate,
73    pub company_code: String,
74    pub entry_id: Option<String>,
75    pub error_type: ValidationErrorType,
76    pub message: String,
77    pub details: HashMap<String, Decimal>,
78}
79
80/// Types of validation errors.
81#[derive(Debug, Clone, PartialEq, Eq)]
82pub enum ValidationErrorType {
83    /// Entry debits don't equal credits.
84    UnbalancedEntry,
85    /// Balance sheet equation violated.
86    BalanceSheetImbalance,
87    /// Account has negative balance where not allowed.
88    NegativeBalance,
89    /// Unknown account code.
90    UnknownAccount,
91    /// Entry applied out of chronological order.
92    OutOfOrder,
93}
94
95/// Statistics about tracked entries.
96#[derive(Debug, Clone, Default)]
97pub struct TrackerStatistics {
98    pub entries_processed: u64,
99    pub lines_processed: u64,
100    pub total_debits: Decimal,
101    pub total_credits: Decimal,
102    pub companies_tracked: usize,
103    pub accounts_tracked: usize,
104    pub validation_errors: usize,
105}
106
107impl RunningBalanceTracker {
108    /// Creates a new balance tracker with the specified currency.
109    pub fn new_with_currency(config: BalanceTrackerConfig, currency: String) -> Self {
110        Self {
111            config,
112            balances: HashMap::new(),
113            account_types: HashMap::new(),
114            history: HashMap::new(),
115            validation_errors: Vec::new(),
116            stats: TrackerStatistics::default(),
117            currency,
118        }
119    }
120
121    /// Creates a new balance tracker (defaults to USD).
122    pub fn new(config: BalanceTrackerConfig) -> Self {
123        Self::new_with_currency(config, "USD".to_string())
124    }
125
126    /// Creates a tracker with default configuration.
127    pub fn with_defaults() -> Self {
128        Self::new(BalanceTrackerConfig::default())
129    }
130
131    /// Registers an account type for balance tracking.
132    pub fn register_account_type(&mut self, account_code: &str, account_type: AccountType) {
133        self.account_types
134            .insert(account_code.to_string(), account_type);
135    }
136
137    /// Registers multiple account types.
138    pub fn register_account_types(&mut self, types: &[(String, AccountType)]) {
139        for (code, account_type) in types {
140            self.account_types.insert(code.clone(), *account_type);
141        }
142    }
143
144    /// Registers account types from a chart of accounts prefix pattern.
145    pub fn register_from_chart_prefixes(&mut self, prefixes: &[(&str, AccountType)]) {
146        for (prefix, account_type) in prefixes {
147            self.account_types.insert(prefix.to_string(), *account_type);
148        }
149    }
150
151    /// Initializes balances from opening balance snapshot.
152    pub fn initialize_from_snapshot(&mut self, snapshot: &BalanceSnapshot) {
153        let company_balances = self
154            .balances
155            .entry(snapshot.company_code.clone())
156            .or_default();
157
158        for (account_code, balance) in &snapshot.balances {
159            company_balances.insert(account_code.clone(), balance.clone());
160        }
161
162        self.stats.companies_tracked = self.balances.len();
163        self.stats.accounts_tracked = self.balances.values().map(|b| b.len()).sum();
164    }
165
166    /// Applies a journal entry to the running balances.
167    pub fn apply_entry(&mut self, entry: &JournalEntry) -> Result<(), ValidationError> {
168        // Validate entry is balanced first
169        if !entry.is_balanced() {
170            let error = ValidationError {
171                date: entry.posting_date(),
172                company_code: entry.company_code().to_string(),
173                entry_id: Some(entry.document_number().clone()),
174                error_type: ValidationErrorType::UnbalancedEntry,
175                message: format!(
176                    "Entry {} is unbalanced: debits={}, credits={}",
177                    entry.document_number(),
178                    entry.total_debit(),
179                    entry.total_credit()
180                ),
181                details: {
182                    let mut d = HashMap::new();
183                    d.insert("total_debit".to_string(), entry.total_debit());
184                    d.insert("total_credit".to_string(), entry.total_credit());
185                    d
186                },
187            };
188
189            if self.config.fail_on_validation_error {
190                return Err(error);
191            }
192            self.validation_errors.push(error);
193        }
194
195        // Extract data we need before mutably borrowing balances
196        let company_code = entry.company_code().to_string();
197        let document_number = entry.document_number().clone();
198        let posting_date = entry.posting_date();
199        let track_history = self.config.track_history;
200
201        // Pre-compute account types for all lines
202        let line_data: Vec<_> = entry
203            .lines
204            .iter()
205            .map(|line| {
206                let account_type = self.determine_account_type(&line.account_code);
207                (line.clone(), account_type)
208            })
209            .collect();
210
211        // Get or create company balances
212        let company_balances = self.balances.entry(company_code.clone()).or_default();
213
214        // History entries to add
215        let mut history_entries = Vec::new();
216
217        // Apply each line
218        for (line, account_type) in &line_data {
219            // Get or create account balance
220            let balance = company_balances
221                .entry(line.account_code.clone())
222                .or_insert_with(|| {
223                    AccountBalance::new(
224                        company_code.clone(),
225                        line.account_code.clone(),
226                        *account_type,
227                        self.currency.clone(),
228                        posting_date.year(),
229                        posting_date.month(),
230                    )
231                });
232
233            let previous_balance = balance.closing_balance;
234
235            // Apply debit or credit
236            if line.debit_amount > Decimal::ZERO {
237                balance.apply_debit(line.debit_amount);
238            }
239            if line.credit_amount > Decimal::ZERO {
240                balance.apply_credit(line.credit_amount);
241            }
242
243            let new_balance = balance.closing_balance;
244
245            // Record history if configured
246            if track_history {
247                let change = line.debit_amount - line.credit_amount;
248                history_entries.push(BalanceHistoryEntry {
249                    date: posting_date,
250                    entry_id: document_number.clone(),
251                    account_code: line.account_code.clone(),
252                    previous_balance,
253                    change,
254                    new_balance,
255                });
256            }
257        }
258
259        // Add history entries after releasing the balances borrow
260        if !history_entries.is_empty() {
261            let hist = self.history.entry(company_code.clone()).or_default();
262            hist.extend(history_entries);
263        }
264
265        // Update statistics
266        self.stats.entries_processed += 1;
267        self.stats.lines_processed += entry.lines.len() as u64;
268        self.stats.total_debits += entry.total_debit();
269        self.stats.total_credits += entry.total_credit();
270        self.stats.companies_tracked = self.balances.len();
271        self.stats.accounts_tracked = self.balances.values().map(|b| b.len()).sum();
272
273        // Validate balance sheet if configured
274        if self.config.validate_on_each_entry {
275            self.validate_balance_sheet(
276                entry.company_code(),
277                entry.posting_date(),
278                Some(&entry.document_number()),
279            )?;
280        }
281
282        Ok(())
283    }
284
285    /// Applies a batch of entries.
286    pub fn apply_entries(&mut self, entries: &[JournalEntry]) -> Vec<ValidationError> {
287        debug!(
288            entry_count = entries.len(),
289            companies_tracked = self.stats.companies_tracked,
290            accounts_tracked = self.stats.accounts_tracked,
291            "Applying entries to balance tracker"
292        );
293
294        let mut errors = Vec::new();
295
296        for entry in entries {
297            if let Err(error) = self.apply_entry(entry) {
298                errors.push(error);
299            }
300        }
301
302        errors
303    }
304
305    /// Determines account type from code prefix.
306    fn determine_account_type(&self, account_code: &str) -> AccountType {
307        // Check registered types first (exact match or prefix)
308        for (registered_code, account_type) in &self.account_types {
309            if account_code.starts_with(registered_code) {
310                return *account_type;
311            }
312        }
313
314        // Default logic based on first digit
315        match account_code.chars().next() {
316            Some('1') => AccountType::Asset,
317            Some('2') => AccountType::Liability,
318            Some('3') => AccountType::Equity,
319            Some('4') => AccountType::Revenue,
320            Some('5') | Some('6') | Some('7') | Some('8') => AccountType::Expense,
321            _ => AccountType::Asset, // Default fallback
322        }
323    }
324
325    /// Validates the balance sheet equation for a company.
326    pub fn validate_balance_sheet(
327        &mut self,
328        company_code: &str,
329        date: NaiveDate,
330        entry_id: Option<&str>,
331    ) -> Result<(), ValidationError> {
332        let Some(company_balances) = self.balances.get(company_code) else {
333            return Ok(()); // No balances to validate
334        };
335
336        let mut total_assets = Decimal::ZERO;
337        let mut total_liabilities = Decimal::ZERO;
338        let mut total_equity = Decimal::ZERO;
339        let mut total_revenue = Decimal::ZERO;
340        let mut total_expenses = Decimal::ZERO;
341
342        for (account_code, balance) in company_balances {
343            let account_type = self.determine_account_type(account_code);
344            match account_type {
345                AccountType::Asset => total_assets += balance.closing_balance,
346                AccountType::ContraAsset => total_assets -= balance.closing_balance.abs(),
347                AccountType::Liability => total_liabilities += balance.closing_balance.abs(),
348                AccountType::ContraLiability => total_liabilities -= balance.closing_balance.abs(),
349                AccountType::Equity => total_equity += balance.closing_balance.abs(),
350                AccountType::ContraEquity => total_equity -= balance.closing_balance.abs(),
351                AccountType::Revenue => total_revenue += balance.closing_balance.abs(),
352                AccountType::Expense => total_expenses += balance.closing_balance.abs(),
353            }
354        }
355
356        // Net income = Revenue - Expenses
357        let net_income = total_revenue - total_expenses;
358
359        // Balance sheet equation: Assets = Liabilities + Equity + Net Income
360        let left_side = total_assets;
361        let right_side = total_liabilities + total_equity + net_income;
362        let difference = (left_side - right_side).abs();
363
364        if difference > self.config.balance_tolerance {
365            let error = ValidationError {
366                date,
367                company_code: company_code.to_string(),
368                entry_id: entry_id.map(String::from),
369                error_type: ValidationErrorType::BalanceSheetImbalance,
370                message: format!(
371                    "Balance sheet imbalance: Assets ({}) != L + E + NI ({}), diff = {}",
372                    left_side, right_side, difference
373                ),
374                details: {
375                    let mut d = HashMap::new();
376                    d.insert("total_assets".to_string(), total_assets);
377                    d.insert("total_liabilities".to_string(), total_liabilities);
378                    d.insert("total_equity".to_string(), total_equity);
379                    d.insert("net_income".to_string(), net_income);
380                    d.insert("difference".to_string(), difference);
381                    d
382                },
383            };
384
385            self.stats.validation_errors += 1;
386
387            if self.config.fail_on_validation_error {
388                return Err(error);
389            }
390            self.validation_errors.push(error);
391        }
392
393        Ok(())
394    }
395
396    /// Gets the current snapshot for a company.
397    pub fn get_snapshot(
398        &self,
399        company_code: &str,
400        as_of_date: NaiveDate,
401    ) -> Option<BalanceSnapshot> {
402        use chrono::Datelike;
403        let currency = self.currency.clone();
404        self.balances.get(company_code).map(|balances| {
405            let mut snapshot = BalanceSnapshot::new(
406                format!("SNAP-{}-{}", company_code, as_of_date),
407                company_code.to_string(),
408                as_of_date,
409                as_of_date.year(),
410                as_of_date.month(),
411                currency,
412            );
413            for (account, balance) in balances {
414                snapshot.balances.insert(account.clone(), balance.clone());
415            }
416            snapshot.recalculate_totals();
417            snapshot
418        })
419    }
420
421    /// Gets snapshots for all companies.
422    pub fn get_all_snapshots(&self, as_of_date: NaiveDate) -> Vec<BalanceSnapshot> {
423        use chrono::Datelike;
424        self.balances
425            .iter()
426            .map(|(company_code, balances)| {
427                let mut snapshot = BalanceSnapshot::new(
428                    format!("SNAP-{}-{}", company_code, as_of_date),
429                    company_code.clone(),
430                    as_of_date,
431                    as_of_date.year(),
432                    as_of_date.month(),
433                    self.currency.clone(),
434                );
435                for (account, balance) in balances {
436                    snapshot.balances.insert(account.clone(), balance.clone());
437                }
438                snapshot.recalculate_totals();
439                snapshot
440            })
441            .collect()
442    }
443
444    /// Gets balance changes for a period.
445    pub fn get_balance_changes(
446        &self,
447        company_code: &str,
448        from_date: NaiveDate,
449        to_date: NaiveDate,
450    ) -> Vec<AccountPeriodActivity> {
451        let Some(history) = self.history.get(company_code) else {
452            return Vec::new();
453        };
454
455        let mut changes_by_account: HashMap<String, AccountPeriodActivity> = HashMap::new();
456
457        for entry in history
458            .iter()
459            .filter(|e| e.date >= from_date && e.date <= to_date)
460        {
461            let change = changes_by_account
462                .entry(entry.account_code.clone())
463                .or_insert_with(|| AccountPeriodActivity {
464                    account_code: entry.account_code.clone(),
465                    period_start: from_date,
466                    period_end: to_date,
467                    opening_balance: Decimal::ZERO,
468                    closing_balance: Decimal::ZERO,
469                    total_debits: Decimal::ZERO,
470                    total_credits: Decimal::ZERO,
471                    net_change: Decimal::ZERO,
472                    transaction_count: 0,
473                });
474
475            if entry.change > Decimal::ZERO {
476                change.total_debits += entry.change;
477            } else {
478                change.total_credits += entry.change.abs();
479            }
480            change.net_change += entry.change;
481            change.transaction_count += 1;
482        }
483
484        // Update opening/closing balances
485        if let Some(company_balances) = self.balances.get(company_code) {
486            for change in changes_by_account.values_mut() {
487                if let Some(balance) = company_balances.get(&change.account_code) {
488                    change.closing_balance = balance.closing_balance;
489                    change.opening_balance = change.closing_balance - change.net_change;
490                }
491            }
492        }
493
494        changes_by_account.into_values().collect()
495    }
496
497    /// Gets balance for a specific account.
498    pub fn get_account_balance(
499        &self,
500        company_code: &str,
501        account_code: &str,
502    ) -> Option<&AccountBalance> {
503        self.balances
504            .get(company_code)
505            .and_then(|b| b.get(account_code))
506    }
507
508    /// Gets all validation errors.
509    pub fn get_validation_errors(&self) -> &[ValidationError] {
510        &self.validation_errors
511    }
512
513    /// Clears validation errors.
514    pub fn clear_validation_errors(&mut self) {
515        self.validation_errors.clear();
516        self.stats.validation_errors = 0;
517    }
518
519    /// Gets tracker statistics.
520    pub fn get_statistics(&self) -> &TrackerStatistics {
521        &self.stats
522    }
523
524    /// Rolls forward balances to a new period.
525    pub fn roll_forward(&mut self, _new_period_start: NaiveDate) {
526        for company_balances in self.balances.values_mut() {
527            for balance in company_balances.values_mut() {
528                balance.roll_forward();
529            }
530        }
531    }
532
533    /// Exports balances to a simple format.
534    pub fn export_balances(&self, company_code: &str) -> Vec<(String, Decimal)> {
535        self.balances
536            .get(company_code)
537            .map(|balances| {
538                balances
539                    .iter()
540                    .map(|(code, balance)| (code.clone(), balance.closing_balance))
541                    .collect()
542            })
543            .unwrap_or_default()
544    }
545}
546
547#[cfg(test)]
548#[allow(clippy::unwrap_used)]
549mod tests {
550    use super::*;
551    use datasynth_core::models::{JournalEntry, JournalEntryLine};
552
553    fn create_test_entry(
554        company: &str,
555        account1: &str,
556        account2: &str,
557        amount: Decimal,
558    ) -> JournalEntry {
559        let mut entry = JournalEntry::new_simple(
560            "TEST001".to_string(),
561            company.to_string(),
562            NaiveDate::from_ymd_opt(2024, 1, 15).unwrap(),
563            "Test entry".to_string(),
564        );
565
566        entry.add_line(JournalEntryLine {
567            line_number: 1,
568            gl_account: account1.to_string(),
569            account_code: account1.to_string(),
570            debit_amount: amount,
571            ..Default::default()
572        });
573
574        entry.add_line(JournalEntryLine {
575            line_number: 2,
576            gl_account: account2.to_string(),
577            account_code: account2.to_string(),
578            credit_amount: amount,
579            ..Default::default()
580        });
581
582        entry
583    }
584
585    #[test]
586    fn test_apply_balanced_entry() {
587        let mut tracker = RunningBalanceTracker::with_defaults();
588        tracker.register_account_type("1100", AccountType::Asset);
589        tracker.register_account_type("4000", AccountType::Revenue);
590
591        let entry = create_test_entry("1000", "1100", "4000", dec!(1000));
592        let result = tracker.apply_entry(&entry);
593
594        assert!(result.is_ok());
595        assert_eq!(tracker.stats.entries_processed, 1);
596        assert_eq!(tracker.stats.lines_processed, 2);
597    }
598
599    #[test]
600    fn test_balance_accumulation() {
601        let mut tracker = RunningBalanceTracker::with_defaults();
602        tracker.config.validate_on_each_entry = false;
603
604        let entry1 = create_test_entry("1000", "1100", "4000", dec!(1000));
605        let entry2 = create_test_entry("1000", "1100", "4000", dec!(500));
606
607        tracker.apply_entry(&entry1).unwrap();
608        tracker.apply_entry(&entry2).unwrap();
609
610        let balance = tracker.get_account_balance("1000", "1100").unwrap();
611        assert_eq!(balance.closing_balance, dec!(1500));
612    }
613
614    #[test]
615    fn test_get_snapshot() {
616        let mut tracker = RunningBalanceTracker::with_defaults();
617        tracker.config.validate_on_each_entry = false;
618
619        let entry = create_test_entry("1000", "1100", "2000", dec!(1000));
620        tracker.apply_entry(&entry).unwrap();
621
622        let snapshot = tracker
623            .get_snapshot("1000", NaiveDate::from_ymd_opt(2024, 1, 31).unwrap())
624            .unwrap();
625
626        assert_eq!(snapshot.balances.len(), 2);
627    }
628
629    #[test]
630    fn test_determine_account_type_from_prefix() {
631        let tracker = RunningBalanceTracker::with_defaults();
632
633        assert_eq!(tracker.determine_account_type("1000"), AccountType::Asset);
634        assert_eq!(
635            tracker.determine_account_type("2000"),
636            AccountType::Liability
637        );
638        assert_eq!(tracker.determine_account_type("3000"), AccountType::Equity);
639        assert_eq!(tracker.determine_account_type("4000"), AccountType::Revenue);
640        assert_eq!(tracker.determine_account_type("5000"), AccountType::Expense);
641    }
642}