Skip to main content

datasynth_eval/ml/
domain_gap.rs

1//! Domain gap evaluation.
2//!
3//! Measures distribution divergence between synthetic and reference data
4//! using PSI (Population Stability Index), KS statistic, and MMD
5//! (Maximum Mean Discrepancy).
6
7use crate::error::EvalResult;
8use serde::{Deserialize, Serialize};
9
10/// A pair of value distributions to compare.
11#[derive(Debug, Clone)]
12pub struct DistributionSample {
13    /// Name of this distribution.
14    pub name: String,
15    /// Values from the synthetic dataset.
16    pub synthetic_values: Vec<f64>,
17    /// Values from the reference (real) dataset.
18    pub reference_values: Vec<f64>,
19}
20
21/// Thresholds for domain gap analysis.
22#[derive(Debug, Clone)]
23pub struct DomainGapThresholds {
24    /// Maximum acceptable domain gap score.
25    pub max_domain_gap: f64,
26}
27
28impl Default for DomainGapThresholds {
29    fn default() -> Self {
30        Self {
31            max_domain_gap: 0.25,
32        }
33    }
34}
35
36/// Detail for a single distribution comparison.
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct DomainGapDetail {
39    /// Name of the distribution.
40    pub name: String,
41    /// Population Stability Index.
42    pub psi: f64,
43    /// Kolmogorov-Smirnov statistic.
44    pub ks_statistic: f64,
45    /// Maximum Mean Discrepancy (Gaussian kernel).
46    pub mmd: f64,
47}
48
49/// Results of domain gap analysis.
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct DomainGapAnalysis {
52    /// Overall domain gap score (0 = identical, 1 = very different).
53    pub domain_gap_score: f64,
54    /// Per-distribution comparison details.
55    pub per_distribution: Vec<DomainGapDetail>,
56    /// Total number of distributions compared.
57    pub total_distributions: usize,
58    /// Whether the analysis passes all thresholds.
59    pub passes: bool,
60    /// Issues found during analysis.
61    pub issues: Vec<String>,
62}
63
64/// Analyzer for domain gap between synthetic and reference distributions.
65pub struct DomainGapAnalyzer {
66    thresholds: DomainGapThresholds,
67}
68
69impl DomainGapAnalyzer {
70    /// Create a new analyzer with default thresholds.
71    pub fn new() -> Self {
72        Self {
73            thresholds: DomainGapThresholds::default(),
74        }
75    }
76
77    /// Create an analyzer with custom thresholds.
78    pub fn with_thresholds(thresholds: DomainGapThresholds) -> Self {
79        Self { thresholds }
80    }
81
82    /// Analyze domain gap across distribution samples.
83    pub fn analyze(&self, samples: &[DistributionSample]) -> EvalResult<DomainGapAnalysis> {
84        let mut issues = Vec::new();
85        let total_distributions = samples.len();
86
87        if samples.is_empty() {
88            return Ok(DomainGapAnalysis {
89                domain_gap_score: 0.0,
90                per_distribution: Vec::new(),
91                total_distributions: 0,
92                passes: true,
93                issues: vec!["No distributions provided".to_string()],
94            });
95        }
96
97        let mut details = Vec::new();
98        let mut gap_sum = 0.0;
99
100        for sample in samples {
101            if sample.synthetic_values.is_empty() || sample.reference_values.is_empty() {
102                details.push(DomainGapDetail {
103                    name: sample.name.clone(),
104                    psi: 0.0,
105                    ks_statistic: 0.0,
106                    mmd: 0.0,
107                });
108                continue;
109            }
110
111            let psi = self.compute_psi(&sample.synthetic_values, &sample.reference_values);
112            let ks = self.compute_ks(&sample.synthetic_values, &sample.reference_values);
113            let mmd = self.compute_mmd(&sample.synthetic_values, &sample.reference_values);
114
115            // Normalize each metric to [0, 1] and average for composite gap
116            let psi_norm = (psi / 0.5).clamp(0.0, 1.0); // PSI > 0.25 is major shift
117            let ks_norm = ks.clamp(0.0, 1.0);
118            let mmd_norm = mmd.clamp(0.0, 1.0);
119
120            let gap = (psi_norm + ks_norm + mmd_norm) / 3.0;
121            gap_sum += gap;
122
123            details.push(DomainGapDetail {
124                name: sample.name.clone(),
125                psi,
126                ks_statistic: ks,
127                mmd,
128            });
129        }
130
131        let domain_gap_score = if total_distributions > 0 {
132            (gap_sum / total_distributions as f64).clamp(0.0, 1.0)
133        } else {
134            0.0
135        };
136
137        if domain_gap_score > self.thresholds.max_domain_gap {
138            issues.push(format!(
139                "Domain gap score {:.4} > {:.4} (threshold)",
140                domain_gap_score, self.thresholds.max_domain_gap
141            ));
142        }
143
144        let passes = issues.is_empty();
145
146        Ok(DomainGapAnalysis {
147            domain_gap_score,
148            per_distribution: details,
149            total_distributions,
150            passes,
151            issues,
152        })
153    }
154
155    /// Compute Population Stability Index.
156    ///
157    /// Bins both distributions into equal-width buckets and computes
158    /// PSI = sum((p_i - q_i) * ln(p_i / q_i)).
159    fn compute_psi(&self, synthetic: &[f64], reference: &[f64]) -> f64 {
160        let num_bins = 10;
161        let epsilon = 1e-6;
162
163        // Find global min/max
164        let all_min = synthetic
165            .iter()
166            .chain(reference.iter())
167            .cloned()
168            .fold(f64::INFINITY, f64::min);
169        let all_max = synthetic
170            .iter()
171            .chain(reference.iter())
172            .cloned()
173            .fold(f64::NEG_INFINITY, f64::max);
174
175        if (all_max - all_min).abs() < 1e-12 {
176            return 0.0;
177        }
178
179        let bin_width = (all_max - all_min) / num_bins as f64;
180
181        let bin_index = |val: f64| -> usize {
182            let idx = ((val - all_min) / bin_width) as usize;
183            idx.min(num_bins - 1)
184        };
185
186        let mut syn_counts = vec![0usize; num_bins];
187        let mut ref_counts = vec![0usize; num_bins];
188
189        for &v in synthetic {
190            syn_counts[bin_index(v)] += 1;
191        }
192        for &v in reference {
193            ref_counts[bin_index(v)] += 1;
194        }
195
196        let syn_total = synthetic.len() as f64;
197        let ref_total = reference.len() as f64;
198
199        let mut psi = 0.0;
200        for i in 0..num_bins {
201            let p = (syn_counts[i] as f64 / syn_total) + epsilon;
202            let q = (ref_counts[i] as f64 / ref_total) + epsilon;
203            psi += (p - q) * (p / q).ln();
204        }
205
206        psi.max(0.0)
207    }
208
209    /// Compute Kolmogorov-Smirnov statistic: max|F_synthetic - F_reference|.
210    fn compute_ks(&self, synthetic: &[f64], reference: &[f64]) -> f64 {
211        let mut syn_sorted = synthetic.to_vec();
212        let mut ref_sorted = reference.to_vec();
213        syn_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
214        ref_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
215
216        let syn_n = syn_sorted.len() as f64;
217        let ref_n = ref_sorted.len() as f64;
218
219        let mut max_diff = 0.0_f64;
220        let mut i = 0usize;
221        let mut j = 0usize;
222
223        while i < syn_sorted.len() && j < ref_sorted.len() {
224            let syn_cdf = (i + 1) as f64 / syn_n;
225            let ref_cdf = (j + 1) as f64 / ref_n;
226
227            if syn_sorted[i] <= ref_sorted[j] {
228                let diff = (syn_cdf - (j as f64 / ref_n)).abs();
229                if diff > max_diff {
230                    max_diff = diff;
231                }
232                i += 1;
233            } else {
234                let diff = ((i as f64 / syn_n) - ref_cdf).abs();
235                if diff > max_diff {
236                    max_diff = diff;
237                }
238                j += 1;
239            }
240        }
241
242        // Handle remaining elements
243        while i < syn_sorted.len() {
244            let syn_cdf = (i + 1) as f64 / syn_n;
245            let diff = (syn_cdf - 1.0).abs();
246            if diff > max_diff {
247                max_diff = diff;
248            }
249            i += 1;
250        }
251        while j < ref_sorted.len() {
252            let ref_cdf = (j + 1) as f64 / ref_n;
253            let diff = (1.0 - ref_cdf).abs();
254            if diff > max_diff {
255                max_diff = diff;
256            }
257            j += 1;
258        }
259
260        max_diff
261    }
262
263    /// Compute Maximum Mean Discrepancy with Gaussian kernel.
264    ///
265    /// Subsamples both distributions to min(1000, n) for efficiency.
266    fn compute_mmd(&self, synthetic: &[f64], reference: &[f64]) -> f64 {
267        let max_samples = 1000;
268        let syn_sub = subsample(synthetic, max_samples);
269        let ref_sub = subsample(reference, max_samples);
270
271        if syn_sub.is_empty() || ref_sub.is_empty() {
272            return 0.0;
273        }
274
275        // Estimate bandwidth using median heuristic
276        let sigma = self.median_bandwidth(&syn_sub, &ref_sub);
277        if sigma < 1e-12 {
278            return 0.0;
279        }
280
281        let gamma = -1.0 / (2.0 * sigma * sigma);
282
283        let k_xx = self.mean_kernel(&syn_sub, &syn_sub, gamma);
284        let k_yy = self.mean_kernel(&ref_sub, &ref_sub, gamma);
285        let k_xy = self.mean_kernel(&syn_sub, &ref_sub, gamma);
286
287        (k_xx + k_yy - 2.0 * k_xy).max(0.0).sqrt()
288    }
289
290    /// Compute mean Gaussian kernel value between two sets.
291    fn mean_kernel(&self, x: &[f64], y: &[f64], gamma: f64) -> f64 {
292        let mut sum = 0.0;
293        for &xi in x {
294            for &yi in y {
295                let diff = xi - yi;
296                sum += (gamma * diff * diff).exp();
297            }
298        }
299        sum / (x.len() as f64 * y.len() as f64)
300    }
301
302    /// Estimate kernel bandwidth using the median heuristic.
303    fn median_bandwidth(&self, x: &[f64], y: &[f64]) -> f64 {
304        let mut dists = Vec::new();
305        let step_x = if x.len() > 50 { x.len() / 50 } else { 1 };
306        let step_y = if y.len() > 50 { y.len() / 50 } else { 1 };
307
308        let mut ix = 0;
309        while ix < x.len() {
310            let mut iy = 0;
311            while iy < y.len() {
312                dists.push((x[ix] - y[iy]).abs());
313                iy += step_y;
314            }
315            ix += step_x;
316        }
317
318        if dists.is_empty() {
319            return 1.0;
320        }
321
322        dists.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
323        dists[dists.len() / 2].max(1e-6)
324    }
325}
326
327/// Subsample a slice to at most `max` elements using stride.
328fn subsample(data: &[f64], max: usize) -> Vec<f64> {
329    if data.len() <= max {
330        return data.to_vec();
331    }
332    let step = data.len() / max;
333    data.iter().step_by(step).copied().take(max).collect()
334}
335
336impl Default for DomainGapAnalyzer {
337    fn default() -> Self {
338        Self::new()
339    }
340}
341
342#[cfg(test)]
343#[allow(clippy::unwrap_used)]
344mod tests {
345    use super::*;
346
347    #[test]
348    fn test_identical_distributions() {
349        let samples = vec![DistributionSample {
350            name: "amount".to_string(),
351            synthetic_values: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
352            reference_values: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
353        }];
354
355        let analyzer = DomainGapAnalyzer::new();
356        let result = analyzer.analyze(&samples).unwrap();
357
358        assert!(result.domain_gap_score < 0.25);
359        assert!(result.passes);
360    }
361
362    #[test]
363    fn test_divergent_distributions() {
364        let samples = vec![DistributionSample {
365            name: "amount".to_string(),
366            synthetic_values: vec![1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5],
367            reference_values: vec![50.0, 55.0, 60.0, 65.0, 70.0, 75.0, 80.0, 85.0, 90.0, 95.0],
368        }];
369
370        let analyzer = DomainGapAnalyzer::new();
371        let result = analyzer.analyze(&samples).unwrap();
372
373        assert!(result.domain_gap_score > 0.25);
374        assert!(!result.passes);
375    }
376
377    #[test]
378    fn test_empty_samples() {
379        let analyzer = DomainGapAnalyzer::new();
380        let result = analyzer.analyze(&[]).unwrap();
381
382        assert_eq!(result.total_distributions, 0);
383        assert_eq!(result.domain_gap_score, 0.0);
384    }
385}