use std::collections::HashMap;
use rand::{Rng, RngExt};
#[derive(Debug, Clone)]
pub struct SourceIetState {
pub cdf_values: Vec<f64>,
pub cdf_probabilities: Vec<f64>,
pub lag1_autocorr: f64,
pub last_iet_days: Option<f64>,
}
impl SourceIetState {
fn sample_quantile<R: Rng>(&self, rng: &mut R) -> f64 {
if self.cdf_values.is_empty() {
return 0.0;
}
let u: f64 = rng.random_range(f64::EPSILON..=1.0);
let mut idx = self.cdf_probabilities.len() - 1;
for (i, &p) in self.cdf_probabilities.iter().enumerate() {
if p >= u {
idx = i;
break;
}
}
self.cdf_values[idx]
}
}
fn empirical_cdf_at(values: &[f64], probabilities: &[f64], x: f64) -> f64 {
if values.is_empty() {
return 0.0;
}
if x <= values[0] {
return probabilities[0];
}
if x >= *values.last().expect("non-empty checked above") {
return *probabilities.last().expect("non-empty checked above");
}
let idx =
match values.binary_search_by(|v| v.partial_cmp(&x).unwrap_or(std::cmp::Ordering::Equal)) {
Ok(i) => return probabilities[i],
Err(i) => i,
};
let lo_v = values[idx - 1];
let hi_v = values[idx];
let lo_p = probabilities[idx - 1];
let hi_p = probabilities[idx];
if hi_v == lo_v {
lo_p
} else {
let t = (x - lo_v) / (hi_v - lo_v);
lo_p + t * (hi_p - lo_p)
}
}
fn quantile_at(values: &[f64], probabilities: &[f64], p: f64) -> f64 {
if values.is_empty() {
return 0.0;
}
let p = p.clamp(0.0, 1.0);
if p <= probabilities[0] {
return values[0];
}
if p >= *probabilities.last().expect("non-empty checked above") {
return *values.last().expect("non-empty checked above");
}
let idx = match probabilities
.binary_search_by(|prob| prob.partial_cmp(&p).unwrap_or(std::cmp::Ordering::Equal))
{
Ok(i) => return values[i],
Err(i) => i,
};
let lo_p = probabilities[idx - 1];
let hi_p = probabilities[idx];
let lo_v = values[idx - 1];
let hi_v = values[idx];
if hi_p == lo_p {
lo_v
} else {
let t = (p - lo_p) / (hi_p - lo_p);
lo_v + t * (hi_v - lo_v)
}
}
fn inverse_standard_normal(p: f64) -> f64 {
let p = p.clamp(1e-12, 1.0 - 1e-12);
let p_low = 0.02425_f64;
let p_high = 1.0 - p_low;
if p < p_low {
let q = (-2.0 * p.ln()).sqrt();
let c = [2.515517_f64, 0.802853, 0.010328];
let d = [1.432788_f64, 0.189269, 0.001308];
let rational =
(c[0] + c[1] * q + c[2] * q * q) / (1.0 + d[0] * q + d[1] * q * q + d[2] * q * q * q);
-(q - rational)
} else if p <= p_high {
let q = p - 0.5;
let r = q * q;
let a = [
2.50662823884_f64,
-18.61500062529,
41.39119773534,
-25.44106049637,
];
let b = [
-8.47351093090_f64,
23.08336743743,
-21.06224101826,
3.13082909833,
];
q * (a[0] + a[1] * r + a[2] * r * r + a[3] * r * r * r)
/ (1.0 + b[0] * r + b[1] * r * r + b[2] * r * r * r + b[3] * r * r * r * r)
} else {
let q = (-2.0 * (1.0 - p).ln()).sqrt();
let c = [2.515517_f64, 0.802853, 0.010328];
let d = [1.432788_f64, 0.189269, 0.001308];
let rational =
(c[0] + c[1] * q + c[2] * q * q) / (1.0 + d[0] * q + d[1] * q * q + d[2] * q * q * q);
q - rational
}
}
fn standard_normal_cdf(z: f64) -> f64 {
super::copula::standard_normal_cdf(z)
}
fn standard_normal_sample<R: Rng + ?Sized>(rng: &mut R) -> f64 {
let u1: f64 = rng.random_range(f64::EPSILON..=1.0);
let u2: f64 = rng.random_range(0.0..=1.0);
(-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
}
#[derive(Clone)]
pub struct ConditionalIETSampler {
per_source: HashMap<String, SourceIetState>,
fallback: SourceIetState,
}
impl ConditionalIETSampler {
pub fn from_state_map(
per_source: HashMap<String, SourceIetState>,
fallback: SourceIetState,
) -> Self {
Self {
per_source,
fallback,
}
}
pub fn sample_next<R: Rng>(&mut self, source: &str, rng: &mut R) -> f64 {
let state = self
.per_source
.get_mut(source)
.unwrap_or(&mut self.fallback);
if state.cdf_values.is_empty() {
return 0.0;
}
let rho = state.lag1_autocorr.clamp(-1.0, 1.0);
if rho.abs() < 0.1 || state.last_iet_days.is_none() {
let s = state.sample_quantile(rng).max(0.0);
state.last_iet_days = Some(s);
return s;
}
let prev = state.last_iet_days.expect("checked above");
let p_prev = empirical_cdf_at(&state.cdf_values, &state.cdf_probabilities, prev);
let z_prev = inverse_standard_normal(p_prev);
let z_curr = rho * z_prev + (1.0 - rho * rho).sqrt() * standard_normal_sample(rng);
let p_curr = standard_normal_cdf(z_curr);
let curr = quantile_at(&state.cdf_values, &state.cdf_probabilities, p_curr).max(0.0);
state.last_iet_days = Some(curr);
curr
}
pub fn has_source(&self, source: &str) -> bool {
self.per_source.contains_key(source)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
fn known_state(values: Vec<f64>, autocorr: f64) -> SourceIetState {
let n = values.len();
SourceIetState {
cdf_values: values,
cdf_probabilities: (1..=n).map(|i| i as f64 / n as f64).collect(),
lag1_autocorr: autocorr,
last_iet_days: None,
}
}
#[test]
fn iet_sampler_returns_known_values() {
let mut per_source = HashMap::new();
per_source.insert(
"KR".to_string(),
known_state(vec![1.0, 2.0, 5.0, 10.0], 0.0),
);
let mut sampler =
ConditionalIETSampler::from_state_map(per_source, known_state(vec![0.5, 1.0], 0.0));
let mut rng = ChaCha8Rng::seed_from_u64(42);
for _ in 0..30 {
let s = sampler.sample_next("KR", &mut rng);
assert!([1.0, 2.0, 5.0, 10.0].contains(&s), "unexpected sample {s}");
}
}
#[test]
fn iet_sampler_falls_back_on_unknown_source() {
let per_source = HashMap::new();
let mut sampler =
ConditionalIETSampler::from_state_map(per_source, known_state(vec![7.0], 0.0));
let mut rng = ChaCha8Rng::seed_from_u64(42);
assert!((sampler.sample_next("UNKNOWN", &mut rng) - 7.0).abs() < 1e-9);
}
#[test]
fn iet_sampler_autocorr_couples_samples() {
let mut per_source = HashMap::new();
per_source.insert("A".to_string(), known_state(vec![1.0, 10.0], 0.9));
let mut sampler =
ConditionalIETSampler::from_state_map(per_source, known_state(vec![5.0], 0.0));
let mut rng = ChaCha8Rng::seed_from_u64(42);
let first = sampler.sample_next("A", &mut rng);
let second = sampler.sample_next("A", &mut rng);
assert!(first.is_finite() && second.is_finite());
}
#[test]
fn empirical_cdf_at_interpolates_linearly() {
let v = vec![1.0, 2.0, 4.0];
let p = vec![0.25, 0.5, 1.0];
assert!((empirical_cdf_at(&v, &p, 1.0) - 0.25).abs() < 1e-9);
assert!((empirical_cdf_at(&v, &p, 2.0) - 0.5).abs() < 1e-9);
assert!((empirical_cdf_at(&v, &p, 3.0) - 0.75).abs() < 1e-9);
assert!((empirical_cdf_at(&v, &p, 0.5) - 0.25).abs() < 1e-9);
assert!((empirical_cdf_at(&v, &p, 5.0) - 1.0).abs() < 1e-9);
}
#[test]
fn quantile_at_inverts_empirical_cdf() {
let v = vec![1.0, 2.0, 4.0];
let p = vec![0.25, 0.5, 1.0];
assert!((quantile_at(&v, &p, 0.25) - 1.0).abs() < 1e-9);
assert!((quantile_at(&v, &p, 0.5) - 2.0).abs() < 1e-9);
assert!((quantile_at(&v, &p, 1.0) - 4.0).abs() < 1e-9);
assert!((quantile_at(&v, &p, 0.75) - 3.0).abs() < 1e-9);
}
#[test]
fn inverse_standard_normal_known_values() {
assert!((inverse_standard_normal(0.5)).abs() < 1e-6);
assert!((inverse_standard_normal(0.975) - 1.96).abs() < 1e-2);
assert!((inverse_standard_normal(0.025) + 1.96).abs() < 1e-2);
assert!(
inverse_standard_normal(0.01) < 0.0,
"lower tail must be negative"
);
assert!(
inverse_standard_normal(0.99) > 0.0,
"upper tail must be positive"
);
assert!((inverse_standard_normal(0.99) + inverse_standard_normal(0.01)).abs() < 1e-3);
}
#[test]
fn standard_normal_cdf_known_values() {
assert!((standard_normal_cdf(0.0) - 0.5).abs() < 1e-6);
assert!((standard_normal_cdf(1.96) - 0.975).abs() < 1e-3);
}
#[test]
fn iet_sampler_never_returns_negative() {
let mut per_source = HashMap::new();
per_source.insert("X".to_string(), known_state(vec![0.0, 0.0, 0.0], -1.0));
let mut sampler =
ConditionalIETSampler::from_state_map(per_source, known_state(vec![0.0], 0.0));
let mut rng = ChaCha8Rng::seed_from_u64(42);
for _ in 0..20 {
assert!(sampler.sample_next("X", &mut rng) >= 0.0);
}
}
#[test]
fn copula_coupling_preserves_target_rho() {
let mut per_source = HashMap::new();
let n = 1000usize;
per_source.insert(
"A".to_string(),
SourceIetState {
cdf_values: (1..=n).map(|i| i as f64 / n as f64).collect(),
cdf_probabilities: (1..=n).map(|i| i as f64 / n as f64).collect(),
lag1_autocorr: 0.6,
last_iet_days: None,
},
);
let fallback = SourceIetState {
cdf_values: vec![1.0],
cdf_probabilities: vec![1.0],
lag1_autocorr: 0.0,
last_iet_days: None,
};
let mut sampler = ConditionalIETSampler::from_state_map(per_source, fallback);
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let n_samples = 5000;
let mut series: Vec<f64> = Vec::with_capacity(n_samples);
for _ in 0..n_samples {
series.push(sampler.sample_next("A", &mut rng));
}
let mean_pre: f64 = series[..n_samples - 1].iter().sum::<f64>() / (n_samples - 1) as f64;
let mean_post: f64 = series[1..].iter().sum::<f64>() / (n_samples - 1) as f64;
let mut num = 0.0;
let mut dp = 0.0;
let mut dq = 0.0;
for i in 0..(n_samples - 1) {
let a = series[i] - mean_pre;
let b = series[i + 1] - mean_post;
num += a * b;
dp += a * a;
dq += b * b;
}
let empirical_rho = num / (dp.sqrt() * dq.sqrt());
assert!(
(empirical_rho - 0.6).abs() < 0.08,
"expected empirical ρ ≈ 0.6, got {empirical_rho}"
);
}
#[test]
fn copula_coupling_low_rho_uses_independent_path() {
let mut per_source = HashMap::new();
per_source.insert(
"A".to_string(),
SourceIetState {
cdf_values: vec![5.0; 10],
cdf_probabilities: (1..=10).map(|i| i as f64 / 10.0).collect(),
lag1_autocorr: 0.05,
last_iet_days: None,
},
);
let fallback = SourceIetState {
cdf_values: vec![5.0],
cdf_probabilities: vec![1.0],
lag1_autocorr: 0.0,
last_iet_days: None,
};
let mut sampler = ConditionalIETSampler::from_state_map(per_source, fallback);
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
for _ in 0..20 {
let s = sampler.sample_next("A", &mut rng);
assert!((s - 5.0).abs() < 1e-9);
}
}
}