use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use crate::manifold::{
GeometryError, GeometryResult, RiemannianManifold, check_len, flatten, from_flat, identity,
matrix_det, matrix_exp, orthonormal_completion, qr_thin, skew_log_orthogonal, sym,
tangent_basis_metric_orthonormal,
};
use crate::manifolds::sphere::SphereManifold;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct StiefelManifold {
k: usize,
n: usize,
}
impl StiefelManifold {
pub fn new(k: usize, n: usize) -> GeometryResult<Self> {
if k == 0 || n == 0 || k > n {
return Err(GeometryError::InvalidPoint(
"Stiefel St(n, k) requires 1 <= k <= n",
));
}
Ok(Self { k, n })
}
fn qr_retraction(&self, y: &Array2<f64>) -> Array2<f64> {
let (mut q, r) = qr_thin(y);
for j in 0..self.k {
if r[[j, j]] < 0.0 {
for i in 0..self.n {
q[[i, j]] = -q[[i, j]];
}
}
}
q
}
fn as_sphere(&self) -> Option<SphereManifold> {
(self.k == 1).then(|| SphereManifold::new(self.n - 1))
}
}
impl RiemannianManifold for StiefelManifold {
fn dim(&self) -> usize {
self.n * self.k - self.k * (self.k + 1) / 2
}
fn ambient_dim(&self) -> usize {
self.n * self.k
}
fn tangent_basis(&self, point: ArrayView1<'_, f64>) -> GeometryResult<Array2<f64>> {
check_len("Stiefel point", point.len(), self.ambient_dim())?;
tangent_basis_metric_orthonormal(self, point, self.n, self.k)
}
fn exp_map(
&self,
point: ArrayView1<'_, f64>,
tangent_vec: ArrayView1<'_, f64>,
) -> GeometryResult<Array1<f64>> {
if let Some(sphere) = self.as_sphere() {
return sphere.exp_map(point, tangent_vec);
}
let y = from_flat(point, self.n, self.k)?;
let delta = from_flat(
self.project_tangent(point, tangent_vec)?.view(),
self.n,
self.k,
)?;
use gam_linalg::faer_ndarray::{fast_ab, fast_abt, fast_atb};
let a = fast_atb(&y, &delta); let delta_yt = fast_abt(&delta, &y); let y_dt = fast_abt(&y, &delta); let yayt = fast_abt(&fast_ab(&y, &a), &y); let w = &(&delta_yt - &y_dt) - &yayt; let expw = matrix_exp(&w)?; let result = fast_ab(&expw, &y); Ok(flatten(&result))
}
fn log_map(
&self,
p_from: ArrayView1<'_, f64>,
p_to: ArrayView1<'_, f64>,
) -> GeometryResult<Array1<f64>> {
if let Some(sphere) = self.as_sphere() {
return sphere.log_map(p_from, p_to);
}
check_len("Stiefel source", p_from.len(), self.ambient_dim())?;
check_len("Stiefel target", p_to.len(), self.ambient_dim())?;
let y = from_flat(p_from, self.n, self.k)?;
let y_target = from_flat(p_to, self.n, self.k)?;
stiefel_canonical_log(&y, &y_target, self.n, self.k)
}
fn parallel_transport(
&self,
point_along: ArrayView2<'_, f64>,
vec: ArrayView1<'_, f64>,
) -> GeometryResult<Array1<f64>> {
if let Some(sphere) = self.as_sphere() {
return sphere.parallel_transport(point_along, vec);
}
check_len("Stiefel transported vector", vec.len(), self.ambient_dim())?;
Err(GeometryError::Unsupported(
"Stiefel parallel_transport: no closed-form transport for k > 1",
))
}
fn metric_tensor(&self, point: ArrayView1<'_, f64>) -> GeometryResult<Array2<f64>> {
if let Some(sphere) = self.as_sphere() {
return sphere.metric_tensor(point);
}
let y = from_flat(point, self.n, self.k)?;
let yyt = gam_linalg::faer_ndarray::fast_abt(&y, &y);
let mut m = identity(self.n);
for i in 0..self.n {
for p in 0..self.n {
m[[i, p]] -= 0.5 * yyt[[i, p]];
}
}
let ambient = self.ambient_dim();
let mut g = Array2::<f64>::zeros((ambient, ambient));
for i in 0..self.n {
for p in 0..self.n {
let block = m[[i, p]];
for j in 0..self.k {
g[[i * self.k + j, p * self.k + j]] = block;
}
}
}
Ok(g)
}
fn sectional_curvature(
&self,
point: ArrayView1<'_, f64>,
tangent_pair: (ArrayView1<'_, f64>, ArrayView1<'_, f64>),
) -> GeometryResult<f64> {
if let Some(sphere) = self.as_sphere() {
return sphere.sectional_curvature(point, tangent_pair);
}
check_len("Stiefel curvature point", point.len(), self.ambient_dim())?;
check_len(
"Stiefel curvature tangent u",
tangent_pair.0.len(),
self.ambient_dim(),
)?;
check_len(
"Stiefel curvature tangent v",
tangent_pair.1.len(),
self.ambient_dim(),
)?;
Err(GeometryError::Unsupported(
"Stiefel sectional_curvature: no closed-form value for k > 1",
))
}
fn project_tangent(
&self,
point: ArrayView1<'_, f64>,
vec: ArrayView1<'_, f64>,
) -> GeometryResult<Array1<f64>> {
use gam_linalg::faer_ndarray::{fast_ab, fast_atb};
let y = from_flat(point, self.n, self.k)?;
let z = from_flat(vec, self.n, self.k)?;
let correction = fast_ab(&y, &sym(&fast_atb(&y, &z)));
Ok(flatten(&(z - correction)))
}
fn riemannian_gradient(
&self,
point: ArrayView1<'_, f64>,
euclidean_grad: ArrayView1<'_, f64>,
) -> GeometryResult<Array1<f64>> {
if let Some(sphere) = self.as_sphere() {
return sphere.riemannian_gradient(point, euclidean_grad);
}
use gam_linalg::faer_ndarray::{fast_ab, fast_atb};
let y = from_flat(point, self.n, self.k)?;
let e = from_flat(euclidean_grad, self.n, self.k)?;
let correction = fast_ab(&y, &fast_atb(&e, &y));
Ok(flatten(&(e - correction)))
}
fn retract(
&self,
point: ArrayView1<'_, f64>,
tangent_vec: ArrayView1<'_, f64>,
) -> GeometryResult<Array1<f64>> {
let y = from_flat(point, self.n, self.k)?;
let tangent = from_flat(
self.project_tangent(point, tangent_vec)?.view(),
self.n,
self.k,
)?;
Ok(flatten(&self.qr_retraction(&(y + tangent))))
}
fn retraction_is_second_order(&self) -> bool {
false
}
fn exp_map_vjp(
&self,
point: ArrayView1<'_, f64>,
tangent_vec: ArrayView1<'_, f64>,
grad_output: ArrayView1<'_, f64>,
) -> GeometryResult<(Array1<f64>, Array1<f64>)> {
if let Some(sphere) = self.as_sphere() {
return sphere.exp_map_vjp(point, tangent_vec, grad_output);
}
let m = self.ambient_dim();
check_len("Stiefel exp_map_vjp point", point.len(), m)?;
check_len("Stiefel exp_map_vjp tangent", tangent_vec.len(), m)?;
check_len("Stiefel exp_map_vjp grad", grad_output.len(), m)?;
use gam_linalg::faer_ndarray::{fast_ab, fast_abt, fast_atb};
let y = from_flat(point, self.n, self.k)?;
let z = from_flat(tangent_vec, self.n, self.k)?; let s_proj = sym(&fast_atb(&y, &z)); let delta = &z - &fast_ab(&y, &s_proj); let a = fast_atb(&y, &delta); let delta_yt = fast_abt(&delta, &y); let y_dt = fast_abt(&y, &delta); let yayt = fast_abt(&fast_ab(&y, &a), &y); let w = &(&delta_yt - &y_dt) - &yayt; let expw = matrix_exp(&w)?;
let grad = from_flat(grad_output, self.n, self.k)?;
let expw_bar = fast_abt(&grad, &y); let mut y_bar = fast_atb(&expw, &grad);
let w_bar = matrix_exp_vjp(&w, &expw_bar)?;
let wb_y = fast_ab(&w_bar, &y); let wbt_y = fast_atb(&w_bar, &y); let mut delta_bar = &wb_y - &wbt_y; y_bar = y_bar + &fast_atb(&w_bar, &delta); y_bar = y_bar - &fast_ab(&w_bar, &delta); y_bar = y_bar - &fast_abt(&wb_y, &a); y_bar = y_bar - &fast_ab(&wbt_y, &a); let a_bar = -fast_ab(&fast_atb(&y, &w_bar), &y);
y_bar = y_bar + &fast_abt(&delta, &a_bar); delta_bar = delta_bar + &fast_ab(&y, &a_bar);
let sym_yt_db = sym(&fast_atb(&y, &delta_bar));
let z_bar = &delta_bar - &fast_ab(&y, &sym_yt_db);
y_bar = y_bar - &fast_ab(&delta_bar, &s_proj) - &fast_ab(&z, &sym_yt_db);
Ok((flatten(&y_bar), flatten(&z_bar)))
}
}
fn matrix_exp_vjp(b: &Array2<f64>, cotangent: &Array2<f64>) -> GeometryResult<Array2<f64>> {
let m = b.nrows();
if b.ncols() != m || cotangent.nrows() != m || cotangent.ncols() != m {
return Err(GeometryError::InvalidPoint(
"matrix_exp_vjp requires square matrices of equal size",
));
}
let two_m = 2 * m;
let mut aug = Array2::<f64>::zeros((two_m, two_m));
for i in 0..m {
for j in 0..m {
let bt = b[[j, i]]; aug[[i, j]] = bt;
aug[[m + i, m + j]] = bt;
aug[[i, m + j]] = cotangent[[i, j]];
}
}
let exp_aug = matrix_exp(&aug)?;
Ok(exp_aug.slice(ndarray::s![0..m, m..two_m]).to_owned())
}
fn stiefel_canonical_log(
y: &Array2<f64>,
y_target: &Array2<f64>,
n: usize,
k: usize,
) -> GeometryResult<Array1<f64>> {
use gam_linalg::faer_ndarray::{fast_ab, fast_atb};
let c_dim = n - k;
let frame_y = orthonormal_completion(y); let mut y_perp = Array2::<f64>::zeros((n, c_dim));
for j in 0..c_dim {
for i in 0..n {
y_perp[[i, j]] = frame_y[[i, k + j]];
}
}
let yt_yperp = fast_atb(y_target, &y_perp); let mut y_perp_t = &y_perp - &fast_ab(y_target, &yt_yperp); for j in 0..c_dim {
for _pass in 0..2 {
for col in 0..k {
let mut dot = 0.0_f64;
for i in 0..n {
dot += y_target[[i, col]] * y_perp_t[[i, j]];
}
for i in 0..n {
y_perp_t[[i, j]] -= dot * y_target[[i, col]];
}
}
for prev in 0..j {
let mut dot = 0.0_f64;
for i in 0..n {
dot += y_perp_t[[i, prev]] * y_perp_t[[i, j]];
}
for i in 0..n {
y_perp_t[[i, j]] -= dot * y_perp_t[[i, prev]];
}
}
}
let mut nrm = 0.0_f64;
for i in 0..n {
nrm += y_perp_t[[i, j]] * y_perp_t[[i, j]];
}
let nrm = nrm.sqrt();
if nrm > 1.0e-12 {
for i in 0..n {
y_perp_t[[i, j]] /= nrm;
}
}
}
let mut frame_yt = Array2::<f64>::zeros((n, n));
for j in 0..k {
for i in 0..n {
frame_yt[[i, j]] = y_target[[i, j]];
}
}
for j in 0..c_dim {
for i in 0..n {
frame_yt[[i, k + j]] = y_perp_t[[i, j]];
}
}
if c_dim >= 1 && matrix_det(&frame_yt) < 0.0 {
for i in 0..n {
frame_yt[[i, n - 1]] = -frame_yt[[i, n - 1]];
}
}
let mut v = fast_atb(&frame_y, &frame_yt);
const MAX_ITER: usize = 100;
const TOL: f64 = 1.0e-13;
let mut a_block = Array2::<f64>::zeros((k, k));
let mut b_block = Array2::<f64>::zeros((c_dim, k));
let mut converged = false;
for _ in 0..MAX_ITER {
let log_v = skew_log_orthogonal(&v)?; let mut c_norm_sq = 0.0_f64;
for i in 0..k {
for j in 0..k {
a_block[[i, j]] = log_v[[i, j]];
}
}
for i in 0..c_dim {
for j in 0..k {
b_block[[i, j]] = log_v[[k + i, j]];
}
for j in 0..c_dim {
let c = log_v[[k + i, k + j]];
c_norm_sq += c * c;
}
}
if c_norm_sq.sqrt() <= TOL {
converged = true;
break;
}
let mut neg_c = Array2::<f64>::zeros((c_dim, c_dim));
for i in 0..c_dim {
for j in 0..c_dim {
neg_c[[i, j]] = -log_v[[k + i, k + j]];
}
}
let phi = matrix_exp(&neg_c)?;
let mut v_new = v.clone();
for r in 0..n {
for j in 0..c_dim {
let mut acc = 0.0_f64;
for t in 0..c_dim {
acc += v[[r, k + t]] * phi[[t, j]];
}
v_new[[r, k + j]] = acc;
}
}
v = v_new;
}
if !converged {
return Err(GeometryError::Unsupported(
"Stiefel log_map: iteration did not converge \
(frames beyond the injectivity radius / near the cut locus)",
));
}
let delta = &fast_ab(y, &a_block) + &fast_ab(&y_perp, &b_block);
Ok(flatten(&delta))
}
#[cfg(test)]
mod tangent_basis_tests {
use super::StiefelManifold;
use crate::manifold::RiemannianManifold;
use ndarray::Array1;
#[test]
fn stiefel_tangent_basis_metric_orthonormal() {
let st = StiefelManifold::new(2, 3).expect("St(3,2) exists");
let y = Array1::from(vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0]);
let q = st.tangent_basis(y.view()).expect("tangent basis");
let w = st.metric_tensor(y.view()).expect("metric tensor");
let d = st.dim();
assert_eq!(q.ncols(), d, "basis must have dim() columns");
let wq = w.dot(&q);
let gram = q.t().dot(&wq);
for i in 0..d {
for j in 0..d {
let want = if i == j { 1.0 } else { 0.0 };
assert!(
(gram[[i, j]] - want).abs() <= 1.0e-10,
"QᵀWQ != I at ({i},{j}): got {}",
gram[[i, j]]
);
}
}
}
#[test]
fn stiefel_vertical_tangent_canonical_norm() {
let st = StiefelManifold::new(2, 3).expect("St(3,2) exists");
let y = Array1::from(vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0]);
let delta = Array1::from(vec![0.0, -1.0, 1.0, 0.0, 0.0, 0.0]);
let w = st.metric_tensor(y.view()).expect("metric tensor");
let wd = w.dot(&delta);
let mut norm_sq = 0.0;
for i in 0..delta.len() {
norm_sq += delta[i] * wd[i];
}
assert!(
(norm_sq - 1.0).abs() <= 1.0e-12,
"canonical-metric norm² of vertical tangent must be 1, got {norm_sq}"
);
}
}
#[cfg(test)]
mod stiefel_tests {
use super::StiefelManifold;
use crate::manifold::{GeometryError, RiemannianManifold, from_flat};
use ndarray::{Array1, Array2};
#[test]
fn constructor_rejects_invalid_args() {
assert!(StiefelManifold::new(3, 2).is_err());
assert!(StiefelManifold::new(0, 3).is_err());
assert!(StiefelManifold::new(1, 0).is_err());
assert!(StiefelManifold::new(2, 2).is_ok());
assert!(StiefelManifold::new(1, 5).is_ok());
}
#[test]
fn dim_and_ambient_dim_are_correct() {
let st = StiefelManifold::new(2, 3).unwrap();
assert_eq!(st.dim(), 3);
assert_eq!(st.ambient_dim(), 6);
let st14 = StiefelManifold::new(1, 4).unwrap();
assert_eq!(st14.dim(), 3);
assert_eq!(st14.ambient_dim(), 4);
}
#[test]
fn log_inverts_exp_k2() {
let st = StiefelManifold::new(2, 4).unwrap();
let y = Array1::from(vec![1.0_f64, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]);
let raw = Array1::from(vec![0.0_f64, -0.30, 0.30, 0.0, 0.15, -0.05, 0.10, 0.20]);
let delta = st.project_tangent(y.view(), raw.view()).unwrap();
let target = st.exp_map(y.view(), delta.view()).unwrap();
let recovered = st.log_map(y.view(), target.view()).unwrap();
let mut worst = 0.0_f64;
for i in 0..delta.len() {
worst = worst.max((recovered[i] - delta[i]).abs());
}
assert!(worst < 1e-9, "Log∘Exp != id: max|Δ̂ − Δ| = {worst:.3e}");
}
#[test]
fn exp_inverts_log_k2() {
let st = StiefelManifold::new(2, 3).unwrap();
let y = Array1::from(vec![1.0_f64, 0.0, 0.0, 1.0, 0.0, 0.0]);
let raw = Array1::from(vec![0.0_f64, 0.12, -0.12, 0.0, 0.08, 0.05]);
let step = st.project_tangent(y.view(), raw.view()).unwrap();
let y_target = st.exp_map(y.view(), step.view()).unwrap();
let lg = st.log_map(y.view(), y_target.view()).unwrap();
let back = st.exp_map(y.view(), lg.view()).unwrap();
let mut worst = 0.0_f64;
for i in 0..y_target.len() {
worst = worst.max((back[i] - y_target[i]).abs());
}
assert!(worst < 1e-9, "Exp∘Log != id: max|Ŷ − Ỹ| = {worst:.3e}");
}
#[test]
fn log_inverts_exp_sweep_all_regimes() {
let mut state: u64 = 0x9e3779b97f4a7c15;
let mut next = || {
state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
((state >> 11) as f64) / ((1u64 << 53) as f64) * 2.0 - 1.0 };
for &(k, n) in &[(2usize, 3usize), (2, 4), (2, 5), (2, 6), (3, 5), (3, 7)] {
let st = StiefelManifold::new(k, n).unwrap();
let mut y = Array1::<f64>::zeros(n * k);
for j in 0..k {
y[j * k + j] = 1.0;
}
for &scale in &[0.05_f64, 0.3, 0.7, 1.1] {
let raw: Array1<f64> = (0..n * k).map(|_| next()).collect();
let mut delta = st.project_tangent(y.view(), raw.view()).unwrap();
let g = st.metric_tensor(y.view()).unwrap();
let gd = g.dot(&delta);
let nrm = (0..delta.len()).map(|i| delta[i] * gd[i]).sum::<f64>().sqrt();
if nrm > 1e-12 {
delta.mapv_inplace(|x| x * scale / nrm);
}
let target = st.exp_map(y.view(), delta.view()).unwrap();
let yt = from_flat(target.view(), n, k).unwrap();
let gram = yt.t().dot(&yt);
for a in 0..k {
for b in 0..k {
let want = if a == b { 1.0 } else { 0.0 };
assert!(
(gram[[a, b]] - want).abs() < 1e-10,
"St({n},{k}) exp off-manifold"
);
}
}
let recovered = st.log_map(y.view(), target.view()).unwrap();
let mut worst = 0.0_f64;
for i in 0..delta.len() {
worst = worst.max((recovered[i] - delta[i]).abs());
}
assert!(
worst < 1e-8,
"St({n},{k}) Log∘Exp != id at scale {scale}: max err {worst:.3e}"
);
}
}
}
#[test]
fn log_of_self_is_zero_k2() {
let st = StiefelManifold::new(2, 5).unwrap();
let y = Array1::from(vec![
1.0_f64, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
]);
let lg = st.log_map(y.view(), y.view()).unwrap();
let worst = lg.iter().fold(0.0_f64, |a, &x| a.max(x.abs()));
assert!(worst < 1e-12, "Log_Y(Y) != 0: max = {worst:.3e}");
}
#[test]
fn log_is_tangent_and_isometric_k2() {
let st = StiefelManifold::new(2, 4).unwrap();
let y = Array1::from(vec![1.0_f64, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]);
let raw = Array1::from(vec![0.0_f64, -0.2, 0.2, 0.0, 0.25, -0.1, 0.05, 0.15]);
let delta = st.project_tangent(y.view(), raw.view()).unwrap();
let target = st.exp_map(y.view(), delta.view()).unwrap();
let recovered = st.log_map(y.view(), target.view()).unwrap();
let proj = st.project_tangent(y.view(), recovered.view()).unwrap();
let mut tan_err = 0.0_f64;
for i in 0..recovered.len() {
tan_err = tan_err.max((proj[i] - recovered[i]).abs());
}
assert!(tan_err < 1e-9, "Log not tangent: max|P Δ̂ − Δ̂| = {tan_err:.3e}");
let g = st.metric_tensor(y.view()).unwrap();
let canon_norm = |d: &Array1<f64>| -> f64 {
let gd = g.dot(d);
let mut acc = 0.0_f64;
for i in 0..d.len() {
acc += d[i] * gd[i];
}
acc.sqrt()
};
let d_norm = canon_norm(&delta);
let r_norm = canon_norm(&recovered);
assert!(
(d_norm - r_norm).abs() < 1e-9,
"geodesic distance not preserved: ‖Δ‖={d_norm:.6}, ‖Δ̂‖={r_norm:.6}"
);
}
#[test]
fn parallel_transport_k_gt_1_returns_unsupported() {
let st = StiefelManifold::new(2, 3).unwrap();
let path = Array2::from_shape_vec((1, 6), vec![1.0_f64, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap();
let v = Array1::from(vec![0.0_f64, -1.0, 1.0, 0.0, 0.0, 0.0]);
match st.parallel_transport(path.view(), v.view()) {
Err(GeometryError::Unsupported(_)) => {}
other => panic!("expected Unsupported for k>1, got {other:?}"),
}
}
#[test]
fn sectional_curvature_k_gt_1_returns_unsupported() {
let st = StiefelManifold::new(2, 3).unwrap();
let y = Array1::from(vec![1.0_f64, 0.0, 0.0, 1.0, 0.0, 0.0]);
let u = Array1::from(vec![0.0_f64, -1.0, 1.0, 0.0, 0.0, 0.0]);
let v = Array1::from(vec![0.0_f64, 0.0, 0.0, 0.0, 1.0, 0.0]);
match st.sectional_curvature(y.view(), (u.view(), v.view())) {
Err(GeometryError::Unsupported(_)) => {}
other => panic!("expected Unsupported for k>1, got {other:?}"),
}
}
#[test]
fn project_tangent_makes_ytz_skew_symmetric() {
let st = StiefelManifold::new(2, 3).unwrap();
let y = Array1::from(vec![1.0_f64, 0.0, 0.0, 1.0, 0.0, 0.0]);
let v = Array1::from(vec![0.5_f64, 1.0, 0.0, 0.5, 1.0, 0.0]);
let h = st.project_tangent(y.view(), v.view()).unwrap();
assert!(h[0].abs() < 1e-12, "YᵀH[0,0] = {}", h[0]);
assert!(h[3].abs() < 1e-12, "YᵀH[1,1] = {}", h[3]);
assert!((h[1] + h[2]).abs() < 1e-12, "YᵀH not skew: h[1]={}, h[2]={}", h[1], h[2]);
}
#[test]
fn retract_stays_on_stiefel_manifold() {
let st = StiefelManifold::new(2, 4).unwrap();
let y = Array1::from(vec![1.0_f64, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]);
let delta = Array1::from(vec![0.0_f64, -0.2, 0.2, 0.0, 0.1, 0.0, 0.0, 0.05]);
let q_flat = st.retract(y.view(), delta.view()).unwrap();
let n = 4usize;
let k = 2usize;
let mut qtq = [[0.0_f64; 2]; 2];
for r in 0..n {
for a in 0..k {
for b in 0..k {
qtq[a][b] += q_flat[r * k + a] * q_flat[r * k + b];
}
}
}
for i in 0..k {
for j in 0..k {
let want = if i == j { 1.0 } else { 0.0 };
assert!((qtq[i][j] - want).abs() < 1e-12, "QᵀQ[{i},{j}] = {}", qtq[i][j]);
}
}
}
#[test]
fn exp_map_k1_sphere_half_pi_rotation() {
let st = StiefelManifold::new(1, 3).unwrap();
let p = Array1::from(vec![1.0_f64, 0.0, 0.0]);
let v = Array1::from(vec![0.0_f64, std::f64::consts::FRAC_PI_2, 0.0]);
let q = st.exp_map(p.view(), v.view()).unwrap();
assert!(q[0].abs() < 1e-12, "q[0] = {}", q[0]);
assert!((q[1] - 1.0).abs() < 1e-12, "q[1] = {}", q[1]);
assert!(q[2].abs() < 1e-12, "q[2] = {}", q[2]);
}
}