use datasynth_config::presets::create_preset;
use datasynth_config::{CompanyConfig, GeneratorConfig, TransactionVolume};
use datasynth_core::models::{CoAComplexity, IndustrySector};
use crate::config::PeriodLength;
use crate::errors::{GroupError, GroupResult};
use crate::manifest::builder::{GroupManifest, ManifestEntity};
const DEFAULT_ROW_BUDGET: u64 = 100_000;
pub fn build_entity_generator_config(
manifest: &GroupManifest,
entity: &ManifestEntity,
) -> GroupResult<GeneratorConfig> {
let industry = map_industry(entity.industry.as_deref());
let period_months = period_months_from_length(manifest.period.length);
let row_budget = lookup_row_budget(manifest, &entity.scoping_profile);
let volume = volume_from_rows(row_budget);
let complexity = CoAComplexity::Medium;
let mut cfg = create_preset(industry, 1, period_months, complexity, volume);
cfg.global.seed = Some(manifest.group_seed);
cfg.global.industry = industry;
cfg.global.start_date = manifest.period.start.format("%Y-%m-%d").to_string();
cfg.global.period_months = period_months;
cfg.global.group_currency = manifest.presentation_currency.clone();
cfg.global.presentation_currency = Some(manifest.presentation_currency.clone());
cfg.financial_reporting.enabled = true;
cfg.balance.generate_opening_balances = true;
cfg.banking.enabled = false;
if let serde_yaml::Value::Mapping(defaults_map) = &manifest.defaults {
if let Some(fraud_yaml) = defaults_map.get("fraud") {
match serde_yaml::from_value::<datasynth_config::FraudConfig>(fraud_yaml.clone()) {
Ok(parsed_fraud) => {
cfg.fraud = parsed_fraud;
tracing::debug!(
entity = %entity.code,
fraud_enabled = cfg.fraud.enabled,
fraud_rate = cfg.fraud.fraud_rate,
"applied fraud config from GroupConfig.defaults.fraud"
);
}
Err(e) => {
tracing::warn!(
entity = %entity.code,
error = %e,
"GroupConfig.defaults.fraud failed to parse — falling back to SOTA defaults"
);
cfg.fraud.enabled = true;
cfg.fraud.fraud_rate = 0.05;
cfg.fraud.document_fraud_rate = Some(0.07);
cfg.fraud.propagate_to_lines = true;
}
}
} else {
cfg.fraud.enabled = true;
cfg.fraud.fraud_rate = 0.05;
cfg.fraud.document_fraud_rate = Some(0.07);
cfg.fraud.propagate_to_lines = true;
}
} else {
cfg.fraud.enabled = true;
cfg.fraud.fraud_rate = 0.05;
cfg.fraud.document_fraud_rate = Some(0.07);
cfg.fraud.propagate_to_lines = true;
}
const REFERENCE_BUDGET: u64 = 100_000;
let scale = (row_budget as f64 / REFERENCE_BUDGET as f64).clamp(0.001, 10.0);
let scale_usize = |n: usize| -> usize { ((n as f64) * scale).round().max(1.0) as usize };
let scale_u32 = |n: u32| -> u32 { ((n as f64) * scale).round().max(1.0) as u32 };
cfg.master_data.vendors.count = scale_usize(cfg.master_data.vendors.count).max(5);
cfg.master_data.customers.count = scale_usize(cfg.master_data.customers.count).max(5);
cfg.master_data.materials.count = scale_usize(cfg.master_data.materials.count).max(5);
cfg.master_data.fixed_assets.count = scale_usize(cfg.master_data.fixed_assets.count).max(3);
cfg.master_data.employees.count = scale_usize(cfg.master_data.employees.count).max(3);
if row_budget < 5_000 {
cfg.manufacturing.enabled = false;
} else {
cfg.manufacturing.production_orders.orders_per_month =
scale_u32(cfg.manufacturing.production_orders.orders_per_month).max(1);
}
cfg.companies = vec![CompanyConfig {
code: entity.code.clone(),
name: entity.name.clone().unwrap_or_else(|| entity.code.clone()),
currency: entity.functional_currency.clone(),
functional_currency: Some(entity.functional_currency.clone()),
country: entity.country.clone(),
fiscal_year_variant: "K4".to_string(),
annual_transaction_volume: volume,
volume_weight: 1.0,
}];
if let Some(fw_str) = entity.accounting_framework.as_deref() {
if let Some(fw_cfg) = parse_accounting_framework(fw_str) {
cfg.accounting_standards.framework = Some(fw_cfg);
}
}
datasynth_config::validate_config(&cfg).map_err(|e| {
GroupError::Config(format!(
"per-entity GeneratorConfig failed validation for {}: {e}",
entity.code
))
})?;
Ok(cfg)
}
fn parse_accounting_framework(
s: &str,
) -> Option<datasynth_config::schema::AccountingFrameworkConfig> {
use datasynth_config::schema::AccountingFrameworkConfig::*;
match s {
"us_gaap" | "UsGaap" | "us-gaap" => Some(UsGaap),
"ifrs" | "Ifrs" | "IFRS" => Some(Ifrs),
"dual_reporting" | "DualReporting" | "dual-reporting" => Some(DualReporting),
"french_gaap" | "FrenchGaap" | "french-gaap" | "pcg" => Some(FrenchGaap),
"german_gaap" | "GermanGaap" | "german-gaap" | "hgb" => Some(GermanGaap),
_ => None,
}
}
fn map_industry(s: Option<&str>) -> IndustrySector {
match s.map(|v| v.to_ascii_lowercase()).as_deref() {
Some("manufacturing") => IndustrySector::Manufacturing,
Some("retail") => IndustrySector::Retail,
Some("financial_services" | "banking" | "finance") => IndustrySector::FinancialServices,
Some("healthcare" | "pharma" | "pharmaceutical") => IndustrySector::Healthcare,
Some("technology" | "tech" | "software") => IndustrySector::Technology,
Some("professional_services" | "consulting") => IndustrySector::ProfessionalServices,
Some("energy" | "oil_gas" | "utilities") => IndustrySector::Energy,
Some("transportation" | "logistics") => IndustrySector::Transportation,
Some("real_estate") => IndustrySector::RealEstate,
Some("telecommunications" | "telecom") => IndustrySector::Telecommunications,
_ => IndustrySector::Manufacturing,
}
}
fn period_months_from_length(len: PeriodLength) -> u32 {
match len {
PeriodLength::Monthly => 1,
PeriodLength::Quarterly => 3,
PeriodLength::SemiAnnual => 6,
PeriodLength::Annual => 12,
}
}
fn lookup_row_budget(manifest: &GroupManifest, profile: &str) -> u64 {
manifest
.scoping_profiles
.get(profile)
.and_then(|v| v.as_mapping())
.and_then(|m| m.get(serde_yaml::Value::String("row_budget".to_string())))
.and_then(|v| v.as_u64())
.unwrap_or(DEFAULT_ROW_BUDGET)
}
fn volume_from_rows(rows: u64) -> TransactionVolume {
if rows < 10_000 {
TransactionVolume::Custom(rows)
} else if rows <= 10_000 {
TransactionVolume::TenK
} else if rows <= 100_000 {
TransactionVolume::HundredK
} else if rows <= 1_000_000 {
TransactionVolume::OneM
} else if rows <= 10_000_000 {
TransactionVolume::TenM
} else {
TransactionVolume::HundredM
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn map_industry_covers_every_variant() {
assert_eq!(
map_industry(Some("manufacturing")),
IndustrySector::Manufacturing
);
assert_eq!(map_industry(Some("retail")), IndustrySector::Retail);
assert_eq!(
map_industry(Some("financial_services")),
IndustrySector::FinancialServices
);
assert_eq!(
map_industry(Some("banking")),
IndustrySector::FinancialServices
);
assert_eq!(
map_industry(Some("finance")),
IndustrySector::FinancialServices
);
assert_eq!(map_industry(Some("healthcare")), IndustrySector::Healthcare);
assert_eq!(map_industry(Some("pharma")), IndustrySector::Healthcare);
assert_eq!(
map_industry(Some("pharmaceutical")),
IndustrySector::Healthcare
);
assert_eq!(map_industry(Some("technology")), IndustrySector::Technology);
assert_eq!(map_industry(Some("tech")), IndustrySector::Technology);
assert_eq!(map_industry(Some("software")), IndustrySector::Technology);
assert_eq!(
map_industry(Some("professional_services")),
IndustrySector::ProfessionalServices
);
assert_eq!(
map_industry(Some("consulting")),
IndustrySector::ProfessionalServices
);
assert_eq!(map_industry(Some("energy")), IndustrySector::Energy);
assert_eq!(map_industry(Some("oil_gas")), IndustrySector::Energy);
assert_eq!(map_industry(Some("utilities")), IndustrySector::Energy);
assert_eq!(
map_industry(Some("transportation")),
IndustrySector::Transportation
);
assert_eq!(
map_industry(Some("logistics")),
IndustrySector::Transportation
);
assert_eq!(
map_industry(Some("real_estate")),
IndustrySector::RealEstate
);
assert_eq!(
map_industry(Some("telecommunications")),
IndustrySector::Telecommunications
);
assert_eq!(
map_industry(Some("telecom")),
IndustrySector::Telecommunications
);
}
#[test]
fn map_industry_is_case_insensitive() {
assert_eq!(
map_industry(Some("MANUFACTURING")),
IndustrySector::Manufacturing
);
assert_eq!(map_industry(Some("Retail")), IndustrySector::Retail);
assert_eq!(
map_industry(Some("FINANCIAL_SERVICES")),
IndustrySector::FinancialServices
);
}
#[test]
fn map_industry_unknown_defaults_to_manufacturing() {
assert_eq!(
map_industry(Some("spacefaring_megacorp")),
IndustrySector::Manufacturing
);
assert_eq!(map_industry(Some("")), IndustrySector::Manufacturing);
}
#[test]
fn map_industry_none_defaults_to_manufacturing() {
assert_eq!(map_industry(None), IndustrySector::Manufacturing);
}
#[test]
fn period_months_covers_every_length() {
assert_eq!(period_months_from_length(PeriodLength::Monthly), 1);
assert_eq!(period_months_from_length(PeriodLength::Quarterly), 3);
assert_eq!(period_months_from_length(PeriodLength::SemiAnnual), 6);
assert_eq!(period_months_from_length(PeriodLength::Annual), 12);
}
#[test]
fn volume_from_rows_sub_tenk_uses_custom() {
assert!(matches!(
volume_from_rows(200),
TransactionVolume::Custom(200)
));
assert!(matches!(
volume_from_rows(1_000),
TransactionVolume::Custom(1_000)
));
assert!(matches!(
volume_from_rows(5_000),
TransactionVolume::Custom(5_000)
));
assert!(matches!(
volume_from_rows(9_999),
TransactionVolume::Custom(9_999)
));
assert!(matches!(volume_from_rows(10_000), TransactionVolume::TenK));
}
#[test]
fn volume_from_rows_honours_bucket_boundaries() {
assert!(matches!(
volume_from_rows(10_001),
TransactionVolume::HundredK
));
assert!(matches!(
volume_from_rows(100_000),
TransactionVolume::HundredK
));
assert!(matches!(volume_from_rows(100_001), TransactionVolume::OneM));
assert!(matches!(
volume_from_rows(1_000_000),
TransactionVolume::OneM
));
assert!(matches!(
volume_from_rows(1_000_001),
TransactionVolume::TenM
));
assert!(matches!(
volume_from_rows(10_000_000),
TransactionVolume::TenM
));
assert!(matches!(
volume_from_rows(10_000_001),
TransactionVolume::HundredM
));
}
#[test]
fn volume_from_rows_interior_samples() {
assert!(matches!(
volume_from_rows(50_000),
TransactionVolume::HundredK
));
assert!(matches!(volume_from_rows(500_000), TransactionVolume::OneM));
assert!(matches!(
volume_from_rows(5_000_000),
TransactionVolume::TenM
));
assert!(matches!(
volume_from_rows(50_000_000),
TransactionVolume::HundredM
));
}
#[test]
fn volume_from_rows_saturates_at_hundredm() {
assert!(matches!(
volume_from_rows(u64::MAX),
TransactionVolume::HundredM
));
}
}