use crate::error::{StatsError, StatsResult};
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::random::prelude::*;
use scirs2_core::random::Uniform as RandUniform;
fn bessel_i(nu: f64, x: f64) -> f64 {
if x < 0.0 {
return f64::NAN;
}
if x == 0.0 {
return if nu == 0.0 { 1.0 } else { 0.0 };
}
if x > 30.0 {
let half_x = x / 2.0;
let mut term = half_x.powf(nu) / gamma_fn(nu + 1.0) * (-x).exp().recip() * x.exp();
let leading = x.exp() / (2.0 * std::f64::consts::PI * x).sqrt();
let mu = 4.0 * nu * nu;
let correction = 1.0 - (mu - 1.0) / (8.0 * x);
let _ = term; return leading * correction;
}
let half_x = x / 2.0;
let mut sum = 0.0_f64;
let mut m = 0_u64;
let mut term = half_x.powf(nu) / gamma_fn(nu + 1.0);
while term.abs() > 1e-15 * sum.abs().max(1e-300) && m < 200 {
sum += term;
m += 1;
term *= (half_x * half_x) / (m as f64 * (m as f64 + nu));
}
sum
}
fn a_p(p: usize, kappa: f64) -> f64 {
let half_p = p as f64 / 2.0;
bessel_i(half_p, kappa) / bessel_i(half_p - 1.0, kappa)
}
fn log_c_p(p: usize, kappa: f64) -> f64 {
let half_p = p as f64 / 2.0;
let nu = half_p - 1.0;
let log_bessel = bessel_i(nu, kappa).ln();
(half_p - 1.0) * kappa.ln() - half_p * (2.0 * std::f64::consts::PI).ln() - log_bessel
}
fn ln_gamma(x: f64) -> f64 {
let coeffs = [
0.99999999999980993,
676.5203681218851,
-1259.1392167224028,
771.32342877765313,
-176.61502916214059,
12.507343278686905,
-0.13857109526572012,
9.9843695780195716e-6,
1.5056327351493116e-7,
];
if x < 0.5 {
let pi = std::f64::consts::PI;
return pi.ln() - (pi * x).sin().ln() - ln_gamma(1.0 - x);
}
let xm1 = x - 1.0;
let mut s = coeffs[0];
for (k, &c) in coeffs[1..].iter().enumerate() {
s += c / (xm1 + k as f64 + 1.0);
}
let t = xm1 + 7.5; 0.5 * (2.0 * std::f64::consts::PI).ln() + (xm1 + 0.5) * t.ln() - t + s.ln()
}
fn gamma_fn(x: f64) -> f64 {
ln_gamma(x).exp()
}
pub struct VonMisesFisher {
pub mu: Array1<f64>,
pub kappa: f64,
pub dim: usize,
log_norm_const: f64,
uniform_distr: RandUniform<f64>,
}
impl VonMisesFisher {
pub fn new(mu: Array1<f64>, kappa: f64) -> StatsResult<Self> {
let p = mu.len();
if p < 2 {
return Err(StatsError::InvalidArgument(
"Dimension must be at least 2".to_string(),
));
}
if kappa < 0.0 {
return Err(StatsError::DomainError(
"Concentration kappa must be non-negative".to_string(),
));
}
if !mu.iter().all(|v| v.is_finite()) {
return Err(StatsError::DomainError(
"Mean direction mu must be finite".to_string(),
));
}
let norm = mu.iter().map(|&v| v * v).sum::<f64>().sqrt();
if norm < 1e-12 {
return Err(StatsError::DomainError(
"Mean direction mu must be non-zero".to_string(),
));
}
let mu_unit = mu / norm;
let log_nc = if kappa == 0.0 {
0.5 * (p as f64) * std::f64::consts::PI.ln()
- ln_gamma(p as f64 / 2.0)
- (2.0_f64).ln()
} else {
log_c_p(p, kappa)
};
let uniform_distr = RandUniform::new(0.0_f64, 1.0_f64).map_err(|_| {
StatsError::ComputationError(
"Failed to create uniform distribution for vMF sampling".to_string(),
)
})?;
Ok(Self {
mu: mu_unit,
kappa,
dim: p,
log_norm_const: log_nc,
uniform_distr,
})
}
pub fn log_pdf(&self, x: &Array1<f64>) -> f64 {
if x.len() != self.dim {
return f64::NEG_INFINITY;
}
let dot = x.iter().zip(self.mu.iter()).map(|(&xi, &mi)| xi * mi).sum::<f64>();
self.log_norm_const + self.kappa * dot
}
pub fn pdf(&self, x: &Array1<f64>) -> f64 {
self.log_pdf(x).exp()
}
pub fn mean(&self) -> Array1<f64> {
let r = if self.kappa == 0.0 {
0.0
} else {
a_p(self.dim, self.kappa)
};
self.mu.mapv(|mi| r * mi)
}
pub fn mean_resultant_length(&self) -> f64 {
if self.kappa == 0.0 {
0.0
} else {
a_p(self.dim, self.kappa)
}
}
pub fn entropy(&self) -> f64 {
-self.log_norm_const - self.kappa * self.mean_resultant_length()
}
pub fn sample_one<R: Rng + ?Sized>(&self, rng: &mut R) -> Array1<f64> {
if self.dim == 2 {
self.sample_2d(rng)
} else {
self.sample_wood(rng)
}
}
pub fn rvs<R: Rng + ?Sized>(&self, n: usize, rng: &mut R) -> StatsResult<Array2<f64>> {
let mut samples = Array2::zeros((n, self.dim));
for i in 0..n {
let s = self.sample_one(rng);
samples.row_mut(i).assign(&s);
}
Ok(samples)
}
fn sample_2d<R: Rng + ?Sized>(&self, rng: &mut R) -> Array1<f64> {
let mu_angle = self.mu[1].atan2(self.mu[0]);
let a = 1.0 + (1.0 + 4.0 * self.kappa * self.kappa).sqrt();
let b = (a - (2.0 * a).sqrt()) / (2.0 * self.kappa);
let r = (1.0 + b * b) / (2.0 * b);
loop {
let u1: f64 = self.uniform_distr.sample(rng);
let u2: f64 = self.uniform_distr.sample(rng);
let u3: f64 = self.uniform_distr.sample(rng);
let z = (1.0 - u1) * b + u1; let z = u1.cos() * std::f64::consts::PI; let f_val = (1.0 + r * z) / (r + z);
let c = self.kappa * (r - f_val);
if c * (2.0 - c) - u2 >= 0.0 || c.ln() + 1.0 - c >= u2.ln() {
let theta = if u3 - 0.5 >= 0.0 { f_val.acos() } else { -f_val.acos() };
let angle = theta + mu_angle;
return Array1::from_vec(vec![angle.cos(), angle.sin()]);
}
let _ = z; }
}
fn sample_wood<R: Rng + ?Sized>(&self, rng: &mut R) -> Array1<f64> {
let p = self.dim;
let kappa = self.kappa;
let b = (-2.0 * kappa + (4.0 * kappa * kappa + (p - 1) as f64 * (p - 1) as f64).sqrt())
/ ((p - 1) as f64);
let x0 = (1.0 - b) / (1.0 + b);
let c = kappa * x0 + (p as f64 - 1.0) * (1.0 - x0 * x0).ln();
let w = loop {
let z: f64 = {
let u1: f64 = self.uniform_distr.sample(rng);
let u2: f64 = self.uniform_distr.sample(rng);
let a = (p as f64 - 1.0) / 2.0;
if a >= 1.0 {
let alpha_bb = a + a;
let beta_bb = (alpha_bb - 2.0 * a + 2.0).sqrt();
let gamma_bb = a - (2.0_f64).ln();
let delta_bb = a + 1.0 / beta_bb;
let k1 = gamma_bb + (a - 0.5).ln() - (2.0_f64 / (a + 0.5)).ln();
let k2 = a + 1.0;
let u = u1;
let v = u2;
let _ = (k1, k2, delta_bb, alpha_bb, beta_bb, gamma_bb);
let beta_sample =
self.sample_beta_symmetric(a, rng);
beta_sample
} else {
u1.powf(1.0 / a) / (u1.powf(1.0 / a) + u2.powf(1.0 / a))
}
};
let w_candidate = (1.0 - (1.0 + b) * z) / (1.0 - (1.0 - b) * z);
let u: f64 = self.uniform_distr.sample(rng);
let log_accept = kappa * w_candidate + (p as f64 - 1.0) * (1.0 - w_candidate * w_candidate).ln() - c;
if log_accept >= u.ln() {
break w_candidate;
}
};
let v = self.sample_uniform_sphere_minus1(rng);
let sqrt_1mw2 = (1.0 - w * w).max(0.0).sqrt();
let mut x = Array1::zeros(p);
x[0] = w;
for i in 1..p {
x[i] = sqrt_1mw2 * v[i - 1];
}
self.householder_rotate(&x)
}
fn sample_beta_symmetric<R: Rng + ?Sized>(&self, a: f64, rng: &mut R) -> f64 {
let g1 = self.sample_gamma(a, rng);
let g2 = self.sample_gamma(a, rng);
g1 / (g1 + g2)
}
fn sample_gamma<R: Rng + ?Sized>(&self, shape: f64, rng: &mut R) -> f64 {
if shape < 1.0 {
let u: f64 = self.uniform_distr.sample(rng);
return self.sample_gamma(1.0 + shape, rng) * u.powf(1.0 / shape);
}
let d = shape - 1.0 / 3.0;
let c = 1.0 / (9.0 * d).sqrt();
loop {
let u1: f64 = self.uniform_distr.sample(rng);
let u2: f64 = self.uniform_distr.sample(rng);
let z = (-2.0 * u1.max(f64::EPSILON).ln()).sqrt()
* (2.0 * std::f64::consts::PI * u2).cos();
let v = (1.0 + c * z).powi(3);
if v <= 0.0 {
continue;
}
let u3: f64 = self.uniform_distr.sample(rng);
if u3 < 1.0 - 0.0331 * z.powi(4)
|| u3.ln() < 0.5 * z * z + d * (1.0 - v + v.ln())
{
return d * v;
}
}
}
fn sample_uniform_sphere_minus1<R: Rng + ?Sized>(&self, rng: &mut R) -> Array1<f64> {
let dim = self.dim - 1;
let mut v = Array1::zeros(dim);
loop {
let mut norm_sq = 0.0_f64;
for i in 0..dim {
let u1: f64 = self.uniform_distr.sample(rng).max(f64::EPSILON);
let u2: f64 = self.uniform_distr.sample(rng);
let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
v[i] = z;
norm_sq += z * z;
}
let norm = norm_sq.sqrt();
if norm > 1e-12 {
v /= norm;
break;
}
}
v
}
fn householder_rotate(&self, x: &Array1<f64>) -> Array1<f64> {
let p = self.dim;
let mu = &self.mu;
let mut u = Array1::zeros(p);
u[0] = 1.0 - mu[0];
for i in 1..p {
u[i] = -mu[i];
}
let norm_u_sq: f64 = u.iter().map(|&v| v * v).sum();
if norm_u_sq < 1e-24 {
return x.clone();
}
let dot = x.iter().zip(u.iter()).map(|(&xi, &ui)| xi * ui).sum::<f64>();
let scale = 2.0 * dot / norm_u_sq;
let mut result = x.clone();
for i in 0..p {
result[i] -= scale * u[i];
}
result
}
pub fn fit_mle(data: &Array2<f64>) -> StatsResult<(Array1<f64>, f64)> {
let (n, p) = data.dim();
if n < 2 {
return Err(StatsError::InsufficientData(
"Need at least 2 observations".to_string(),
));
}
if p < 2 {
return Err(StatsError::InvalidArgument(
"Dimension must be at least 2".to_string(),
));
}
let mut mean = Array1::<f64>::zeros(p);
for row in data.rows() {
mean = mean + row;
}
mean /= n as f64;
let r_bar = mean.iter().map(|&v| v * v).sum::<f64>().sqrt();
if r_bar < 1e-12 {
return Err(StatsError::ComputationError(
"Data is too dispersed to estimate mean direction".to_string(),
));
}
let mu_hat = mean / r_bar;
let kappa_hat = r_bar * (p as f64 - r_bar * r_bar) / (1.0 - r_bar * r_bar);
let kappa_hat = kappa_hat.max(0.0);
Ok((mu_hat, kappa_hat))
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
use scirs2_core::random::{SmallRng, SeedableRng};
#[test]
fn test_log_pdf_at_mean() {
let mu = array![0.0f64, 0.0, 1.0];
let vmf = VonMisesFisher::new(mu.clone(), 5.0).expect("valid params");
let log_p_at_mean = vmf.log_pdf(&mu);
assert!(log_p_at_mean.is_finite());
let anti = array![0.0f64, 0.0, -1.0];
let log_p_anti = vmf.log_pdf(&anti);
assert!(log_p_at_mean > log_p_anti);
}
#[test]
fn test_kappa_zero_is_uniform() {
let mu = array![1.0f64, 0.0, 0.0];
let vmf = VonMisesFisher::new(mu, 0.0).expect("valid params");
let x1 = array![1.0f64, 0.0, 0.0];
let x2 = array![0.0f64, 1.0, 0.0];
let x3 = array![0.0f64, 0.0, 1.0];
let d1 = vmf.pdf(&x1);
let d2 = vmf.pdf(&x2);
let d3 = vmf.pdf(&x3);
assert!((d1 - d2).abs() < 1e-6);
assert!((d1 - d3).abs() < 1e-6);
}
#[test]
fn test_samples_on_sphere() {
let mut rng = SmallRng::seed_from_u64(42);
let mu = array![0.0f64, 0.0, 1.0];
let vmf = VonMisesFisher::new(mu, 10.0).expect("valid params");
let samples = vmf.rvs(200, &mut rng).expect("sampling should succeed");
for row in samples.rows() {
let norm_sq: f64 = row.iter().map(|&v| v * v).sum();
assert!(
(norm_sq.sqrt() - 1.0).abs() < 1e-10,
"sample not on sphere: norm={}",
norm_sq.sqrt()
);
}
}
#[test]
fn test_samples_concentrated_near_mu() {
let mut rng = SmallRng::seed_from_u64(99);
let mu = array![0.0f64, 0.0, 1.0];
let vmf = VonMisesFisher::new(mu.clone(), 50.0).expect("valid params");
let samples = vmf.rvs(500, &mut rng).expect("sampling should succeed");
let mut avg_dot = 0.0_f64;
for row in samples.rows() {
let dot: f64 = row.iter().zip(mu.iter()).map(|(&xi, &mi)| xi * mi).sum();
avg_dot += dot;
}
avg_dot /= 500.0;
assert!(avg_dot > 0.9, "avg_dot={}", avg_dot);
}
#[test]
fn test_mean_and_entropy() {
let mu = array![1.0f64, 0.0];
let vmf = VonMisesFisher::new(mu.clone(), 3.0).expect("valid params");
let mean = vmf.mean();
assert!(mean[0] > 0.0);
let entropy = vmf.entropy();
assert!(entropy.is_finite());
}
#[test]
fn test_fit_mle() {
let mut rng = SmallRng::seed_from_u64(7);
let mu = array![0.0f64, 1.0, 0.0];
let vmf = VonMisesFisher::new(mu.clone(), 8.0).expect("valid params");
let samples = vmf.rvs(500, &mut rng).expect("sampling should succeed");
let (mu_hat, kappa_hat) = VonMisesFisher::fit_mle(&samples).expect("fit should succeed");
let dot: f64 = mu_hat.iter().zip(mu.iter()).map(|(&a, &b)| a * b).sum();
assert!(dot > 0.9, "mean direction dot product too low: {}", dot);
assert!(kappa_hat > 3.0 && kappa_hat < 20.0, "kappa_hat={}", kappa_hat);
}
}