datasynth_core/distributions/
conditional_iet.rs1use std::collections::HashMap;
4
5use rand::{Rng, RngExt};
6
7#[derive(Debug, Clone)]
9pub struct SourceIetState {
10 pub cdf_values: Vec<f64>,
12 pub cdf_probabilities: Vec<f64>,
14 pub lag1_autocorr: f64,
16 pub last_iet_days: Option<f64>,
18}
19
20impl SourceIetState {
21 fn sample_quantile<R: Rng>(&self, rng: &mut R) -> f64 {
22 if self.cdf_values.is_empty() {
23 return 0.0;
24 }
25 let u: f64 = rng.random_range(f64::EPSILON..=1.0);
26 let mut idx = self.cdf_probabilities.len() - 1;
27 for (i, &p) in self.cdf_probabilities.iter().enumerate() {
28 if p >= u {
29 idx = i;
30 break;
31 }
32 }
33 self.cdf_values[idx]
34 }
35}
36
37fn empirical_cdf_at(values: &[f64], probabilities: &[f64], x: f64) -> f64 {
39 if values.is_empty() {
40 return 0.0;
41 }
42 if x <= values[0] {
43 return probabilities[0];
44 }
45 if x >= *values.last().expect("non-empty checked above") {
46 return *probabilities.last().expect("non-empty checked above");
47 }
48 let idx =
49 match values.binary_search_by(|v| v.partial_cmp(&x).unwrap_or(std::cmp::Ordering::Equal)) {
50 Ok(i) => return probabilities[i],
51 Err(i) => i,
52 };
53 let lo_v = values[idx - 1];
54 let hi_v = values[idx];
55 let lo_p = probabilities[idx - 1];
56 let hi_p = probabilities[idx];
57 if hi_v == lo_v {
58 lo_p
59 } else {
60 let t = (x - lo_v) / (hi_v - lo_v);
61 lo_p + t * (hi_p - lo_p)
62 }
63}
64
65fn quantile_at(values: &[f64], probabilities: &[f64], p: f64) -> f64 {
67 if values.is_empty() {
68 return 0.0;
69 }
70 let p = p.clamp(0.0, 1.0);
71 if p <= probabilities[0] {
72 return values[0];
73 }
74 if p >= *probabilities.last().expect("non-empty checked above") {
75 return *values.last().expect("non-empty checked above");
76 }
77 let idx = match probabilities
78 .binary_search_by(|prob| prob.partial_cmp(&p).unwrap_or(std::cmp::Ordering::Equal))
79 {
80 Ok(i) => return values[i],
81 Err(i) => i,
82 };
83 let lo_p = probabilities[idx - 1];
84 let hi_p = probabilities[idx];
85 let lo_v = values[idx - 1];
86 let hi_v = values[idx];
87 if hi_p == lo_p {
88 lo_v
89 } else {
90 let t = (p - lo_p) / (hi_p - lo_p);
91 lo_v + t * (hi_v - lo_v)
92 }
93}
94
95fn inverse_standard_normal(p: f64) -> f64 {
102 let p = p.clamp(1e-12, 1.0 - 1e-12);
103 let p_low = 0.02425_f64;
104 let p_high = 1.0 - p_low;
105
106 if p < p_low {
107 let q = (-2.0 * p.ln()).sqrt();
109 let c = [2.515517_f64, 0.802853, 0.010328];
110 let d = [1.432788_f64, 0.189269, 0.001308];
111 let rational =
112 (c[0] + c[1] * q + c[2] * q * q) / (1.0 + d[0] * q + d[1] * q * q + d[2] * q * q * q);
113 -(q - rational)
114 } else if p <= p_high {
115 let q = p - 0.5;
117 let r = q * q;
118 let a = [
119 2.50662823884_f64,
120 -18.61500062529,
121 41.39119773534,
122 -25.44106049637,
123 ];
124 let b = [
125 -8.47351093090_f64,
126 23.08336743743,
127 -21.06224101826,
128 3.13082909833,
129 ];
130 q * (a[0] + a[1] * r + a[2] * r * r + a[3] * r * r * r)
131 / (1.0 + b[0] * r + b[1] * r * r + b[2] * r * r * r + b[3] * r * r * r * r)
132 } else {
133 let q = (-2.0 * (1.0 - p).ln()).sqrt();
135 let c = [2.515517_f64, 0.802853, 0.010328];
136 let d = [1.432788_f64, 0.189269, 0.001308];
137 let rational =
138 (c[0] + c[1] * q + c[2] * q * q) / (1.0 + d[0] * q + d[1] * q * q + d[2] * q * q * q);
139 q - rational
140 }
141}
142
143fn standard_normal_cdf(z: f64) -> f64 {
146 super::copula::standard_normal_cdf(z)
147}
148
149fn standard_normal_sample<R: Rng + ?Sized>(rng: &mut R) -> f64 {
151 let u1: f64 = rng.random_range(f64::EPSILON..=1.0);
152 let u2: f64 = rng.random_range(0.0..=1.0);
153 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
154}
155
156#[derive(Clone)]
160pub struct ConditionalIETSampler {
161 per_source: HashMap<String, SourceIetState>,
162 fallback: SourceIetState,
163}
164
165impl ConditionalIETSampler {
166 pub fn from_state_map(
167 per_source: HashMap<String, SourceIetState>,
168 fallback: SourceIetState,
169 ) -> Self {
170 Self {
171 per_source,
172 fallback,
173 }
174 }
175
176 pub fn sample_next<R: Rng>(&mut self, source: &str, rng: &mut R) -> f64 {
177 let state = self
178 .per_source
179 .get_mut(source)
180 .unwrap_or(&mut self.fallback);
181 if state.cdf_values.is_empty() {
182 return 0.0;
183 }
184 let rho = state.lag1_autocorr.clamp(-1.0, 1.0);
185
186 if rho.abs() < 0.1 || state.last_iet_days.is_none() {
188 let s = state.sample_quantile(rng).max(0.0);
189 state.last_iet_days = Some(s);
190 return s;
191 }
192
193 let prev = state.last_iet_days.expect("checked above");
195 let p_prev = empirical_cdf_at(&state.cdf_values, &state.cdf_probabilities, prev);
196 let z_prev = inverse_standard_normal(p_prev);
197 let z_curr = rho * z_prev + (1.0 - rho * rho).sqrt() * standard_normal_sample(rng);
198 let p_curr = standard_normal_cdf(z_curr);
199 let curr = quantile_at(&state.cdf_values, &state.cdf_probabilities, p_curr).max(0.0);
200
201 state.last_iet_days = Some(curr);
202 curr
203 }
204
205 pub fn has_source(&self, source: &str) -> bool {
206 self.per_source.contains_key(source)
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use rand::SeedableRng;
214 use rand_chacha::ChaCha8Rng;
215
216 fn known_state(values: Vec<f64>, autocorr: f64) -> SourceIetState {
217 let n = values.len();
218 SourceIetState {
219 cdf_values: values,
220 cdf_probabilities: (1..=n).map(|i| i as f64 / n as f64).collect(),
221 lag1_autocorr: autocorr,
222 last_iet_days: None,
223 }
224 }
225
226 #[test]
227 fn iet_sampler_returns_known_values() {
228 let mut per_source = HashMap::new();
229 per_source.insert(
230 "KR".to_string(),
231 known_state(vec![1.0, 2.0, 5.0, 10.0], 0.0),
232 );
233 let mut sampler =
234 ConditionalIETSampler::from_state_map(per_source, known_state(vec![0.5, 1.0], 0.0));
235 let mut rng = ChaCha8Rng::seed_from_u64(42);
236 for _ in 0..30 {
237 let s = sampler.sample_next("KR", &mut rng);
238 assert!([1.0, 2.0, 5.0, 10.0].contains(&s), "unexpected sample {s}");
239 }
240 }
241
242 #[test]
243 fn iet_sampler_falls_back_on_unknown_source() {
244 let per_source = HashMap::new();
245 let mut sampler =
246 ConditionalIETSampler::from_state_map(per_source, known_state(vec![7.0], 0.0));
247 let mut rng = ChaCha8Rng::seed_from_u64(42);
248 assert!((sampler.sample_next("UNKNOWN", &mut rng) - 7.0).abs() < 1e-9);
249 }
250
251 #[test]
252 fn iet_sampler_autocorr_couples_samples() {
253 let mut per_source = HashMap::new();
254 per_source.insert("A".to_string(), known_state(vec![1.0, 10.0], 0.9));
255 let mut sampler =
256 ConditionalIETSampler::from_state_map(per_source, known_state(vec![5.0], 0.0));
257 let mut rng = ChaCha8Rng::seed_from_u64(42);
258 let first = sampler.sample_next("A", &mut rng);
259 let second = sampler.sample_next("A", &mut rng);
260 assert!(first.is_finite() && second.is_finite());
261 }
262
263 #[test]
264 fn empirical_cdf_at_interpolates_linearly() {
265 let v = vec![1.0, 2.0, 4.0];
266 let p = vec![0.25, 0.5, 1.0];
267 assert!((empirical_cdf_at(&v, &p, 1.0) - 0.25).abs() < 1e-9);
268 assert!((empirical_cdf_at(&v, &p, 2.0) - 0.5).abs() < 1e-9);
269 assert!((empirical_cdf_at(&v, &p, 3.0) - 0.75).abs() < 1e-9);
270 assert!((empirical_cdf_at(&v, &p, 0.5) - 0.25).abs() < 1e-9);
271 assert!((empirical_cdf_at(&v, &p, 5.0) - 1.0).abs() < 1e-9);
272 }
273
274 #[test]
275 fn quantile_at_inverts_empirical_cdf() {
276 let v = vec![1.0, 2.0, 4.0];
277 let p = vec![0.25, 0.5, 1.0];
278 assert!((quantile_at(&v, &p, 0.25) - 1.0).abs() < 1e-9);
279 assert!((quantile_at(&v, &p, 0.5) - 2.0).abs() < 1e-9);
280 assert!((quantile_at(&v, &p, 1.0) - 4.0).abs() < 1e-9);
281 assert!((quantile_at(&v, &p, 0.75) - 3.0).abs() < 1e-9);
282 }
283
284 #[test]
285 fn inverse_standard_normal_known_values() {
286 assert!((inverse_standard_normal(0.5)).abs() < 1e-6);
288 assert!((inverse_standard_normal(0.975) - 1.96).abs() < 1e-2);
290 assert!((inverse_standard_normal(0.025) + 1.96).abs() < 1e-2);
291 assert!(
293 inverse_standard_normal(0.01) < 0.0,
294 "lower tail must be negative"
295 );
296 assert!(
297 inverse_standard_normal(0.99) > 0.0,
298 "upper tail must be positive"
299 );
300 assert!((inverse_standard_normal(0.99) + inverse_standard_normal(0.01)).abs() < 1e-3);
301 }
302
303 #[test]
304 fn standard_normal_cdf_known_values() {
305 assert!((standard_normal_cdf(0.0) - 0.5).abs() < 1e-6);
306 assert!((standard_normal_cdf(1.96) - 0.975).abs() < 1e-3);
307 }
308
309 #[test]
310 fn iet_sampler_never_returns_negative() {
311 let mut per_source = HashMap::new();
312 per_source.insert("X".to_string(), known_state(vec![0.0, 0.0, 0.0], -1.0));
313 let mut sampler =
314 ConditionalIETSampler::from_state_map(per_source, known_state(vec![0.0], 0.0));
315 let mut rng = ChaCha8Rng::seed_from_u64(42);
316 for _ in 0..20 {
317 assert!(sampler.sample_next("X", &mut rng) >= 0.0);
318 }
319 }
320
321 #[test]
322 fn copula_coupling_preserves_target_rho() {
323 let mut per_source = HashMap::new();
328 let n = 1000usize;
329 per_source.insert(
330 "A".to_string(),
331 SourceIetState {
332 cdf_values: (1..=n).map(|i| i as f64 / n as f64).collect(),
333 cdf_probabilities: (1..=n).map(|i| i as f64 / n as f64).collect(),
334 lag1_autocorr: 0.6,
335 last_iet_days: None,
336 },
337 );
338 let fallback = SourceIetState {
339 cdf_values: vec![1.0],
340 cdf_probabilities: vec![1.0],
341 lag1_autocorr: 0.0,
342 last_iet_days: None,
343 };
344 let mut sampler = ConditionalIETSampler::from_state_map(per_source, fallback);
345 let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
346 let n_samples = 5000;
347 let mut series: Vec<f64> = Vec::with_capacity(n_samples);
348 for _ in 0..n_samples {
349 series.push(sampler.sample_next("A", &mut rng));
350 }
351 let mean_pre: f64 = series[..n_samples - 1].iter().sum::<f64>() / (n_samples - 1) as f64;
353 let mean_post: f64 = series[1..].iter().sum::<f64>() / (n_samples - 1) as f64;
354 let mut num = 0.0;
355 let mut dp = 0.0;
356 let mut dq = 0.0;
357 for i in 0..(n_samples - 1) {
358 let a = series[i] - mean_pre;
359 let b = series[i + 1] - mean_post;
360 num += a * b;
361 dp += a * a;
362 dq += b * b;
363 }
364 let empirical_rho = num / (dp.sqrt() * dq.sqrt());
365 assert!(
366 (empirical_rho - 0.6).abs() < 0.08,
367 "expected empirical ρ ≈ 0.6, got {empirical_rho}"
368 );
369 }
370
371 #[test]
372 fn copula_coupling_low_rho_uses_independent_path() {
373 let mut per_source = HashMap::new();
374 per_source.insert(
375 "A".to_string(),
376 SourceIetState {
377 cdf_values: vec![5.0; 10],
378 cdf_probabilities: (1..=10).map(|i| i as f64 / 10.0).collect(),
379 lag1_autocorr: 0.05,
380 last_iet_days: None,
381 },
382 );
383 let fallback = SourceIetState {
384 cdf_values: vec![5.0],
385 cdf_probabilities: vec![1.0],
386 lag1_autocorr: 0.0,
387 last_iet_days: None,
388 };
389 let mut sampler = ConditionalIETSampler::from_state_map(per_source, fallback);
390 let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
391 for _ in 0..20 {
393 let s = sampler.sample_next("A", &mut rng);
394 assert!((s - 5.0).abs() < 1e-9);
395 }
396 }
397}