use std::collections::BTreeMap;
use std::path::Path;
use datasynth_core::models::balance::{EntityOpeningBalance, TrialBalance};
use datasynth_generators::balance::project_closing_to_opening;
use crate::errors::{GroupError, GroupResult};
use crate::manifest::builder::GroupManifest;
use crate::shard::runner::{run_shard_with_opening_balances, ShardSummary};
pub const DEFAULT_FRAMEWORK: &str = "us_gaap";
pub fn build_opening_balances_from_prior(
manifest: &GroupManifest,
prior_period_shards_dir: &Path,
framework: &str,
) -> GroupResult<BTreeMap<String, Vec<EntityOpeningBalance>>> {
let mut openings_by_entity: BTreeMap<String, Vec<EntityOpeningBalance>> = BTreeMap::new();
for entity in &manifest.ownership_graph.entities {
let tb_path = prior_period_shards_dir
.join("entities")
.join(&entity.code)
.join("period_close")
.join("trial_balances.json");
if !tb_path.exists() {
return Err(GroupError::Shard(format!(
"build_opening_balances_from_prior: missing prior-period TB for `{}` at `{}` — \
the prior-period shard must contain every entity in the manifest",
entity.code,
tb_path.display(),
)));
}
let bytes = std::fs::read(&tb_path).map_err(GroupError::Io)?;
let tbs: Vec<TrialBalance> = serde_json::from_slice(&bytes)?;
let closing_tb = tbs
.into_iter()
.max_by_key(|tb| tb.as_of_date)
.ok_or_else(|| {
GroupError::Shard(format!(
"build_opening_balances_from_prior: prior-period TB file for `{}` was empty",
entity.code,
))
})?;
let openings = project_closing_to_opening(&closing_tb, framework);
openings_by_entity.insert(entity.code.clone(), openings);
}
Ok(openings_by_entity)
}
pub fn run_shard_chained(
manifest: &GroupManifest,
shard_id: &str,
out_dir: &Path,
prior_period_shards_dir: &Path,
framework: &str,
) -> GroupResult<ShardSummary> {
let openings = build_opening_balances_from_prior(manifest, prior_period_shards_dir, framework)?;
run_shard_with_opening_balances(manifest, shard_id, out_dir, &openings)
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::{NaiveDate, NaiveDateTime};
use datasynth_core::models::balance::{
AccountCategory, AccountType, TrialBalanceLine, TrialBalanceStatus, TrialBalanceType,
};
use rust_decimal_macros::dec;
use std::collections::HashMap;
use std::fs;
use tempfile::TempDir;
fn make_tb(
company_code: &str,
as_of_date: NaiveDate,
lines: Vec<TrialBalanceLine>,
) -> TrialBalance {
let total_debits: rust_decimal::Decimal = lines.iter().map(|l| l.debit_balance).sum();
let total_credits: rust_decimal::Decimal = lines.iter().map(|l| l.credit_balance).sum();
TrialBalance {
trial_balance_id: format!("TB-{company_code}-{as_of_date}"),
company_code: company_code.to_string(),
company_name: None,
as_of_date,
fiscal_year: as_of_date.format("%Y").to_string().parse().unwrap_or(2026),
fiscal_period: 12,
currency: "USD".into(),
balance_type: TrialBalanceType::PostClosing,
lines,
total_debits,
total_credits,
is_balanced: total_debits == total_credits,
out_of_balance: total_debits - total_credits,
is_equation_valid: true,
equation_difference: rust_decimal::Decimal::ZERO,
category_summary: HashMap::new(),
created_at: NaiveDateTime::default(),
created_by: "test".into(),
approved_by: None,
approved_at: None,
status: TrialBalanceStatus::Final,
}
}
fn make_line(
code: &str,
at: AccountType,
cat: AccountCategory,
debit: rust_decimal::Decimal,
credit: rust_decimal::Decimal,
) -> TrialBalanceLine {
TrialBalanceLine {
account_code: code.into(),
account_description: format!("Test {code}"),
category: cat,
account_type: at,
opening_balance: rust_decimal::Decimal::ZERO,
period_debits: rust_decimal::Decimal::ZERO,
period_credits: rust_decimal::Decimal::ZERO,
closing_balance: debit - credit,
debit_balance: debit,
credit_balance: credit,
cost_center: None,
profit_center: None,
}
}
fn stage_prior_shard(
root: &Path,
entity_code: &str,
as_of: NaiveDate,
assets: rust_decimal::Decimal,
liab: rust_decimal::Decimal,
equity: rust_decimal::Decimal,
) {
let pc_dir = root.join("entities").join(entity_code).join("period_close");
fs::create_dir_all(&pc_dir).unwrap();
let lines = vec![
make_line(
"1000",
AccountType::Asset,
AccountCategory::CurrentAssets,
assets,
rust_decimal::Decimal::ZERO,
),
make_line(
"2000",
AccountType::Liability,
AccountCategory::CurrentLiabilities,
rust_decimal::Decimal::ZERO,
liab,
),
make_line(
"3000",
AccountType::Equity,
AccountCategory::Equity,
rust_decimal::Decimal::ZERO,
equity,
),
];
let tbs = vec![make_tb(entity_code, as_of, lines)];
let json = serde_json::to_string(&tbs).unwrap();
fs::write(pc_dir.join("trial_balances.json"), json).unwrap();
}
fn build_test_manifest(entity_codes: &[&str]) -> crate::manifest::builder::GroupManifest {
let yaml = include_str!("../../tests/fixtures/mini_acme.yaml");
let cfg: crate::config::GroupConfig =
serde_yaml::from_str(yaml).expect("mini_acme.yaml must parse");
let mut manifest = crate::manifest::builder::build_manifest(&cfg)
.expect("build_manifest must succeed for mini_acme");
if !entity_codes.is_empty() {
let want: std::collections::BTreeSet<&str> = entity_codes.iter().copied().collect();
manifest
.ownership_graph
.entities
.retain(|e| want.contains(e.code.as_str()));
}
manifest
}
#[test]
fn build_opening_balances_reads_each_entity_tb() {
let tmp = TempDir::new().unwrap();
let prior = tmp.path();
stage_prior_shard(
prior,
"ACME_USA",
NaiveDate::from_ymd_opt(2026, 12, 31).unwrap(),
dec!(10_000),
dec!(4_000),
dec!(6_000),
);
stage_prior_shard(
prior,
"ACME_DE",
NaiveDate::from_ymd_opt(2026, 12, 31).unwrap(),
dec!(20_000),
dec!(7_000),
dec!(13_000),
);
let manifest = build_test_manifest(&["ACME_USA", "ACME_DE"]);
let openings = build_opening_balances_from_prior(&manifest, prior, DEFAULT_FRAMEWORK)
.expect("walk should succeed");
assert_eq!(openings.len(), 2);
let usa = openings.get("ACME_USA").expect("ACME_USA missing");
assert_eq!(
usa.len(),
3,
"expected 3 BS opening rows for ACME_USA, got {}",
usa.len()
);
let eu = openings.get("ACME_DE").expect("ACME_EU missing");
assert_eq!(eu.len(), 3);
}
#[test]
fn missing_prior_entity_is_a_hard_error() {
let tmp = TempDir::new().unwrap();
let prior = tmp.path();
stage_prior_shard(
prior,
"ACME_USA",
NaiveDate::from_ymd_opt(2026, 12, 31).unwrap(),
dec!(10_000),
dec!(4_000),
dec!(6_000),
);
let manifest = build_test_manifest(&["ACME_USA", "ACME_DE"]);
let err = build_opening_balances_from_prior(&manifest, prior, DEFAULT_FRAMEWORK)
.expect_err("missing entity must error");
let msg = format!("{err}");
assert!(
msg.contains("ACME_DE") && msg.contains("missing prior-period TB"),
"error should name the missing entity and be specific: {msg}",
);
}
#[test]
fn latest_tb_wins_when_multiple_periods_present() {
let tmp = TempDir::new().unwrap();
let prior = tmp.path();
let pc_dir = prior.join("entities").join("ACME_USA").join("period_close");
fs::create_dir_all(&pc_dir).unwrap();
let early = make_tb(
"ACME_USA",
NaiveDate::from_ymd_opt(2026, 9, 30).unwrap(),
vec![make_line(
"1000",
AccountType::Asset,
AccountCategory::CurrentAssets,
dec!(5_000), rust_decimal::Decimal::ZERO,
)],
);
let late = make_tb(
"ACME_USA",
NaiveDate::from_ymd_opt(2026, 12, 31).unwrap(),
vec![make_line(
"1000",
AccountType::Asset,
AccountCategory::CurrentAssets,
dec!(10_000), rust_decimal::Decimal::ZERO,
)],
);
let tbs = vec![early, late];
fs::write(
pc_dir.join("trial_balances.json"),
serde_json::to_string(&tbs).unwrap(),
)
.unwrap();
let manifest = build_test_manifest(&["ACME_USA"]);
let openings = build_opening_balances_from_prior(&manifest, prior, DEFAULT_FRAMEWORK)
.expect("walk should succeed");
let usa = openings.get("ACME_USA").unwrap();
assert_eq!(usa.len(), 1);
assert_eq!(
usa[0].debit,
dec!(10_000),
"should have picked the Q4 TB (10_000), not the Q3 TB (5_000)"
);
}
}