1use chrono::{Datelike, NaiveDate};
7use rust_decimal::Decimal;
8use rust_decimal_macros::dec;
9use std::collections::HashMap;
10use tracing::debug;
11
12use datasynth_core::models::balance::{
13 AccountBalance, AccountPeriodActivity, AccountType, BalanceSnapshot,
14};
15use datasynth_core::models::JournalEntry;
16use datasynth_core::FrameworkAccounts;
17
18#[derive(Debug, Clone)]
20pub struct BalanceTrackerConfig {
21 pub validate_on_each_entry: bool,
23 pub track_history: bool,
25 pub balance_tolerance: Decimal,
27 pub fail_on_validation_error: bool,
29}
30
31impl Default for BalanceTrackerConfig {
32 fn default() -> Self {
33 Self {
34 validate_on_each_entry: true,
35 track_history: true,
36 balance_tolerance: dec!(0.01),
37 fail_on_validation_error: false,
38 }
39 }
40}
41
42pub struct RunningBalanceTracker {
44 config: BalanceTrackerConfig,
45 balances: HashMap<String, HashMap<String, AccountBalance>>,
47 account_types: HashMap<String, AccountType>,
49 framework_accounts: FrameworkAccounts,
51 history: HashMap<String, Vec<BalanceHistoryEntry>>,
53 validation_errors: Vec<ValidationError>,
55 stats: TrackerStatistics,
57 currency: String,
59}
60
61#[derive(Debug, Clone)]
63pub struct BalanceHistoryEntry {
64 pub date: NaiveDate,
65 pub entry_id: String,
66 pub account_code: String,
67 pub previous_balance: Decimal,
68 pub change: Decimal,
69 pub new_balance: Decimal,
70}
71
72#[derive(Debug, Clone)]
74pub struct ValidationError {
75 pub date: NaiveDate,
76 pub company_code: String,
77 pub entry_id: Option<String>,
78 pub error_type: ValidationErrorType,
79 pub message: String,
80 pub details: HashMap<String, Decimal>,
81}
82
83#[derive(Debug, Clone, PartialEq, Eq)]
85pub enum ValidationErrorType {
86 UnbalancedEntry,
88 BalanceSheetImbalance,
90 NegativeBalance,
92 UnknownAccount,
94 OutOfOrder,
96}
97
98#[derive(Debug, Clone, Default)]
100pub struct TrackerStatistics {
101 pub entries_processed: u64,
102 pub lines_processed: u64,
103 pub total_debits: Decimal,
104 pub total_credits: Decimal,
105 pub companies_tracked: usize,
106 pub accounts_tracked: usize,
107 pub validation_errors: usize,
108}
109
110impl RunningBalanceTracker {
111 pub fn new_with_currency_and_framework(
113 config: BalanceTrackerConfig,
114 currency: String,
115 framework: &str,
116 ) -> Self {
117 Self {
118 config,
119 balances: HashMap::new(),
120 account_types: HashMap::new(),
121 framework_accounts: FrameworkAccounts::for_framework(framework),
122 history: HashMap::new(),
123 validation_errors: Vec::new(),
124 stats: TrackerStatistics::default(),
125 currency,
126 }
127 }
128
129 pub fn new_with_currency(config: BalanceTrackerConfig, currency: String) -> Self {
131 Self::new_with_currency_and_framework(config, currency, "us_gaap")
132 }
133
134 pub fn new(config: BalanceTrackerConfig) -> Self {
136 Self::new_with_currency(config, "USD".to_string())
137 }
138
139 pub fn new_with_framework(config: BalanceTrackerConfig, framework: &str) -> Self {
141 Self::new_with_currency_and_framework(config, "USD".to_string(), framework)
142 }
143
144 pub fn with_defaults() -> Self {
146 Self::new(BalanceTrackerConfig::default())
147 }
148
149 pub fn register_account_type(&mut self, account_code: &str, account_type: AccountType) {
151 self.account_types
152 .insert(account_code.to_string(), account_type);
153 }
154
155 pub fn register_account_types(&mut self, types: &[(String, AccountType)]) {
157 for (code, account_type) in types {
158 self.account_types.insert(code.clone(), *account_type);
159 }
160 }
161
162 pub fn register_from_chart_prefixes(&mut self, prefixes: &[(&str, AccountType)]) {
164 for (prefix, account_type) in prefixes {
165 self.account_types.insert(prefix.to_string(), *account_type);
166 }
167 }
168
169 pub fn initialize_from_snapshot(&mut self, snapshot: &BalanceSnapshot) {
171 let company_balances = self
172 .balances
173 .entry(snapshot.company_code.clone())
174 .or_default();
175
176 for (account_code, balance) in &snapshot.balances {
177 company_balances.insert(account_code.clone(), balance.clone());
178 }
179
180 self.stats.companies_tracked = self.balances.len();
181 self.stats.accounts_tracked = self
182 .balances
183 .values()
184 .map(std::collections::HashMap::len)
185 .sum();
186 }
187
188 pub fn apply_entry(&mut self, entry: &JournalEntry) -> Result<(), ValidationError> {
190 if !entry.is_balanced() {
192 let error = ValidationError {
193 date: entry.posting_date(),
194 company_code: entry.company_code().to_string(),
195 entry_id: Some(entry.document_number().clone()),
196 error_type: ValidationErrorType::UnbalancedEntry,
197 message: format!(
198 "Entry {} is unbalanced: debits={}, credits={}",
199 entry.document_number(),
200 entry.total_debit(),
201 entry.total_credit()
202 ),
203 details: {
204 let mut d = HashMap::new();
205 d.insert("total_debit".to_string(), entry.total_debit());
206 d.insert("total_credit".to_string(), entry.total_credit());
207 d
208 },
209 };
210
211 if self.config.fail_on_validation_error {
212 return Err(error);
213 }
214 self.validation_errors.push(error);
215 }
216
217 let company_code = entry.company_code().to_string();
219 let document_number = entry.document_number().clone();
220 let posting_date = entry.posting_date();
221 let track_history = self.config.track_history;
222
223 let line_data: Vec<_> = entry
225 .lines
226 .iter()
227 .map(|line| {
228 let account_type = self.determine_account_type(&line.account_code);
229 (line.clone(), account_type)
230 })
231 .collect();
232
233 let company_balances = self.balances.entry(company_code.clone()).or_default();
235
236 let mut history_entries = Vec::new();
238
239 for (line, account_type) in &line_data {
241 let balance = company_balances
243 .entry(line.account_code.clone())
244 .or_insert_with(|| {
245 AccountBalance::new(
246 company_code.clone(),
247 line.account_code.clone(),
248 *account_type,
249 self.currency.clone(),
250 posting_date.year(),
251 posting_date.month(),
252 )
253 });
254
255 let previous_balance = balance.closing_balance;
256
257 if line.debit_amount > Decimal::ZERO {
259 balance.apply_debit(line.debit_amount);
260 }
261 if line.credit_amount > Decimal::ZERO {
262 balance.apply_credit(line.credit_amount);
263 }
264
265 let new_balance = balance.closing_balance;
266
267 if track_history {
269 let change = line.debit_amount - line.credit_amount;
270 history_entries.push(BalanceHistoryEntry {
271 date: posting_date,
272 entry_id: document_number.clone(),
273 account_code: line.account_code.clone(),
274 previous_balance,
275 change,
276 new_balance,
277 });
278 }
279 }
280
281 if !history_entries.is_empty() {
283 let hist = self.history.entry(company_code.clone()).or_default();
284 hist.extend(history_entries);
285 }
286
287 self.stats.entries_processed += 1;
289 self.stats.lines_processed += entry.lines.len() as u64;
290 self.stats.total_debits += entry.total_debit();
291 self.stats.total_credits += entry.total_credit();
292 self.stats.companies_tracked = self.balances.len();
293 self.stats.accounts_tracked = self
294 .balances
295 .values()
296 .map(std::collections::HashMap::len)
297 .sum();
298
299 if self.config.validate_on_each_entry {
301 self.validate_balance_sheet(
302 entry.company_code(),
303 entry.posting_date(),
304 Some(&entry.document_number()),
305 )?;
306 }
307
308 Ok(())
309 }
310
311 pub fn apply_entries(&mut self, entries: &[JournalEntry]) -> Vec<ValidationError> {
313 debug!(
314 entry_count = entries.len(),
315 companies_tracked = self.stats.companies_tracked,
316 accounts_tracked = self.stats.accounts_tracked,
317 "Applying entries to balance tracker"
318 );
319
320 let mut errors = Vec::new();
321
322 for entry in entries {
323 if let Err(error) = self.apply_entry(entry) {
324 errors.push(error);
325 }
326 }
327
328 errors
329 }
330
331 fn determine_account_type(&self, account_code: &str) -> AccountType {
336 for (registered_code, account_type) in &self.account_types {
338 if account_code.starts_with(registered_code) {
339 return *account_type;
340 }
341 }
342
343 self.framework_accounts.classify_account_type(account_code)
345 }
346
347 pub fn validate_balance_sheet(
349 &mut self,
350 company_code: &str,
351 date: NaiveDate,
352 entry_id: Option<&str>,
353 ) -> Result<(), ValidationError> {
354 let Some(company_balances) = self.balances.get(company_code) else {
355 return Ok(()); };
357
358 let mut total_assets = Decimal::ZERO;
359 let mut total_liabilities = Decimal::ZERO;
360 let mut total_equity = Decimal::ZERO;
361 let mut total_revenue = Decimal::ZERO;
362 let mut total_expenses = Decimal::ZERO;
363
364 for (account_code, balance) in company_balances {
365 let account_type = self.determine_account_type(account_code);
366 match account_type {
367 AccountType::Asset => total_assets += balance.closing_balance,
368 AccountType::ContraAsset => total_assets -= balance.closing_balance.abs(),
369 AccountType::Liability => total_liabilities += balance.closing_balance.abs(),
370 AccountType::ContraLiability => total_liabilities -= balance.closing_balance.abs(),
371 AccountType::Equity => total_equity += balance.closing_balance.abs(),
372 AccountType::ContraEquity => total_equity -= balance.closing_balance.abs(),
373 AccountType::Revenue => total_revenue += balance.closing_balance.abs(),
374 AccountType::Expense => total_expenses += balance.closing_balance.abs(),
375 }
376 }
377
378 let net_income = total_revenue - total_expenses;
380
381 let left_side = total_assets;
383 let right_side = total_liabilities + total_equity + net_income;
384 let difference = (left_side - right_side).abs();
385
386 if difference > self.config.balance_tolerance {
387 let error = ValidationError {
388 date,
389 company_code: company_code.to_string(),
390 entry_id: entry_id.map(String::from),
391 error_type: ValidationErrorType::BalanceSheetImbalance,
392 message: format!(
393 "Balance sheet imbalance: Assets ({left_side}) != L + E + NI ({right_side}), diff = {difference}"
394 ),
395 details: {
396 let mut d = HashMap::new();
397 d.insert("total_assets".to_string(), total_assets);
398 d.insert("total_liabilities".to_string(), total_liabilities);
399 d.insert("total_equity".to_string(), total_equity);
400 d.insert("net_income".to_string(), net_income);
401 d.insert("difference".to_string(), difference);
402 d
403 },
404 };
405
406 self.stats.validation_errors += 1;
407
408 if self.config.fail_on_validation_error {
409 return Err(error);
410 }
411 self.validation_errors.push(error);
412 }
413
414 Ok(())
415 }
416
417 pub fn get_snapshot(
419 &self,
420 company_code: &str,
421 as_of_date: NaiveDate,
422 ) -> Option<BalanceSnapshot> {
423 use chrono::Datelike;
424 let currency = self.currency.clone();
425 self.balances.get(company_code).map(|balances| {
426 let mut snapshot = BalanceSnapshot::new(
427 format!("SNAP-{company_code}-{as_of_date}"),
428 company_code.to_string(),
429 as_of_date,
430 as_of_date.year(),
431 as_of_date.month(),
432 currency,
433 );
434 for (account, balance) in balances {
435 snapshot.balances.insert(account.clone(), balance.clone());
436 }
437 snapshot.recalculate_totals();
438 snapshot
439 })
440 }
441
442 pub fn get_all_snapshots(&self, as_of_date: NaiveDate) -> Vec<BalanceSnapshot> {
444 use chrono::Datelike;
445 self.balances
446 .iter()
447 .map(|(company_code, balances)| {
448 let mut snapshot = BalanceSnapshot::new(
449 format!("SNAP-{company_code}-{as_of_date}"),
450 company_code.clone(),
451 as_of_date,
452 as_of_date.year(),
453 as_of_date.month(),
454 self.currency.clone(),
455 );
456 for (account, balance) in balances {
457 snapshot.balances.insert(account.clone(), balance.clone());
458 }
459 snapshot.recalculate_totals();
460 snapshot
461 })
462 .collect()
463 }
464
465 pub fn get_balance_changes(
467 &self,
468 company_code: &str,
469 from_date: NaiveDate,
470 to_date: NaiveDate,
471 ) -> Vec<AccountPeriodActivity> {
472 let Some(history) = self.history.get(company_code) else {
473 return Vec::new();
474 };
475
476 let mut changes_by_account: HashMap<String, AccountPeriodActivity> = HashMap::new();
477
478 for entry in history
479 .iter()
480 .filter(|e| e.date >= from_date && e.date <= to_date)
481 {
482 let change = changes_by_account
483 .entry(entry.account_code.clone())
484 .or_insert_with(|| AccountPeriodActivity {
485 account_code: entry.account_code.clone(),
486 period_start: from_date,
487 period_end: to_date,
488 opening_balance: Decimal::ZERO,
489 closing_balance: Decimal::ZERO,
490 total_debits: Decimal::ZERO,
491 total_credits: Decimal::ZERO,
492 net_change: Decimal::ZERO,
493 transaction_count: 0,
494 });
495
496 if entry.change > Decimal::ZERO {
497 change.total_debits += entry.change;
498 } else {
499 change.total_credits += entry.change.abs();
500 }
501 change.net_change += entry.change;
502 change.transaction_count += 1;
503 }
504
505 if let Some(company_balances) = self.balances.get(company_code) {
507 for change in changes_by_account.values_mut() {
508 if let Some(balance) = company_balances.get(&change.account_code) {
509 change.closing_balance = balance.closing_balance;
510 change.opening_balance = change.closing_balance - change.net_change;
511 }
512 }
513 }
514
515 changes_by_account.into_values().collect()
516 }
517
518 pub fn get_account_balance(
520 &self,
521 company_code: &str,
522 account_code: &str,
523 ) -> Option<&AccountBalance> {
524 self.balances
525 .get(company_code)
526 .and_then(|b| b.get(account_code))
527 }
528
529 pub fn get_validation_errors(&self) -> &[ValidationError] {
531 &self.validation_errors
532 }
533
534 pub fn clear_validation_errors(&mut self) {
536 self.validation_errors.clear();
537 self.stats.validation_errors = 0;
538 }
539
540 pub fn get_statistics(&self) -> &TrackerStatistics {
542 &self.stats
543 }
544
545 pub fn roll_forward(&mut self, _new_period_start: NaiveDate) {
547 for company_balances in self.balances.values_mut() {
548 for balance in company_balances.values_mut() {
549 balance.roll_forward();
550 }
551 }
552 }
553
554 pub fn export_balances(&self, company_code: &str) -> Vec<(String, Decimal)> {
556 self.balances
557 .get(company_code)
558 .map(|balances| {
559 balances
560 .iter()
561 .map(|(code, balance)| (code.clone(), balance.closing_balance))
562 .collect()
563 })
564 .unwrap_or_default()
565 }
566}
567
568#[cfg(test)]
569#[allow(clippy::unwrap_used)]
570mod tests {
571 use super::*;
572 use datasynth_core::models::{JournalEntry, JournalEntryLine};
573
574 fn create_test_entry(
575 company: &str,
576 account1: &str,
577 account2: &str,
578 amount: Decimal,
579 ) -> JournalEntry {
580 let mut entry = JournalEntry::new_simple(
581 "TEST001".to_string(),
582 company.to_string(),
583 NaiveDate::from_ymd_opt(2024, 1, 15).unwrap(),
584 "Test entry".to_string(),
585 );
586
587 entry.add_line(JournalEntryLine {
588 line_number: 1,
589 gl_account: account1.to_string(),
590 account_code: account1.to_string(),
591 debit_amount: amount,
592 ..Default::default()
593 });
594
595 entry.add_line(JournalEntryLine {
596 line_number: 2,
597 gl_account: account2.to_string(),
598 account_code: account2.to_string(),
599 credit_amount: amount,
600 ..Default::default()
601 });
602
603 entry
604 }
605
606 #[test]
607 fn test_apply_balanced_entry() {
608 let mut tracker = RunningBalanceTracker::with_defaults();
609 tracker.register_account_type("1100", AccountType::Asset);
610 tracker.register_account_type("4000", AccountType::Revenue);
611
612 let entry = create_test_entry("1000", "1100", "4000", dec!(1000));
613 let result = tracker.apply_entry(&entry);
614
615 assert!(result.is_ok());
616 assert_eq!(tracker.stats.entries_processed, 1);
617 assert_eq!(tracker.stats.lines_processed, 2);
618 }
619
620 #[test]
621 fn test_balance_accumulation() {
622 let mut tracker = RunningBalanceTracker::with_defaults();
623 tracker.config.validate_on_each_entry = false;
624
625 let entry1 = create_test_entry("1000", "1100", "4000", dec!(1000));
626 let entry2 = create_test_entry("1000", "1100", "4000", dec!(500));
627
628 tracker.apply_entry(&entry1).unwrap();
629 tracker.apply_entry(&entry2).unwrap();
630
631 let balance = tracker.get_account_balance("1000", "1100").unwrap();
632 assert_eq!(balance.closing_balance, dec!(1500));
633 }
634
635 #[test]
636 fn test_get_snapshot() {
637 let mut tracker = RunningBalanceTracker::with_defaults();
638 tracker.config.validate_on_each_entry = false;
639
640 let entry = create_test_entry("1000", "1100", "2000", dec!(1000));
641 tracker.apply_entry(&entry).unwrap();
642
643 let snapshot = tracker
644 .get_snapshot("1000", NaiveDate::from_ymd_opt(2024, 1, 31).unwrap())
645 .unwrap();
646
647 assert_eq!(snapshot.balances.len(), 2);
648 }
649
650 #[test]
651 fn test_determine_account_type_from_prefix() {
652 let tracker = RunningBalanceTracker::with_defaults();
653
654 assert_eq!(tracker.determine_account_type("1000"), AccountType::Asset);
655 assert_eq!(
656 tracker.determine_account_type("2000"),
657 AccountType::Liability
658 );
659 assert_eq!(tracker.determine_account_type("3000"), AccountType::Equity);
660 assert_eq!(tracker.determine_account_type("4000"), AccountType::Revenue);
661 assert_eq!(tracker.determine_account_type("5000"), AccountType::Expense);
662 }
663
664 #[test]
665 fn test_determine_account_type_french_gaap() {
666 let tracker = RunningBalanceTracker::new_with_framework(
667 BalanceTrackerConfig::default(),
668 "french_gaap",
669 );
670
671 assert_eq!(tracker.determine_account_type("210000"), AccountType::Asset);
673 assert_eq!(
675 tracker.determine_account_type("101000"),
676 AccountType::Equity
677 );
678 assert_eq!(
680 tracker.determine_account_type("401000"),
681 AccountType::Liability
682 );
683 assert_eq!(
685 tracker.determine_account_type("603000"),
686 AccountType::Expense
687 );
688 assert_eq!(
690 tracker.determine_account_type("701000"),
691 AccountType::Revenue
692 );
693 }
694
695 #[test]
696 fn test_determine_account_type_german_gaap() {
697 let tracker = RunningBalanceTracker::new_with_framework(
698 BalanceTrackerConfig::default(),
699 "german_gaap",
700 );
701
702 assert_eq!(tracker.determine_account_type("0200"), AccountType::Asset);
704 assert_eq!(tracker.determine_account_type("2000"), AccountType::Equity);
706 assert_eq!(
708 tracker.determine_account_type("3300"),
709 AccountType::Liability
710 );
711 assert_eq!(tracker.determine_account_type("4000"), AccountType::Revenue);
713 assert_eq!(tracker.determine_account_type("5000"), AccountType::Expense);
715 }
716}