Skip to main content

datasynth_eval/coherence/
balance.rs

1//! Balance sheet equation validation.
2//!
3//! Validates that Assets = Liabilities + Equity + Net Income across all periods.
4
5use crate::error::{EvalError, EvalResult};
6use rust_decimal::Decimal;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Results of balance sheet evaluation.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct BalanceSheetEvaluation {
13    /// Whether the balance sheet equation holds.
14    pub equation_balanced: bool,
15    /// Maximum imbalance observed across all periods.
16    pub max_imbalance: Decimal,
17    /// Number of periods evaluated.
18    pub periods_evaluated: usize,
19    /// Number of periods with imbalance.
20    pub periods_imbalanced: usize,
21    /// Per-period results.
22    pub period_results: Vec<PeriodBalanceResult>,
23    /// Companies evaluated.
24    pub companies_evaluated: usize,
25}
26
27/// Balance result for a single period.
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct PeriodBalanceResult {
30    /// Company code.
31    pub company_code: String,
32    /// Fiscal year.
33    pub fiscal_year: u16,
34    /// Fiscal period (month).
35    pub fiscal_period: u8,
36    /// Total assets.
37    pub total_assets: Decimal,
38    /// Total liabilities.
39    pub total_liabilities: Decimal,
40    /// Total equity.
41    pub total_equity: Decimal,
42    /// Net income (Revenue - Expenses).
43    pub net_income: Decimal,
44    /// Imbalance amount (should be zero).
45    pub imbalance: Decimal,
46    /// Whether this period is balanced.
47    pub is_balanced: bool,
48}
49
50/// Input for balance sheet evaluation.
51#[derive(Debug, Clone)]
52pub struct BalanceSnapshot {
53    /// Company code.
54    pub company_code: String,
55    /// Fiscal year.
56    pub fiscal_year: u16,
57    /// Fiscal period.
58    pub fiscal_period: u8,
59    /// Account balances by account type.
60    pub balances: HashMap<AccountType, Decimal>,
61}
62
63/// Account types for balance sheet.
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
65pub enum AccountType {
66    Asset,
67    ContraAsset,
68    Liability,
69    ContraLiability,
70    Equity,
71    ContraEquity,
72    Revenue,
73    Expense,
74}
75
76/// Evaluator for balance sheet equations.
77pub struct BalanceSheetEvaluator {
78    /// Tolerance for balance differences.
79    tolerance: Decimal,
80}
81
82impl BalanceSheetEvaluator {
83    /// Create a new evaluator with the specified tolerance.
84    pub fn new(tolerance: Decimal) -> Self {
85        Self { tolerance }
86    }
87
88    /// Evaluate balance sheet equation across all snapshots.
89    pub fn evaluate(&self, snapshots: &[BalanceSnapshot]) -> EvalResult<BalanceSheetEvaluation> {
90        if snapshots.is_empty() {
91            return Err(EvalError::MissingData(
92                "No balance snapshots provided".to_string(),
93            ));
94        }
95
96        let mut period_results = Vec::new();
97        let mut max_imbalance = Decimal::ZERO;
98        let mut periods_imbalanced = 0;
99        let companies: std::collections::HashSet<_> =
100            snapshots.iter().map(|s| &s.company_code).collect();
101
102        for snapshot in snapshots {
103            let result = self.evaluate_snapshot(snapshot);
104            if !result.is_balanced {
105                periods_imbalanced += 1;
106            }
107            if result.imbalance.abs() > max_imbalance {
108                max_imbalance = result.imbalance.abs();
109            }
110            period_results.push(result);
111        }
112
113        let equation_balanced = periods_imbalanced == 0;
114
115        Ok(BalanceSheetEvaluation {
116            equation_balanced,
117            max_imbalance,
118            periods_evaluated: snapshots.len(),
119            periods_imbalanced,
120            period_results,
121            companies_evaluated: companies.len(),
122        })
123    }
124
125    /// Evaluate a single balance snapshot.
126    fn evaluate_snapshot(&self, snapshot: &BalanceSnapshot) -> PeriodBalanceResult {
127        let get_balance = |account_type: AccountType| {
128            snapshot
129                .balances
130                .get(&account_type)
131                .copied()
132                .unwrap_or(Decimal::ZERO)
133        };
134
135        // Calculate totals
136        let total_assets = get_balance(AccountType::Asset) - get_balance(AccountType::ContraAsset);
137        let total_liabilities =
138            get_balance(AccountType::Liability) - get_balance(AccountType::ContraLiability);
139        let total_equity =
140            get_balance(AccountType::Equity) - get_balance(AccountType::ContraEquity);
141        let net_income = get_balance(AccountType::Revenue) - get_balance(AccountType::Expense);
142
143        // Balance equation: Assets = Liabilities + Equity + Net Income
144        let imbalance = total_assets - (total_liabilities + total_equity + net_income);
145        let is_balanced = imbalance.abs() <= self.tolerance;
146
147        PeriodBalanceResult {
148            company_code: snapshot.company_code.clone(),
149            fiscal_year: snapshot.fiscal_year,
150            fiscal_period: snapshot.fiscal_period,
151            total_assets,
152            total_liabilities,
153            total_equity,
154            net_income,
155            imbalance,
156            is_balanced,
157        }
158    }
159}
160
161impl Default for BalanceSheetEvaluator {
162    fn default() -> Self {
163        Self::new(Decimal::new(1, 2)) // 0.01 tolerance
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170
171    fn create_balanced_snapshot() -> BalanceSnapshot {
172        let mut balances = HashMap::new();
173        balances.insert(AccountType::Asset, Decimal::new(100000, 2));
174        balances.insert(AccountType::Liability, Decimal::new(40000, 2));
175        balances.insert(AccountType::Equity, Decimal::new(50000, 2));
176        balances.insert(AccountType::Revenue, Decimal::new(20000, 2));
177        balances.insert(AccountType::Expense, Decimal::new(10000, 2));
178        // Assets (1000) = Liabilities (400) + Equity (500) + Net Income (200-100=100)
179
180        BalanceSnapshot {
181            company_code: "1000".to_string(),
182            fiscal_year: 2024,
183            fiscal_period: 1,
184            balances,
185        }
186    }
187
188    #[test]
189    fn test_balanced_snapshot() {
190        let evaluator = BalanceSheetEvaluator::default();
191        let snapshot = create_balanced_snapshot();
192        let result = evaluator.evaluate(&[snapshot]).unwrap();
193
194        assert!(result.equation_balanced);
195        assert_eq!(result.periods_imbalanced, 0);
196    }
197
198    #[test]
199    fn test_imbalanced_snapshot() {
200        let mut snapshot = create_balanced_snapshot();
201        snapshot
202            .balances
203            .insert(AccountType::Asset, Decimal::new(110000, 2)); // Add 100 to assets
204
205        let evaluator = BalanceSheetEvaluator::default();
206        let result = evaluator.evaluate(&[snapshot]).unwrap();
207
208        assert!(!result.equation_balanced);
209        assert_eq!(result.periods_imbalanced, 1);
210        assert!(result.max_imbalance > Decimal::ZERO);
211    }
212
213    #[test]
214    fn test_empty_snapshots() {
215        let evaluator = BalanceSheetEvaluator::default();
216        let result = evaluator.evaluate(&[]);
217        assert!(matches!(result, Err(EvalError::MissingData(_))));
218    }
219}