datasynth_eval/behavioral_fidelity/
math.rs1use std::cmp::Ordering;
4
5use chrono::{Datelike, NaiveDate, Weekday};
6
7pub fn wasserstein_1(a: &[f64], b: &[f64]) -> f64 {
15 if a.is_empty() || b.is_empty() {
16 return 0.0;
17 }
18 let mut sa: Vec<f64> = a.iter().copied().filter(|x| x.is_finite()).collect();
19 let mut sb: Vec<f64> = b.iter().copied().filter(|x| x.is_finite()).collect();
20 sa.sort_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Equal));
21 sb.sort_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Equal));
22 if sa.len() == sb.len() {
23 return sa
24 .iter()
25 .zip(sb.iter())
26 .map(|(x, y)| (x - y).abs())
27 .sum::<f64>()
28 / sa.len() as f64;
29 }
30 const STEPS: usize = 1024;
31 let mut acc = 0.0;
32 for k in 0..STEPS {
33 let t = (k as f64 + 0.5) / STEPS as f64;
34 let qa = quantile_sorted(&sa, t);
35 let qb = quantile_sorted(&sb, t);
36 acc += (qa - qb).abs();
37 }
38 acc / STEPS as f64
39}
40
41fn quantile_sorted(sorted: &[f64], t: f64) -> f64 {
42 if sorted.is_empty() {
43 return 0.0;
44 }
45 let pos = t * (sorted.len() as f64 - 1.0);
46 let lo = pos.floor() as usize;
47 let hi = pos.ceil() as usize;
48 if lo == hi {
49 sorted[lo]
50 } else {
51 let frac = pos - lo as f64;
52 sorted[lo] * (1.0 - frac) + sorted[hi] * frac
53 }
54}
55
56pub fn pearson_lag1_correlation(xs: &[f64]) -> Option<f64> {
61 if xs.len() < 3 {
62 return None;
63 }
64 let a = &xs[..xs.len() - 1];
65 let b = &xs[1..];
66 let n = a.len() as f64;
67 let mean_a = a.iter().sum::<f64>() / n;
68 let mean_b = b.iter().sum::<f64>() / n;
69 let mut num = 0.0;
70 let mut da = 0.0;
71 let mut db = 0.0;
72 for i in 0..a.len() {
73 let xa = a[i] - mean_a;
74 let xb = b[i] - mean_b;
75 num += xa * xb;
76 da += xa * xa;
77 db += xb * xb;
78 }
79 if da == 0.0 || db == 0.0 {
80 return None;
81 }
82 Some(num / (da.sqrt() * db.sqrt()))
83}
84
85pub fn percentile(xs: &[f64], pct: f64) -> f64 {
87 if xs.is_empty() {
88 return 0.0;
89 }
90 let mut s: Vec<f64> = xs.iter().copied().filter(|x| x.is_finite()).collect();
91 s.sort_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Equal));
92 quantile_sorted(&s, pct.clamp(0.0, 1.0))
93}
94
95pub fn days_between(a: NaiveDate, b: NaiveDate) -> i64 {
97 (b - a).num_days()
98}
99
100pub fn is_weekend(d: NaiveDate) -> bool {
102 matches!(d.weekday(), Weekday::Sat | Weekday::Sun)
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108
109 #[test]
110 fn w1_identical_samples_is_zero() {
111 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
112 let b = a.clone();
113 assert!((wasserstein_1(&a, &b)).abs() < 1e-9);
114 }
115
116 #[test]
117 fn w1_shifted_samples_equals_shift() {
118 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
119 let b: Vec<f64> = a.iter().map(|x| x + 3.0).collect();
120 assert!((wasserstein_1(&a, &b) - 3.0).abs() < 1e-9);
121 }
122
123 #[test]
124 fn w1_unequal_lengths_handles_gracefully() {
125 let a = vec![1.0; 10];
126 let b = vec![2.0; 100];
127 let d = wasserstein_1(&a, &b);
128 assert!((d - 1.0).abs() < 1e-3);
129 }
130
131 #[test]
132 fn pearson_lag1_positive_autocorr() {
133 let xs = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
134 let r = pearson_lag1_correlation(&xs).unwrap();
135 assert!((r - 1.0).abs() < 1e-9);
136 }
137
138 #[test]
139 fn pearson_lag1_negative_autocorr() {
140 let xs = vec![1.0, 10.0, 1.0, 10.0, 1.0, 10.0];
141 let r = pearson_lag1_correlation(&xs).unwrap();
142 assert!(r < -0.9);
143 }
144
145 #[test]
146 fn pearson_lag1_short_series_returns_none() {
147 let xs = vec![1.0, 2.0];
148 assert!(pearson_lag1_correlation(&xs).is_none());
149 }
150
151 #[test]
152 fn percentile_known_values() {
153 let xs: Vec<f64> = (1..=100).map(|i| i as f64).collect();
154 let p50 = percentile(&xs, 0.50);
155 assert!((p50 - 50.5).abs() < 1.0);
156 let p90 = percentile(&xs, 0.90);
157 assert!((p90 - 90.0).abs() < 1.0);
158 }
159
160 #[test]
161 fn days_between_known() {
162 let a = NaiveDate::from_ymd_opt(2022, 4, 25).unwrap();
163 let b = NaiveDate::from_ymd_opt(2022, 5, 2).unwrap();
164 assert_eq!(days_between(a, b), 7);
165 }
166
167 #[test]
168 fn is_weekend_known() {
169 assert!(is_weekend(NaiveDate::from_ymd_opt(2022, 4, 30).unwrap()));
171 assert!(!is_weekend(NaiveDate::from_ymd_opt(2022, 4, 25).unwrap()));
173 }
174}