use crate::math::{sqrt_f32, ln_f32};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct GaussPoint {
pub mu: f32,
pub sigma: f32,
}
impl GaussPoint {
#[inline]
pub fn new(mu: f32, sigma: f32) -> Self {
Self { mu, sigma: sigma.max(1e-9) }
}
}
#[inline]
pub fn fisher_rao_distance(p1: GaussPoint, p2: GaussPoint) -> f32 {
let sigma_bar = 0.5 * (p1.sigma + p2.sigma).max(1e-9);
let d_mu = (p2.mu - p1.mu) / sigma_bar;
let d_sigma = (p2.sigma - p1.sigma) / sigma_bar;
sqrt_f32(d_mu * d_mu + 2.0 * d_sigma * d_sigma)
}
pub fn fisher_rao_distance_exact(p1: GaussPoint, p2: GaussPoint) -> f32 {
let sqrt2_inv = 1.0_f32 / sqrt_f32(2.0);
let z1 = p1.mu * sqrt2_inv / p1.sigma.max(1e-9);
let z2 = p2.mu * sqrt2_inv / p2.sigma.max(1e-9);
let r1 = sqrt2_inv / p1.sigma.max(1e-9);
let r2 = sqrt2_inv / p2.sigma.max(1e-9);
let delta_z = z2 - z1;
let delta_r = r2 - r1;
let denom = 2.0 * r1 * r2;
let cosh_ratio = 1.0 + (delta_z * delta_z + delta_r * delta_r) / denom.max(1e-18);
let c = cosh_ratio.max(1.0);
let inner = c + sqrt_f32((c * c - 1.0).max(0.0));
sqrt_f32(2.0) * ln_f32(inner.max(1.0 + 1e-9))
}
pub fn geodesic_curvature(p0: GaussPoint, p1: GaussPoint, p2: GaussPoint) -> f32 {
let d01 = fisher_rao_distance(p0, p1);
let d12 = fisher_rao_distance(p1, p2);
let path_len = d01 + d12;
if path_len < 1e-9 { return 0.0; }
let chord = fisher_rao_distance(p0, p2);
1.0 - (chord / path_len).min(1.0)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DriftGeometry {
Linear,
Settling,
NonLinear,
Oscillatory,
}
impl DriftGeometry {
pub fn classify(kappa: f32) -> Self {
if kappa < 0.05 {
DriftGeometry::Linear
} else if kappa < 0.15 {
DriftGeometry::Settling
} else if kappa < 0.35 {
DriftGeometry::NonLinear
} else {
DriftGeometry::Oscillatory
}
}
pub const fn label(self) -> &'static str {
match self {
DriftGeometry::Linear => "Linear",
DriftGeometry::Settling => "Settling",
DriftGeometry::NonLinear => "NonLinear",
DriftGeometry::Oscillatory => "Oscillatory",
}
}
}
#[derive(Debug, Clone)]
pub struct ManifoldTracker {
prev: Option<GaussPoint>,
cumulative: f32,
step_count: u32,
peak_distance: f32,
}
impl ManifoldTracker {
pub const fn new() -> Self {
Self { prev: None, cumulative: 0.0, step_count: 0, peak_distance: 0.0 }
}
pub fn push(&mut self, p: GaussPoint) -> Option<f32> {
let result = self.prev.map(|prev| {
let d = fisher_rao_distance(prev, p);
self.cumulative += d;
if d > self.peak_distance { self.peak_distance = d; }
d
});
self.prev = Some(p);
self.step_count += 1;
result
}
#[inline]
pub fn cumulative_length(&self) -> f32 { self.cumulative }
#[inline]
pub fn mean_step_distance(&self) -> f32 {
if self.step_count < 2 { 0.0 }
else { self.cumulative / (self.step_count - 1) as f32 }
}
#[inline]
pub fn peak_distance(&self) -> f32 { self.peak_distance }
pub fn reset(&mut self) {
self.prev = None;
self.cumulative = 0.0;
self.step_count = 0;
self.peak_distance = 0.0;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn zero_distance_identical_points() {
let p = GaussPoint::new(0.1, 0.05);
let d = fisher_rao_distance(p, p);
assert!(d.abs() < 1e-6, "identical points: distance = {}", d);
}
#[test]
fn distance_increases_with_mean_separation() {
let base = GaussPoint::new(0.0, 0.05);
let close = GaussPoint::new(0.05, 0.05);
let far = GaussPoint::new(0.15, 0.05);
assert!(fisher_rao_distance(base, far) > fisher_rao_distance(base, close),
"larger mean separation must give larger FR distance");
}
#[test]
fn distance_increases_with_sigma_separation() {
let base = GaussPoint::new(0.0, 0.05);
let sigma1 = GaussPoint::new(0.0, 0.06);
let sigma2 = GaussPoint::new(0.0, 0.10);
assert!(fisher_rao_distance(base, sigma2) > fisher_rao_distance(base, sigma1),
"larger sigma change must give larger FR distance");
}
#[test]
fn drift_geometry_linear_for_constant_rate() {
let p0 = GaussPoint::new(0.00, 0.05);
let p1 = GaussPoint::new(0.01, 0.05);
let p2 = GaussPoint::new(0.02, 0.05);
let kappa = geodesic_curvature(p0, p1, p2);
let geom = DriftGeometry::classify(kappa);
assert_eq!(geom, DriftGeometry::Linear,
"constant-rate mean drift must be Linear: kappa={}", kappa);
}
#[test]
fn drift_geometry_oscillatory_for_reversal() {
let p0 = GaussPoint::new(0.00, 0.05);
let p1 = GaussPoint::new(0.10, 0.05);
let p2 = GaussPoint::new(0.00, 0.05); let kappa = geodesic_curvature(p0, p1, p2);
let geom = DriftGeometry::classify(kappa);
assert!(matches!(geom, DriftGeometry::NonLinear | DriftGeometry::Oscillatory),
"reversal must be NonLinear or Oscillatory: kappa={}", kappa);
}
#[test]
fn manifold_tracker_accumulates_path() {
let mut tracker = ManifoldTracker::new();
assert_eq!(tracker.push(GaussPoint::new(0.0, 0.05)), None);
let d1 = tracker.push(GaussPoint::new(0.01, 0.05)).unwrap();
let d2 = tracker.push(GaussPoint::new(0.02, 0.05)).unwrap();
assert!((tracker.cumulative_length() - d1 - d2).abs() < 1e-6,
"cumulative must equal sum of steps");
assert!(tracker.peak_distance() >= d1.max(d2) - 1e-6);
}
#[test]
fn manifold_tracker_reset_clears_state() {
let mut tracker = ManifoldTracker::new();
tracker.push(GaussPoint::new(0.0, 0.05));
tracker.push(GaussPoint::new(0.1, 0.1));
tracker.reset();
assert_eq!(tracker.push(GaussPoint::new(0.0, 0.05)), None,
"after reset, first push returns None");
assert_eq!(tracker.cumulative_length(), 0.0);
}
#[test]
fn exact_distance_consistent_with_approx() {
let p1 = GaussPoint::new(0.0, 0.1);
let p2 = GaussPoint::new(0.1, 0.12);
let approx = fisher_rao_distance(p1, p2);
let exact = fisher_rao_distance_exact(p1, p2);
let ratio = (exact / approx.max(1e-9)).max(0.5);
assert!(ratio > 0.3 && ratio < 3.0,
"approx and exact should be within order of magnitude: approx={:.4} exact={:.4}",
approx, exact);
}
#[test]
fn drift_geometry_label_correct() {
assert_eq!(DriftGeometry::Linear.label(), "Linear");
assert_eq!(DriftGeometry::NonLinear.label(), "NonLinear");
assert_eq!(DriftGeometry::Oscillatory.label(), "Oscillatory");
}
#[test]
fn robust_mode_mad_gives_nonzero_sigma() {
let samples = [0.1f32, 0.1, 0.1, 0.1, 5.0]; let p = GaussPointRobust::from_samples_mad(&samples).unwrap();
assert!(p.sigma < 0.5, "MAD sigma must not be inflated by outlier: {}", p.sigma);
let gaussian_std = {
let mean: f32 = samples.iter().sum::<f32>() / samples.len() as f32;
let var: f32 = samples.iter().map(|&x| (x - mean).powi(2)).sum::<f32>()
/ samples.len() as f32;
crate::math::sqrt_f32(var)
};
assert!(p.sigma < gaussian_std * 0.5,
"MAD-robust sigma must be less than sample std: MAD={} samplestd={}", p.sigma, gaussian_std);
}
#[test]
fn robust_mode_from_empty_returns_none() {
let r = GaussPointRobust::from_samples_mad(&[]);
assert!(r.is_none());
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum RobustManifoldMode {
#[default]
Gaussian,
MadRegularized,
}
#[derive(Debug, Clone, Copy)]
pub struct GaussPointRobust {
pub mu: f32,
pub sigma: f32,
pub mode: RobustManifoldMode,
}
impl GaussPointRobust {
pub fn from_samples_gaussian(samples: &[f32]) -> Option<Self> {
if samples.is_empty() { return None; }
let n = samples.len() as f32;
let mu: f32 = samples.iter().sum::<f32>() / n;
let var: f32 = samples.iter().map(|&x| (x - mu) * (x - mu)).sum::<f32>() / n;
let sigma = crate::math::sqrt_f32(var).max(1e-9);
Some(Self { mu, sigma, mode: RobustManifoldMode::Gaussian })
}
pub fn from_samples_mad(samples: &[f32]) -> Option<Self> {
if samples.is_empty() { return None; }
let n = samples.len();
let cap = 256usize;
let nc = n.min(cap);
let mut scratch = [0.0f32; 256];
for i in 0..nc { scratch[i] = samples[i]; }
for i in 1..nc {
let k = scratch[i];
let mut j = i;
while j > 0 && scratch[j - 1] > k { scratch[j] = scratch[j - 1]; j -= 1; }
scratch[j] = k;
}
let mu = if nc % 2 == 1 {
scratch[nc / 2]
} else {
(scratch[nc / 2 - 1] + scratch[nc / 2]) * 0.5
};
let mut devs = [0.0f32; 256];
for i in 0..nc {
let d = samples[i] - mu;
devs[i] = if d < 0.0 { -d } else { d };
}
for i in 1..nc {
let k = devs[i];
let mut j = i;
while j > 0 && devs[j - 1] > k { devs[j] = devs[j - 1]; j -= 1; }
devs[j] = k;
}
let mad = if nc % 2 == 1 { devs[nc / 2] } else { (devs[nc / 2 - 1] + devs[nc / 2]) * 0.5 };
const MAD_SCALE: f32 = 1.482_602_2; let sigma = (MAD_SCALE * mad).max(1e-9_f32);
Some(Self { mu, sigma, mode: RobustManifoldMode::MadRegularized })
}
pub fn to_gauss_point(self) -> GaussPoint {
GaussPoint::new(self.mu, self.sigma)
}
}