use std::collections::{BTreeMap, HashSet};
use std::path::Path;
use chrono::NaiveDate;
use datasynth_eval::behavioral_fidelity::loader::{load_csv_records, load_parquet_records};
use datasynth_eval::behavioral_fidelity::math::pearson_lag1_correlation;
use datasynth_eval::behavioral_fidelity::Record;
use crate::error::FingerprintError;
use crate::models::behavioral::BehavioralPriors;
use crate::models::behavioral::{
ActiveLifetimePrior, ActiveSegmentsPrior, CategoricalDistribution, EntityCluster,
EntityClustersPrior, FanoutPrior, IetSummary, LagSummary, LineCountHistogram, LinesPerJePrior,
LognormalAmount, LognormalParams, PerSourceAmountPrior, PerSourceAttributePrior,
PerSourceIetPrior, PerSourceRolePrior, PostingLagPrior, SourceMixPrior, SourceSegmentSummary,
ACTIVE_LIFETIME_DAY_BUCKETS, FANOUT_BUCKETS, LINE_COUNT_BUCKETS, SEGMENT_COUNT_BUCKETS,
SEGMENT_GAP_BUCKETS,
};
use crate::models::EmpiricalCdf;
use super::reference_extractor::extract_reference_formats;
use super::user_extractor::extract_user_personas;
pub const DEFAULT_MIN_USER_RECORDS: usize = 100;
pub const DEFAULT_MIN_REFERENCE_OCCURRENCES: usize = 10;
pub const DEFAULT_MIN_SOURCE_THRESHOLD: f64 = 0.005;
pub const DEFAULT_MIN_SOURCE_OBSERVATIONS: usize = 1000;
pub fn extract_source_mix(
records: &[Record],
min_threshold: f64,
min_observations: usize,
) -> SourceMixPrior {
if records.is_empty() {
return SourceMixPrior {
probabilities: BTreeMap::new(),
other_fraction: 0.0,
min_threshold,
};
}
let mut counts: BTreeMap<String, usize> = BTreeMap::new();
for r in records {
*counts.entry(r.source.clone()).or_insert(0) += 1;
}
counts.retain(|_, c| *c >= min_observations);
let total = records.len() as f64;
let mut probabilities = BTreeMap::new();
let mut other = 0.0;
for (src, c) in counts {
let frac = c as f64 / total;
if frac >= min_threshold {
probabilities.insert(src, frac);
} else {
other += frac;
}
}
let retained_sum: f64 = probabilities.values().sum();
let dropped_mass = 1.0 - retained_sum - other;
let other_fraction = other + dropped_mass;
if retained_sum > 0.0 {
for v in probabilities.values_mut() {
*v /= retained_sum;
}
}
SourceMixPrior {
probabilities,
other_fraction,
min_threshold,
}
}
pub const DEFAULT_MIN_IET_SAMPLES: usize = 100;
pub fn extract_per_source_iet(records: &[Record], min_samples: usize) -> PerSourceIetPrior {
let mut by_source: BTreeMap<String, Vec<NaiveDate>> = BTreeMap::new();
for r in records {
by_source
.entry(r.source.clone())
.or_default()
.push(r.entry_date);
}
let mut summaries: BTreeMap<String, IetSummary> = BTreeMap::new();
for (source, mut dates) in by_source {
if dates.len() < 2 {
continue;
}
dates.sort();
let iets: Vec<f64> = dates
.windows(2)
.map(|w| (w[1] - w[0]).num_days() as f64)
.collect();
if iets.len() < min_samples {
continue;
}
let cdf = build_empirical_cdf(&format!("iet_{source}"), &iets);
let lognormal = fit_lognormal(&iets);
let auto = pearson_lag1_correlation(&iets).unwrap_or(0.0);
summaries.insert(
source,
IetSummary {
n: iets.len(),
empirical_cdf_days: cdf,
lognormal_fit: lognormal,
lag1_autocorr: auto,
},
);
}
PerSourceIetPrior {
by_source: summaries,
}
}
fn build_empirical_cdf(column: &str, samples: &[f64]) -> EmpiricalCdf {
let mut sorted: Vec<f64> = samples.iter().copied().filter(|x| x.is_finite()).collect();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
EmpiricalCdf::from_sorted_values(column.to_string(), sorted)
}
pub const DEFAULT_MIN_LAG_SAMPLES: usize = 100;
pub fn extract_posting_lag(records: &[Record], min_samples: usize) -> Option<PostingLagPrior> {
if records.is_empty() {
return None;
}
let mut by_source: BTreeMap<String, Vec<f64>> = BTreeMap::new();
for r in records {
let lag = (r.effective_date - r.entry_date).num_days() as f64;
by_source.entry(r.source.clone()).or_default().push(lag);
}
let mut summaries: BTreeMap<String, LagSummary> = BTreeMap::new();
for (source, samples) in by_source {
if samples.len() < min_samples {
continue;
}
let n = samples.len();
let mean = samples.iter().sum::<f64>() / n as f64;
let var = samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n as f64;
let cdf = build_empirical_cdf(&format!("lag_{source}"), &samples);
summaries.insert(
source,
LagSummary {
empirical_cdf_days: cdf,
mean,
stddev: var.sqrt(),
n,
},
);
}
if summaries.is_empty() {
None
} else {
Some(PostingLagPrior {
by_source: summaries,
})
}
}
fn fit_lognormal(samples: &[f64]) -> Option<LognormalParams> {
let log_samples: Vec<f64> = samples
.iter()
.filter(|&&x| x.is_finite() && x > 0.0)
.map(|&x| (x + 1.0).ln())
.collect();
if log_samples.len() < 3 {
return None;
}
let n = log_samples.len() as f64;
let mean: f64 = log_samples.iter().sum::<f64>() / n;
let var: f64 = log_samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n.max(1.0);
Some(LognormalParams {
mu: mean,
sigma: var.sqrt(),
})
}
pub fn extract_active_lifetime(records: &[Record]) -> ActiveLifetimePrior {
let mut by_source: BTreeMap<String, (NaiveDate, NaiveDate)> = BTreeMap::new();
for r in records {
let d = r.entry_date;
by_source
.entry(r.source.clone())
.and_modify(|(lo, hi)| {
if d < *lo {
*lo = d;
}
if d > *hi {
*hi = d;
}
})
.or_insert((d, d));
}
let lifetimes_by_source: Vec<u32> = by_source
.values()
.map(|(lo, hi)| hi.signed_duration_since(*lo).num_days().max(0) as u32)
.collect();
let (overall, _) = LineCountHistogram::build(&lifetimes_by_source, ACTIVE_LIFETIME_DAY_BUCKETS);
let mut per_source_hists: BTreeMap<String, LineCountHistogram> = BTreeMap::new();
for (src, (lo, hi)) in &by_source {
let life = hi.signed_duration_since(*lo).num_days().max(0) as u32;
let (h, _) = LineCountHistogram::build(&[life], ACTIVE_LIFETIME_DAY_BUCKETS);
per_source_hists.insert(src.clone(), h);
}
ActiveLifetimePrior {
by_source: per_source_hists,
overall,
}
}
type AttributeProjector = fn(&Record) -> Option<String>;
pub fn extract_fanout(records: &[Record]) -> FanoutPrior {
let attributes: [(&str, AttributeProjector); 4] = [
("GLAccount", |r| Some(r.gl_account.clone())),
("CostCenter", |r| r.cost_center.clone()),
("ProfitCenter", |r| r.profit_center.clone()),
("TradingPartner", |r| r.trading_partner.clone()),
];
let mut by_attribute: BTreeMap<String, LineCountHistogram> = BTreeMap::new();
for (name, proj) in attributes {
let mut sources_per_value: BTreeMap<String, HashSet<String>> = BTreeMap::new();
for r in records {
if let Some(v) = proj(r) {
sources_per_value
.entry(v)
.or_default()
.insert(r.source.clone());
}
}
let fanouts: Vec<u32> = sources_per_value.values().map(|s| s.len() as u32).collect();
let (hist, _) = LineCountHistogram::build(&fanouts, FANOUT_BUCKETS);
by_attribute.insert(name.to_string(), hist);
}
FanoutPrior { by_attribute }
}
pub const DEFAULT_MIN_ATTRIBUTE_OBSERVATIONS: usize = 10;
pub fn extract_per_source_attribute(
records: &[Record],
min_observations: usize,
) -> PerSourceAttributePrior {
let mut counts: BTreeMap<String, BTreeMap<String, BTreeMap<String, usize>>> = BTreeMap::new();
for r in records {
if r.source.is_empty() {
continue;
}
let source_map = counts.entry(r.source.clone()).or_default();
if !r.gl_account.is_empty() {
*source_map
.entry("gl_account".to_string())
.or_default()
.entry(r.gl_account.clone())
.or_default() += 1;
}
if let Some(cc) = r.cost_center.as_ref().filter(|s| !s.is_empty()) {
*source_map
.entry("cost_center".to_string())
.or_default()
.entry(cc.clone())
.or_default() += 1;
}
if let Some(pc) = r.profit_center.as_ref().filter(|s| !s.is_empty()) {
*source_map
.entry("profit_center".to_string())
.or_default()
.entry(pc.clone())
.or_default() += 1;
}
if let Some(tp) = r.trading_partner.as_ref().filter(|s| !s.is_empty()) {
*source_map
.entry("trading_partner".to_string())
.or_default()
.entry(tp.clone())
.or_default() += 1;
}
}
let by_source = counts
.into_iter()
.filter_map(|(source, attr_map)| {
let kept: BTreeMap<String, CategoricalDistribution> = attr_map
.into_iter()
.filter_map(|(attr, value_counts)| {
let total: usize = value_counts.values().sum();
if total < min_observations {
None
} else {
Some((attr, CategoricalDistribution::from_counts(value_counts)))
}
})
.collect();
if kept.is_empty() {
None
} else {
Some((source, kept))
}
})
.collect();
PerSourceAttributePrior {
by_source,
min_observations,
}
}
pub const DEFAULT_MIN_SOURCE_ROLE_OBSERVATIONS: usize = 10;
pub fn extract_source_role_gl(records: &[Record], min_observations: usize) -> PerSourceRolePrior {
let mut counts: BTreeMap<String, BTreeMap<String, BTreeMap<String, usize>>> = BTreeMap::new();
for r in records {
if r.source.is_empty() || r.gl_account.is_empty() {
continue;
}
let role = if r.functional_amount > 0.0 {
"DR"
} else if r.functional_amount < 0.0 {
"CR"
} else {
continue; };
*counts
.entry(r.source.clone())
.or_default()
.entry(role.to_string())
.or_default()
.entry(r.gl_account.clone())
.or_default() += 1;
}
let mut by_source_and_role: BTreeMap<String, BTreeMap<String, CategoricalDistribution>> =
BTreeMap::new();
for (source, role_map) in counts {
let mut roles: BTreeMap<String, CategoricalDistribution> = BTreeMap::new();
for (role, value_counts) in role_map {
let filtered: BTreeMap<String, usize> = value_counts
.into_iter()
.filter(|(_, c)| *c >= min_observations)
.collect();
let total: usize = filtered.values().sum();
if total < min_observations {
continue;
}
roles.insert(role, CategoricalDistribution::from_counts(filtered));
}
if !roles.is_empty() {
by_source_and_role.insert(source, roles);
}
}
PerSourceRolePrior { by_source_and_role }
}
pub const DEFAULT_MIN_AMOUNT_OBSERVATIONS: usize = 10;
pub fn extract_source_amount_conditionals(
records: &[Record],
min_observations: usize,
) -> PerSourceAmountPrior {
use std::collections::BTreeMap;
let mut by_pair: BTreeMap<(String, String), Vec<f64>> = BTreeMap::new();
let mut by_src: BTreeMap<String, Vec<f64>> = BTreeMap::new();
for r in records {
let abs_amt = r.functional_amount.abs();
if abs_amt <= 0.0 || !abs_amt.is_finite() {
continue;
}
if r.source.is_empty() {
continue;
}
let gl_prefix: String = if r.gl_account.len() >= 4 {
r.gl_account[..4].to_string()
} else {
r.gl_account.clone()
};
by_pair
.entry((r.source.clone(), gl_prefix))
.or_default()
.push(abs_amt);
by_src.entry(r.source.clone()).or_default().push(abs_amt);
}
let mut by_source_and_class: BTreeMap<String, BTreeMap<String, LognormalAmount>> =
BTreeMap::new();
for ((source, gl_prefix), values) in by_pair {
if values.len() < min_observations {
continue;
}
if let Some(params) = fit_lognormal_amount(&values) {
by_source_and_class
.entry(source)
.or_default()
.insert(gl_prefix, params);
}
}
let mut by_source: BTreeMap<String, LognormalAmount> = BTreeMap::new();
for (source, values) in by_src {
if values.len() < min_observations {
continue;
}
if let Some(params) = fit_lognormal_amount(&values) {
by_source.insert(source, params);
}
}
PerSourceAmountPrior {
by_source_and_class,
by_source,
}
}
fn fit_lognormal_amount(values: &[f64]) -> Option<LognormalAmount> {
let log_vals: Vec<f64> = values
.iter()
.filter(|&&v| v > 0.0 && v.is_finite())
.map(|&v| v.ln())
.collect();
if log_vals.len() < 2 {
return None;
}
let n = log_vals.len() as f64;
let mu = log_vals.iter().sum::<f64>() / n;
let var = log_vals.iter().map(|x| (x - mu).powi(2)).sum::<f64>() / n.max(1.0);
let sigma = var.sqrt();
let mut sorted = values.to_vec();
sorted.retain(|v| *v > 0.0 && v.is_finite());
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let median_abs = if sorted.is_empty() {
0.0
} else if sorted.len().is_multiple_of(2) {
(sorted[sorted.len() / 2 - 1] + sorted[sorted.len() / 2]) / 2.0
} else {
sorted[sorted.len() / 2]
};
Some(LognormalAmount {
mu,
sigma,
n: log_vals.len(),
median_abs,
})
}
pub const DEFAULT_MIN_JES_PER_SOURCE: usize = 500;
pub fn extract_lines_per_je(records: &[Record], min_jes_per_source: usize) -> LinesPerJePrior {
let mut lines_per_je: BTreeMap<String, u32> = BTreeMap::new();
let mut source_of_je: BTreeMap<String, String> = BTreeMap::new();
for r in records {
*lines_per_je.entry(r.je_number.clone()).or_insert(0) += 1;
source_of_je
.entry(r.je_number.clone())
.or_insert_with(|| r.source.clone());
}
let overall_values: Vec<u32> = lines_per_je.values().copied().collect();
let (overall, _) = LineCountHistogram::build(&overall_values, LINE_COUNT_BUCKETS);
let mut by_source_values: BTreeMap<String, Vec<u32>> = BTreeMap::new();
for (je, n_lines) in &lines_per_je {
if let Some(src) = source_of_je.get(je) {
by_source_values
.entry(src.clone())
.or_default()
.push(*n_lines);
}
}
let mut by_source: BTreeMap<String, LineCountHistogram> = BTreeMap::new();
for (src, values) in by_source_values {
if values.len() < min_jes_per_source {
continue;
}
let (hist, _) = LineCountHistogram::build(&values, LINE_COUNT_BUCKETS);
by_source.insert(src, hist);
}
LinesPerJePrior {
overall,
by_source,
min_jes_per_source,
}
}
pub type BehavioralResult<T> = Result<T, FingerprintError>;
pub fn extract_behavioral_priors(
records: &[Record],
industry: &str,
) -> BehavioralResult<BehavioralPriors> {
Ok(BehavioralPriors {
schema_version: BehavioralPriors::SCHEMA_VERSION,
generator_version: env!("CARGO_PKG_VERSION").to_string(),
industry: industry.to_string(),
n_client_inputs: 1,
n_rows_aggregated: records.len(),
source_mix: extract_source_mix(
records,
DEFAULT_MIN_SOURCE_THRESHOLD,
DEFAULT_MIN_SOURCE_OBSERVATIONS,
),
per_source_iet: extract_per_source_iet(records, DEFAULT_MIN_IET_SAMPLES),
lines_per_je: extract_lines_per_je(records, DEFAULT_MIN_JES_PER_SOURCE),
active_lifetime: extract_active_lifetime(records),
fanout: extract_fanout(records),
posting_lag: extract_posting_lag(records, DEFAULT_MIN_LAG_SAMPLES),
active_segments: Some(extract_active_segments(records)),
entity_clusters: Some(extract_entity_clusters(records)),
per_source_attribute: Some(extract_per_source_attribute(
records,
DEFAULT_MIN_ATTRIBUTE_OBSERVATIONS,
)),
tp_entity_clusters: Some(extract_tp_entity_clusters(records)),
reference_formats: {
let rf = extract_reference_formats(records, DEFAULT_MIN_REFERENCE_OCCURRENCES);
if rf.by_source.is_empty() {
None
} else {
Some(rf)
}
},
coa_semantic: None,
user_personas: {
let up = extract_user_personas(records, DEFAULT_MIN_USER_RECORDS);
Some(up)
},
source_amount_conditionals: {
let sac = extract_source_amount_conditionals(records, DEFAULT_MIN_AMOUNT_OBSERVATIONS);
if sac.by_source.is_empty() && sac.by_source_and_class.is_empty() {
None
} else {
Some(sac)
}
},
source_role_gl_conditionals: {
let srg = extract_source_role_gl(records, DEFAULT_MIN_SOURCE_ROLE_OBSERVATIONS);
if srg.by_source_and_role.is_empty() {
None
} else {
Some(srg)
}
},
text_taxonomy: None,
tb_anchor: None,
})
}
pub fn extract_behavioral_priors_from_path(
path: &Path,
industry: &str,
) -> BehavioralResult<BehavioralPriors> {
let records = match path.extension().and_then(|s| s.to_str()) {
Some("parquet") => load_parquet_records(path)
.map_err(|e| io_error_to_fp(format!("parquet load failed: {e}")))?,
Some("csv") => {
load_csv_records(path).map_err(|e| io_error_to_fp(format!("csv load failed: {e}")))?
}
_ => {
return Err(io_error_to_fp(format!(
"unsupported extension at {}",
path.display()
)));
}
};
extract_behavioral_priors(&records, industry)
}
fn io_error_to_fp(msg: String) -> FingerprintError {
FingerprintError::ExtractionError {
extractor: "behavioral_priors".to_string(),
message: msg,
}
}
pub const SEGMENT_GAP_THRESHOLD_DAYS: i64 = 7;
fn split_into_segments(
dates: &[NaiveDate],
gap_threshold: i64,
) -> (Vec<(NaiveDate, NaiveDate)>, Vec<u32>) {
if dates.is_empty() {
return (vec![], vec![]);
}
let mut segments = Vec::new();
let mut gaps = Vec::new();
let mut seg_start = dates[0];
let mut seg_end = dates[0];
for &d in &dates[1..] {
let gap = (d - seg_end).num_days();
if gap > gap_threshold {
segments.push((seg_start, seg_end));
gaps.push(gap as u32);
seg_start = d;
seg_end = d;
} else {
seg_end = d;
}
}
segments.push((seg_start, seg_end));
(segments, gaps)
}
pub fn extract_active_segments(records: &[Record]) -> ActiveSegmentsPrior {
let mut by_source: BTreeMap<String, Vec<NaiveDate>> = BTreeMap::new();
for r in records {
by_source
.entry(r.source.clone())
.or_default()
.push(r.entry_date);
}
let mut summaries: BTreeMap<String, SourceSegmentSummary> = BTreeMap::new();
for (src, mut dates) in by_source {
dates.sort();
dates.dedup();
if dates.len() < 2 {
continue;
}
let (segments, gaps) = split_into_segments(&dates, SEGMENT_GAP_THRESHOLD_DAYS);
let segment_count = segments.len() as u32;
let segment_lengths: Vec<u32> = segments
.iter()
.map(|s| (s.1 - s.0).num_days().max(0) as u32)
.collect();
let (count_hist, _) = LineCountHistogram::build(&[segment_count], SEGMENT_COUNT_BUCKETS);
let (length_hist, _) =
LineCountHistogram::build(&segment_lengths, ACTIVE_LIFETIME_DAY_BUCKETS);
let (gap_hist, _) = LineCountHistogram::build(&gaps, SEGMENT_GAP_BUCKETS);
summaries.insert(
src,
SourceSegmentSummary {
segment_count_histogram: count_hist,
segment_length_histogram: length_hist,
gap_length_histogram: gap_hist,
},
);
}
ActiveSegmentsPrior {
by_source: summaries,
}
}
const MAX_SOURCES_FOR_CLUSTERING: usize = 50;
const JACCARD_THRESHOLD: f64 = 0.3;
const CANONICAL_SAP_CODES: &[&str] = &[
"KR", "RV", "DZ", "WE", "RE", "SA", "IM", "KZ", "AB", "AF", "DR", "KK", "K9", "KX", "PK", "RB",
"RY", "SL", "ZP",
];
fn normalise_source_code(raw: &str) -> Option<String> {
let trimmed = raw.trim();
if trimmed.is_empty() {
return None;
}
if CANONICAL_SAP_CODES.contains(&trimmed) {
return Some(trimmed.to_string());
}
match trimmed {
"0" | "00" => Some("SA".to_string()),
"1" | "01" => Some("RV".to_string()),
"2" | "02" => Some("KR".to_string()),
_ => None,
}
}
pub fn extract_entity_clusters(records: &[Record]) -> EntityClustersPrior {
let mut row_count_per_source: BTreeMap<String, usize> = BTreeMap::new();
for r in records {
*row_count_per_source.entry(r.source.clone()).or_insert(0) += 1;
}
let mut sorted_sources: Vec<(String, usize)> = row_count_per_source.into_iter().collect();
sorted_sources.sort_by_key(|b| std::cmp::Reverse(b.1));
let top_sources: Vec<String> = sorted_sources
.into_iter()
.take(MAX_SOURCES_FOR_CLUSTERING)
.map(|(s, _)| s)
.collect();
let top_set: HashSet<&String> = top_sources.iter().collect();
let mut attr_sets: BTreeMap<String, HashSet<String>> = BTreeMap::new();
for r in records {
if !top_set.contains(&r.source) {
continue;
}
let set = attr_sets.entry(r.source.clone()).or_default();
set.insert(format!("GL:{}", r.gl_account));
if let Some(cc) = &r.cost_center {
set.insert(format!("CC:{cc}"));
}
if let Some(pc) = &r.profit_center {
set.insert(format!("PC:{pc}"));
}
if let Some(tp) = &r.trading_partner {
set.insert(format!("TP:{tp}"));
}
}
let attr_sets: BTreeMap<String, HashSet<String>> = attr_sets
.into_iter()
.filter_map(|(raw, set)| normalise_source_code(&raw).map(|canonical| (canonical, set)))
.fold(BTreeMap::new(), |mut acc, (canonical, set)| {
acc.entry(canonical).or_default().extend(set);
acc
});
let sources: Vec<String> = attr_sets.keys().cloned().collect();
let mut adj: BTreeMap<String, Vec<String>> = BTreeMap::new();
let mut edge_weights: BTreeMap<(String, String), f64> = BTreeMap::new();
for i in 0..sources.len() {
for j in (i + 1)..sources.len() {
let a = &attr_sets[&sources[i]];
let b = &attr_sets[&sources[j]];
if a.is_empty() || b.is_empty() {
continue;
}
let intersection = a.intersection(b).count() as f64;
let union = a.union(b).count() as f64;
if union == 0.0 {
continue;
}
let jaccard = intersection / union;
if jaccard >= JACCARD_THRESHOLD {
adj.entry(sources[i].clone())
.or_default()
.push(sources[j].clone());
adj.entry(sources[j].clone())
.or_default()
.push(sources[i].clone());
let key = if sources[i] < sources[j] {
(sources[i].clone(), sources[j].clone())
} else {
(sources[j].clone(), sources[i].clone())
};
edge_weights.insert(key, jaccard);
}
}
}
let mut visited: HashSet<String> = HashSet::new();
let mut clusters: Vec<EntityCluster> = Vec::new();
for src in &sources {
if visited.contains(src) {
continue;
}
let mut members = Vec::new();
let mut stack = vec![src.clone()];
while let Some(s) = stack.pop() {
if !visited.insert(s.clone()) {
continue;
}
members.push(s.clone());
if let Some(neighbors) = adj.get(&s) {
for n in neighbors {
if !visited.contains(n) {
stack.push(n.clone());
}
}
}
}
if members.len() >= 2 {
let mut sum = 0.0;
let mut count = 0.0;
for i in 0..members.len() {
for j in (i + 1)..members.len() {
let key = if members[i] < members[j] {
(members[i].clone(), members[j].clone())
} else {
(members[j].clone(), members[i].clone())
};
if let Some(&w) = edge_weights.get(&key) {
sum += w;
count += 1.0;
}
}
}
let avg_jaccard = if count > 0.0 { sum / count } else { 0.0 };
clusters.push(EntityCluster {
members,
avg_jaccard,
});
}
}
let total_in_clusters: usize = clusters.iter().map(|c| c.members.len()).sum();
let denom = sources.len().max(1);
let clustering_rate = total_in_clusters as f64 / denom as f64;
EntityClustersPrior {
clusters,
clustering_rate,
}
}
const MAX_TP_FOR_CLUSTERING: usize = 200;
pub fn extract_tp_entity_clusters(records: &[Record]) -> EntityClustersPrior {
let mut row_count_per_tp: BTreeMap<String, usize> = BTreeMap::new();
for r in records {
if let Some(tp) = &r.trading_partner {
if !tp.is_empty() {
*row_count_per_tp.entry(tp.clone()).or_insert(0) += 1;
}
}
}
let mut sorted_tps: Vec<(String, usize)> = row_count_per_tp.into_iter().collect();
sorted_tps.sort_by_key(|b| std::cmp::Reverse(b.1));
let top_tps: Vec<String> = sorted_tps
.into_iter()
.take(MAX_TP_FOR_CLUSTERING)
.map(|(tp, _)| tp)
.collect();
let top_set: HashSet<&String> = top_tps.iter().collect();
let mut attr_sets: BTreeMap<String, HashSet<String>> = BTreeMap::new();
for r in records {
let tp = match &r.trading_partner {
Some(tp) if !tp.is_empty() && top_set.contains(tp) => tp.clone(),
_ => continue,
};
let set = attr_sets.entry(tp).or_default();
set.insert(format!("GL:{}", r.gl_account));
if let Some(cc) = &r.cost_center {
set.insert(format!("CC:{cc}"));
}
if let Some(pc) = &r.profit_center {
set.insert(format!("PC:{pc}"));
}
set.insert(format!("SRC:{}", r.source));
}
let tps: Vec<String> = attr_sets.keys().cloned().collect();
let mut adj: BTreeMap<String, Vec<String>> = BTreeMap::new();
let mut edge_weights: BTreeMap<(String, String), f64> = BTreeMap::new();
for i in 0..tps.len() {
for j in (i + 1)..tps.len() {
let a = &attr_sets[&tps[i]];
let b = &attr_sets[&tps[j]];
if a.is_empty() || b.is_empty() {
continue;
}
let intersection = a.intersection(b).count() as f64;
let union = a.union(b).count() as f64;
if union == 0.0 {
continue;
}
let jaccard = intersection / union;
if jaccard >= JACCARD_THRESHOLD {
adj.entry(tps[i].clone()).or_default().push(tps[j].clone());
adj.entry(tps[j].clone()).or_default().push(tps[i].clone());
let key = if tps[i] < tps[j] {
(tps[i].clone(), tps[j].clone())
} else {
(tps[j].clone(), tps[i].clone())
};
edge_weights.insert(key, jaccard);
}
}
}
let mut visited: HashSet<String> = HashSet::new();
let mut clusters: Vec<EntityCluster> = Vec::new();
for tp in &tps {
if visited.contains(tp) {
continue;
}
let mut members = Vec::new();
let mut stack = vec![tp.clone()];
while let Some(t) = stack.pop() {
if !visited.insert(t.clone()) {
continue;
}
members.push(t.clone());
if let Some(neighbors) = adj.get(&t) {
for n in neighbors {
if !visited.contains(n) {
stack.push(n.clone());
}
}
}
}
if members.len() >= 2 {
let mut sum = 0.0;
let mut count = 0.0;
for i in 0..members.len() {
for j in (i + 1)..members.len() {
let key = if members[i] < members[j] {
(members[i].clone(), members[j].clone())
} else {
(members[j].clone(), members[i].clone())
};
if let Some(&w) = edge_weights.get(&key) {
sum += w;
count += 1.0;
}
}
}
let avg_jaccard = if count > 0.0 { sum / count } else { 0.0 };
clusters.push(EntityCluster {
members,
avg_jaccard,
});
}
}
let total_in_clusters: usize = clusters.iter().map(|c| c.members.len()).sum();
let denom = tps.len().max(1);
let clustering_rate = total_in_clusters as f64 / denom as f64;
EntityClustersPrior {
clusters,
clustering_rate,
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::{Duration, NaiveDate};
use rand::{RngExt, SeedableRng};
pub(crate) fn rec(src: &str) -> Record {
let d = NaiveDate::from_ymd_opt(2022, 1, 1).expect("date");
Record {
source: src.into(),
gl_account: "1".into(),
cost_center: None,
profit_center: None,
trading_partner: None,
je_number: "J1".into(),
je_line_number: "001".into(),
effective_date: d,
entry_date: d,
created_at: None,
functional_amount: 1.0,
header_text: String::new(),
line_text: String::new(),
}
}
#[test]
fn source_mix_shares_match() {
let mut recs: Vec<Record> = Vec::new();
recs.extend(std::iter::repeat_with(|| rec("A")).take(60));
recs.extend(std::iter::repeat_with(|| rec("B")).take(30));
recs.extend(std::iter::repeat_with(|| rec("C")).take(10));
let mix = extract_source_mix(&recs, DEFAULT_MIN_SOURCE_THRESHOLD, 0);
assert!((mix.probabilities["A"] - 0.6).abs() < 1e-9);
assert!((mix.probabilities["B"] - 0.3).abs() < 1e-9);
assert!((mix.probabilities["C"] - 0.1).abs() < 1e-9);
assert!(mix.other_fraction.abs() < 1e-9);
}
#[test]
fn source_mix_long_tail_rolls_into_other() {
let mut recs: Vec<Record> = Vec::new();
recs.extend(std::iter::repeat_with(|| rec("A")).take(995));
for i in 1..=5 {
recs.push(rec(&format!("X{i}")));
}
let mix = extract_source_mix(&recs, 0.005, 0);
assert!(mix.probabilities.contains_key("A"));
assert!(!mix.probabilities.contains_key("X1"));
assert!(mix.other_fraction > 0.0);
}
#[test]
fn source_mix_empty_input_returns_empty() {
let mix = extract_source_mix(&[], DEFAULT_MIN_SOURCE_THRESHOLD, 0);
assert!(mix.probabilities.is_empty());
assert!(mix.other_fraction.abs() < 1e-9);
}
#[test]
fn extract_source_mix_drops_low_volume_codes() {
let mut records = Vec::new();
records.extend(std::iter::repeat_with(|| rec("KR")).take(1500));
records.extend(std::iter::repeat_with(|| rec("RV")).take(100));
records.extend(std::iter::repeat_with(|| rec("DZ")).take(5));
let mix = extract_source_mix(&records, 0.0, 1000);
assert!(mix.probabilities.contains_key("KR"), "KR should survive");
assert!(
!mix.probabilities.contains_key("RV"),
"RV (100 obs) should be dropped"
);
assert!(
!mix.probabilities.contains_key("DZ"),
"DZ (5 obs) should be dropped"
);
assert!(
(mix.probabilities["KR"] - 1.0).abs() < 1e-9,
"KR probability should be 1.0 after renormalisation"
);
}
#[test]
fn per_source_iet_basic() {
let mut recs: Vec<Record> = Vec::new();
let base = NaiveDate::from_ymd_opt(2022, 1, 1).unwrap();
for i in 0..120 {
let mut r = rec("A");
r.entry_date = base + chrono::Duration::days(i);
recs.push(r);
}
for i in 0..50 {
let mut r = rec("B");
r.entry_date = base + chrono::Duration::days(i);
recs.push(r);
}
let p = extract_per_source_iet(&recs, 100);
assert!(p.by_source.contains_key("A"));
assert!(!p.by_source.contains_key("B"));
let summ = &p.by_source["A"];
assert_eq!(summ.n, 119);
assert!(summ.lognormal_fit.is_some());
}
#[test]
fn per_source_iet_constant_gap_zero_autocorr() {
let mut recs: Vec<Record> = Vec::new();
let base = NaiveDate::from_ymd_opt(2022, 1, 1).unwrap();
for i in 0..200 {
let mut r = rec("A");
r.entry_date = base + chrono::Duration::days(3 * i);
recs.push(r);
}
let p = extract_per_source_iet(&recs, 100);
assert!((p.by_source["A"].lag1_autocorr).abs() < 1e-9);
}
#[test]
fn lines_per_je_overall_known() {
let mut recs: Vec<Record> = Vec::new();
for _ in 0..3 {
let mut r = rec("S");
r.je_number = "JE-A".into();
recs.push(r);
}
for _ in 0..2 {
let mut r = rec("S");
r.je_number = "JE-B".into();
recs.push(r);
}
let mut r = rec("S");
r.je_number = "JE-C".into();
recs.push(r);
let p = extract_lines_per_je(&recs, DEFAULT_MIN_JES_PER_SOURCE);
let idx_1 = LINE_COUNT_BUCKETS.iter().position(|&b| b == 1).unwrap();
let idx_2 = LINE_COUNT_BUCKETS.iter().position(|&b| b == 2).unwrap();
let idx_3 = LINE_COUNT_BUCKETS.iter().position(|&b| b == 3).unwrap();
assert!((p.overall.probabilities[idx_1] - 1.0 / 3.0).abs() < 1e-9);
assert!((p.overall.probabilities[idx_2] - 1.0 / 3.0).abs() < 1e-9);
assert!((p.overall.probabilities[idx_3] - 1.0 / 3.0).abs() < 1e-9);
assert_eq!(p.overall.n, 3);
}
#[test]
fn active_lifetime_basic() {
let base = NaiveDate::from_ymd_opt(2022, 1, 1).unwrap();
let mut recs: Vec<Record> = Vec::new();
for i in 0..5 {
let mut r = rec("A");
r.entry_date = base + chrono::Duration::days(i * 6);
recs.push(r);
}
for i in 0..5 {
let mut r = rec("B");
r.entry_date = base + chrono::Duration::days(i * 50);
recs.push(r);
}
let p = extract_active_lifetime(&recs);
let idx_7 = ACTIVE_LIFETIME_DAY_BUCKETS
.iter()
.position(|&b| b == 7)
.unwrap();
let idx_180 = ACTIVE_LIFETIME_DAY_BUCKETS
.iter()
.position(|&b| b == 180)
.unwrap();
assert!((p.overall.probabilities[idx_7] - 0.5).abs() < 1e-9);
assert!((p.overall.probabilities[idx_180] - 0.5).abs() < 1e-9);
}
#[test]
fn fanout_basic() {
let mut recs: Vec<Record> = Vec::new();
for &(src, gl) in &[("A", "X"), ("B", "X"), ("C", "X"), ("A", "Y")] {
let mut r = rec(src);
r.gl_account = gl.into();
recs.push(r);
}
let p = extract_fanout(&recs);
let hist = &p.by_attribute["GLAccount"];
let idx_1 = FANOUT_BUCKETS.iter().position(|&b| b == 1).unwrap();
let idx_3 = FANOUT_BUCKETS.iter().position(|&b| b == 3).unwrap();
assert!((hist.probabilities[idx_1] - 0.5).abs() < 1e-9);
assert!((hist.probabilities[idx_3] - 0.5).abs() < 1e-9);
}
#[test]
fn posting_lag_known() {
let mut recs: Vec<Record> = Vec::new();
let base = NaiveDate::from_ymd_opt(2022, 1, 1).unwrap();
for i in 0..120 {
let mut r = rec("A");
r.entry_date = base + chrono::Duration::days(i);
r.effective_date = r.entry_date + chrono::Duration::days(5);
recs.push(r);
}
for i in 0..120 {
let mut r = rec("B");
r.entry_date = base + chrono::Duration::days(i);
r.effective_date = r.entry_date - chrono::Duration::days(2);
recs.push(r);
}
let p = extract_posting_lag(&recs, 100).expect("non-empty");
assert!((p.by_source["A"].mean - 5.0).abs() < 1e-9);
assert!((p.by_source["B"].mean - (-2.0)).abs() < 1e-9);
assert!((p.by_source["A"].stddev).abs() < 1e-9);
}
#[test]
fn split_into_segments_known() {
let base = NaiveDate::from_ymd_opt(2022, 1, 1).unwrap();
let dates: Vec<NaiveDate> = [0i64, 1, 2, 14, 15, 16, 17, 49]
.iter()
.map(|&d| base + chrono::Duration::days(d))
.collect();
let (segs, gaps) = split_into_segments(&dates, 7);
assert_eq!(segs.len(), 3); assert_eq!(gaps.len(), 2); assert_eq!(gaps[0], 12);
assert_eq!(gaps[1], 32);
}
#[test]
fn extract_active_segments_basic() {
let base = NaiveDate::from_ymd_opt(2022, 1, 1).unwrap();
let mut recs: Vec<Record> = Vec::new();
for &day_off in &[0i64, 1, 2, 14, 15, 16, 17, 49] {
let mut r = rec("A");
r.entry_date = base + chrono::Duration::days(day_off);
recs.push(r);
}
let p = extract_active_segments(&recs);
assert!(p.by_source.contains_key("A"));
let summary = &p.by_source["A"];
let idx_3 = SEGMENT_COUNT_BUCKETS.iter().position(|&b| b == 3).unwrap();
assert!((summary.segment_count_histogram.probabilities[idx_3] - 1.0).abs() < 1e-9);
}
#[test]
fn extract_entity_clusters_finds_shared_attrs() {
let mut recs: Vec<Record> = Vec::new();
for gl in ["1", "2", "3"] {
let mut r = rec("KR");
r.gl_account = gl.into();
recs.push(r);
}
for gl in ["1", "2", "3"] {
let mut r = rec("RV");
r.gl_account = gl.into();
recs.push(r);
}
for gl in ["1", "2", "4"] {
let mut r = rec("DZ");
r.gl_account = gl.into();
recs.push(r);
}
let mut r = rec("WE");
r.gl_account = "99".into();
recs.push(r);
let p = extract_entity_clusters(&recs);
let any_cluster_has_kr_rv_dz = p.clusters.iter().any(|c| {
let members: HashSet<&String> = c.members.iter().collect();
members.contains(&"KR".to_string())
&& members.contains(&"RV".to_string())
&& members.contains(&"DZ".to_string())
});
assert!(
any_cluster_has_kr_rv_dz,
"expected a cluster containing KR, RV, DZ"
);
let any_cluster_has_we = p
.clusters
.iter()
.any(|c| c.members.iter().any(|m| m == "WE"));
assert!(!any_cluster_has_we, "WE should be an isolate (no cluster)");
}
#[test]
fn extract_entity_clusters_normalises_source_codes() {
let base = chrono::NaiveDate::from_ymd_opt(2022, 1, 1).unwrap();
let mut recs: Vec<Record> = Vec::new();
for source in ["0", "0", "0", "KR", "KR", "KR"] {
for gl in ["1", "2", "3"] {
let mut r = rec(source);
r.entry_date = base;
r.gl_account = gl.into();
recs.push(r);
}
}
let p = extract_entity_clusters(&recs);
for cluster in &p.clusters {
for member in &cluster.members {
assert!(
!["0", "1", "2", "00", "01", "02"].contains(&member.as_str()),
"raw numeric code {member:?} should have been normalised"
);
}
}
if !p.clusters.is_empty() {
let any_sap = p.clusters.iter().any(|c| {
c.members
.iter()
.any(|m| ["KR", "RV", "DZ", "SA", "WE", "RE", "IM", "KZ"].contains(&m.as_str()))
});
assert!(
any_sap,
"expected at least one SAP-style canonical code in clusters"
);
}
}
#[test]
fn normalise_source_code_known_mappings() {
assert_eq!(normalise_source_code("KR"), Some("KR".to_string()));
assert_eq!(normalise_source_code("RV"), Some("RV".to_string()));
assert_eq!(normalise_source_code("0"), Some("SA".to_string()));
assert_eq!(normalise_source_code("00"), Some("SA".to_string()));
assert_eq!(normalise_source_code("1"), Some("RV".to_string()));
assert_eq!(normalise_source_code("2"), Some("KR".to_string()));
assert_eq!(normalise_source_code(" KR "), Some("KR".to_string()));
assert_eq!(normalise_source_code("XYZ"), None);
assert_eq!(normalise_source_code(""), None);
}
fn make_record(
src: &str,
gl: &str,
cost_center: Option<&str>,
profit_center: Option<&str>,
) -> Record {
let d = NaiveDate::from_ymd_opt(2022, 1, 1).expect("date");
Record {
source: src.into(),
gl_account: gl.into(),
cost_center: cost_center.map(|s| s.to_string()),
profit_center: profit_center.map(|s| s.to_string()),
trading_partner: None,
je_number: "J1".into(),
je_line_number: "001".into(),
effective_date: d,
entry_date: d,
created_at: None,
functional_amount: 1.0,
header_text: String::new(),
line_text: String::new(),
}
}
#[test]
fn extract_per_source_attribute_filters_low_observations() {
let mut records: Vec<Record> = (0..15)
.map(|_| make_record("KR", "200001", Some("CC1"), Some("PC1")))
.collect();
records.extend((0..3).map(|_| make_record("RV", "400001", Some("CC2"), Some("PC2"))));
let prior = extract_per_source_attribute(&records, 10);
assert!(prior.by_source.contains_key("KR"), "KR should be retained");
let kr_gl = prior
.conditional("KR", "gl_account")
.expect("KR/gl_account");
assert!(
kr_gl.probabilities.contains_key("200001"),
"200001 should appear"
);
assert_eq!(kr_gl.n, 15);
assert!((kr_gl.probabilities["200001"] - 1.0).abs() < 1e-9);
assert!(
!prior.by_source.contains_key("RV"),
"RV should be filtered out"
);
}
#[test]
fn extract_per_source_attribute_skips_empty_source() {
let records: Vec<Record> = (0..20)
.map(|_| make_record("", "100001", None, None))
.collect();
let prior = extract_per_source_attribute(&records, 5);
assert!(
prior.by_source.is_empty(),
"empty source rows must be skipped"
);
}
#[test]
fn extract_per_source_attribute_multiple_values_normalise() {
let mut records: Vec<Record> = (0..8)
.map(|_| make_record("KR", "200001", None, None))
.collect();
records.extend((0..12).map(|_| make_record("KR", "200002", None, None)));
let prior = extract_per_source_attribute(&records, 10);
let kr_gl = prior
.conditional("KR", "gl_account")
.expect("KR/gl_account");
assert_eq!(kr_gl.n, 20);
assert!((kr_gl.probabilities["200001"] - 0.4).abs() < 1e-9);
assert!((kr_gl.probabilities["200002"] - 0.6).abs() < 1e-9);
let total: f64 = kr_gl.probabilities.values().sum();
assert!((total - 1.0).abs() < 1e-9, "probabilities must sum to 1.0");
}
#[test]
fn extract_behavioral_priors_smoke() {
let base = NaiveDate::from_ymd_opt(2022, 1, 1).unwrap();
let mut recs: Vec<Record> = Vec::new();
for i in 0..1200i64 {
let mut r = rec("A");
r.je_number = format!("JE-A-{:04}", i / 3);
r.entry_date = base + chrono::Duration::days(i);
r.effective_date = r.entry_date + chrono::Duration::days(1);
r.gl_account = format!("ACC-{}", i % 5);
recs.push(r);
}
for i in 0..1200i64 {
let mut r = rec("B");
r.je_number = format!("JE-B-{:04}", i / 2);
r.entry_date = base + chrono::Duration::days(i);
r.effective_date = r.entry_date - chrono::Duration::days(1);
r.gl_account = format!("ACC-{}", i % 7);
recs.push(r);
}
let bp = extract_behavioral_priors(&recs, "test_industry").expect("ok");
assert_eq!(bp.schema_version, BehavioralPriors::SCHEMA_VERSION);
assert_eq!(bp.industry, "test_industry");
assert_eq!(bp.n_client_inputs, 1);
assert_eq!(bp.n_rows_aggregated, 2400);
assert!(!bp.source_mix.probabilities.is_empty());
assert!(bp.per_source_iet.by_source.contains_key("A"));
assert!(bp.per_source_iet.by_source.contains_key("B"));
assert!(bp.lines_per_je.overall.n > 0);
assert!(bp.active_lifetime.overall.n > 0);
assert_eq!(bp.fanout.by_attribute.len(), 4);
assert!(bp.posting_lag.is_some());
assert!(
bp.per_source_attribute.is_some(),
"per_source_attribute should be extracted"
);
let psa = bp.per_source_attribute.as_ref().unwrap();
assert!(psa.by_source.contains_key("A") || psa.by_source.contains_key("B"));
}
#[test]
fn extract_per_source_attribute_includes_trading_partner() {
let mut records: Vec<Record> = (0..15)
.map(|_| {
let mut r = make_record("KR", "200001", Some("CC1"), Some("PC1"));
r.trading_partner = Some("V100".to_string());
r
})
.collect();
records.extend((0..5).map(|_| {
let mut r = make_record("KR", "200001", Some("CC1"), Some("PC1"));
r.trading_partner = Some("V200".to_string());
r
}));
records.extend((0..3).map(|_| {
let mut r = make_record("RV", "400001", Some("CC2"), Some("PC2"));
r.trading_partner = Some("V300".to_string());
r
}));
let prior = extract_per_source_attribute(&records, 10);
let kr_tp = prior
.conditional("KR", "trading_partner")
.expect("KR/trading_partner conditional must be present");
assert!(
kr_tp.probabilities.contains_key("V100"),
"V100 should appear in KR trading_partner conditional"
);
assert!(
kr_tp.probabilities.contains_key("V200"),
"V200 should appear in KR trading_partner conditional"
);
assert_eq!(kr_tp.n, 20, "total observations should be 20");
assert!(
(kr_tp.probabilities["V100"] - 0.75).abs() < 1e-9,
"V100 share should be 0.75"
);
assert!(
(kr_tp.probabilities["V200"] - 0.25).abs() < 1e-9,
"V200 share should be 0.25"
);
let total: f64 = kr_tp.probabilities.values().sum();
assert!((total - 1.0).abs() < 1e-9, "probabilities must sum to 1.0");
assert!(
!prior.by_source.contains_key("RV"),
"RV should be filtered out"
);
}
#[test]
fn extract_tp_entity_clusters_finds_shared_attrs() {
let d = NaiveDate::from_ymd_opt(2022, 1, 1).expect("date");
let make_tp_rec = |tp: &str, gl: &str| Record {
source: "KR".into(),
gl_account: gl.into(),
cost_center: None,
profit_center: None,
trading_partner: Some(tp.into()),
je_number: "J1".into(),
je_line_number: "001".into(),
effective_date: d,
entry_date: d,
created_at: None,
functional_amount: 1.0,
header_text: String::new(),
line_text: String::new(),
};
let mut recs = Vec::new();
for gl in ["10", "20", "30"] {
recs.push(make_tp_rec("T1", gl));
recs.push(make_tp_rec("T2", gl));
}
recs.push(make_tp_rec("T3", "99"));
let p = extract_tp_entity_clusters(&recs);
let cluster_has_t1_t2 = p.clusters.iter().any(|c| {
let m: HashSet<&String> = c.members.iter().collect();
m.contains(&"T1".to_string()) && m.contains(&"T2".to_string())
});
assert!(cluster_has_t1_t2, "expected a cluster containing T1 and T2");
let cluster_has_t3 = p
.clusters
.iter()
.any(|c| c.members.iter().any(|m| m == "T3"));
assert!(!cluster_has_t3, "T3 should be an isolate (no cluster)");
assert!(p.clustering_rate > 0.0, "clustering_rate must be > 0");
}
fn make_amount_rec(src: &str, gl: &str, amount: f64) -> Record {
let d = NaiveDate::from_ymd_opt(2022, 1, 1).expect("date");
Record {
source: src.into(),
gl_account: gl.into(),
cost_center: None,
profit_center: None,
trading_partner: None,
je_number: "J1".into(),
je_line_number: "001".into(),
effective_date: d,
entry_date: d,
created_at: None,
functional_amount: amount,
header_text: String::new(),
line_text: String::new(),
}
}
#[test]
fn extract_source_amount_conditionals_filters_low_count_pairs() {
let mut records: Vec<Record> = (0..15)
.map(|i| make_amount_rec("KR", "0041", 100.0 + i as f64))
.collect();
records.extend((0..3).map(|i| make_amount_rec("RV", "0013", 500.0 + i as f64)));
let prior = extract_source_amount_conditionals(&records, 10);
assert!(
prior.by_source.contains_key("KR"),
"KR marginal should be retained"
);
assert!(
prior
.by_source_and_class
.get("KR")
.map(|m| m.contains_key("0041"))
.unwrap_or(false),
"KR/0041 pair should be retained"
);
assert!(
!prior.by_source.contains_key("RV"),
"RV marginal should be dropped (only 3 observations)"
);
}
#[test]
fn extract_source_amount_conditionals_lognormal_params_sensible() {
let base_amount = 4.5_f64.exp(); let records: Vec<Record> = (0..50)
.map(|i| {
let amt = base_amount * (1.0 + 0.01 * ((i as f64) - 25.0));
make_amount_rec("KR", "0041", amt)
})
.collect();
let prior = extract_source_amount_conditionals(&records, 10);
let params = prior
.by_source_and_class
.get("KR")
.and_then(|m| m.get("0041"))
.expect("KR/0041 params should be present");
assert!(
(params.mu - 4.5).abs() < 0.1,
"mu {:.3} should be close to 4.5",
params.mu
);
assert_eq!(params.n, 50);
assert!(params.median_abs > 0.0, "median_abs must be positive");
}
#[test]
fn extract_source_amount_conditionals_skips_zeros() {
let mut records: Vec<Record> = (0..20)
.map(|_| make_amount_rec("KR", "0041", 100.0))
.collect();
records.extend((0..5).map(|_| make_amount_rec("KR", "0041", 0.0)));
records.extend((0..5).map(|_| make_amount_rec("KR", "0041", -50.0)));
let prior = extract_source_amount_conditionals(&records, 10);
let params = prior
.by_source_and_class
.get("KR")
.and_then(|m| m.get("0041"))
.expect("KR/0041 should be present");
assert_eq!(params.n, 25, "zero-amount records must be excluded from n");
}
#[test]
fn extract_behavioral_priors_populates_source_amount_conditionals() {
let base = NaiveDate::from_ymd_opt(2022, 1, 1).unwrap();
let mut recs: Vec<Record> = Vec::new();
for i in 0..50 {
let mut r = make_amount_rec("KR", "0041", 100.0 + i as f64);
r.entry_date = base + chrono::Duration::days(i);
r.effective_date = r.entry_date;
r.je_number = format!("JE-{i}");
recs.push(r);
}
let bp = extract_behavioral_priors(&recs, "test").expect("ok");
assert!(
bp.source_amount_conditionals.is_some(),
"source_amount_conditionals should be populated"
);
let sac = bp.source_amount_conditionals.as_ref().unwrap();
assert!(
sac.by_source.contains_key("KR"),
"KR marginal should be present"
);
}
fn make_role_records(source: &str, role_sign: f64, gl: &str, n: usize) -> Vec<Record> {
let base = NaiveDate::from_ymd_opt(2022, 1, 1).expect("date");
(0..n)
.map(|i| Record {
source: source.to_string(),
gl_account: gl.to_string(),
cost_center: None,
profit_center: None,
trading_partner: None,
je_number: format!("J{i:06}"),
je_line_number: "001".to_string(),
effective_date: base,
entry_date: base,
created_at: None,
functional_amount: role_sign * 100.0,
header_text: String::new(),
line_text: String::new(),
})
.collect()
}
#[test]
fn sp4_6_extract_source_role_gl_produces_role_conditionals() {
let mut records = make_role_records("KR", 1.0, "6000", 15); records.extend(make_role_records("KR", -1.0, "2000", 15));
let prior = extract_source_role_gl(&records, 10);
assert!(
prior.conditional("KR", "DR").is_some(),
"KR DR should be present"
);
assert!(
prior.conditional("KR", "CR").is_some(),
"KR CR should be present"
);
let dr_dist = prior.conditional("KR", "DR").unwrap();
assert!(
dr_dist.probabilities.contains_key("6000"),
"DR should have 6000"
);
let cr_dist = prior.conditional("KR", "CR").unwrap();
assert!(
cr_dist.probabilities.contains_key("2000"),
"CR should have 2000"
);
}
#[test]
fn sp4_6_extract_source_role_gl_filters_low_counts() {
let mut records = make_role_records("KR", 1.0, "6000", 15);
records.extend(make_role_records("KR", -1.0, "2000", 3));
let prior = extract_source_role_gl(&records, 10);
assert!(
prior.conditional("KR", "DR").is_some(),
"KR DR should pass threshold"
);
assert!(
prior.conditional("KR", "CR").is_none(),
"KR CR with only 3 obs should be dropped"
);
}
#[test]
fn sp4_6_extract_source_role_gl_skips_zero_amounts() {
let records: Vec<Record> = (0..20)
.map(|i| Record {
source: "SA".to_string(),
gl_account: "4000".to_string(),
cost_center: None,
profit_center: None,
trading_partner: None,
je_number: format!("J{i:06}"),
je_line_number: "001".to_string(),
effective_date: NaiveDate::from_ymd_opt(2022, 1, 1).unwrap(),
entry_date: NaiveDate::from_ymd_opt(2022, 1, 1).unwrap(),
created_at: None,
functional_amount: 0.0, header_text: String::new(),
line_text: String::new(),
})
.collect();
let prior = extract_source_role_gl(&records, 10);
assert!(
prior.by_source_and_role.is_empty(),
"zero-amount records should yield an empty prior"
);
}
#[test]
fn sp4_6_extract_behavioral_priors_populates_source_role_gl_conditionals() {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let mut recs: Vec<Record> = Vec::new();
let base = NaiveDate::from_ymd_opt(2022, 1, 1).unwrap();
let sources = ["KR", "RV", "SA"];
let gl_dr = ["6000", "6100", "6200"];
let gl_cr = ["2000", "2100", "1000"];
for i in 0..1200usize {
let src = sources[i % sources.len()];
let (amt, gl) = if i % 2 == 0 {
(100.0, gl_dr[i % gl_dr.len()])
} else {
(-100.0, gl_cr[i % gl_cr.len()])
};
recs.push(Record {
source: src.to_string(),
gl_account: gl.to_string(),
cost_center: None,
profit_center: None,
trading_partner: None,
je_number: format!("J{:06}", i / 3),
je_line_number: format!("{:03}", (i % 3) + 1),
effective_date: base + Duration::days(rng.random_range(0..365)),
entry_date: base + Duration::days(rng.random_range(0..365)),
created_at: None,
functional_amount: amt,
header_text: String::new(),
line_text: String::new(),
});
}
let bp = extract_behavioral_priors(&recs, "test").expect("ok");
assert!(
bp.source_role_gl_conditionals.is_some(),
"source_role_gl_conditionals should be populated with sufficient data"
);
}
}