use std::collections::BTreeMap;
use std::fs;
use std::path::Path;
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
use datasynth_core::models::balance::{AccountType, TrialBalance};
use crate::errors::{GroupError, GroupResult};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct OpeningBalance {
pub account_code: String,
pub account_type: AccountType,
pub debit: Decimal,
pub credit: Decimal,
}
pub fn read_prior_period_closing_tbs(
prior_period_dir: &Path,
entity_codes: &[String],
) -> GroupResult<BTreeMap<String, TrialBalance>> {
let mut out: BTreeMap<String, TrialBalance> = BTreeMap::new();
if !prior_period_dir.exists() {
return Ok(out);
}
for code in entity_codes {
let path = prior_period_dir
.join("entities")
.join(code)
.join("period_close")
.join("trial_balances.json");
if !path.exists() {
continue;
}
let bytes = fs::read(&path).map_err(|e| {
GroupError::Aggregate(format!(
"read_prior_period_closing_tbs: cannot read `{}`: {e}",
path.display()
))
})?;
let tbs: Vec<TrialBalance> = serde_json::from_slice(&bytes).map_err(|e| {
GroupError::Aggregate(format!(
"read_prior_period_closing_tbs: cannot parse `{}` as Vec<TrialBalance>: {e}",
path.display()
))
})?;
if tbs.is_empty() {
return Err(GroupError::Aggregate(format!(
"read_prior_period_closing_tbs: `{}` contains no trial balances",
path.display()
)));
}
let latest = tbs
.into_iter()
.max_by_key(|tb| (tb.fiscal_year, tb.fiscal_period))
.expect("non-empty after the is_empty check above");
out.insert(code.clone(), latest);
}
Ok(out)
}
pub fn extract_opening_balances(tb: &TrialBalance) -> Vec<OpeningBalance> {
let mut out: Vec<OpeningBalance> = tb
.lines
.iter()
.filter(|line| is_balance_sheet_account(line.account_type))
.map(|line| OpeningBalance {
account_code: line.account_code.clone(),
account_type: line.account_type,
debit: line.debit_balance,
credit: line.credit_balance,
})
.collect();
out.sort_by(|a, b| a.account_code.cmp(&b.account_code));
out
}
fn is_balance_sheet_account(ty: AccountType) -> bool {
matches!(
ty,
AccountType::Asset
| AccountType::ContraAsset
| AccountType::Liability
| AccountType::ContraLiability
| AccountType::Equity
| AccountType::ContraEquity
)
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::NaiveDate;
use datasynth_core::models::balance::{
AccountCategory, TrialBalance, TrialBalanceLine, TrialBalanceType,
};
use rust_decimal_macros::dec;
fn period_end() -> NaiveDate {
NaiveDate::from_ymd_opt(2024, 3, 31).unwrap()
}
fn make_tb(company_code: &str) -> TrialBalance {
let mut tb = TrialBalance::new(
format!("TB_{company_code}"),
company_code.to_string(),
period_end(),
2024,
3,
"EUR".to_string(),
TrialBalanceType::Adjusted,
);
tb.add_line(line("1000", AccountType::Asset, dec!(10000), Decimal::ZERO));
tb.add_line(line("1100", AccountType::Asset, dec!(5000), Decimal::ZERO));
tb.add_line(line(
"2000",
AccountType::Liability,
Decimal::ZERO,
dec!(3000),
));
tb.add_line(line("3000", AccountType::Equity, Decimal::ZERO, dec!(9000)));
tb.add_line(line(
"4000",
AccountType::Revenue,
Decimal::ZERO,
dec!(7000),
));
tb.add_line(line(
"5000",
AccountType::Expense,
dec!(4000),
Decimal::ZERO,
));
tb
}
fn line(code: &str, ty: AccountType, debit: Decimal, credit: Decimal) -> TrialBalanceLine {
TrialBalanceLine {
account_code: code.to_string(),
account_description: code.to_string(),
category: AccountCategory::from_account_type(ty),
account_type: ty,
opening_balance: Decimal::ZERO,
period_debits: Decimal::ZERO,
period_credits: Decimal::ZERO,
closing_balance: if debit > credit { debit } else { credit },
debit_balance: debit,
credit_balance: credit,
cost_center: None,
profit_center: None,
}
}
#[test]
fn read_prior_period_closing_tbs_returns_empty_when_dir_missing() {
let tmp = tempfile::tempdir().unwrap();
let prior = tmp.path().join("non_existent");
let codes = vec!["E1".to_string()];
let map = read_prior_period_closing_tbs(&prior, &codes).unwrap();
assert!(map.is_empty());
}
#[test]
fn read_prior_period_closing_tbs_loads_matching_entities() {
let tmp = tempfile::tempdir().unwrap();
let prior = tmp.path();
let entity_dir = prior.join("entities").join("E1").join("period_close");
fs::create_dir_all(&entity_dir).unwrap();
let tbs = vec![make_tb("E1")];
fs::write(
entity_dir.join("trial_balances.json"),
serde_json::to_string_pretty(&tbs).unwrap(),
)
.unwrap();
let codes = vec!["E1".to_string(), "E2_MISSING".to_string()];
let map = read_prior_period_closing_tbs(prior, &codes).unwrap();
assert_eq!(map.len(), 1);
assert_eq!(map.get("E1").unwrap().company_code, "E1");
assert!(!map.contains_key("E2_MISSING"));
}
#[test]
fn read_prior_period_closing_tbs_picks_latest_by_fiscal_period() {
let tmp = tempfile::tempdir().unwrap();
let prior = tmp.path();
let entity_dir = prior.join("entities").join("E1").join("period_close");
fs::create_dir_all(&entity_dir).unwrap();
let mut p1 = make_tb("E1");
p1.fiscal_period = 1;
let mut p2 = make_tb("E1");
p2.fiscal_period = 2;
let mut p3 = make_tb("E1");
p3.fiscal_period = 3;
let tbs = vec![p1, p2, p3];
fs::write(
entity_dir.join("trial_balances.json"),
serde_json::to_string_pretty(&tbs).unwrap(),
)
.unwrap();
let map = read_prior_period_closing_tbs(prior, &["E1".to_string()]).unwrap();
assert_eq!(map.get("E1").unwrap().fiscal_period, 3);
}
#[test]
fn read_prior_period_closing_tbs_rejects_empty_array() {
let tmp = tempfile::tempdir().unwrap();
let prior = tmp.path();
let entity_dir = prior.join("entities").join("E1").join("period_close");
fs::create_dir_all(&entity_dir).unwrap();
let tbs: Vec<TrialBalance> = Vec::new();
fs::write(
entity_dir.join("trial_balances.json"),
serde_json::to_string_pretty(&tbs).unwrap(),
)
.unwrap();
let err = read_prior_period_closing_tbs(prior, &["E1".to_string()]).unwrap_err();
assert!(format!("{err}").contains("contains no trial balances"));
}
#[test]
fn extract_opening_balances_drops_pl_accounts() {
let tb = make_tb("E1");
let openings = extract_opening_balances(&tb);
assert_eq!(openings.len(), 4);
let codes: Vec<&str> = openings.iter().map(|o| o.account_code.as_str()).collect();
assert_eq!(codes, vec!["1000", "1100", "2000", "3000"]);
}
#[test]
fn extract_opening_balances_preserves_dr_cr_sides() {
let tb = make_tb("E1");
let openings = extract_opening_balances(&tb);
let cash = openings.iter().find(|o| o.account_code == "1000").unwrap();
assert_eq!(cash.debit, dec!(10000));
assert_eq!(cash.credit, Decimal::ZERO);
let ap = openings.iter().find(|o| o.account_code == "2000").unwrap();
assert_eq!(ap.debit, Decimal::ZERO);
assert_eq!(ap.credit, dec!(3000));
}
#[test]
fn extract_opening_balances_sorted_by_account_code() {
let tb = make_tb("E1");
let openings = extract_opening_balances(&tb);
let codes: Vec<&str> = openings.iter().map(|o| o.account_code.as_str()).collect();
let mut sorted = codes.clone();
sorted.sort();
assert_eq!(codes, sorted);
}
#[test]
fn extract_opening_balances_total_dr_equals_total_cr_for_balanced_tb() {
let tb = make_tb("E1");
let openings = extract_opening_balances(&tb);
let total_dr: Decimal = openings.iter().map(|o| o.debit).sum();
let total_cr: Decimal = openings.iter().map(|o| o.credit).sum();
let pl_net = dec!(7000) - dec!(4000); assert_eq!(total_dr - total_cr, pl_net);
}
#[test]
fn opening_balance_round_trips_json() {
let ob = OpeningBalance {
account_code: "1000".to_string(),
account_type: AccountType::Asset,
debit: dec!(12345.67),
credit: Decimal::ZERO,
};
let json = serde_json::to_string(&ob).unwrap();
let back: OpeningBalance = serde_json::from_str(&json).unwrap();
assert_eq!(ob, back);
}
}