Skip to main content

datasynth_eval/behavioral_fidelity/
math.rs

1//! Numerical primitives shared across the behavioral-fidelity metrics.
2
3use std::cmp::Ordering;
4
5use chrono::{Datelike, NaiveDate, Weekday};
6
7/// Wasserstein-1 distance between two empirical 1-D samples.
8///
9/// Implementation: integrate |F_a^{-1}(t) - F_b^{-1}(t)| over t ∈ \[0,1\]
10/// on a uniform grid of `quantile_steps` knots. For equal-length sorted
11/// samples this reduces to the mean L¹ distance, which we use directly
12/// as a fast path. Quantile-step default of 1024 keeps error < 1e-6 for
13/// practical distributions.
14pub fn wasserstein_1(a: &[f64], b: &[f64]) -> f64 {
15    if a.is_empty() || b.is_empty() {
16        return 0.0;
17    }
18    let mut sa: Vec<f64> = a.iter().copied().filter(|x| x.is_finite()).collect();
19    let mut sb: Vec<f64> = b.iter().copied().filter(|x| x.is_finite()).collect();
20    sa.sort_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Equal));
21    sb.sort_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Equal));
22    if sa.len() == sb.len() {
23        return sa
24            .iter()
25            .zip(sb.iter())
26            .map(|(x, y)| (x - y).abs())
27            .sum::<f64>()
28            / sa.len() as f64;
29    }
30    const STEPS: usize = 1024;
31    let mut acc = 0.0;
32    for k in 0..STEPS {
33        let t = (k as f64 + 0.5) / STEPS as f64;
34        let qa = quantile_sorted(&sa, t);
35        let qb = quantile_sorted(&sb, t);
36        acc += (qa - qb).abs();
37    }
38    acc / STEPS as f64
39}
40
41fn quantile_sorted(sorted: &[f64], t: f64) -> f64 {
42    if sorted.is_empty() {
43        return 0.0;
44    }
45    let pos = t * (sorted.len() as f64 - 1.0);
46    let lo = pos.floor() as usize;
47    let hi = pos.ceil() as usize;
48    if lo == hi {
49        sorted[lo]
50    } else {
51        let frac = pos - lo as f64;
52        sorted[lo] * (1.0 - frac) + sorted[hi] * frac
53    }
54}
55
56/// Lag-1 Pearson correlation between consecutive elements of `xs`.
57///
58/// Returns `None` if `xs.len() < 3` or if either of the two shifted
59/// series has zero variance.
60pub fn pearson_lag1_correlation(xs: &[f64]) -> Option<f64> {
61    if xs.len() < 3 {
62        return None;
63    }
64    let a = &xs[..xs.len() - 1];
65    let b = &xs[1..];
66    let n = a.len() as f64;
67    let mean_a = a.iter().sum::<f64>() / n;
68    let mean_b = b.iter().sum::<f64>() / n;
69    let mut num = 0.0;
70    let mut da = 0.0;
71    let mut db = 0.0;
72    for i in 0..a.len() {
73        let xa = a[i] - mean_a;
74        let xb = b[i] - mean_b;
75        num += xa * xb;
76        da += xa * xa;
77        db += xb * xb;
78    }
79    if da == 0.0 || db == 0.0 {
80        return None;
81    }
82    Some(num / (da.sqrt() * db.sqrt()))
83}
84
85/// Empirical percentile of an unsorted slice (clones + sorts).
86pub fn percentile(xs: &[f64], pct: f64) -> f64 {
87    if xs.is_empty() {
88        return 0.0;
89    }
90    let mut s: Vec<f64> = xs.iter().copied().filter(|x| x.is_finite()).collect();
91    s.sort_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Equal));
92    quantile_sorted(&s, pct.clamp(0.0, 1.0))
93}
94
95/// Days between two dates, can be negative.
96pub fn days_between(a: NaiveDate, b: NaiveDate) -> i64 {
97    (b - a).num_days()
98}
99
100/// `true` if the date falls on Sat or Sun.
101pub fn is_weekend(d: NaiveDate) -> bool {
102    matches!(d.weekday(), Weekday::Sat | Weekday::Sun)
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108
109    #[test]
110    fn w1_identical_samples_is_zero() {
111        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
112        let b = a.clone();
113        assert!((wasserstein_1(&a, &b)).abs() < 1e-9);
114    }
115
116    #[test]
117    fn w1_shifted_samples_equals_shift() {
118        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
119        let b: Vec<f64> = a.iter().map(|x| x + 3.0).collect();
120        assert!((wasserstein_1(&a, &b) - 3.0).abs() < 1e-9);
121    }
122
123    #[test]
124    fn w1_unequal_lengths_handles_gracefully() {
125        let a = vec![1.0; 10];
126        let b = vec![2.0; 100];
127        let d = wasserstein_1(&a, &b);
128        assert!((d - 1.0).abs() < 1e-3);
129    }
130
131    #[test]
132    fn pearson_lag1_positive_autocorr() {
133        let xs = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
134        let r = pearson_lag1_correlation(&xs).unwrap();
135        assert!((r - 1.0).abs() < 1e-9);
136    }
137
138    #[test]
139    fn pearson_lag1_negative_autocorr() {
140        let xs = vec![1.0, 10.0, 1.0, 10.0, 1.0, 10.0];
141        let r = pearson_lag1_correlation(&xs).unwrap();
142        assert!(r < -0.9);
143    }
144
145    #[test]
146    fn pearson_lag1_short_series_returns_none() {
147        let xs = vec![1.0, 2.0];
148        assert!(pearson_lag1_correlation(&xs).is_none());
149    }
150
151    #[test]
152    fn percentile_known_values() {
153        let xs: Vec<f64> = (1..=100).map(|i| i as f64).collect();
154        let p50 = percentile(&xs, 0.50);
155        assert!((p50 - 50.5).abs() < 1.0);
156        let p90 = percentile(&xs, 0.90);
157        assert!((p90 - 90.0).abs() < 1.0);
158    }
159
160    #[test]
161    fn days_between_known() {
162        let a = NaiveDate::from_ymd_opt(2022, 4, 25).unwrap();
163        let b = NaiveDate::from_ymd_opt(2022, 5, 2).unwrap();
164        assert_eq!(days_between(a, b), 7);
165    }
166
167    #[test]
168    fn is_weekend_known() {
169        // 2022-04-30 is a Saturday
170        assert!(is_weekend(NaiveDate::from_ymd_opt(2022, 4, 30).unwrap()));
171        // 2022-04-25 is a Monday
172        assert!(!is_weekend(NaiveDate::from_ymd_opt(2022, 4, 25).unwrap()));
173    }
174}