1use crate::error::EvalResult;
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone)]
12pub struct DistributionSample {
13 pub name: String,
15 pub synthetic_values: Vec<f64>,
17 pub reference_values: Vec<f64>,
19}
20
21#[derive(Debug, Clone)]
23pub struct DomainGapThresholds {
24 pub max_domain_gap: f64,
26}
27
28impl Default for DomainGapThresholds {
29 fn default() -> Self {
30 Self {
31 max_domain_gap: 0.25,
32 }
33 }
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct DomainGapDetail {
39 pub name: String,
41 pub psi: f64,
43 pub ks_statistic: f64,
45 pub mmd: f64,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct DomainGapAnalysis {
52 pub domain_gap_score: f64,
54 pub per_distribution: Vec<DomainGapDetail>,
56 pub total_distributions: usize,
58 pub passes: bool,
60 pub issues: Vec<String>,
62}
63
64pub struct DomainGapAnalyzer {
66 thresholds: DomainGapThresholds,
67}
68
69impl DomainGapAnalyzer {
70 pub fn new() -> Self {
72 Self {
73 thresholds: DomainGapThresholds::default(),
74 }
75 }
76
77 pub fn with_thresholds(thresholds: DomainGapThresholds) -> Self {
79 Self { thresholds }
80 }
81
82 pub fn analyze(&self, samples: &[DistributionSample]) -> EvalResult<DomainGapAnalysis> {
84 let mut issues = Vec::new();
85 let total_distributions = samples.len();
86
87 if samples.is_empty() {
88 return Ok(DomainGapAnalysis {
89 domain_gap_score: 0.0,
90 per_distribution: Vec::new(),
91 total_distributions: 0,
92 passes: true,
93 issues: vec!["No distributions provided".to_string()],
94 });
95 }
96
97 let mut details = Vec::new();
98 let mut gap_sum = 0.0;
99
100 for sample in samples {
101 if sample.synthetic_values.is_empty() || sample.reference_values.is_empty() {
102 details.push(DomainGapDetail {
103 name: sample.name.clone(),
104 psi: 0.0,
105 ks_statistic: 0.0,
106 mmd: 0.0,
107 });
108 continue;
109 }
110
111 let psi = self.compute_psi(&sample.synthetic_values, &sample.reference_values);
112 let ks = self.compute_ks(&sample.synthetic_values, &sample.reference_values);
113 let mmd = self.compute_mmd(&sample.synthetic_values, &sample.reference_values);
114
115 let psi_norm = (psi / 0.5).clamp(0.0, 1.0); let ks_norm = ks.clamp(0.0, 1.0);
118 let mmd_norm = mmd.clamp(0.0, 1.0);
119
120 let gap = (psi_norm + ks_norm + mmd_norm) / 3.0;
121 gap_sum += gap;
122
123 details.push(DomainGapDetail {
124 name: sample.name.clone(),
125 psi,
126 ks_statistic: ks,
127 mmd,
128 });
129 }
130
131 let domain_gap_score = if total_distributions > 0 {
132 (gap_sum / total_distributions as f64).clamp(0.0, 1.0)
133 } else {
134 0.0
135 };
136
137 if domain_gap_score > self.thresholds.max_domain_gap {
138 issues.push(format!(
139 "Domain gap score {:.4} > {:.4} (threshold)",
140 domain_gap_score, self.thresholds.max_domain_gap
141 ));
142 }
143
144 let passes = issues.is_empty();
145
146 Ok(DomainGapAnalysis {
147 domain_gap_score,
148 per_distribution: details,
149 total_distributions,
150 passes,
151 issues,
152 })
153 }
154
155 fn compute_psi(&self, synthetic: &[f64], reference: &[f64]) -> f64 {
160 let num_bins = 10;
161 let epsilon = 1e-6;
162
163 let all_min = synthetic
165 .iter()
166 .chain(reference.iter())
167 .cloned()
168 .fold(f64::INFINITY, f64::min);
169 let all_max = synthetic
170 .iter()
171 .chain(reference.iter())
172 .cloned()
173 .fold(f64::NEG_INFINITY, f64::max);
174
175 if (all_max - all_min).abs() < 1e-12 {
176 return 0.0;
177 }
178
179 let bin_width = (all_max - all_min) / num_bins as f64;
180
181 let bin_index = |val: f64| -> usize {
182 let idx = ((val - all_min) / bin_width) as usize;
183 idx.min(num_bins - 1)
184 };
185
186 let mut syn_counts = vec![0usize; num_bins];
187 let mut ref_counts = vec![0usize; num_bins];
188
189 for &v in synthetic {
190 syn_counts[bin_index(v)] += 1;
191 }
192 for &v in reference {
193 ref_counts[bin_index(v)] += 1;
194 }
195
196 let syn_total = synthetic.len() as f64;
197 let ref_total = reference.len() as f64;
198
199 let mut psi = 0.0;
200 for i in 0..num_bins {
201 let p = (syn_counts[i] as f64 / syn_total) + epsilon;
202 let q = (ref_counts[i] as f64 / ref_total) + epsilon;
203 psi += (p - q) * (p / q).ln();
204 }
205
206 psi.max(0.0)
207 }
208
209 fn compute_ks(&self, synthetic: &[f64], reference: &[f64]) -> f64 {
211 let mut syn_sorted = synthetic.to_vec();
212 let mut ref_sorted = reference.to_vec();
213 syn_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
214 ref_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
215
216 let syn_n = syn_sorted.len() as f64;
217 let ref_n = ref_sorted.len() as f64;
218
219 let mut max_diff = 0.0_f64;
220 let mut i = 0usize;
221 let mut j = 0usize;
222
223 while i < syn_sorted.len() && j < ref_sorted.len() {
224 let syn_cdf = (i + 1) as f64 / syn_n;
225 let ref_cdf = (j + 1) as f64 / ref_n;
226
227 if syn_sorted[i] <= ref_sorted[j] {
228 let diff = (syn_cdf - (j as f64 / ref_n)).abs();
229 if diff > max_diff {
230 max_diff = diff;
231 }
232 i += 1;
233 } else {
234 let diff = ((i as f64 / syn_n) - ref_cdf).abs();
235 if diff > max_diff {
236 max_diff = diff;
237 }
238 j += 1;
239 }
240 }
241
242 while i < syn_sorted.len() {
244 let syn_cdf = (i + 1) as f64 / syn_n;
245 let diff = (syn_cdf - 1.0).abs();
246 if diff > max_diff {
247 max_diff = diff;
248 }
249 i += 1;
250 }
251 while j < ref_sorted.len() {
252 let ref_cdf = (j + 1) as f64 / ref_n;
253 let diff = (1.0 - ref_cdf).abs();
254 if diff > max_diff {
255 max_diff = diff;
256 }
257 j += 1;
258 }
259
260 max_diff
261 }
262
263 fn compute_mmd(&self, synthetic: &[f64], reference: &[f64]) -> f64 {
267 let max_samples = 1000;
268 let syn_sub = subsample(synthetic, max_samples);
269 let ref_sub = subsample(reference, max_samples);
270
271 if syn_sub.is_empty() || ref_sub.is_empty() {
272 return 0.0;
273 }
274
275 let sigma = self.median_bandwidth(&syn_sub, &ref_sub);
277 if sigma < 1e-12 {
278 return 0.0;
279 }
280
281 let gamma = -1.0 / (2.0 * sigma * sigma);
282
283 let k_xx = self.mean_kernel(&syn_sub, &syn_sub, gamma);
284 let k_yy = self.mean_kernel(&ref_sub, &ref_sub, gamma);
285 let k_xy = self.mean_kernel(&syn_sub, &ref_sub, gamma);
286
287 (k_xx + k_yy - 2.0 * k_xy).max(0.0).sqrt()
288 }
289
290 fn mean_kernel(&self, x: &[f64], y: &[f64], gamma: f64) -> f64 {
292 let mut sum = 0.0;
293 for &xi in x {
294 for &yi in y {
295 let diff = xi - yi;
296 sum += (gamma * diff * diff).exp();
297 }
298 }
299 sum / (x.len() as f64 * y.len() as f64)
300 }
301
302 fn median_bandwidth(&self, x: &[f64], y: &[f64]) -> f64 {
304 let mut dists = Vec::new();
305 let step_x = if x.len() > 50 { x.len() / 50 } else { 1 };
306 let step_y = if y.len() > 50 { y.len() / 50 } else { 1 };
307
308 let mut ix = 0;
309 while ix < x.len() {
310 let mut iy = 0;
311 while iy < y.len() {
312 dists.push((x[ix] - y[iy]).abs());
313 iy += step_y;
314 }
315 ix += step_x;
316 }
317
318 if dists.is_empty() {
319 return 1.0;
320 }
321
322 dists.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
323 dists[dists.len() / 2].max(1e-6)
324 }
325}
326
327fn subsample(data: &[f64], max: usize) -> Vec<f64> {
329 if data.len() <= max {
330 return data.to_vec();
331 }
332 let step = data.len() / max;
333 data.iter().step_by(step).copied().take(max).collect()
334}
335
336impl Default for DomainGapAnalyzer {
337 fn default() -> Self {
338 Self::new()
339 }
340}
341
342#[cfg(test)]
343#[allow(clippy::unwrap_used)]
344mod tests {
345 use super::*;
346
347 #[test]
348 fn test_identical_distributions() {
349 let samples = vec![DistributionSample {
350 name: "amount".to_string(),
351 synthetic_values: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
352 reference_values: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
353 }];
354
355 let analyzer = DomainGapAnalyzer::new();
356 let result = analyzer.analyze(&samples).unwrap();
357
358 assert!(result.domain_gap_score < 0.25);
359 assert!(result.passes);
360 }
361
362 #[test]
363 fn test_divergent_distributions() {
364 let samples = vec![DistributionSample {
365 name: "amount".to_string(),
366 synthetic_values: vec![1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5],
367 reference_values: vec![50.0, 55.0, 60.0, 65.0, 70.0, 75.0, 80.0, 85.0, 90.0, 95.0],
368 }];
369
370 let analyzer = DomainGapAnalyzer::new();
371 let result = analyzer.analyze(&samples).unwrap();
372
373 assert!(result.domain_gap_score > 0.25);
374 assert!(!result.passes);
375 }
376
377 #[test]
378 fn test_empty_samples() {
379 let analyzer = DomainGapAnalyzer::new();
380 let result = analyzer.analyze(&[]).unwrap();
381
382 assert_eq!(result.total_distributions, 0);
383 assert_eq!(result.domain_gap_score, 0.0);
384 }
385}