use cartan_core::{Manifold, Real};
use crate::result::OptResult;
#[derive(Debug, Clone)]
pub struct FrechetConfig {
pub max_iters: usize,
pub tol: Real,
pub step_size: Real,
}
impl Default for FrechetConfig {
fn default() -> Self {
Self {
max_iters: 200,
tol: 1e-8,
step_size: 1.0,
}
}
}
pub fn frechet_mean<M>(
manifold: &M,
points: &[M::Point],
init: Option<M::Point>,
config: &FrechetConfig,
) -> OptResult<M::Point>
where
M: Manifold,
{
assert!(!points.is_empty(), "frechet_mean: points must be non-empty");
let n = points.len();
let mut mu = init.unwrap_or_else(|| points[0].clone());
let variance = |p: &M::Point| -> Real {
points
.iter()
.filter_map(|xi| manifold.dist(p, xi).ok())
.map(|d| d * d)
.sum::<Real>()
/ n as Real
};
for iter in 0..config.max_iters {
let mut velocity = manifold.zero_tangent(&mu);
let mut count = 0usize;
for xi in points {
if let Ok(log_i) = manifold.log(&mu, xi) {
velocity = velocity + log_i;
count += 1;
}
}
if count == 0 {
let v = variance(&mu);
return OptResult {
point: mu,
value: v,
grad_norm: Real::NAN,
iterations: iter,
converged: false,
};
}
velocity = velocity * (1.0 / count as Real);
let vel_norm = manifold.norm(&mu, &velocity);
if vel_norm < config.tol {
let v = variance(&mu);
return OptResult {
point: mu,
value: v,
grad_norm: vel_norm,
iterations: iter,
converged: true,
};
}
mu = manifold.exp(&mu, &(velocity * config.step_size));
}
let v = variance(&mu);
let g_sq = points
.iter()
.filter_map(|xi| manifold.log(&mu, xi).ok())
.fold(manifold.zero_tangent(&mu), |acc, l| acc + l);
let g_sq_norm = manifold.norm(&mu, &(g_sq * (1.0 / n as Real)));
OptResult {
point: mu,
value: v,
grad_norm: g_sq_norm,
iterations: config.max_iters,
converged: false,
}
}