use super::copula::{
cholesky_decompose, standard_normal_cdf, standard_normal_quantile, CopulaType,
};
use rand::prelude::*;
use rand_chacha::ChaCha8Rng;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CorrelatedField {
pub name: String,
pub distribution: MarginalDistribution,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum MarginalDistribution {
Normal { mu: f64, sigma: f64 },
LogNormal { mu: f64, sigma: f64 },
Uniform { a: f64, b: f64 },
DiscreteUniform { min: i32, max: i32 },
Custom { quantiles: Vec<f64> },
}
impl Default for MarginalDistribution {
fn default() -> Self {
Self::Normal {
mu: 0.0,
sigma: 1.0,
}
}
}
impl MarginalDistribution {
pub fn inverse_cdf(&self, u: f64) -> f64 {
match self {
Self::Normal { mu, sigma } => mu + sigma * standard_normal_quantile(u),
Self::LogNormal { mu, sigma } => {
let z = standard_normal_quantile(u);
(mu + sigma * z).exp()
}
Self::Uniform { a, b } => a + u * (b - a),
Self::DiscreteUniform { min, max } => {
let range = (*max - *min + 1) as f64;
(*min as f64 + (u * range).floor()).min(*max as f64)
}
Self::Custom { quantiles } => {
if quantiles.is_empty() {
return 0.0;
}
let n = quantiles.len();
let idx = u * (n - 1) as f64;
let low_idx = idx.floor() as usize;
let high_idx = (low_idx + 1).min(n - 1);
let frac = idx - low_idx as f64;
quantiles[low_idx] * (1.0 - frac) + quantiles[high_idx] * frac
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CorrelationConfig {
pub fields: Vec<CorrelatedField>,
pub matrix: Vec<f64>,
#[serde(default)]
pub copula_type: CopulaType,
}
impl Default for CorrelationConfig {
fn default() -> Self {
Self {
fields: vec![],
matrix: vec![],
copula_type: CopulaType::Gaussian,
}
}
}
impl CorrelationConfig {
pub fn new(fields: Vec<CorrelatedField>, matrix: Vec<f64>) -> Self {
Self {
fields,
matrix,
copula_type: CopulaType::Gaussian,
}
}
pub fn bivariate(field1: CorrelatedField, field2: CorrelatedField, correlation: f64) -> Self {
Self {
fields: vec![field1, field2],
matrix: vec![correlation],
copula_type: CopulaType::Gaussian,
}
}
pub fn validate(&self) -> Result<(), String> {
let n = self.fields.len();
if n < 2 {
return Err("At least 2 fields are required for correlation".to_string());
}
let expected_matrix_size = n * (n - 1) / 2;
if self.matrix.len() != expected_matrix_size {
return Err(format!(
"Expected {} correlation values for {} fields, got {}",
expected_matrix_size,
n,
self.matrix.len()
));
}
for (i, &corr) in self.matrix.iter().enumerate() {
if !(-1.0..=1.0).contains(&corr) {
return Err(format!(
"Correlation at index {i} must be in [-1, 1], got {corr}"
));
}
}
let full_matrix = self.to_full_matrix();
if cholesky_decompose(&full_matrix).is_none() {
return Err(
"Correlation matrix is not positive semi-definite (invalid correlations)"
.to_string(),
);
}
Ok(())
}
pub fn to_full_matrix(&self) -> Vec<Vec<f64>> {
let n = self.fields.len();
let mut matrix = vec![vec![0.0; n]; n];
for (i, row) in matrix.iter_mut().enumerate() {
row[i] = 1.0;
}
#[allow(clippy::needless_range_loop)]
{
let mut idx = 0;
for i in 0..n {
for j in (i + 1)..n {
let val = self.matrix[idx];
matrix[i][j] = val;
matrix[j][i] = val;
idx += 1;
}
}
}
matrix
}
pub fn field_names(&self) -> Vec<&str> {
self.fields.iter().map(|f| f.name.as_str()).collect()
}
}
pub struct CorrelationEngine {
rng: ChaCha8Rng,
config: CorrelationConfig,
cholesky: Vec<Vec<f64>>,
}
impl CorrelationEngine {
pub fn new(seed: u64, config: CorrelationConfig) -> Result<Self, String> {
config.validate()?;
let full_matrix = config.to_full_matrix();
let cholesky = cholesky_decompose(&full_matrix)
.ok_or_else(|| "Failed to compute Cholesky decomposition".to_string())?;
Ok(Self {
rng: ChaCha8Rng::seed_from_u64(seed),
config,
cholesky,
})
}
pub fn sample(&mut self) -> HashMap<String, f64> {
let n = self.config.fields.len();
let z: Vec<f64> = (0..n).map(|_| self.sample_standard_normal()).collect();
let y: Vec<f64> = self
.cholesky
.iter()
.enumerate()
.map(|(i, row)| {
row.iter()
.take(i + 1)
.zip(z.iter())
.map(|(c, z)| c * z)
.sum()
})
.collect();
let u: Vec<f64> = y.iter().map(|&yi| standard_normal_cdf(yi)).collect();
let mut result = HashMap::new();
for (i, field) in self.config.fields.iter().enumerate() {
let value = field.distribution.inverse_cdf(u[i]);
result.insert(field.name.clone(), value);
}
result
}
pub fn sample_vec(&mut self) -> Vec<f64> {
let n = self.config.fields.len();
let z: Vec<f64> = (0..n).map(|_| self.sample_standard_normal()).collect();
let y: Vec<f64> = self
.cholesky
.iter()
.enumerate()
.map(|(i, row)| {
row.iter()
.take(i + 1)
.zip(z.iter())
.map(|(c, z)| c * z)
.sum()
})
.collect();
let u: Vec<f64> = y.iter().map(|&yi| standard_normal_cdf(yi)).collect();
self.config
.fields
.iter()
.enumerate()
.map(|(i, field)| field.distribution.inverse_cdf(u[i]))
.collect()
}
pub fn sample_field(&mut self, name: &str) -> Option<f64> {
let sample = self.sample();
sample.get(name).copied()
}
pub fn sample_n(&mut self, n: usize) -> Vec<HashMap<String, f64>> {
(0..n).map(|_| self.sample()).collect()
}
fn sample_standard_normal(&mut self) -> f64 {
let u1: f64 = self.rng.random();
let u2: f64 = self.rng.random();
(-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
}
pub fn reset(&mut self, seed: u64) {
self.rng = ChaCha8Rng::seed_from_u64(seed);
}
pub fn config(&self) -> &CorrelationConfig {
&self.config
}
}
pub mod correlation_presets {
use super::*;
pub fn amount_line_items() -> CorrelationConfig {
CorrelationConfig::bivariate(
CorrelatedField {
name: "amount".to_string(),
distribution: MarginalDistribution::LogNormal {
mu: 7.0,
sigma: 2.0,
},
},
CorrelatedField {
name: "line_items".to_string(),
distribution: MarginalDistribution::DiscreteUniform { min: 2, max: 20 },
},
0.65,
)
}
pub fn amount_approval_level() -> CorrelationConfig {
CorrelationConfig::bivariate(
CorrelatedField {
name: "amount".to_string(),
distribution: MarginalDistribution::LogNormal {
mu: 8.0,
sigma: 2.5,
},
},
CorrelatedField {
name: "approval_level".to_string(),
distribution: MarginalDistribution::DiscreteUniform { min: 1, max: 5 },
},
0.72,
)
}
pub fn order_processing_time() -> CorrelationConfig {
CorrelationConfig::bivariate(
CorrelatedField {
name: "order_value".to_string(),
distribution: MarginalDistribution::LogNormal {
mu: 7.5,
sigma: 1.5,
},
},
CorrelatedField {
name: "processing_days".to_string(),
distribution: MarginalDistribution::LogNormal {
mu: 1.5,
sigma: 0.8,
},
},
0.35,
)
}
pub fn transaction_attributes() -> CorrelationConfig {
CorrelationConfig {
fields: vec![
CorrelatedField {
name: "amount".to_string(),
distribution: MarginalDistribution::LogNormal {
mu: 7.0,
sigma: 2.0,
},
},
CorrelatedField {
name: "line_items".to_string(),
distribution: MarginalDistribution::DiscreteUniform { min: 2, max: 15 },
},
CorrelatedField {
name: "approval_level".to_string(),
distribution: MarginalDistribution::DiscreteUniform { min: 1, max: 4 },
},
],
matrix: vec![0.65, 0.72, 0.55],
copula_type: CopulaType::Gaussian,
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_correlation_config_validation() {
let valid = CorrelationConfig::bivariate(
CorrelatedField {
name: "x".to_string(),
distribution: MarginalDistribution::Normal {
mu: 0.0,
sigma: 1.0,
},
},
CorrelatedField {
name: "y".to_string(),
distribution: MarginalDistribution::Normal {
mu: 0.0,
sigma: 1.0,
},
},
0.5,
);
assert!(valid.validate().is_ok());
let invalid_corr = CorrelationConfig::bivariate(
CorrelatedField {
name: "x".to_string(),
distribution: MarginalDistribution::Normal {
mu: 0.0,
sigma: 1.0,
},
},
CorrelatedField {
name: "y".to_string(),
distribution: MarginalDistribution::Normal {
mu: 0.0,
sigma: 1.0,
},
},
1.5,
);
assert!(invalid_corr.validate().is_err());
}
#[test]
fn test_full_matrix_conversion() {
let config = CorrelationConfig {
fields: vec![
CorrelatedField {
name: "a".to_string(),
distribution: MarginalDistribution::default(),
},
CorrelatedField {
name: "b".to_string(),
distribution: MarginalDistribution::default(),
},
CorrelatedField {
name: "c".to_string(),
distribution: MarginalDistribution::default(),
},
],
matrix: vec![0.5, 0.3, 0.4], copula_type: CopulaType::Gaussian,
};
let full = config.to_full_matrix();
assert_eq!(full[0][0], 1.0);
assert_eq!(full[1][1], 1.0);
assert_eq!(full[2][2], 1.0);
assert_eq!(full[0][1], full[1][0]);
assert_eq!(full[0][2], full[2][0]);
assert_eq!(full[1][2], full[2][1]);
assert_eq!(full[0][1], 0.5);
assert_eq!(full[0][2], 0.3);
assert_eq!(full[1][2], 0.4);
}
#[test]
fn test_correlation_engine_sampling() {
let config = correlation_presets::amount_line_items();
let mut engine = CorrelationEngine::new(42, config).unwrap();
let samples = engine.sample_n(2000); assert_eq!(samples.len(), 2000);
let n = samples.len() as f64;
let amounts: Vec<f64> = samples.iter().map(|s| s["amount"]).collect();
let line_items: Vec<f64> = samples.iter().map(|s| s["line_items"]).collect();
assert!(amounts.iter().all(|&a| a > 0.0));
assert!(line_items.iter().all(|&l| (2.0..=20.0).contains(&l)));
let mean_a = amounts.iter().sum::<f64>() / n;
let mean_l = line_items.iter().sum::<f64>() / n;
let mut cov = 0.0;
let mut var_a = 0.0;
let mut var_l = 0.0;
for (a, l) in amounts.iter().zip(line_items.iter()) {
let da = a - mean_a;
let dl = l - mean_l;
cov += da * dl;
var_a += da * da;
var_l += dl * dl;
}
let correlation = if var_a > 0.0 && var_l > 0.0 {
cov / (var_a.sqrt() * var_l.sqrt())
} else {
0.0
};
assert!(
correlation > -0.5,
"Correlation {} is unexpectedly strongly negative",
correlation
);
}
#[test]
fn test_correlation_engine_determinism() {
let config = correlation_presets::amount_line_items();
let mut engine1 = CorrelationEngine::new(42, config.clone()).unwrap();
let mut engine2 = CorrelationEngine::new(42, config).unwrap();
for _ in 0..100 {
let s1 = engine1.sample();
let s2 = engine2.sample();
assert_eq!(s1["amount"], s2["amount"]);
assert_eq!(s1["line_items"], s2["line_items"]);
}
}
#[test]
fn test_marginal_inverse_cdf() {
let normal = MarginalDistribution::Normal {
mu: 10.0,
sigma: 2.0,
};
assert!((normal.inverse_cdf(0.5) - 10.0).abs() < 0.1);
let lognormal = MarginalDistribution::LogNormal {
mu: 2.0,
sigma: 0.5,
};
assert!(lognormal.inverse_cdf(0.5) > 0.0);
let uniform = MarginalDistribution::Uniform { a: 0.0, b: 100.0 };
assert!((uniform.inverse_cdf(0.5) - 50.0).abs() < 0.1);
let discrete = MarginalDistribution::DiscreteUniform { min: 1, max: 10 };
let value = discrete.inverse_cdf(0.5);
assert!((1.0..=10.0).contains(&value));
}
#[test]
fn test_multi_field_correlation() {
let config = correlation_presets::transaction_attributes();
assert!(config.validate().is_ok());
let mut engine = CorrelationEngine::new(42, config).unwrap();
let sample = engine.sample();
assert!(sample.contains_key("amount"));
assert!(sample.contains_key("line_items"));
assert!(sample.contains_key("approval_level"));
}
#[test]
fn test_sample_vec() {
let config = correlation_presets::amount_line_items();
let mut engine = CorrelationEngine::new(42, config).unwrap();
let vec = engine.sample_vec();
assert_eq!(vec.len(), 2);
assert!(vec[0] > 0.0);
assert!(vec[1] >= 2.0 && vec[1] <= 20.0);
}
}