use std::collections::HashMap;
use super::CopulaGenerator;
use crate::error::FingerprintResult;
use crate::models::{
CorrelationMatrix, DistributionType, Fingerprint, GaussianCopula, NumericStats,
};
#[derive(Debug, Clone)]
pub struct SynthesisOptions {
pub scale: f64,
pub seed: Option<u64>,
pub preserve_correlations: bool,
pub inject_anomalies: bool,
}
impl Default for SynthesisOptions {
fn default() -> Self {
Self {
scale: 1.0,
seed: None,
preserve_correlations: true,
inject_anomalies: true,
}
}
}
pub struct ConfigSynthesizer {
options: SynthesisOptions,
}
impl ConfigSynthesizer {
pub fn new() -> Self {
Self {
options: SynthesisOptions::default(),
}
}
pub fn with_options(options: SynthesisOptions) -> Self {
Self { options }
}
pub fn synthesize(&self, fingerprint: &Fingerprint) -> FingerprintResult<ConfigPatch> {
let mut patch = ConfigPatch::new();
let total_rows: u64 = fingerprint
.schema
.tables
.values()
.map(|t| t.row_count)
.sum();
let scaled_rows = (total_rows as f64 * self.options.scale) as u64;
patch.set(
"transactions.count",
ConfigValue::Integer(scaled_rows as i64),
);
if let Some(seed) = self.options.seed {
patch.set("global.seed", ConfigValue::Integer(seed as i64));
}
let mut amount_cols: Vec<(&String, &NumericStats)> = fingerprint
.statistics
.numeric_columns
.iter()
.filter(|(k, _)| {
let kl = k.to_lowercase();
kl.contains("amount")
|| kl.contains("value")
|| kl.contains("price")
|| kl.contains("dmbtr")
|| kl.contains("wrbtr")
})
.collect();
amount_cols.sort_by(|(a, _), (b, _)| {
fn rank(k: &str) -> u8 {
let kl = k.to_lowercase();
if kl.contains("functional") {
0
} else if kl.contains("reporting") {
1
} else {
2
}
}
rank(a).cmp(&rank(b)).then_with(|| a.cmp(b))
});
if let Some((_, stats)) = amount_cols.first() {
let amount_config = self.map_numeric_distribution(stats);
for (k, v) in amount_config {
patch.set(&format!("transactions.amounts.{k}"), v);
}
}
if self.options.inject_anomalies {
if let Some(ref anomalies) = fingerprint.anomalies {
let rate = anomalies.overall.anomaly_rate;
patch.set("anomaly_injection.overall_rate", ConfigValue::Float(rate));
patch.set("anomaly_injection.enabled", ConfigValue::Bool(rate > 0.0));
}
}
Ok(patch)
}
fn map_numeric_distribution(&self, stats: &NumericStats) -> HashMap<String, ConfigValue> {
let mut config = HashMap::new();
if let Some(lmp) = &stats.log_magnitude_percentiles {
let anchors: [(f64, f64); 6] = [
(0.10, lmp.p10),
(0.25, lmp.p25),
(0.50, lmp.p50),
(0.75, lmp.p75),
(0.90, lmp.p90),
(0.99, lmp.p99),
];
let n = anchors.len();
let mut comps: Vec<(f64, f64, f64)> = Vec::with_capacity(n);
let mut wsum = 0.0_f64;
for i in 0..n {
let (_, v) = anchors[i];
let p_prev = if i == 0 { 0.0 } else { anchors[i - 1].0 };
let p_next = if i == n - 1 { 1.0 } else { anchors[i + 1].0 };
let w = ((p_next - p_prev) / 2.0).max(1e-3);
let v_lo = if i == 0 { v } else { anchors[i - 1].1 };
let v_hi = if i == n - 1 { v } else { anchors[i + 1].1 };
let span = if i == 0 {
v_hi - v
} else if i == n - 1 {
v - v_lo
} else {
(v_hi - v_lo) / 2.0
};
comps.push((w, v, (span / 2.0).abs().max(0.1)));
wsum += w;
}
config.insert(
"mixture_components".to_string(),
ConfigValue::Integer(n as i64),
);
for (i, (w, mu, sg)) in comps.iter().enumerate() {
config.insert(format!("comp{i}_weight"), ConfigValue::Float(w / wsum));
config.insert(format!("comp{i}_mu"), ConfigValue::Float(*mu));
config.insert(format!("comp{i}_sigma"), ConfigValue::Float(*sg));
}
config.insert("lognormal_mu".to_string(), ConfigValue::Float(lmp.p50));
config.insert(
"lognormal_sigma".to_string(),
ConfigValue::Float(((lmp.p75 - lmp.p25) / 1.349).abs().max(0.1)),
);
config.insert("min_amount".to_string(), ConfigValue::Float(lmp.p1.exp()));
config.insert("max_amount".to_string(), ConfigValue::Float(lmp.p99.exp()));
} else {
config.insert("min_amount".to_string(), ConfigValue::Float(stats.min));
config.insert("max_amount".to_string(), ConfigValue::Float(stats.max));
match stats.distribution {
DistributionType::LogNormal => {
if let (Some(mu), Some(sigma)) = (
stats.distribution_params.param1,
stats.distribution_params.param2,
) {
config.insert("lognormal_mu".to_string(), ConfigValue::Float(mu));
config.insert("lognormal_sigma".to_string(), ConfigValue::Float(sigma));
}
}
DistributionType::Normal => {
if stats.mean > 0.0 {
let variance = stats.std_dev.powi(2);
let sigma_sq = (1.0 + variance / stats.mean.powi(2)).ln();
let mu = stats.mean.ln() - sigma_sq / 2.0;
config.insert("lognormal_mu".to_string(), ConfigValue::Float(mu));
config.insert(
"lognormal_sigma".to_string(),
ConfigValue::Float(sigma_sq.sqrt()),
);
}
}
_ => {
if stats.percentiles.p50 > 0.0 {
let mu = stats.percentiles.p50.ln();
let sigma = (stats.percentiles.p75 / stats.percentiles.p25).ln() / 1.349;
config.insert("lognormal_mu".to_string(), ConfigValue::Float(mu));
config.insert(
"lognormal_sigma".to_string(),
ConfigValue::Float(sigma.abs()),
);
}
}
}
}
if let Some(benford) = stats.benford_first_digit {
let round_bias = if benford[0] < 0.25 { 0.3 } else { 0.15 };
config.insert(
"round_number_probability".to_string(),
ConfigValue::Float(round_bias),
);
}
config
}
}
impl Default for ConfigSynthesizer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct SynthesisResult {
pub config_patch: ConfigPatch,
pub copula_generators: Vec<CopulaGeneratorSpec>,
}
#[derive(Debug)]
pub struct CopulaGeneratorSpec {
pub name: String,
pub table: String,
pub columns: Vec<String>,
pub generator: CopulaGenerator,
}
impl ConfigSynthesizer {
pub fn synthesize_full(
&self,
fingerprint: &Fingerprint,
seed: u64,
) -> FingerprintResult<SynthesisResult> {
let config_patch = self.synthesize(fingerprint)?;
let mut copula_generators = Vec::new();
if self.options.preserve_correlations {
if let Some(ref correlations) = fingerprint.correlations {
for copula in &correlations.copulas {
if let Some(generator) = CopulaGenerator::from_copula(copula, seed) {
copula_generators.push(CopulaGeneratorSpec {
name: copula.name.clone(),
table: copula.table.clone(),
columns: copula.columns.clone(),
generator,
});
}
}
if copula_generators.is_empty() {
for (table_name, matrix) in &correlations.matrices {
if matrix.columns.len() >= 2 {
if let Some(generator) =
CopulaGenerator::from_correlation_matrix(matrix, seed)
{
copula_generators.push(CopulaGeneratorSpec {
name: format!("{table_name}_copula"),
table: table_name.clone(),
columns: matrix.columns.clone(),
generator,
});
}
}
}
}
}
}
Ok(SynthesisResult {
config_patch,
copula_generators,
})
}
pub fn create_copula_generator(copula: &GaussianCopula, seed: u64) -> Option<CopulaGenerator> {
CopulaGenerator::from_copula(copula, seed)
}
pub fn create_copula_from_matrix(
matrix: &CorrelationMatrix,
seed: u64,
) -> Option<CopulaGenerator> {
CopulaGenerator::from_correlation_matrix(matrix, seed)
}
}
#[derive(Debug, Clone, Default)]
pub struct ConfigPatch {
values: HashMap<String, ConfigValue>,
}
impl ConfigPatch {
pub fn new() -> Self {
Self {
values: HashMap::new(),
}
}
pub fn set(&mut self, path: &str, value: ConfigValue) {
self.values.insert(path.to_string(), value);
}
pub fn get(&self, path: &str) -> Option<&ConfigValue> {
self.values.get(path)
}
pub fn values(&self) -> &HashMap<String, ConfigValue> {
&self.values
}
pub fn merge(&mut self, other: ConfigPatch) {
self.values.extend(other.values);
}
pub fn to_yaml(&self) -> FingerprintResult<String> {
let mut root = serde_yaml::Mapping::new();
for (path, value) in &self.values {
let parts: Vec<&str> = path.split('.').collect();
set_nested_value(&mut root, &parts, value);
}
Ok(serde_yaml::to_string(&root)?)
}
}
#[derive(Debug, Clone)]
pub enum ConfigValue {
Bool(bool),
Integer(i64),
Float(f64),
String(String),
Array(Vec<ConfigValue>),
}
impl ConfigValue {
fn to_yaml_value(&self) -> serde_yaml::Value {
match self {
Self::Bool(b) => serde_yaml::Value::Bool(*b),
Self::Integer(i) => serde_yaml::Value::Number(serde_yaml::Number::from(*i)),
Self::Float(f) => {
if f.is_finite() {
serde_yaml::Value::Number(serde_yaml::Number::from(*f))
} else {
serde_yaml::Value::Null
}
}
Self::String(s) => serde_yaml::Value::String(s.clone()),
Self::Array(arr) => {
serde_yaml::Value::Sequence(arr.iter().map(ConfigValue::to_yaml_value).collect())
}
}
}
}
fn set_nested_value(root: &mut serde_yaml::Mapping, path: &[&str], value: &ConfigValue) {
if path.is_empty() {
return;
}
let key = serde_yaml::Value::String(path[0].to_string());
if path.len() == 1 {
root.insert(key, value.to_yaml_value());
} else {
let entry = root
.entry(key)
.or_insert_with(|| serde_yaml::Value::Mapping(serde_yaml::Mapping::new()));
if let serde_yaml::Value::Mapping(ref mut nested) = entry {
set_nested_value(nested, &path[1..], value);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_patch() {
let mut patch = ConfigPatch::new();
patch.set("global.seed", ConfigValue::Integer(42));
patch.set("transactions.count", ConfigValue::Integer(1000));
assert!(patch.get("global.seed").is_some());
let yaml = patch.to_yaml().unwrap();
assert!(yaml.contains("global"));
assert!(yaml.contains("seed"));
}
fn logmag_percentiles(p25: f64, p50: f64, p75: f64) -> crate::models::Percentiles {
crate::models::Percentiles {
p1: 0.0,
p5: 1.0,
p10: 2.0,
p25,
p50,
p75,
p90: p75 + 2.0,
p95: p75 + 3.0,
p99: 14.0,
}
}
#[test]
fn log_magnitude_percentiles_override_corrupted_mean() {
let synth = ConfigSynthesizer::new();
let mut stats = NumericStats::new(1_000_000, -3062.36, 3062.36, 378_348.3, 237_409.8);
stats.distribution = DistributionType::Normal; stats.percentiles = logmag_percentiles(-36.0, 0.0, 37.7); stats.log_magnitude_percentiles = Some(logmag_percentiles(3.0, 4.856, 6.5));
let cfg = synth.map_numeric_distribution(&stats);
let mu = match cfg.get("lognormal_mu") {
Some(ConfigValue::Float(v)) => *v,
other => panic!("expected lognormal_mu float, got {other:?}"),
};
let sigma = match cfg.get("lognormal_sigma") {
Some(ConfigValue::Float(v)) => *v,
other => panic!("expected lognormal_sigma float, got {other:?}"),
};
assert!(
(mu - 4.856).abs() < 1e-9,
"mu={mu} should equal log-mag p50"
);
assert!(
(sigma - (6.5 - 3.0) / 1.349).abs() < 1e-9,
"sigma={sigma} should be IQR/1.349"
);
match cfg.get("max_amount") {
Some(ConfigValue::Float(v)) => assert!(
*v > 1e5,
"max_amount={v} should reflect the heavy tail exp(p99)"
),
other => panic!("expected max_amount float, got {other:?}"),
}
let nc = match cfg.get("mixture_components") {
Some(ConfigValue::Integer(n)) => *n as usize,
other => panic!("expected mixture_components integer, got {other:?}"),
};
assert!(nc >= 3, "expected a multi-component mixture, got {nc}");
let f = |k: &str| match cfg.get(k) {
Some(ConfigValue::Float(v)) => *v,
other => panic!("expected {k} float, got {other:?}"),
};
let last = nc - 1;
assert!(
(f(&format!("comp{last}_mu")) - 14.0).abs() < 1e-9,
"tail component centered at p99"
);
assert!(
f(&format!("comp{last}_mu")) > f("comp0_mu"),
"tail mu must exceed the first component"
);
let wsum: f64 = (0..nc).map(|i| f(&format!("comp{i}_weight"))).sum();
assert!(
(wsum - 1.0).abs() < 1e-9,
"component weights sum to 1.0, got {wsum}"
);
}
#[test]
fn legacy_path_used_when_log_magnitude_absent() {
let synth = ConfigSynthesizer::new();
let mut stats = NumericStats::new(1000, 1.0, 1000.0, 100.0, 50.0);
stats.distribution = DistributionType::Normal;
stats.log_magnitude_percentiles = None;
let cfg = synth.map_numeric_distribution(&stats);
assert!(
cfg.contains_key("lognormal_mu"),
"legacy normal path should still fit"
);
}
}