use scirs2_core::ndarray::Array1;
use scirs2_metrics::anomaly::{
js_divergence, kl_divergence, maximum_mean_discrepancy, wasserstein_distance,
};
use scirs2_metrics::distance::{cosine_similarity, euclidean_distance};
fn simulate_feature_extractor(n: usize, shift: f64, scale: f64) -> Vec<f64> {
(0..n)
.map(|i| {
let x = (i as f64) / (n as f64);
scale * (std::f64::consts::PI * x).sin() + shift
})
.collect()
}
fn histogram(samples: &[f64], n_bins: usize) -> Vec<f64> {
let min = samples.iter().cloned().fold(f64::INFINITY, f64::min);
let max = samples.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let width = (max - min) / n_bins as f64;
let mut counts = vec![0usize; n_bins];
for &v in samples {
let bin = ((v - min) / width) as usize;
let bin = bin.min(n_bins - 1);
counts[bin] += 1;
}
let total = samples.len() as f64;
counts.iter().map(|&c| c as f64 / total).collect()
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("=== Integration: Transformation Outputs → Distribution Metrics ===\n");
let n: usize = 200;
let n_bins: usize = 20;
let source = simulate_feature_extractor(n, 0.0, 1.0);
let transformed_good = simulate_feature_extractor(n, 0.05, 1.02);
let transformed_bad = simulate_feature_extractor(n, 1.5, 2.0);
let eps = 1e-9;
let source_hist: Vec<f64> = histogram(&source, n_bins)
.into_iter()
.map(|v| v + eps)
.collect();
let good_hist: Vec<f64> = histogram(&transformed_good, n_bins)
.into_iter()
.map(|v| v + eps)
.collect();
let bad_hist: Vec<f64> = histogram(&transformed_bad, n_bins)
.into_iter()
.map(|v| v + eps)
.collect();
let renorm = |h: Vec<f64>| -> Vec<f64> {
let s: f64 = h.iter().sum();
h.into_iter().map(|v| v / s).collect()
};
let source_pmf = renorm(source_hist);
let good_pmf = renorm(good_hist);
let bad_pmf = renorm(bad_hist);
let src_arr = Array1::from(source.clone());
let good_arr = Array1::from(transformed_good.clone());
let bad_arr = Array1::from(transformed_bad.clone());
let src_pmf_arr = Array1::from(source_pmf.clone());
let good_pmf_arr = Array1::from(good_pmf.clone());
let bad_pmf_arr = Array1::from(bad_pmf.clone());
let w1_good: f64 = wasserstein_distance(&src_arr, &good_arr)?;
let w1_bad: f64 = wasserstein_distance(&src_arr, &bad_arr)?;
let mmd_good: f64 = maximum_mean_discrepancy(&src_arr, &good_arr, None)?;
let mmd_bad: f64 = maximum_mean_discrepancy(&src_arr, &bad_arr, None)?;
let kl_good: f64 = kl_divergence(&src_pmf_arr, &good_pmf_arr)?;
let kl_bad: f64 = kl_divergence(&src_pmf_arr, &bad_pmf_arr)?;
let js_good: f64 = js_divergence(&src_pmf_arr, &good_pmf_arr)?;
let js_bad: f64 = js_divergence(&src_pmf_arr, &bad_pmf_arr)?;
let src_pt = Array1::from(source_pmf.clone());
let good_pt = Array1::from(good_pmf.clone());
let bad_pt = Array1::from(bad_pmf.clone());
let euc_good: f64 = euclidean_distance(&src_pt, &good_pt)?;
let euc_bad: f64 = euclidean_distance(&src_pt, &bad_pt)?;
let cos_good: f64 = cosine_similarity(&src_pt, &good_pt)?;
let cos_bad: f64 = cosine_similarity(&src_pt, &bad_pt)?;
println!("Source: sin-feature extractor, n={n}");
println!("Good transform: shift=0.05, scale=1.02 (small domain gap)");
println!("Bad transform: shift=1.5, scale=2.0 (large domain gap)\n");
println!(
"{:<38} {:>12} {:>12}",
"Metric (source vs transform)", "Good", "Bad"
);
println!("{}", "-".repeat(66));
println!(
"{:<38} {:>12.6} {:>12.6}",
"Wasserstein-1 (↓ → aligned)", w1_good, w1_bad
);
println!(
"{:<38} {:>12.6} {:>12.6}",
"MMD (↓ → aligned)", mmd_good, mmd_bad
);
println!(
"{:<38} {:>12.6} {:>12.6}",
"KL divergence KL(P||Q) (↓)", kl_good, kl_bad
);
println!(
"{:<38} {:>12.6} {:>12.6}",
"JS divergence (↓, max=ln2)", js_good, js_bad
);
println!(
"{:<38} {:>12.6} {:>12.6}",
"Euclidean dist (histogram)", euc_good, euc_bad
);
println!(
"{:<38} {:>12.6} {:>12.6}",
"Cosine similarity (↑ → similar)", cos_good, cos_bad
);
println!("\n--- Tips for transform pipeline integration ---");
println!(" Use Wasserstein or MMD to detect dataset shift before/after transform");
println!(" KL/JS divergence works on histogram representations of embeddings");
println!(" Cosine similarity on aggregate histograms captures shape similarity");
println!(" Threshold: Wasserstein > 0.1 typically indicates meaningful shift");
println!("\n=== Done ===");
Ok(())
}