use ndarray::{Array1, Array2, ArrayView1, ArrayViewMut1};
const TWO_PI: f64 = std::f64::consts::PI * 2.0;
pub const SPHERE_POLE_WARN_THRESHOLD: f64 = 1.0e-8;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ManifoldWarning {
InvalidPointDimension,
SphereNearPole,
IntervalNearBoundary,
}
impl ManifoldWarning {
pub const fn as_str(self) -> &'static str {
match self {
ManifoldWarning::InvalidPointDimension => {
"manifold: point dimension does not match ambient dimension"
}
ManifoldWarning::SphereNearPole => {
"sphere: near chart pole, retraction may amplify error"
}
ManifoldWarning::IntervalNearBoundary => {
"interval: near boundary; trust radius should be clipped"
}
}
}
}
fn no_manifold_warning(valid_dimension: bool) -> Option<ManifoldWarning> {
match valid_dimension {
true => None,
false => Some(ManifoldWarning::InvalidPointDimension),
}
}
fn valid_manifold_warning_dimension(p: ArrayView1<'_, f64>, ambient_dim: usize) -> bool {
match p.len().cmp(&ambient_dim) {
std::cmp::Ordering::Equal => None,
_ => Some(ManifoldWarning::InvalidPointDimension),
}
.is_none()
}
pub trait Manifold: Send + Sync {
fn dim(&self) -> usize;
fn ambient_dim(&self) -> usize;
fn project_tangent(&self, p: ArrayView1<f64>, v: ArrayViewMut1<f64>);
fn retract(&self, p: ArrayView1<f64>, xi: ArrayView1<f64>, out: ArrayViewMut1<f64>);
fn vector_transport(&self, from: ArrayView1<f64>, to: ArrayView1<f64>, xi: ArrayViewMut1<f64>);
fn metric_weights(&self) -> Vec<f64> {
vec![1.0; self.ambient_dim()]
}
fn inner_product(&self, p: ArrayView1<f64>, xi: ArrayView1<f64>, eta: ArrayView1<f64>) -> f64 {
assert_eq!(p.len(), self.ambient_dim());
assert_eq!(xi.len(), eta.len());
assert_eq!(xi.len(), p.len());
let weights = self.metric_weights();
assert_eq!(weights.len(), xi.len());
let mut acc = 0.0_f64;
for i in 0..xi.len() {
acc += weights[i] * xi[i] * eta[i];
}
acc
}
fn euclidean_to_riemannian_grad(&self, p: ArrayView1<f64>, egrad: ArrayViewMut1<f64>) {
self.project_tangent(p, egrad);
}
fn euclidean_to_riemannian_hess_vp(
&self,
p: ArrayView1<f64>,
egrad: ArrayView1<f64>,
ehess_vp: ArrayViewMut1<f64>,
xi: ArrayView1<f64>,
);
fn name(&self) -> &str;
fn warn_at(&self, p: ArrayView1<f64>) -> Option<&'static str> {
self.warn_at_typed(p).map(ManifoldWarning::as_str)
}
fn warn_at_typed(&self, p: ArrayView1<f64>) -> Option<ManifoldWarning> {
if p.len() == self.ambient_dim() {
None
} else {
Some(ManifoldWarning::InvalidPointDimension)
}
}
fn tangent_basis(&self, p: ArrayView1<f64>) -> Array2<f64> {
let m = self.ambient_dim();
let d = self.dim();
let mut cols: Vec<Array1<f64>> = Vec::with_capacity(m);
for j in 0..m {
let mut e = Array1::<f64>::zeros(m);
e[j] = 1.0;
self.project_tangent(p, e.view_mut());
cols.push(e);
}
let tol = 1.0e-12;
let mut basis: Vec<Array1<f64>> = Vec::with_capacity(d);
for mut v in cols.into_iter() {
if basis.len() == d {
break;
}
for q in basis.iter() {
let mut dot = 0.0_f64;
for i in 0..m {
dot += v[i] * q[i];
}
for i in 0..m {
v[i] -= dot * q[i];
}
}
let mut nrm2 = 0.0_f64;
for i in 0..m {
nrm2 += v[i] * v[i];
}
let nrm = nrm2.sqrt();
if nrm > tol {
for i in 0..m {
v[i] /= nrm;
}
basis.push(v);
}
}
let cols_kept = basis.len();
let mut q = Array2::<f64>::zeros((m, cols_kept));
for (j, col) in basis.into_iter().enumerate() {
for i in 0..m {
q[[i, j]] = col[i];
}
}
q
}
}
#[derive(Debug, Clone)]
pub struct Euclidean {
pub d: usize,
}
impl Manifold for Euclidean {
fn dim(&self) -> usize {
self.d
}
fn ambient_dim(&self) -> usize {
self.d
}
fn project_tangent(&self, p: ArrayView1<f64>, v: ArrayViewMut1<f64>) {
assert_eq!(p.len(), self.d);
assert_eq!(v.len(), self.d);
}
fn retract(&self, p: ArrayView1<f64>, xi: ArrayView1<f64>, mut out: ArrayViewMut1<f64>) {
assert_eq!(p.len(), self.d);
assert_eq!(xi.len(), self.d);
for i in 0..self.d {
out[i] = p[i] + xi[i];
}
}
fn vector_transport(&self, from: ArrayView1<f64>, to: ArrayView1<f64>, xi: ArrayViewMut1<f64>) {
assert_eq!(from.len(), self.d);
assert_eq!(to.len(), self.d);
assert_eq!(xi.len(), self.d);
}
fn euclidean_to_riemannian_hess_vp(
&self,
p: ArrayView1<f64>,
egrad: ArrayView1<f64>,
ehess_vp: ArrayViewMut1<f64>,
xi: ArrayView1<f64>,
) {
assert_eq!(p.len(), self.d);
assert_eq!(egrad.len(), self.d);
assert_eq!(ehess_vp.len(), self.d);
assert_eq!(xi.len(), self.d);
}
fn name(&self) -> &str {
"Euclidean"
}
}
#[derive(Debug, Clone)]
pub struct Circle;
impl Manifold for Circle {
fn dim(&self) -> usize {
1
}
fn ambient_dim(&self) -> usize {
2
}
fn metric_weights(&self) -> Vec<f64> {
let w = 1.0 / (TWO_PI * TWO_PI);
vec![w; 2]
}
fn project_tangent(&self, p: ArrayView1<f64>, mut v: ArrayViewMut1<f64>) {
assert_eq!(p.len(), 2);
assert_eq!(v.len(), 2);
let dot = v[0] * p[0] + v[1] * p[1];
v[0] -= dot * p[0];
v[1] -= dot * p[1];
assert!((v[0] * p[0] + v[1] * p[1]).abs() < 1.0e-9);
}
fn retract(&self, p: ArrayView1<f64>, xi: ArrayView1<f64>, mut out: ArrayViewMut1<f64>) {
let x = p[0] + xi[0];
let y = p[1] + xi[1];
let s2 = x * x + y * y;
assert!(
s2.is_finite() && s2 > 0.0,
"Circle::retract degenerate ||p+ξ||"
);
let norm = s2.sqrt().max(1.0e-300);
out[0] = x / norm;
out[1] = y / norm;
assert!((out[0] * out[0] + out[1] * out[1] - 1.0).abs() < 1.0e-9);
}
fn vector_transport(&self, from: ArrayView1<f64>, to: ArrayView1<f64>, xi: ArrayViewMut1<f64>) {
assert_eq!(from.len(), 2);
self.project_tangent(to, xi);
}
fn euclidean_to_riemannian_hess_vp(
&self,
p: ArrayView1<f64>,
egrad: ArrayView1<f64>,
mut ehess_vp: ArrayViewMut1<f64>,
xi: ArrayView1<f64>,
) {
assert_eq!(p.len(), 2);
assert_eq!(egrad.len(), 2);
assert_eq!(ehess_vp.len(), 2);
assert_eq!(xi.len(), 2);
let radial_egrad = egrad[0] * p[0] + egrad[1] * p[1];
ehess_vp[0] -= radial_egrad * xi[0];
ehess_vp[1] -= radial_egrad * xi[1];
self.project_tangent(p, ehess_vp);
}
fn tangent_basis(&self, p: ArrayView1<f64>) -> Array2<f64> {
assert_eq!(p.len(), 2);
let mut q = Array2::<f64>::zeros((2, 1));
q[[0, 0]] = -p[1];
q[[1, 0]] = p[0];
q
}
fn name(&self) -> &str {
"Circle"
}
}
#[derive(Debug, Clone)]
pub struct Sphere {
pub n: usize,
}
impl Manifold for Sphere {
fn dim(&self) -> usize {
self.n
}
fn ambient_dim(&self) -> usize {
self.n + 1
}
fn metric_weights(&self) -> Vec<f64> {
let w = 1.0 / (std::f64::consts::PI * std::f64::consts::PI);
vec![w; self.n + 1]
}
fn project_tangent(&self, p: ArrayView1<f64>, mut v: ArrayViewMut1<f64>) {
assert_eq!(p.len(), self.n + 1);
assert_eq!(v.len(), self.n + 1);
let mut dot = 0.0_f64;
for i in 0..p.len() {
dot += v[i] * p[i];
}
for i in 0..p.len() {
v[i] -= dot * p[i];
}
let mut chk = 0.0_f64;
for i in 0..p.len() {
chk += v[i] * p[i];
}
assert!(chk.abs() < 1.0e-8 * (1.0 + p.len() as f64));
}
fn retract(&self, p: ArrayView1<f64>, xi: ArrayView1<f64>, mut out: ArrayViewMut1<f64>) {
let m = self.n + 1;
let mut s2 = 0.0_f64;
for i in 0..m {
let v = p[i] + xi[i];
out[i] = v;
s2 += v * v;
}
assert!(
s2.is_finite() && s2 > 0.0,
"Sphere::retract degenerate ||p+ξ||"
);
let norm = s2.sqrt().max(1.0e-300);
for i in 0..m {
out[i] /= norm;
}
let mut n2 = 0.0_f64;
for i in 0..m {
n2 += out[i] * out[i];
}
assert!(
(n2 - 1.0).abs() < 1.0e-9,
"Sphere::retract output not on S^n"
);
}
fn vector_transport(&self, from: ArrayView1<f64>, to: ArrayView1<f64>, xi: ArrayViewMut1<f64>) {
assert_eq!(from.len(), self.n + 1);
self.project_tangent(to, xi);
}
fn euclidean_to_riemannian_hess_vp(
&self,
p: ArrayView1<f64>,
egrad: ArrayView1<f64>,
mut ehess_vp: ArrayViewMut1<f64>,
xi: ArrayView1<f64>,
) {
assert_eq!(p.len(), self.n + 1);
assert_eq!(egrad.len(), self.n + 1);
assert_eq!(ehess_vp.len(), self.n + 1);
assert_eq!(xi.len(), self.n + 1);
let mut radial_egrad = 0.0_f64;
for i in 0..p.len() {
radial_egrad += egrad[i] * p[i];
}
for i in 0..ehess_vp.len() {
ehess_vp[i] -= radial_egrad * xi[i];
}
self.project_tangent(p, ehess_vp);
}
fn tangent_basis(&self, p: ArrayView1<f64>) -> Array2<f64> {
let m = self.n + 1;
assert_eq!(p.len(), m);
if m == 0 {
return Array2::<f64>::zeros((0, 0));
}
let mut anchor = 0usize;
let mut amax = -1.0_f64;
for i in 0..m {
let a = p[i].abs();
if a > amax {
amax = a;
anchor = i;
}
}
let sign = if p[anchor] >= 0.0 { 1.0 } else { -1.0 };
let mut u = Array1::<f64>::zeros(m);
for i in 0..m {
u[i] = sign * p[i];
}
u[anchor] -= 1.0;
let mut u_n2 = 0.0_f64;
for i in 0..m {
u_n2 += u[i] * u[i];
}
let mut basis = Array2::<f64>::zeros((m, self.n));
if u_n2 < 1.0e-30 {
let mut col = 0usize;
for i in 0..m {
if i == anchor {
continue;
}
basis[[i, col]] = 1.0;
col += 1;
}
return basis;
}
let inv_un = 1.0 / u_n2.sqrt();
for i in 0..m {
u[i] *= inv_un;
}
let mut col = 0usize;
for j in 0..m {
if j == anchor {
continue;
}
let coef = 2.0 * u[j];
for i in 0..m {
basis[[i, col]] = -coef * u[i];
}
basis[[j, col]] += 1.0;
col += 1;
}
basis
}
fn warn_at_typed(&self, p: ArrayView1<f64>) -> Option<ManifoldWarning> {
if let Some(last) = p.iter().last()
&& last.abs() < SPHERE_POLE_WARN_THRESHOLD
{
return Some(ManifoldWarning::SphereNearPole);
}
no_manifold_warning(valid_manifold_warning_dimension(p, self.ambient_dim()))
}
fn name(&self) -> &str {
"Sphere"
}
}
#[derive(Debug, Clone)]
pub struct Interval {
pub lo: f64,
pub hi: f64,
}
impl Interval {
const EDGE_FRAC: f64 = 1.0e-6;
fn clip(&self, x: f64) -> f64 {
let band = (self.hi - self.lo).abs() * Self::EDGE_FRAC;
x.max(self.lo + band).min(self.hi - band)
}
}
impl Manifold for Interval {
fn dim(&self) -> usize {
1
}
fn ambient_dim(&self) -> usize {
1
}
fn metric_weights(&self) -> Vec<f64> {
let scale = self.hi - self.lo;
vec![1.0 / (scale * scale)]
}
fn project_tangent(&self, p: ArrayView1<f64>, v: ArrayViewMut1<f64>) {
assert_eq!(p.len(), 1);
assert_eq!(v.len(), 1);
}
fn retract(&self, p: ArrayView1<f64>, xi: ArrayView1<f64>, mut out: ArrayViewMut1<f64>) {
assert_eq!(p.len(), 1);
assert_eq!(xi.len(), 1);
assert_eq!(out.len(), 1);
out[0] = self.clip(p[0] + xi[0]);
assert!(
out[0] > self.lo && out[0] < self.hi,
"Interval::retract output left feasible band"
);
}
fn vector_transport(&self, from: ArrayView1<f64>, to: ArrayView1<f64>, xi: ArrayViewMut1<f64>) {
assert_eq!(from.len(), 1);
assert_eq!(to.len(), 1);
assert_eq!(xi.len(), 1);
}
fn euclidean_to_riemannian_hess_vp(
&self,
p: ArrayView1<f64>,
egrad: ArrayView1<f64>,
ehess_vp: ArrayViewMut1<f64>,
xi: ArrayView1<f64>,
) {
assert_eq!(p.len(), 1);
assert_eq!(egrad.len(), 1);
assert_eq!(ehess_vp.len(), 1);
assert_eq!(xi.len(), 1);
}
fn warn_at_typed(&self, p: ArrayView1<f64>) -> Option<ManifoldWarning> {
let band = (self.hi - self.lo).abs() * Self::EDGE_FRAC * 10.0;
if p[0] < self.lo + band || p[0] > self.hi - band {
Some(ManifoldWarning::IntervalNearBoundary)
} else {
None
}
}
fn tangent_basis(&self, p: ArrayView1<f64>) -> Array2<f64> {
assert_eq!(p.len(), 1);
let mut q = Array2::<f64>::zeros((1, 1));
q[[0, 0]] = 1.0;
q
}
fn name(&self) -> &str {
"Interval"
}
}
#[derive(Debug, Clone)]
pub struct Torus {
pub d: usize,
}
impl Manifold for Torus {
fn dim(&self) -> usize {
self.d
}
fn ambient_dim(&self) -> usize {
2 * self.d
}
fn metric_weights(&self) -> Vec<f64> {
let w = 1.0 / (TWO_PI * TWO_PI);
vec![w; 2 * self.d]
}
fn project_tangent(&self, p: ArrayView1<f64>, mut v: ArrayViewMut1<f64>) {
for k in 0..self.d {
let px = p[2 * k];
let py = p[2 * k + 1];
let dot = v[2 * k] * px + v[2 * k + 1] * py;
v[2 * k] -= dot * px;
v[2 * k + 1] -= dot * py;
}
}
fn retract(&self, p: ArrayView1<f64>, xi: ArrayView1<f64>, mut out: ArrayViewMut1<f64>) {
for k in 0..self.d {
let x = p[2 * k] + xi[2 * k];
let y = p[2 * k + 1] + xi[2 * k + 1];
let s2 = x * x + y * y;
assert!(
s2.is_finite() && s2 > 0.0,
"Torus::retract degenerate at axis {}",
k
);
let norm = s2.sqrt().max(1.0e-300);
out[2 * k] = x / norm;
out[2 * k + 1] = y / norm;
let n2 = out[2 * k] * out[2 * k] + out[2 * k + 1] * out[2 * k + 1];
assert!((n2 - 1.0).abs() < 1.0e-9);
}
}
fn vector_transport(&self, from: ArrayView1<f64>, to: ArrayView1<f64>, xi: ArrayViewMut1<f64>) {
assert_eq!(from.len(), 2 * self.d);
self.project_tangent(to, xi);
}
fn euclidean_to_riemannian_hess_vp(
&self,
p: ArrayView1<f64>,
egrad: ArrayView1<f64>,
mut ehess_vp: ArrayViewMut1<f64>,
xi: ArrayView1<f64>,
) {
for k in 0..self.d {
let radial = egrad[2 * k] * p[2 * k] + egrad[2 * k + 1] * p[2 * k + 1];
ehess_vp[2 * k] -= radial * xi[2 * k];
ehess_vp[2 * k + 1] -= radial * xi[2 * k + 1];
}
self.project_tangent(p, ehess_vp);
}
fn tangent_basis(&self, p: ArrayView1<f64>) -> Array2<f64> {
assert_eq!(p.len(), 2 * self.d);
let mut q = Array2::<f64>::zeros((2 * self.d, self.d));
for k in 0..self.d {
q[[2 * k, k]] = -p[2 * k + 1];
q[[2 * k + 1, k]] = p[2 * k];
}
q
}
fn name(&self) -> &str {
"Torus"
}
}
pub struct Product {
pub components: Vec<Box<dyn Manifold>>,
pub weights: Option<Vec<f64>>,
}
impl Manifold for Product {
fn dim(&self) -> usize {
self.components.iter().map(|c| c.dim()).sum()
}
fn ambient_dim(&self) -> usize {
self.components.iter().map(|c| c.ambient_dim()).sum()
}
fn metric_weights(&self) -> Vec<f64> {
if let Some(weights) = &self.weights {
assert_eq!(
weights.len(),
self.ambient_dim(),
"Product manifold metric weights length must match ambient dimension"
);
return weights.clone();
}
let mut out = Vec::with_capacity(self.ambient_dim());
for c in &self.components {
out.extend(c.metric_weights());
}
out
}
fn project_tangent(&self, p: ArrayView1<f64>, v: ArrayViewMut1<f64>) {
let mut off = 0usize;
let mut v_mut = v;
for c in &self.components {
let m = c.ambient_dim();
let p_slice = p.slice(ndarray::s![off..off + m]);
let v_slice = v_mut.slice_mut(ndarray::s![off..off + m]);
c.project_tangent(p_slice, v_slice);
off += m;
}
}
fn retract(&self, p: ArrayView1<f64>, xi: ArrayView1<f64>, out: ArrayViewMut1<f64>) {
let mut off = 0usize;
let mut out_mut = out;
for c in &self.components {
let m = c.ambient_dim();
let p_slice = p.slice(ndarray::s![off..off + m]);
let xi_slice = xi.slice(ndarray::s![off..off + m]);
let out_slice = out_mut.slice_mut(ndarray::s![off..off + m]);
c.retract(p_slice, xi_slice, out_slice);
off += m;
}
}
fn vector_transport(&self, from: ArrayView1<f64>, to: ArrayView1<f64>, xi: ArrayViewMut1<f64>) {
let mut off = 0usize;
let mut xi_mut = xi;
for c in &self.components {
let m = c.ambient_dim();
let from_slice = from.slice(ndarray::s![off..off + m]);
let to_slice = to.slice(ndarray::s![off..off + m]);
let xi_slice = xi_mut.slice_mut(ndarray::s![off..off + m]);
c.vector_transport(from_slice, to_slice, xi_slice);
off += m;
}
}
fn euclidean_to_riemannian_hess_vp(
&self,
p: ArrayView1<f64>,
egrad: ArrayView1<f64>,
ehess_vp: ArrayViewMut1<f64>,
xi: ArrayView1<f64>,
) {
let mut off = 0usize;
let mut ehess_mut = ehess_vp;
for c in &self.components {
let m = c.ambient_dim();
let p_slice = p.slice(ndarray::s![off..off + m]);
let eg_slice = egrad.slice(ndarray::s![off..off + m]);
let xi_slice = xi.slice(ndarray::s![off..off + m]);
let eh_slice = ehess_mut.slice_mut(ndarray::s![off..off + m]);
c.euclidean_to_riemannian_hess_vp(p_slice, eg_slice, eh_slice, xi_slice);
off += m;
}
}
fn tangent_basis(&self, p: ArrayView1<f64>) -> Array2<f64> {
let m = self.ambient_dim();
let d = self.dim();
let mut q = Array2::<f64>::zeros((m, d));
let mut row_off = 0usize;
let mut col_off = 0usize;
for c in &self.components {
let mc = c.ambient_dim();
let dc = c.dim();
let p_slice = p.slice(ndarray::s![row_off..row_off + mc]);
let qc = c.tangent_basis(p_slice);
assert_eq!(qc.nrows(), mc);
assert_eq!(qc.ncols(), dc);
for i in 0..mc {
for j in 0..dc {
q[[row_off + i, col_off + j]] = qc[[i, j]];
}
}
row_off += mc;
col_off += dc;
}
q
}
fn name(&self) -> &str {
"Product"
}
}
#[derive(Debug, Clone)]
pub enum ManifoldKind {
Euclidean(usize),
Circle,
Sphere(usize),
Interval(f64, f64),
Torus(usize),
Product(Vec<ManifoldKind>),
ProductWithMetric {
components: Vec<ManifoldKind>,
weights: Vec<f64>,
},
}
impl ManifoldKind {
#[must_use]
pub fn build(&self) -> Box<dyn Manifold> {
match self {
ManifoldKind::Euclidean(d) => Box::new(Euclidean { d: *d }),
ManifoldKind::Circle => Box::new(Circle),
ManifoldKind::Sphere(n) => Box::new(Sphere { n: *n }),
ManifoldKind::Interval(lo, hi) => Box::new(Interval { lo: *lo, hi: *hi }),
ManifoldKind::Torus(d) => Box::new(Torus { d: *d }),
ManifoldKind::Product(components) => Box::new(Product {
components: components.iter().map(|c| c.build()).collect(),
weights: None,
}),
ManifoldKind::ProductWithMetric {
components,
weights,
} => Box::new(Product {
components: components.iter().map(|c| c.build()).collect(),
weights: Some(weights.clone()),
}),
}
}
pub fn is_euclidean(&self) -> bool {
matches!(self, ManifoldKind::Euclidean(_))
}
pub fn ambient_dim(&self) -> usize {
match self {
ManifoldKind::Euclidean(d) => *d,
ManifoldKind::Circle => 2,
ManifoldKind::Sphere(n) => n + 1,
ManifoldKind::Interval(_, _) => 1,
ManifoldKind::Torus(d) => 2 * d,
ManifoldKind::Product(components)
| ManifoldKind::ProductWithMetric {
components,
weights: _,
} => components.iter().map(|c| c.ambient_dim()).sum(),
}
}
}
pub fn retract_euclidean_delta(
manifold: &dyn Manifold,
point: ArrayView1<f64>,
delta: ArrayView1<f64>,
out_new_point: ArrayViewMut1<f64>,
) {
let m = manifold.ambient_dim();
assert_eq!(point.len(), m);
assert_eq!(delta.len(), m);
let mut xi = delta.to_owned();
manifold.project_tangent(point, xi.view_mut());
manifold.retract(point, xi.view(), out_new_point);
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
fn norm2(v: ArrayView1<f64>) -> f64 {
v.iter().map(|x| x * x).sum::<f64>().sqrt()
}
#[test]
fn circle_retraction_stays_unit() {
let m = Circle;
let p = array![1.0_f64, 0.0];
let xi = array![0.0_f64, 0.3];
let mut out = array![0.0_f64, 0.0];
m.retract(p.view(), xi.view(), out.view_mut());
let n = norm2(out.view());
assert!((n - 1.0).abs() < 1.0e-12);
}
#[test]
fn sphere_tangent_orthogonal_to_point() {
let m = Sphere { n: 2 };
let p = array![0.0_f64, 0.0, 1.0];
let mut v = array![1.0_f64, 2.0, 3.0];
m.project_tangent(p.view(), v.view_mut());
let dot: f64 = (0..3).map(|i| v[i] * p[i]).sum();
assert!(dot.abs() < 1.0e-12);
}
#[test]
fn interval_stays_strictly_inside() {
let m = Interval { lo: -1.0, hi: 1.0 };
let p = array![0.99_f64];
let xi = array![10.0_f64];
let mut out = array![0.0_f64];
m.retract(p.view(), xi.view(), out.view_mut());
assert!(out[0] > -1.0 && out[0] < 1.0);
}
#[test]
fn euclidean_is_identity() {
let m = Euclidean { d: 3 };
let p = array![0.1_f64, 0.2, 0.3];
let xi = array![1.0_f64, -1.0, 0.5];
let mut out = array![0.0_f64, 0.0, 0.0];
m.retract(p.view(), xi.view(), out.view_mut());
for i in 0..3 {
assert!((out[i] - (p[i] + xi[i])).abs() < 1.0e-15);
}
}
#[test]
fn circle_tangent_basis_is_orthogonal_to_point() {
let m = Circle;
let p = array![0.6_f64, 0.8];
let q = m.tangent_basis(p.view());
assert_eq!(q.shape(), &[2, 1]);
let dot = q[[0, 0]] * p[0] + q[[1, 0]] * p[1];
assert!(dot.abs() < 1.0e-12);
let nrm = (q[[0, 0]].powi(2) + q[[1, 0]].powi(2)).sqrt();
assert!((nrm - 1.0).abs() < 1.0e-12);
}
#[test]
fn sphere_householder_basis_is_orthonormal_and_tangent() {
let m = Sphere { n: 3 };
let raw = array![0.5_f64, -0.3, 0.7, 0.2];
let nrm = (raw.iter().map(|x| x * x).sum::<f64>()).sqrt();
let p: Array1<f64> = raw.iter().map(|x| x / nrm).collect();
let q = m.tangent_basis(p.view());
assert_eq!(q.shape(), &[4, 3]);
for a in 0..3 {
for b in 0..3 {
let mut dot = 0.0_f64;
for i in 0..4 {
dot += q[[i, a]] * q[[i, b]];
}
let expected = if a == b { 1.0 } else { 0.0 };
assert!((dot - expected).abs() < 1.0e-10, "QtQ[{a},{b}]={dot}");
}
}
for a in 0..3 {
let mut dot = 0.0_f64;
for i in 0..4 {
dot += q[[i, a]] * p[i];
}
assert!(dot.abs() < 1.0e-10);
}
}
#[test]
fn torus_retraction_keeps_each_circle_unit() {
let m = Torus { d: 3 };
let p = array![1.0_f64, 0.0, 0.0, 1.0, 0.6, 0.8];
let xi = array![0.0_f64, 0.5, -0.4, 0.0, 0.1, -0.075];
let mut xi_p = xi.clone();
m.project_tangent(p.view(), xi_p.view_mut());
let mut out = Array1::<f64>::zeros(6);
m.retract(p.view(), xi_p.view(), out.view_mut());
for k in 0..3 {
let n2 = out[2 * k] * out[2 * k] + out[2 * k + 1] * out[2 * k + 1];
assert!((n2 - 1.0).abs() < 1.0e-12, "circle {k} not unit");
}
}
#[test]
fn torus_tangent_basis_is_block_diagonal() {
let m = Torus { d: 2 };
let p = array![1.0_f64, 0.0, 0.6, 0.8];
let q = m.tangent_basis(p.view());
assert_eq!(q.shape(), &[4, 2]);
assert!(q[[0, 1]].abs() < 1.0e-15);
assert!(q[[1, 1]].abs() < 1.0e-15);
assert!(q[[2, 0]].abs() < 1.0e-15);
assert!(q[[3, 0]].abs() < 1.0e-15);
}
#[test]
fn product_retraction_equals_component_retractions() {
let prod = Product {
components: vec![
Box::new(Circle),
Box::new(Interval { lo: 0.0, hi: 1.0 }),
Box::new(Euclidean { d: 2 }),
],
weights: None,
};
let p = array![1.0_f64, 0.0, 0.5, 3.0, -1.0];
let xi = array![0.0_f64, 0.2, 0.1, 1.5, -0.25];
let mut xi_p = xi.clone();
prod.project_tangent(p.view(), xi_p.view_mut());
let mut out = Array1::<f64>::zeros(5);
prod.retract(p.view(), xi_p.view(), out.view_mut());
let mut c_out = Array1::<f64>::zeros(2);
Circle.retract(
p.slice(ndarray::s![0..2]),
xi_p.slice(ndarray::s![0..2]),
c_out.view_mut(),
);
let mut i_out = Array1::<f64>::zeros(1);
Interval { lo: 0.0, hi: 1.0 }.retract(
p.slice(ndarray::s![2..3]),
xi_p.slice(ndarray::s![2..3]),
i_out.view_mut(),
);
let mut e_out = Array1::<f64>::zeros(2);
(Euclidean { d: 2 }).retract(
p.slice(ndarray::s![3..5]),
xi_p.slice(ndarray::s![3..5]),
e_out.view_mut(),
);
for i in 0..2 {
assert!((out[i] - c_out[i]).abs() < 1.0e-15);
}
assert!((out[2] - i_out[0]).abs() < 1.0e-15);
for i in 0..2 {
assert!((out[3 + i] - e_out[i]).abs() < 1.0e-15);
}
}
#[test]
fn euclidean_projection_is_identity() {
let m = Euclidean { d: 4 };
let p = array![0.0_f64, 0.0, 0.0, 0.0];
let v_in = array![1.0_f64, -2.0, 3.0, -4.0];
let mut v = v_in.clone();
m.project_tangent(p.view(), v.view_mut());
for i in 0..4 {
assert!((v[i] - v_in[i]).abs() < 1.0e-15);
}
}
#[test]
fn circle_2pi_periodicity_under_retraction() {
let m = Circle;
let mut p = array![1.0_f64, 0.0];
let n_steps = 100usize;
let step = (2.0 * std::f64::consts::PI) / (n_steps as f64);
for _ in 0..n_steps {
let xi = array![-p[1] * step, p[0] * step];
let mut out = array![0.0_f64, 0.0];
m.retract(p.view(), xi.view(), out.view_mut());
p = out;
}
assert!((p[0] - 1.0).abs() < 1.0e-3, "p[0]={}", p[0]);
assert!(p[1].abs() < 1.0e-3, "p[1]={}", p[1]);
}
#[test]
fn sphere_retraction_stays_on_unit_sphere() {
let m = Sphere { n: 4 };
let p = array![0.4_f64, -0.1, 0.5, 0.6, 0.2];
let nrm = (p.iter().map(|x| x * x).sum::<f64>()).sqrt();
let p: Array1<f64> = p.iter().map(|x| x / nrm).collect();
let mut xi = array![0.3_f64, -0.2, 0.1, 0.05, 0.4];
m.project_tangent(p.view(), xi.view_mut());
let mut out = Array1::<f64>::zeros(5);
m.retract(p.view(), xi.view(), out.view_mut());
let n2: f64 = out.iter().map(|x| x * x).sum();
assert!((n2 - 1.0).abs() < 1.0e-12);
}
#[test]
fn weingarten_correction_matches_two_paths_on_sphere() {
let m = Sphere { n: 2 };
let p = array![
1.0_f64 / 3.0_f64.sqrt(),
1.0 / 3.0_f64.sqrt(),
1.0 / 3.0_f64.sqrt()
];
let egrad = array![0.5_f64, -0.2, 0.7];
let mut xi = array![1.0_f64, 0.0, 0.0];
m.project_tangent(p.view(), xi.view_mut());
let mut path_a = xi.clone(); m.euclidean_to_riemannian_hess_vp(p.view(), egrad.view(), path_a.view_mut(), xi.view());
let radial: f64 = (0..3).map(|i| egrad[i] * p[i]).sum();
let mut path_b = xi.clone();
for i in 0..3 {
path_b[i] -= radial * xi[i];
}
m.project_tangent(p.view(), path_b.view_mut());
for i in 0..3 {
assert!(
(path_a[i] - path_b[i]).abs() < 1.0e-12,
"weingarten mismatch at {i}: {} vs {}",
path_a[i],
path_b[i]
);
}
}
}