use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use crate::geometry::manifold::{
GeometryError, GeometryResult, RiemannianManifold, check_len, cholesky_spd, dot, flatten,
from_flat, inverse, spectral_map_spd, spectral_map_symmetric, sym,
tangent_basis_metric_orthonormal,
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SpdManifold {
n: usize,
}
impl SpdManifold {
const SYM_REL_TOL: f64 = 1.0e-9;
pub const fn new(n: usize) -> Self {
Self { n }
}
fn matrix(&self, point: ArrayView1<'_, f64>) -> GeometryResult<Array2<f64>> {
let raw = from_flat(point, self.n, self.n)?;
let mut max_abs = 0.0_f64;
let mut max_asym = 0.0_f64;
for i in 0..self.n {
for j in 0..self.n {
max_abs = max_abs.max(raw[[i, j]].abs());
max_asym = max_asym.max((raw[[i, j]] - raw[[j, i]]).abs());
}
}
if !max_asym.is_finite() || max_asym > Self::SYM_REL_TOL * max_abs.max(1.0) {
return Err(GeometryError::InvalidPoint(
"SPD point must be a symmetric matrix",
));
}
let p = sym(&raw);
cholesky_spd(&p)?;
Ok(p)
}
fn affine_inner(
&self,
p: &Array2<f64>,
u: &Array2<f64>,
v: &Array2<f64>,
) -> GeometryResult<f64> {
use crate::linalg::faer_ndarray::fast_ab;
let pinv = inverse(p)?;
let a = fast_ab(&fast_ab(&fast_ab(&pinv, u), &pinv), v);
let mut trace = 0.0;
for i in 0..self.n {
trace += a[[i, i]];
}
Ok(trace)
}
}
impl RiemannianManifold for SpdManifold {
fn dim(&self) -> usize {
self.n * (self.n + 1) / 2
}
fn ambient_dim(&self) -> usize {
self.n * self.n
}
fn tangent_basis(&self, point: ArrayView1<'_, f64>) -> GeometryResult<Array2<f64>> {
check_len("SPD point", point.len(), self.ambient_dim())?;
tangent_basis_metric_orthonormal(self, point, self.n, self.n)
}
fn exp_map(
&self,
point: ArrayView1<'_, f64>,
tangent_vec: ArrayView1<'_, f64>,
) -> GeometryResult<Array1<f64>> {
use crate::linalg::faer_ndarray::fast_ab;
let p = self.matrix(point)?;
let u = sym(&from_flat(tangent_vec, self.n, self.n)?);
let sqrt_p = spectral_map_spd(&p, |x| Ok(x.sqrt()))?;
let inv_sqrt_p = spectral_map_spd(&p, |x| Ok(1.0 / x.sqrt()))?;
let middle = fast_ab(&fast_ab(&inv_sqrt_p, &u), &inv_sqrt_p);
let exp_middle = spectral_map_symmetric(&middle, |x| Ok(x.exp()))?;
Ok(flatten(&sym(&fast_ab(
&fast_ab(&sqrt_p, &exp_middle),
&sqrt_p,
))))
}
fn log_map(
&self,
p_from: ArrayView1<'_, f64>,
p_to: ArrayView1<'_, f64>,
) -> GeometryResult<Array1<f64>> {
use crate::linalg::faer_ndarray::fast_ab;
let p = self.matrix(p_from)?;
let q = self.matrix(p_to)?;
let sqrt_p = spectral_map_spd(&p, |x| Ok(x.sqrt()))?;
let inv_sqrt_p = spectral_map_spd(&p, |x| Ok(1.0 / x.sqrt()))?;
let middle = fast_ab(&fast_ab(&inv_sqrt_p, &q), &inv_sqrt_p);
let log_middle = spectral_map_spd(&middle, |x| Ok(x.ln()))?;
Ok(flatten(&sym(&fast_ab(
&fast_ab(&sqrt_p, &log_middle),
&sqrt_p,
))))
}
fn parallel_transport(
&self,
point_along: ArrayView2<'_, f64>,
vec: ArrayView1<'_, f64>,
) -> GeometryResult<Array1<f64>> {
check_len("SPD transported vector", vec.len(), self.ambient_dim())?;
if point_along.nrows() < 2 {
return Ok(flatten(&sym(&from_flat(vec, self.n, self.n)?)));
}
let p = self.matrix(point_along.row(0))?;
let q = self.matrix(point_along.row(point_along.nrows() - 1))?;
use crate::linalg::faer_ndarray::{fast_ab, fast_abt};
let u = sym(&from_flat(vec, self.n, self.n)?);
let inv_sqrt_p = spectral_map_spd(&p, |x| Ok(1.0 / x.sqrt()))?;
let middle = fast_ab(&fast_ab(&inv_sqrt_p, &q), &inv_sqrt_p);
let e = spectral_map_spd(&middle, |x| Ok(x.sqrt()))?;
let sqrt_p = spectral_map_spd(&p, |x| Ok(x.sqrt()))?;
let a = fast_ab(&fast_ab(&sqrt_p, &e), &inv_sqrt_p);
Ok(flatten(&sym(&fast_abt(&fast_ab(&a, &u), &a))))
}
fn metric_tensor(&self, point: ArrayView1<'_, f64>) -> GeometryResult<Array2<f64>> {
let p = self.matrix(point)?;
let pinv = inverse(&p)?;
let ambient = self.ambient_dim();
let mut g = Array2::<f64>::zeros((ambient, ambient));
for i in 0..self.n {
for j in 0..self.n {
for k in 0..self.n {
for l in 0..self.n {
g[[i * self.n + j, k * self.n + l]] = pinv[[i, k]] * pinv[[l, j]];
}
}
}
}
Ok(g)
}
fn christoffel_symbols(&self, point: ArrayView1<'_, f64>) -> GeometryResult<Vec<Array2<f64>>> {
let p = self.matrix(point)?;
let pinv = inverse(&p)?;
let ambient = self.ambient_dim();
let mut gamma = (0..ambient)
.map(|_| Array2::<f64>::zeros((ambient, ambient)))
.collect::<Vec<_>>();
for a in 0..ambient {
let ai = a / self.n;
let aj = a % self.n;
for b in 0..ambient {
let bi = b / self.n;
let bj = b % self.n;
let mut u = Array2::<f64>::zeros((self.n, self.n));
let mut v = Array2::<f64>::zeros((self.n, self.n));
u[[ai, aj]] = 1.0;
v[[bi, bj]] = 1.0;
let c = -0.5 * (u.dot(&pinv).dot(&v) + v.dot(&pinv).dot(&u));
for r in 0..self.n {
for s in 0..self.n {
gamma[r * self.n + s][[a, b]] = c[[r, s]];
}
}
}
}
Ok(gamma)
}
fn sectional_curvature(
&self,
point: ArrayView1<'_, f64>,
tangent_pair: (ArrayView1<'_, f64>, ArrayView1<'_, f64>),
) -> GeometryResult<f64> {
let p = self.matrix(point)?;
let u = sym(&from_flat(tangent_pair.0, self.n, self.n)?);
let v = sym(&from_flat(tangent_pair.1, self.n, self.n)?);
use crate::linalg::faer_ndarray::fast_ab;
let inv_sqrt_p = spectral_map_spd(&p, |x| Ok(1.0 / x.sqrt()))?;
let a = fast_ab(&fast_ab(&inv_sqrt_p, &u), &inv_sqrt_p);
let b = fast_ab(&fast_ab(&inv_sqrt_p, &v), &inv_sqrt_p);
let comm = &fast_ab(&a, &b) - &fast_ab(&b, &a);
let comm_norm = dot(flatten(&comm).view(), flatten(&comm).view());
let uu = self.affine_inner(&p, &u, &u)?;
let vv = self.affine_inner(&p, &v, &v)?;
let uv = self.affine_inner(&p, &u, &v)?;
let denom = uu * vv - uv * uv;
if denom.abs() <= 1.0e-14 {
return Err(GeometryError::Singular(
"SPD sectional curvature plane is degenerate",
));
}
Ok(-0.25 * comm_norm / denom)
}
fn project_tangent(
&self,
point: ArrayView1<'_, f64>,
vec: ArrayView1<'_, f64>,
) -> GeometryResult<Array1<f64>> {
check_len("SPD projection point", point.len(), self.ambient_dim())?;
Ok(flatten(&sym(&from_flat(vec, self.n, self.n)?)))
}
fn exp_map_vjp(
&self,
point: ArrayView1<'_, f64>,
tangent_vec: ArrayView1<'_, f64>,
grad_output: ArrayView1<'_, f64>,
) -> GeometryResult<(Array1<f64>, Array1<f64>)> {
let m = self.ambient_dim();
check_len("SPD exp_map_vjp point", point.len(), m)?;
check_len("SPD exp_map_vjp tangent", tangent_vec.len(), m)?;
check_len("SPD exp_map_vjp grad", grad_output.len(), m)?;
Err(GeometryError::Unsupported(
"SPD exp_map_vjp: no analytic backward implemented",
))
}
}
fn affine_sq_norm(n: usize, pinv: &Array2<f64>, v: ArrayView1<'_, f64>) -> GeometryResult<f64> {
use crate::linalg::faer_ndarray::fast_ab;
let vm = sym(&from_flat(v, n, n)?);
let a = fast_ab(&fast_ab(pinv, &vm), &fast_ab(pinv, &vm));
let mut trace = 0.0_f64;
for i in 0..n {
trace += a[[i, i]];
}
Ok(trace.max(0.0))
}
pub fn spd_frechet_mean(
n: usize,
points: ArrayView2<'_, f64>,
weights: Option<ArrayView1<'_, f64>>,
tol: f64,
max_iter: usize,
) -> GeometryResult<Array1<f64>> {
let ambient = n * n;
let (m, cols) = points.dim();
if m == 0 || cols != ambient {
return Err(GeometryError::InvalidPoint(
"SPD Fréchet mean: points must be M×n² with M ≥ 1",
));
}
if !(tol.is_finite() && tol > 0.0) {
return Err(GeometryError::InvalidPoint(
"SPD Fréchet mean tolerance must be finite and positive",
));
}
let spd = SpdManifold::new(n);
let w = crate::geometry::normalize_weights(m, weights)
.map_err(|_| GeometryError::InvalidPoint("SPD Fréchet mean: invalid weights"))?;
let samples: Vec<Array1<f64>> = (0..m).map(|i| points.row(i).to_owned()).collect();
let dispersion = |p: ArrayView1<'_, f64>| -> GeometryResult<f64> {
let pm = spd.matrix(p)?;
let pinv = inverse(&pm)?;
let mut acc = 0.0_f64;
for (i, x) in samples.iter().enumerate() {
let lg = spd.log_map(p, x.view())?;
acc += w[i] * affine_sq_norm(n, &pinv, lg.view())?;
}
Ok(acc)
};
let mut p = Array1::<f64>::zeros(ambient);
for (i, x) in samples.iter().enumerate() {
p.scaled_add(w[i], x);
}
p = flatten(&sym(&from_flat(p.view(), n, n)?));
let mut f_cur = dispersion(p.view())?;
let mut best_p = p.clone();
let mut best_grad = f64::INFINITY;
const STALL_REL: f64 = 5.0e-3;
const STALL_PATIENCE: usize = 10;
let mut stall = 0_usize;
let armijo_c1 = 1.0e-4_f64;
for _ in 0..max_iter {
let pm = spd.matrix(p.view())?;
let pinv = inverse(&pm)?;
let mut xi = Array1::<f64>::zeros(ambient);
for (i, x) in samples.iter().enumerate() {
let lg = spd.log_map(p.view(), x.view())?;
xi.scaled_add(w[i], &lg);
}
let grad_norm = affine_sq_norm(n, &pinv, xi.view())?.sqrt();
if grad_norm <= tol {
return Ok(p);
}
let improved = grad_norm < best_grad * (1.0 - STALL_REL);
if grad_norm < best_grad {
best_grad = grad_norm;
best_p.assign(&p);
}
if improved {
stall = 0;
} else {
stall += 1;
if stall >= STALL_PATIENCE {
return Ok(best_p);
}
}
let pred = grad_norm * grad_norm; let f_tol = 8.0 * f64::EPSILON * (1.0 + f_cur.abs());
let mut t = 1.0_f64;
let mut accepted = false;
for _ in 0..60 {
let step = &xi * t;
let cand = spd.exp_map(p.view(), step.view())?;
let f_cand = dispersion(cand.view())?;
if f_cand <= f_cur - 2.0 * armijo_c1 * t * pred + f_tol {
p = cand;
f_cur = f_cand;
accepted = true;
break;
}
t *= 0.5;
}
if !accepted {
return Ok(best_p);
}
}
Err(GeometryError::Singular(
"SPD Fréchet mean did not reach stationarity tolerance within max_iter",
))
}
#[cfg(test)]
mod tangent_basis_tests {
use super::SpdManifold;
use crate::geometry::manifold::RiemannianManifold;
use ndarray::Array1;
#[test]
fn spd_tangent_basis_metric_orthonormal() {
let spd = SpdManifold::new(2);
let p = Array1::from(vec![2.0, 0.5, 0.5, 1.0]);
let q = spd.tangent_basis(p.view()).expect("tangent basis");
let w = spd.metric_tensor(p.view()).expect("metric tensor");
let d = spd.dim();
assert_eq!(q.ncols(), d, "basis must have dim() columns");
let wq = w.dot(&q);
let gram = q.t().dot(&wq);
for i in 0..d {
for j in 0..d {
let want = if i == j { 1.0 } else { 0.0 };
assert!(
(gram[[i, j]] - want).abs() <= 1.0e-10,
"QᵀWQ != I at ({i},{j}): got {}",
gram[[i, j]]
);
}
}
}
}
#[cfg(test)]
mod frechet_mean_tests {
use super::{SpdManifold, spd_frechet_mean};
use crate::geometry::manifold::RiemannianManifold;
use ndarray::{Array1, Array2};
fn diag_flat(d: &[f64]) -> Array1<f64> {
let n = d.len();
let mut m = Array2::<f64>::zeros((n, n));
for i in 0..n {
m[[i, i]] = d[i];
}
Array1::from_iter(m.iter().copied())
}
fn stack(rows: &[Array1<f64>]) -> Array2<f64> {
let m = rows.len();
let k = rows[0].len();
let mut s = Array2::<f64>::zeros((m, k));
for (i, r) in rows.iter().enumerate() {
for (j, &v) in r.iter().enumerate() {
s[[i, j]] = v;
}
}
s
}
fn residual(spd: &SpdManifold, p: &Array1<f64>, rows: &[Array1<f64>], w: &[f64]) -> f64 {
let k = p.len();
let mut xi = Array1::<f64>::zeros(k);
for (x, &wi) in rows.iter().zip(w) {
xi.scaled_add(wi, &spd.log_map(p.view(), x.view()).expect("log_map"));
}
let g = spd.metric_tensor(p.view()).expect("metric_tensor");
xi.dot(&g.dot(&xi)).max(0.0).sqrt()
}
#[test]
fn spd_frechet_mean_matches_geometric_mean_on_commuting_extreme_magnitudes() {
let n = 3;
let diags = [
[1e6, 1e-6, 1.0],
[1e-6, 1.0, 1e6],
[1.0, 1e6, 1e-6],
[1e2, 1e-2, 1e2],
];
let rows: Vec<Array1<f64>> = diags.iter().map(|d| diag_flat(d)).collect();
let m = rows.len();
let mut want = [0.0_f64; 3];
for k in 0..n {
let mut s = 0.0;
for d in &diags {
s += d[k].ln();
}
want[k] = (s / m as f64).exp();
}
let p = spd_frechet_mean(n, stack(&rows).view(), None, 1e-12, 500)
.expect("frechet mean converges on commuting extreme-magnitude SPD");
let spd = SpdManifold::new(n);
for i in 0..n {
for j in 0..n {
let got = p[i * n + j];
let exp = if i == j { want[i] } else { 0.0 };
let scale = exp.abs().max(1.0);
assert!(
(got - exp).abs() <= 1e-7 * scale,
"commuting mean[{i},{j}] = {got:.6e}, want {exp:.6e}"
);
}
}
let w = vec![1.0 / m as f64; m];
let r = residual(&spd, &p, &rows, &w);
assert!(r < 1e-9, "commuting case residual {r:.3e} not at floor");
}
#[test]
fn spd_frechet_mean_weighted_matches_weighted_geometric_mean() {
let n = 2;
let diags = [[4.0, 0.25], [0.5, 16.0], [9.0, 1.0]];
let raw_w = [0.5, 0.3, 0.2];
let rows: Vec<Array1<f64>> = diags.iter().map(|d| diag_flat(d)).collect();
let mut want = [0.0_f64; 2];
for k in 0..n {
let mut s = 0.0;
for (d, &wi) in diags.iter().zip(&raw_w) {
s += wi * d[k].ln();
}
want[k] = s.exp();
}
let wv = Array1::from(raw_w.to_vec());
let p = spd_frechet_mean(n, stack(&rows).view(), Some(wv.view()), 1e-12, 500)
.expect("weighted frechet mean converges");
for k in 0..n {
let got = p[k * n + k];
assert!(
(got - want[k]).abs() <= 1e-9 * want[k].max(1.0),
"weighted mean diag[{k}] = {got:.9e}, want {want_k:.9e}",
want_k = want[k]
);
}
}
#[test]
fn spd_frechet_mean_converges_below_sqrt_eps_on_spread_non_commuting() {
let n = 2;
let angles = [0.0_f64, 0.6, 1.2, 1.9, 2.7];
let eig = [(12.0_f64, 0.4_f64), (0.5, 9.0), (3.0, 0.2), (0.3, 6.0), (5.0, 0.7)];
let mut rows: Vec<Array1<f64>> = Vec::new();
for (&th, &(a, b)) in angles.iter().zip(&eig) {
let (c, s) = (th.cos(), th.sin());
let m00 = c * c * a + s * s * b;
let m01 = c * s * (a - b);
let m11 = s * s * a + c * c * b;
rows.push(Array1::from(vec![m00, m01, m01, m11]));
}
let m = rows.len();
let p = spd_frechet_mean(n, stack(&rows).view(), None, 1e-14, 1000)
.expect("spread non-commuting frechet mean converges via safeguard");
let spd = SpdManifold::new(n);
let w = vec![1.0 / m as f64; m];
let r = residual(&spd, &p, &rows, &w);
assert!(
r < 1e-9,
"spread non-commuting residual {r:.3e} did not descend below √ε \
(regression of #693: line search stalled at the V round-off floor)"
);
let disp = |q: &Array1<f64>| -> f64 {
rows.iter()
.map(|x| {
let lg = spd.log_map(q.view(), x.view()).expect("log_map");
let g = spd.metric_tensor(q.view()).expect("metric");
lg.dot(&g.dot(&lg)) / m as f64
})
.sum()
};
let v_mean = disp(&p);
for x in &rows {
assert!(
v_mean < disp(x),
"mean does not minimize dispersion: V(mean)={v_mean:.6e}"
);
}
}
}