1use chrono::NaiveDate;
11use rust_decimal::Decimal;
12use rust_decimal_macros::dec;
13use std::collections::HashMap;
14
15use datasynth_core::models::balance::{
16 AccountBalance, AccountCategory, AccountType, BalanceSnapshot, CategorySummary,
17 ComparativeTrialBalance, TrialBalance, TrialBalanceLine, TrialBalanceStatus, TrialBalanceType,
18};
19use datasynth_core::models::ChartOfAccounts;
20
21use super::RunningBalanceTracker;
22
23#[derive(Debug, Clone)]
25pub struct TrialBalanceConfig {
26 pub include_zero_balances: bool,
28 pub group_by_category: bool,
30 pub generate_subtotals: bool,
32 pub sort_by_account_code: bool,
34 pub trial_balance_type: TrialBalanceType,
36}
37
38impl Default for TrialBalanceConfig {
39 fn default() -> Self {
40 Self {
41 include_zero_balances: false,
42 group_by_category: true,
43 generate_subtotals: true,
44 sort_by_account_code: true,
45 trial_balance_type: TrialBalanceType::Unadjusted,
46 }
47 }
48}
49
50pub struct TrialBalanceGenerator {
52 config: TrialBalanceConfig,
53 category_mappings: HashMap<String, AccountCategory>,
55 account_descriptions: HashMap<String, String>,
57}
58
59impl TrialBalanceGenerator {
60 pub fn new(config: TrialBalanceConfig) -> Self {
62 Self {
63 config,
64 category_mappings: HashMap::new(),
65 account_descriptions: HashMap::new(),
66 }
67 }
68
69 pub fn with_defaults() -> Self {
71 Self::new(TrialBalanceConfig::default())
72 }
73
74 pub fn register_from_chart(&mut self, chart: &ChartOfAccounts) {
76 for account in &chart.accounts {
77 self.account_descriptions.insert(
78 account.account_code().to_string(),
79 account.description().to_string(),
80 );
81
82 let category = self.determine_category(account.account_code());
84 self.category_mappings
85 .insert(account.account_code().to_string(), category);
86 }
87 }
88
89 pub fn register_category(&mut self, account_code: &str, category: AccountCategory) {
91 self.category_mappings
92 .insert(account_code.to_string(), category);
93 }
94
95 pub fn generate_from_snapshot(
97 &self,
98 snapshot: &BalanceSnapshot,
99 fiscal_year: i32,
100 fiscal_period: u32,
101 ) -> TrialBalance {
102 let mut lines = Vec::new();
103 let mut total_debits = Decimal::ZERO;
104 let mut total_credits = Decimal::ZERO;
105
106 for (account_code, balance) in &snapshot.balances {
108 if !self.config.include_zero_balances && balance.closing_balance == Decimal::ZERO {
109 continue;
110 }
111
112 let (debit, credit) = self.split_balance(balance);
113 total_debits += debit;
114 total_credits += credit;
115
116 let category = self.determine_category(account_code);
117 let description = self
118 .account_descriptions
119 .get(account_code)
120 .cloned()
121 .unwrap_or_else(|| format!("Account {}", account_code));
122
123 lines.push(TrialBalanceLine {
124 account_code: account_code.clone(),
125 account_description: description,
126 category,
127 account_type: balance.account_type,
128 debit_balance: debit,
129 credit_balance: credit,
130 opening_balance: balance.opening_balance,
131 period_debits: balance.period_debits,
132 period_credits: balance.period_credits,
133 closing_balance: balance.closing_balance,
134 cost_center: None,
135 profit_center: None,
136 });
137 }
138
139 if self.config.sort_by_account_code {
141 lines.sort_by(|a, b| a.account_code.cmp(&b.account_code));
142 }
143
144 let category_summary = if self.config.group_by_category {
146 self.calculate_category_summary(&lines)
147 } else {
148 HashMap::new()
149 };
150
151 let out_of_balance = total_debits - total_credits;
152
153 let mut tb = TrialBalance {
154 trial_balance_id: format!(
155 "TB-{}-{}-{:02}",
156 snapshot.company_code, fiscal_year, fiscal_period
157 ),
158 company_code: snapshot.company_code.clone(),
159 company_name: None,
160 as_of_date: snapshot.as_of_date,
161 fiscal_year,
162 fiscal_period,
163 currency: snapshot.currency.clone(),
164 balance_type: self.config.trial_balance_type,
165 lines,
166 total_debits,
167 total_credits,
168 is_balanced: out_of_balance.abs() < dec!(0.01),
169 out_of_balance,
170 is_equation_valid: false, equation_difference: Decimal::ZERO, category_summary,
173 created_at: chrono::Utc::now().naive_utc(),
174 created_by: "TrialBalanceGenerator".to_string(),
175 approved_by: None,
176 approved_at: None,
177 status: TrialBalanceStatus::Draft,
178 };
179
180 let (is_valid, _assets, _liabilities, _equity, diff) = tb.validate_accounting_equation();
182 tb.is_equation_valid = is_valid;
183 tb.equation_difference = diff;
184
185 tb
186 }
187
188 pub fn generate_from_tracker(
190 &self,
191 tracker: &RunningBalanceTracker,
192 company_code: &str,
193 as_of_date: NaiveDate,
194 fiscal_year: i32,
195 fiscal_period: u32,
196 ) -> Option<TrialBalance> {
197 tracker
198 .get_snapshot(company_code, as_of_date)
199 .map(|snapshot| self.generate_from_snapshot(&snapshot, fiscal_year, fiscal_period))
200 }
201
202 pub fn generate_all_from_tracker(
204 &self,
205 tracker: &RunningBalanceTracker,
206 as_of_date: NaiveDate,
207 fiscal_year: i32,
208 fiscal_period: u32,
209 ) -> Vec<TrialBalance> {
210 tracker
211 .get_all_snapshots(as_of_date)
212 .iter()
213 .map(|snapshot| self.generate_from_snapshot(snapshot, fiscal_year, fiscal_period))
214 .collect()
215 }
216
217 pub fn generate_comparative(
219 &self,
220 snapshots: &[(NaiveDate, BalanceSnapshot)],
221 fiscal_year: i32,
222 ) -> ComparativeTrialBalance {
223 use datasynth_core::models::balance::ComparativeTrialBalanceLine;
224
225 let trial_balances: Vec<TrialBalance> = snapshots
227 .iter()
228 .enumerate()
229 .map(|(i, (date, snapshot))| {
230 let mut tb = self.generate_from_snapshot(snapshot, fiscal_year, (i + 1) as u32);
231 tb.as_of_date = *date;
232 tb
233 })
234 .collect();
235
236 let periods: Vec<(i32, u32)> = trial_balances
238 .iter()
239 .map(|tb| (tb.fiscal_year, tb.fiscal_period))
240 .collect();
241
242 let mut lines_map: HashMap<String, ComparativeTrialBalanceLine> = HashMap::new();
244
245 for tb in &trial_balances {
246 for line in &tb.lines {
247 let entry = lines_map
248 .entry(line.account_code.clone())
249 .or_insert_with(|| ComparativeTrialBalanceLine {
250 account_code: line.account_code.clone(),
251 account_description: line.account_description.clone(),
252 category: line.category,
253 period_balances: HashMap::new(),
254 period_changes: HashMap::new(),
255 });
256
257 entry
258 .period_balances
259 .insert((tb.fiscal_year, tb.fiscal_period), line.closing_balance);
260 }
261 }
262
263 for line in lines_map.values_mut() {
265 let mut sorted_periods: Vec<_> = line.period_balances.keys().cloned().collect();
266 sorted_periods.sort();
267
268 for i in 1..sorted_periods.len() {
269 let prev_period = sorted_periods[i - 1];
270 let curr_period = sorted_periods[i];
271
272 if let (Some(&prev_balance), Some(&curr_balance)) = (
273 line.period_balances.get(&prev_period),
274 line.period_balances.get(&curr_period),
275 ) {
276 line.period_changes
277 .insert(curr_period, curr_balance - prev_balance);
278 }
279 }
280 }
281
282 let lines: Vec<ComparativeTrialBalanceLine> = lines_map.into_values().collect();
283
284 let company_code = snapshots
285 .first()
286 .map(|(_, s)| s.company_code.clone())
287 .unwrap_or_default();
288
289 let currency = snapshots
290 .first()
291 .map(|(_, s)| s.currency.clone())
292 .unwrap_or_else(|| "USD".to_string());
293
294 ComparativeTrialBalance {
295 company_code,
296 currency,
297 periods,
298 lines,
299 created_at: chrono::Utc::now().naive_utc(),
300 }
301 }
302
303 pub fn generate_consolidated(
305 &self,
306 trial_balances: &[TrialBalance],
307 consolidated_company_code: &str,
308 ) -> TrialBalance {
309 let mut consolidated_balances: HashMap<String, TrialBalanceLine> = HashMap::new();
310
311 for tb in trial_balances {
312 for line in &tb.lines {
313 let entry = consolidated_balances
314 .entry(line.account_code.clone())
315 .or_insert_with(|| TrialBalanceLine {
316 account_code: line.account_code.clone(),
317 account_description: line.account_description.clone(),
318 category: line.category,
319 account_type: line.account_type,
320 debit_balance: Decimal::ZERO,
321 credit_balance: Decimal::ZERO,
322 opening_balance: Decimal::ZERO,
323 period_debits: Decimal::ZERO,
324 period_credits: Decimal::ZERO,
325 closing_balance: Decimal::ZERO,
326 cost_center: None,
327 profit_center: None,
328 });
329
330 entry.debit_balance += line.debit_balance;
331 entry.credit_balance += line.credit_balance;
332 entry.opening_balance += line.opening_balance;
333 entry.period_debits += line.period_debits;
334 entry.period_credits += line.period_credits;
335 entry.closing_balance += line.closing_balance;
336 }
337 }
338
339 let mut lines: Vec<TrialBalanceLine> = consolidated_balances.into_values().collect();
340 if self.config.sort_by_account_code {
341 lines.sort_by(|a, b| a.account_code.cmp(&b.account_code));
342 }
343
344 let total_debits: Decimal = lines.iter().map(|l| l.debit_balance).sum();
345 let total_credits: Decimal = lines.iter().map(|l| l.credit_balance).sum();
346
347 let category_summary = if self.config.group_by_category {
348 self.calculate_category_summary(&lines)
349 } else {
350 HashMap::new()
351 };
352
353 let as_of_date = trial_balances
354 .first()
355 .map(|tb| tb.as_of_date)
356 .unwrap_or_else(|| chrono::Local::now().date_naive());
357
358 let fiscal_year = trial_balances.first().map(|tb| tb.fiscal_year).unwrap_or(0);
359 let fiscal_period = trial_balances
360 .first()
361 .map(|tb| tb.fiscal_period)
362 .unwrap_or(0);
363
364 let currency = trial_balances
365 .first()
366 .map(|tb| tb.currency.clone())
367 .unwrap_or_else(|| "USD".to_string());
368
369 let out_of_balance = total_debits - total_credits;
370
371 let mut tb = TrialBalance {
372 trial_balance_id: format!(
373 "TB-CONS-{}-{}-{:02}",
374 consolidated_company_code, fiscal_year, fiscal_period
375 ),
376 company_code: consolidated_company_code.to_string(),
377 company_name: None,
378 as_of_date,
379 fiscal_year,
380 fiscal_period,
381 currency,
382 balance_type: TrialBalanceType::Consolidated,
383 lines,
384 total_debits,
385 total_credits,
386 is_balanced: out_of_balance.abs() < dec!(0.01),
387 out_of_balance,
388 is_equation_valid: false, equation_difference: Decimal::ZERO, category_summary,
391 created_at: chrono::Utc::now().naive_utc(),
392 created_by: format!(
393 "TrialBalanceGenerator (Consolidated from {} companies)",
394 trial_balances.len()
395 ),
396 approved_by: None,
397 approved_at: None,
398 status: TrialBalanceStatus::Draft,
399 };
400
401 let (is_valid, _assets, _liabilities, _equity, diff) = tb.validate_accounting_equation();
403 tb.is_equation_valid = is_valid;
404 tb.equation_difference = diff;
405
406 tb
407 }
408
409 fn split_balance(&self, balance: &AccountBalance) -> (Decimal, Decimal) {
411 let closing = balance.closing_balance;
412
413 match balance.account_type {
415 AccountType::Asset | AccountType::Expense => {
416 if closing >= Decimal::ZERO {
417 (closing, Decimal::ZERO)
418 } else {
419 (Decimal::ZERO, closing.abs())
420 }
421 }
422 AccountType::ContraAsset | AccountType::ContraLiability | AccountType::ContraEquity => {
423 if closing >= Decimal::ZERO {
425 (Decimal::ZERO, closing)
426 } else {
427 (closing.abs(), Decimal::ZERO)
428 }
429 }
430 AccountType::Liability | AccountType::Equity | AccountType::Revenue => {
431 if closing >= Decimal::ZERO {
432 (Decimal::ZERO, closing)
433 } else {
434 (closing.abs(), Decimal::ZERO)
435 }
436 }
437 }
438 }
439
440 fn determine_category(&self, account_code: &str) -> AccountCategory {
442 if let Some(category) = self.category_mappings.get(account_code) {
444 return *category;
445 }
446
447 let prefix: u32 = account_code
449 .chars()
450 .take(2)
451 .collect::<String>()
452 .parse()
453 .unwrap_or(0);
454
455 match prefix {
456 10..=14 => AccountCategory::CurrentAssets,
457 15..=19 => AccountCategory::NonCurrentAssets,
458 20..=24 => AccountCategory::CurrentLiabilities,
459 25..=29 => AccountCategory::NonCurrentLiabilities,
460 30..=39 => AccountCategory::Equity,
461 40..=44 => AccountCategory::Revenue,
462 50..=54 => AccountCategory::CostOfGoodsSold,
463 55..=69 => AccountCategory::OperatingExpenses,
464 70..=74 => AccountCategory::OtherIncome,
465 75..=99 => AccountCategory::OtherExpenses,
466 _ => AccountCategory::OtherExpenses,
467 }
468 }
469
470 fn calculate_category_summary(
472 &self,
473 lines: &[TrialBalanceLine],
474 ) -> HashMap<AccountCategory, CategorySummary> {
475 let mut summaries: HashMap<AccountCategory, CategorySummary> = HashMap::new();
476
477 for line in lines {
478 let summary = summaries
479 .entry(line.category)
480 .or_insert_with(|| CategorySummary::new(line.category));
481
482 summary.add_balance(line.debit_balance, line.credit_balance);
483 }
484
485 summaries
486 }
487
488 fn calculate_period_variances(
490 &self,
491 periods: &[TrialBalance],
492 ) -> HashMap<String, Vec<Decimal>> {
493 let mut variances: HashMap<String, Vec<Decimal>> = HashMap::new();
494
495 if periods.len() < 2 {
496 return variances;
497 }
498
499 let mut all_accounts: Vec<String> = periods
501 .iter()
502 .flat_map(|p| p.lines.iter().map(|l| l.account_code.clone()))
503 .collect();
504 all_accounts.sort();
505 all_accounts.dedup();
506
507 for account in all_accounts {
509 let mut period_variances = Vec::new();
510
511 for i in 1..periods.len() {
512 let current = periods[i]
513 .lines
514 .iter()
515 .find(|l| l.account_code == account)
516 .map(|l| l.closing_balance)
517 .unwrap_or_default();
518
519 let previous = periods[i - 1]
520 .lines
521 .iter()
522 .find(|l| l.account_code == account)
523 .map(|l| l.closing_balance)
524 .unwrap_or_default();
525
526 period_variances.push(current - previous);
527 }
528
529 variances.insert(account, period_variances);
530 }
531
532 variances
533 }
534
535 pub fn finalize(&self, mut trial_balance: TrialBalance) -> TrialBalance {
537 trial_balance.status = TrialBalanceStatus::Final;
538 trial_balance
539 }
540
541 pub fn approve(&self, mut trial_balance: TrialBalance, approver: &str) -> TrialBalance {
543 trial_balance.status = TrialBalanceStatus::Approved;
544 trial_balance.approved_by = Some(approver.to_string());
545 trial_balance.approved_at = Some(chrono::Utc::now().naive_utc());
546 trial_balance
547 }
548}
549
550pub struct TrialBalanceBuilder {
552 generator: TrialBalanceGenerator,
553 snapshots: Vec<(String, BalanceSnapshot)>,
554 fiscal_year: i32,
555 fiscal_period: u32,
556}
557
558impl TrialBalanceBuilder {
559 pub fn new(fiscal_year: i32, fiscal_period: u32) -> Self {
561 Self {
562 generator: TrialBalanceGenerator::with_defaults(),
563 snapshots: Vec::new(),
564 fiscal_year,
565 fiscal_period,
566 }
567 }
568
569 pub fn add_snapshot(mut self, company_code: &str, snapshot: BalanceSnapshot) -> Self {
571 self.snapshots.push((company_code.to_string(), snapshot));
572 self
573 }
574
575 pub fn with_config(mut self, config: TrialBalanceConfig) -> Self {
577 self.generator = TrialBalanceGenerator::new(config);
578 self
579 }
580
581 pub fn build(self) -> Vec<TrialBalance> {
583 self.snapshots
584 .iter()
585 .map(|(_, snapshot)| {
586 self.generator.generate_from_snapshot(
587 snapshot,
588 self.fiscal_year,
589 self.fiscal_period,
590 )
591 })
592 .collect()
593 }
594
595 pub fn build_consolidated(self, consolidated_code: &str) -> TrialBalance {
597 let individual = self
598 .snapshots
599 .iter()
600 .map(|(_, snapshot)| {
601 self.generator.generate_from_snapshot(
602 snapshot,
603 self.fiscal_year,
604 self.fiscal_period,
605 )
606 })
607 .collect::<Vec<_>>();
608
609 self.generator
610 .generate_consolidated(&individual, consolidated_code)
611 }
612}
613
614#[cfg(test)]
615mod tests {
616 use super::*;
617
618 fn create_test_balance(
619 company: &str,
620 account: &str,
621 acct_type: AccountType,
622 opening: Decimal,
623 ) -> AccountBalance {
624 let mut bal = AccountBalance::new(
625 company.to_string(),
626 account.to_string(),
627 acct_type,
628 "USD".to_string(),
629 2024,
630 1,
631 );
632 bal.opening_balance = opening;
633 bal.closing_balance = opening;
634 bal
635 }
636
637 fn create_test_snapshot() -> BalanceSnapshot {
638 let mut snapshot = BalanceSnapshot::new(
639 "SNAP-TEST-2024-01".to_string(),
640 "TEST".to_string(),
641 NaiveDate::from_ymd_opt(2024, 1, 31).unwrap(),
642 2024,
643 1,
644 "USD".to_string(),
645 );
646
647 snapshot.balances.insert(
649 "1100".to_string(),
650 create_test_balance("TEST", "1100", AccountType::Asset, dec!(10000)),
651 );
652
653 snapshot.balances.insert(
655 "2100".to_string(),
656 create_test_balance("TEST", "2100", AccountType::Liability, dec!(5000)),
657 );
658
659 snapshot.balances.insert(
661 "3100".to_string(),
662 create_test_balance("TEST", "3100", AccountType::Equity, dec!(5000)),
663 );
664
665 snapshot.recalculate_totals();
666 snapshot
667 }
668
669 #[test]
670 fn test_generate_trial_balance() {
671 let generator = TrialBalanceGenerator::with_defaults();
672 let snapshot = create_test_snapshot();
673
674 let tb = generator.generate_from_snapshot(&snapshot, 2024, 1);
675
676 assert!(tb.is_balanced);
677 assert_eq!(tb.lines.len(), 3);
678 assert_eq!(tb.total_debits, dec!(10000));
679 assert_eq!(tb.total_credits, dec!(10000));
680 }
681
682 #[test]
683 fn test_category_summaries() {
684 let generator = TrialBalanceGenerator::with_defaults();
685 let snapshot = create_test_snapshot();
686
687 let tb = generator.generate_from_snapshot(&snapshot, 2024, 1);
688
689 assert!(!tb.category_summary.is_empty());
690 }
691
692 #[test]
693 fn test_consolidated_trial_balance() {
694 let generator = TrialBalanceGenerator::with_defaults();
695
696 let snapshot1 = create_test_snapshot();
697 let mut snapshot2 = BalanceSnapshot::new(
698 "SNAP-TEST2-2024-01".to_string(),
699 "TEST2".to_string(),
700 snapshot1.as_of_date,
701 2024,
702 1,
703 "USD".to_string(),
704 );
705
706 for (code, balance) in &snapshot1.balances {
708 let mut new_bal = balance.clone();
709 new_bal.company_code = "TEST2".to_string();
710 new_bal.closing_balance *= dec!(2);
711 new_bal.opening_balance *= dec!(2);
712 snapshot2.balances.insert(code.clone(), new_bal);
713 }
714 snapshot2.recalculate_totals();
715
716 let tb1 = generator.generate_from_snapshot(&snapshot1, 2024, 1);
717 let tb2 = generator.generate_from_snapshot(&snapshot2, 2024, 1);
718
719 let consolidated = generator.generate_consolidated(&[tb1, tb2], "CONSOL");
720
721 assert_eq!(consolidated.company_code, "CONSOL");
722 assert!(consolidated.is_balanced);
723 }
724
725 #[test]
726 fn test_builder_pattern() {
727 let snapshot = create_test_snapshot();
728
729 let trial_balances = TrialBalanceBuilder::new(2024, 1)
730 .add_snapshot("TEST", snapshot)
731 .build();
732
733 assert_eq!(trial_balances.len(), 1);
734 assert!(trial_balances[0].is_balanced);
735 }
736}