1use chrono::{Datelike, NaiveDate};
7use rust_decimal::Decimal;
8use rust_decimal_macros::dec;
9use std::collections::HashMap;
10use tracing::debug;
11
12use datasynth_core::models::balance::{
13 AccountBalance, AccountPeriodActivity, AccountType, BalanceSnapshot,
14};
15use datasynth_core::models::JournalEntry;
16use datasynth_core::FrameworkAccounts;
17
18#[derive(Debug, Clone)]
20pub struct BalanceTrackerConfig {
21 pub validate_on_each_entry: bool,
23 pub track_history: bool,
25 pub balance_tolerance: Decimal,
27 pub fail_on_validation_error: bool,
29}
30
31impl Default for BalanceTrackerConfig {
32 fn default() -> Self {
33 Self {
34 validate_on_each_entry: true,
35 track_history: true,
36 balance_tolerance: dec!(0.01),
37 fail_on_validation_error: false,
38 }
39 }
40}
41
42pub struct RunningBalanceTracker {
44 config: BalanceTrackerConfig,
45 balances: HashMap<String, HashMap<String, AccountBalance>>,
47 account_types: HashMap<String, AccountType>,
49 framework_accounts: FrameworkAccounts,
51 history: HashMap<String, Vec<BalanceHistoryEntry>>,
53 validation_errors: Vec<ValidationError>,
55 stats: TrackerStatistics,
57 currency: String,
59}
60
61#[derive(Debug, Clone)]
63pub struct BalanceHistoryEntry {
64 pub date: NaiveDate,
65 pub entry_id: String,
66 pub account_code: String,
67 pub previous_balance: Decimal,
68 pub change: Decimal,
69 pub new_balance: Decimal,
70}
71
72#[derive(Debug, Clone)]
74pub struct ValidationError {
75 pub date: NaiveDate,
76 pub company_code: String,
77 pub entry_id: Option<String>,
78 pub error_type: ValidationErrorType,
79 pub message: String,
80 pub details: HashMap<String, Decimal>,
81}
82
83#[derive(Debug, Clone, PartialEq, Eq)]
85pub enum ValidationErrorType {
86 UnbalancedEntry,
88 BalanceSheetImbalance,
90 NegativeBalance,
92 UnknownAccount,
94 OutOfOrder,
96}
97
98#[derive(Debug, Clone, Default)]
100pub struct TrackerStatistics {
101 pub entries_processed: u64,
102 pub lines_processed: u64,
103 pub total_debits: Decimal,
104 pub total_credits: Decimal,
105 pub companies_tracked: usize,
106 pub accounts_tracked: usize,
107 pub validation_errors: usize,
108}
109
110impl RunningBalanceTracker {
111 pub fn new_with_currency_and_framework(
113 config: BalanceTrackerConfig,
114 currency: String,
115 framework: &str,
116 ) -> Self {
117 Self {
118 config,
119 balances: HashMap::new(),
120 account_types: HashMap::new(),
121 framework_accounts: FrameworkAccounts::for_framework(framework),
122 history: HashMap::new(),
123 validation_errors: Vec::new(),
124 stats: TrackerStatistics::default(),
125 currency,
126 }
127 }
128
129 pub fn new_with_currency(config: BalanceTrackerConfig, currency: String) -> Self {
131 Self::new_with_currency_and_framework(config, currency, "us_gaap")
132 }
133
134 pub fn new(config: BalanceTrackerConfig) -> Self {
136 Self::new_with_currency(config, "USD".to_string())
137 }
138
139 pub fn new_with_framework(config: BalanceTrackerConfig, framework: &str) -> Self {
141 Self::new_with_currency_and_framework(config, "USD".to_string(), framework)
142 }
143
144 pub fn with_defaults() -> Self {
146 Self::new(BalanceTrackerConfig::default())
147 }
148
149 pub fn register_account_type(&mut self, account_code: &str, account_type: AccountType) {
151 self.account_types
152 .insert(account_code.to_string(), account_type);
153 }
154
155 pub fn register_account_types(&mut self, types: &[(String, AccountType)]) {
157 for (code, account_type) in types {
158 self.account_types.insert(code.clone(), *account_type);
159 }
160 }
161
162 pub fn register_from_chart_prefixes(&mut self, prefixes: &[(&str, AccountType)]) {
164 for (prefix, account_type) in prefixes {
165 self.account_types.insert(prefix.to_string(), *account_type);
166 }
167 }
168
169 pub fn initialize_from_snapshot(&mut self, snapshot: &BalanceSnapshot) {
171 let company_balances = self
172 .balances
173 .entry(snapshot.company_code.clone())
174 .or_default();
175
176 for (account_code, balance) in &snapshot.balances {
177 company_balances.insert(account_code.clone(), balance.clone());
178 }
179
180 self.stats.companies_tracked = self.balances.len();
181 self.stats.accounts_tracked = self.balances.values().map(|b| b.len()).sum();
182 }
183
184 pub fn apply_entry(&mut self, entry: &JournalEntry) -> Result<(), ValidationError> {
186 if !entry.is_balanced() {
188 let error = ValidationError {
189 date: entry.posting_date(),
190 company_code: entry.company_code().to_string(),
191 entry_id: Some(entry.document_number().clone()),
192 error_type: ValidationErrorType::UnbalancedEntry,
193 message: format!(
194 "Entry {} is unbalanced: debits={}, credits={}",
195 entry.document_number(),
196 entry.total_debit(),
197 entry.total_credit()
198 ),
199 details: {
200 let mut d = HashMap::new();
201 d.insert("total_debit".to_string(), entry.total_debit());
202 d.insert("total_credit".to_string(), entry.total_credit());
203 d
204 },
205 };
206
207 if self.config.fail_on_validation_error {
208 return Err(error);
209 }
210 self.validation_errors.push(error);
211 }
212
213 let company_code = entry.company_code().to_string();
215 let document_number = entry.document_number().clone();
216 let posting_date = entry.posting_date();
217 let track_history = self.config.track_history;
218
219 let line_data: Vec<_> = entry
221 .lines
222 .iter()
223 .map(|line| {
224 let account_type = self.determine_account_type(&line.account_code);
225 (line.clone(), account_type)
226 })
227 .collect();
228
229 let company_balances = self.balances.entry(company_code.clone()).or_default();
231
232 let mut history_entries = Vec::new();
234
235 for (line, account_type) in &line_data {
237 let balance = company_balances
239 .entry(line.account_code.clone())
240 .or_insert_with(|| {
241 AccountBalance::new(
242 company_code.clone(),
243 line.account_code.clone(),
244 *account_type,
245 self.currency.clone(),
246 posting_date.year(),
247 posting_date.month(),
248 )
249 });
250
251 let previous_balance = balance.closing_balance;
252
253 if line.debit_amount > Decimal::ZERO {
255 balance.apply_debit(line.debit_amount);
256 }
257 if line.credit_amount > Decimal::ZERO {
258 balance.apply_credit(line.credit_amount);
259 }
260
261 let new_balance = balance.closing_balance;
262
263 if track_history {
265 let change = line.debit_amount - line.credit_amount;
266 history_entries.push(BalanceHistoryEntry {
267 date: posting_date,
268 entry_id: document_number.clone(),
269 account_code: line.account_code.clone(),
270 previous_balance,
271 change,
272 new_balance,
273 });
274 }
275 }
276
277 if !history_entries.is_empty() {
279 let hist = self.history.entry(company_code.clone()).or_default();
280 hist.extend(history_entries);
281 }
282
283 self.stats.entries_processed += 1;
285 self.stats.lines_processed += entry.lines.len() as u64;
286 self.stats.total_debits += entry.total_debit();
287 self.stats.total_credits += entry.total_credit();
288 self.stats.companies_tracked = self.balances.len();
289 self.stats.accounts_tracked = self.balances.values().map(|b| b.len()).sum();
290
291 if self.config.validate_on_each_entry {
293 self.validate_balance_sheet(
294 entry.company_code(),
295 entry.posting_date(),
296 Some(&entry.document_number()),
297 )?;
298 }
299
300 Ok(())
301 }
302
303 pub fn apply_entries(&mut self, entries: &[JournalEntry]) -> Vec<ValidationError> {
305 debug!(
306 entry_count = entries.len(),
307 companies_tracked = self.stats.companies_tracked,
308 accounts_tracked = self.stats.accounts_tracked,
309 "Applying entries to balance tracker"
310 );
311
312 let mut errors = Vec::new();
313
314 for entry in entries {
315 if let Err(error) = self.apply_entry(entry) {
316 errors.push(error);
317 }
318 }
319
320 errors
321 }
322
323 fn determine_account_type(&self, account_code: &str) -> AccountType {
328 for (registered_code, account_type) in &self.account_types {
330 if account_code.starts_with(registered_code) {
331 return *account_type;
332 }
333 }
334
335 self.framework_accounts.classify_account_type(account_code)
337 }
338
339 pub fn validate_balance_sheet(
341 &mut self,
342 company_code: &str,
343 date: NaiveDate,
344 entry_id: Option<&str>,
345 ) -> Result<(), ValidationError> {
346 let Some(company_balances) = self.balances.get(company_code) else {
347 return Ok(()); };
349
350 let mut total_assets = Decimal::ZERO;
351 let mut total_liabilities = Decimal::ZERO;
352 let mut total_equity = Decimal::ZERO;
353 let mut total_revenue = Decimal::ZERO;
354 let mut total_expenses = Decimal::ZERO;
355
356 for (account_code, balance) in company_balances {
357 let account_type = self.determine_account_type(account_code);
358 match account_type {
359 AccountType::Asset => total_assets += balance.closing_balance,
360 AccountType::ContraAsset => total_assets -= balance.closing_balance.abs(),
361 AccountType::Liability => total_liabilities += balance.closing_balance.abs(),
362 AccountType::ContraLiability => total_liabilities -= balance.closing_balance.abs(),
363 AccountType::Equity => total_equity += balance.closing_balance.abs(),
364 AccountType::ContraEquity => total_equity -= balance.closing_balance.abs(),
365 AccountType::Revenue => total_revenue += balance.closing_balance.abs(),
366 AccountType::Expense => total_expenses += balance.closing_balance.abs(),
367 }
368 }
369
370 let net_income = total_revenue - total_expenses;
372
373 let left_side = total_assets;
375 let right_side = total_liabilities + total_equity + net_income;
376 let difference = (left_side - right_side).abs();
377
378 if difference > self.config.balance_tolerance {
379 let error = ValidationError {
380 date,
381 company_code: company_code.to_string(),
382 entry_id: entry_id.map(String::from),
383 error_type: ValidationErrorType::BalanceSheetImbalance,
384 message: format!(
385 "Balance sheet imbalance: Assets ({}) != L + E + NI ({}), diff = {}",
386 left_side, right_side, difference
387 ),
388 details: {
389 let mut d = HashMap::new();
390 d.insert("total_assets".to_string(), total_assets);
391 d.insert("total_liabilities".to_string(), total_liabilities);
392 d.insert("total_equity".to_string(), total_equity);
393 d.insert("net_income".to_string(), net_income);
394 d.insert("difference".to_string(), difference);
395 d
396 },
397 };
398
399 self.stats.validation_errors += 1;
400
401 if self.config.fail_on_validation_error {
402 return Err(error);
403 }
404 self.validation_errors.push(error);
405 }
406
407 Ok(())
408 }
409
410 pub fn get_snapshot(
412 &self,
413 company_code: &str,
414 as_of_date: NaiveDate,
415 ) -> Option<BalanceSnapshot> {
416 use chrono::Datelike;
417 let currency = self.currency.clone();
418 self.balances.get(company_code).map(|balances| {
419 let mut snapshot = BalanceSnapshot::new(
420 format!("SNAP-{}-{}", company_code, as_of_date),
421 company_code.to_string(),
422 as_of_date,
423 as_of_date.year(),
424 as_of_date.month(),
425 currency,
426 );
427 for (account, balance) in balances {
428 snapshot.balances.insert(account.clone(), balance.clone());
429 }
430 snapshot.recalculate_totals();
431 snapshot
432 })
433 }
434
435 pub fn get_all_snapshots(&self, as_of_date: NaiveDate) -> Vec<BalanceSnapshot> {
437 use chrono::Datelike;
438 self.balances
439 .iter()
440 .map(|(company_code, balances)| {
441 let mut snapshot = BalanceSnapshot::new(
442 format!("SNAP-{}-{}", company_code, as_of_date),
443 company_code.clone(),
444 as_of_date,
445 as_of_date.year(),
446 as_of_date.month(),
447 self.currency.clone(),
448 );
449 for (account, balance) in balances {
450 snapshot.balances.insert(account.clone(), balance.clone());
451 }
452 snapshot.recalculate_totals();
453 snapshot
454 })
455 .collect()
456 }
457
458 pub fn get_balance_changes(
460 &self,
461 company_code: &str,
462 from_date: NaiveDate,
463 to_date: NaiveDate,
464 ) -> Vec<AccountPeriodActivity> {
465 let Some(history) = self.history.get(company_code) else {
466 return Vec::new();
467 };
468
469 let mut changes_by_account: HashMap<String, AccountPeriodActivity> = HashMap::new();
470
471 for entry in history
472 .iter()
473 .filter(|e| e.date >= from_date && e.date <= to_date)
474 {
475 let change = changes_by_account
476 .entry(entry.account_code.clone())
477 .or_insert_with(|| AccountPeriodActivity {
478 account_code: entry.account_code.clone(),
479 period_start: from_date,
480 period_end: to_date,
481 opening_balance: Decimal::ZERO,
482 closing_balance: Decimal::ZERO,
483 total_debits: Decimal::ZERO,
484 total_credits: Decimal::ZERO,
485 net_change: Decimal::ZERO,
486 transaction_count: 0,
487 });
488
489 if entry.change > Decimal::ZERO {
490 change.total_debits += entry.change;
491 } else {
492 change.total_credits += entry.change.abs();
493 }
494 change.net_change += entry.change;
495 change.transaction_count += 1;
496 }
497
498 if let Some(company_balances) = self.balances.get(company_code) {
500 for change in changes_by_account.values_mut() {
501 if let Some(balance) = company_balances.get(&change.account_code) {
502 change.closing_balance = balance.closing_balance;
503 change.opening_balance = change.closing_balance - change.net_change;
504 }
505 }
506 }
507
508 changes_by_account.into_values().collect()
509 }
510
511 pub fn get_account_balance(
513 &self,
514 company_code: &str,
515 account_code: &str,
516 ) -> Option<&AccountBalance> {
517 self.balances
518 .get(company_code)
519 .and_then(|b| b.get(account_code))
520 }
521
522 pub fn get_validation_errors(&self) -> &[ValidationError] {
524 &self.validation_errors
525 }
526
527 pub fn clear_validation_errors(&mut self) {
529 self.validation_errors.clear();
530 self.stats.validation_errors = 0;
531 }
532
533 pub fn get_statistics(&self) -> &TrackerStatistics {
535 &self.stats
536 }
537
538 pub fn roll_forward(&mut self, _new_period_start: NaiveDate) {
540 for company_balances in self.balances.values_mut() {
541 for balance in company_balances.values_mut() {
542 balance.roll_forward();
543 }
544 }
545 }
546
547 pub fn export_balances(&self, company_code: &str) -> Vec<(String, Decimal)> {
549 self.balances
550 .get(company_code)
551 .map(|balances| {
552 balances
553 .iter()
554 .map(|(code, balance)| (code.clone(), balance.closing_balance))
555 .collect()
556 })
557 .unwrap_or_default()
558 }
559}
560
561#[cfg(test)]
562#[allow(clippy::unwrap_used)]
563mod tests {
564 use super::*;
565 use datasynth_core::models::{JournalEntry, JournalEntryLine};
566
567 fn create_test_entry(
568 company: &str,
569 account1: &str,
570 account2: &str,
571 amount: Decimal,
572 ) -> JournalEntry {
573 let mut entry = JournalEntry::new_simple(
574 "TEST001".to_string(),
575 company.to_string(),
576 NaiveDate::from_ymd_opt(2024, 1, 15).unwrap(),
577 "Test entry".to_string(),
578 );
579
580 entry.add_line(JournalEntryLine {
581 line_number: 1,
582 gl_account: account1.to_string(),
583 account_code: account1.to_string(),
584 debit_amount: amount,
585 ..Default::default()
586 });
587
588 entry.add_line(JournalEntryLine {
589 line_number: 2,
590 gl_account: account2.to_string(),
591 account_code: account2.to_string(),
592 credit_amount: amount,
593 ..Default::default()
594 });
595
596 entry
597 }
598
599 #[test]
600 fn test_apply_balanced_entry() {
601 let mut tracker = RunningBalanceTracker::with_defaults();
602 tracker.register_account_type("1100", AccountType::Asset);
603 tracker.register_account_type("4000", AccountType::Revenue);
604
605 let entry = create_test_entry("1000", "1100", "4000", dec!(1000));
606 let result = tracker.apply_entry(&entry);
607
608 assert!(result.is_ok());
609 assert_eq!(tracker.stats.entries_processed, 1);
610 assert_eq!(tracker.stats.lines_processed, 2);
611 }
612
613 #[test]
614 fn test_balance_accumulation() {
615 let mut tracker = RunningBalanceTracker::with_defaults();
616 tracker.config.validate_on_each_entry = false;
617
618 let entry1 = create_test_entry("1000", "1100", "4000", dec!(1000));
619 let entry2 = create_test_entry("1000", "1100", "4000", dec!(500));
620
621 tracker.apply_entry(&entry1).unwrap();
622 tracker.apply_entry(&entry2).unwrap();
623
624 let balance = tracker.get_account_balance("1000", "1100").unwrap();
625 assert_eq!(balance.closing_balance, dec!(1500));
626 }
627
628 #[test]
629 fn test_get_snapshot() {
630 let mut tracker = RunningBalanceTracker::with_defaults();
631 tracker.config.validate_on_each_entry = false;
632
633 let entry = create_test_entry("1000", "1100", "2000", dec!(1000));
634 tracker.apply_entry(&entry).unwrap();
635
636 let snapshot = tracker
637 .get_snapshot("1000", NaiveDate::from_ymd_opt(2024, 1, 31).unwrap())
638 .unwrap();
639
640 assert_eq!(snapshot.balances.len(), 2);
641 }
642
643 #[test]
644 fn test_determine_account_type_from_prefix() {
645 let tracker = RunningBalanceTracker::with_defaults();
646
647 assert_eq!(tracker.determine_account_type("1000"), AccountType::Asset);
648 assert_eq!(
649 tracker.determine_account_type("2000"),
650 AccountType::Liability
651 );
652 assert_eq!(tracker.determine_account_type("3000"), AccountType::Equity);
653 assert_eq!(tracker.determine_account_type("4000"), AccountType::Revenue);
654 assert_eq!(tracker.determine_account_type("5000"), AccountType::Expense);
655 }
656
657 #[test]
658 fn test_determine_account_type_french_gaap() {
659 let tracker = RunningBalanceTracker::new_with_framework(
660 BalanceTrackerConfig::default(),
661 "french_gaap",
662 );
663
664 assert_eq!(tracker.determine_account_type("210000"), AccountType::Asset);
666 assert_eq!(
668 tracker.determine_account_type("101000"),
669 AccountType::Equity
670 );
671 assert_eq!(
673 tracker.determine_account_type("401000"),
674 AccountType::Liability
675 );
676 assert_eq!(
678 tracker.determine_account_type("603000"),
679 AccountType::Expense
680 );
681 assert_eq!(
683 tracker.determine_account_type("701000"),
684 AccountType::Revenue
685 );
686 }
687
688 #[test]
689 fn test_determine_account_type_german_gaap() {
690 let tracker = RunningBalanceTracker::new_with_framework(
691 BalanceTrackerConfig::default(),
692 "german_gaap",
693 );
694
695 assert_eq!(tracker.determine_account_type("0200"), AccountType::Asset);
697 assert_eq!(tracker.determine_account_type("2000"), AccountType::Equity);
699 assert_eq!(
701 tracker.determine_account_type("3300"),
702 AccountType::Liability
703 );
704 assert_eq!(tracker.determine_account_type("4000"), AccountType::Revenue);
706 assert_eq!(tracker.determine_account_type("5000"), AccountType::Expense);
708 }
709}