Skip to main content

datasynth_core/distributions/
conditional_iet.rs

1//! Per-Source inter-event-time sampler driven by SP2's PerSourceIetPrior.
2
3use std::collections::HashMap;
4
5use rand::{Rng, RngExt};
6
7/// Per-Source RNG state for IET sampling.
8#[derive(Debug, Clone)]
9pub struct SourceIetState {
10    /// Quantile knot values (sorted ascending) from the prior's empirical CDF.
11    pub cdf_values: Vec<f64>,
12    /// Cumulative probabilities matching `cdf_values` (monotone in [0, 1]).
13    pub cdf_probabilities: Vec<f64>,
14    /// Lag-1 Pearson correlation observed in the corpus for this Source.
15    pub lag1_autocorr: f64,
16    /// Last sampled IET (in days) — used to couple the next draw via the autocorr.
17    pub last_iet_days: Option<f64>,
18}
19
20impl SourceIetState {
21    fn sample_quantile<R: Rng>(&self, rng: &mut R) -> f64 {
22        if self.cdf_values.is_empty() {
23            return 0.0;
24        }
25        let u: f64 = rng.random_range(f64::EPSILON..=1.0);
26        let mut idx = self.cdf_probabilities.len() - 1;
27        for (i, &p) in self.cdf_probabilities.iter().enumerate() {
28            if p >= u {
29                idx = i;
30                break;
31            }
32        }
33        self.cdf_values[idx]
34    }
35}
36
37/// Empirical CDF value at `x` — linear interpolation between knots.
38fn empirical_cdf_at(values: &[f64], probabilities: &[f64], x: f64) -> f64 {
39    if values.is_empty() {
40        return 0.0;
41    }
42    if x <= values[0] {
43        return probabilities[0];
44    }
45    if x >= *values.last().expect("non-empty checked above") {
46        return *probabilities.last().expect("non-empty checked above");
47    }
48    let idx =
49        match values.binary_search_by(|v| v.partial_cmp(&x).unwrap_or(std::cmp::Ordering::Equal)) {
50            Ok(i) => return probabilities[i],
51            Err(i) => i,
52        };
53    let lo_v = values[idx - 1];
54    let hi_v = values[idx];
55    let lo_p = probabilities[idx - 1];
56    let hi_p = probabilities[idx];
57    if hi_v == lo_v {
58        lo_p
59    } else {
60        let t = (x - lo_v) / (hi_v - lo_v);
61        lo_p + t * (hi_p - lo_p)
62    }
63}
64
65/// Inverse CDF: given quantile p in [0,1], return the value via linear interpolation.
66fn quantile_at(values: &[f64], probabilities: &[f64], p: f64) -> f64 {
67    if values.is_empty() {
68        return 0.0;
69    }
70    let p = p.clamp(0.0, 1.0);
71    if p <= probabilities[0] {
72        return values[0];
73    }
74    if p >= *probabilities.last().expect("non-empty checked above") {
75        return *values.last().expect("non-empty checked above");
76    }
77    let idx = match probabilities
78        .binary_search_by(|prob| prob.partial_cmp(&p).unwrap_or(std::cmp::Ordering::Equal))
79    {
80        Ok(i) => return values[i],
81        Err(i) => i,
82    };
83    let lo_p = probabilities[idx - 1];
84    let hi_p = probabilities[idx];
85    let lo_v = values[idx - 1];
86    let hi_v = values[idx];
87    if hi_p == lo_p {
88        lo_v
89    } else {
90        let t = (p - lo_p) / (hi_p - lo_p);
91        lo_v + t * (hi_v - lo_v)
92    }
93}
94
95/// Inverse standard-normal CDF (Φ⁻¹).
96///
97/// Rational approximation from Abramowitz & Stegun §26.2.23 with correct
98/// sign convention: returns negative values for p < 0.5 and positive for
99/// p > 0.5. The copula.rs `standard_normal_quantile` has inverted tail signs,
100/// so we provide an independent correct implementation here.
101fn inverse_standard_normal(p: f64) -> f64 {
102    let p = p.clamp(1e-12, 1.0 - 1e-12);
103    let p_low = 0.02425_f64;
104    let p_high = 1.0 - p_low;
105
106    if p < p_low {
107        // Lower tail: result is negative
108        let q = (-2.0 * p.ln()).sqrt();
109        let c = [2.515517_f64, 0.802853, 0.010328];
110        let d = [1.432788_f64, 0.189269, 0.001308];
111        let rational =
112            (c[0] + c[1] * q + c[2] * q * q) / (1.0 + d[0] * q + d[1] * q * q + d[2] * q * q * q);
113        -(q - rational)
114    } else if p <= p_high {
115        // Central region
116        let q = p - 0.5;
117        let r = q * q;
118        let a = [
119            2.50662823884_f64,
120            -18.61500062529,
121            41.39119773534,
122            -25.44106049637,
123        ];
124        let b = [
125            -8.47351093090_f64,
126            23.08336743743,
127            -21.06224101826,
128            3.13082909833,
129        ];
130        q * (a[0] + a[1] * r + a[2] * r * r + a[3] * r * r * r)
131            / (1.0 + b[0] * r + b[1] * r * r + b[2] * r * r * r + b[3] * r * r * r * r)
132    } else {
133        // Upper tail: result is positive (mirror of lower tail)
134        let q = (-2.0 * (1.0 - p).ln()).sqrt();
135        let c = [2.515517_f64, 0.802853, 0.010328];
136        let d = [1.432788_f64, 0.189269, 0.001308];
137        let rational =
138            (c[0] + c[1] * q + c[2] * q * q) / (1.0 + d[0] * q + d[1] * q * q + d[2] * q * q * q);
139        q - rational
140    }
141}
142
143/// Standard normal CDF (Φ).
144/// Re-uses the erf-based approximation from `copula.rs`.
145fn standard_normal_cdf(z: f64) -> f64 {
146    super::copula::standard_normal_cdf(z)
147}
148
149/// Sample one standard normal value via Box-Muller.
150fn standard_normal_sample<R: Rng + ?Sized>(rng: &mut R) -> f64 {
151    let u1: f64 = rng.random_range(f64::EPSILON..=1.0);
152    let u2: f64 = rng.random_range(0.0..=1.0);
153    (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
154}
155
156/// Per-Source IET sampler: each call to `sample_next` produces a fresh day-gap
157/// drawn from that Source's empirical CDF, optionally coupled with the previous
158/// sample via the lag-1 autocorrelation.
159#[derive(Clone)]
160pub struct ConditionalIETSampler {
161    per_source: HashMap<String, SourceIetState>,
162    fallback: SourceIetState,
163}
164
165impl ConditionalIETSampler {
166    pub fn from_state_map(
167        per_source: HashMap<String, SourceIetState>,
168        fallback: SourceIetState,
169    ) -> Self {
170        Self {
171            per_source,
172            fallback,
173        }
174    }
175
176    pub fn sample_next<R: Rng>(&mut self, source: &str, rng: &mut R) -> f64 {
177        let state = self
178            .per_source
179            .get_mut(source)
180            .unwrap_or(&mut self.fallback);
181        if state.cdf_values.is_empty() {
182            return 0.0;
183        }
184        let rho = state.lag1_autocorr.clamp(-1.0, 1.0);
185
186        // No-coupling path: |ρ| small or no previous sample.
187        if rho.abs() < 0.1 || state.last_iet_days.is_none() {
188            let s = state.sample_quantile(rng).max(0.0);
189            state.last_iet_days = Some(s);
190            return s;
191        }
192
193        // Gaussian-copula coupling.
194        let prev = state.last_iet_days.expect("checked above");
195        let p_prev = empirical_cdf_at(&state.cdf_values, &state.cdf_probabilities, prev);
196        let z_prev = inverse_standard_normal(p_prev);
197        let z_curr = rho * z_prev + (1.0 - rho * rho).sqrt() * standard_normal_sample(rng);
198        let p_curr = standard_normal_cdf(z_curr);
199        let curr = quantile_at(&state.cdf_values, &state.cdf_probabilities, p_curr).max(0.0);
200
201        state.last_iet_days = Some(curr);
202        curr
203    }
204
205    pub fn has_source(&self, source: &str) -> bool {
206        self.per_source.contains_key(source)
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use rand::SeedableRng;
214    use rand_chacha::ChaCha8Rng;
215
216    fn known_state(values: Vec<f64>, autocorr: f64) -> SourceIetState {
217        let n = values.len();
218        SourceIetState {
219            cdf_values: values,
220            cdf_probabilities: (1..=n).map(|i| i as f64 / n as f64).collect(),
221            lag1_autocorr: autocorr,
222            last_iet_days: None,
223        }
224    }
225
226    #[test]
227    fn iet_sampler_returns_known_values() {
228        let mut per_source = HashMap::new();
229        per_source.insert(
230            "KR".to_string(),
231            known_state(vec![1.0, 2.0, 5.0, 10.0], 0.0),
232        );
233        let mut sampler =
234            ConditionalIETSampler::from_state_map(per_source, known_state(vec![0.5, 1.0], 0.0));
235        let mut rng = ChaCha8Rng::seed_from_u64(42);
236        for _ in 0..30 {
237            let s = sampler.sample_next("KR", &mut rng);
238            assert!([1.0, 2.0, 5.0, 10.0].contains(&s), "unexpected sample {s}");
239        }
240    }
241
242    #[test]
243    fn iet_sampler_falls_back_on_unknown_source() {
244        let per_source = HashMap::new();
245        let mut sampler =
246            ConditionalIETSampler::from_state_map(per_source, known_state(vec![7.0], 0.0));
247        let mut rng = ChaCha8Rng::seed_from_u64(42);
248        assert!((sampler.sample_next("UNKNOWN", &mut rng) - 7.0).abs() < 1e-9);
249    }
250
251    #[test]
252    fn iet_sampler_autocorr_couples_samples() {
253        let mut per_source = HashMap::new();
254        per_source.insert("A".to_string(), known_state(vec![1.0, 10.0], 0.9));
255        let mut sampler =
256            ConditionalIETSampler::from_state_map(per_source, known_state(vec![5.0], 0.0));
257        let mut rng = ChaCha8Rng::seed_from_u64(42);
258        let first = sampler.sample_next("A", &mut rng);
259        let second = sampler.sample_next("A", &mut rng);
260        assert!(first.is_finite() && second.is_finite());
261    }
262
263    #[test]
264    fn empirical_cdf_at_interpolates_linearly() {
265        let v = vec![1.0, 2.0, 4.0];
266        let p = vec![0.25, 0.5, 1.0];
267        assert!((empirical_cdf_at(&v, &p, 1.0) - 0.25).abs() < 1e-9);
268        assert!((empirical_cdf_at(&v, &p, 2.0) - 0.5).abs() < 1e-9);
269        assert!((empirical_cdf_at(&v, &p, 3.0) - 0.75).abs() < 1e-9);
270        assert!((empirical_cdf_at(&v, &p, 0.5) - 0.25).abs() < 1e-9);
271        assert!((empirical_cdf_at(&v, &p, 5.0) - 1.0).abs() < 1e-9);
272    }
273
274    #[test]
275    fn quantile_at_inverts_empirical_cdf() {
276        let v = vec![1.0, 2.0, 4.0];
277        let p = vec![0.25, 0.5, 1.0];
278        assert!((quantile_at(&v, &p, 0.25) - 1.0).abs() < 1e-9);
279        assert!((quantile_at(&v, &p, 0.5) - 2.0).abs() < 1e-9);
280        assert!((quantile_at(&v, &p, 1.0) - 4.0).abs() < 1e-9);
281        assert!((quantile_at(&v, &p, 0.75) - 3.0).abs() < 1e-9);
282    }
283
284    #[test]
285    fn inverse_standard_normal_known_values() {
286        // Central region
287        assert!((inverse_standard_normal(0.5)).abs() < 1e-6);
288        // Near-central (still in central region since p_low=0.02425)
289        assert!((inverse_standard_normal(0.975) - 1.96).abs() < 1e-2);
290        assert!((inverse_standard_normal(0.025) + 1.96).abs() < 1e-2);
291        // Tail regions (these verify the sign convention is correct)
292        assert!(
293            inverse_standard_normal(0.01) < 0.0,
294            "lower tail must be negative"
295        );
296        assert!(
297            inverse_standard_normal(0.99) > 0.0,
298            "upper tail must be positive"
299        );
300        assert!((inverse_standard_normal(0.99) + inverse_standard_normal(0.01)).abs() < 1e-3);
301    }
302
303    #[test]
304    fn standard_normal_cdf_known_values() {
305        assert!((standard_normal_cdf(0.0) - 0.5).abs() < 1e-6);
306        assert!((standard_normal_cdf(1.96) - 0.975).abs() < 1e-3);
307    }
308
309    #[test]
310    fn iet_sampler_never_returns_negative() {
311        let mut per_source = HashMap::new();
312        per_source.insert("X".to_string(), known_state(vec![0.0, 0.0, 0.0], -1.0));
313        let mut sampler =
314            ConditionalIETSampler::from_state_map(per_source, known_state(vec![0.0], 0.0));
315        let mut rng = ChaCha8Rng::seed_from_u64(42);
316        for _ in 0..20 {
317            assert!(sampler.sample_next("X", &mut rng) >= 0.0);
318        }
319    }
320
321    #[test]
322    fn copula_coupling_preserves_target_rho() {
323        // Use a dense fine-grid uniform CDF (1000 knots over [0,1]) so that
324        // the Gaussian copula's rank correlation transfers accurately to
325        // Pearson correlation in value-space (rank ≈ Pearson for near-continuous
326        // uniform marginals).
327        let mut per_source = HashMap::new();
328        let n = 1000usize;
329        per_source.insert(
330            "A".to_string(),
331            SourceIetState {
332                cdf_values: (1..=n).map(|i| i as f64 / n as f64).collect(),
333                cdf_probabilities: (1..=n).map(|i| i as f64 / n as f64).collect(),
334                lag1_autocorr: 0.6,
335                last_iet_days: None,
336            },
337        );
338        let fallback = SourceIetState {
339            cdf_values: vec![1.0],
340            cdf_probabilities: vec![1.0],
341            lag1_autocorr: 0.0,
342            last_iet_days: None,
343        };
344        let mut sampler = ConditionalIETSampler::from_state_map(per_source, fallback);
345        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
346        let n_samples = 5000;
347        let mut series: Vec<f64> = Vec::with_capacity(n_samples);
348        for _ in 0..n_samples {
349            series.push(sampler.sample_next("A", &mut rng));
350        }
351        // Empirical lag-1 correlation should be near 0.6 (±0.08).
352        let mean_pre: f64 = series[..n_samples - 1].iter().sum::<f64>() / (n_samples - 1) as f64;
353        let mean_post: f64 = series[1..].iter().sum::<f64>() / (n_samples - 1) as f64;
354        let mut num = 0.0;
355        let mut dp = 0.0;
356        let mut dq = 0.0;
357        for i in 0..(n_samples - 1) {
358            let a = series[i] - mean_pre;
359            let b = series[i + 1] - mean_post;
360            num += a * b;
361            dp += a * a;
362            dq += b * b;
363        }
364        let empirical_rho = num / (dp.sqrt() * dq.sqrt());
365        assert!(
366            (empirical_rho - 0.6).abs() < 0.08,
367            "expected empirical ρ ≈ 0.6, got {empirical_rho}"
368        );
369    }
370
371    #[test]
372    fn copula_coupling_low_rho_uses_independent_path() {
373        let mut per_source = HashMap::new();
374        per_source.insert(
375            "A".to_string(),
376            SourceIetState {
377                cdf_values: vec![5.0; 10],
378                cdf_probabilities: (1..=10).map(|i| i as f64 / 10.0).collect(),
379                lag1_autocorr: 0.05,
380                last_iet_days: None,
381            },
382        );
383        let fallback = SourceIetState {
384            cdf_values: vec![5.0],
385            cdf_probabilities: vec![1.0],
386            lag1_autocorr: 0.0,
387            last_iet_days: None,
388        };
389        let mut sampler = ConditionalIETSampler::from_state_map(per_source, fallback);
390        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
391        // All CDF values are 5.0 — all samples must be 5.0 regardless of path.
392        for _ in 0..20 {
393            let s = sampler.sample_next("A", &mut rng);
394            assert!((s - 5.0).abs() < 1e-9);
395        }
396    }
397}