use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
#[non_exhaustive]
#[derive(Default)]
pub enum KernelType {
#[default]
Rbf,
Imq {
c: f64,
beta: f64,
},
Polynomial {
alpha: f64,
c: f64,
degree: u32,
},
}
#[derive(Debug, Clone)]
pub struct KsdConfig {
pub kernel: KernelType,
pub bandwidth: Option<f64>,
pub n_bootstrap: usize,
}
impl Default for KsdConfig {
fn default() -> Self {
Self {
kernel: KernelType::Rbf,
bandwidth: None,
n_bootstrap: 1000,
}
}
}
#[derive(Debug, Clone)]
pub struct KsdResult {
pub statistic: f64,
pub p_value: Option<f64>,
pub rejected: Option<bool>,
pub kernel: KernelType,
pub bandwidth: f64,
}
#[derive(Debug, Clone)]
pub struct SinkhornConfig {
pub epsilon: f64,
pub max_iter: usize,
pub tol: f64,
pub log_domain: bool,
}
impl Default for SinkhornConfig {
fn default() -> Self {
Self {
epsilon: 0.1,
max_iter: 1000,
tol: 1e-9,
log_domain: true,
}
}
}
#[derive(Debug, Clone)]
pub struct SinkhornResult {
pub divergence: f64,
pub transport_plan: Vec<Vec<f64>>,
pub converged: bool,
pub iterations: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum DistanceMethod {
TotalVariation,
Hellinger,
KullbackLeibler,
JensenShannon,
ChiSquare,
Energy,
Wasserstein,
SlicedWasserstein,
Sinkhorn,
}
#[derive(Debug, Clone)]
pub struct DistanceResult {
pub value: f64,
pub method: DistanceMethod,
}
impl DistanceResult {
pub fn new(value: f64, method: DistanceMethod) -> Self {
Self { value, method }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kernel_type_default_is_rbf() {
assert_eq!(KernelType::default(), KernelType::Rbf);
}
#[test]
fn test_sinkhorn_config_default() {
let cfg = SinkhornConfig::default();
assert!((cfg.epsilon - 0.1).abs() < 1e-12);
assert_eq!(cfg.max_iter, 1000);
assert!(cfg.log_domain);
}
#[test]
fn test_ksd_config_default() {
let cfg = KsdConfig::default();
assert_eq!(cfg.kernel, KernelType::Rbf);
assert!(cfg.bandwidth.is_none());
assert_eq!(cfg.n_bootstrap, 1000);
}
#[test]
fn test_distance_result_new() {
let r = DistanceResult::new(0.42, DistanceMethod::Hellinger);
assert!((r.value - 0.42).abs() < 1e-12);
assert_eq!(r.method, DistanceMethod::Hellinger);
}
#[test]
fn test_distance_method_non_exhaustive() {
let methods = [
DistanceMethod::TotalVariation,
DistanceMethod::Hellinger,
DistanceMethod::KullbackLeibler,
DistanceMethod::JensenShannon,
DistanceMethod::ChiSquare,
DistanceMethod::Energy,
DistanceMethod::Wasserstein,
DistanceMethod::SlicedWasserstein,
DistanceMethod::Sinkhorn,
];
for m in &methods {
let _ = format!("{m:?}");
}
}
}