datasynth_core/distributions/
copula.rs1use rand::prelude::*;
14use rand_chacha::ChaCha8Rng;
15use serde::{Deserialize, Serialize};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
19#[serde(rename_all = "snake_case")]
20pub enum CopulaType {
21 #[default]
23 Gaussian,
24 Clayton,
26 Gumbel,
28 Frank,
30 StudentT,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct CopulaConfig {
37 pub copula_type: CopulaType,
39 pub theta: f64,
45 #[serde(default = "default_df")]
47 pub degrees_of_freedom: f64,
48}
49
50fn default_df() -> f64 {
51 4.0
52}
53
54impl Default for CopulaConfig {
55 fn default() -> Self {
56 Self {
57 copula_type: CopulaType::Gaussian,
58 theta: 0.5,
59 degrees_of_freedom: 4.0,
60 }
61 }
62}
63
64impl CopulaConfig {
65 pub fn gaussian(correlation: f64) -> Self {
67 Self {
68 copula_type: CopulaType::Gaussian,
69 theta: correlation.clamp(-0.999, 0.999),
70 degrees_of_freedom: 4.0,
71 }
72 }
73
74 pub fn clayton(theta: f64) -> Self {
76 Self {
77 copula_type: CopulaType::Clayton,
78 theta: theta.max(0.001),
79 degrees_of_freedom: 4.0,
80 }
81 }
82
83 pub fn gumbel(theta: f64) -> Self {
85 Self {
86 copula_type: CopulaType::Gumbel,
87 theta: theta.max(1.0),
88 degrees_of_freedom: 4.0,
89 }
90 }
91
92 pub fn frank(theta: f64) -> Self {
94 Self {
95 copula_type: CopulaType::Frank,
96 theta: if theta.abs() < 0.001 { 0.001 } else { theta },
97 degrees_of_freedom: 4.0,
98 }
99 }
100
101 pub fn student_t(correlation: f64, df: f64) -> Self {
103 Self {
104 copula_type: CopulaType::StudentT,
105 theta: correlation.clamp(-0.999, 0.999),
106 degrees_of_freedom: df.max(2.0),
107 }
108 }
109
110 pub fn validate(&self) -> Result<(), String> {
112 match self.copula_type {
113 CopulaType::Gaussian | CopulaType::StudentT => {
114 if self.theta < -1.0 || self.theta > 1.0 {
115 return Err(format!(
116 "Correlation must be in [-1, 1], got {}",
117 self.theta
118 ));
119 }
120 }
121 CopulaType::Clayton => {
122 if self.theta <= 0.0 {
123 return Err(format!("Clayton theta must be > 0, got {}", self.theta));
124 }
125 }
126 CopulaType::Gumbel => {
127 if self.theta < 1.0 {
128 return Err(format!("Gumbel theta must be >= 1, got {}", self.theta));
129 }
130 }
131 CopulaType::Frank => {
132 if self.theta.abs() < 0.0001 {
133 return Err("Frank theta must be non-zero".to_string());
134 }
135 }
136 }
137
138 if self.copula_type == CopulaType::StudentT && self.degrees_of_freedom <= 0.0 {
139 return Err("Degrees of freedom must be positive".to_string());
140 }
141
142 Ok(())
143 }
144
145 pub fn kendalls_tau(&self) -> f64 {
147 match self.copula_type {
148 CopulaType::Gaussian | CopulaType::StudentT => {
149 2.0 * self.theta.asin() / std::f64::consts::PI
151 }
152 CopulaType::Clayton => {
153 self.theta / (self.theta + 2.0)
155 }
156 CopulaType::Gumbel => {
157 1.0 - 1.0 / self.theta
159 }
160 CopulaType::Frank => {
161 let abs_theta = self.theta.abs();
164 if abs_theta < 10.0 {
165 1.0 - 4.0 / self.theta + 4.0 / self.theta.powi(2) * debye_1(abs_theta)
166 } else {
167 self.theta.signum() * (1.0 - 4.0 / abs_theta)
169 }
170 }
171 }
172 }
173
174 pub fn lower_tail_dependence(&self) -> f64 {
176 match self.copula_type {
177 CopulaType::Gaussian | CopulaType::Frank => 0.0,
178 CopulaType::Clayton => 2.0_f64.powf(-1.0 / self.theta),
179 CopulaType::Gumbel => 0.0,
180 CopulaType::StudentT => {
181 let nu = self.degrees_of_freedom;
184 let rho = self.theta;
185 let arg = ((nu + 1.0) * (1.0 - rho) / (1.0 + rho)).sqrt();
186 2.0 * student_t_cdf(-arg, nu + 1.0)
187 }
188 }
189 }
190
191 pub fn upper_tail_dependence(&self) -> f64 {
193 match self.copula_type {
194 CopulaType::Gaussian | CopulaType::Frank => 0.0,
195 CopulaType::Clayton => 0.0,
196 CopulaType::Gumbel => 2.0 - 2.0_f64.powf(1.0 / self.theta),
197 CopulaType::StudentT => self.lower_tail_dependence(), }
199 }
200}
201
202pub struct BivariateCopulaSampler {
204 rng: ChaCha8Rng,
205 config: CopulaConfig,
206}
207
208impl BivariateCopulaSampler {
209 pub fn new(seed: u64, config: CopulaConfig) -> Result<Self, String> {
211 config.validate()?;
212 Ok(Self {
213 rng: ChaCha8Rng::seed_from_u64(seed),
214 config,
215 })
216 }
217
218 pub fn sample(&mut self) -> (f64, f64) {
220 match self.config.copula_type {
221 CopulaType::Gaussian => self.sample_gaussian(),
222 CopulaType::Clayton => self.sample_clayton(),
223 CopulaType::Gumbel => self.sample_gumbel(),
224 CopulaType::Frank => self.sample_frank(),
225 CopulaType::StudentT => self.sample_student_t(),
226 }
227 }
228
229 fn sample_gaussian(&mut self) -> (f64, f64) {
231 let rho = self.config.theta;
232
233 let z1 = self.sample_standard_normal();
235 let z2 = self.sample_standard_normal();
236
237 let x1 = z1;
239 let x2 = rho * z1 + (1.0 - rho.powi(2)).sqrt() * z2;
240
241 (standard_normal_cdf(x1), standard_normal_cdf(x2))
243 }
244
245 fn sample_clayton(&mut self) -> (f64, f64) {
247 let theta = self.config.theta;
248
249 let u: f64 = self.rng.gen();
251 let t: f64 = self.rng.gen();
252
253 let v = (u.powf(-theta) * (t.powf(-theta / (theta + 1.0)) - 1.0) + 1.0).powf(-1.0 / theta);
255
256 (u, v.clamp(0.0, 1.0))
257 }
258
259 fn sample_gumbel(&mut self) -> (f64, f64) {
261 let theta = self.config.theta;
262
263 let u: f64 = self.rng.gen();
267 let t: f64 = self.rng.gen();
268
269 let s = sample_positive_stable(&mut self.rng, 1.0 / theta);
271 let e1 = sample_exponential(&mut self.rng, 1.0);
272 let e2 = sample_exponential(&mut self.rng, 1.0);
273
274 let v1 = (-e1 / s).exp().powf(1.0 / theta);
275 let v2 = (-e2 / s).exp().powf(1.0 / theta);
276
277 let c_u = v1 / (v1 + v2);
278 let c_v = v2 / (v1 + v2);
279
280 let u_out = (-((-u.ln()).powf(theta) + (-c_u.ln()).powf(theta)).powf(1.0 / theta)).exp();
282 let v_out = (-((-t.ln()).powf(theta) + (-c_v.ln()).powf(theta)).powf(1.0 / theta)).exp();
283
284 (u_out.clamp(0.0001, 0.9999), v_out.clamp(0.0001, 0.9999))
285 }
286
287 fn sample_frank(&mut self) -> (f64, f64) {
289 let theta = self.config.theta;
290
291 let u: f64 = self.rng.gen();
292 let t: f64 = self.rng.gen();
293
294 let v = -((1.0 - t)
296 / (t * (-theta).exp() + (1.0 - t) * (1.0 - u * (1.0 - (-theta).exp())).recip()))
297 .ln()
298 / theta;
299
300 (u, v.clamp(0.0, 1.0))
301 }
302
303 fn sample_student_t(&mut self) -> (f64, f64) {
305 let rho = self.config.theta;
306 let nu = self.config.degrees_of_freedom;
307
308 let chi2 = sample_chi_squared(&mut self.rng, nu);
310 let scale = (nu / chi2).sqrt();
311
312 let z1 = self.sample_standard_normal();
314 let z2 = self.sample_standard_normal();
315
316 let x1 = z1 * scale;
317 let x2 = (rho * z1 + (1.0 - rho.powi(2)).sqrt() * z2) * scale;
318
319 (student_t_cdf(x1, nu), student_t_cdf(x2, nu))
321 }
322
323 fn sample_standard_normal(&mut self) -> f64 {
325 let u1: f64 = self.rng.gen();
326 let u2: f64 = self.rng.gen();
327 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
328 }
329
330 pub fn sample_n(&mut self, n: usize) -> Vec<(f64, f64)> {
332 (0..n).map(|_| self.sample()).collect()
333 }
334
335 pub fn reset(&mut self, seed: u64) {
337 self.rng = ChaCha8Rng::seed_from_u64(seed);
338 }
339
340 pub fn config(&self) -> &CopulaConfig {
342 &self.config
343 }
344}
345
346pub fn cholesky_decompose(matrix: &[Vec<f64>]) -> Option<Vec<Vec<f64>>> {
348 let n = matrix.len();
349 let mut l = vec![vec![0.0; n]; n];
350
351 for i in 0..n {
352 for j in 0..=i {
353 let sum: f64 = (0..j).map(|k| l[i][k] * l[j][k]).sum();
354
355 if i == j {
356 let diag = matrix[i][i] - sum;
357 if diag <= 0.0 {
358 l[i][j] = (diag + 0.001).sqrt();
360 } else {
361 l[i][j] = diag.sqrt();
362 }
363 } else {
364 if l[j][j].abs() < 1e-10 {
365 return None;
366 }
367 l[i][j] = (matrix[i][j] - sum) / l[j][j];
368 }
369 }
370 }
371
372 Some(l)
373}
374
375pub fn standard_normal_cdf(x: f64) -> f64 {
377 0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
378}
379
380pub fn standard_normal_quantile(p: f64) -> f64 {
382 if p <= 0.0 {
384 return f64::NEG_INFINITY;
385 }
386 if p >= 1.0 {
387 return f64::INFINITY;
388 }
389
390 let p_low = 0.02425;
391 let p_high = 1.0 - p_low;
392
393 if p < p_low {
394 let q = (-2.0 * p.ln()).sqrt();
396 let c = [2.515517, 0.802853, 0.010328];
397 let d = [1.432788, 0.189269, 0.001308];
398 -(c[0] + c[1] * q + c[2] * q.powi(2))
399 / (1.0 + d[0] * q + d[1] * q.powi(2) + d[2] * q.powi(3))
400 + q
401 } else if p <= p_high {
402 let q = p - 0.5;
404 let r = q * q;
405 let a = [
406 2.50662823884,
407 -18.61500062529,
408 41.39119773534,
409 -25.44106049637,
410 ];
411 let b = [
412 -8.47351093090,
413 23.08336743743,
414 -21.06224101826,
415 3.13082909833,
416 ];
417 q * (a[0] + a[1] * r + a[2] * r.powi(2) + a[3] * r.powi(3))
418 / (1.0 + b[0] * r + b[1] * r.powi(2) + b[2] * r.powi(3) + b[3] * r.powi(4))
419 } else {
420 let q = (-2.0 * (1.0 - p).ln()).sqrt();
422 let c = [2.515517, 0.802853, 0.010328];
423 let d = [1.432788, 0.189269, 0.001308];
424 (c[0] + c[1] * q + c[2] * q.powi(2))
425 / (1.0 + d[0] * q + d[1] * q.powi(2) + d[2] * q.powi(3))
426 - q
427 }
428}
429
430fn erf(x: f64) -> f64 {
432 let a1 = 0.254829592;
433 let a2 = -0.284496736;
434 let a3 = 1.421413741;
435 let a4 = -1.453152027;
436 let a5 = 1.061405429;
437 let p = 0.3275911;
438
439 let sign = if x < 0.0 { -1.0 } else { 1.0 };
440 let x = x.abs();
441
442 let t = 1.0 / (1.0 + p * x);
443 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
444
445 sign * y
446}
447
448fn student_t_cdf(x: f64, df: f64) -> f64 {
450 if df > 30.0 {
452 return standard_normal_cdf(x);
453 }
454
455 let t2 = x * x;
457 let prob = 0.5 * incomplete_beta(df / 2.0, 0.5, df / (df + t2));
458
459 if x > 0.0 {
460 1.0 - prob
461 } else {
462 prob
463 }
464}
465
466fn incomplete_beta(a: f64, b: f64, x: f64) -> f64 {
468 if x <= 0.0 {
469 return 0.0;
470 }
471 if x >= 1.0 {
472 return 1.0;
473 }
474
475 let lbeta = ln_gamma(a) + ln_gamma(b) - ln_gamma(a + b);
477 let front = (x.powf(a) * (1.0 - x).powf(b)) / lbeta.exp();
478
479 let mut c: f64 = 1.0;
481 let mut d: f64 = 1.0 / (1.0 - (a + b) * x / (a + 1.0)).max(1e-30);
482 let mut h = d;
483
484 for m in 1..100 {
485 let m = m as f64;
486 let d1 = m * (b - m) * x / ((a + 2.0 * m - 1.0) * (a + 2.0 * m));
487 let d2 = -(a + m) * (a + b + m) * x / ((a + 2.0 * m) * (a + 2.0 * m + 1.0));
488
489 d = 1.0 / (1.0 + d1 * d).max(1e-30);
490 c = 1.0 + d1 / c.max(1e-30);
491 h *= c * d;
492
493 d = 1.0 / (1.0 + d2 * d).max(1e-30);
494 c = 1.0 + d2 / c.max(1e-30);
495 h *= c * d;
496
497 if ((c * d) - 1.0).abs() < 1e-8 {
498 break;
499 }
500 }
501
502 front * h / a
503}
504
505fn ln_gamma(x: f64) -> f64 {
507 if x <= 0.0 {
508 return f64::INFINITY;
509 }
510 0.5 * (2.0 * std::f64::consts::PI / x).ln() + x * ((x + 1.0 / (12.0 * x)).ln() - 1.0)
511}
512
513fn debye_1(x: f64) -> f64 {
515 if x.abs() < 0.01 {
516 return 1.0 - x / 4.0 + x.powi(2) / 36.0;
517 }
518
519 let n = 100;
521 let h = x / n as f64;
522 let mut sum = 0.0;
523
524 for i in 1..n {
525 let t = i as f64 * h;
526 sum += t / (t.exp() - 1.0);
527 }
528
529 (sum + 0.5 * (h / (h.exp() - 1.0) + x / (x.exp() - 1.0))) * h / x
530}
531
532fn sample_exponential(rng: &mut ChaCha8Rng, lambda: f64) -> f64 {
534 let u: f64 = rng.gen();
535 -u.ln() / lambda
536}
537
538fn sample_chi_squared(rng: &mut ChaCha8Rng, df: f64) -> f64 {
540 let n = df.floor() as usize;
541 let mut sum = 0.0;
542 for _ in 0..n {
543 let u1: f64 = rng.gen();
544 let u2: f64 = rng.gen();
545 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
546 sum += z * z;
547 }
548 sum
549}
550
551fn sample_positive_stable(rng: &mut ChaCha8Rng, alpha: f64) -> f64 {
553 if (alpha - 1.0).abs() < 0.001 {
554 return 1.0;
555 }
556
557 let u: f64 = rng.gen::<f64>() * std::f64::consts::PI - std::f64::consts::PI / 2.0;
558 let e = sample_exponential(rng, 1.0);
559
560 let b = (std::f64::consts::PI * alpha / 2.0).tan();
561 let s = (1.0 + b * b).powf(1.0 / (2.0 * alpha));
562
563 let term1 = (alpha * u).sin();
564 let term2 = (u.cos()).powf(1.0 / alpha);
565 let term3 = ((1.0 - alpha) * u).cos() / e;
566
567 s * term1 / term2 * term3.powf((1.0 - alpha) / alpha)
568}
569
570#[cfg(test)]
571mod tests {
572 use super::*;
573
574 #[test]
575 fn test_copula_validation() {
576 let gaussian = CopulaConfig::gaussian(0.5);
577 assert!(gaussian.validate().is_ok());
578
579 let invalid_gaussian = CopulaConfig {
581 copula_type: CopulaType::Gaussian,
582 theta: 1.5, degrees_of_freedom: 4.0,
584 };
585 assert!(invalid_gaussian.validate().is_err());
586
587 let clayton = CopulaConfig::clayton(2.0);
588 assert!(clayton.validate().is_ok());
589
590 let invalid_clayton = CopulaConfig {
592 copula_type: CopulaType::Clayton,
593 theta: -1.0, degrees_of_freedom: 4.0,
595 };
596 assert!(invalid_clayton.validate().is_err());
597
598 let gumbel = CopulaConfig::gumbel(2.0);
599 assert!(gumbel.validate().is_ok());
600
601 let invalid_gumbel = CopulaConfig {
603 copula_type: CopulaType::Gumbel,
604 theta: 0.5, degrees_of_freedom: 4.0,
606 };
607 assert!(invalid_gumbel.validate().is_err());
608 }
609
610 #[test]
611 fn test_gaussian_copula_sampling() {
612 let config = CopulaConfig::gaussian(0.7);
613 let mut sampler = BivariateCopulaSampler::new(42, config).unwrap();
614
615 let samples = sampler.sample_n(1000);
616 assert_eq!(samples.len(), 1000);
617
618 assert!(samples
620 .iter()
621 .all(|(u, v)| *u >= 0.0 && *u <= 1.0 && *v >= 0.0 && *v <= 1.0));
622
623 let mean_u: f64 = samples.iter().map(|(u, _)| u).sum::<f64>() / 1000.0;
625 let mean_v: f64 = samples.iter().map(|(_, v)| v).sum::<f64>() / 1000.0;
626 let covariance: f64 = samples
627 .iter()
628 .map(|(u, v)| (u - mean_u) * (v - mean_v))
629 .sum::<f64>()
630 / 1000.0;
631
632 assert!(covariance > 0.0); }
634
635 #[test]
636 fn test_copula_determinism() {
637 let config = CopulaConfig::gaussian(0.5);
638
639 let mut sampler1 = BivariateCopulaSampler::new(42, config.clone()).unwrap();
640 let mut sampler2 = BivariateCopulaSampler::new(42, config).unwrap();
641
642 for _ in 0..100 {
643 assert_eq!(sampler1.sample(), sampler2.sample());
644 }
645 }
646
647 #[test]
648 fn test_kendalls_tau() {
649 let gaussian = CopulaConfig::gaussian(0.5);
651 let tau = gaussian.kendalls_tau();
652 let expected = 2.0 * (0.5_f64).asin() / std::f64::consts::PI;
653 assert!((tau - expected).abs() < 0.001);
654
655 let clayton = CopulaConfig::clayton(2.0);
657 let tau = clayton.kendalls_tau();
658 assert!((tau - 0.5).abs() < 0.001);
659
660 let gumbel = CopulaConfig::gumbel(2.0);
662 let tau = gumbel.kendalls_tau();
663 assert!((tau - 0.5).abs() < 0.001);
664 }
665
666 #[test]
667 fn test_tail_dependence() {
668 let gaussian = CopulaConfig::gaussian(0.7);
670 assert_eq!(gaussian.lower_tail_dependence(), 0.0);
671 assert_eq!(gaussian.upper_tail_dependence(), 0.0);
672
673 let clayton = CopulaConfig::clayton(2.0);
675 assert!(clayton.lower_tail_dependence() > 0.0);
676 assert_eq!(clayton.upper_tail_dependence(), 0.0);
677
678 let gumbel = CopulaConfig::gumbel(2.0);
680 assert_eq!(gumbel.lower_tail_dependence(), 0.0);
681 assert!(gumbel.upper_tail_dependence() > 0.0);
682 }
683
684 #[test]
685 fn test_cholesky_decomposition() {
686 let matrix = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
687 let l = cholesky_decompose(&matrix).unwrap();
688
689 let reconstructed_00 = l[0][0] * l[0][0];
691 let reconstructed_01 = l[0][0] * l[1][0];
692 let reconstructed_11 = l[1][0] * l[1][0] + l[1][1] * l[1][1];
693
694 assert!((reconstructed_00 - 1.0).abs() < 0.001);
695 assert!((reconstructed_01 - 0.5).abs() < 0.001);
696 assert!((reconstructed_11 - 1.0).abs() < 0.001);
697 }
698
699 #[test]
700 fn test_standard_normal_cdf() {
701 assert!((standard_normal_cdf(0.0) - 0.5).abs() < 0.001);
702 assert!(standard_normal_cdf(-3.0) < 0.01);
703 assert!(standard_normal_cdf(3.0) > 0.99);
704 }
705
706 #[test]
707 fn test_clayton_copula() {
708 let config = CopulaConfig::clayton(2.0);
709 let mut sampler = BivariateCopulaSampler::new(42, config).unwrap();
710
711 let samples = sampler.sample_n(1000);
712 assert!(samples
713 .iter()
714 .all(|(u, v)| *u >= 0.0 && *u <= 1.0 && *v >= 0.0 && *v <= 1.0));
715 }
716
717 #[test]
718 fn test_frank_copula() {
719 let config = CopulaConfig::frank(5.0);
720 let mut sampler = BivariateCopulaSampler::new(42, config).unwrap();
721
722 let samples = sampler.sample_n(1000);
723 assert!(samples
724 .iter()
725 .all(|(u, v)| *u >= 0.0 && *u <= 1.0 && *v >= 0.0 && *v <= 1.0));
726 }
727}