1use super::copula::{
10 cholesky_decompose, standard_normal_cdf, standard_normal_quantile, CopulaType,
11};
12use rand::prelude::*;
13use rand_chacha::ChaCha8Rng;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct CorrelatedField {
20 pub name: String,
22 pub distribution: MarginalDistribution,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28#[serde(rename_all = "snake_case", tag = "type")]
29pub enum MarginalDistribution {
30 Normal { mu: f64, sigma: f64 },
32 LogNormal { mu: f64, sigma: f64 },
34 Uniform { a: f64, b: f64 },
36 DiscreteUniform { min: i32, max: i32 },
38 Custom { quantiles: Vec<f64> },
40}
41
42impl Default for MarginalDistribution {
43 fn default() -> Self {
44 Self::Normal {
45 mu: 0.0,
46 sigma: 1.0,
47 }
48 }
49}
50
51impl MarginalDistribution {
52 pub fn inverse_cdf(&self, u: f64) -> f64 {
54 match self {
55 Self::Normal { mu, sigma } => mu + sigma * standard_normal_quantile(u),
56 Self::LogNormal { mu, sigma } => {
57 let z = standard_normal_quantile(u);
58 (mu + sigma * z).exp()
59 }
60 Self::Uniform { a, b } => a + u * (b - a),
61 Self::DiscreteUniform { min, max } => {
62 let range = (*max - *min + 1) as f64;
63 (*min as f64 + (u * range).floor()).min(*max as f64)
64 }
65 Self::Custom { quantiles } => {
66 if quantiles.is_empty() {
67 return 0.0;
68 }
69 let n = quantiles.len();
71 let idx = u * (n - 1) as f64;
72 let low_idx = idx.floor() as usize;
73 let high_idx = (low_idx + 1).min(n - 1);
74 let frac = idx - low_idx as f64;
75 quantiles[low_idx] * (1.0 - frac) + quantiles[high_idx] * frac
76 }
77 }
78 }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct CorrelationConfig {
84 pub fields: Vec<CorrelatedField>,
86 pub matrix: Vec<f64>,
89 #[serde(default)]
91 pub copula_type: CopulaType,
92}
93
94impl Default for CorrelationConfig {
95 fn default() -> Self {
96 Self {
97 fields: vec![],
98 matrix: vec![],
99 copula_type: CopulaType::Gaussian,
100 }
101 }
102}
103
104impl CorrelationConfig {
105 pub fn new(fields: Vec<CorrelatedField>, matrix: Vec<f64>) -> Self {
107 Self {
108 fields,
109 matrix,
110 copula_type: CopulaType::Gaussian,
111 }
112 }
113
114 pub fn bivariate(field1: CorrelatedField, field2: CorrelatedField, correlation: f64) -> Self {
116 Self {
117 fields: vec![field1, field2],
118 matrix: vec![correlation],
119 copula_type: CopulaType::Gaussian,
120 }
121 }
122
123 pub fn validate(&self) -> Result<(), String> {
125 let n = self.fields.len();
126 if n < 2 {
127 return Err("At least 2 fields are required for correlation".to_string());
128 }
129
130 let expected_matrix_size = n * (n - 1) / 2;
131 if self.matrix.len() != expected_matrix_size {
132 return Err(format!(
133 "Expected {} correlation values for {} fields, got {}",
134 expected_matrix_size,
135 n,
136 self.matrix.len()
137 ));
138 }
139
140 for (i, &corr) in self.matrix.iter().enumerate() {
142 if !(-1.0..=1.0).contains(&corr) {
143 return Err(format!(
144 "Correlation at index {} must be in [-1, 1], got {}",
145 i, corr
146 ));
147 }
148 }
149
150 let full_matrix = self.to_full_matrix();
152 if cholesky_decompose(&full_matrix).is_none() {
153 return Err(
154 "Correlation matrix is not positive semi-definite (invalid correlations)"
155 .to_string(),
156 );
157 }
158
159 Ok(())
160 }
161
162 pub fn to_full_matrix(&self) -> Vec<Vec<f64>> {
164 let n = self.fields.len();
165 let mut matrix = vec![vec![0.0; n]; n];
166
167 for (i, row) in matrix.iter_mut().enumerate() {
169 row[i] = 1.0;
170 }
171
172 #[allow(clippy::needless_range_loop)]
175 {
176 let mut idx = 0;
177 for i in 0..n {
178 for j in (i + 1)..n {
179 let val = self.matrix[idx];
180 matrix[i][j] = val;
181 matrix[j][i] = val;
182 idx += 1;
183 }
184 }
185 }
186
187 matrix
188 }
189
190 pub fn field_names(&self) -> Vec<&str> {
192 self.fields.iter().map(|f| f.name.as_str()).collect()
193 }
194}
195
196pub struct CorrelationEngine {
198 rng: ChaCha8Rng,
199 config: CorrelationConfig,
200 cholesky: Vec<Vec<f64>>,
202}
203
204impl CorrelationEngine {
205 pub fn new(seed: u64, config: CorrelationConfig) -> Result<Self, String> {
207 config.validate()?;
208
209 let full_matrix = config.to_full_matrix();
210 let cholesky = cholesky_decompose(&full_matrix)
211 .ok_or_else(|| "Failed to compute Cholesky decomposition".to_string())?;
212
213 Ok(Self {
214 rng: ChaCha8Rng::seed_from_u64(seed),
215 config,
216 cholesky,
217 })
218 }
219
220 pub fn sample(&mut self) -> HashMap<String, f64> {
222 let n = self.config.fields.len();
223
224 let z: Vec<f64> = (0..n).map(|_| self.sample_standard_normal()).collect();
226
227 let y: Vec<f64> = self
229 .cholesky
230 .iter()
231 .enumerate()
232 .map(|(i, row)| {
233 row.iter()
234 .take(i + 1)
235 .zip(z.iter())
236 .map(|(c, z)| c * z)
237 .sum()
238 })
239 .collect();
240
241 let u: Vec<f64> = y.iter().map(|&yi| standard_normal_cdf(yi)).collect();
243
244 let mut result = HashMap::new();
246 for (i, field) in self.config.fields.iter().enumerate() {
247 let value = field.distribution.inverse_cdf(u[i]);
248 result.insert(field.name.clone(), value);
249 }
250
251 result
252 }
253
254 pub fn sample_vec(&mut self) -> Vec<f64> {
256 let n = self.config.fields.len();
257
258 let z: Vec<f64> = (0..n).map(|_| self.sample_standard_normal()).collect();
260
261 let y: Vec<f64> = self
263 .cholesky
264 .iter()
265 .enumerate()
266 .map(|(i, row)| {
267 row.iter()
268 .take(i + 1)
269 .zip(z.iter())
270 .map(|(c, z)| c * z)
271 .sum()
272 })
273 .collect();
274
275 let u: Vec<f64> = y.iter().map(|&yi| standard_normal_cdf(yi)).collect();
277
278 self.config
280 .fields
281 .iter()
282 .enumerate()
283 .map(|(i, field)| field.distribution.inverse_cdf(u[i]))
284 .collect()
285 }
286
287 pub fn sample_field(&mut self, name: &str) -> Option<f64> {
289 let sample = self.sample();
290 sample.get(name).copied()
291 }
292
293 pub fn sample_n(&mut self, n: usize) -> Vec<HashMap<String, f64>> {
295 (0..n).map(|_| self.sample()).collect()
296 }
297
298 fn sample_standard_normal(&mut self) -> f64 {
300 let u1: f64 = self.rng.gen();
301 let u2: f64 = self.rng.gen();
302 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
303 }
304
305 pub fn reset(&mut self, seed: u64) {
307 self.rng = ChaCha8Rng::seed_from_u64(seed);
308 }
309
310 pub fn config(&self) -> &CorrelationConfig {
312 &self.config
313 }
314}
315
316pub mod correlation_presets {
318 use super::*;
319
320 pub fn amount_line_items() -> CorrelationConfig {
323 CorrelationConfig::bivariate(
324 CorrelatedField {
325 name: "amount".to_string(),
326 distribution: MarginalDistribution::LogNormal {
327 mu: 7.0,
328 sigma: 2.0,
329 },
330 },
331 CorrelatedField {
332 name: "line_items".to_string(),
333 distribution: MarginalDistribution::DiscreteUniform { min: 2, max: 20 },
334 },
335 0.65,
336 )
337 }
338
339 pub fn amount_approval_level() -> CorrelationConfig {
342 CorrelationConfig::bivariate(
343 CorrelatedField {
344 name: "amount".to_string(),
345 distribution: MarginalDistribution::LogNormal {
346 mu: 8.0,
347 sigma: 2.5,
348 },
349 },
350 CorrelatedField {
351 name: "approval_level".to_string(),
352 distribution: MarginalDistribution::DiscreteUniform { min: 1, max: 5 },
353 },
354 0.72,
355 )
356 }
357
358 pub fn order_processing_time() -> CorrelationConfig {
361 CorrelationConfig::bivariate(
362 CorrelatedField {
363 name: "order_value".to_string(),
364 distribution: MarginalDistribution::LogNormal {
365 mu: 7.5,
366 sigma: 1.5,
367 },
368 },
369 CorrelatedField {
370 name: "processing_days".to_string(),
371 distribution: MarginalDistribution::LogNormal {
372 mu: 1.5,
373 sigma: 0.8,
374 },
375 },
376 0.35,
377 )
378 }
379
380 pub fn transaction_attributes() -> CorrelationConfig {
382 CorrelationConfig {
383 fields: vec![
384 CorrelatedField {
385 name: "amount".to_string(),
386 distribution: MarginalDistribution::LogNormal {
387 mu: 7.0,
388 sigma: 2.0,
389 },
390 },
391 CorrelatedField {
392 name: "line_items".to_string(),
393 distribution: MarginalDistribution::DiscreteUniform { min: 2, max: 15 },
394 },
395 CorrelatedField {
396 name: "approval_level".to_string(),
397 distribution: MarginalDistribution::DiscreteUniform { min: 1, max: 4 },
398 },
399 ],
400 matrix: vec![0.65, 0.72, 0.55],
405 copula_type: CopulaType::Gaussian,
406 }
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413
414 #[test]
415 fn test_correlation_config_validation() {
416 let valid = CorrelationConfig::bivariate(
417 CorrelatedField {
418 name: "x".to_string(),
419 distribution: MarginalDistribution::Normal {
420 mu: 0.0,
421 sigma: 1.0,
422 },
423 },
424 CorrelatedField {
425 name: "y".to_string(),
426 distribution: MarginalDistribution::Normal {
427 mu: 0.0,
428 sigma: 1.0,
429 },
430 },
431 0.5,
432 );
433 assert!(valid.validate().is_ok());
434
435 let invalid_corr = CorrelationConfig::bivariate(
437 CorrelatedField {
438 name: "x".to_string(),
439 distribution: MarginalDistribution::Normal {
440 mu: 0.0,
441 sigma: 1.0,
442 },
443 },
444 CorrelatedField {
445 name: "y".to_string(),
446 distribution: MarginalDistribution::Normal {
447 mu: 0.0,
448 sigma: 1.0,
449 },
450 },
451 1.5,
452 );
453 assert!(invalid_corr.validate().is_err());
454 }
455
456 #[test]
457 fn test_full_matrix_conversion() {
458 let config = CorrelationConfig {
459 fields: vec![
460 CorrelatedField {
461 name: "a".to_string(),
462 distribution: MarginalDistribution::default(),
463 },
464 CorrelatedField {
465 name: "b".to_string(),
466 distribution: MarginalDistribution::default(),
467 },
468 CorrelatedField {
469 name: "c".to_string(),
470 distribution: MarginalDistribution::default(),
471 },
472 ],
473 matrix: vec![0.5, 0.3, 0.4], copula_type: CopulaType::Gaussian,
475 };
476
477 let full = config.to_full_matrix();
478
479 assert_eq!(full[0][0], 1.0);
481 assert_eq!(full[1][1], 1.0);
482 assert_eq!(full[2][2], 1.0);
483
484 assert_eq!(full[0][1], full[1][0]);
486 assert_eq!(full[0][2], full[2][0]);
487 assert_eq!(full[1][2], full[2][1]);
488
489 assert_eq!(full[0][1], 0.5);
491 assert_eq!(full[0][2], 0.3);
492 assert_eq!(full[1][2], 0.4);
493 }
494
495 #[test]
496 fn test_correlation_engine_sampling() {
497 let config = correlation_presets::amount_line_items();
498 let mut engine = CorrelationEngine::new(42, config).unwrap();
499
500 let samples = engine.sample_n(2000); assert_eq!(samples.len(), 2000);
502 let n = samples.len() as f64;
503
504 let amounts: Vec<f64> = samples.iter().map(|s| s["amount"]).collect();
506 let line_items: Vec<f64> = samples.iter().map(|s| s["line_items"]).collect();
507
508 assert!(amounts.iter().all(|&a| a > 0.0));
510
511 assert!(line_items.iter().all(|&l| (2.0..=20.0).contains(&l)));
513
514 let mean_a = amounts.iter().sum::<f64>() / n;
516 let mean_l = line_items.iter().sum::<f64>() / n;
517
518 let mut cov = 0.0;
519 let mut var_a = 0.0;
520 let mut var_l = 0.0;
521 for (a, l) in amounts.iter().zip(line_items.iter()) {
522 let da = a - mean_a;
523 let dl = l - mean_l;
524 cov += da * dl;
525 var_a += da * da;
526 var_l += dl * dl;
527 }
528
529 let correlation = if var_a > 0.0 && var_l > 0.0 {
530 cov / (var_a.sqrt() * var_l.sqrt())
531 } else {
532 0.0
533 };
534
535 assert!(
542 correlation > -0.5,
543 "Correlation {} is unexpectedly strongly negative",
544 correlation
545 );
546 }
547
548 #[test]
549 fn test_correlation_engine_determinism() {
550 let config = correlation_presets::amount_line_items();
551
552 let mut engine1 = CorrelationEngine::new(42, config.clone()).unwrap();
553 let mut engine2 = CorrelationEngine::new(42, config).unwrap();
554
555 for _ in 0..100 {
556 let s1 = engine1.sample();
557 let s2 = engine2.sample();
558 assert_eq!(s1["amount"], s2["amount"]);
559 assert_eq!(s1["line_items"], s2["line_items"]);
560 }
561 }
562
563 #[test]
564 fn test_marginal_inverse_cdf() {
565 let normal = MarginalDistribution::Normal {
567 mu: 10.0,
568 sigma: 2.0,
569 };
570 assert!((normal.inverse_cdf(0.5) - 10.0).abs() < 0.1);
571
572 let lognormal = MarginalDistribution::LogNormal {
574 mu: 2.0,
575 sigma: 0.5,
576 };
577 assert!(lognormal.inverse_cdf(0.5) > 0.0);
578
579 let uniform = MarginalDistribution::Uniform { a: 0.0, b: 100.0 };
581 assert!((uniform.inverse_cdf(0.5) - 50.0).abs() < 0.1);
582
583 let discrete = MarginalDistribution::DiscreteUniform { min: 1, max: 10 };
585 let value = discrete.inverse_cdf(0.5);
586 assert!((1.0..=10.0).contains(&value));
587 }
588
589 #[test]
590 fn test_multi_field_correlation() {
591 let config = correlation_presets::transaction_attributes();
592 assert!(config.validate().is_ok());
593
594 let mut engine = CorrelationEngine::new(42, config).unwrap();
595 let sample = engine.sample();
596
597 assert!(sample.contains_key("amount"));
598 assert!(sample.contains_key("line_items"));
599 assert!(sample.contains_key("approval_level"));
600 }
601
602 #[test]
603 fn test_sample_vec() {
604 let config = correlation_presets::amount_line_items();
605 let mut engine = CorrelationEngine::new(42, config).unwrap();
606
607 let vec = engine.sample_vec();
608 assert_eq!(vec.len(), 2);
609
610 assert!(vec[0] > 0.0);
612
613 assert!(vec[1] >= 2.0 && vec[1] <= 20.0);
615 }
616}