1use super::copula::{
10 cholesky_decompose, standard_normal_cdf, standard_normal_quantile, CopulaType,
11};
12use rand::prelude::*;
13use rand_chacha::ChaCha8Rng;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct CorrelatedField {
20 pub name: String,
22 pub distribution: MarginalDistribution,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28#[serde(rename_all = "snake_case", tag = "type")]
29pub enum MarginalDistribution {
30 Normal { mu: f64, sigma: f64 },
32 LogNormal { mu: f64, sigma: f64 },
34 Uniform { a: f64, b: f64 },
36 DiscreteUniform { min: i32, max: i32 },
38 Custom { quantiles: Vec<f64> },
40}
41
42impl Default for MarginalDistribution {
43 fn default() -> Self {
44 Self::Normal {
45 mu: 0.0,
46 sigma: 1.0,
47 }
48 }
49}
50
51impl MarginalDistribution {
52 pub fn inverse_cdf(&self, u: f64) -> f64 {
54 match self {
55 Self::Normal { mu, sigma } => mu + sigma * standard_normal_quantile(u),
56 Self::LogNormal { mu, sigma } => {
57 let z = standard_normal_quantile(u);
58 (mu + sigma * z).exp()
59 }
60 Self::Uniform { a, b } => a + u * (b - a),
61 Self::DiscreteUniform { min, max } => {
62 let range = (*max - *min + 1) as f64;
63 (*min as f64 + (u * range).floor()).min(*max as f64)
64 }
65 Self::Custom { quantiles } => {
66 if quantiles.is_empty() {
67 return 0.0;
68 }
69 let n = quantiles.len();
71 let idx = u * (n - 1) as f64;
72 let low_idx = idx.floor() as usize;
73 let high_idx = (low_idx + 1).min(n - 1);
74 let frac = idx - low_idx as f64;
75 quantiles[low_idx] * (1.0 - frac) + quantiles[high_idx] * frac
76 }
77 }
78 }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct CorrelationConfig {
84 pub fields: Vec<CorrelatedField>,
86 pub matrix: Vec<f64>,
89 #[serde(default)]
91 pub copula_type: CopulaType,
92}
93
94impl Default for CorrelationConfig {
95 fn default() -> Self {
96 Self {
97 fields: vec![],
98 matrix: vec![],
99 copula_type: CopulaType::Gaussian,
100 }
101 }
102}
103
104impl CorrelationConfig {
105 pub fn new(fields: Vec<CorrelatedField>, matrix: Vec<f64>) -> Self {
107 Self {
108 fields,
109 matrix,
110 copula_type: CopulaType::Gaussian,
111 }
112 }
113
114 pub fn bivariate(field1: CorrelatedField, field2: CorrelatedField, correlation: f64) -> Self {
116 Self {
117 fields: vec![field1, field2],
118 matrix: vec![correlation],
119 copula_type: CopulaType::Gaussian,
120 }
121 }
122
123 pub fn validate(&self) -> Result<(), String> {
125 let n = self.fields.len();
126 if n < 2 {
127 return Err("At least 2 fields are required for correlation".to_string());
128 }
129
130 let expected_matrix_size = n * (n - 1) / 2;
131 if self.matrix.len() != expected_matrix_size {
132 return Err(format!(
133 "Expected {} correlation values for {} fields, got {}",
134 expected_matrix_size,
135 n,
136 self.matrix.len()
137 ));
138 }
139
140 for (i, &corr) in self.matrix.iter().enumerate() {
142 if !(-1.0..=1.0).contains(&corr) {
143 return Err(format!(
144 "Correlation at index {i} must be in [-1, 1], got {corr}"
145 ));
146 }
147 }
148
149 let full_matrix = self.to_full_matrix();
151 if cholesky_decompose(&full_matrix).is_none() {
152 return Err(
153 "Correlation matrix is not positive semi-definite (invalid correlations)"
154 .to_string(),
155 );
156 }
157
158 Ok(())
159 }
160
161 pub fn to_full_matrix(&self) -> Vec<Vec<f64>> {
163 let n = self.fields.len();
164 let mut matrix = vec![vec![0.0; n]; n];
165
166 for (i, row) in matrix.iter_mut().enumerate() {
168 row[i] = 1.0;
169 }
170
171 #[allow(clippy::needless_range_loop)]
174 {
175 let mut idx = 0;
176 for i in 0..n {
177 for j in (i + 1)..n {
178 let val = self.matrix[idx];
179 matrix[i][j] = val;
180 matrix[j][i] = val;
181 idx += 1;
182 }
183 }
184 }
185
186 matrix
187 }
188
189 pub fn field_names(&self) -> Vec<&str> {
191 self.fields.iter().map(|f| f.name.as_str()).collect()
192 }
193}
194
195pub struct CorrelationEngine {
197 rng: ChaCha8Rng,
198 config: CorrelationConfig,
199 cholesky: Vec<Vec<f64>>,
201}
202
203impl CorrelationEngine {
204 pub fn new(seed: u64, config: CorrelationConfig) -> Result<Self, String> {
206 config.validate()?;
207
208 let full_matrix = config.to_full_matrix();
209 let cholesky = cholesky_decompose(&full_matrix)
210 .ok_or_else(|| "Failed to compute Cholesky decomposition".to_string())?;
211
212 Ok(Self {
213 rng: ChaCha8Rng::seed_from_u64(seed),
214 config,
215 cholesky,
216 })
217 }
218
219 pub fn sample(&mut self) -> HashMap<String, f64> {
221 let n = self.config.fields.len();
222
223 let z: Vec<f64> = (0..n).map(|_| self.sample_standard_normal()).collect();
225
226 let y: Vec<f64> = self
228 .cholesky
229 .iter()
230 .enumerate()
231 .map(|(i, row)| {
232 row.iter()
233 .take(i + 1)
234 .zip(z.iter())
235 .map(|(c, z)| c * z)
236 .sum()
237 })
238 .collect();
239
240 let u: Vec<f64> = y.iter().map(|&yi| standard_normal_cdf(yi)).collect();
242
243 let mut result = HashMap::new();
245 for (i, field) in self.config.fields.iter().enumerate() {
246 let value = field.distribution.inverse_cdf(u[i]);
247 result.insert(field.name.clone(), value);
248 }
249
250 result
251 }
252
253 pub fn sample_vec(&mut self) -> Vec<f64> {
255 let n = self.config.fields.len();
256
257 let z: Vec<f64> = (0..n).map(|_| self.sample_standard_normal()).collect();
259
260 let y: Vec<f64> = self
262 .cholesky
263 .iter()
264 .enumerate()
265 .map(|(i, row)| {
266 row.iter()
267 .take(i + 1)
268 .zip(z.iter())
269 .map(|(c, z)| c * z)
270 .sum()
271 })
272 .collect();
273
274 let u: Vec<f64> = y.iter().map(|&yi| standard_normal_cdf(yi)).collect();
276
277 self.config
279 .fields
280 .iter()
281 .enumerate()
282 .map(|(i, field)| field.distribution.inverse_cdf(u[i]))
283 .collect()
284 }
285
286 pub fn sample_field(&mut self, name: &str) -> Option<f64> {
288 let sample = self.sample();
289 sample.get(name).copied()
290 }
291
292 pub fn sample_n(&mut self, n: usize) -> Vec<HashMap<String, f64>> {
294 (0..n).map(|_| self.sample()).collect()
295 }
296
297 fn sample_standard_normal(&mut self) -> f64 {
299 let u1: f64 = self.rng.random();
300 let u2: f64 = self.rng.random();
301 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
302 }
303
304 pub fn reset(&mut self, seed: u64) {
306 self.rng = ChaCha8Rng::seed_from_u64(seed);
307 }
308
309 pub fn config(&self) -> &CorrelationConfig {
311 &self.config
312 }
313}
314
315pub mod correlation_presets {
317 use super::*;
318
319 pub fn amount_line_items() -> CorrelationConfig {
322 CorrelationConfig::bivariate(
323 CorrelatedField {
324 name: "amount".to_string(),
325 distribution: MarginalDistribution::LogNormal {
326 mu: 7.0,
327 sigma: 2.0,
328 },
329 },
330 CorrelatedField {
331 name: "line_items".to_string(),
332 distribution: MarginalDistribution::DiscreteUniform { min: 2, max: 20 },
333 },
334 0.65,
335 )
336 }
337
338 pub fn amount_approval_level() -> CorrelationConfig {
341 CorrelationConfig::bivariate(
342 CorrelatedField {
343 name: "amount".to_string(),
344 distribution: MarginalDistribution::LogNormal {
345 mu: 8.0,
346 sigma: 2.5,
347 },
348 },
349 CorrelatedField {
350 name: "approval_level".to_string(),
351 distribution: MarginalDistribution::DiscreteUniform { min: 1, max: 5 },
352 },
353 0.72,
354 )
355 }
356
357 pub fn order_processing_time() -> CorrelationConfig {
360 CorrelationConfig::bivariate(
361 CorrelatedField {
362 name: "order_value".to_string(),
363 distribution: MarginalDistribution::LogNormal {
364 mu: 7.5,
365 sigma: 1.5,
366 },
367 },
368 CorrelatedField {
369 name: "processing_days".to_string(),
370 distribution: MarginalDistribution::LogNormal {
371 mu: 1.5,
372 sigma: 0.8,
373 },
374 },
375 0.35,
376 )
377 }
378
379 pub fn transaction_attributes() -> CorrelationConfig {
381 CorrelationConfig {
382 fields: vec![
383 CorrelatedField {
384 name: "amount".to_string(),
385 distribution: MarginalDistribution::LogNormal {
386 mu: 7.0,
387 sigma: 2.0,
388 },
389 },
390 CorrelatedField {
391 name: "line_items".to_string(),
392 distribution: MarginalDistribution::DiscreteUniform { min: 2, max: 15 },
393 },
394 CorrelatedField {
395 name: "approval_level".to_string(),
396 distribution: MarginalDistribution::DiscreteUniform { min: 1, max: 4 },
397 },
398 ],
399 matrix: vec![0.65, 0.72, 0.55],
404 copula_type: CopulaType::Gaussian,
405 }
406 }
407}
408
409#[cfg(test)]
410#[allow(clippy::unwrap_used)]
411mod tests {
412 use super::*;
413
414 #[test]
415 fn test_correlation_config_validation() {
416 let valid = CorrelationConfig::bivariate(
417 CorrelatedField {
418 name: "x".to_string(),
419 distribution: MarginalDistribution::Normal {
420 mu: 0.0,
421 sigma: 1.0,
422 },
423 },
424 CorrelatedField {
425 name: "y".to_string(),
426 distribution: MarginalDistribution::Normal {
427 mu: 0.0,
428 sigma: 1.0,
429 },
430 },
431 0.5,
432 );
433 assert!(valid.validate().is_ok());
434
435 let invalid_corr = CorrelationConfig::bivariate(
437 CorrelatedField {
438 name: "x".to_string(),
439 distribution: MarginalDistribution::Normal {
440 mu: 0.0,
441 sigma: 1.0,
442 },
443 },
444 CorrelatedField {
445 name: "y".to_string(),
446 distribution: MarginalDistribution::Normal {
447 mu: 0.0,
448 sigma: 1.0,
449 },
450 },
451 1.5,
452 );
453 assert!(invalid_corr.validate().is_err());
454 }
455
456 #[test]
457 fn test_full_matrix_conversion() {
458 let config = CorrelationConfig {
459 fields: vec![
460 CorrelatedField {
461 name: "a".to_string(),
462 distribution: MarginalDistribution::default(),
463 },
464 CorrelatedField {
465 name: "b".to_string(),
466 distribution: MarginalDistribution::default(),
467 },
468 CorrelatedField {
469 name: "c".to_string(),
470 distribution: MarginalDistribution::default(),
471 },
472 ],
473 matrix: vec![0.5, 0.3, 0.4], copula_type: CopulaType::Gaussian,
475 };
476
477 let full = config.to_full_matrix();
478
479 assert_eq!(full[0][0], 1.0);
481 assert_eq!(full[1][1], 1.0);
482 assert_eq!(full[2][2], 1.0);
483
484 assert_eq!(full[0][1], full[1][0]);
486 assert_eq!(full[0][2], full[2][0]);
487 assert_eq!(full[1][2], full[2][1]);
488
489 assert_eq!(full[0][1], 0.5);
491 assert_eq!(full[0][2], 0.3);
492 assert_eq!(full[1][2], 0.4);
493 }
494
495 #[test]
496 fn test_correlation_engine_sampling() {
497 let config = correlation_presets::amount_line_items();
498 let mut engine = CorrelationEngine::new(42, config).unwrap();
499
500 let samples = engine.sample_n(2000); assert_eq!(samples.len(), 2000);
502 let n = samples.len() as f64;
503
504 let amounts: Vec<f64> = samples.iter().map(|s| s["amount"]).collect();
506 let line_items: Vec<f64> = samples.iter().map(|s| s["line_items"]).collect();
507
508 assert!(amounts.iter().all(|&a| a > 0.0));
510
511 assert!(line_items.iter().all(|&l| (2.0..=20.0).contains(&l)));
513
514 let mean_a = amounts.iter().sum::<f64>() / n;
516 let mean_l = line_items.iter().sum::<f64>() / n;
517
518 let mut cov = 0.0;
519 let mut var_a = 0.0;
520 let mut var_l = 0.0;
521 for (a, l) in amounts.iter().zip(line_items.iter()) {
522 let da = a - mean_a;
523 let dl = l - mean_l;
524 cov += da * dl;
525 var_a += da * da;
526 var_l += dl * dl;
527 }
528
529 let correlation = if var_a > 0.0 && var_l > 0.0 {
530 cov / (var_a.sqrt() * var_l.sqrt())
531 } else {
532 0.0
533 };
534
535 assert!(
542 correlation > -0.5,
543 "Correlation {} is unexpectedly strongly negative",
544 correlation
545 );
546 }
547
548 #[test]
549 fn test_correlation_engine_determinism() {
550 let config = correlation_presets::amount_line_items();
551
552 let mut engine1 = CorrelationEngine::new(42, config.clone()).unwrap();
553 let mut engine2 = CorrelationEngine::new(42, config).unwrap();
554
555 for _ in 0..100 {
556 let s1 = engine1.sample();
557 let s2 = engine2.sample();
558 assert_eq!(s1["amount"], s2["amount"]);
559 assert_eq!(s1["line_items"], s2["line_items"]);
560 }
561 }
562
563 #[test]
564 fn test_marginal_inverse_cdf() {
565 let normal = MarginalDistribution::Normal {
567 mu: 10.0,
568 sigma: 2.0,
569 };
570 assert!((normal.inverse_cdf(0.5) - 10.0).abs() < 0.1);
571
572 let lognormal = MarginalDistribution::LogNormal {
574 mu: 2.0,
575 sigma: 0.5,
576 };
577 assert!(lognormal.inverse_cdf(0.5) > 0.0);
578
579 let uniform = MarginalDistribution::Uniform { a: 0.0, b: 100.0 };
581 assert!((uniform.inverse_cdf(0.5) - 50.0).abs() < 0.1);
582
583 let discrete = MarginalDistribution::DiscreteUniform { min: 1, max: 10 };
585 let value = discrete.inverse_cdf(0.5);
586 assert!((1.0..=10.0).contains(&value));
587 }
588
589 #[test]
590 fn test_multi_field_correlation() {
591 let config = correlation_presets::transaction_attributes();
592 assert!(config.validate().is_ok());
593
594 let mut engine = CorrelationEngine::new(42, config).unwrap();
595 let sample = engine.sample();
596
597 assert!(sample.contains_key("amount"));
598 assert!(sample.contains_key("line_items"));
599 assert!(sample.contains_key("approval_level"));
600 }
601
602 #[test]
603 fn test_sample_vec() {
604 let config = correlation_presets::amount_line_items();
605 let mut engine = CorrelationEngine::new(42, config).unwrap();
606
607 let vec = engine.sample_vec();
608 assert_eq!(vec.len(), 2);
609
610 assert!(vec[0] > 0.0);
612
613 assert!(vec[1] >= 2.0 && vec[1] <= 20.0);
615 }
616}