use std::collections::HashMap;
use serde::Serialize;
use crate::error::FingerprintResult;
use crate::extraction::{DataSource, FingerprintExtractor};
use crate::models::{Fingerprint, NumericStats, StatisticsFingerprint};
#[derive(Debug, Clone)]
pub struct FidelityConfig {
pub threshold: f64,
pub statistical_weight: f64,
pub correlation_weight: f64,
pub schema_weight: f64,
pub rule_weight: f64,
pub anomaly_weight: f64,
}
impl Default for FidelityConfig {
fn default() -> Self {
Self {
threshold: 0.8,
statistical_weight: 0.30,
correlation_weight: 0.20,
schema_weight: 0.20,
rule_weight: 0.20,
anomaly_weight: 0.10,
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct FidelityReport {
pub overall_score: f64,
pub statistical_fidelity: f64,
pub correlation_fidelity: f64,
pub schema_fidelity: f64,
pub rule_compliance: f64,
pub anomaly_fidelity: f64,
pub passes: bool,
pub details: FidelityDetails,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct FidelityDetails {
pub column_metrics: HashMap<String, ColumnFidelityMetrics>,
pub ks_statistics: HashMap<String, f64>,
pub wasserstein_distances: HashMap<String, f64>,
pub js_divergences: HashMap<String, f64>,
pub benford_mad: Option<f64>,
pub correlation_rmse: Option<f64>,
pub row_count_ratio: f64,
pub warnings: Vec<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct ColumnFidelityMetrics {
pub name: String,
pub ks_statistic: f64,
pub wasserstein_distance: f64,
pub js_divergence: f64,
pub mean_diff: f64,
pub std_dev_diff: f64,
}
pub struct FidelityEvaluator {
config: FidelityConfig,
}
impl FidelityEvaluator {
pub fn new() -> Self {
Self {
config: FidelityConfig::default(),
}
}
pub fn with_threshold(threshold: f64) -> Self {
Self {
config: FidelityConfig {
threshold,
..Default::default()
},
}
}
pub fn with_config(config: FidelityConfig) -> Self {
Self { config }
}
pub fn evaluate(
&self,
fingerprint: &Fingerprint,
synthetic_data: &DataSource,
) -> FingerprintResult<FidelityReport> {
let extractor = FingerprintExtractor::new();
let synthetic_fp = extractor.extract(synthetic_data)?;
self.evaluate_fingerprints(fingerprint, &synthetic_fp)
}
pub fn evaluate_fingerprints(
&self,
original: &Fingerprint,
synthetic: &Fingerprint,
) -> FingerprintResult<FidelityReport> {
let mut details = FidelityDetails::default();
let schema_fidelity = self.evaluate_schema(original, synthetic, &mut details);
let statistical_fidelity =
self.evaluate_statistical(&original.statistics, &synthetic.statistics, &mut details);
let raw_correlation_fidelity =
self.evaluate_correlations(original, synthetic, &mut details);
let correlation_fidelity = if raw_correlation_fidelity >= 0.99 && schema_fidelity < 0.5 {
schema_fidelity
} else {
raw_correlation_fidelity
};
let raw_rule_compliance = self.evaluate_rules(original, synthetic, &mut details);
let rule_compliance = if raw_rule_compliance >= 0.99 && schema_fidelity < 0.5 {
schema_fidelity
} else {
raw_rule_compliance
};
let raw_anomaly_fidelity = self.evaluate_anomalies(original, synthetic, &mut details);
let anomaly_fidelity = if raw_anomaly_fidelity >= 0.99 && schema_fidelity < 0.5 {
schema_fidelity
} else {
raw_anomaly_fidelity
};
let overall_score = self.config.statistical_weight * statistical_fidelity
+ self.config.correlation_weight * correlation_fidelity
+ self.config.schema_weight * schema_fidelity
+ self.config.rule_weight * rule_compliance
+ self.config.anomaly_weight * anomaly_fidelity;
let passes = overall_score >= self.config.threshold;
Ok(FidelityReport {
overall_score,
statistical_fidelity,
correlation_fidelity,
schema_fidelity,
rule_compliance,
anomaly_fidelity,
passes,
details,
})
}
fn evaluate_statistical(
&self,
original: &StatisticsFingerprint,
synthetic: &StatisticsFingerprint,
details: &mut FidelityDetails,
) -> f64 {
let mut scores = Vec::new();
let orig_numeric_count = original.numeric_columns.len();
let orig_categorical_count = original.categorical_columns.len();
let total_orig_columns = orig_numeric_count + orig_categorical_count;
let mut matched_numeric = 0;
let mut matched_categorical = 0;
let syn_numeric_by_stripped: HashMap<String, (&str, &NumericStats)> = synthetic
.numeric_columns
.iter()
.map(|(k, v)| {
let stripped = k.split('.').next_back().unwrap_or(k).to_string();
(stripped, (k.as_str(), v))
})
.collect();
let syn_categorical_keys: std::collections::HashSet<String> = synthetic
.categorical_columns
.keys()
.map(|k| k.split('.').next_back().unwrap_or(k).to_string())
.collect();
for (col_name, orig_stats) in &original.numeric_columns {
let stripped = col_name.split('.').next_back().unwrap_or(col_name);
if let Some(syn_stats) = synthetic.numeric_columns.get(col_name) {
matched_numeric += 1;
let metrics = self.compare_numeric_stats(col_name, orig_stats, syn_stats);
let col_score = 1.0
- (metrics.ks_statistic
+ metrics.mean_diff.abs().min(1.0)
+ metrics.std_dev_diff.abs().min(1.0))
/ 3.0;
scores.push(col_score.max(0.0));
details
.ks_statistics
.insert(col_name.clone(), metrics.ks_statistic);
details
.wasserstein_distances
.insert(col_name.clone(), metrics.wasserstein_distance);
details
.js_divergences
.insert(col_name.clone(), metrics.js_divergence);
details.column_metrics.insert(col_name.clone(), metrics);
} else if let Some((_syn_key, syn_stats)) = syn_numeric_by_stripped.get(stripped) {
matched_numeric += 1;
let metrics = self.compare_numeric_stats(stripped, orig_stats, syn_stats);
let col_score = 1.0
- (metrics.ks_statistic
+ metrics.mean_diff.abs().min(1.0)
+ metrics.std_dev_diff.abs().min(1.0))
/ 3.0;
scores.push(col_score.max(0.0));
details
.ks_statistics
.insert(col_name.clone(), metrics.ks_statistic);
details
.wasserstein_distances
.insert(col_name.clone(), metrics.wasserstein_distance);
details
.js_divergences
.insert(col_name.clone(), metrics.js_divergence);
details.column_metrics.insert(col_name.clone(), metrics);
}
}
for col_name in original.categorical_columns.keys() {
let stripped = col_name.split('.').next_back().unwrap_or(col_name);
if synthetic.categorical_columns.contains_key(col_name)
|| syn_categorical_keys.contains(stripped)
{
matched_categorical += 1;
}
}
if let (Some(orig_benford), Some(syn_benford)) =
(&original.benford_analysis, &synthetic.benford_analysis)
{
let benford_mad = compute_benford_mad(
&orig_benford.observed_frequencies,
&syn_benford.observed_frequencies,
);
details.benford_mad = Some(benford_mad);
scores.push(1.0 - benford_mad.min(0.1) * 10.0); }
if total_orig_columns > 0 {
let total_matched = matched_numeric + matched_categorical;
let match_ratio = total_matched as f64 / total_orig_columns as f64;
if total_matched == 0 {
details.warnings.push(
"No columns matched between original and synthetic fingerprints".to_string(),
);
return 0.0;
}
if scores.is_empty() {
return match_ratio;
}
let avg_score = scores.iter().sum::<f64>() / scores.len() as f64;
return avg_score * match_ratio;
}
if scores.is_empty() {
return 1.0; }
scores.iter().sum::<f64>() / scores.len() as f64
}
fn compare_numeric_stats(
&self,
name: &str,
original: &NumericStats,
synthetic: &NumericStats,
) -> ColumnFidelityMetrics {
let ks_statistic = self.compute_percentile_ks(original, synthetic);
let mean_range = (original.max - original.min).max(1.0);
let mean_diff = (original.mean - synthetic.mean) / mean_range;
let std_dev_diff = if original.std_dev > 0.0 {
(original.std_dev - synthetic.std_dev) / original.std_dev
} else {
0.0
};
let wasserstein_distance =
wasserstein_distance_from_percentiles(&original.percentiles, &synthetic.percentiles);
let js_divergence =
js_divergence_from_percentiles(&original.percentiles, &synthetic.percentiles);
ColumnFidelityMetrics {
name: name.to_string(),
ks_statistic,
wasserstein_distance,
js_divergence,
mean_diff,
std_dev_diff,
}
}
fn compute_percentile_ks(&self, original: &NumericStats, synthetic: &NumericStats) -> f64 {
let orig_pcts = original.percentiles.to_array();
let syn_pcts = synthetic.percentiles.to_array();
let range = (original.max - original.min).max(1.0);
orig_pcts
.iter()
.zip(syn_pcts.iter())
.map(|(&o, &s)| ((o - s) / range).abs())
.fold(0.0, f64::max)
}
fn evaluate_correlations(
&self,
original: &Fingerprint,
synthetic: &Fingerprint,
details: &mut FidelityDetails,
) -> f64 {
let (orig_corr, syn_corr) = match (&original.correlations, &synthetic.correlations) {
(Some(o), Some(s)) => (o, s),
_ => return 1.0, };
let mut rmse_sum = 0.0;
let mut count = 0;
for (table_name, orig_matrix) in &orig_corr.matrices {
if let Some(syn_matrix) = syn_corr.matrices.get(table_name) {
for (i, &orig_val) in orig_matrix.correlations.iter().enumerate() {
if let Some(&syn_val) = syn_matrix.correlations.get(i) {
rmse_sum += (orig_val - syn_val).powi(2);
count += 1;
}
}
}
}
if count == 0 {
return 1.0;
}
let rmse = (rmse_sum / count as f64).sqrt();
details.correlation_rmse = Some(rmse);
1.0 - rmse.min(1.0)
}
fn evaluate_schema(
&self,
original: &Fingerprint,
synthetic: &Fingerprint,
details: &mut FidelityDetails,
) -> f64 {
let orig_tables: std::collections::HashSet<_> = original.schema.tables.keys().collect();
let syn_tables: std::collections::HashSet<_> = synthetic.schema.tables.keys().collect();
let common_tables = orig_tables.intersection(&syn_tables).count();
let total_tables = orig_tables.len().max(syn_tables.len());
if total_tables == 0 {
return 1.0; }
let table_overlap_ratio = common_tables as f64 / total_tables as f64;
let orig_rows: u64 = original.schema.tables.values().map(|t| t.row_count).sum();
let syn_rows: u64 = synthetic.schema.tables.values().map(|t| t.row_count).sum();
let ratio = if orig_rows > 0 {
syn_rows as f64 / orig_rows as f64
} else {
1.0
};
details.row_count_ratio = ratio;
let ratio_penalty = (ratio - 1.0).abs().min(1.0) * 0.2;
let mut column_match_scores = Vec::new();
if common_tables > 0 {
for (table_name, orig_table) in &original.schema.tables {
if let Some(syn_table) = synthetic.schema.tables.get(table_name) {
let orig_cols: std::collections::HashSet<_> =
orig_table.columns.iter().map(|c| &c.name).collect();
let syn_cols: std::collections::HashSet<_> =
syn_table.columns.iter().map(|c| &c.name).collect();
let common_cols = orig_cols.intersection(&syn_cols).count();
let total_cols = orig_cols.len().max(syn_cols.len());
if total_cols > 0 {
column_match_scores.push(common_cols as f64 / total_cols as f64);
}
}
}
} else if orig_tables.len() == syn_tables.len() {
let orig_table_list: Vec<_> = original.schema.tables.values().collect();
let syn_table_list: Vec<_> = synthetic.schema.tables.values().collect();
for orig_table in &orig_table_list {
let orig_cols: std::collections::HashSet<_> =
orig_table.columns.iter().map(|c| &c.name).collect();
let mut best_match_score: f64 = 0.0;
for syn_table in &syn_table_list {
let syn_cols: std::collections::HashSet<_> =
syn_table.columns.iter().map(|c| &c.name).collect();
let common_cols = orig_cols.intersection(&syn_cols).count();
let total_cols = orig_cols.len().max(syn_cols.len());
if total_cols > 0 {
let score = common_cols as f64 / total_cols as f64;
best_match_score = best_match_score.max(score);
}
}
column_match_scores.push(best_match_score);
}
} else {
let missing = orig_tables.difference(&syn_tables).count();
details.warnings.push(format!(
"{missing} tables missing in synthetic data (no overlap)"
));
return 0.0;
}
let missing = orig_tables.difference(&syn_tables).count();
if missing > 0 && common_tables > 0 {
details
.warnings
.push(format!("{missing} tables missing in synthetic data"));
}
let column_score = if column_match_scores.is_empty() {
if common_tables == 0 {
0.0 } else {
1.0 }
} else {
column_match_scores.iter().sum::<f64>() / column_match_scores.len() as f64
};
let effective_table_ratio = if common_tables == 0 && column_score > 0.8 {
1.0
} else {
table_overlap_ratio
};
let score = 0.4 * effective_table_ratio + 0.4 * column_score + 0.2 * (1.0 - ratio_penalty);
score.clamp(0.0, 1.0)
}
fn evaluate_rules(
&self,
original: &Fingerprint,
synthetic: &Fingerprint,
_details: &mut FidelityDetails,
) -> f64 {
let (orig_rules, syn_rules) = match (&original.rules, &synthetic.rules) {
(Some(o), Some(s)) => (o, s),
_ => return 1.0,
};
let mut score = 1.0;
for orig_rule in &orig_rules.balance_rules {
if let Some(syn_rule) = syn_rules
.balance_rules
.iter()
.find(|r| r.name == orig_rule.name)
{
let diff = (orig_rule.compliance_rate - syn_rule.compliance_rate).abs();
score -= diff * 0.1;
}
}
score.max(0.0)
}
fn evaluate_anomalies(
&self,
original: &Fingerprint,
synthetic: &Fingerprint,
_details: &mut FidelityDetails,
) -> f64 {
let (orig_anomalies, syn_anomalies) = match (&original.anomalies, &synthetic.anomalies) {
(Some(o), Some(s)) => (o, s),
_ => return 1.0,
};
let rate_diff =
(orig_anomalies.overall.anomaly_rate - syn_anomalies.overall.anomaly_rate).abs();
1.0 - (rate_diff * 10.0).min(1.0)
}
}
impl Default for FidelityEvaluator {
fn default() -> Self {
Self::new()
}
}
const PERCENTILE_PROBS: [f64; 9] = [0.01, 0.05, 0.10, 0.25, 0.50, 0.75, 0.90, 0.95, 0.99];
fn wasserstein_distance_from_percentiles(
original: &crate::models::Percentiles,
synthetic: &crate::models::Percentiles,
) -> f64 {
let orig_vals = original.to_array();
let syn_vals = synthetic.to_array();
let n_steps = 200;
let dp = 1.0 / n_steps as f64;
let mut integral = 0.0;
let prev_diff = (interp_inv_cdf(0.0, &orig_vals) - interp_inv_cdf(0.0, &syn_vals)).abs();
let mut prev = prev_diff;
for i in 1..=n_steps {
let p = i as f64 * dp;
let orig_q = interp_inv_cdf(p, &orig_vals);
let syn_q = interp_inv_cdf(p, &syn_vals);
let curr = (orig_q - syn_q).abs();
integral += (prev + curr) * dp / 2.0;
prev = curr;
}
integral
}
fn interp_inv_cdf(p: f64, values: &[f64; 9]) -> f64 {
let p = p.clamp(0.0, 1.0);
if p <= PERCENTILE_PROBS[0] {
let slope = if (PERCENTILE_PROBS[1] - PERCENTILE_PROBS[0]).abs() > f64::EPSILON {
(values[1] - values[0]) / (PERCENTILE_PROBS[1] - PERCENTILE_PROBS[0])
} else {
0.0
};
values[0] + slope * (p - PERCENTILE_PROBS[0])
} else if p >= PERCENTILE_PROBS[8] {
let slope = if (PERCENTILE_PROBS[8] - PERCENTILE_PROBS[7]).abs() > f64::EPSILON {
(values[8] - values[7]) / (PERCENTILE_PROBS[8] - PERCENTILE_PROBS[7])
} else {
0.0
};
values[8] + slope * (p - PERCENTILE_PROBS[8])
} else {
for i in 0..8 {
if p >= PERCENTILE_PROBS[i] && p <= PERCENTILE_PROBS[i + 1] {
let frac =
(p - PERCENTILE_PROBS[i]) / (PERCENTILE_PROBS[i + 1] - PERCENTILE_PROBS[i]);
return values[i] + frac * (values[i + 1] - values[i]);
}
}
values[4] }
}
fn js_divergence_from_percentiles(
original: &crate::models::Percentiles,
synthetic: &crate::models::Percentiles,
) -> f64 {
let bin_probs: [f64; 10] = [
0.01, 0.04, 0.05, 0.15, 0.25, 0.25, 0.15, 0.05, 0.04, 0.01, ];
let orig_vals = original.to_array();
let syn_vals = synthetic.to_array();
let orig_edges = percentile_bin_edges(&orig_vals);
let syn_edges = percentile_bin_edges(&syn_vals);
let mut all_edges: Vec<f64> = orig_edges.iter().chain(syn_edges.iter()).copied().collect();
all_edges.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
all_edges.dedup_by(|a, b| (*a - *b).abs() < f64::EPSILON * 100.0);
if all_edges.len() < 2 {
return 0.0;
}
let n_bins = all_edges.len() - 1;
let mut p_masses = vec![0.0; n_bins];
let mut q_masses = vec![0.0; n_bins];
for i in 0..n_bins {
let lo = all_edges[i];
let hi = all_edges[i + 1];
p_masses[i] = probability_mass_in_interval(lo, hi, &orig_vals, &bin_probs);
q_masses[i] = probability_mass_in_interval(lo, hi, &syn_vals, &bin_probs);
}
let p_sum: f64 = p_masses.iter().sum();
let q_sum: f64 = q_masses.iter().sum();
if p_sum > 0.0 {
for m in &mut p_masses {
*m /= p_sum;
}
}
if q_sum > 0.0 {
for m in &mut q_masses {
*m /= q_sum;
}
}
let mut js = 0.0;
for i in 0..n_bins {
let p = p_masses[i];
let q = q_masses[i];
let m = 0.5 * (p + q);
if m > 0.0 {
if p > 0.0 {
js += 0.5 * p * (p / m).ln();
}
if q > 0.0 {
js += 0.5 * q * (q / m).ln();
}
}
}
js.max(0.0)
}
fn percentile_bin_edges(values: &[f64; 9]) -> Vec<f64> {
let margin = if (values[8] - values[0]).abs() > f64::EPSILON {
(values[8] - values[0]) * 0.01
} else {
1.0
};
let mut edges = Vec::with_capacity(11);
edges.push(values[0] - margin); for &v in values.iter() {
edges.push(v);
}
edges.push(values[8] + margin); edges
}
fn probability_mass_in_interval(
lo: f64,
hi: f64,
percentile_vals: &[f64; 9],
bin_probs: &[f64; 10],
) -> f64 {
if hi <= lo {
return 0.0;
}
let margin = if (percentile_vals[8] - percentile_vals[0]).abs() > f64::EPSILON {
(percentile_vals[8] - percentile_vals[0]) * 0.01
} else {
1.0
};
let edges: [f64; 11] = [
percentile_vals[0] - margin, percentile_vals[0], percentile_vals[1], percentile_vals[2], percentile_vals[3], percentile_vals[4], percentile_vals[5], percentile_vals[6], percentile_vals[7], percentile_vals[8], percentile_vals[8] + margin, ];
let mut total_mass = 0.0;
for i in 0..10 {
let bin_lo = edges[i];
let bin_hi = edges[i + 1];
let bin_width = bin_hi - bin_lo;
if bin_width <= 0.0 {
if lo <= bin_lo && bin_lo < hi {
total_mass += bin_probs[i];
}
continue;
}
let overlap_lo = lo.max(bin_lo);
let overlap_hi = hi.min(bin_hi);
if overlap_hi > overlap_lo {
let overlap_fraction = (overlap_hi - overlap_lo) / bin_width;
total_mass += bin_probs[i] * overlap_fraction;
}
}
total_mass
}
fn compute_benford_mad(original: &[f64; 9], synthetic: &[f64; 9]) -> f64 {
let sum: f64 = original
.iter()
.zip(synthetic.iter())
.map(|(&o, &s)| (o - s).abs())
.sum();
sum / 9.0
}
pub fn generate_html_report(report: &FidelityReport) -> String {
let status_class = if report.passes { "pass" } else { "fail" };
let status_text = if report.passes { "PASS" } else { "FAIL" };
format!(
r#"<!DOCTYPE html>
<html>
<head>
<title>Fidelity Report</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 40px; }}
.pass {{ color: green; }}
.fail {{ color: red; }}
.metric {{ margin: 10px 0; }}
.score {{ font-weight: bold; }}
table {{ border-collapse: collapse; width: 100%; }}
th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
th {{ background-color: #4CAF50; color: white; }}
</style>
</head>
<body>
<h1>Fidelity Evaluation Report</h1>
<div class="metric">
<h2>Overall Score: <span class="score {}">{:.1}%</span></h2>
<p>Status: <span class="{}">{}</span></p>
</div>
<h2>Component Scores</h2>
<table>
<tr><th>Component</th><th>Score</th></tr>
<tr><td>Statistical Fidelity</td><td>{:.1}%</td></tr>
<tr><td>Correlation Fidelity</td><td>{:.1}%</td></tr>
<tr><td>Schema Fidelity</td><td>{:.1}%</td></tr>
<tr><td>Rule Compliance</td><td>{:.1}%</td></tr>
<tr><td>Anomaly Fidelity</td><td>{:.1}%</td></tr>
</table>
<h2>Details</h2>
<p>Row count ratio: {:.2}</p>
{}
{}
</body>
</html>"#,
status_class,
report.overall_score * 100.0,
status_class,
status_text,
report.statistical_fidelity * 100.0,
report.correlation_fidelity * 100.0,
report.schema_fidelity * 100.0,
report.rule_compliance * 100.0,
report.anomaly_fidelity * 100.0,
report.details.row_count_ratio,
report
.details
.benford_mad
.map(|m| format!("<p>Benford MAD: {m:.4}</p>"))
.unwrap_or_default(),
report
.details
.correlation_rmse
.map(|r| format!("<p>Correlation RMSE: {r:.4}</p>"))
.unwrap_or_default(),
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::Percentiles;
#[test]
fn test_benford_mad() {
let original = [
0.301, 0.176, 0.125, 0.097, 0.079, 0.067, 0.058, 0.051, 0.046,
];
let synthetic = [
0.301, 0.176, 0.125, 0.097, 0.079, 0.067, 0.058, 0.051, 0.046,
];
let mad = compute_benford_mad(&original, &synthetic);
assert!(mad < 0.001); }
fn make_percentiles(vals: [f64; 9]) -> Percentiles {
Percentiles::from_array(vals)
}
#[test]
fn test_wasserstein_identical_distributions() {
let pcts = make_percentiles([1.0, 5.0, 10.0, 25.0, 50.0, 75.0, 90.0, 95.0, 99.0]);
let w = wasserstein_distance_from_percentiles(&pcts, &pcts);
assert!(
w.abs() < 1e-10,
"Wasserstein distance between identical distributions should be ~0, got {}",
w
);
}
#[test]
fn test_wasserstein_shifted_distributions() {
let orig = make_percentiles([1.0, 5.0, 10.0, 25.0, 50.0, 75.0, 90.0, 95.0, 99.0]);
let shifted = make_percentiles([11.0, 15.0, 20.0, 35.0, 60.0, 85.0, 100.0, 105.0, 109.0]);
let w = wasserstein_distance_from_percentiles(&orig, &shifted);
assert!(
w > 5.0,
"Expected W1 > 5 for shifted distributions, got {}",
w
);
}
#[test]
fn test_wasserstein_symmetry() {
let a = make_percentiles([1.0, 5.0, 10.0, 25.0, 50.0, 75.0, 90.0, 95.0, 99.0]);
let b = make_percentiles([2.0, 8.0, 15.0, 30.0, 55.0, 78.0, 92.0, 97.0, 100.0]);
let w_ab = wasserstein_distance_from_percentiles(&a, &b);
let w_ba = wasserstein_distance_from_percentiles(&b, &a);
assert!(
(w_ab - w_ba).abs() < 1e-10,
"Wasserstein distance should be symmetric: {} vs {}",
w_ab,
w_ba
);
}
#[test]
fn test_wasserstein_constant_shift() {
let orig = make_percentiles([10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0]);
let shifted = make_percentiles([15.0, 25.0, 35.0, 45.0, 55.0, 65.0, 75.0, 85.0, 95.0]);
let w = wasserstein_distance_from_percentiles(&orig, &shifted);
assert!(
(w - 5.0).abs() < 0.5,
"For a constant shift of 5, expected W1 ~ 5.0, got {}",
w
);
}
#[test]
fn test_wasserstein_non_negative() {
let a = make_percentiles([0.0, 1.0, 2.0, 5.0, 10.0, 20.0, 50.0, 80.0, 100.0]);
let b = make_percentiles([1.0, 3.0, 5.0, 10.0, 20.0, 40.0, 60.0, 85.0, 99.0]);
let w = wasserstein_distance_from_percentiles(&a, &b);
assert!(
w >= 0.0,
"Wasserstein distance should be non-negative, got {}",
w
);
}
#[test]
fn test_js_identical_distributions() {
let pcts = make_percentiles([1.0, 5.0, 10.0, 25.0, 50.0, 75.0, 90.0, 95.0, 99.0]);
let js = js_divergence_from_percentiles(&pcts, &pcts);
assert!(
js.abs() < 1e-10,
"JS divergence between identical distributions should be ~0, got {}",
js
);
}
#[test]
fn test_js_symmetry() {
let a = make_percentiles([1.0, 5.0, 10.0, 25.0, 50.0, 75.0, 90.0, 95.0, 99.0]);
let b = make_percentiles([2.0, 8.0, 15.0, 30.0, 55.0, 78.0, 92.0, 97.0, 100.0]);
let js_ab = js_divergence_from_percentiles(&a, &b);
let js_ba = js_divergence_from_percentiles(&b, &a);
assert!(
(js_ab - js_ba).abs() < 1e-10,
"JS divergence should be symmetric: {} vs {}",
js_ab,
js_ba
);
}
#[test]
fn test_js_non_negative() {
let a = make_percentiles([0.0, 1.0, 2.0, 5.0, 10.0, 20.0, 50.0, 80.0, 100.0]);
let b = make_percentiles([1.0, 3.0, 5.0, 10.0, 20.0, 40.0, 60.0, 85.0, 99.0]);
let js = js_divergence_from_percentiles(&a, &b);
assert!(
js >= 0.0,
"JS divergence should be non-negative, got {}",
js
);
}
#[test]
fn test_js_bounded_by_ln2() {
let a = make_percentiles([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
let b = make_percentiles([
100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0,
]);
let js = js_divergence_from_percentiles(&a, &b);
assert!(
js <= std::f64::consts::LN_2 + 0.01,
"JS divergence should be <= ln(2) ~ 0.693, got {}",
js
);
}
#[test]
fn test_js_different_distributions_positive() {
let a = make_percentiles([1.0, 5.0, 10.0, 25.0, 50.0, 75.0, 90.0, 95.0, 99.0]);
let b = make_percentiles([10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0]);
let js = js_divergence_from_percentiles(&a, &b);
assert!(
js > 0.0,
"JS divergence between different distributions should be > 0, got {}",
js
);
}
#[test]
fn test_interp_inv_cdf_at_knots() {
let values = [1.0, 5.0, 10.0, 25.0, 50.0, 75.0, 90.0, 95.0, 99.0];
for (i, &prob) in PERCENTILE_PROBS.iter().enumerate() {
let result = interp_inv_cdf(prob, &values);
assert!(
(result - values[i]).abs() < 1e-10,
"At p={}, expected {}, got {}",
prob,
values[i],
result
);
}
}
#[test]
fn test_interp_inv_cdf_interpolation() {
let values = [0.0, 10.0, 20.0, 30.0, 50.0, 70.0, 80.0, 90.0, 100.0];
let result = interp_inv_cdf(0.03, &values);
assert!((result - 5.0).abs() < 1e-10, "Expected 5.0, got {}", result);
}
#[test]
fn test_compare_numeric_stats_populates_metrics() {
let orig = NumericStats {
count: 1000,
min: 0.0,
max: 100.0,
mean: 50.0,
std_dev: 15.0,
percentiles: make_percentiles([1.0, 5.0, 10.0, 25.0, 50.0, 75.0, 90.0, 95.0, 99.0]),
distribution: crate::models::DistributionType::Normal,
distribution_params: crate::models::DistributionParams::normal(50.0, 15.0),
zero_rate: 0.0,
negative_rate: 0.0,
benford_first_digit: None,
log_magnitude_percentiles: None,
};
let syn = NumericStats {
count: 1000,
min: 2.0,
max: 98.0,
mean: 52.0,
std_dev: 14.0,
percentiles: make_percentiles([2.0, 6.0, 12.0, 27.0, 52.0, 77.0, 91.0, 96.0, 98.0]),
distribution: crate::models::DistributionType::Normal,
distribution_params: crate::models::DistributionParams::normal(52.0, 14.0),
zero_rate: 0.0,
negative_rate: 0.0,
benford_first_digit: None,
log_magnitude_percentiles: None,
};
let evaluator = FidelityEvaluator::new();
let metrics = evaluator.compare_numeric_stats("test_col", &orig, &syn);
assert!(
metrics.wasserstein_distance > 0.0,
"W1 should be positive for different distributions"
);
assert!(
metrics.wasserstein_distance < 10.0,
"W1 should be modest for similar distributions, got {}",
metrics.wasserstein_distance
);
assert!(metrics.js_divergence >= 0.0, "JS should be non-negative");
assert!(
metrics.js_divergence < 0.5,
"JS should be modest for similar distributions, got {}",
metrics.js_divergence
);
}
}