datasynth_core/distributions/
copula.rs1use rand::prelude::*;
14use rand_chacha::ChaCha8Rng;
15use serde::{Deserialize, Serialize};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
19#[serde(rename_all = "snake_case")]
20pub enum CopulaType {
21 #[default]
23 Gaussian,
24 Clayton,
26 Gumbel,
28 Frank,
30 StudentT,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct CopulaConfig {
37 pub copula_type: CopulaType,
39 pub theta: f64,
45 #[serde(default = "default_df")]
47 pub degrees_of_freedom: f64,
48}
49
50fn default_df() -> f64 {
51 4.0
52}
53
54impl Default for CopulaConfig {
55 fn default() -> Self {
56 Self {
57 copula_type: CopulaType::Gaussian,
58 theta: 0.5,
59 degrees_of_freedom: 4.0,
60 }
61 }
62}
63
64impl CopulaConfig {
65 pub fn gaussian(correlation: f64) -> Self {
67 Self {
68 copula_type: CopulaType::Gaussian,
69 theta: correlation.clamp(-0.999, 0.999),
70 degrees_of_freedom: 4.0,
71 }
72 }
73
74 pub fn clayton(theta: f64) -> Self {
76 Self {
77 copula_type: CopulaType::Clayton,
78 theta: theta.max(0.001),
79 degrees_of_freedom: 4.0,
80 }
81 }
82
83 pub fn gumbel(theta: f64) -> Self {
85 Self {
86 copula_type: CopulaType::Gumbel,
87 theta: theta.max(1.0),
88 degrees_of_freedom: 4.0,
89 }
90 }
91
92 pub fn frank(theta: f64) -> Self {
94 Self {
95 copula_type: CopulaType::Frank,
96 theta: if theta.abs() < 0.001 { 0.001 } else { theta },
97 degrees_of_freedom: 4.0,
98 }
99 }
100
101 pub fn student_t(correlation: f64, df: f64) -> Self {
103 Self {
104 copula_type: CopulaType::StudentT,
105 theta: correlation.clamp(-0.999, 0.999),
106 degrees_of_freedom: df.max(2.0),
107 }
108 }
109
110 pub fn validate(&self) -> Result<(), String> {
112 match self.copula_type {
113 CopulaType::Gaussian | CopulaType::StudentT => {
114 if self.theta < -1.0 || self.theta > 1.0 {
115 return Err(format!(
116 "Correlation must be in [-1, 1], got {}",
117 self.theta
118 ));
119 }
120 }
121 CopulaType::Clayton => {
122 if self.theta <= 0.0 {
123 return Err(format!("Clayton theta must be > 0, got {}", self.theta));
124 }
125 }
126 CopulaType::Gumbel => {
127 if self.theta < 1.0 {
128 return Err(format!("Gumbel theta must be >= 1, got {}", self.theta));
129 }
130 }
131 CopulaType::Frank => {
132 if self.theta.abs() < 0.0001 {
133 return Err("Frank theta must be non-zero".to_string());
134 }
135 }
136 }
137
138 if self.copula_type == CopulaType::StudentT && self.degrees_of_freedom <= 0.0 {
139 return Err("Degrees of freedom must be positive".to_string());
140 }
141
142 Ok(())
143 }
144
145 pub fn kendalls_tau(&self) -> f64 {
147 match self.copula_type {
148 CopulaType::Gaussian | CopulaType::StudentT => {
149 2.0 * self.theta.asin() / std::f64::consts::PI
151 }
152 CopulaType::Clayton => {
153 self.theta / (self.theta + 2.0)
155 }
156 CopulaType::Gumbel => {
157 1.0 - 1.0 / self.theta
159 }
160 CopulaType::Frank => {
161 let abs_theta = self.theta.abs();
164 if abs_theta < 10.0 {
165 1.0 - 4.0 / self.theta + 4.0 / self.theta.powi(2) * debye_1(abs_theta)
166 } else {
167 self.theta.signum() * (1.0 - 4.0 / abs_theta)
169 }
170 }
171 }
172 }
173
174 pub fn lower_tail_dependence(&self) -> f64 {
176 match self.copula_type {
177 CopulaType::Gaussian | CopulaType::Frank => 0.0,
178 CopulaType::Clayton => 2.0_f64.powf(-1.0 / self.theta),
179 CopulaType::Gumbel => 0.0,
180 CopulaType::StudentT => {
181 let nu = self.degrees_of_freedom;
184 let rho = self.theta;
185 let arg = ((nu + 1.0) * (1.0 - rho) / (1.0 + rho)).sqrt();
186 2.0 * student_t_cdf(-arg, nu + 1.0)
187 }
188 }
189 }
190
191 pub fn upper_tail_dependence(&self) -> f64 {
193 match self.copula_type {
194 CopulaType::Gaussian | CopulaType::Frank => 0.0,
195 CopulaType::Clayton => 0.0,
196 CopulaType::Gumbel => 2.0 - 2.0_f64.powf(1.0 / self.theta),
197 CopulaType::StudentT => self.lower_tail_dependence(), }
199 }
200}
201
202#[derive(Clone)]
204pub struct BivariateCopulaSampler {
205 rng: ChaCha8Rng,
206 config: CopulaConfig,
207}
208
209impl BivariateCopulaSampler {
210 pub fn new(seed: u64, config: CopulaConfig) -> Result<Self, String> {
212 config.validate()?;
213 Ok(Self {
214 rng: ChaCha8Rng::seed_from_u64(seed),
215 config,
216 })
217 }
218
219 pub fn sample(&mut self) -> (f64, f64) {
221 match self.config.copula_type {
222 CopulaType::Gaussian => self.sample_gaussian(),
223 CopulaType::Clayton => self.sample_clayton(),
224 CopulaType::Gumbel => self.sample_gumbel(),
225 CopulaType::Frank => self.sample_frank(),
226 CopulaType::StudentT => self.sample_student_t(),
227 }
228 }
229
230 fn sample_gaussian(&mut self) -> (f64, f64) {
232 let rho = self.config.theta;
233
234 let z1 = self.sample_standard_normal();
236 let z2 = self.sample_standard_normal();
237
238 let x1 = z1;
240 let x2 = rho * z1 + (1.0 - rho.powi(2)).sqrt() * z2;
241
242 (standard_normal_cdf(x1), standard_normal_cdf(x2))
244 }
245
246 fn sample_clayton(&mut self) -> (f64, f64) {
248 let theta = self.config.theta;
249
250 let u: f64 = self.rng.random();
252 let t: f64 = self.rng.random();
253
254 let v = (u.powf(-theta) * (t.powf(-theta / (theta + 1.0)) - 1.0) + 1.0).powf(-1.0 / theta);
256
257 (u, v.clamp(0.0, 1.0))
258 }
259
260 fn sample_gumbel(&mut self) -> (f64, f64) {
262 let theta = self.config.theta;
263
264 let u: f64 = self.rng.random();
268 let t: f64 = self.rng.random();
269
270 let s = sample_positive_stable(&mut self.rng, 1.0 / theta);
272 let e1 = sample_exponential(&mut self.rng, 1.0);
273 let e2 = sample_exponential(&mut self.rng, 1.0);
274
275 let v1 = (-e1 / s).exp().powf(1.0 / theta);
276 let v2 = (-e2 / s).exp().powf(1.0 / theta);
277
278 let c_u = v1 / (v1 + v2);
279 let c_v = v2 / (v1 + v2);
280
281 let u_out = (-((-u.ln()).powf(theta) + (-c_u.ln()).powf(theta)).powf(1.0 / theta)).exp();
283 let v_out = (-((-t.ln()).powf(theta) + (-c_v.ln()).powf(theta)).powf(1.0 / theta)).exp();
284
285 (u_out.clamp(0.0001, 0.9999), v_out.clamp(0.0001, 0.9999))
286 }
287
288 fn sample_frank(&mut self) -> (f64, f64) {
290 let theta = self.config.theta;
291
292 let u: f64 = self.rng.random();
293 let t: f64 = self.rng.random();
294
295 let v = -((1.0 - t)
297 / (t * (-theta).exp() + (1.0 - t) * (1.0 - u * (1.0 - (-theta).exp())).recip()))
298 .ln()
299 / theta;
300
301 (u, v.clamp(0.0, 1.0))
302 }
303
304 fn sample_student_t(&mut self) -> (f64, f64) {
306 let rho = self.config.theta;
307 let nu = self.config.degrees_of_freedom;
308
309 let chi2 = sample_chi_squared(&mut self.rng, nu);
311 let scale = (nu / chi2).sqrt();
312
313 let z1 = self.sample_standard_normal();
315 let z2 = self.sample_standard_normal();
316
317 let x1 = z1 * scale;
318 let x2 = (rho * z1 + (1.0 - rho.powi(2)).sqrt() * z2) * scale;
319
320 (student_t_cdf(x1, nu), student_t_cdf(x2, nu))
322 }
323
324 fn sample_standard_normal(&mut self) -> f64 {
326 let u1: f64 = self.rng.random();
327 let u2: f64 = self.rng.random();
328 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
329 }
330
331 pub fn sample_n(&mut self, n: usize) -> Vec<(f64, f64)> {
333 (0..n).map(|_| self.sample()).collect()
334 }
335
336 pub fn reset(&mut self, seed: u64) {
338 self.rng = ChaCha8Rng::seed_from_u64(seed);
339 }
340
341 pub fn config(&self) -> &CopulaConfig {
343 &self.config
344 }
345}
346
347pub fn cholesky_decompose(matrix: &[Vec<f64>]) -> Option<Vec<Vec<f64>>> {
349 let n = matrix.len();
350 let mut l = vec![vec![0.0; n]; n];
351
352 for i in 0..n {
353 for j in 0..=i {
354 let sum: f64 = (0..j).map(|k| l[i][k] * l[j][k]).sum();
355
356 if i == j {
357 let diag = matrix[i][i] - sum;
358 if diag <= 0.0 {
359 l[i][j] = (diag + 0.001).sqrt();
361 } else {
362 l[i][j] = diag.sqrt();
363 }
364 } else {
365 if l[j][j].abs() < 1e-10 {
366 return None;
367 }
368 l[i][j] = (matrix[i][j] - sum) / l[j][j];
369 }
370 }
371 }
372
373 Some(l)
374}
375
376pub fn standard_normal_cdf(x: f64) -> f64 {
378 0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
379}
380
381pub fn standard_normal_quantile(p: f64) -> f64 {
388 use statrs::distribution::{ContinuousCDF, Normal};
389 Normal::new(0.0, 1.0)
390 .expect("standard normal distribution")
391 .inverse_cdf(p.clamp(1e-12, 1.0 - 1e-12))
392}
393
394fn erf(x: f64) -> f64 {
396 let a1 = 0.254829592;
397 let a2 = -0.284496736;
398 let a3 = 1.421413741;
399 let a4 = -1.453152027;
400 let a5 = 1.061405429;
401 let p = 0.3275911;
402
403 let sign = if x < 0.0 { -1.0 } else { 1.0 };
404 let x = x.abs();
405
406 let t = 1.0 / (1.0 + p * x);
407 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
408
409 sign * y
410}
411
412fn student_t_cdf(x: f64, df: f64) -> f64 {
414 if df > 30.0 {
416 return standard_normal_cdf(x);
417 }
418
419 let t2 = x * x;
421 let prob = 0.5 * incomplete_beta(df / 2.0, 0.5, df / (df + t2));
422
423 if x > 0.0 {
424 1.0 - prob
425 } else {
426 prob
427 }
428}
429
430fn incomplete_beta(a: f64, b: f64, x: f64) -> f64 {
444 if x <= 0.0 {
445 return 0.0;
446 }
447 if x >= 1.0 {
448 return 1.0;
449 }
450
451 let lbeta = ln_gamma(a) + ln_gamma(b) - ln_gamma(a + b);
453 let front = (x.powf(a) * (1.0 - x).powf(b)) / lbeta.exp();
454
455 let mut c: f64 = 1.0;
457 let mut d: f64 = 1.0 / (1.0 - (a + b) * x / (a + 1.0)).max(1e-30);
458 let mut h = d;
459
460 for m in 1..100 {
461 let m = m as f64;
462 let d1 = m * (b - m) * x / ((a + 2.0 * m - 1.0) * (a + 2.0 * m));
463 let d2 = -(a + m) * (a + b + m) * x / ((a + 2.0 * m) * (a + 2.0 * m + 1.0));
464
465 d = 1.0 / (1.0 + d1 * d).max(1e-30);
466 c = 1.0 + d1 / c.max(1e-30);
467 h *= c * d;
468
469 d = 1.0 / (1.0 + d2 * d).max(1e-30);
470 c = 1.0 + d2 / c.max(1e-30);
471 h *= c * d;
472
473 if ((c * d) - 1.0).abs() < 1e-8 {
474 break;
475 }
476 }
477
478 front * h / a
479}
480
481fn ln_gamma(x: f64) -> f64 {
483 if x <= 0.0 {
484 return f64::INFINITY;
485 }
486 0.5 * (2.0 * std::f64::consts::PI / x).ln() + x * ((x + 1.0 / (12.0 * x)).ln() - 1.0)
487}
488
489fn debye_1(x: f64) -> f64 {
491 if x.abs() < 0.01 {
492 return 1.0 - x / 4.0 + x.powi(2) / 36.0;
493 }
494
495 let n = 100;
497 let h = x / n as f64;
498 let mut sum = 0.0;
499
500 for i in 1..n {
501 let t = i as f64 * h;
502 sum += t / (t.exp() - 1.0);
503 }
504
505 (sum + 0.5 * (h / (h.exp() - 1.0) + x / (x.exp() - 1.0))) * h / x
506}
507
508fn sample_exponential(rng: &mut ChaCha8Rng, lambda: f64) -> f64 {
510 let u: f64 = rng.random();
511 -u.ln() / lambda
512}
513
514fn sample_chi_squared(rng: &mut ChaCha8Rng, df: f64) -> f64 {
516 let n = df.floor() as usize;
517 let mut sum = 0.0;
518 for _ in 0..n {
519 let u1: f64 = rng.random();
520 let u2: f64 = rng.random();
521 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
522 sum += z * z;
523 }
524 sum
525}
526
527fn sample_positive_stable(rng: &mut ChaCha8Rng, alpha: f64) -> f64 {
529 if (alpha - 1.0).abs() < 0.001 {
530 return 1.0;
531 }
532
533 let u: f64 = rng.random::<f64>() * std::f64::consts::PI - std::f64::consts::PI / 2.0;
534 let e = sample_exponential(rng, 1.0);
535
536 let b = (std::f64::consts::PI * alpha / 2.0).tan();
537 let s = (1.0 + b * b).powf(1.0 / (2.0 * alpha));
538
539 let term1 = (alpha * u).sin();
540 let term2 = (u.cos()).powf(1.0 / alpha);
541 let term3 = ((1.0 - alpha) * u).cos() / e;
542
543 s * term1 / term2 * term3.powf((1.0 - alpha) / alpha)
544}
545
546#[cfg(test)]
547mod tests {
548 use super::*;
549
550 #[test]
551 fn test_copula_validation() {
552 let gaussian = CopulaConfig::gaussian(0.5);
553 assert!(gaussian.validate().is_ok());
554
555 let invalid_gaussian = CopulaConfig {
557 copula_type: CopulaType::Gaussian,
558 theta: 1.5, degrees_of_freedom: 4.0,
560 };
561 assert!(invalid_gaussian.validate().is_err());
562
563 let clayton = CopulaConfig::clayton(2.0);
564 assert!(clayton.validate().is_ok());
565
566 let invalid_clayton = CopulaConfig {
568 copula_type: CopulaType::Clayton,
569 theta: -1.0, degrees_of_freedom: 4.0,
571 };
572 assert!(invalid_clayton.validate().is_err());
573
574 let gumbel = CopulaConfig::gumbel(2.0);
575 assert!(gumbel.validate().is_ok());
576
577 let invalid_gumbel = CopulaConfig {
579 copula_type: CopulaType::Gumbel,
580 theta: 0.5, degrees_of_freedom: 4.0,
582 };
583 assert!(invalid_gumbel.validate().is_err());
584 }
585
586 #[test]
587 fn test_gaussian_copula_sampling() {
588 let config = CopulaConfig::gaussian(0.7);
589 let mut sampler = BivariateCopulaSampler::new(42, config).unwrap();
590
591 let samples = sampler.sample_n(1000);
592 assert_eq!(samples.len(), 1000);
593
594 assert!(samples
596 .iter()
597 .all(|(u, v)| *u >= 0.0 && *u <= 1.0 && *v >= 0.0 && *v <= 1.0));
598
599 let mean_u: f64 = samples.iter().map(|(u, _)| u).sum::<f64>() / 1000.0;
601 let mean_v: f64 = samples.iter().map(|(_, v)| v).sum::<f64>() / 1000.0;
602 let covariance: f64 = samples
603 .iter()
604 .map(|(u, v)| (u - mean_u) * (v - mean_v))
605 .sum::<f64>()
606 / 1000.0;
607
608 assert!(covariance > 0.0); }
610
611 #[test]
612 fn test_copula_determinism() {
613 let config = CopulaConfig::gaussian(0.5);
614
615 let mut sampler1 = BivariateCopulaSampler::new(42, config.clone()).unwrap();
616 let mut sampler2 = BivariateCopulaSampler::new(42, config).unwrap();
617
618 for _ in 0..100 {
619 assert_eq!(sampler1.sample(), sampler2.sample());
620 }
621 }
622
623 #[test]
624 fn test_kendalls_tau() {
625 let gaussian = CopulaConfig::gaussian(0.5);
627 let tau = gaussian.kendalls_tau();
628 let expected = 2.0 * (0.5_f64).asin() / std::f64::consts::PI;
629 assert!((tau - expected).abs() < 0.001);
630
631 let clayton = CopulaConfig::clayton(2.0);
633 let tau = clayton.kendalls_tau();
634 assert!((tau - 0.5).abs() < 0.001);
635
636 let gumbel = CopulaConfig::gumbel(2.0);
638 let tau = gumbel.kendalls_tau();
639 assert!((tau - 0.5).abs() < 0.001);
640 }
641
642 #[test]
643 fn test_tail_dependence() {
644 let gaussian = CopulaConfig::gaussian(0.7);
646 assert_eq!(gaussian.lower_tail_dependence(), 0.0);
647 assert_eq!(gaussian.upper_tail_dependence(), 0.0);
648
649 let clayton = CopulaConfig::clayton(2.0);
651 assert!(clayton.lower_tail_dependence() > 0.0);
652 assert_eq!(clayton.upper_tail_dependence(), 0.0);
653
654 let gumbel = CopulaConfig::gumbel(2.0);
656 assert_eq!(gumbel.lower_tail_dependence(), 0.0);
657 assert!(gumbel.upper_tail_dependence() > 0.0);
658 }
659
660 #[test]
661 fn test_cholesky_decomposition() {
662 let matrix = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
663 let l = cholesky_decompose(&matrix).unwrap();
664
665 let reconstructed_00 = l[0][0] * l[0][0];
667 let reconstructed_01 = l[0][0] * l[1][0];
668 let reconstructed_11 = l[1][0] * l[1][0] + l[1][1] * l[1][1];
669
670 assert!((reconstructed_00 - 1.0).abs() < 0.001);
671 assert!((reconstructed_01 - 0.5).abs() < 0.001);
672 assert!((reconstructed_11 - 1.0).abs() < 0.001);
673 }
674
675 #[test]
676 fn test_standard_normal_cdf() {
677 assert!((standard_normal_cdf(0.0) - 0.5).abs() < 0.001);
678 assert!(standard_normal_cdf(-3.0) < 0.01);
679 assert!(standard_normal_cdf(3.0) > 0.99);
680 }
681
682 #[test]
683 fn test_clayton_copula() {
684 let config = CopulaConfig::clayton(2.0);
685 let mut sampler = BivariateCopulaSampler::new(42, config).unwrap();
686
687 let samples = sampler.sample_n(1000);
688 assert!(samples
689 .iter()
690 .all(|(u, v)| *u >= 0.0 && *u <= 1.0 && *v >= 0.0 && *v <= 1.0));
691 }
692
693 #[test]
694 fn test_frank_copula() {
695 let config = CopulaConfig::frank(5.0);
696 let mut sampler = BivariateCopulaSampler::new(42, config).unwrap();
697
698 let samples = sampler.sample_n(1000);
699 assert!(samples
700 .iter()
701 .all(|(u, v)| *u >= 0.0 && *u <= 1.0 && *v >= 0.0 && *v <= 1.0));
702 }
703
704 #[test]
707 fn standard_normal_quantile_lower_tail_correct_sign() {
708 let q = standard_normal_quantile(0.025);
709 assert!(q < 0.0, "Φ⁻¹(0.025) should be negative, got {q}");
710 assert!((q + 1.96).abs() < 0.01, "expected ≈ -1.96, got {q}");
711 }
712
713 #[test]
714 fn standard_normal_quantile_upper_tail_correct_sign() {
715 let q = standard_normal_quantile(0.975);
716 assert!(q > 0.0, "Φ⁻¹(0.975) should be positive, got {q}");
717 assert!((q - 1.96).abs() < 0.01, "expected ≈ +1.96, got {q}");
718 }
719
720 #[test]
721 fn standard_normal_quantile_median_zero() {
722 assert!(standard_normal_quantile(0.5).abs() < 1e-9);
723 }
724
725 #[test]
726 fn standard_normal_quantile_extreme_tails_bounded() {
727 let q_lo = standard_normal_quantile(0.001);
729 let q_hi = standard_normal_quantile(0.999);
730 assert!(q_lo < -3.0 && q_lo > -3.2);
731 assert!(q_hi > 3.0 && q_hi < 3.2);
732 assert!(q_lo.is_sign_negative());
734 assert!(q_hi.is_sign_positive());
735 }
736}