use crate::error::{EvalError, EvalResult};
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BalanceSheetEvaluation {
pub equation_balanced: bool,
pub max_imbalance: Decimal,
pub periods_evaluated: usize,
pub periods_imbalanced: usize,
pub period_results: Vec<PeriodBalanceResult>,
pub companies_evaluated: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PeriodBalanceResult {
pub company_code: String,
pub fiscal_year: u16,
pub fiscal_period: u8,
pub total_assets: Decimal,
pub total_liabilities: Decimal,
pub total_equity: Decimal,
pub net_income: Decimal,
pub imbalance: Decimal,
pub is_balanced: bool,
}
#[derive(Debug, Clone)]
pub struct BalanceSnapshot {
pub company_code: String,
pub fiscal_year: u16,
pub fiscal_period: u8,
pub balances: HashMap<AccountType, Decimal>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum AccountType {
Asset,
ContraAsset,
Liability,
ContraLiability,
Equity,
ContraEquity,
Revenue,
Expense,
}
pub struct BalanceSheetEvaluator {
tolerance: Decimal,
}
impl BalanceSheetEvaluator {
pub fn new(tolerance: Decimal) -> Self {
Self { tolerance }
}
pub fn evaluate(&self, snapshots: &[BalanceSnapshot]) -> EvalResult<BalanceSheetEvaluation> {
if snapshots.is_empty() {
return Err(EvalError::MissingData(
"No balance snapshots provided".to_string(),
));
}
let mut period_results = Vec::new();
let mut max_imbalance = Decimal::ZERO;
let mut periods_imbalanced = 0;
let companies: std::collections::HashSet<_> =
snapshots.iter().map(|s| &s.company_code).collect();
for snapshot in snapshots {
let result = self.evaluate_snapshot(snapshot);
if !result.is_balanced {
periods_imbalanced += 1;
}
if result.imbalance.abs() > max_imbalance {
max_imbalance = result.imbalance.abs();
}
period_results.push(result);
}
let equation_balanced = periods_imbalanced == 0;
Ok(BalanceSheetEvaluation {
equation_balanced,
max_imbalance,
periods_evaluated: snapshots.len(),
periods_imbalanced,
period_results,
companies_evaluated: companies.len(),
})
}
fn evaluate_snapshot(&self, snapshot: &BalanceSnapshot) -> PeriodBalanceResult {
let get_balance = |account_type: AccountType| {
snapshot
.balances
.get(&account_type)
.copied()
.unwrap_or(Decimal::ZERO)
};
let total_assets = get_balance(AccountType::Asset) - get_balance(AccountType::ContraAsset);
let total_liabilities =
get_balance(AccountType::Liability) - get_balance(AccountType::ContraLiability);
let total_equity =
get_balance(AccountType::Equity) - get_balance(AccountType::ContraEquity);
let net_income = get_balance(AccountType::Revenue) - get_balance(AccountType::Expense);
let imbalance = total_assets - (total_liabilities + total_equity + net_income);
let is_balanced = imbalance.abs() <= self.tolerance;
PeriodBalanceResult {
company_code: snapshot.company_code.clone(),
fiscal_year: snapshot.fiscal_year,
fiscal_period: snapshot.fiscal_period,
total_assets,
total_liabilities,
total_equity,
net_income,
imbalance,
is_balanced,
}
}
}
impl Default for BalanceSheetEvaluator {
fn default() -> Self {
Self::new(Decimal::new(1, 2)) }
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
fn create_balanced_snapshot() -> BalanceSnapshot {
let mut balances = HashMap::new();
balances.insert(AccountType::Asset, Decimal::new(100000, 2));
balances.insert(AccountType::Liability, Decimal::new(40000, 2));
balances.insert(AccountType::Equity, Decimal::new(50000, 2));
balances.insert(AccountType::Revenue, Decimal::new(20000, 2));
balances.insert(AccountType::Expense, Decimal::new(10000, 2));
BalanceSnapshot {
company_code: "1000".to_string(),
fiscal_year: 2024,
fiscal_period: 1,
balances,
}
}
#[test]
fn test_balanced_snapshot() {
let evaluator = BalanceSheetEvaluator::default();
let snapshot = create_balanced_snapshot();
let result = evaluator.evaluate(&[snapshot]).unwrap();
assert!(result.equation_balanced);
assert_eq!(result.periods_imbalanced, 0);
}
#[test]
fn test_imbalanced_snapshot() {
let mut snapshot = create_balanced_snapshot();
snapshot
.balances
.insert(AccountType::Asset, Decimal::new(110000, 2));
let evaluator = BalanceSheetEvaluator::default();
let result = evaluator.evaluate(&[snapshot]).unwrap();
assert!(!result.equation_balanced);
assert_eq!(result.periods_imbalanced, 1);
assert!(result.max_imbalance > Decimal::ZERO);
}
#[test]
fn test_empty_snapshots() {
let evaluator = BalanceSheetEvaluator::default();
let result = evaluator.evaluate(&[]);
assert!(matches!(result, Err(EvalError::MissingData(_))));
}
}