Skip to main content

datasynth_generators/balance/
trial_balance_generator.rs

1//! Trial balance generator.
2//!
3//! Generates trial balances at period end from running balance snapshots,
4//! with support for:
5//! - Unadjusted, adjusted, and post-closing trial balances
6//! - Category summaries and subtotals
7//! - Comparative trial balances across periods
8//! - Consolidated trial balances across companies
9
10use chrono::NaiveDate;
11use rust_decimal::Decimal;
12use rust_decimal_macros::dec;
13use std::collections::HashMap;
14use tracing::debug;
15
16use datasynth_core::models::balance::{
17    AccountBalance, AccountCategory, AccountType, BalanceSnapshot, CategorySummary,
18    ComparativeTrialBalance, TrialBalance, TrialBalanceLine, TrialBalanceStatus, TrialBalanceType,
19};
20use datasynth_core::models::ChartOfAccounts;
21use datasynth_core::FrameworkAccounts;
22
23use super::RunningBalanceTracker;
24
25/// Configuration for trial balance generation.
26#[derive(Debug, Clone)]
27pub struct TrialBalanceConfig {
28    /// Include zero balance accounts.
29    pub include_zero_balances: bool,
30    /// Group accounts by category.
31    pub group_by_category: bool,
32    /// Generate category subtotals.
33    pub generate_subtotals: bool,
34    /// Sort accounts by code.
35    pub sort_by_account_code: bool,
36    /// Trial balance type to generate.
37    pub trial_balance_type: TrialBalanceType,
38}
39
40impl Default for TrialBalanceConfig {
41    fn default() -> Self {
42        Self {
43            include_zero_balances: false,
44            group_by_category: true,
45            generate_subtotals: true,
46            sort_by_account_code: true,
47            trial_balance_type: TrialBalanceType::Unadjusted,
48        }
49    }
50}
51
52/// Generator for trial balance reports.
53pub struct TrialBalanceGenerator {
54    config: TrialBalanceConfig,
55    /// Account category mappings.
56    category_mappings: HashMap<String, AccountCategory>,
57    /// Account descriptions.
58    account_descriptions: HashMap<String, String>,
59    /// Framework-aware account classification.
60    framework_accounts: FrameworkAccounts,
61}
62
63impl TrialBalanceGenerator {
64    /// Creates a new trial balance generator for a specific accounting framework.
65    pub fn new_with_framework(config: TrialBalanceConfig, framework: &str) -> Self {
66        Self {
67            config,
68            category_mappings: HashMap::new(),
69            account_descriptions: HashMap::new(),
70            framework_accounts: FrameworkAccounts::for_framework(framework),
71        }
72    }
73
74    /// Creates a new trial balance generator (defaults to US GAAP).
75    pub fn new(config: TrialBalanceConfig) -> Self {
76        Self::new_with_framework(config, "us_gaap")
77    }
78
79    /// Creates a generator with default configuration (US GAAP).
80    pub fn with_defaults() -> Self {
81        Self::new(TrialBalanceConfig::default())
82    }
83
84    /// Registers category mappings from chart of accounts.
85    pub fn register_from_chart(&mut self, chart: &ChartOfAccounts) {
86        for account in &chart.accounts {
87            self.account_descriptions.insert(
88                account.account_code().to_string(),
89                account.description().to_string(),
90            );
91
92            // Determine category from account code prefix
93            let category = self.determine_category(account.account_code());
94            self.category_mappings
95                .insert(account.account_code().to_string(), category);
96        }
97    }
98
99    /// Registers a custom category mapping.
100    pub fn register_category(&mut self, account_code: &str, category: AccountCategory) {
101        self.category_mappings
102            .insert(account_code.to_string(), category);
103    }
104
105    /// Generates a trial balance from a balance snapshot.
106    pub fn generate_from_snapshot(
107        &self,
108        snapshot: &BalanceSnapshot,
109        fiscal_year: i32,
110        fiscal_period: u32,
111    ) -> TrialBalance {
112        debug!(
113            company_code = %snapshot.company_code,
114            fiscal_year,
115            fiscal_period,
116            balance_count = snapshot.balances.len(),
117            "Generating trial balance from snapshot"
118        );
119
120        let mut lines = Vec::new();
121        let mut total_debits = Decimal::ZERO;
122        let mut total_credits = Decimal::ZERO;
123
124        // Convert balances to trial balance lines
125        for (account_code, balance) in &snapshot.balances {
126            if !self.config.include_zero_balances && balance.closing_balance == Decimal::ZERO {
127                continue;
128            }
129
130            let (debit, credit) = self.split_balance(balance);
131            total_debits += debit;
132            total_credits += credit;
133
134            let category = self.determine_category(account_code);
135            let description = self
136                .account_descriptions
137                .get(account_code)
138                .cloned()
139                .unwrap_or_else(|| format!("Account {}", account_code));
140
141            lines.push(TrialBalanceLine {
142                account_code: account_code.clone(),
143                account_description: description,
144                category,
145                account_type: balance.account_type,
146                debit_balance: debit,
147                credit_balance: credit,
148                opening_balance: balance.opening_balance,
149                period_debits: balance.period_debits,
150                period_credits: balance.period_credits,
151                closing_balance: balance.closing_balance,
152                cost_center: None,
153                profit_center: None,
154            });
155        }
156
157        // Sort lines
158        if self.config.sort_by_account_code {
159            lines.sort_by(|a, b| a.account_code.cmp(&b.account_code));
160        }
161
162        // Calculate category summaries
163        let category_summary = if self.config.group_by_category {
164            self.calculate_category_summary(&lines)
165        } else {
166            HashMap::new()
167        };
168
169        let out_of_balance = total_debits - total_credits;
170
171        let mut tb = TrialBalance {
172            trial_balance_id: format!(
173                "TB-{}-{}-{:02}",
174                snapshot.company_code, fiscal_year, fiscal_period
175            ),
176            company_code: snapshot.company_code.clone(),
177            company_name: None,
178            as_of_date: snapshot.as_of_date,
179            fiscal_year,
180            fiscal_period,
181            currency: snapshot.currency.clone(),
182            balance_type: self.config.trial_balance_type,
183            lines,
184            total_debits,
185            total_credits,
186            is_balanced: out_of_balance.abs() < dec!(0.01),
187            out_of_balance,
188            is_equation_valid: false,           // Will be calculated below
189            equation_difference: Decimal::ZERO, // Will be calculated below
190            category_summary,
191            created_at: snapshot
192                .as_of_date
193                .and_hms_opt(23, 59, 59)
194                .unwrap_or_default(),
195            created_by: "TrialBalanceGenerator".to_string(),
196            approved_by: None,
197            approved_at: None,
198            status: TrialBalanceStatus::Draft,
199        };
200
201        // Calculate and set accounting equation validity
202        let (is_valid, _assets, _liabilities, _equity, diff) = tb.validate_accounting_equation();
203        tb.is_equation_valid = is_valid;
204        tb.equation_difference = diff;
205
206        tb
207    }
208
209    /// Generates a trial balance from the balance tracker.
210    pub fn generate_from_tracker(
211        &self,
212        tracker: &RunningBalanceTracker,
213        company_code: &str,
214        as_of_date: NaiveDate,
215        fiscal_year: i32,
216        fiscal_period: u32,
217    ) -> Option<TrialBalance> {
218        tracker
219            .get_snapshot(company_code, as_of_date)
220            .map(|snapshot| self.generate_from_snapshot(&snapshot, fiscal_year, fiscal_period))
221    }
222
223    /// Generates trial balances for all companies in the tracker.
224    pub fn generate_all_from_tracker(
225        &self,
226        tracker: &RunningBalanceTracker,
227        as_of_date: NaiveDate,
228        fiscal_year: i32,
229        fiscal_period: u32,
230    ) -> Vec<TrialBalance> {
231        tracker
232            .get_all_snapshots(as_of_date)
233            .iter()
234            .map(|snapshot| self.generate_from_snapshot(snapshot, fiscal_year, fiscal_period))
235            .collect()
236    }
237
238    /// Generates a comparative trial balance across multiple periods.
239    pub fn generate_comparative(
240        &self,
241        snapshots: &[(NaiveDate, BalanceSnapshot)],
242        fiscal_year: i32,
243    ) -> ComparativeTrialBalance {
244        use datasynth_core::models::balance::ComparativeTrialBalanceLine;
245
246        // Generate trial balances for each period
247        let trial_balances: Vec<TrialBalance> = snapshots
248            .iter()
249            .enumerate()
250            .map(|(i, (date, snapshot))| {
251                let mut tb = self.generate_from_snapshot(snapshot, fiscal_year, (i + 1) as u32);
252                tb.as_of_date = *date;
253                tb
254            })
255            .collect();
256
257        // Build periods list
258        let periods: Vec<(i32, u32)> = trial_balances
259            .iter()
260            .map(|tb| (tb.fiscal_year, tb.fiscal_period))
261            .collect();
262
263        // Build comparative lines
264        let mut lines_map: HashMap<String, ComparativeTrialBalanceLine> = HashMap::new();
265
266        for tb in &trial_balances {
267            for line in &tb.lines {
268                let entry = lines_map
269                    .entry(line.account_code.clone())
270                    .or_insert_with(|| ComparativeTrialBalanceLine {
271                        account_code: line.account_code.clone(),
272                        account_description: line.account_description.clone(),
273                        category: line.category,
274                        period_balances: HashMap::new(),
275                        period_changes: HashMap::new(),
276                    });
277
278                entry
279                    .period_balances
280                    .insert((tb.fiscal_year, tb.fiscal_period), line.closing_balance);
281            }
282        }
283
284        // Calculate period-over-period changes
285        for line in lines_map.values_mut() {
286            let mut sorted_periods: Vec<_> = line.period_balances.keys().cloned().collect();
287            sorted_periods.sort();
288
289            for i in 1..sorted_periods.len() {
290                let prev_period = sorted_periods[i - 1];
291                let curr_period = sorted_periods[i];
292
293                if let (Some(&prev_balance), Some(&curr_balance)) = (
294                    line.period_balances.get(&prev_period),
295                    line.period_balances.get(&curr_period),
296                ) {
297                    line.period_changes
298                        .insert(curr_period, curr_balance - prev_balance);
299                }
300            }
301        }
302
303        let lines: Vec<ComparativeTrialBalanceLine> = lines_map.into_values().collect();
304
305        let company_code = snapshots
306            .first()
307            .map(|(_, s)| s.company_code.clone())
308            .unwrap_or_default();
309
310        let currency = snapshots
311            .first()
312            .map(|(_, s)| s.currency.clone())
313            .unwrap_or_else(|| "USD".to_string());
314
315        let created_at = snapshots
316            .last()
317            .map(|(date, _)| date.and_hms_opt(23, 59, 59).unwrap_or_default())
318            .unwrap_or_default();
319
320        ComparativeTrialBalance {
321            company_code,
322            currency,
323            periods,
324            lines,
325            created_at,
326        }
327    }
328
329    /// Generates a consolidated trial balance across companies.
330    pub fn generate_consolidated(
331        &self,
332        trial_balances: &[TrialBalance],
333        consolidated_company_code: &str,
334    ) -> TrialBalance {
335        let mut consolidated_balances: HashMap<String, TrialBalanceLine> = HashMap::new();
336
337        for tb in trial_balances {
338            for line in &tb.lines {
339                let entry = consolidated_balances
340                    .entry(line.account_code.clone())
341                    .or_insert_with(|| TrialBalanceLine {
342                        account_code: line.account_code.clone(),
343                        account_description: line.account_description.clone(),
344                        category: line.category,
345                        account_type: line.account_type,
346                        debit_balance: Decimal::ZERO,
347                        credit_balance: Decimal::ZERO,
348                        opening_balance: Decimal::ZERO,
349                        period_debits: Decimal::ZERO,
350                        period_credits: Decimal::ZERO,
351                        closing_balance: Decimal::ZERO,
352                        cost_center: None,
353                        profit_center: None,
354                    });
355
356                entry.debit_balance += line.debit_balance;
357                entry.credit_balance += line.credit_balance;
358                entry.opening_balance += line.opening_balance;
359                entry.period_debits += line.period_debits;
360                entry.period_credits += line.period_credits;
361                entry.closing_balance += line.closing_balance;
362            }
363        }
364
365        let mut lines: Vec<TrialBalanceLine> = consolidated_balances.into_values().collect();
366        if self.config.sort_by_account_code {
367            lines.sort_by(|a, b| a.account_code.cmp(&b.account_code));
368        }
369
370        let total_debits: Decimal = lines.iter().map(|l| l.debit_balance).sum();
371        let total_credits: Decimal = lines.iter().map(|l| l.credit_balance).sum();
372
373        let category_summary = if self.config.group_by_category {
374            self.calculate_category_summary(&lines)
375        } else {
376            HashMap::new()
377        };
378
379        let as_of_date = trial_balances
380            .first()
381            .map(|tb| tb.as_of_date)
382            .unwrap_or_else(|| chrono::Local::now().date_naive());
383
384        let fiscal_year = trial_balances.first().map(|tb| tb.fiscal_year).unwrap_or(0);
385        let fiscal_period = trial_balances
386            .first()
387            .map(|tb| tb.fiscal_period)
388            .unwrap_or(0);
389
390        let currency = trial_balances
391            .first()
392            .map(|tb| tb.currency.clone())
393            .unwrap_or_else(|| "USD".to_string());
394
395        let out_of_balance = total_debits - total_credits;
396
397        let mut tb = TrialBalance {
398            trial_balance_id: format!(
399                "TB-CONS-{}-{}-{:02}",
400                consolidated_company_code, fiscal_year, fiscal_period
401            ),
402            company_code: consolidated_company_code.to_string(),
403            company_name: None,
404            as_of_date,
405            fiscal_year,
406            fiscal_period,
407            currency,
408            balance_type: TrialBalanceType::Consolidated,
409            lines,
410            total_debits,
411            total_credits,
412            is_balanced: out_of_balance.abs() < dec!(0.01),
413            out_of_balance,
414            is_equation_valid: false,           // Will be calculated below
415            equation_difference: Decimal::ZERO, // Will be calculated below
416            category_summary,
417            created_at: as_of_date.and_hms_opt(23, 59, 59).unwrap_or_default(),
418            created_by: format!(
419                "TrialBalanceGenerator (Consolidated from {} companies)",
420                trial_balances.len()
421            ),
422            approved_by: None,
423            approved_at: None,
424            status: TrialBalanceStatus::Draft,
425        };
426
427        // Calculate and set accounting equation validity
428        let (is_valid, _assets, _liabilities, _equity, diff) = tb.validate_accounting_equation();
429        tb.is_equation_valid = is_valid;
430        tb.equation_difference = diff;
431
432        tb
433    }
434
435    /// Splits a balance into debit and credit components.
436    fn split_balance(&self, balance: &AccountBalance) -> (Decimal, Decimal) {
437        let closing = balance.closing_balance;
438
439        // Determine natural balance side based on account type
440        match balance.account_type {
441            AccountType::Asset | AccountType::Expense => {
442                if closing >= Decimal::ZERO {
443                    (closing, Decimal::ZERO)
444                } else {
445                    (Decimal::ZERO, closing.abs())
446                }
447            }
448            AccountType::ContraAsset | AccountType::ContraLiability | AccountType::ContraEquity => {
449                // Contra accounts have opposite natural balance
450                if closing >= Decimal::ZERO {
451                    (Decimal::ZERO, closing)
452                } else {
453                    (closing.abs(), Decimal::ZERO)
454                }
455            }
456            AccountType::Liability | AccountType::Equity | AccountType::Revenue => {
457                if closing >= Decimal::ZERO {
458                    (Decimal::ZERO, closing)
459                } else {
460                    (closing.abs(), Decimal::ZERO)
461                }
462            }
463        }
464    }
465
466    /// Determines account category from code prefix.
467    ///
468    /// Checks explicitly registered mappings first, then falls back to the
469    /// framework-aware classifier from [`FrameworkAccounts`].
470    fn determine_category(&self, account_code: &str) -> AccountCategory {
471        // Check registered mappings first
472        if let Some(category) = self.category_mappings.get(account_code) {
473            return *category;
474        }
475
476        // Use framework-aware classification
477        self.framework_accounts
478            .classify_trial_balance_category(account_code)
479    }
480
481    /// Calculates category summaries from lines.
482    fn calculate_category_summary(
483        &self,
484        lines: &[TrialBalanceLine],
485    ) -> HashMap<AccountCategory, CategorySummary> {
486        let mut summaries: HashMap<AccountCategory, CategorySummary> = HashMap::new();
487
488        for line in lines {
489            let summary = summaries
490                .entry(line.category)
491                .or_insert_with(|| CategorySummary::new(line.category));
492
493            summary.add_balance(line.debit_balance, line.credit_balance);
494        }
495
496        summaries
497    }
498
499    /// Finalizes a trial balance (changes status to Final).
500    pub fn finalize(&self, mut trial_balance: TrialBalance) -> TrialBalance {
501        trial_balance.status = TrialBalanceStatus::Final;
502        trial_balance
503    }
504
505    /// Approves a trial balance.
506    pub fn approve(&self, mut trial_balance: TrialBalance, approver: &str) -> TrialBalance {
507        trial_balance.status = TrialBalanceStatus::Approved;
508        trial_balance.approved_by = Some(approver.to_string());
509        trial_balance.approved_at = Some(
510            trial_balance
511                .as_of_date
512                .succ_opt()
513                .unwrap_or(trial_balance.as_of_date)
514                .and_hms_opt(9, 0, 0)
515                .unwrap_or_default(),
516        );
517        trial_balance
518    }
519}
520
521/// Builder for trial balance generation with fluent API.
522pub struct TrialBalanceBuilder {
523    generator: TrialBalanceGenerator,
524    snapshots: Vec<(String, BalanceSnapshot)>,
525    fiscal_year: i32,
526    fiscal_period: u32,
527}
528
529impl TrialBalanceBuilder {
530    /// Creates a new builder.
531    pub fn new(fiscal_year: i32, fiscal_period: u32) -> Self {
532        Self {
533            generator: TrialBalanceGenerator::with_defaults(),
534            snapshots: Vec::new(),
535            fiscal_year,
536            fiscal_period,
537        }
538    }
539
540    /// Adds a balance snapshot.
541    pub fn add_snapshot(mut self, company_code: &str, snapshot: BalanceSnapshot) -> Self {
542        self.snapshots.push((company_code.to_string(), snapshot));
543        self
544    }
545
546    /// Sets configuration.
547    pub fn with_config(mut self, config: TrialBalanceConfig) -> Self {
548        self.generator = TrialBalanceGenerator::new(config);
549        self
550    }
551
552    /// Builds individual trial balances.
553    pub fn build(self) -> Vec<TrialBalance> {
554        self.snapshots
555            .iter()
556            .map(|(_, snapshot)| {
557                self.generator.generate_from_snapshot(
558                    snapshot,
559                    self.fiscal_year,
560                    self.fiscal_period,
561                )
562            })
563            .collect()
564    }
565
566    /// Builds a consolidated trial balance.
567    pub fn build_consolidated(self, consolidated_code: &str) -> TrialBalance {
568        let individual = self
569            .snapshots
570            .iter()
571            .map(|(_, snapshot)| {
572                self.generator.generate_from_snapshot(
573                    snapshot,
574                    self.fiscal_year,
575                    self.fiscal_period,
576                )
577            })
578            .collect::<Vec<_>>();
579
580        self.generator
581            .generate_consolidated(&individual, consolidated_code)
582    }
583}
584
585#[cfg(test)]
586#[allow(clippy::unwrap_used)]
587mod tests {
588    use super::*;
589
590    fn create_test_balance(
591        company: &str,
592        account: &str,
593        acct_type: AccountType,
594        opening: Decimal,
595    ) -> AccountBalance {
596        let mut bal = AccountBalance::new(
597            company.to_string(),
598            account.to_string(),
599            acct_type,
600            "USD".to_string(),
601            2024,
602            1,
603        );
604        bal.opening_balance = opening;
605        bal.closing_balance = opening;
606        bal
607    }
608
609    fn create_test_snapshot() -> BalanceSnapshot {
610        let mut snapshot = BalanceSnapshot::new(
611            "SNAP-TEST-2024-01".to_string(),
612            "TEST".to_string(),
613            NaiveDate::from_ymd_opt(2024, 1, 31).unwrap(),
614            2024,
615            1,
616            "USD".to_string(),
617        );
618
619        // Add assets
620        snapshot.balances.insert(
621            "1100".to_string(),
622            create_test_balance("TEST", "1100", AccountType::Asset, dec!(10000)),
623        );
624
625        // Add liabilities
626        snapshot.balances.insert(
627            "2100".to_string(),
628            create_test_balance("TEST", "2100", AccountType::Liability, dec!(5000)),
629        );
630
631        // Add equity
632        snapshot.balances.insert(
633            "3100".to_string(),
634            create_test_balance("TEST", "3100", AccountType::Equity, dec!(5000)),
635        );
636
637        snapshot.recalculate_totals();
638        snapshot
639    }
640
641    #[test]
642    fn test_generate_trial_balance() {
643        let generator = TrialBalanceGenerator::with_defaults();
644        let snapshot = create_test_snapshot();
645
646        let tb = generator.generate_from_snapshot(&snapshot, 2024, 1);
647
648        assert!(tb.is_balanced);
649        assert_eq!(tb.lines.len(), 3);
650        assert_eq!(tb.total_debits, dec!(10000));
651        assert_eq!(tb.total_credits, dec!(10000));
652    }
653
654    #[test]
655    fn test_category_summaries() {
656        let generator = TrialBalanceGenerator::with_defaults();
657        let snapshot = create_test_snapshot();
658
659        let tb = generator.generate_from_snapshot(&snapshot, 2024, 1);
660
661        assert!(!tb.category_summary.is_empty());
662    }
663
664    #[test]
665    fn test_consolidated_trial_balance() {
666        let generator = TrialBalanceGenerator::with_defaults();
667
668        let snapshot1 = create_test_snapshot();
669        let mut snapshot2 = BalanceSnapshot::new(
670            "SNAP-TEST2-2024-01".to_string(),
671            "TEST2".to_string(),
672            snapshot1.as_of_date,
673            2024,
674            1,
675            "USD".to_string(),
676        );
677
678        // Copy and double the balances
679        for (code, balance) in &snapshot1.balances {
680            let mut new_bal = balance.clone();
681            new_bal.company_code = "TEST2".to_string();
682            new_bal.closing_balance *= dec!(2);
683            new_bal.opening_balance *= dec!(2);
684            snapshot2.balances.insert(code.clone(), new_bal);
685        }
686        snapshot2.recalculate_totals();
687
688        let tb1 = generator.generate_from_snapshot(&snapshot1, 2024, 1);
689        let tb2 = generator.generate_from_snapshot(&snapshot2, 2024, 1);
690
691        let consolidated = generator.generate_consolidated(&[tb1, tb2], "CONSOL");
692
693        assert_eq!(consolidated.company_code, "CONSOL");
694        assert!(consolidated.is_balanced);
695    }
696
697    #[test]
698    fn test_builder_pattern() {
699        let snapshot = create_test_snapshot();
700
701        let trial_balances = TrialBalanceBuilder::new(2024, 1)
702            .add_snapshot("TEST", snapshot)
703            .build();
704
705        assert_eq!(trial_balances.len(), 1);
706        assert!(trial_balances[0].is_balanced);
707    }
708}