use std::collections::{BTreeMap, HashSet};
use std::path::Path;
use chrono::NaiveDate;
use datasynth_eval::behavioral_fidelity::loader::{load_csv_records, load_parquet_records};
use super::manual_extractor::extract_manual_share_from_parquet;
use datasynth_eval::behavioral_fidelity::math::pearson_lag1_correlation;
use datasynth_eval::behavioral_fidelity::Record;
use crate::error::FingerprintError;
use crate::models::behavioral::BehavioralPriors;
use crate::models::behavioral::{
ActiveLifetimePrior, ActiveSegmentsPrior, AmountQuantileSketch, CategoricalDistribution,
EntityCluster, EntityClustersPrior, FanoutPrior, IetSummary, LagSummary, LineCountHistogram,
LinesPerJePrior, LognormalAmount, LognormalParams, MonthVolumePrior, PerSourceAmountPrior,
PerSourceAttributePrior, PerSourceFlowPairPrior, PerSourceIetPrior, PerSourceRolePrior,
PostingLagPrior, SourceMixPrior, SourceSegmentSummary, ACTIVE_LIFETIME_DAY_BUCKETS,
FANOUT_BUCKETS, LINE_COUNT_BUCKETS, SEGMENT_COUNT_BUCKETS, SEGMENT_GAP_BUCKETS,
};
use crate::models::EmpiricalCdf;
use super::reference_extractor::extract_reference_formats;
use super::user_extractor::extract_user_personas;
pub const DEFAULT_MIN_USER_RECORDS: usize = 100;
pub const DEFAULT_MIN_REFERENCE_OCCURRENCES: usize = 10;
pub const DEFAULT_MIN_SOURCE_THRESHOLD: f64 = 0.005;
pub const DEFAULT_MIN_SOURCE_OBSERVATIONS: usize = 1000;
#[derive(Debug, Clone, Copy)]
pub struct SourceMixGates {
pub min_share: f64,
pub min_observations: usize,
}
impl Default for SourceMixGates {
fn default() -> Self {
Self {
min_share: DEFAULT_MIN_SOURCE_THRESHOLD,
min_observations: DEFAULT_MIN_SOURCE_OBSERVATIONS,
}
}
}
pub fn extract_source_mix(
records: &[Record],
min_threshold: f64,
min_observations: usize,
) -> SourceMixPrior {
if records.is_empty() {
return SourceMixPrior {
probabilities: BTreeMap::new(),
other_fraction: 0.0,
min_threshold,
};
}
let mut counts: BTreeMap<String, usize> = BTreeMap::new();
let mut counted: usize = 0;
for r in records {
if r.functional_amount.abs() > 0.0 {
*counts.entry(r.source.clone()).or_insert(0) += 1;
counted += 1;
}
}
if counted == 0 {
for r in records {
*counts.entry(r.source.clone()).or_insert(0) += 1;
}
counted = records.len();
}
counts.retain(|_, c| *c >= min_observations);
let total = counted as f64;
let mut probabilities = BTreeMap::new();
let mut other = 0.0;
for (src, c) in counts {
let frac = c as f64 / total;
if frac >= min_threshold {
probabilities.insert(src, frac);
} else {
other += frac;
}
}
let retained_sum: f64 = probabilities.values().sum();
let dropped_mass = 1.0 - retained_sum - other;
let other_fraction = other + dropped_mass;
if retained_sum > 0.0 {
for v in probabilities.values_mut() {
*v /= retained_sum;
}
}
SourceMixPrior {
probabilities,
other_fraction,
min_threshold,
}
}
pub fn extract_source_mix_je(
records: &[Record],
retained: &BTreeMap<String, f64>,
) -> Option<SourceMixPrior> {
use chrono::Datelike;
if retained.is_empty() {
return None;
}
let mut first_src: BTreeMap<(i32, &str), &str> = BTreeMap::new();
for r in records {
if r.functional_amount > 0.0 && r.functional_amount.is_finite() && !r.source.is_empty() {
first_src
.entry((r.effective_date.year(), r.je_number.as_str()))
.or_insert(r.source.as_str());
}
}
if first_src.is_empty() {
return None;
}
let total = first_src.len() as f64;
let mut counts: BTreeMap<&str, u64> = BTreeMap::new();
for src in first_src.values() {
*counts.entry(src).or_insert(0) += 1;
}
let retained_total: u64 = counts
.iter()
.filter(|(s, _)| retained.contains_key(**s))
.map(|(_, c)| *c)
.sum();
if retained_total == 0 {
return None;
}
let probabilities: BTreeMap<String, f64> = counts
.iter()
.filter(|(s, _)| retained.contains_key(**s))
.map(|(s, c)| (s.to_string(), *c as f64 / retained_total as f64))
.collect();
let other_fraction = 1.0 - retained_total as f64 / total;
Some(SourceMixPrior {
probabilities,
other_fraction,
min_threshold: 0.0,
})
}
pub const DEFAULT_MIN_IET_SAMPLES: usize = 100;
pub fn extract_per_source_iet(records: &[Record], min_samples: usize) -> PerSourceIetPrior {
let mut by_source: BTreeMap<String, Vec<NaiveDate>> = BTreeMap::new();
for r in records {
by_source
.entry(r.source.clone())
.or_default()
.push(r.entry_date);
}
let mut summaries: BTreeMap<String, IetSummary> = BTreeMap::new();
for (source, mut dates) in by_source {
if dates.len() < 2 {
continue;
}
dates.sort();
let iets: Vec<f64> = dates
.windows(2)
.map(|w| (w[1] - w[0]).num_days() as f64)
.collect();
if iets.len() < min_samples {
continue;
}
let cdf = build_empirical_cdf(&format!("iet_{source}"), &iets);
let lognormal = fit_lognormal(&iets);
let auto = pearson_lag1_correlation(&iets).unwrap_or(0.0);
summaries.insert(
source,
IetSummary {
n: iets.len(),
empirical_cdf_days: cdf,
lognormal_fit: lognormal,
lag1_autocorr: auto,
},
);
}
PerSourceIetPrior {
by_source: summaries,
}
}
fn build_empirical_cdf(column: &str, samples: &[f64]) -> EmpiricalCdf {
let mut sorted: Vec<f64> = samples.iter().copied().filter(|x| x.is_finite()).collect();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
EmpiricalCdf::from_sorted_values(column.to_string(), sorted)
}
pub const DEFAULT_MIN_LAG_SAMPLES: usize = 100;
pub fn extract_posting_lag(records: &[Record], min_samples: usize) -> Option<PostingLagPrior> {
if records.is_empty() {
return None;
}
let mut by_source: BTreeMap<String, Vec<f64>> = BTreeMap::new();
let mut seen_jes: HashSet<&str> = HashSet::new();
for r in records {
if r.functional_amount == 0.0 {
continue;
}
if !r.je_number.is_empty() && !seen_jes.insert(r.je_number.as_str()) {
continue;
}
let lag = (r.effective_date - r.entry_date).num_days() as f64;
by_source.entry(r.source.clone()).or_default().push(lag);
}
let mut summaries: BTreeMap<String, LagSummary> = BTreeMap::new();
for (source, samples) in by_source {
if samples.len() < min_samples {
continue;
}
let n = samples.len();
let mean = samples.iter().sum::<f64>() / n as f64;
let var = samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n as f64;
let cdf = build_empirical_cdf(&format!("lag_{source}"), &samples);
summaries.insert(
source,
LagSummary {
empirical_cdf_days: cdf,
mean,
stddev: var.sqrt(),
n,
},
);
}
if summaries.is_empty() {
None
} else {
Some(PostingLagPrior {
by_source: summaries,
})
}
}
pub const DEFAULT_MIN_MONTH_VOLUME_JES: usize = 60;
pub fn extract_month_volume(records: &[Record], min_jes: usize) -> Option<MonthVolumePrior> {
use chrono::Datelike;
let mut counts = [0.0f64; 12];
let mut seen_jes: HashSet<&str> = HashSet::new();
let mut n = 0usize;
for r in records {
if r.functional_amount == 0.0 {
continue;
}
if !r.je_number.is_empty() && !seen_jes.insert(r.je_number.as_str()) {
continue;
}
let m = r.effective_date.month(); if (1..=12).contains(&m) {
counts[(m - 1) as usize] += 1.0;
n += 1;
}
}
if n < min_jes {
return None;
}
let prior = MonthVolumePrior::from_counts(counts, n);
prior.has_data().then_some(prior)
}
fn fit_lognormal(samples: &[f64]) -> Option<LognormalParams> {
let log_samples: Vec<f64> = samples
.iter()
.filter(|&&x| x.is_finite() && x > 0.0)
.map(|&x| (x + 1.0).ln())
.collect();
if log_samples.len() < 3 {
return None;
}
let n = log_samples.len() as f64;
let mean: f64 = log_samples.iter().sum::<f64>() / n;
let var: f64 = log_samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n.max(1.0);
Some(LognormalParams {
mu: mean,
sigma: var.sqrt(),
})
}
pub fn extract_active_lifetime(records: &[Record]) -> ActiveLifetimePrior {
let mut by_source: BTreeMap<String, (NaiveDate, NaiveDate)> = BTreeMap::new();
for r in records {
let d = r.entry_date;
by_source
.entry(r.source.clone())
.and_modify(|(lo, hi)| {
if d < *lo {
*lo = d;
}
if d > *hi {
*hi = d;
}
})
.or_insert((d, d));
}
let lifetimes_by_source: Vec<u32> = by_source
.values()
.map(|(lo, hi)| hi.signed_duration_since(*lo).num_days().max(0) as u32)
.collect();
let (overall, _) = LineCountHistogram::build(&lifetimes_by_source, ACTIVE_LIFETIME_DAY_BUCKETS);
let mut per_source_hists: BTreeMap<String, LineCountHistogram> = BTreeMap::new();
for (src, (lo, hi)) in &by_source {
let life = hi.signed_duration_since(*lo).num_days().max(0) as u32;
let (h, _) = LineCountHistogram::build(&[life], ACTIVE_LIFETIME_DAY_BUCKETS);
per_source_hists.insert(src.clone(), h);
}
ActiveLifetimePrior {
by_source: per_source_hists,
overall,
}
}
type AttributeProjector = fn(&Record) -> Option<String>;
pub fn extract_fanout(records: &[Record]) -> FanoutPrior {
let attributes: [(&str, AttributeProjector); 4] = [
("GLAccount", |r| Some(r.gl_account.clone())),
("CostCenter", |r| r.cost_center.clone()),
("ProfitCenter", |r| r.profit_center.clone()),
("TradingPartner", |r| r.trading_partner.clone()),
];
let mut by_attribute: BTreeMap<String, LineCountHistogram> = BTreeMap::new();
for (name, proj) in attributes {
let mut sources_per_value: BTreeMap<String, HashSet<String>> = BTreeMap::new();
for r in records {
if let Some(v) = proj(r) {
sources_per_value
.entry(v)
.or_default()
.insert(r.source.clone());
}
}
let fanouts: Vec<u32> = sources_per_value.values().map(|s| s.len() as u32).collect();
let (hist, _) = LineCountHistogram::build(&fanouts, FANOUT_BUCKETS);
by_attribute.insert(name.to_string(), hist);
}
FanoutPrior { by_attribute }
}
pub const DEFAULT_MIN_ATTRIBUTE_OBSERVATIONS: usize = 10;
pub fn extract_per_source_attribute(
records: &[Record],
min_observations: usize,
) -> PerSourceAttributePrior {
let mut counts: BTreeMap<String, BTreeMap<String, BTreeMap<String, usize>>> = BTreeMap::new();
for r in records {
if r.source.is_empty() {
continue;
}
let source_map = counts.entry(r.source.clone()).or_default();
if !r.gl_account.is_empty() {
*source_map
.entry("gl_account".to_string())
.or_default()
.entry(r.gl_account.clone())
.or_default() += 1;
}
if let Some(cc) = r.cost_center.as_ref().filter(|s| !s.is_empty()) {
*source_map
.entry("cost_center".to_string())
.or_default()
.entry(cc.clone())
.or_default() += 1;
}
if let Some(pc) = r.profit_center.as_ref().filter(|s| !s.is_empty()) {
*source_map
.entry("profit_center".to_string())
.or_default()
.entry(pc.clone())
.or_default() += 1;
}
if let Some(tp) = r.trading_partner.as_ref().filter(|s| !s.is_empty()) {
*source_map
.entry("trading_partner".to_string())
.or_default()
.entry(tp.clone())
.or_default() += 1;
}
}
let by_source = counts
.into_iter()
.filter_map(|(source, attr_map)| {
let kept: BTreeMap<String, CategoricalDistribution> = attr_map
.into_iter()
.filter_map(|(attr, value_counts)| {
let total: usize = value_counts.values().sum();
if total < min_observations {
None
} else {
Some((attr, CategoricalDistribution::from_counts(value_counts)))
}
})
.collect();
if kept.is_empty() {
None
} else {
Some((source, kept))
}
})
.collect();
PerSourceAttributePrior {
by_source,
min_observations,
}
}
pub const DEFAULT_MIN_SOURCE_ROLE_OBSERVATIONS: usize = 10;
pub fn extract_source_role_gl(records: &[Record], min_observations: usize) -> PerSourceRolePrior {
let mut counts: BTreeMap<String, BTreeMap<String, BTreeMap<String, usize>>> = BTreeMap::new();
for r in records {
if r.source.is_empty() || r.gl_account.is_empty() {
continue;
}
let role = if r.functional_amount > 0.0 {
"DR"
} else if r.functional_amount < 0.0 {
"CR"
} else {
continue; };
*counts
.entry(r.source.clone())
.or_default()
.entry(role.to_string())
.or_default()
.entry(r.gl_account.clone())
.or_default() += 1;
}
let mut by_source_and_role: BTreeMap<String, BTreeMap<String, CategoricalDistribution>> =
BTreeMap::new();
for (source, role_map) in counts {
let mut roles: BTreeMap<String, CategoricalDistribution> = BTreeMap::new();
for (role, value_counts) in role_map {
let filtered: BTreeMap<String, usize> = value_counts
.into_iter()
.filter(|(_, c)| *c >= min_observations)
.collect();
let total: usize = filtered.values().sum();
if total < min_observations {
continue;
}
roles.insert(role, CategoricalDistribution::from_counts(filtered));
}
if !roles.is_empty() {
by_source_and_role.insert(source, roles);
}
}
PerSourceRolePrior { by_source_and_role }
}
pub const DEFAULT_MIN_FLOW_PAIR_OBSERVATIONS: usize = 10;
pub const DEFAULT_FLOW_PAIR_GRANULARITY: usize = 1;
pub fn extract_source_flow_pairs(
records: &[Record],
granularity: usize,
min_observations: usize,
) -> PerSourceFlowPairPrior {
use std::collections::BTreeSet;
let mut jes: BTreeMap<&str, (String, BTreeSet<String>, BTreeSet<String>)> = BTreeMap::new();
for r in records {
if r.je_number.is_empty() || r.source.is_empty() || r.gl_account.is_empty() {
continue;
}
let Some(class) = PerSourceFlowPairPrior::account_class(&r.gl_account, granularity) else {
continue;
};
let entry = jes
.entry(r.je_number.as_str())
.or_insert_with(|| (r.source.clone(), BTreeSet::new(), BTreeSet::new()));
if r.functional_amount > 0.0 {
entry.1.insert(class);
} else if r.functional_amount < 0.0 {
entry.2.insert(class);
}
}
let mut counts: BTreeMap<String, BTreeMap<String, usize>> = BTreeMap::new();
for (source, deb, cred) in jes.into_values() {
for d in &deb {
for c in &cred {
*counts
.entry(source.clone())
.or_default()
.entry(PerSourceFlowPairPrior::pair_key(d, c))
.or_default() += 1;
}
}
}
let by_source: BTreeMap<String, CategoricalDistribution> = counts
.into_iter()
.filter_map(|(source, pair_counts)| {
let filtered: BTreeMap<String, usize> = pair_counts
.into_iter()
.filter(|(_, c)| *c >= min_observations)
.collect();
let total: usize = filtered.values().sum();
if total < min_observations {
None
} else {
Some((source, CategoricalDistribution::from_counts(filtered)))
}
})
.collect();
PerSourceFlowPairPrior {
by_source,
granularity,
}
}
pub const DEFAULT_MIN_AMOUNT_OBSERVATIONS: usize = 10;
pub fn extract_source_amount_conditionals(
records: &[Record],
min_observations: usize,
) -> PerSourceAmountPrior {
use std::collections::BTreeMap;
use chrono::Datelike;
let mut by_pair: BTreeMap<(String, String), Vec<f64>> = BTreeMap::new();
let mut by_src: BTreeMap<String, Vec<f64>> = BTreeMap::new();
let mut je_acc: BTreeMap<(i32, &str), (&str, f64)> = BTreeMap::new();
for r in records {
let abs_amt = r.functional_amount.abs();
if abs_amt <= 0.0 || !abs_amt.is_finite() {
continue;
}
if r.source.is_empty() {
continue;
}
let gl_prefix: String = if r.gl_account.len() >= 4 {
r.gl_account[..4].to_string()
} else {
r.gl_account.clone()
};
by_pair
.entry((r.source.clone(), gl_prefix))
.or_default()
.push(abs_amt);
by_src.entry(r.source.clone()).or_default().push(abs_amt);
if r.functional_amount > 0.0 {
let key = (r.effective_date.year(), r.je_number.as_str());
let entry = je_acc.entry(key).or_insert((r.source.as_str(), 0.0));
entry.1 += r.functional_amount;
}
}
let mut by_source_and_class: BTreeMap<String, BTreeMap<String, LognormalAmount>> =
BTreeMap::new();
for ((source, gl_prefix), values) in by_pair {
if values.len() < min_observations {
continue;
}
if let Some(params) = fit_lognormal_amount(&values) {
by_source_and_class
.entry(source)
.or_default()
.insert(gl_prefix, params);
}
}
let mut by_source: BTreeMap<String, LognormalAmount> = BTreeMap::new();
for (source, values) in &by_src {
if values.len() < min_observations {
continue;
}
if let Some(params) = fit_lognormal_amount(values) {
by_source.insert(source.clone(), params);
}
}
let mut quantile_sketch_by_source: BTreeMap<String, AmountQuantileSketch> = BTreeMap::new();
for (source, values) in by_src {
if values.len() < DEFAULT_MIN_SKETCH_OBSERVATIONS {
continue;
}
if let Some(sketch) = build_amount_quantile_sketch(values) {
quantile_sketch_by_source.insert(source, sketch);
}
}
let mut je_totals_by_src: BTreeMap<&str, Vec<f64>> = BTreeMap::new();
let mut je_totals_global: Vec<f64> = Vec::new();
for (source, total) in je_acc.into_values() {
if total > 0.0 && total.is_finite() {
je_totals_by_src.entry(source).or_default().push(total);
je_totals_global.push(total);
}
}
let mut je_total_sketch_by_source: BTreeMap<String, AmountQuantileSketch> = BTreeMap::new();
for (source, totals) in je_totals_by_src {
if totals.len() < DEFAULT_MIN_SKETCH_OBSERVATIONS {
continue;
}
if let Some(sketch) = build_amount_quantile_sketch(totals) {
je_total_sketch_by_source.insert(source.to_string(), sketch);
}
}
let je_total_sketch_global = if je_totals_global.len() >= DEFAULT_MIN_SKETCH_OBSERVATIONS {
build_amount_quantile_sketch(je_totals_global)
} else {
None
};
PerSourceAmountPrior {
by_source_and_class,
by_source,
quantile_sketch_by_source,
je_total_sketch_by_source,
je_total_sketch_global,
}
}
pub const DEFAULT_MIN_SKETCH_OBSERVATIONS: usize = 1000;
pub const AMOUNT_SKETCH_GRID: &[f64] = &[
0.01, 0.02, 0.05, 0.10, 0.20, 0.30, 0.40, 0.50, 0.60, 0.70, 0.80, 0.90, 0.95, 0.98, 0.99,
0.995, 0.999,
];
fn build_amount_quantile_sketch(mut values: Vec<f64>) -> Option<AmountQuantileSketch> {
values.retain(|v| *v > 0.0 && v.is_finite());
let n = values.len();
if n < 2 {
return None;
}
values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let p_cap = 1.0 - 5.0 / n as f64;
let probabilities: Vec<f64> = AMOUNT_SKETCH_GRID
.iter()
.copied()
.filter(|p| *p <= p_cap)
.collect();
if probabilities.len() < 4 {
return None;
}
let quantile = |p: f64| -> f64 {
let h = (n - 1) as f64 * p;
let lo = h.floor() as usize;
let hi = (lo + 1).min(n - 1);
values[lo] + (h - lo as f64) * (values[hi] - values[lo])
};
let knots: Vec<f64> = probabilities.iter().map(|&p| quantile(p)).collect();
let sketch = AmountQuantileSketch {
probabilities,
values: knots,
n,
};
if sketch.is_usable() {
Some(sketch)
} else {
None
}
}
fn fit_lognormal_amount(values: &[f64]) -> Option<LognormalAmount> {
let log_vals: Vec<f64> = values
.iter()
.filter(|&&v| v > 0.0 && v.is_finite())
.map(|&v| v.ln())
.collect();
if log_vals.len() < 2 {
return None;
}
let n = log_vals.len() as f64;
let mu = log_vals.iter().sum::<f64>() / n;
let var = log_vals.iter().map(|x| (x - mu).powi(2)).sum::<f64>() / n.max(1.0);
let sigma = var.sqrt();
let mut sorted = values.to_vec();
sorted.retain(|v| *v > 0.0 && v.is_finite());
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let median_abs = if sorted.is_empty() {
0.0
} else if sorted.len().is_multiple_of(2) {
(sorted[sorted.len() / 2 - 1] + sorted[sorted.len() / 2]) / 2.0
} else {
sorted[sorted.len() / 2]
};
Some(LognormalAmount {
mu,
sigma,
n: log_vals.len(),
median_abs,
})
}
pub const DEFAULT_MIN_JES_PER_SOURCE: usize = 500;
pub fn extract_lines_per_je(records: &[Record], min_jes_per_source: usize) -> LinesPerJePrior {
let mut lines_per_je: BTreeMap<String, u32> = BTreeMap::new();
let mut source_of_je: BTreeMap<String, String> = BTreeMap::new();
for r in records {
*lines_per_je.entry(r.je_number.clone()).or_insert(0) += 1;
source_of_je
.entry(r.je_number.clone())
.or_insert_with(|| r.source.clone());
}
let overall_values: Vec<u32> = lines_per_je.values().copied().collect();
let (overall, _) = LineCountHistogram::build(&overall_values, LINE_COUNT_BUCKETS);
let mut by_source_values: BTreeMap<String, Vec<u32>> = BTreeMap::new();
for (je, n_lines) in &lines_per_je {
if let Some(src) = source_of_je.get(je) {
by_source_values
.entry(src.clone())
.or_default()
.push(*n_lines);
}
}
let mut by_source: BTreeMap<String, LineCountHistogram> = BTreeMap::new();
for (src, values) in by_source_values {
if values.len() < min_jes_per_source {
continue;
}
let (hist, _) = LineCountHistogram::build(&values, LINE_COUNT_BUCKETS);
by_source.insert(src, hist);
}
LinesPerJePrior {
overall,
by_source,
min_jes_per_source,
}
}
pub type BehavioralResult<T> = Result<T, FingerprintError>;
pub fn extract_behavioral_priors(
records: &[Record],
industry: &str,
) -> BehavioralResult<BehavioralPriors> {
extract_behavioral_priors_with_gates(records, industry, SourceMixGates::default())
}
pub fn extract_behavioral_priors_with_gates(
records: &[Record],
industry: &str,
source_mix_gates: SourceMixGates,
) -> BehavioralResult<BehavioralPriors> {
let source_mix = extract_source_mix(
records,
source_mix_gates.min_share,
source_mix_gates.min_observations,
);
let source_mix_je = extract_source_mix_je(records, &source_mix.probabilities);
Ok(BehavioralPriors {
schema_version: BehavioralPriors::SCHEMA_VERSION,
generator_version: env!("CARGO_PKG_VERSION").to_string(),
industry: industry.to_string(),
n_client_inputs: 1,
n_rows_aggregated: records.len(),
source_mix,
per_source_iet: extract_per_source_iet(records, DEFAULT_MIN_IET_SAMPLES),
lines_per_je: extract_lines_per_je(records, DEFAULT_MIN_JES_PER_SOURCE),
active_lifetime: extract_active_lifetime(records),
fanout: extract_fanout(records),
posting_lag: extract_posting_lag(records, DEFAULT_MIN_LAG_SAMPLES),
month_volume: extract_month_volume(records, DEFAULT_MIN_MONTH_VOLUME_JES),
active_segments: Some(extract_active_segments(records)),
entity_clusters: Some(extract_entity_clusters(records)),
per_source_attribute: Some(extract_per_source_attribute(
records,
DEFAULT_MIN_ATTRIBUTE_OBSERVATIONS,
)),
tp_entity_clusters: Some(extract_tp_entity_clusters(records)),
reference_formats: {
let rf = extract_reference_formats(records, DEFAULT_MIN_REFERENCE_OCCURRENCES);
if rf.by_source.is_empty() {
None
} else {
Some(rf)
}
},
coa_semantic: None,
user_personas: {
let up = extract_user_personas(records, DEFAULT_MIN_USER_RECORDS);
Some(up)
},
source_amount_conditionals: {
let sac = extract_source_amount_conditionals(records, DEFAULT_MIN_AMOUNT_OBSERVATIONS);
if sac.by_source.is_empty() && sac.by_source_and_class.is_empty() {
None
} else {
Some(sac)
}
},
source_role_gl_conditionals: {
let srg = extract_source_role_gl(records, DEFAULT_MIN_SOURCE_ROLE_OBSERVATIONS);
if srg.by_source_and_role.is_empty() {
None
} else {
Some(srg)
}
},
source_flow_pairs: {
let sfp = extract_source_flow_pairs(
records,
DEFAULT_FLOW_PAIR_GRANULARITY,
DEFAULT_MIN_FLOW_PAIR_OBSERVATIONS,
);
if sfp.by_source.is_empty() {
None
} else {
Some(sfp)
}
},
text_taxonomy: None,
approver: None,
source_mix_je,
tb_anchor: None,
manual_share: None,
})
}
pub fn extract_behavioral_priors_from_path(
path: &Path,
industry: &str,
) -> BehavioralResult<BehavioralPriors> {
extract_behavioral_priors_from_path_with_gates(path, industry, SourceMixGates::default())
}
pub fn extract_behavioral_priors_from_path_with_gates(
path: &Path,
industry: &str,
source_mix_gates: SourceMixGates,
) -> BehavioralResult<BehavioralPriors> {
let is_parquet = matches!(path.extension().and_then(|s| s.to_str()), Some("parquet"));
let records = match path.extension().and_then(|s| s.to_str()) {
Some("parquet") => load_parquet_records(path)
.map_err(|e| io_error_to_fp(format!("parquet load failed: {e}")))?,
Some("csv") => {
load_csv_records(path).map_err(|e| io_error_to_fp(format!("csv load failed: {e}")))?
}
_ => {
return Err(io_error_to_fp(format!(
"unsupported extension at {}",
path.display()
)));
}
};
let mut priors = extract_behavioral_priors_with_gates(&records, industry, source_mix_gates)?;
if is_parquet {
priors.manual_share = match extract_manual_share_from_parquet(
path,
super::manual_extractor::DEFAULT_MIN_MANUAL_OBSERVATIONS,
) {
Ok(ms) => ms,
Err(e) => {
tracing::warn!("manual-share extraction failed (skipping): {e}");
None
}
};
}
if is_parquet {
priors.approver = match super::approver_extractor::extract_approver_prior_from_parquet(
path,
super::approver_extractor::DEFAULT_MIN_APPROVER_OBSERVATIONS,
) {
Ok(ap) => ap,
Err(e) => {
tracing::warn!("approver extraction failed (skipping): {e}");
None
}
};
}
Ok(priors)
}
fn io_error_to_fp(msg: String) -> FingerprintError {
FingerprintError::ExtractionError {
extractor: "behavioral_priors".to_string(),
message: msg,
}
}
pub const SEGMENT_GAP_THRESHOLD_DAYS: i64 = 7;
fn split_into_segments(
dates: &[NaiveDate],
gap_threshold: i64,
) -> (Vec<(NaiveDate, NaiveDate)>, Vec<u32>) {
if dates.is_empty() {
return (vec![], vec![]);
}
let mut segments = Vec::new();
let mut gaps = Vec::new();
let mut seg_start = dates[0];
let mut seg_end = dates[0];
for &d in &dates[1..] {
let gap = (d - seg_end).num_days();
if gap > gap_threshold {
segments.push((seg_start, seg_end));
gaps.push(gap as u32);
seg_start = d;
seg_end = d;
} else {
seg_end = d;
}
}
segments.push((seg_start, seg_end));
(segments, gaps)
}
pub fn extract_active_segments(records: &[Record]) -> ActiveSegmentsPrior {
let mut by_source: BTreeMap<String, Vec<NaiveDate>> = BTreeMap::new();
for r in records {
by_source
.entry(r.source.clone())
.or_default()
.push(r.entry_date);
}
let mut summaries: BTreeMap<String, SourceSegmentSummary> = BTreeMap::new();
for (src, mut dates) in by_source {
dates.sort();
dates.dedup();
if dates.len() < 2 {
continue;
}
let (segments, gaps) = split_into_segments(&dates, SEGMENT_GAP_THRESHOLD_DAYS);
let segment_count = segments.len() as u32;
let segment_lengths: Vec<u32> = segments
.iter()
.map(|s| (s.1 - s.0).num_days().max(0) as u32)
.collect();
let (count_hist, _) = LineCountHistogram::build(&[segment_count], SEGMENT_COUNT_BUCKETS);
let (length_hist, _) =
LineCountHistogram::build(&segment_lengths, ACTIVE_LIFETIME_DAY_BUCKETS);
let (gap_hist, _) = LineCountHistogram::build(&gaps, SEGMENT_GAP_BUCKETS);
summaries.insert(
src,
SourceSegmentSummary {
segment_count_histogram: count_hist,
segment_length_histogram: length_hist,
gap_length_histogram: gap_hist,
},
);
}
ActiveSegmentsPrior {
by_source: summaries,
}
}
const MAX_SOURCES_FOR_CLUSTERING: usize = 50;
const JACCARD_THRESHOLD: f64 = 0.3;
const CANONICAL_SAP_CODES: &[&str] = &[
"KR", "RV", "DZ", "WE", "RE", "SA", "IM", "KZ", "AB", "AF", "DR", "KK", "K9", "KX", "PK", "RB",
"RY", "SL", "ZP",
];
fn normalise_source_code(raw: &str) -> Option<String> {
let trimmed = raw.trim();
if trimmed.is_empty() {
return None;
}
if CANONICAL_SAP_CODES.contains(&trimmed) {
return Some(trimmed.to_string());
}
match trimmed {
"0" | "00" => Some("SA".to_string()),
"1" | "01" => Some("RV".to_string()),
"2" | "02" => Some("KR".to_string()),
_ => None,
}
}
pub fn extract_entity_clusters(records: &[Record]) -> EntityClustersPrior {
let mut row_count_per_source: BTreeMap<String, usize> = BTreeMap::new();
for r in records {
*row_count_per_source.entry(r.source.clone()).or_insert(0) += 1;
}
let mut sorted_sources: Vec<(String, usize)> = row_count_per_source.into_iter().collect();
sorted_sources.sort_by_key(|b| std::cmp::Reverse(b.1));
let top_sources: Vec<String> = sorted_sources
.into_iter()
.take(MAX_SOURCES_FOR_CLUSTERING)
.map(|(s, _)| s)
.collect();
let top_set: HashSet<&String> = top_sources.iter().collect();
let mut attr_sets: BTreeMap<String, HashSet<String>> = BTreeMap::new();
for r in records {
if !top_set.contains(&r.source) {
continue;
}
let set = attr_sets.entry(r.source.clone()).or_default();
set.insert(format!("GL:{}", r.gl_account));
if let Some(cc) = &r.cost_center {
set.insert(format!("CC:{cc}"));
}
if let Some(pc) = &r.profit_center {
set.insert(format!("PC:{pc}"));
}
if let Some(tp) = &r.trading_partner {
set.insert(format!("TP:{tp}"));
}
}
let attr_sets: BTreeMap<String, HashSet<String>> = attr_sets
.into_iter()
.filter_map(|(raw, set)| normalise_source_code(&raw).map(|canonical| (canonical, set)))
.fold(BTreeMap::new(), |mut acc, (canonical, set)| {
acc.entry(canonical).or_default().extend(set);
acc
});
let sources: Vec<String> = attr_sets.keys().cloned().collect();
let mut adj: BTreeMap<String, Vec<String>> = BTreeMap::new();
let mut edge_weights: BTreeMap<(String, String), f64> = BTreeMap::new();
for i in 0..sources.len() {
for j in (i + 1)..sources.len() {
let a = &attr_sets[&sources[i]];
let b = &attr_sets[&sources[j]];
if a.is_empty() || b.is_empty() {
continue;
}
let intersection = a.intersection(b).count() as f64;
let union = a.union(b).count() as f64;
if union == 0.0 {
continue;
}
let jaccard = intersection / union;
if jaccard >= JACCARD_THRESHOLD {
adj.entry(sources[i].clone())
.or_default()
.push(sources[j].clone());
adj.entry(sources[j].clone())
.or_default()
.push(sources[i].clone());
let key = if sources[i] < sources[j] {
(sources[i].clone(), sources[j].clone())
} else {
(sources[j].clone(), sources[i].clone())
};
edge_weights.insert(key, jaccard);
}
}
}
let mut visited: HashSet<String> = HashSet::new();
let mut clusters: Vec<EntityCluster> = Vec::new();
for src in &sources {
if visited.contains(src) {
continue;
}
let mut members = Vec::new();
let mut stack = vec![src.clone()];
while let Some(s) = stack.pop() {
if !visited.insert(s.clone()) {
continue;
}
members.push(s.clone());
if let Some(neighbors) = adj.get(&s) {
for n in neighbors {
if !visited.contains(n) {
stack.push(n.clone());
}
}
}
}
if members.len() >= 2 {
let mut sum = 0.0;
let mut count = 0.0;
for i in 0..members.len() {
for j in (i + 1)..members.len() {
let key = if members[i] < members[j] {
(members[i].clone(), members[j].clone())
} else {
(members[j].clone(), members[i].clone())
};
if let Some(&w) = edge_weights.get(&key) {
sum += w;
count += 1.0;
}
}
}
let avg_jaccard = if count > 0.0 { sum / count } else { 0.0 };
clusters.push(EntityCluster {
members,
avg_jaccard,
});
}
}
let total_in_clusters: usize = clusters.iter().map(|c| c.members.len()).sum();
let denom = sources.len().max(1);
let clustering_rate = total_in_clusters as f64 / denom as f64;
EntityClustersPrior {
clusters,
clustering_rate,
}
}
const MAX_TP_FOR_CLUSTERING: usize = 200;
pub fn extract_tp_entity_clusters(records: &[Record]) -> EntityClustersPrior {
let mut row_count_per_tp: BTreeMap<String, usize> = BTreeMap::new();
for r in records {
if let Some(tp) = &r.trading_partner {
if !tp.is_empty() {
*row_count_per_tp.entry(tp.clone()).or_insert(0) += 1;
}
}
}
let mut sorted_tps: Vec<(String, usize)> = row_count_per_tp.into_iter().collect();
sorted_tps.sort_by_key(|b| std::cmp::Reverse(b.1));
let top_tps: Vec<String> = sorted_tps
.into_iter()
.take(MAX_TP_FOR_CLUSTERING)
.map(|(tp, _)| tp)
.collect();
let top_set: HashSet<&String> = top_tps.iter().collect();
let mut attr_sets: BTreeMap<String, HashSet<String>> = BTreeMap::new();
for r in records {
let tp = match &r.trading_partner {
Some(tp) if !tp.is_empty() && top_set.contains(tp) => tp.clone(),
_ => continue,
};
let set = attr_sets.entry(tp).or_default();
set.insert(format!("GL:{}", r.gl_account));
if let Some(cc) = &r.cost_center {
set.insert(format!("CC:{cc}"));
}
if let Some(pc) = &r.profit_center {
set.insert(format!("PC:{pc}"));
}
set.insert(format!("SRC:{}", r.source));
}
let tps: Vec<String> = attr_sets.keys().cloned().collect();
let mut adj: BTreeMap<String, Vec<String>> = BTreeMap::new();
let mut edge_weights: BTreeMap<(String, String), f64> = BTreeMap::new();
for i in 0..tps.len() {
for j in (i + 1)..tps.len() {
let a = &attr_sets[&tps[i]];
let b = &attr_sets[&tps[j]];
if a.is_empty() || b.is_empty() {
continue;
}
let intersection = a.intersection(b).count() as f64;
let union = a.union(b).count() as f64;
if union == 0.0 {
continue;
}
let jaccard = intersection / union;
if jaccard >= JACCARD_THRESHOLD {
adj.entry(tps[i].clone()).or_default().push(tps[j].clone());
adj.entry(tps[j].clone()).or_default().push(tps[i].clone());
let key = if tps[i] < tps[j] {
(tps[i].clone(), tps[j].clone())
} else {
(tps[j].clone(), tps[i].clone())
};
edge_weights.insert(key, jaccard);
}
}
}
let mut visited: HashSet<String> = HashSet::new();
let mut clusters: Vec<EntityCluster> = Vec::new();
for tp in &tps {
if visited.contains(tp) {
continue;
}
let mut members = Vec::new();
let mut stack = vec![tp.clone()];
while let Some(t) = stack.pop() {
if !visited.insert(t.clone()) {
continue;
}
members.push(t.clone());
if let Some(neighbors) = adj.get(&t) {
for n in neighbors {
if !visited.contains(n) {
stack.push(n.clone());
}
}
}
}
if members.len() >= 2 {
let mut sum = 0.0;
let mut count = 0.0;
for i in 0..members.len() {
for j in (i + 1)..members.len() {
let key = if members[i] < members[j] {
(members[i].clone(), members[j].clone())
} else {
(members[j].clone(), members[i].clone())
};
if let Some(&w) = edge_weights.get(&key) {
sum += w;
count += 1.0;
}
}
}
let avg_jaccard = if count > 0.0 { sum / count } else { 0.0 };
clusters.push(EntityCluster {
members,
avg_jaccard,
});
}
}
let total_in_clusters: usize = clusters.iter().map(|c| c.members.len()).sum();
let denom = tps.len().max(1);
let clustering_rate = total_in_clusters as f64 / denom as f64;
EntityClustersPrior {
clusters,
clustering_rate,
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::{Duration, NaiveDate};
use rand::{RngExt, SeedableRng};
pub(crate) fn rec(src: &str) -> Record {
let d = NaiveDate::from_ymd_opt(2022, 1, 1).expect("date");
Record {
source: src.into(),
gl_account: "1".into(),
cost_center: None,
profit_center: None,
trading_partner: None,
je_number: "J1".into(),
je_line_number: "001".into(),
effective_date: d,
entry_date: d,
created_at: None,
functional_amount: 1.0,
header_text: String::new(),
line_text: String::new(),
}
}
pub(crate) fn rec_amt(src: &str, amount: f64) -> Record {
let mut r = rec(src);
r.functional_amount = amount;
r
}
#[test]
fn source_mix_ignores_zero_amount_lines() {
let mut recs: Vec<Record> = Vec::new();
recs.extend(std::iter::repeat_with(|| rec_amt("A", 10.0)).take(60));
recs.extend(std::iter::repeat_with(|| rec_amt("B", -5.0)).take(30));
recs.extend(std::iter::repeat_with(|| rec_amt("B", 0.0)).take(200));
let mix = extract_source_mix(&recs, DEFAULT_MIN_SOURCE_THRESHOLD, 0);
assert!((mix.probabilities["A"] - 60.0 / 90.0).abs() < 1e-9);
assert!((mix.probabilities["B"] - 30.0 / 90.0).abs() < 1e-9);
}
#[test]
fn source_mix_falls_back_to_all_lines_when_amounts_degenerate() {
let mut recs: Vec<Record> = Vec::new();
recs.extend(std::iter::repeat_with(|| rec_amt("A", 0.0)).take(70));
recs.extend(std::iter::repeat_with(|| rec_amt("B", 0.0)).take(30));
let mix = extract_source_mix(&recs, DEFAULT_MIN_SOURCE_THRESHOLD, 0);
assert!((mix.probabilities["A"] - 0.7).abs() < 1e-9);
assert!((mix.probabilities["B"] - 0.3).abs() < 1e-9);
}
#[test]
fn source_mix_shares_match() {
let mut recs: Vec<Record> = Vec::new();
recs.extend(std::iter::repeat_with(|| rec("A")).take(60));
recs.extend(std::iter::repeat_with(|| rec("B")).take(30));
recs.extend(std::iter::repeat_with(|| rec("C")).take(10));
let mix = extract_source_mix(&recs, DEFAULT_MIN_SOURCE_THRESHOLD, 0);
assert!((mix.probabilities["A"] - 0.6).abs() < 1e-9);
assert!((mix.probabilities["B"] - 0.3).abs() < 1e-9);
assert!((mix.probabilities["C"] - 0.1).abs() < 1e-9);
assert!(mix.other_fraction.abs() < 1e-9);
}
#[test]
fn source_mix_long_tail_rolls_into_other() {
let mut recs: Vec<Record> = Vec::new();
recs.extend(std::iter::repeat_with(|| rec("A")).take(995));
for i in 1..=5 {
recs.push(rec(&format!("X{i}")));
}
let mix = extract_source_mix(&recs, 0.005, 0);
assert!(mix.probabilities.contains_key("A"));
assert!(!mix.probabilities.contains_key("X1"));
assert!(mix.other_fraction > 0.0);
}
#[test]
fn source_mix_empty_input_returns_empty() {
let mix = extract_source_mix(&[], DEFAULT_MIN_SOURCE_THRESHOLD, 0);
assert!(mix.probabilities.is_empty());
assert!(mix.other_fraction.abs() < 1e-9);
}
#[test]
fn source_mix_gates_default_matches_constants() {
let gates = SourceMixGates::default();
assert!((gates.min_share - DEFAULT_MIN_SOURCE_THRESHOLD).abs() < 1e-12);
assert_eq!(gates.min_observations, DEFAULT_MIN_SOURCE_OBSERVATIONS);
}
#[test]
fn lowered_gates_retain_rare_but_real_sources() {
let mut recs: Vec<Record> = Vec::new();
recs.extend(std::iter::repeat_with(|| rec("A")).take(2490));
recs.extend(std::iter::repeat_with(|| rec("B")).take(10));
let default_mix = extract_behavioral_priors(&recs, "technology")
.expect("default extraction")
.source_mix;
assert!(!default_mix.probabilities.contains_key("B"));
let gates = SourceMixGates {
min_share: 0.001,
min_observations: 10,
};
let mix = extract_behavioral_priors_with_gates(&recs, "technology", gates)
.expect("gated extraction")
.source_mix;
assert!(mix.probabilities.contains_key("B"), "B should survive");
assert!((mix.probabilities["B"] - 0.004).abs() < 1e-9);
let total: f64 = mix.probabilities.values().sum();
assert!((total - 1.0).abs() < 1e-9);
}
#[test]
fn extract_source_mix_je_uses_je_counts_not_line_counts() {
let mut records: Vec<Record> = Vec::new();
for i in 0..60 {
records.push(make_je_rec("CO", &format!("CO-{i}"), 2024, 100.0));
records.push(make_je_rec("CO", &format!("CO-{i}"), 2024, -100.0));
}
for i in 0..10 {
for _ in 0..4 {
records.push(make_je_rec("WL", &format!("WL-{i}"), 2024, 25.0));
records.push(make_je_rec("WL", &format!("WL-{i}"), 2024, -25.0));
}
}
let line_mix = extract_source_mix(&records, 0.001, 10);
assert!((line_mix.probabilities["CO"] - 0.60).abs() < 1e-9);
assert!((line_mix.probabilities["WL"] - 0.40).abs() < 1e-9);
let je_mix = extract_source_mix_je(&records, &line_mix.probabilities)
.expect("JE mix present when the line mix has a vocabulary");
assert!(
(je_mix.probabilities["CO"] - 60.0 / 70.0).abs() < 1e-9,
"CO JE share must be 6/7, got {}",
je_mix.probabilities["CO"]
);
assert!((je_mix.probabilities["WL"] - 10.0 / 70.0).abs() < 1e-9);
let total: f64 = je_mix.probabilities.values().sum();
assert!((total - 1.0).abs() < 1e-9);
}
#[test]
fn extract_source_mix_je_year_scopes_je_numbers() {
let mut records: Vec<Record> = Vec::new();
for year in [2024, 2025] {
records.push(make_je_rec("CO", "J1", year, 100.0));
records.push(make_je_rec("CO", "J1", year, -100.0));
}
for i in 0..2 {
records.push(make_je_rec("WL", &format!("W-{i}"), 2024, 50.0));
records.push(make_je_rec("WL", &format!("W-{i}"), 2024, -50.0));
}
let line_mix = extract_source_mix(&records, 0.001, 1);
let je_mix =
extract_source_mix_je(&records, &line_mix.probabilities).expect("JE mix present");
assert!(
(je_mix.probabilities["CO"] - 0.5).abs() < 1e-9,
"CO JE share must be 0.5 (year-scoped), got {}",
je_mix.probabilities["CO"]
);
}
#[test]
fn extract_source_mix_je_inherits_line_mix_vocabulary() {
let mut records: Vec<Record> = Vec::new();
for i in 0..40 {
records.push(make_je_rec("CO", &format!("CO-{i}"), 2024, 100.0));
records.push(make_je_rec("CO", &format!("CO-{i}"), 2024, -100.0));
}
for i in 0..10 {
records.push(make_je_rec("ZX", &format!("ZX-{i}"), 2024, 9.0));
}
let mut retained = BTreeMap::new();
retained.insert("CO".to_string(), 1.0);
let je_mix = extract_source_mix_je(&records, &retained).expect("JE mix present");
assert!(!je_mix.probabilities.contains_key("ZX"));
assert!((je_mix.probabilities["CO"] - 1.0).abs() < 1e-9);
assert!(
(je_mix.other_fraction - 0.2).abs() < 1e-9,
"excluded JE mass should be reported, got {}",
je_mix.other_fraction
);
}
#[test]
fn extract_behavioral_priors_populates_source_mix_je() {
let mut records: Vec<Record> = Vec::new();
for i in 0..30 {
records.push(make_je_rec("CO", &format!("CO-{i}"), 2024, 100.0));
records.push(make_je_rec("CO", &format!("CO-{i}"), 2024, -100.0));
}
for i in 0..10 {
for l in 0..3 {
records.push(make_je_rec("WL", &format!("WL-{i}"), 2024, 10.0 + l as f64));
}
records.push(make_je_rec("WL", &format!("WL-{i}"), 2024, -33.0));
}
let gates = SourceMixGates {
min_share: 0.001,
min_observations: 10,
};
let bp = extract_behavioral_priors_with_gates(&records, "test", gates).expect("ok");
let je_mix = bp.source_mix_je.expect("source_mix_je populated");
let line_keys: Vec<&String> = bp.source_mix.probabilities.keys().collect();
let je_keys: Vec<&String> = je_mix.probabilities.keys().collect();
assert_eq!(je_keys, line_keys, "vocabularies must match");
assert!(
(je_mix.probabilities["CO"] - 0.75).abs() < 1e-9,
"CO is 30 of 40 JEs, got {}",
je_mix.probabilities["CO"]
);
let bp_default = extract_behavioral_priors(&records, "test").expect("ok");
if bp_default.source_mix.probabilities.is_empty() {
assert!(bp_default.source_mix_je.is_none());
}
}
#[test]
fn extract_source_mix_drops_low_volume_codes() {
let mut records = Vec::new();
records.extend(std::iter::repeat_with(|| rec("KR")).take(1500));
records.extend(std::iter::repeat_with(|| rec("RV")).take(100));
records.extend(std::iter::repeat_with(|| rec("DZ")).take(5));
let mix = extract_source_mix(&records, 0.0, 1000);
assert!(mix.probabilities.contains_key("KR"), "KR should survive");
assert!(
!mix.probabilities.contains_key("RV"),
"RV (100 obs) should be dropped"
);
assert!(
!mix.probabilities.contains_key("DZ"),
"DZ (5 obs) should be dropped"
);
assert!(
(mix.probabilities["KR"] - 1.0).abs() < 1e-9,
"KR probability should be 1.0 after renormalisation"
);
}
#[test]
fn per_source_iet_basic() {
let mut recs: Vec<Record> = Vec::new();
let base = NaiveDate::from_ymd_opt(2022, 1, 1).unwrap();
for i in 0..120 {
let mut r = rec("A");
r.entry_date = base + chrono::Duration::days(i);
recs.push(r);
}
for i in 0..50 {
let mut r = rec("B");
r.entry_date = base + chrono::Duration::days(i);
recs.push(r);
}
let p = extract_per_source_iet(&recs, 100);
assert!(p.by_source.contains_key("A"));
assert!(!p.by_source.contains_key("B"));
let summ = &p.by_source["A"];
assert_eq!(summ.n, 119);
assert!(summ.lognormal_fit.is_some());
}
#[test]
fn per_source_iet_constant_gap_zero_autocorr() {
let mut recs: Vec<Record> = Vec::new();
let base = NaiveDate::from_ymd_opt(2022, 1, 1).unwrap();
for i in 0..200 {
let mut r = rec("A");
r.entry_date = base + chrono::Duration::days(3 * i);
recs.push(r);
}
let p = extract_per_source_iet(&recs, 100);
assert!((p.by_source["A"].lag1_autocorr).abs() < 1e-9);
}
#[test]
fn lines_per_je_overall_known() {
let mut recs: Vec<Record> = Vec::new();
for _ in 0..3 {
let mut r = rec("S");
r.je_number = "JE-A".into();
recs.push(r);
}
for _ in 0..2 {
let mut r = rec("S");
r.je_number = "JE-B".into();
recs.push(r);
}
let mut r = rec("S");
r.je_number = "JE-C".into();
recs.push(r);
let p = extract_lines_per_je(&recs, DEFAULT_MIN_JES_PER_SOURCE);
let idx_1 = LINE_COUNT_BUCKETS.iter().position(|&b| b == 1).unwrap();
let idx_2 = LINE_COUNT_BUCKETS.iter().position(|&b| b == 2).unwrap();
let idx_3 = LINE_COUNT_BUCKETS.iter().position(|&b| b == 3).unwrap();
assert!((p.overall.probabilities[idx_1] - 1.0 / 3.0).abs() < 1e-9);
assert!((p.overall.probabilities[idx_2] - 1.0 / 3.0).abs() < 1e-9);
assert!((p.overall.probabilities[idx_3] - 1.0 / 3.0).abs() < 1e-9);
assert_eq!(p.overall.n, 3);
}
#[test]
fn active_lifetime_basic() {
let base = NaiveDate::from_ymd_opt(2022, 1, 1).unwrap();
let mut recs: Vec<Record> = Vec::new();
for i in 0..5 {
let mut r = rec("A");
r.entry_date = base + chrono::Duration::days(i * 6);
recs.push(r);
}
for i in 0..5 {
let mut r = rec("B");
r.entry_date = base + chrono::Duration::days(i * 50);
recs.push(r);
}
let p = extract_active_lifetime(&recs);
let idx_7 = ACTIVE_LIFETIME_DAY_BUCKETS
.iter()
.position(|&b| b == 7)
.unwrap();
let idx_180 = ACTIVE_LIFETIME_DAY_BUCKETS
.iter()
.position(|&b| b == 180)
.unwrap();
assert!((p.overall.probabilities[idx_7] - 0.5).abs() < 1e-9);
assert!((p.overall.probabilities[idx_180] - 0.5).abs() < 1e-9);
}
#[test]
fn fanout_basic() {
let mut recs: Vec<Record> = Vec::new();
for &(src, gl) in &[("A", "X"), ("B", "X"), ("C", "X"), ("A", "Y")] {
let mut r = rec(src);
r.gl_account = gl.into();
recs.push(r);
}
let p = extract_fanout(&recs);
let hist = &p.by_attribute["GLAccount"];
let idx_1 = FANOUT_BUCKETS.iter().position(|&b| b == 1).unwrap();
let idx_3 = FANOUT_BUCKETS.iter().position(|&b| b == 3).unwrap();
assert!((hist.probabilities[idx_1] - 0.5).abs() < 1e-9);
assert!((hist.probabilities[idx_3] - 0.5).abs() < 1e-9);
}
#[test]
fn posting_lag_known() {
let mut recs: Vec<Record> = Vec::new();
let base = NaiveDate::from_ymd_opt(2022, 1, 1).unwrap();
for i in 0..120 {
let mut r = rec("A");
r.je_number = format!("A{i}");
r.entry_date = base + chrono::Duration::days(i);
r.effective_date = r.entry_date + chrono::Duration::days(5);
recs.push(r);
}
for i in 0..120 {
let mut r = rec("B");
r.je_number = format!("B{i}");
r.entry_date = base + chrono::Duration::days(i);
r.effective_date = r.entry_date - chrono::Duration::days(2);
recs.push(r);
}
let p = extract_posting_lag(&recs, 100).expect("non-empty");
assert!((p.by_source["A"].mean - 5.0).abs() < 1e-9);
assert!((p.by_source["B"].mean - (-2.0)).abs() < 1e-9);
assert!((p.by_source["A"].stddev).abs() < 1e-9);
}
#[test]
fn posting_lag_counts_amount_bearing_jes_once() {
let base = NaiveDate::from_ymd_opt(2022, 1, 1).unwrap();
let mut recs: Vec<Record> = Vec::new();
for _ in 0..3 {
let mut r = rec("A");
r.je_number = "j1".into();
r.entry_date = base;
r.effective_date = base + chrono::Duration::days(5);
recs.push(r);
}
let mut r2 = rec("A");
r2.je_number = "j2".into();
r2.entry_date = base;
r2.effective_date = base - chrono::Duration::days(10);
recs.push(r2);
let mut r3 = rec_amt("A", 0.0);
r3.je_number = "j3".into();
r3.entry_date = base;
r3.effective_date = base - chrono::Duration::days(50);
recs.push(r3);
let p = extract_posting_lag(&recs, 1).expect("non-empty");
let a = &p.by_source["A"];
assert_eq!(a.n, 2, "two amount-bearing JEs, not four lines");
assert!(
(a.mean - (-2.5)).abs() < 1e-9,
"JE-level mean (5 + -10)/2 = -2.5, got {}",
a.mean
);
}
#[test]
fn extract_month_volume_counts_amount_bearing_jes_per_month() {
let jan = NaiveDate::from_ymd_opt(2022, 1, 15).unwrap();
let feb = NaiveDate::from_ymd_opt(2022, 2, 15).unwrap();
let mut recs: Vec<Record> = Vec::new();
for (je, d) in [
("a1", jan),
("a2", jan),
("a2", jan),
("a3", jan),
("b1", feb),
] {
let mut r = rec("A");
r.je_number = je.into();
r.effective_date = d;
recs.push(r);
}
let mut z = rec_amt("A", 0.0);
z.je_number = "z1".into();
z.effective_date = feb;
recs.push(z);
let mv = extract_month_volume(&recs, 1).expect("non-empty");
assert_eq!(mv.n, 4, "3 distinct amount-bearing Jan JEs + 1 Feb JE");
assert!((mv.shares[0] - 0.75).abs() < 1e-9, "Jan share 3/4");
assert!((mv.shares[1] - 0.25).abs() < 1e-9, "Feb share 1/4");
assert!(extract_month_volume(&recs, 100).is_none());
}
#[test]
fn split_into_segments_known() {
let base = NaiveDate::from_ymd_opt(2022, 1, 1).unwrap();
let dates: Vec<NaiveDate> = [0i64, 1, 2, 14, 15, 16, 17, 49]
.iter()
.map(|&d| base + chrono::Duration::days(d))
.collect();
let (segs, gaps) = split_into_segments(&dates, 7);
assert_eq!(segs.len(), 3); assert_eq!(gaps.len(), 2); assert_eq!(gaps[0], 12);
assert_eq!(gaps[1], 32);
}
#[test]
fn extract_active_segments_basic() {
let base = NaiveDate::from_ymd_opt(2022, 1, 1).unwrap();
let mut recs: Vec<Record> = Vec::new();
for &day_off in &[0i64, 1, 2, 14, 15, 16, 17, 49] {
let mut r = rec("A");
r.entry_date = base + chrono::Duration::days(day_off);
recs.push(r);
}
let p = extract_active_segments(&recs);
assert!(p.by_source.contains_key("A"));
let summary = &p.by_source["A"];
let idx_3 = SEGMENT_COUNT_BUCKETS.iter().position(|&b| b == 3).unwrap();
assert!((summary.segment_count_histogram.probabilities[idx_3] - 1.0).abs() < 1e-9);
}
#[test]
fn extract_entity_clusters_finds_shared_attrs() {
let mut recs: Vec<Record> = Vec::new();
for gl in ["1", "2", "3"] {
let mut r = rec("KR");
r.gl_account = gl.into();
recs.push(r);
}
for gl in ["1", "2", "3"] {
let mut r = rec("RV");
r.gl_account = gl.into();
recs.push(r);
}
for gl in ["1", "2", "4"] {
let mut r = rec("DZ");
r.gl_account = gl.into();
recs.push(r);
}
let mut r = rec("WE");
r.gl_account = "99".into();
recs.push(r);
let p = extract_entity_clusters(&recs);
let any_cluster_has_kr_rv_dz = p.clusters.iter().any(|c| {
let members: HashSet<&String> = c.members.iter().collect();
members.contains(&"KR".to_string())
&& members.contains(&"RV".to_string())
&& members.contains(&"DZ".to_string())
});
assert!(
any_cluster_has_kr_rv_dz,
"expected a cluster containing KR, RV, DZ"
);
let any_cluster_has_we = p
.clusters
.iter()
.any(|c| c.members.iter().any(|m| m == "WE"));
assert!(!any_cluster_has_we, "WE should be an isolate (no cluster)");
}
#[test]
fn extract_entity_clusters_normalises_source_codes() {
let base = chrono::NaiveDate::from_ymd_opt(2022, 1, 1).unwrap();
let mut recs: Vec<Record> = Vec::new();
for source in ["0", "0", "0", "KR", "KR", "KR"] {
for gl in ["1", "2", "3"] {
let mut r = rec(source);
r.entry_date = base;
r.gl_account = gl.into();
recs.push(r);
}
}
let p = extract_entity_clusters(&recs);
for cluster in &p.clusters {
for member in &cluster.members {
assert!(
!["0", "1", "2", "00", "01", "02"].contains(&member.as_str()),
"raw numeric code {member:?} should have been normalised"
);
}
}
if !p.clusters.is_empty() {
let any_sap = p.clusters.iter().any(|c| {
c.members
.iter()
.any(|m| ["KR", "RV", "DZ", "SA", "WE", "RE", "IM", "KZ"].contains(&m.as_str()))
});
assert!(
any_sap,
"expected at least one SAP-style canonical code in clusters"
);
}
}
#[test]
fn normalise_source_code_known_mappings() {
assert_eq!(normalise_source_code("KR"), Some("KR".to_string()));
assert_eq!(normalise_source_code("RV"), Some("RV".to_string()));
assert_eq!(normalise_source_code("0"), Some("SA".to_string()));
assert_eq!(normalise_source_code("00"), Some("SA".to_string()));
assert_eq!(normalise_source_code("1"), Some("RV".to_string()));
assert_eq!(normalise_source_code("2"), Some("KR".to_string()));
assert_eq!(normalise_source_code(" KR "), Some("KR".to_string()));
assert_eq!(normalise_source_code("XYZ"), None);
assert_eq!(normalise_source_code(""), None);
}
fn make_record(
src: &str,
gl: &str,
cost_center: Option<&str>,
profit_center: Option<&str>,
) -> Record {
let d = NaiveDate::from_ymd_opt(2022, 1, 1).expect("date");
Record {
source: src.into(),
gl_account: gl.into(),
cost_center: cost_center.map(|s| s.to_string()),
profit_center: profit_center.map(|s| s.to_string()),
trading_partner: None,
je_number: "J1".into(),
je_line_number: "001".into(),
effective_date: d,
entry_date: d,
created_at: None,
functional_amount: 1.0,
header_text: String::new(),
line_text: String::new(),
}
}
#[test]
fn extract_per_source_attribute_filters_low_observations() {
let mut records: Vec<Record> = (0..15)
.map(|_| make_record("KR", "200001", Some("CC1"), Some("PC1")))
.collect();
records.extend((0..3).map(|_| make_record("RV", "400001", Some("CC2"), Some("PC2"))));
let prior = extract_per_source_attribute(&records, 10);
assert!(prior.by_source.contains_key("KR"), "KR should be retained");
let kr_gl = prior
.conditional("KR", "gl_account")
.expect("KR/gl_account");
assert!(
kr_gl.probabilities.contains_key("200001"),
"200001 should appear"
);
assert_eq!(kr_gl.n, 15);
assert!((kr_gl.probabilities["200001"] - 1.0).abs() < 1e-9);
assert!(
!prior.by_source.contains_key("RV"),
"RV should be filtered out"
);
}
#[test]
fn extract_per_source_attribute_skips_empty_source() {
let records: Vec<Record> = (0..20)
.map(|_| make_record("", "100001", None, None))
.collect();
let prior = extract_per_source_attribute(&records, 5);
assert!(
prior.by_source.is_empty(),
"empty source rows must be skipped"
);
}
#[test]
fn extract_per_source_attribute_multiple_values_normalise() {
let mut records: Vec<Record> = (0..8)
.map(|_| make_record("KR", "200001", None, None))
.collect();
records.extend((0..12).map(|_| make_record("KR", "200002", None, None)));
let prior = extract_per_source_attribute(&records, 10);
let kr_gl = prior
.conditional("KR", "gl_account")
.expect("KR/gl_account");
assert_eq!(kr_gl.n, 20);
assert!((kr_gl.probabilities["200001"] - 0.4).abs() < 1e-9);
assert!((kr_gl.probabilities["200002"] - 0.6).abs() < 1e-9);
let total: f64 = kr_gl.probabilities.values().sum();
assert!((total - 1.0).abs() < 1e-9, "probabilities must sum to 1.0");
}
#[test]
fn extract_behavioral_priors_smoke() {
let base = NaiveDate::from_ymd_opt(2022, 1, 1).unwrap();
let mut recs: Vec<Record> = Vec::new();
for i in 0..1200i64 {
let mut r = rec("A");
r.je_number = format!("JE-A-{:04}", i / 3);
r.entry_date = base + chrono::Duration::days(i);
r.effective_date = r.entry_date + chrono::Duration::days(1);
r.gl_account = format!("ACC-{}", i % 5);
recs.push(r);
}
for i in 0..1200i64 {
let mut r = rec("B");
r.je_number = format!("JE-B-{:04}", i / 2);
r.entry_date = base + chrono::Duration::days(i);
r.effective_date = r.entry_date - chrono::Duration::days(1);
r.gl_account = format!("ACC-{}", i % 7);
recs.push(r);
}
let bp = extract_behavioral_priors(&recs, "test_industry").expect("ok");
assert_eq!(bp.schema_version, BehavioralPriors::SCHEMA_VERSION);
assert_eq!(bp.industry, "test_industry");
assert_eq!(bp.n_client_inputs, 1);
assert_eq!(bp.n_rows_aggregated, 2400);
assert!(!bp.source_mix.probabilities.is_empty());
assert!(bp.per_source_iet.by_source.contains_key("A"));
assert!(bp.per_source_iet.by_source.contains_key("B"));
assert!(bp.lines_per_je.overall.n > 0);
assert!(bp.active_lifetime.overall.n > 0);
assert_eq!(bp.fanout.by_attribute.len(), 4);
assert!(bp.posting_lag.is_some());
assert!(
bp.per_source_attribute.is_some(),
"per_source_attribute should be extracted"
);
let psa = bp.per_source_attribute.as_ref().unwrap();
assert!(psa.by_source.contains_key("A") || psa.by_source.contains_key("B"));
}
#[test]
fn extract_per_source_attribute_includes_trading_partner() {
let mut records: Vec<Record> = (0..15)
.map(|_| {
let mut r = make_record("KR", "200001", Some("CC1"), Some("PC1"));
r.trading_partner = Some("V100".to_string());
r
})
.collect();
records.extend((0..5).map(|_| {
let mut r = make_record("KR", "200001", Some("CC1"), Some("PC1"));
r.trading_partner = Some("V200".to_string());
r
}));
records.extend((0..3).map(|_| {
let mut r = make_record("RV", "400001", Some("CC2"), Some("PC2"));
r.trading_partner = Some("V300".to_string());
r
}));
let prior = extract_per_source_attribute(&records, 10);
let kr_tp = prior
.conditional("KR", "trading_partner")
.expect("KR/trading_partner conditional must be present");
assert!(
kr_tp.probabilities.contains_key("V100"),
"V100 should appear in KR trading_partner conditional"
);
assert!(
kr_tp.probabilities.contains_key("V200"),
"V200 should appear in KR trading_partner conditional"
);
assert_eq!(kr_tp.n, 20, "total observations should be 20");
assert!(
(kr_tp.probabilities["V100"] - 0.75).abs() < 1e-9,
"V100 share should be 0.75"
);
assert!(
(kr_tp.probabilities["V200"] - 0.25).abs() < 1e-9,
"V200 share should be 0.25"
);
let total: f64 = kr_tp.probabilities.values().sum();
assert!((total - 1.0).abs() < 1e-9, "probabilities must sum to 1.0");
assert!(
!prior.by_source.contains_key("RV"),
"RV should be filtered out"
);
}
#[test]
fn extract_tp_entity_clusters_finds_shared_attrs() {
let d = NaiveDate::from_ymd_opt(2022, 1, 1).expect("date");
let make_tp_rec = |tp: &str, gl: &str| Record {
source: "KR".into(),
gl_account: gl.into(),
cost_center: None,
profit_center: None,
trading_partner: Some(tp.into()),
je_number: "J1".into(),
je_line_number: "001".into(),
effective_date: d,
entry_date: d,
created_at: None,
functional_amount: 1.0,
header_text: String::new(),
line_text: String::new(),
};
let mut recs = Vec::new();
for gl in ["10", "20", "30"] {
recs.push(make_tp_rec("T1", gl));
recs.push(make_tp_rec("T2", gl));
}
recs.push(make_tp_rec("T3", "99"));
let p = extract_tp_entity_clusters(&recs);
let cluster_has_t1_t2 = p.clusters.iter().any(|c| {
let m: HashSet<&String> = c.members.iter().collect();
m.contains(&"T1".to_string()) && m.contains(&"T2".to_string())
});
assert!(cluster_has_t1_t2, "expected a cluster containing T1 and T2");
let cluster_has_t3 = p
.clusters
.iter()
.any(|c| c.members.iter().any(|m| m == "T3"));
assert!(!cluster_has_t3, "T3 should be an isolate (no cluster)");
assert!(p.clustering_rate > 0.0, "clustering_rate must be > 0");
}
fn make_amount_rec(src: &str, gl: &str, amount: f64) -> Record {
let d = NaiveDate::from_ymd_opt(2022, 1, 1).expect("date");
Record {
source: src.into(),
gl_account: gl.into(),
cost_center: None,
profit_center: None,
trading_partner: None,
je_number: "J1".into(),
je_line_number: "001".into(),
effective_date: d,
entry_date: d,
created_at: None,
functional_amount: amount,
header_text: String::new(),
line_text: String::new(),
}
}
#[test]
fn extract_source_amount_conditionals_filters_low_count_pairs() {
let mut records: Vec<Record> = (0..15)
.map(|i| make_amount_rec("KR", "0041", 100.0 + i as f64))
.collect();
records.extend((0..3).map(|i| make_amount_rec("RV", "0013", 500.0 + i as f64)));
let prior = extract_source_amount_conditionals(&records, 10);
assert!(
prior.by_source.contains_key("KR"),
"KR marginal should be retained"
);
assert!(
prior
.by_source_and_class
.get("KR")
.map(|m| m.contains_key("0041"))
.unwrap_or(false),
"KR/0041 pair should be retained"
);
assert!(
!prior.by_source.contains_key("RV"),
"RV marginal should be dropped (only 3 observations)"
);
}
#[test]
fn extract_source_amount_conditionals_lognormal_params_sensible() {
let base_amount = 4.5_f64.exp(); let records: Vec<Record> = (0..50)
.map(|i| {
let amt = base_amount * (1.0 + 0.01 * ((i as f64) - 25.0));
make_amount_rec("KR", "0041", amt)
})
.collect();
let prior = extract_source_amount_conditionals(&records, 10);
let params = prior
.by_source_and_class
.get("KR")
.and_then(|m| m.get("0041"))
.expect("KR/0041 params should be present");
assert!(
(params.mu - 4.5).abs() < 0.1,
"mu {:.3} should be close to 4.5",
params.mu
);
assert_eq!(params.n, 50);
assert!(params.median_abs > 0.0, "median_abs must be positive");
}
#[test]
fn extract_source_amount_conditionals_skips_zeros() {
let mut records: Vec<Record> = (0..20)
.map(|_| make_amount_rec("KR", "0041", 100.0))
.collect();
records.extend((0..5).map(|_| make_amount_rec("KR", "0041", 0.0)));
records.extend((0..5).map(|_| make_amount_rec("KR", "0041", -50.0)));
let prior = extract_source_amount_conditionals(&records, 10);
let params = prior
.by_source_and_class
.get("KR")
.and_then(|m| m.get("0041"))
.expect("KR/0041 should be present");
assert_eq!(params.n, 25, "zero-amount records must be excluded from n");
}
#[test]
fn extract_source_amount_conditionals_builds_privacy_gated_sketch() {
let mut records: Vec<Record> = (0..2000)
.map(|i| make_amount_rec("KR", "0041", (i + 1) as f64))
.collect();
records.extend((0..500).map(|i| make_amount_rec("RV", "0040", (i + 1) as f64)));
let prior = extract_source_amount_conditionals(&records, 10);
let kr = prior
.quantile_sketch_by_source
.get("KR")
.expect("KR sketch present at n=2000");
assert_eq!(kr.n, 2000);
assert!(kr.is_usable());
let top = *kr.probabilities.last().expect("knots");
assert!(
(top - 0.995).abs() < 1e-12,
"top knot must be 0.995 under the 1-5/n cap, got {top}"
);
let p50_idx = kr
.probabilities
.iter()
.position(|p| (*p - 0.50).abs() < 1e-12)
.expect("p50 knot");
assert!(
(kr.values[p50_idx] - 1000.5).abs() < 2.0,
"p50 knot should be ~1000.5, got {}",
kr.values[p50_idx]
);
assert!(kr.values.windows(2).all(|w| w[1] >= w[0]));
assert!(!prior.quantile_sketch_by_source.contains_key("RV"));
assert!(prior.by_source.contains_key("RV"));
}
fn make_je_rec(src: &str, je: &str, year: i32, amount: f64) -> Record {
let mut r = make_amount_rec(src, "0041", amount);
r.je_number = je.into();
r.effective_date = NaiveDate::from_ymd_opt(year, 3, 15).expect("date");
r.entry_date = r.effective_date;
r
}
#[test]
fn extract_source_amount_conditionals_builds_je_total_sketch() {
let mut records: Vec<Record> = Vec::new();
for i in 0..1200 {
let x = (i + 1) as f64;
records.push(make_je_rec("SA", &format!("JE-{i}"), 2024, x));
records.push(make_je_rec("SA", &format!("JE-{i}"), 2024, -x));
}
for i in 0..300 {
records.push(make_je_rec("DR", &format!("DR-{i}"), 2024, 50_000.0));
records.push(make_je_rec("DR", &format!("DR-{i}"), 2024, -50_000.0));
}
let prior = extract_source_amount_conditionals(&records, 10);
let sa = prior
.je_total_sketch_by_source
.get("SA")
.expect("SA JE-total sketch present at 1200 JEs");
assert_eq!(sa.n, 1200, "one observation per JE, not per line");
assert!(sa.is_usable());
let p50_idx = sa
.probabilities
.iter()
.position(|p| (*p - 0.50).abs() < 1e-12)
.expect("p50 knot");
assert!(
(sa.values[p50_idx] - 600.5).abs() < 2.0,
"JE-total p50 knot should be ~600.5, got {}",
sa.values[p50_idx]
);
assert!(sa.values.windows(2).all(|w| w[1] >= w[0]));
let top = *sa.probabilities.last().expect("knots");
assert!(
(top - 0.995).abs() < 1e-12,
"top knot must respect the 1-5/n cap, got {top}"
);
assert!(!prior.je_total_sketch_by_source.contains_key("DR"));
}
#[test]
fn extract_source_amount_conditionals_builds_global_je_total_sketch() {
let mut records: Vec<Record> = Vec::new();
for i in 0..600 {
records.push(make_je_rec("AA", &format!("AA-{i}"), 2024, 100.0));
records.push(make_je_rec("AA", &format!("AA-{i}"), 2024, -100.0));
}
for i in 0..600 {
records.push(make_je_rec("BB", &format!("BB-{i}"), 2024, 1000.0));
records.push(make_je_rec("BB", &format!("BB-{i}"), 2024, -1000.0));
}
let prior = extract_source_amount_conditionals(&records, 10);
assert!(!prior.je_total_sketch_by_source.contains_key("AA"));
assert!(!prior.je_total_sketch_by_source.contains_key("BB"));
let g = prior
.je_total_sketch_global
.as_ref()
.expect("global JE-total sketch present once pooled JEs clear the gate");
assert_eq!(g.n, 1200, "one observation per JE across both sources");
assert!(g.is_usable());
let p50_idx = g
.probabilities
.iter()
.position(|p| (*p - 0.50).abs() < 1e-12)
.expect("p50 knot");
let p50 = g.values[p50_idx];
assert!(
p50 > 100.0 && p50 < 1000.0,
"pooled p50 must sit between the two sources' totals, got {p50}"
);
}
#[test]
fn global_je_total_sketch_absent_below_gate() {
let mut records: Vec<Record> = Vec::new();
for i in 0..500 {
records.push(make_je_rec("AA", &format!("AA-{i}"), 2024, 100.0));
records.push(make_je_rec("AA", &format!("AA-{i}"), 2024, -100.0));
}
let prior = extract_source_amount_conditionals(&records, 10);
assert!(prior.je_total_sketch_global.is_none());
}
#[test]
fn je_total_sketch_year_scopes_je_numbers() {
let mut records: Vec<Record> = Vec::new();
for year in [2024, 2025] {
for i in 0..1000 {
records.push(make_je_rec("SA", &format!("JE-{i}"), year, 100.0));
records.push(make_je_rec("SA", &format!("JE-{i}"), year, -100.0));
}
}
let prior = extract_source_amount_conditionals(&records, 10);
let sa = prior
.je_total_sketch_by_source
.get("SA")
.expect("SA JE-total sketch present");
assert_eq!(sa.n, 2000, "year-scoped: 1000 JEs per year, not merged");
assert!(
sa.values.iter().all(|v| (*v - 100.0).abs() < 1e-9),
"every JE total is 100 (a cross-year merge would double it), got {:?}",
&sa.values[..3.min(sa.values.len())]
);
}
#[test]
fn je_total_sums_debit_legs_only() {
let mut records: Vec<Record> = Vec::new();
for i in 0..1000 {
let je = format!("JE-{i}");
records.push(make_je_rec("SA", &je, 2024, 60.0));
records.push(make_je_rec("SA", &je, 2024, 40.0));
records.push(make_je_rec("SA", &je, 2024, -100.0));
}
for i in 0..50 {
records.push(make_je_rec("SA", &format!("CR-{i}"), 2024, -77.0));
}
let prior = extract_source_amount_conditionals(&records, 10);
let sa = prior
.je_total_sketch_by_source
.get("SA")
.expect("SA JE-total sketch present");
assert_eq!(sa.n, 1000, "credit-only JEs must not contribute");
assert!(
sa.values.iter().all(|v| (*v - 100.0).abs() < 1e-9),
"JE total must be the debit-leg sum (100), not the abs-sum (200); got {:?}",
&sa.values[..3.min(sa.values.len())]
);
}
#[test]
fn extract_behavioral_priors_populates_source_amount_conditionals() {
let base = NaiveDate::from_ymd_opt(2022, 1, 1).unwrap();
let mut recs: Vec<Record> = Vec::new();
for i in 0..50 {
let mut r = make_amount_rec("KR", "0041", 100.0 + i as f64);
r.entry_date = base + chrono::Duration::days(i);
r.effective_date = r.entry_date;
r.je_number = format!("JE-{i}");
recs.push(r);
}
let bp = extract_behavioral_priors(&recs, "test").expect("ok");
assert!(
bp.source_amount_conditionals.is_some(),
"source_amount_conditionals should be populated"
);
let sac = bp.source_amount_conditionals.as_ref().unwrap();
assert!(
sac.by_source.contains_key("KR"),
"KR marginal should be present"
);
}
fn make_role_records(source: &str, role_sign: f64, gl: &str, n: usize) -> Vec<Record> {
let base = NaiveDate::from_ymd_opt(2022, 1, 1).expect("date");
(0..n)
.map(|i| Record {
source: source.to_string(),
gl_account: gl.to_string(),
cost_center: None,
profit_center: None,
trading_partner: None,
je_number: format!("J{i:06}"),
je_line_number: "001".to_string(),
effective_date: base,
entry_date: base,
created_at: None,
functional_amount: role_sign * 100.0,
header_text: String::new(),
line_text: String::new(),
})
.collect()
}
#[test]
fn sp4_6_extract_source_role_gl_produces_role_conditionals() {
let mut records = make_role_records("KR", 1.0, "6000", 15); records.extend(make_role_records("KR", -1.0, "2000", 15));
let prior = extract_source_role_gl(&records, 10);
assert!(
prior.conditional("KR", "DR").is_some(),
"KR DR should be present"
);
assert!(
prior.conditional("KR", "CR").is_some(),
"KR CR should be present"
);
let dr_dist = prior.conditional("KR", "DR").unwrap();
assert!(
dr_dist.probabilities.contains_key("6000"),
"DR should have 6000"
);
let cr_dist = prior.conditional("KR", "CR").unwrap();
assert!(
cr_dist.probabilities.contains_key("2000"),
"CR should have 2000"
);
}
#[test]
fn sp4_6_extract_source_role_gl_filters_low_counts() {
let mut records = make_role_records("KR", 1.0, "6000", 15);
records.extend(make_role_records("KR", -1.0, "2000", 3));
let prior = extract_source_role_gl(&records, 10);
assert!(
prior.conditional("KR", "DR").is_some(),
"KR DR should pass threshold"
);
assert!(
prior.conditional("KR", "CR").is_none(),
"KR CR with only 3 obs should be dropped"
);
}
#[test]
fn sp4_6_extract_source_role_gl_skips_zero_amounts() {
let records: Vec<Record> = (0..20)
.map(|i| Record {
source: "SA".to_string(),
gl_account: "4000".to_string(),
cost_center: None,
profit_center: None,
trading_partner: None,
je_number: format!("J{i:06}"),
je_line_number: "001".to_string(),
effective_date: NaiveDate::from_ymd_opt(2022, 1, 1).unwrap(),
entry_date: NaiveDate::from_ymd_opt(2022, 1, 1).unwrap(),
created_at: None,
functional_amount: 0.0, header_text: String::new(),
line_text: String::new(),
})
.collect();
let prior = extract_source_role_gl(&records, 10);
assert!(
prior.by_source_and_role.is_empty(),
"zero-amount records should yield an empty prior"
);
}
#[test]
fn sp4_6_extract_behavioral_priors_populates_source_role_gl_conditionals() {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let mut recs: Vec<Record> = Vec::new();
let base = NaiveDate::from_ymd_opt(2022, 1, 1).unwrap();
let sources = ["KR", "RV", "SA"];
let gl_dr = ["6000", "6100", "6200"];
let gl_cr = ["2000", "2100", "1000"];
for i in 0..1200usize {
let src = sources[i % sources.len()];
let (amt, gl) = if i % 2 == 0 {
(100.0, gl_dr[i % gl_dr.len()])
} else {
(-100.0, gl_cr[i % gl_cr.len()])
};
recs.push(Record {
source: src.to_string(),
gl_account: gl.to_string(),
cost_center: None,
profit_center: None,
trading_partner: None,
je_number: format!("J{:06}", i / 3),
je_line_number: format!("{:03}", (i % 3) + 1),
effective_date: base + Duration::days(rng.random_range(0..365)),
entry_date: base + Duration::days(rng.random_range(0..365)),
created_at: None,
functional_amount: amt,
header_text: String::new(),
line_text: String::new(),
});
}
let bp = extract_behavioral_priors(&recs, "test").expect("ok");
assert!(
bp.source_role_gl_conditionals.is_some(),
"source_role_gl_conditionals should be populated with sufficient data"
);
assert!(
bp.source_flow_pairs.is_some(),
"source_flow_pairs should be populated with sufficient data"
);
}
fn make_flow_records(source: &str, gl_dr: &str, gl_cr: &str, n: usize) -> Vec<Record> {
let base = NaiveDate::from_ymd_opt(2022, 1, 1).expect("date");
let mut recs = Vec::with_capacity(n * 2);
for i in 0..n {
for (sign, gl, line) in [(1.0, gl_dr, "001"), (-1.0, gl_cr, "002")] {
recs.push(Record {
source: source.to_string(),
gl_account: gl.to_string(),
cost_center: None,
profit_center: None,
trading_partner: None,
je_number: format!("{source}-J{i:06}"),
je_line_number: line.to_string(),
effective_date: base,
entry_date: base,
created_at: None,
functional_amount: sign * 100.0,
header_text: String::new(),
line_text: String::new(),
});
}
}
recs
}
#[test]
fn sp4_8_extract_source_flow_pairs_produces_joint_pairs() {
let records = make_flow_records("KR", "5010", "2100", 20);
let prior = extract_source_flow_pairs(&records, 1, 10);
let dist = prior.pairs("KR").expect("KR pairs present");
assert_eq!(dist.probabilities.len(), 1);
assert!((dist.probabilities["5|2"] - 1.0).abs() < 1e-9);
assert_eq!(dist.n, 20);
assert_eq!(prior.granularity, 1);
}
#[test]
fn sp4_8_extract_source_flow_pairs_filters_low_counts() {
let mut records = make_flow_records("KR", "5010", "2100", 20);
records.extend(make_flow_records("KR", "1200", "4000", 3)); records.extend(make_flow_records("DZ", "1000", "1100", 4)); let prior = extract_source_flow_pairs(&records, 1, 10);
let dist = prior.pairs("KR").expect("KR pairs present");
assert_eq!(
dist.probabilities.len(),
1,
"low-count 1|4 pair must be gated out"
);
assert!(dist.probabilities.contains_key("5|2"));
assert!(prior.pairs("DZ").is_none(), "sparse source must be dropped");
}
#[test]
fn sp4_8_extract_source_flow_pairs_skips_non_numeric_accounts() {
let records = make_flow_records("KR", "5010", "SUSPENSE", 20);
let prior = extract_source_flow_pairs(&records, 1, 10);
assert!(
prior.pairs("KR").is_none(),
"non-numeric credit side yields no cross pairs"
);
}
#[test]
fn sp4_8_extract_source_flow_pairs_cross_product() {
let base = NaiveDate::from_ymd_opt(2022, 1, 1).expect("date");
let mut records = Vec::new();
for i in 0..15 {
for (sign, gl, line) in [
(1.0, "5010", "001"),
(1.0, "6010", "002"),
(-1.0, "2100", "003"),
] {
records.push(Record {
source: "SA".to_string(),
gl_account: gl.to_string(),
cost_center: None,
profit_center: None,
trading_partner: None,
je_number: format!("SA-J{i:06}"),
je_line_number: line.to_string(),
effective_date: base,
entry_date: base,
created_at: None,
functional_amount: sign * 100.0,
header_text: String::new(),
line_text: String::new(),
});
}
}
let prior = extract_source_flow_pairs(&records, 1, 10);
let dist = prior.pairs("SA").expect("SA pairs present");
assert_eq!(dist.probabilities.len(), 2);
assert!((dist.probabilities["5|2"] - 0.5).abs() < 1e-9);
assert!((dist.probabilities["6|2"] - 0.5).abs() < 1e-9);
}
}