Skip to main content

datasynth_generators/balance/
trial_balance_generator.rs

1//! Trial balance generator.
2//!
3//! Generates trial balances at period end from running balance snapshots,
4//! with support for:
5//! - Unadjusted, adjusted, and post-closing trial balances
6//! - Category summaries and subtotals
7//! - Comparative trial balances across periods
8//! - Consolidated trial balances across companies
9
10use chrono::NaiveDate;
11use rust_decimal::Decimal;
12use rust_decimal_macros::dec;
13use std::collections::HashMap;
14
15use datasynth_core::models::balance::{
16    AccountBalance, AccountCategory, AccountType, BalanceSnapshot, CategorySummary,
17    ComparativeTrialBalance, TrialBalance, TrialBalanceLine, TrialBalanceStatus, TrialBalanceType,
18};
19use datasynth_core::models::ChartOfAccounts;
20
21use super::RunningBalanceTracker;
22
23/// Configuration for trial balance generation.
24#[derive(Debug, Clone)]
25pub struct TrialBalanceConfig {
26    /// Include zero balance accounts.
27    pub include_zero_balances: bool,
28    /// Group accounts by category.
29    pub group_by_category: bool,
30    /// Generate category subtotals.
31    pub generate_subtotals: bool,
32    /// Sort accounts by code.
33    pub sort_by_account_code: bool,
34    /// Trial balance type to generate.
35    pub trial_balance_type: TrialBalanceType,
36}
37
38impl Default for TrialBalanceConfig {
39    fn default() -> Self {
40        Self {
41            include_zero_balances: false,
42            group_by_category: true,
43            generate_subtotals: true,
44            sort_by_account_code: true,
45            trial_balance_type: TrialBalanceType::Unadjusted,
46        }
47    }
48}
49
50/// Generator for trial balance reports.
51pub struct TrialBalanceGenerator {
52    config: TrialBalanceConfig,
53    /// Account category mappings.
54    category_mappings: HashMap<String, AccountCategory>,
55    /// Account descriptions.
56    account_descriptions: HashMap<String, String>,
57}
58
59impl TrialBalanceGenerator {
60    /// Creates a new trial balance generator.
61    pub fn new(config: TrialBalanceConfig) -> Self {
62        Self {
63            config,
64            category_mappings: HashMap::new(),
65            account_descriptions: HashMap::new(),
66        }
67    }
68
69    /// Creates a generator with default configuration.
70    pub fn with_defaults() -> Self {
71        Self::new(TrialBalanceConfig::default())
72    }
73
74    /// Registers category mappings from chart of accounts.
75    pub fn register_from_chart(&mut self, chart: &ChartOfAccounts) {
76        for account in &chart.accounts {
77            self.account_descriptions.insert(
78                account.account_code().to_string(),
79                account.description().to_string(),
80            );
81
82            // Determine category from account code prefix
83            let category = self.determine_category(account.account_code());
84            self.category_mappings
85                .insert(account.account_code().to_string(), category);
86        }
87    }
88
89    /// Registers a custom category mapping.
90    pub fn register_category(&mut self, account_code: &str, category: AccountCategory) {
91        self.category_mappings
92            .insert(account_code.to_string(), category);
93    }
94
95    /// Generates a trial balance from a balance snapshot.
96    pub fn generate_from_snapshot(
97        &self,
98        snapshot: &BalanceSnapshot,
99        fiscal_year: i32,
100        fiscal_period: u32,
101    ) -> TrialBalance {
102        let mut lines = Vec::new();
103        let mut total_debits = Decimal::ZERO;
104        let mut total_credits = Decimal::ZERO;
105
106        // Convert balances to trial balance lines
107        for (account_code, balance) in &snapshot.balances {
108            if !self.config.include_zero_balances && balance.closing_balance == Decimal::ZERO {
109                continue;
110            }
111
112            let (debit, credit) = self.split_balance(balance);
113            total_debits += debit;
114            total_credits += credit;
115
116            let category = self.determine_category(account_code);
117            let description = self
118                .account_descriptions
119                .get(account_code)
120                .cloned()
121                .unwrap_or_else(|| format!("Account {}", account_code));
122
123            lines.push(TrialBalanceLine {
124                account_code: account_code.clone(),
125                account_description: description,
126                category,
127                account_type: balance.account_type,
128                debit_balance: debit,
129                credit_balance: credit,
130                opening_balance: balance.opening_balance,
131                period_debits: balance.period_debits,
132                period_credits: balance.period_credits,
133                closing_balance: balance.closing_balance,
134                cost_center: None,
135                profit_center: None,
136            });
137        }
138
139        // Sort lines
140        if self.config.sort_by_account_code {
141            lines.sort_by(|a, b| a.account_code.cmp(&b.account_code));
142        }
143
144        // Calculate category summaries
145        let category_summary = if self.config.group_by_category {
146            self.calculate_category_summary(&lines)
147        } else {
148            HashMap::new()
149        };
150
151        let out_of_balance = total_debits - total_credits;
152
153        let mut tb = TrialBalance {
154            trial_balance_id: format!(
155                "TB-{}-{}-{:02}",
156                snapshot.company_code, fiscal_year, fiscal_period
157            ),
158            company_code: snapshot.company_code.clone(),
159            company_name: None,
160            as_of_date: snapshot.as_of_date,
161            fiscal_year,
162            fiscal_period,
163            currency: snapshot.currency.clone(),
164            balance_type: self.config.trial_balance_type,
165            lines,
166            total_debits,
167            total_credits,
168            is_balanced: out_of_balance.abs() < dec!(0.01),
169            out_of_balance,
170            is_equation_valid: false,           // Will be calculated below
171            equation_difference: Decimal::ZERO, // Will be calculated below
172            category_summary,
173            created_at: chrono::Utc::now().naive_utc(),
174            created_by: "TrialBalanceGenerator".to_string(),
175            approved_by: None,
176            approved_at: None,
177            status: TrialBalanceStatus::Draft,
178        };
179
180        // Calculate and set accounting equation validity
181        let (is_valid, _assets, _liabilities, _equity, diff) = tb.validate_accounting_equation();
182        tb.is_equation_valid = is_valid;
183        tb.equation_difference = diff;
184
185        tb
186    }
187
188    /// Generates a trial balance from the balance tracker.
189    pub fn generate_from_tracker(
190        &self,
191        tracker: &RunningBalanceTracker,
192        company_code: &str,
193        as_of_date: NaiveDate,
194        fiscal_year: i32,
195        fiscal_period: u32,
196    ) -> Option<TrialBalance> {
197        tracker
198            .get_snapshot(company_code, as_of_date)
199            .map(|snapshot| self.generate_from_snapshot(&snapshot, fiscal_year, fiscal_period))
200    }
201
202    /// Generates trial balances for all companies in the tracker.
203    pub fn generate_all_from_tracker(
204        &self,
205        tracker: &RunningBalanceTracker,
206        as_of_date: NaiveDate,
207        fiscal_year: i32,
208        fiscal_period: u32,
209    ) -> Vec<TrialBalance> {
210        tracker
211            .get_all_snapshots(as_of_date)
212            .iter()
213            .map(|snapshot| self.generate_from_snapshot(snapshot, fiscal_year, fiscal_period))
214            .collect()
215    }
216
217    /// Generates a comparative trial balance across multiple periods.
218    pub fn generate_comparative(
219        &self,
220        snapshots: &[(NaiveDate, BalanceSnapshot)],
221        fiscal_year: i32,
222    ) -> ComparativeTrialBalance {
223        use datasynth_core::models::balance::ComparativeTrialBalanceLine;
224
225        // Generate trial balances for each period
226        let trial_balances: Vec<TrialBalance> = snapshots
227            .iter()
228            .enumerate()
229            .map(|(i, (date, snapshot))| {
230                let mut tb = self.generate_from_snapshot(snapshot, fiscal_year, (i + 1) as u32);
231                tb.as_of_date = *date;
232                tb
233            })
234            .collect();
235
236        // Build periods list
237        let periods: Vec<(i32, u32)> = trial_balances
238            .iter()
239            .map(|tb| (tb.fiscal_year, tb.fiscal_period))
240            .collect();
241
242        // Build comparative lines
243        let mut lines_map: HashMap<String, ComparativeTrialBalanceLine> = HashMap::new();
244
245        for tb in &trial_balances {
246            for line in &tb.lines {
247                let entry = lines_map
248                    .entry(line.account_code.clone())
249                    .or_insert_with(|| ComparativeTrialBalanceLine {
250                        account_code: line.account_code.clone(),
251                        account_description: line.account_description.clone(),
252                        category: line.category,
253                        period_balances: HashMap::new(),
254                        period_changes: HashMap::new(),
255                    });
256
257                entry
258                    .period_balances
259                    .insert((tb.fiscal_year, tb.fiscal_period), line.closing_balance);
260            }
261        }
262
263        // Calculate period-over-period changes
264        for line in lines_map.values_mut() {
265            let mut sorted_periods: Vec<_> = line.period_balances.keys().cloned().collect();
266            sorted_periods.sort();
267
268            for i in 1..sorted_periods.len() {
269                let prev_period = sorted_periods[i - 1];
270                let curr_period = sorted_periods[i];
271
272                if let (Some(&prev_balance), Some(&curr_balance)) = (
273                    line.period_balances.get(&prev_period),
274                    line.period_balances.get(&curr_period),
275                ) {
276                    line.period_changes
277                        .insert(curr_period, curr_balance - prev_balance);
278                }
279            }
280        }
281
282        let lines: Vec<ComparativeTrialBalanceLine> = lines_map.into_values().collect();
283
284        let company_code = snapshots
285            .first()
286            .map(|(_, s)| s.company_code.clone())
287            .unwrap_or_default();
288
289        let currency = snapshots
290            .first()
291            .map(|(_, s)| s.currency.clone())
292            .unwrap_or_else(|| "USD".to_string());
293
294        ComparativeTrialBalance {
295            company_code,
296            currency,
297            periods,
298            lines,
299            created_at: chrono::Utc::now().naive_utc(),
300        }
301    }
302
303    /// Generates a consolidated trial balance across companies.
304    pub fn generate_consolidated(
305        &self,
306        trial_balances: &[TrialBalance],
307        consolidated_company_code: &str,
308    ) -> TrialBalance {
309        let mut consolidated_balances: HashMap<String, TrialBalanceLine> = HashMap::new();
310
311        for tb in trial_balances {
312            for line in &tb.lines {
313                let entry = consolidated_balances
314                    .entry(line.account_code.clone())
315                    .or_insert_with(|| TrialBalanceLine {
316                        account_code: line.account_code.clone(),
317                        account_description: line.account_description.clone(),
318                        category: line.category,
319                        account_type: line.account_type,
320                        debit_balance: Decimal::ZERO,
321                        credit_balance: Decimal::ZERO,
322                        opening_balance: Decimal::ZERO,
323                        period_debits: Decimal::ZERO,
324                        period_credits: Decimal::ZERO,
325                        closing_balance: Decimal::ZERO,
326                        cost_center: None,
327                        profit_center: None,
328                    });
329
330                entry.debit_balance += line.debit_balance;
331                entry.credit_balance += line.credit_balance;
332                entry.opening_balance += line.opening_balance;
333                entry.period_debits += line.period_debits;
334                entry.period_credits += line.period_credits;
335                entry.closing_balance += line.closing_balance;
336            }
337        }
338
339        let mut lines: Vec<TrialBalanceLine> = consolidated_balances.into_values().collect();
340        if self.config.sort_by_account_code {
341            lines.sort_by(|a, b| a.account_code.cmp(&b.account_code));
342        }
343
344        let total_debits: Decimal = lines.iter().map(|l| l.debit_balance).sum();
345        let total_credits: Decimal = lines.iter().map(|l| l.credit_balance).sum();
346
347        let category_summary = if self.config.group_by_category {
348            self.calculate_category_summary(&lines)
349        } else {
350            HashMap::new()
351        };
352
353        let as_of_date = trial_balances
354            .first()
355            .map(|tb| tb.as_of_date)
356            .unwrap_or_else(|| chrono::Local::now().date_naive());
357
358        let fiscal_year = trial_balances.first().map(|tb| tb.fiscal_year).unwrap_or(0);
359        let fiscal_period = trial_balances
360            .first()
361            .map(|tb| tb.fiscal_period)
362            .unwrap_or(0);
363
364        let currency = trial_balances
365            .first()
366            .map(|tb| tb.currency.clone())
367            .unwrap_or_else(|| "USD".to_string());
368
369        let out_of_balance = total_debits - total_credits;
370
371        let mut tb = TrialBalance {
372            trial_balance_id: format!(
373                "TB-CONS-{}-{}-{:02}",
374                consolidated_company_code, fiscal_year, fiscal_period
375            ),
376            company_code: consolidated_company_code.to_string(),
377            company_name: None,
378            as_of_date,
379            fiscal_year,
380            fiscal_period,
381            currency,
382            balance_type: TrialBalanceType::Consolidated,
383            lines,
384            total_debits,
385            total_credits,
386            is_balanced: out_of_balance.abs() < dec!(0.01),
387            out_of_balance,
388            is_equation_valid: false,           // Will be calculated below
389            equation_difference: Decimal::ZERO, // Will be calculated below
390            category_summary,
391            created_at: chrono::Utc::now().naive_utc(),
392            created_by: format!(
393                "TrialBalanceGenerator (Consolidated from {} companies)",
394                trial_balances.len()
395            ),
396            approved_by: None,
397            approved_at: None,
398            status: TrialBalanceStatus::Draft,
399        };
400
401        // Calculate and set accounting equation validity
402        let (is_valid, _assets, _liabilities, _equity, diff) = tb.validate_accounting_equation();
403        tb.is_equation_valid = is_valid;
404        tb.equation_difference = diff;
405
406        tb
407    }
408
409    /// Splits a balance into debit and credit components.
410    fn split_balance(&self, balance: &AccountBalance) -> (Decimal, Decimal) {
411        let closing = balance.closing_balance;
412
413        // Determine natural balance side based on account type
414        match balance.account_type {
415            AccountType::Asset | AccountType::Expense => {
416                if closing >= Decimal::ZERO {
417                    (closing, Decimal::ZERO)
418                } else {
419                    (Decimal::ZERO, closing.abs())
420                }
421            }
422            AccountType::ContraAsset | AccountType::ContraLiability | AccountType::ContraEquity => {
423                // Contra accounts have opposite natural balance
424                if closing >= Decimal::ZERO {
425                    (Decimal::ZERO, closing)
426                } else {
427                    (closing.abs(), Decimal::ZERO)
428                }
429            }
430            AccountType::Liability | AccountType::Equity | AccountType::Revenue => {
431                if closing >= Decimal::ZERO {
432                    (Decimal::ZERO, closing)
433                } else {
434                    (closing.abs(), Decimal::ZERO)
435                }
436            }
437        }
438    }
439
440    /// Determines account category from code prefix.
441    fn determine_category(&self, account_code: &str) -> AccountCategory {
442        // Check registered mappings first
443        if let Some(category) = self.category_mappings.get(account_code) {
444            return *category;
445        }
446
447        // Default logic based on account code ranges
448        let prefix: u32 = account_code
449            .chars()
450            .take(2)
451            .collect::<String>()
452            .parse()
453            .unwrap_or(0);
454
455        match prefix {
456            10..=14 => AccountCategory::CurrentAssets,
457            15..=19 => AccountCategory::NonCurrentAssets,
458            20..=24 => AccountCategory::CurrentLiabilities,
459            25..=29 => AccountCategory::NonCurrentLiabilities,
460            30..=39 => AccountCategory::Equity,
461            40..=44 => AccountCategory::Revenue,
462            50..=54 => AccountCategory::CostOfGoodsSold,
463            55..=69 => AccountCategory::OperatingExpenses,
464            70..=74 => AccountCategory::OtherIncome,
465            75..=99 => AccountCategory::OtherExpenses,
466            _ => AccountCategory::OtherExpenses,
467        }
468    }
469
470    /// Calculates category summaries from lines.
471    fn calculate_category_summary(
472        &self,
473        lines: &[TrialBalanceLine],
474    ) -> HashMap<AccountCategory, CategorySummary> {
475        let mut summaries: HashMap<AccountCategory, CategorySummary> = HashMap::new();
476
477        for line in lines {
478            let summary = summaries
479                .entry(line.category)
480                .or_insert_with(|| CategorySummary::new(line.category));
481
482            summary.add_balance(line.debit_balance, line.credit_balance);
483        }
484
485        summaries
486    }
487
488    /// Calculates variances between periods.
489    fn calculate_period_variances(
490        &self,
491        periods: &[TrialBalance],
492    ) -> HashMap<String, Vec<Decimal>> {
493        let mut variances: HashMap<String, Vec<Decimal>> = HashMap::new();
494
495        if periods.len() < 2 {
496            return variances;
497        }
498
499        // Collect all account codes
500        let mut all_accounts: Vec<String> = periods
501            .iter()
502            .flat_map(|p| p.lines.iter().map(|l| l.account_code.clone()))
503            .collect();
504        all_accounts.sort();
505        all_accounts.dedup();
506
507        // Calculate period-over-period variances
508        for account in all_accounts {
509            let mut period_variances = Vec::new();
510
511            for i in 1..periods.len() {
512                let current = periods[i]
513                    .lines
514                    .iter()
515                    .find(|l| l.account_code == account)
516                    .map(|l| l.closing_balance)
517                    .unwrap_or_default();
518
519                let previous = periods[i - 1]
520                    .lines
521                    .iter()
522                    .find(|l| l.account_code == account)
523                    .map(|l| l.closing_balance)
524                    .unwrap_or_default();
525
526                period_variances.push(current - previous);
527            }
528
529            variances.insert(account, period_variances);
530        }
531
532        variances
533    }
534
535    /// Finalizes a trial balance (changes status to Final).
536    pub fn finalize(&self, mut trial_balance: TrialBalance) -> TrialBalance {
537        trial_balance.status = TrialBalanceStatus::Final;
538        trial_balance
539    }
540
541    /// Approves a trial balance.
542    pub fn approve(&self, mut trial_balance: TrialBalance, approver: &str) -> TrialBalance {
543        trial_balance.status = TrialBalanceStatus::Approved;
544        trial_balance.approved_by = Some(approver.to_string());
545        trial_balance.approved_at = Some(chrono::Utc::now().naive_utc());
546        trial_balance
547    }
548}
549
550/// Builder for trial balance generation with fluent API.
551pub struct TrialBalanceBuilder {
552    generator: TrialBalanceGenerator,
553    snapshots: Vec<(String, BalanceSnapshot)>,
554    fiscal_year: i32,
555    fiscal_period: u32,
556}
557
558impl TrialBalanceBuilder {
559    /// Creates a new builder.
560    pub fn new(fiscal_year: i32, fiscal_period: u32) -> Self {
561        Self {
562            generator: TrialBalanceGenerator::with_defaults(),
563            snapshots: Vec::new(),
564            fiscal_year,
565            fiscal_period,
566        }
567    }
568
569    /// Adds a balance snapshot.
570    pub fn add_snapshot(mut self, company_code: &str, snapshot: BalanceSnapshot) -> Self {
571        self.snapshots.push((company_code.to_string(), snapshot));
572        self
573    }
574
575    /// Sets configuration.
576    pub fn with_config(mut self, config: TrialBalanceConfig) -> Self {
577        self.generator = TrialBalanceGenerator::new(config);
578        self
579    }
580
581    /// Builds individual trial balances.
582    pub fn build(self) -> Vec<TrialBalance> {
583        self.snapshots
584            .iter()
585            .map(|(_, snapshot)| {
586                self.generator.generate_from_snapshot(
587                    snapshot,
588                    self.fiscal_year,
589                    self.fiscal_period,
590                )
591            })
592            .collect()
593    }
594
595    /// Builds a consolidated trial balance.
596    pub fn build_consolidated(self, consolidated_code: &str) -> TrialBalance {
597        let individual = self
598            .snapshots
599            .iter()
600            .map(|(_, snapshot)| {
601                self.generator.generate_from_snapshot(
602                    snapshot,
603                    self.fiscal_year,
604                    self.fiscal_period,
605                )
606            })
607            .collect::<Vec<_>>();
608
609        self.generator
610            .generate_consolidated(&individual, consolidated_code)
611    }
612}
613
614#[cfg(test)]
615mod tests {
616    use super::*;
617
618    fn create_test_balance(
619        company: &str,
620        account: &str,
621        acct_type: AccountType,
622        opening: Decimal,
623    ) -> AccountBalance {
624        let mut bal = AccountBalance::new(
625            company.to_string(),
626            account.to_string(),
627            acct_type,
628            "USD".to_string(),
629            2024,
630            1,
631        );
632        bal.opening_balance = opening;
633        bal.closing_balance = opening;
634        bal
635    }
636
637    fn create_test_snapshot() -> BalanceSnapshot {
638        let mut snapshot = BalanceSnapshot::new(
639            "SNAP-TEST-2024-01".to_string(),
640            "TEST".to_string(),
641            NaiveDate::from_ymd_opt(2024, 1, 31).unwrap(),
642            2024,
643            1,
644            "USD".to_string(),
645        );
646
647        // Add assets
648        snapshot.balances.insert(
649            "1100".to_string(),
650            create_test_balance("TEST", "1100", AccountType::Asset, dec!(10000)),
651        );
652
653        // Add liabilities
654        snapshot.balances.insert(
655            "2100".to_string(),
656            create_test_balance("TEST", "2100", AccountType::Liability, dec!(5000)),
657        );
658
659        // Add equity
660        snapshot.balances.insert(
661            "3100".to_string(),
662            create_test_balance("TEST", "3100", AccountType::Equity, dec!(5000)),
663        );
664
665        snapshot.recalculate_totals();
666        snapshot
667    }
668
669    #[test]
670    fn test_generate_trial_balance() {
671        let generator = TrialBalanceGenerator::with_defaults();
672        let snapshot = create_test_snapshot();
673
674        let tb = generator.generate_from_snapshot(&snapshot, 2024, 1);
675
676        assert!(tb.is_balanced);
677        assert_eq!(tb.lines.len(), 3);
678        assert_eq!(tb.total_debits, dec!(10000));
679        assert_eq!(tb.total_credits, dec!(10000));
680    }
681
682    #[test]
683    fn test_category_summaries() {
684        let generator = TrialBalanceGenerator::with_defaults();
685        let snapshot = create_test_snapshot();
686
687        let tb = generator.generate_from_snapshot(&snapshot, 2024, 1);
688
689        assert!(!tb.category_summary.is_empty());
690    }
691
692    #[test]
693    fn test_consolidated_trial_balance() {
694        let generator = TrialBalanceGenerator::with_defaults();
695
696        let snapshot1 = create_test_snapshot();
697        let mut snapshot2 = BalanceSnapshot::new(
698            "SNAP-TEST2-2024-01".to_string(),
699            "TEST2".to_string(),
700            snapshot1.as_of_date,
701            2024,
702            1,
703            "USD".to_string(),
704        );
705
706        // Copy and double the balances
707        for (code, balance) in &snapshot1.balances {
708            let mut new_bal = balance.clone();
709            new_bal.company_code = "TEST2".to_string();
710            new_bal.closing_balance *= dec!(2);
711            new_bal.opening_balance *= dec!(2);
712            snapshot2.balances.insert(code.clone(), new_bal);
713        }
714        snapshot2.recalculate_totals();
715
716        let tb1 = generator.generate_from_snapshot(&snapshot1, 2024, 1);
717        let tb2 = generator.generate_from_snapshot(&snapshot2, 2024, 1);
718
719        let consolidated = generator.generate_consolidated(&[tb1, tb2], "CONSOL");
720
721        assert_eq!(consolidated.company_code, "CONSOL");
722        assert!(consolidated.is_balanced);
723    }
724
725    #[test]
726    fn test_builder_pattern() {
727        let snapshot = create_test_snapshot();
728
729        let trial_balances = TrialBalanceBuilder::new(2024, 1)
730            .add_snapshot("TEST", snapshot)
731            .build();
732
733        assert_eq!(trial_balances.len(), 1);
734        assert!(trial_balances[0].is_balanced);
735    }
736}