datasynth_eval/coherence/
balance.rs1use crate::error::{EvalError, EvalResult};
6use rust_decimal::Decimal;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct BalanceSheetEvaluation {
13 pub equation_balanced: bool,
15 pub max_imbalance: Decimal,
17 pub periods_evaluated: usize,
19 pub periods_imbalanced: usize,
21 pub period_results: Vec<PeriodBalanceResult>,
23 pub companies_evaluated: usize,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct PeriodBalanceResult {
30 pub company_code: String,
32 pub fiscal_year: u16,
34 pub fiscal_period: u8,
36 pub total_assets: Decimal,
38 pub total_liabilities: Decimal,
40 pub total_equity: Decimal,
42 pub net_income: Decimal,
44 pub imbalance: Decimal,
46 pub is_balanced: bool,
48}
49
50#[derive(Debug, Clone)]
52pub struct BalanceSnapshot {
53 pub company_code: String,
55 pub fiscal_year: u16,
57 pub fiscal_period: u8,
59 pub balances: HashMap<AccountType, Decimal>,
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
65pub enum AccountType {
66 Asset,
67 ContraAsset,
68 Liability,
69 ContraLiability,
70 Equity,
71 ContraEquity,
72 Revenue,
73 Expense,
74}
75
76pub struct BalanceSheetEvaluator {
78 tolerance: Decimal,
80}
81
82impl BalanceSheetEvaluator {
83 pub fn new(tolerance: Decimal) -> Self {
85 Self { tolerance }
86 }
87
88 pub fn evaluate(&self, snapshots: &[BalanceSnapshot]) -> EvalResult<BalanceSheetEvaluation> {
90 if snapshots.is_empty() {
91 return Err(EvalError::MissingData(
92 "No balance snapshots provided".to_string(),
93 ));
94 }
95
96 let mut period_results = Vec::new();
97 let mut max_imbalance = Decimal::ZERO;
98 let mut periods_imbalanced = 0;
99 let companies: std::collections::HashSet<_> =
100 snapshots.iter().map(|s| &s.company_code).collect();
101
102 for snapshot in snapshots {
103 let result = self.evaluate_snapshot(snapshot);
104 if !result.is_balanced {
105 periods_imbalanced += 1;
106 }
107 if result.imbalance.abs() > max_imbalance {
108 max_imbalance = result.imbalance.abs();
109 }
110 period_results.push(result);
111 }
112
113 let equation_balanced = periods_imbalanced == 0;
114
115 Ok(BalanceSheetEvaluation {
116 equation_balanced,
117 max_imbalance,
118 periods_evaluated: snapshots.len(),
119 periods_imbalanced,
120 period_results,
121 companies_evaluated: companies.len(),
122 })
123 }
124
125 fn evaluate_snapshot(&self, snapshot: &BalanceSnapshot) -> PeriodBalanceResult {
127 let get_balance = |account_type: AccountType| {
128 snapshot
129 .balances
130 .get(&account_type)
131 .copied()
132 .unwrap_or(Decimal::ZERO)
133 };
134
135 let total_assets = get_balance(AccountType::Asset) - get_balance(AccountType::ContraAsset);
137 let total_liabilities =
138 get_balance(AccountType::Liability) - get_balance(AccountType::ContraLiability);
139 let total_equity =
140 get_balance(AccountType::Equity) - get_balance(AccountType::ContraEquity);
141 let net_income = get_balance(AccountType::Revenue) - get_balance(AccountType::Expense);
142
143 let imbalance = total_assets - (total_liabilities + total_equity + net_income);
145 let is_balanced = imbalance.abs() <= self.tolerance;
146
147 PeriodBalanceResult {
148 company_code: snapshot.company_code.clone(),
149 fiscal_year: snapshot.fiscal_year,
150 fiscal_period: snapshot.fiscal_period,
151 total_assets,
152 total_liabilities,
153 total_equity,
154 net_income,
155 imbalance,
156 is_balanced,
157 }
158 }
159}
160
161impl Default for BalanceSheetEvaluator {
162 fn default() -> Self {
163 Self::new(Decimal::new(1, 2)) }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 fn create_balanced_snapshot() -> BalanceSnapshot {
172 let mut balances = HashMap::new();
173 balances.insert(AccountType::Asset, Decimal::new(100000, 2));
174 balances.insert(AccountType::Liability, Decimal::new(40000, 2));
175 balances.insert(AccountType::Equity, Decimal::new(50000, 2));
176 balances.insert(AccountType::Revenue, Decimal::new(20000, 2));
177 balances.insert(AccountType::Expense, Decimal::new(10000, 2));
178 BalanceSnapshot {
181 company_code: "1000".to_string(),
182 fiscal_year: 2024,
183 fiscal_period: 1,
184 balances,
185 }
186 }
187
188 #[test]
189 fn test_balanced_snapshot() {
190 let evaluator = BalanceSheetEvaluator::default();
191 let snapshot = create_balanced_snapshot();
192 let result = evaluator.evaluate(&[snapshot]).unwrap();
193
194 assert!(result.equation_balanced);
195 assert_eq!(result.periods_imbalanced, 0);
196 }
197
198 #[test]
199 fn test_imbalanced_snapshot() {
200 let mut snapshot = create_balanced_snapshot();
201 snapshot
202 .balances
203 .insert(AccountType::Asset, Decimal::new(110000, 2)); let evaluator = BalanceSheetEvaluator::default();
206 let result = evaluator.evaluate(&[snapshot]).unwrap();
207
208 assert!(!result.equation_balanced);
209 assert_eq!(result.periods_imbalanced, 1);
210 assert!(result.max_imbalance > Decimal::ZERO);
211 }
212
213 #[test]
214 fn test_empty_snapshots() {
215 let evaluator = BalanceSheetEvaluator::default();
216 let result = evaluator.evaluate(&[]);
217 assert!(matches!(result, Err(EvalError::MissingData(_))));
218 }
219}