datasynth_core/distributions/
copula.rs1use rand::prelude::*;
14use rand_chacha::ChaCha8Rng;
15use serde::{Deserialize, Serialize};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
19#[serde(rename_all = "snake_case")]
20pub enum CopulaType {
21 #[default]
23 Gaussian,
24 Clayton,
26 Gumbel,
28 Frank,
30 StudentT,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct CopulaConfig {
37 pub copula_type: CopulaType,
39 pub theta: f64,
45 #[serde(default = "default_df")]
47 pub degrees_of_freedom: f64,
48}
49
50fn default_df() -> f64 {
51 4.0
52}
53
54impl Default for CopulaConfig {
55 fn default() -> Self {
56 Self {
57 copula_type: CopulaType::Gaussian,
58 theta: 0.5,
59 degrees_of_freedom: 4.0,
60 }
61 }
62}
63
64impl CopulaConfig {
65 pub fn gaussian(correlation: f64) -> Self {
67 Self {
68 copula_type: CopulaType::Gaussian,
69 theta: correlation.clamp(-0.999, 0.999),
70 degrees_of_freedom: 4.0,
71 }
72 }
73
74 pub fn clayton(theta: f64) -> Self {
76 Self {
77 copula_type: CopulaType::Clayton,
78 theta: theta.max(0.001),
79 degrees_of_freedom: 4.0,
80 }
81 }
82
83 pub fn gumbel(theta: f64) -> Self {
85 Self {
86 copula_type: CopulaType::Gumbel,
87 theta: theta.max(1.0),
88 degrees_of_freedom: 4.0,
89 }
90 }
91
92 pub fn frank(theta: f64) -> Self {
94 Self {
95 copula_type: CopulaType::Frank,
96 theta: if theta.abs() < 0.001 { 0.001 } else { theta },
97 degrees_of_freedom: 4.0,
98 }
99 }
100
101 pub fn student_t(correlation: f64, df: f64) -> Self {
103 Self {
104 copula_type: CopulaType::StudentT,
105 theta: correlation.clamp(-0.999, 0.999),
106 degrees_of_freedom: df.max(2.0),
107 }
108 }
109
110 pub fn validate(&self) -> Result<(), String> {
112 match self.copula_type {
113 CopulaType::Gaussian | CopulaType::StudentT => {
114 if self.theta < -1.0 || self.theta > 1.0 {
115 return Err(format!(
116 "Correlation must be in [-1, 1], got {}",
117 self.theta
118 ));
119 }
120 }
121 CopulaType::Clayton => {
122 if self.theta <= 0.0 {
123 return Err(format!("Clayton theta must be > 0, got {}", self.theta));
124 }
125 }
126 CopulaType::Gumbel => {
127 if self.theta < 1.0 {
128 return Err(format!("Gumbel theta must be >= 1, got {}", self.theta));
129 }
130 }
131 CopulaType::Frank => {
132 if self.theta.abs() < 0.0001 {
133 return Err("Frank theta must be non-zero".to_string());
134 }
135 }
136 }
137
138 if self.copula_type == CopulaType::StudentT && self.degrees_of_freedom <= 0.0 {
139 return Err("Degrees of freedom must be positive".to_string());
140 }
141
142 Ok(())
143 }
144
145 pub fn kendalls_tau(&self) -> f64 {
147 match self.copula_type {
148 CopulaType::Gaussian | CopulaType::StudentT => {
149 2.0 * self.theta.asin() / std::f64::consts::PI
151 }
152 CopulaType::Clayton => {
153 self.theta / (self.theta + 2.0)
155 }
156 CopulaType::Gumbel => {
157 1.0 - 1.0 / self.theta
159 }
160 CopulaType::Frank => {
161 let abs_theta = self.theta.abs();
164 if abs_theta < 10.0 {
165 1.0 - 4.0 / self.theta + 4.0 / self.theta.powi(2) * debye_1(abs_theta)
166 } else {
167 self.theta.signum() * (1.0 - 4.0 / abs_theta)
169 }
170 }
171 }
172 }
173
174 pub fn lower_tail_dependence(&self) -> f64 {
176 match self.copula_type {
177 CopulaType::Gaussian | CopulaType::Frank => 0.0,
178 CopulaType::Clayton => 2.0_f64.powf(-1.0 / self.theta),
179 CopulaType::Gumbel => 0.0,
180 CopulaType::StudentT => {
181 let nu = self.degrees_of_freedom;
184 let rho = self.theta;
185 let arg = ((nu + 1.0) * (1.0 - rho) / (1.0 + rho)).sqrt();
186 2.0 * student_t_cdf(-arg, nu + 1.0)
187 }
188 }
189 }
190
191 pub fn upper_tail_dependence(&self) -> f64 {
193 match self.copula_type {
194 CopulaType::Gaussian | CopulaType::Frank => 0.0,
195 CopulaType::Clayton => 0.0,
196 CopulaType::Gumbel => 2.0 - 2.0_f64.powf(1.0 / self.theta),
197 CopulaType::StudentT => self.lower_tail_dependence(), }
199 }
200}
201
202#[derive(Clone)]
204pub struct BivariateCopulaSampler {
205 rng: ChaCha8Rng,
206 config: CopulaConfig,
207}
208
209impl BivariateCopulaSampler {
210 pub fn new(seed: u64, config: CopulaConfig) -> Result<Self, String> {
212 config.validate()?;
213 Ok(Self {
214 rng: ChaCha8Rng::seed_from_u64(seed),
215 config,
216 })
217 }
218
219 pub fn sample(&mut self) -> (f64, f64) {
221 match self.config.copula_type {
222 CopulaType::Gaussian => self.sample_gaussian(),
223 CopulaType::Clayton => self.sample_clayton(),
224 CopulaType::Gumbel => self.sample_gumbel(),
225 CopulaType::Frank => self.sample_frank(),
226 CopulaType::StudentT => self.sample_student_t(),
227 }
228 }
229
230 fn sample_gaussian(&mut self) -> (f64, f64) {
232 let rho = self.config.theta;
233
234 let z1 = self.sample_standard_normal();
236 let z2 = self.sample_standard_normal();
237
238 let x1 = z1;
240 let x2 = rho * z1 + (1.0 - rho.powi(2)).sqrt() * z2;
241
242 (standard_normal_cdf(x1), standard_normal_cdf(x2))
244 }
245
246 fn sample_clayton(&mut self) -> (f64, f64) {
248 let theta = self.config.theta;
249
250 let u: f64 = self.rng.random();
252 let t: f64 = self.rng.random();
253
254 let v = (u.powf(-theta) * (t.powf(-theta / (theta + 1.0)) - 1.0) + 1.0).powf(-1.0 / theta);
256
257 (u, v.clamp(0.0, 1.0))
258 }
259
260 fn sample_gumbel(&mut self) -> (f64, f64) {
262 let theta = self.config.theta;
263
264 let u: f64 = self.rng.random();
268 let t: f64 = self.rng.random();
269
270 let s = sample_positive_stable(&mut self.rng, 1.0 / theta);
272 let e1 = sample_exponential(&mut self.rng, 1.0);
273 let e2 = sample_exponential(&mut self.rng, 1.0);
274
275 let v1 = (-e1 / s).exp().powf(1.0 / theta);
276 let v2 = (-e2 / s).exp().powf(1.0 / theta);
277
278 let c_u = v1 / (v1 + v2);
279 let c_v = v2 / (v1 + v2);
280
281 let u_out = (-((-u.ln()).powf(theta) + (-c_u.ln()).powf(theta)).powf(1.0 / theta)).exp();
283 let v_out = (-((-t.ln()).powf(theta) + (-c_v.ln()).powf(theta)).powf(1.0 / theta)).exp();
284
285 (u_out.clamp(0.0001, 0.9999), v_out.clamp(0.0001, 0.9999))
286 }
287
288 fn sample_frank(&mut self) -> (f64, f64) {
290 let theta = self.config.theta;
291
292 let u: f64 = self.rng.random();
293 let t: f64 = self.rng.random();
294
295 let v = -((1.0 - t)
297 / (t * (-theta).exp() + (1.0 - t) * (1.0 - u * (1.0 - (-theta).exp())).recip()))
298 .ln()
299 / theta;
300
301 (u, v.clamp(0.0, 1.0))
302 }
303
304 fn sample_student_t(&mut self) -> (f64, f64) {
306 let rho = self.config.theta;
307 let nu = self.config.degrees_of_freedom;
308
309 let chi2 = sample_chi_squared(&mut self.rng, nu);
311 let scale = (nu / chi2).sqrt();
312
313 let z1 = self.sample_standard_normal();
315 let z2 = self.sample_standard_normal();
316
317 let x1 = z1 * scale;
318 let x2 = (rho * z1 + (1.0 - rho.powi(2)).sqrt() * z2) * scale;
319
320 (student_t_cdf(x1, nu), student_t_cdf(x2, nu))
322 }
323
324 fn sample_standard_normal(&mut self) -> f64 {
326 let u1: f64 = self.rng.random();
327 let u2: f64 = self.rng.random();
328 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
329 }
330
331 pub fn sample_n(&mut self, n: usize) -> Vec<(f64, f64)> {
333 (0..n).map(|_| self.sample()).collect()
334 }
335
336 pub fn reset(&mut self, seed: u64) {
338 self.rng = ChaCha8Rng::seed_from_u64(seed);
339 }
340
341 pub fn config(&self) -> &CopulaConfig {
343 &self.config
344 }
345}
346
347pub fn cholesky_decompose(matrix: &[Vec<f64>]) -> Option<Vec<Vec<f64>>> {
349 let n = matrix.len();
350 let mut l = vec![vec![0.0; n]; n];
351
352 for i in 0..n {
353 for j in 0..=i {
354 let sum: f64 = (0..j).map(|k| l[i][k] * l[j][k]).sum();
355
356 if i == j {
357 let diag = matrix[i][i] - sum;
358 if diag <= 0.0 {
359 l[i][j] = (diag + 0.001).sqrt();
361 } else {
362 l[i][j] = diag.sqrt();
363 }
364 } else {
365 if l[j][j].abs() < 1e-10 {
366 return None;
367 }
368 l[i][j] = (matrix[i][j] - sum) / l[j][j];
369 }
370 }
371 }
372
373 Some(l)
374}
375
376pub fn standard_normal_cdf(x: f64) -> f64 {
378 0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
379}
380
381pub fn standard_normal_quantile(p: f64) -> f64 {
383 if p <= 0.0 {
385 return f64::NEG_INFINITY;
386 }
387 if p >= 1.0 {
388 return f64::INFINITY;
389 }
390
391 let p_low = 0.02425;
392 let p_high = 1.0 - p_low;
393
394 if p < p_low {
395 let q = (-2.0 * p.ln()).sqrt();
397 let c = [2.515517, 0.802853, 0.010328];
398 let d = [1.432788, 0.189269, 0.001308];
399 -(c[0] + c[1] * q + c[2] * q.powi(2))
400 / (1.0 + d[0] * q + d[1] * q.powi(2) + d[2] * q.powi(3))
401 + q
402 } else if p <= p_high {
403 let q = p - 0.5;
405 let r = q * q;
406 let a = [
407 2.50662823884,
408 -18.61500062529,
409 41.39119773534,
410 -25.44106049637,
411 ];
412 let b = [
413 -8.47351093090,
414 23.08336743743,
415 -21.06224101826,
416 3.13082909833,
417 ];
418 q * (a[0] + a[1] * r + a[2] * r.powi(2) + a[3] * r.powi(3))
419 / (1.0 + b[0] * r + b[1] * r.powi(2) + b[2] * r.powi(3) + b[3] * r.powi(4))
420 } else {
421 let q = (-2.0 * (1.0 - p).ln()).sqrt();
423 let c = [2.515517, 0.802853, 0.010328];
424 let d = [1.432788, 0.189269, 0.001308];
425 (c[0] + c[1] * q + c[2] * q.powi(2))
426 / (1.0 + d[0] * q + d[1] * q.powi(2) + d[2] * q.powi(3))
427 - q
428 }
429}
430
431fn erf(x: f64) -> f64 {
433 let a1 = 0.254829592;
434 let a2 = -0.284496736;
435 let a3 = 1.421413741;
436 let a4 = -1.453152027;
437 let a5 = 1.061405429;
438 let p = 0.3275911;
439
440 let sign = if x < 0.0 { -1.0 } else { 1.0 };
441 let x = x.abs();
442
443 let t = 1.0 / (1.0 + p * x);
444 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
445
446 sign * y
447}
448
449fn student_t_cdf(x: f64, df: f64) -> f64 {
451 if df > 30.0 {
453 return standard_normal_cdf(x);
454 }
455
456 let t2 = x * x;
458 let prob = 0.5 * incomplete_beta(df / 2.0, 0.5, df / (df + t2));
459
460 if x > 0.0 {
461 1.0 - prob
462 } else {
463 prob
464 }
465}
466
467fn incomplete_beta(a: f64, b: f64, x: f64) -> f64 {
481 if x <= 0.0 {
482 return 0.0;
483 }
484 if x >= 1.0 {
485 return 1.0;
486 }
487
488 let lbeta = ln_gamma(a) + ln_gamma(b) - ln_gamma(a + b);
490 let front = (x.powf(a) * (1.0 - x).powf(b)) / lbeta.exp();
491
492 let mut c: f64 = 1.0;
494 let mut d: f64 = 1.0 / (1.0 - (a + b) * x / (a + 1.0)).max(1e-30);
495 let mut h = d;
496
497 for m in 1..100 {
498 let m = m as f64;
499 let d1 = m * (b - m) * x / ((a + 2.0 * m - 1.0) * (a + 2.0 * m));
500 let d2 = -(a + m) * (a + b + m) * x / ((a + 2.0 * m) * (a + 2.0 * m + 1.0));
501
502 d = 1.0 / (1.0 + d1 * d).max(1e-30);
503 c = 1.0 + d1 / c.max(1e-30);
504 h *= c * d;
505
506 d = 1.0 / (1.0 + d2 * d).max(1e-30);
507 c = 1.0 + d2 / c.max(1e-30);
508 h *= c * d;
509
510 if ((c * d) - 1.0).abs() < 1e-8 {
511 break;
512 }
513 }
514
515 front * h / a
516}
517
518fn ln_gamma(x: f64) -> f64 {
520 if x <= 0.0 {
521 return f64::INFINITY;
522 }
523 0.5 * (2.0 * std::f64::consts::PI / x).ln() + x * ((x + 1.0 / (12.0 * x)).ln() - 1.0)
524}
525
526fn debye_1(x: f64) -> f64 {
528 if x.abs() < 0.01 {
529 return 1.0 - x / 4.0 + x.powi(2) / 36.0;
530 }
531
532 let n = 100;
534 let h = x / n as f64;
535 let mut sum = 0.0;
536
537 for i in 1..n {
538 let t = i as f64 * h;
539 sum += t / (t.exp() - 1.0);
540 }
541
542 (sum + 0.5 * (h / (h.exp() - 1.0) + x / (x.exp() - 1.0))) * h / x
543}
544
545fn sample_exponential(rng: &mut ChaCha8Rng, lambda: f64) -> f64 {
547 let u: f64 = rng.random();
548 -u.ln() / lambda
549}
550
551fn sample_chi_squared(rng: &mut ChaCha8Rng, df: f64) -> f64 {
553 let n = df.floor() as usize;
554 let mut sum = 0.0;
555 for _ in 0..n {
556 let u1: f64 = rng.random();
557 let u2: f64 = rng.random();
558 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
559 sum += z * z;
560 }
561 sum
562}
563
564fn sample_positive_stable(rng: &mut ChaCha8Rng, alpha: f64) -> f64 {
566 if (alpha - 1.0).abs() < 0.001 {
567 return 1.0;
568 }
569
570 let u: f64 = rng.random::<f64>() * std::f64::consts::PI - std::f64::consts::PI / 2.0;
571 let e = sample_exponential(rng, 1.0);
572
573 let b = (std::f64::consts::PI * alpha / 2.0).tan();
574 let s = (1.0 + b * b).powf(1.0 / (2.0 * alpha));
575
576 let term1 = (alpha * u).sin();
577 let term2 = (u.cos()).powf(1.0 / alpha);
578 let term3 = ((1.0 - alpha) * u).cos() / e;
579
580 s * term1 / term2 * term3.powf((1.0 - alpha) / alpha)
581}
582
583#[cfg(test)]
584#[allow(clippy::unwrap_used)]
585mod tests {
586 use super::*;
587
588 #[test]
589 fn test_copula_validation() {
590 let gaussian = CopulaConfig::gaussian(0.5);
591 assert!(gaussian.validate().is_ok());
592
593 let invalid_gaussian = CopulaConfig {
595 copula_type: CopulaType::Gaussian,
596 theta: 1.5, degrees_of_freedom: 4.0,
598 };
599 assert!(invalid_gaussian.validate().is_err());
600
601 let clayton = CopulaConfig::clayton(2.0);
602 assert!(clayton.validate().is_ok());
603
604 let invalid_clayton = CopulaConfig {
606 copula_type: CopulaType::Clayton,
607 theta: -1.0, degrees_of_freedom: 4.0,
609 };
610 assert!(invalid_clayton.validate().is_err());
611
612 let gumbel = CopulaConfig::gumbel(2.0);
613 assert!(gumbel.validate().is_ok());
614
615 let invalid_gumbel = CopulaConfig {
617 copula_type: CopulaType::Gumbel,
618 theta: 0.5, degrees_of_freedom: 4.0,
620 };
621 assert!(invalid_gumbel.validate().is_err());
622 }
623
624 #[test]
625 fn test_gaussian_copula_sampling() {
626 let config = CopulaConfig::gaussian(0.7);
627 let mut sampler = BivariateCopulaSampler::new(42, config).unwrap();
628
629 let samples = sampler.sample_n(1000);
630 assert_eq!(samples.len(), 1000);
631
632 assert!(samples
634 .iter()
635 .all(|(u, v)| *u >= 0.0 && *u <= 1.0 && *v >= 0.0 && *v <= 1.0));
636
637 let mean_u: f64 = samples.iter().map(|(u, _)| u).sum::<f64>() / 1000.0;
639 let mean_v: f64 = samples.iter().map(|(_, v)| v).sum::<f64>() / 1000.0;
640 let covariance: f64 = samples
641 .iter()
642 .map(|(u, v)| (u - mean_u) * (v - mean_v))
643 .sum::<f64>()
644 / 1000.0;
645
646 assert!(covariance > 0.0); }
648
649 #[test]
650 fn test_copula_determinism() {
651 let config = CopulaConfig::gaussian(0.5);
652
653 let mut sampler1 = BivariateCopulaSampler::new(42, config.clone()).unwrap();
654 let mut sampler2 = BivariateCopulaSampler::new(42, config).unwrap();
655
656 for _ in 0..100 {
657 assert_eq!(sampler1.sample(), sampler2.sample());
658 }
659 }
660
661 #[test]
662 fn test_kendalls_tau() {
663 let gaussian = CopulaConfig::gaussian(0.5);
665 let tau = gaussian.kendalls_tau();
666 let expected = 2.0 * (0.5_f64).asin() / std::f64::consts::PI;
667 assert!((tau - expected).abs() < 0.001);
668
669 let clayton = CopulaConfig::clayton(2.0);
671 let tau = clayton.kendalls_tau();
672 assert!((tau - 0.5).abs() < 0.001);
673
674 let gumbel = CopulaConfig::gumbel(2.0);
676 let tau = gumbel.kendalls_tau();
677 assert!((tau - 0.5).abs() < 0.001);
678 }
679
680 #[test]
681 fn test_tail_dependence() {
682 let gaussian = CopulaConfig::gaussian(0.7);
684 assert_eq!(gaussian.lower_tail_dependence(), 0.0);
685 assert_eq!(gaussian.upper_tail_dependence(), 0.0);
686
687 let clayton = CopulaConfig::clayton(2.0);
689 assert!(clayton.lower_tail_dependence() > 0.0);
690 assert_eq!(clayton.upper_tail_dependence(), 0.0);
691
692 let gumbel = CopulaConfig::gumbel(2.0);
694 assert_eq!(gumbel.lower_tail_dependence(), 0.0);
695 assert!(gumbel.upper_tail_dependence() > 0.0);
696 }
697
698 #[test]
699 fn test_cholesky_decomposition() {
700 let matrix = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
701 let l = cholesky_decompose(&matrix).unwrap();
702
703 let reconstructed_00 = l[0][0] * l[0][0];
705 let reconstructed_01 = l[0][0] * l[1][0];
706 let reconstructed_11 = l[1][0] * l[1][0] + l[1][1] * l[1][1];
707
708 assert!((reconstructed_00 - 1.0).abs() < 0.001);
709 assert!((reconstructed_01 - 0.5).abs() < 0.001);
710 assert!((reconstructed_11 - 1.0).abs() < 0.001);
711 }
712
713 #[test]
714 fn test_standard_normal_cdf() {
715 assert!((standard_normal_cdf(0.0) - 0.5).abs() < 0.001);
716 assert!(standard_normal_cdf(-3.0) < 0.01);
717 assert!(standard_normal_cdf(3.0) > 0.99);
718 }
719
720 #[test]
721 fn test_clayton_copula() {
722 let config = CopulaConfig::clayton(2.0);
723 let mut sampler = BivariateCopulaSampler::new(42, config).unwrap();
724
725 let samples = sampler.sample_n(1000);
726 assert!(samples
727 .iter()
728 .all(|(u, v)| *u >= 0.0 && *u <= 1.0 && *v >= 0.0 && *v <= 1.0));
729 }
730
731 #[test]
732 fn test_frank_copula() {
733 let config = CopulaConfig::frank(5.0);
734 let mut sampler = BivariateCopulaSampler::new(42, config).unwrap();
735
736 let samples = sampler.sample_n(1000);
737 assert!(samples
738 .iter()
739 .all(|(u, v)| *u >= 0.0 && *u <= 1.0 && *v >= 0.0 && *v <= 1.0));
740 }
741}