use crate::error::EvalResult;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
pub struct DistributionSample {
pub name: String,
pub synthetic_values: Vec<f64>,
pub reference_values: Vec<f64>,
}
#[derive(Debug, Clone)]
pub struct DomainGapThresholds {
pub max_domain_gap: f64,
}
impl Default for DomainGapThresholds {
fn default() -> Self {
Self {
max_domain_gap: 0.25,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DomainGapDetail {
pub name: String,
pub psi: f64,
pub ks_statistic: f64,
pub mmd: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DomainGapAnalysis {
pub domain_gap_score: f64,
pub per_distribution: Vec<DomainGapDetail>,
pub total_distributions: usize,
pub passes: bool,
pub issues: Vec<String>,
}
pub struct DomainGapAnalyzer {
thresholds: DomainGapThresholds,
}
impl DomainGapAnalyzer {
pub fn new() -> Self {
Self {
thresholds: DomainGapThresholds::default(),
}
}
pub fn with_thresholds(thresholds: DomainGapThresholds) -> Self {
Self { thresholds }
}
pub fn analyze(&self, samples: &[DistributionSample]) -> EvalResult<DomainGapAnalysis> {
let mut issues = Vec::new();
let total_distributions = samples.len();
if samples.is_empty() {
return Ok(DomainGapAnalysis {
domain_gap_score: 0.0,
per_distribution: Vec::new(),
total_distributions: 0,
passes: true,
issues: vec!["No distributions provided".to_string()],
});
}
let mut details = Vec::new();
let mut gap_sum = 0.0;
for sample in samples {
if sample.synthetic_values.is_empty() || sample.reference_values.is_empty() {
details.push(DomainGapDetail {
name: sample.name.clone(),
psi: 0.0,
ks_statistic: 0.0,
mmd: 0.0,
});
continue;
}
let psi = self.compute_psi(&sample.synthetic_values, &sample.reference_values);
let ks = self.compute_ks(&sample.synthetic_values, &sample.reference_values);
let mmd = self.compute_mmd(&sample.synthetic_values, &sample.reference_values);
let psi_norm = (psi / 0.5).clamp(0.0, 1.0); let ks_norm = ks.clamp(0.0, 1.0);
let mmd_norm = mmd.clamp(0.0, 1.0);
let gap = (psi_norm + ks_norm + mmd_norm) / 3.0;
gap_sum += gap;
details.push(DomainGapDetail {
name: sample.name.clone(),
psi,
ks_statistic: ks,
mmd,
});
}
let domain_gap_score = if total_distributions > 0 {
(gap_sum / total_distributions as f64).clamp(0.0, 1.0)
} else {
0.0
};
if domain_gap_score > self.thresholds.max_domain_gap {
issues.push(format!(
"Domain gap score {:.4} > {:.4} (threshold)",
domain_gap_score, self.thresholds.max_domain_gap
));
}
let passes = issues.is_empty();
Ok(DomainGapAnalysis {
domain_gap_score,
per_distribution: details,
total_distributions,
passes,
issues,
})
}
fn compute_psi(&self, synthetic: &[f64], reference: &[f64]) -> f64 {
let num_bins = 10;
let epsilon = 1e-6;
let all_min = synthetic
.iter()
.chain(reference.iter())
.cloned()
.fold(f64::INFINITY, f64::min);
let all_max = synthetic
.iter()
.chain(reference.iter())
.cloned()
.fold(f64::NEG_INFINITY, f64::max);
if (all_max - all_min).abs() < 1e-12 {
return 0.0;
}
let bin_width = (all_max - all_min) / num_bins as f64;
let bin_index = |val: f64| -> usize {
let idx = ((val - all_min) / bin_width) as usize;
idx.min(num_bins - 1)
};
let mut syn_counts = vec![0usize; num_bins];
let mut ref_counts = vec![0usize; num_bins];
for &v in synthetic {
syn_counts[bin_index(v)] += 1;
}
for &v in reference {
ref_counts[bin_index(v)] += 1;
}
let syn_total = synthetic.len() as f64;
let ref_total = reference.len() as f64;
let mut psi = 0.0;
for i in 0..num_bins {
let p = (syn_counts[i] as f64 / syn_total) + epsilon;
let q = (ref_counts[i] as f64 / ref_total) + epsilon;
psi += (p - q) * (p / q).ln();
}
psi.max(0.0)
}
fn compute_ks(&self, synthetic: &[f64], reference: &[f64]) -> f64 {
let mut syn_sorted = synthetic.to_vec();
let mut ref_sorted = reference.to_vec();
syn_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
ref_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let syn_n = syn_sorted.len() as f64;
let ref_n = ref_sorted.len() as f64;
let mut max_diff = 0.0_f64;
let mut i = 0usize;
let mut j = 0usize;
while i < syn_sorted.len() && j < ref_sorted.len() {
let syn_cdf = (i + 1) as f64 / syn_n;
let ref_cdf = (j + 1) as f64 / ref_n;
if syn_sorted[i] <= ref_sorted[j] {
let diff = (syn_cdf - (j as f64 / ref_n)).abs();
if diff > max_diff {
max_diff = diff;
}
i += 1;
} else {
let diff = ((i as f64 / syn_n) - ref_cdf).abs();
if diff > max_diff {
max_diff = diff;
}
j += 1;
}
}
while i < syn_sorted.len() {
let syn_cdf = (i + 1) as f64 / syn_n;
let diff = (syn_cdf - 1.0).abs();
if diff > max_diff {
max_diff = diff;
}
i += 1;
}
while j < ref_sorted.len() {
let ref_cdf = (j + 1) as f64 / ref_n;
let diff = (1.0 - ref_cdf).abs();
if diff > max_diff {
max_diff = diff;
}
j += 1;
}
max_diff
}
fn compute_mmd(&self, synthetic: &[f64], reference: &[f64]) -> f64 {
let max_samples = 1000;
let syn_sub = subsample(synthetic, max_samples);
let ref_sub = subsample(reference, max_samples);
if syn_sub.is_empty() || ref_sub.is_empty() {
return 0.0;
}
let sigma = self.median_bandwidth(&syn_sub, &ref_sub);
if sigma < 1e-12 {
return 0.0;
}
let gamma = -1.0 / (2.0 * sigma * sigma);
let k_xx = self.mean_kernel(&syn_sub, &syn_sub, gamma);
let k_yy = self.mean_kernel(&ref_sub, &ref_sub, gamma);
let k_xy = self.mean_kernel(&syn_sub, &ref_sub, gamma);
(k_xx + k_yy - 2.0 * k_xy).max(0.0).sqrt()
}
fn mean_kernel(&self, x: &[f64], y: &[f64], gamma: f64) -> f64 {
let mut sum = 0.0;
for &xi in x {
for &yi in y {
let diff = xi - yi;
sum += (gamma * diff * diff).exp();
}
}
sum / (x.len() as f64 * y.len() as f64)
}
fn median_bandwidth(&self, x: &[f64], y: &[f64]) -> f64 {
let mut dists = Vec::new();
let step_x = if x.len() > 50 { x.len() / 50 } else { 1 };
let step_y = if y.len() > 50 { y.len() / 50 } else { 1 };
let mut ix = 0;
while ix < x.len() {
let mut iy = 0;
while iy < y.len() {
dists.push((x[ix] - y[iy]).abs());
iy += step_y;
}
ix += step_x;
}
if dists.is_empty() {
return 1.0;
}
dists.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
dists[dists.len() / 2].max(1e-6)
}
}
fn subsample(data: &[f64], max: usize) -> Vec<f64> {
if data.len() <= max {
return data.to_vec();
}
let step = data.len() / max;
data.iter().step_by(step).copied().take(max).collect()
}
impl Default for DomainGapAnalyzer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_identical_distributions() {
let samples = vec![DistributionSample {
name: "amount".to_string(),
synthetic_values: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
reference_values: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
}];
let analyzer = DomainGapAnalyzer::new();
let result = analyzer.analyze(&samples).unwrap();
assert!(result.domain_gap_score < 0.25);
assert!(result.passes);
}
#[test]
fn test_divergent_distributions() {
let samples = vec![DistributionSample {
name: "amount".to_string(),
synthetic_values: vec![1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5],
reference_values: vec![50.0, 55.0, 60.0, 65.0, 70.0, 75.0, 80.0, 85.0, 90.0, 95.0],
}];
let analyzer = DomainGapAnalyzer::new();
let result = analyzer.analyze(&samples).unwrap();
assert!(result.domain_gap_score > 0.25);
assert!(!result.passes);
}
#[test]
fn test_empty_samples() {
let analyzer = DomainGapAnalyzer::new();
let result = analyzer.analyze(&[]).unwrap();
assert_eq!(result.total_distributions, 0);
assert_eq!(result.domain_gap_score, 0.0);
}
}