use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use crate::geometry::manifold::{
GEOMETRY_EPS, GeometryError, GeometryResult, RiemannianManifold, check_len, dot, identity, norm,
};
use crate::geometry::normalize_weights;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SphereManifold {
intrinsic_dim: usize,
}
impl SphereManifold {
const UNIT_TOL: f64 = 1.0e-6;
pub const fn new(intrinsic_dim: usize) -> Self {
Self { intrinsic_dim }
}
fn normalize(&self, x: Array1<f64>) -> GeometryResult<Array1<f64>> {
let nrm = norm(x.view());
if nrm <= GEOMETRY_EPS || !nrm.is_finite() {
return Err(GeometryError::InvalidPoint(
"sphere normalization underflow",
));
}
Ok(x / nrm)
}
fn require_unit(&self, point: ArrayView1<'_, f64>) -> GeometryResult<()> {
let n2 = dot(point, point);
if !n2.is_finite() || (n2 - 1.0).abs() > Self::UNIT_TOL {
return Err(GeometryError::InvalidPoint(
"sphere operation requires a unit-norm base point",
));
}
Ok(())
}
}
impl RiemannianManifold for SphereManifold {
fn dim(&self) -> usize {
self.intrinsic_dim
}
fn ambient_dim(&self) -> usize {
self.intrinsic_dim + 1
}
fn tangent_basis(&self, point: ArrayView1<'_, f64>) -> GeometryResult<Array2<f64>> {
let m = self.ambient_dim();
check_len("Sphere point", point.len(), m)?;
self.require_unit(point)?;
let mut anchor = 0usize;
let mut max_abs = 0.0;
for i in 0..m {
if point[i].abs() > max_abs {
max_abs = point[i].abs();
anchor = i;
}
}
let sign = if point[anchor] >= 0.0 { 1.0 } else { -1.0 };
let mut u = point.to_owned() * sign;
u[anchor] -= 1.0;
let u_nrm = norm(u.view());
let mut basis = Array2::<f64>::zeros((m, self.intrinsic_dim));
if u_nrm <= GEOMETRY_EPS {
let mut col = 0usize;
for row in 0..m {
if row != anchor {
basis[[row, col]] = 1.0;
col += 1;
}
}
return Ok(basis);
}
u /= u_nrm;
let mut col = 0usize;
for j in 0..m {
if j == anchor {
continue;
}
let coef = 2.0 * u[j];
for i in 0..m {
basis[[i, col]] = -coef * u[i];
}
basis[[j, col]] += 1.0;
col += 1;
}
Ok(basis)
}
fn exp_map(
&self,
point: ArrayView1<'_, f64>,
tangent_vec: ArrayView1<'_, f64>,
) -> GeometryResult<Array1<f64>> {
let m = self.ambient_dim();
check_len("Sphere point", point.len(), m)?;
check_len("Sphere tangent", tangent_vec.len(), m)?;
self.require_unit(point)?;
let xi = self.project_tangent(point, tangent_vec)?;
let theta = norm(xi.view());
if theta < 1.0e-10 {
return self.normalize(&point + &xi);
}
Ok(point.to_owned() * theta.cos() + xi * (theta.sin() / theta))
}
fn log_map(
&self,
p_from: ArrayView1<'_, f64>,
p_to: ArrayView1<'_, f64>,
) -> GeometryResult<Array1<f64>> {
let m = self.ambient_dim();
check_len("Sphere source", p_from.len(), m)?;
check_len("Sphere target", p_to.len(), m)?;
self.require_unit(p_from)?;
self.require_unit(p_to)?;
let c = dot(p_from, p_to).clamp(-1.0, 1.0);
let mut chord_sq = 0.0_f64;
for i in 0..m {
let d = p_to[i] - p_from[i];
chord_sq += d * d;
}
let theta = 2.0 * (0.5 * chord_sq.sqrt()).min(1.0).asin();
if theta < 1.0e-10 {
return Ok(Array1::<f64>::zeros(m));
}
let mut u = &p_to - &(p_from.to_owned() * c);
let u_nrm = norm(u.view());
if u_nrm < 1.0e-10 {
return Err(GeometryError::Singular(
"sphere log map is undefined at the antipode (cut locus)",
));
}
u *= theta / u_nrm;
Ok(u)
}
fn parallel_transport(
&self,
point_along: ArrayView2<'_, f64>,
vec: ArrayView1<'_, f64>,
) -> GeometryResult<Array1<f64>> {
let m = self.ambient_dim();
check_len("Sphere path width", point_along.ncols(), m)?;
check_len("Sphere transported vector", vec.len(), m)?;
if point_along.nrows() < 2 {
return Ok(vec.to_owned());
}
let from = point_along.row(0);
let to = point_along.row(point_along.nrows() - 1);
self.require_unit(from)?;
self.require_unit(to)?;
let denom = 1.0 + dot(from, to);
if denom.abs() < 1.0e-10 {
return Err(GeometryError::Singular(
"sphere parallel transport across antipodal endpoints is path-dependent",
));
}
let scale = dot(vec, to) / denom;
Ok(vec.to_owned() - &(from.to_owned() + to.to_owned()) * scale)
}
fn metric_tensor(&self, point: ArrayView1<'_, f64>) -> GeometryResult<Array2<f64>> {
check_len("Sphere metric point", point.len(), self.ambient_dim())?;
self.require_unit(point)?;
Ok(identity(self.ambient_dim()))
}
fn christoffel_symbols(&self, point: ArrayView1<'_, f64>) -> GeometryResult<Vec<Array2<f64>>> {
check_len("Sphere Christoffel point", point.len(), self.ambient_dim())?;
self.require_unit(point)?;
Err(GeometryError::Unsupported(
"Christoffel symbols of the embedded sphere require a local chart",
))
}
fn sectional_curvature(
&self,
point: ArrayView1<'_, f64>,
tangent_pair: (ArrayView1<'_, f64>, ArrayView1<'_, f64>),
) -> GeometryResult<f64> {
check_len("Sphere curvature point", point.len(), self.ambient_dim())?;
check_len(
"Sphere curvature tangent u",
tangent_pair.0.len(),
self.ambient_dim(),
)?;
check_len(
"Sphere curvature tangent v",
tangent_pair.1.len(),
self.ambient_dim(),
)?;
Ok(1.0)
}
fn project_tangent(
&self,
point: ArrayView1<'_, f64>,
vec: ArrayView1<'_, f64>,
) -> GeometryResult<Array1<f64>> {
check_len("Sphere projection point", point.len(), self.ambient_dim())?;
check_len("Sphere projection vector", vec.len(), self.ambient_dim())?;
self.require_unit(point)?;
Ok(vec.to_owned() - &(point.to_owned() * dot(point, vec)))
}
fn exp_map_vjp(
&self,
point: ArrayView1<'_, f64>,
tangent_vec: ArrayView1<'_, f64>,
grad_output: ArrayView1<'_, f64>,
) -> GeometryResult<(Array1<f64>, Array1<f64>)> {
let m = self.ambient_dim();
check_len("Sphere exp_map_vjp point", point.len(), m)?;
check_len("Sphere exp_map_vjp tangent", tangent_vec.len(), m)?;
check_len("Sphere exp_map_vjp grad", grad_output.len(), m)?;
self.require_unit(point)?;
let c = dot(point, tangent_vec); let xi = &tangent_vec.to_owned() - &(point.to_owned() * c);
let theta = norm(xi.view());
let g = grad_output;
let p = point;
let v = tangent_vec;
if theta < 1.0e-10 {
let p_dot_g = dot(p, g.view());
let grad_v = &g.to_owned() - &(p.to_owned() * p_dot_g);
let grad_p = &(g.to_owned() * (1.0 - c)) - &(v.to_owned() * p_dot_g);
return Ok((grad_p, grad_v));
}
let sin_t = theta.sin();
let cos_t = theta.cos();
let g_fn = sin_t / theta; let g_prime = (theta * cos_t - sin_t) / (theta * theta);
let n2 = dot(p, p);
let p_dot_g = dot(p, g);
let xi_dot_g = dot(xi.view(), g);
let alpha = -sin_t * p_dot_g + g_prime * xi_dot_g;
let cn = c * (1.0 - n2);
let w_v = (&xi - &(p.to_owned() * cn)) / theta;
let g_perp = &g.to_owned() - &(p.to_owned() * p_dot_g);
let grad_v = &(&w_v * alpha) + &(&g_perp * g_fn);
let w_p = (&(&xi * c) + &(v.to_owned() * cn)) / (-theta);
let p_term = &(g.to_owned() * c) + &(v.to_owned() * p_dot_g);
let grad_p = &(&(&w_p * alpha) + &(g.to_owned() * cos_t)) - &(&p_term * g_fn);
Ok((grad_p, grad_v))
}
}
pub fn validate_sphere_matrix(values: ArrayView2<'_, f64>) -> Result<(), String> {
let (n, d) = values.dim();
if n == 0 || d < 2 {
return Err(
"spherical values must have at least one row and at least two columns".to_string(),
);
}
if let Some(((row, col), value)) = values.indexed_iter().find(|(_, v)| !v.is_finite()) {
return Err(format!(
"spherical values must contain only finite values; got {value} at ({row}, {col})"
));
}
Ok(())
}
pub fn normalize_sphere_matrix(values: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
validate_sphere_matrix(values)?;
let (n, d) = values.dim();
let mut out = Array2::<f64>::zeros((n, d));
for row in 0..n {
let row_norm = norm(values.row(row));
if row_norm <= 0.0 {
return Err("spherical rows must have non-zero norm".to_string());
}
for col in 0..d {
out[[row, col]] = values[[row, col]] / row_norm;
}
}
Ok(out)
}
pub fn response_sphere_log_map(
values: ArrayView2<'_, f64>,
base: ArrayView1<'_, f64>,
) -> Result<Array2<f64>, String> {
let y = normalize_sphere_matrix(values)?;
let base2 = Array2::from_shape_fn((1, base.len()), |(_, j)| base[j]);
let b_mat = normalize_sphere_matrix(base2.view())?;
let (n, d) = y.dim();
if d != b_mat.ncols() {
return Err("spherical values and base point have different dimensions".to_string());
}
let base_col = b_mat.slice(ndarray::s![0..1, ..]).t().to_owned();
let dots_mat = crate::geometry::manifold::fast_ab_rows_multi_gpu(y.view(), base_col.view());
let dots = dots_mat.column(0).to_owned();
let mut out = Array2::<f64>::zeros((n, d));
for row in 0..n {
let mut dot = dots[row];
dot = dot.clamp(-1.0, 1.0);
if dot <= -1.0 + 1.0e-12 {
return Err("spherical log map is undefined at antipodal points".to_string());
}
let mut s_sq = 0.0_f64;
for col in 0..d {
let uc = y[[row, col]] - dot * b_mat[[0, col]];
s_sq += uc * uc;
}
let s = s_sq.sqrt();
if s < 1.0e-12 {
for col in 0..d {
out[[row, col]] = 0.0;
}
} else {
let scale = s.atan2(dot) / s;
for col in 0..d {
out[[row, col]] = (y[[row, col]] - dot * b_mat[[0, col]]) * scale;
}
}
}
Ok(out)
}
pub fn response_sphere_exp_map(
tangent: ArrayView2<'_, f64>,
base: ArrayView1<'_, f64>,
) -> Result<Array2<f64>, String> {
let base2 = Array2::from_shape_fn((1, base.len()), |(_, j)| base[j]);
let b_mat = normalize_sphere_matrix(base2.view())?;
let (n, d) = tangent.dim();
if d != b_mat.ncols() {
return Err("spherical tangent and base point have different dimensions".to_string());
}
if !tangent.iter().all(|v| v.is_finite()) {
return Err("spherical tangent must contain only finite values".to_string());
}
let base_col = b_mat.slice(ndarray::s![0..1, ..]).t().to_owned();
let radials_mat = crate::geometry::manifold::fast_ab_rows_multi_gpu(tangent, base_col.view());
let radials = radials_mat.column(0).to_owned();
let mut out = Array2::<f64>::zeros((n, d));
for row in 0..n {
let radial = radials[row];
let mut z = vec![0.0_f64; d];
let mut r_sq = 0.0_f64;
for col in 0..d {
let v = tangent[[row, col]] - radial * b_mat[[0, col]];
z[col] = v;
r_sq += v * v;
}
let r = r_sq.sqrt();
let mut norm_sq = 0.0_f64;
if r < 1.0e-12 {
for col in 0..d {
let v = b_mat[[0, col]] + z[col];
out[[row, col]] = v;
norm_sq += v * v;
}
} else {
let cos_r = r.cos();
let sin_scale = r.sin() / r;
for col in 0..d {
let v = cos_r * b_mat[[0, col]] + sin_scale * z[col];
out[[row, col]] = v;
norm_sq += v * v;
}
}
let norm = norm_sq.sqrt();
if !norm.is_finite() || norm <= 0.0 {
return Err("spherical exponential map produced a non-finite point".to_string());
}
for col in 0..d {
out[[row, col]] /= norm;
}
}
Ok(out)
}
fn sphere_orthogonal_unit(vector: ArrayView1<'_, f64>) -> Result<Array1<f64>, String> {
let mut min_index = 0;
let mut min_abs = vector[0].abs();
for (index, value) in vector.iter().enumerate().skip(1) {
let candidate = value.abs();
if candidate < min_abs {
min_abs = candidate;
min_index = index;
}
}
let axis_dot = vector[min_index];
let mut tangent = Array1::<f64>::zeros(vector.len());
tangent[min_index] = 1.0;
for col in 0..vector.len() {
tangent[col] -= axis_dot * vector[col];
}
let tangent_norm = norm(tangent.view());
if tangent_norm <= 0.0 {
return Err("cannot construct a tangent direction for the spherical mean".to_string());
}
Ok(tangent.mapv(|v| v / tangent_norm))
}
fn sphere_mean_candidates(
values: ArrayView2<'_, f64>,
weights: ArrayView1<'_, f64>,
) -> Result<Vec<Array1<f64>>, String> {
let (_, d) = values.dim();
let mut candidates: Vec<Array1<f64>> = Vec::new();
let extrinsic = crate::linalg::faer_ndarray::fast_atv(&values, &weights);
let ex_norm = norm(extrinsic.view());
if ex_norm > 0.0 {
candidates.push(extrinsic.mapv(|v| v / ex_norm));
}
let moment = sphere_second_moment(values, weights);
let mut v = Array1::<f64>::from_elem(d, 1.0 / (d as f64).sqrt());
for _ in 0..64 {
let mut nv = Array1::<f64>::zeros(d);
for r in 0..d {
let mut acc = 0.0;
for c in 0..d {
acc += moment[[r, c]] * v[c];
}
nv[r] = acc;
}
let nrm = norm(nv.view());
if nrm <= 0.0 {
break;
}
nv.mapv_inplace(|x| x / nrm);
v = nv;
}
let v_norm = norm(v.view());
if v_norm > 0.0 {
let unit = v.mapv(|x| x / v_norm);
candidates.push(unit.clone());
candidates.push(unit.mapv(|x| -x));
}
Ok(candidates)
}
fn sphere_second_moment(values: ArrayView2<'_, f64>, weights: ArrayView1<'_, f64>) -> Array2<f64> {
crate::linalg::faer_ndarray::fast_xt_diag_x(&values, &weights)
}
fn sphere_dominant_axis(moment: ArrayView2<'_, f64>) -> Option<Array1<f64>> {
let d = moment.nrows();
if d == 0 {
return None;
}
let mut v = Array1::<f64>::from_elem(d, 1.0 / (d as f64).sqrt());
for _ in 0..128 {
let mut nv = Array1::<f64>::zeros(d);
for r in 0..d {
let mut acc = 0.0;
for c in 0..d {
acc += moment[[r, c]] * v[c];
}
nv[r] = acc;
}
let nrm = norm(nv.view());
if nrm <= 0.0 {
return None;
}
nv.mapv_inplace(|x| x / nrm);
v = nv;
}
let nrm = norm(v.view());
if nrm > 0.0 {
Some(v.mapv(|x| x / nrm))
} else {
None
}
}
fn sphere_equatorial_minimizer(
values: ArrayView2<'_, f64>,
weights: ArrayView1<'_, f64>,
) -> Option<Array1<f64>> {
let (_, d) = values.dim();
if d == 0 {
return None;
}
let moment = sphere_second_moment(values, weights);
let axis = sphere_dominant_axis(moment.view())?;
let mut best_k = 0usize;
let mut best_diag = moment[[0, 0]];
for k in 1..d {
let diag = moment[[k, k]];
if diag < best_diag {
best_diag = diag;
best_k = k;
}
}
let mut cand = Array1::<f64>::zeros(d);
cand[best_k] = 1.0;
let proj = dot(cand.view(), axis.view());
for col in 0..d {
cand[col] -= proj * axis[col];
}
let nrm = norm(cand.view());
if nrm > 0.0 {
return Some(cand.mapv(|x| x / nrm));
}
for k in 0..d {
let mut c = Array1::<f64>::zeros(d);
c[k] = 1.0;
let p = dot(c.view(), axis.view());
for col in 0..d {
c[col] -= p * axis[col];
}
let n = norm(c.view());
if n > 0.0 {
return Some(c.mapv(|x| x / n));
}
}
None
}
fn sphere_weighted_log_step(
values: ArrayView2<'_, f64>,
weights: ArrayView1<'_, f64>,
base: ArrayView1<'_, f64>,
) -> Result<Array1<f64>, String> {
let mut step = Array1::<f64>::zeros(base.len());
for row in 0..values.nrows() {
let mut dot_value = 0.0_f64;
let mut chord_sq = 0.0_f64;
for col in 0..base.len() {
dot_value += values[[row, col]] * base[col];
let d = values[[row, col]] - base[col];
chord_sq += d * d;
}
let dot_value = dot_value.clamp(-1.0, 1.0);
if dot_value <= -1.0 + 1.0e-12 {
return Err("spherical log map is undefined at antipodal points".to_string());
}
let theta = 2.0 * (0.5 * chord_sq.sqrt()).min(1.0).asin();
if theta < 1.0e-12 {
continue;
}
let sin_theta = theta.sin();
let scale = if sin_theta > 1.0e-12 {
theta / sin_theta
} else {
1.0
};
for col in 0..base.len() {
step[col] += weights[row] * (values[[row, col]] - dot_value * base[col]) * scale;
}
}
Ok(step)
}
fn sphere_exp_single(
tangent: ArrayView1<'_, f64>,
base: ArrayView1<'_, f64>,
) -> Result<Array1<f64>, String> {
let mut radial = 0.0_f64;
for i in 0..base.len() {
radial += tangent[i] * base[i];
}
let mut z = Array1::<f64>::zeros(base.len());
for col in 0..base.len() {
z[col] = tangent[col] - radial * base[col];
}
let r = norm(z.view());
let mut out = Array1::<f64>::zeros(base.len());
if r < 1.0e-12 {
for col in 0..base.len() {
out[col] = base[col] + z[col];
}
} else {
let cos_r = r.cos();
let sin_scale = r.sin() / r;
for col in 0..base.len() {
out[col] = cos_r * base[col] + sin_scale * z[col];
}
}
let out_norm = norm(out.view());
if !out_norm.is_finite() || out_norm <= 0.0 {
return Err("spherical exponential map produced a non-finite point".to_string());
}
Ok(out.mapv(|v| v / out_norm))
}
fn sphere_frechet_objective(
values: ArrayView2<'_, f64>,
weights: ArrayView1<'_, f64>,
base: ArrayView1<'_, f64>,
) -> f64 {
let mut obj = 0.0_f64;
for row in 0..values.nrows() {
let mut chord_sq = 0.0_f64;
for col in 0..base.len() {
let d = values[[row, col]] - base[col];
chord_sq += d * d;
}
let theta = 2.0 * (0.5 * chord_sq.sqrt()).min(1.0).asin();
obj += weights[row] * theta * theta;
}
obj
}
pub fn sphere_frechet_mean(
points: ArrayView2<'_, f64>,
weights: Option<ArrayView1<'_, f64>>,
tol: f64,
max_iter: usize,
) -> Result<Vec<f64>, String> {
if !(tol.is_finite() && tol >= 0.0) {
return Err("spherical Fréchet mean tolerance must be finite and non-negative".to_string());
}
let y = normalize_sphere_matrix(points)?;
let w = normalize_weights(y.nrows(), weights)?;
let mut candidates = sphere_mean_candidates(y.view(), w.view())?;
if candidates.is_empty() {
candidates.push(sphere_orthogonal_unit(y.row(0))?);
}
let mut best_mu: Option<Array1<f64>> = None;
let mut best_obj = f64::INFINITY;
for candidate in candidates {
let mut mu = candidate;
let mut failed = false;
for _ in 0..max_iter {
let step = match sphere_weighted_log_step(y.view(), w.view(), mu.view()) {
Ok(step) => step,
Err(_) => {
failed = true;
break;
}
};
let step_norm = norm(step.view());
if step_norm < tol {
break;
}
mu = sphere_exp_single(step.view(), mu.view())?;
}
if failed {
continue;
}
let obj = sphere_frechet_objective(y.view(), w.view(), mu.view());
if obj < best_obj {
best_obj = obj;
best_mu = Some(mu);
}
}
if let Some(mu) = best_mu {
return Ok(mu.to_vec());
}
if let Some(mu) = sphere_equatorial_minimizer(y.view(), w.view()) {
return Ok(mu.to_vec());
}
Err("spherical Fréchet mean is not identifiable for these points".to_string())
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
fn obj_at(values: ArrayView2<'_, f64>, weights: ArrayView1<'_, f64>, mu: &[f64]) -> f64 {
let mu_arr = Array1::from(mu.to_vec());
sphere_frechet_objective(values, weights, mu_arr.view())
}
#[test]
fn antipodal_pair_returns_deterministic_equatorial_minimizer() {
let values = array![[1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]];
let mean = sphere_frechet_mean(values.view(), None, 1.0e-12, 256)
.expect("antipodal pair must return a deterministic minimizer");
assert_eq!(mean.len(), 3);
let nrm = (mean[0] * mean[0] + mean[1] * mean[1] + mean[2] * mean[2]).sqrt();
assert!(
(nrm - 1.0).abs() < 1e-9,
"mean must be a unit vector, got {nrm}"
);
assert!(
mean[0].abs() < 1e-9,
"mean must be orthogonal to e1, got {mean:?}"
);
assert!((mean[1] - 1.0).abs() < 1e-9, "expected +e2, got {mean:?}");
assert!(mean[2].abs() < 1e-9, "expected +e2, got {mean:?}");
let w = normalize_weights(2, None).unwrap();
let y = normalize_sphere_matrix(values.view()).unwrap();
let obj_mean = obj_at(y.view(), w.view(), &mean);
let obj_e3 = obj_at(y.view(), w.view(), &[0.0, 0.0, 1.0]);
let obj_e1 = obj_at(y.view(), w.view(), &[1.0, 0.0, 0.0]);
assert!(
(obj_mean - obj_e3).abs() < 1e-9,
"equatorial minimizer must tie other equatorial points: {obj_mean} vs {obj_e3}"
);
assert!(
obj_mean < obj_e1 - 1e-9,
"equatorial minimizer must beat an endpoint: {obj_mean} vs {obj_e1}"
);
}
#[test]
fn antipodal_minimizer_is_deterministic_across_calls() {
let values = array![[1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]];
let a = sphere_frechet_mean(values.view(), None, 1.0e-12, 256).unwrap();
let b = sphere_frechet_mean(values.view(), None, 1.0e-12, 256).unwrap();
assert_eq!(a, b, "tie-breaker must be deterministic across calls");
}
#[test]
fn empty_input_still_errors() {
let values = array![[1.0, 0.0, 0.0]];
let zero = array![0.0_f64];
let err = sphere_frechet_mean(values.view(), Some(zero.view()), 1.0e-12, 256);
assert!(err.is_err(), "zero-weight input must still error");
}
#[test]
fn non_degenerate_mean_unchanged() {
let values = array![[1.0, 0.0, 0.0], [0.9, 0.1, 0.0], [0.9, 0.0, 0.1]];
let mean = sphere_frechet_mean(values.view(), None, 1.0e-12, 256).unwrap();
assert!(mean[0] > 0.9, "expected near-e1 mean, got {mean:?}");
}
}