use super::FixedPoint;
use super::FixedVector;
use super::FixedMatrix;
use super::compute_matrix::ComputeMatrix;
use super::linalg::compute_tier_dot_raw;
use super::matrix_functions::{matrix_exp, matrix_log, matrix_sqrt};
use super::derived::{inverse_spd, frobenius_norm};
use super::decompose::{svd_decompose, qr_decompose};
use crate::fixed_point::universal::fasc::stack_evaluator::BinaryStorage;
use crate::fixed_point::core_types::errors::OverflowDetected;
pub trait Manifold {
fn dimension(&self) -> usize;
fn inner_product(
&self,
base: &FixedVector,
u: &FixedVector,
v: &FixedVector,
) -> FixedPoint;
fn norm(&self, base: &FixedVector, v: &FixedVector) -> FixedPoint {
self.inner_product(base, v, v).sqrt()
}
fn exp_map(
&self,
base: &FixedVector,
tangent: &FixedVector,
) -> Result<FixedVector, OverflowDetected>;
fn log_map(
&self,
base: &FixedVector,
target: &FixedVector,
) -> Result<FixedVector, OverflowDetected>;
fn distance(
&self,
p: &FixedVector,
q: &FixedVector,
) -> Result<FixedPoint, OverflowDetected>;
fn parallel_transport(
&self,
base: &FixedVector,
target: &FixedVector,
tangent: &FixedVector,
) -> Result<FixedVector, OverflowDetected>;
}
pub struct EuclideanSpace {
pub dim: usize,
}
impl Manifold for EuclideanSpace {
fn dimension(&self) -> usize { self.dim }
fn inner_product(&self, _base: &FixedVector, u: &FixedVector, v: &FixedVector) -> FixedPoint {
u.dot_precise(v)
}
fn exp_map(&self, base: &FixedVector, tangent: &FixedVector) -> Result<FixedVector, OverflowDetected> {
Ok(base + tangent)
}
fn log_map(&self, base: &FixedVector, target: &FixedVector) -> Result<FixedVector, OverflowDetected> {
Ok(target - base)
}
fn distance(&self, p: &FixedVector, q: &FixedVector) -> Result<FixedPoint, OverflowDetected> {
Ok(p.metric_distance_safe(q))
}
fn parallel_transport(&self, _base: &FixedVector, _target: &FixedVector, tangent: &FixedVector) -> Result<FixedVector, OverflowDetected> {
Ok(tangent.clone()) }
}
pub struct Sphere {
pub dim: usize, }
impl Sphere {
fn clamp_unit(x: FixedPoint) -> FixedPoint {
let one = FixedPoint::one();
if x > one { one }
else if x < -one { -one }
else { x }
}
}
impl Manifold for Sphere {
fn dimension(&self) -> usize { self.dim }
fn inner_product(&self, _base: &FixedVector, u: &FixedVector, v: &FixedVector) -> FixedPoint {
u.dot_precise(v)
}
fn exp_map(&self, base: &FixedVector, tangent: &FixedVector) -> Result<FixedVector, OverflowDetected> {
let theta = tangent.length();
if theta.is_zero() {
return Ok(base.clone());
}
let cos_t = theta.try_cos()?;
let sin_t = theta.try_sin()?;
let direction = tangent * (FixedPoint::one() / theta);
Ok(base * cos_t + direction * sin_t)
}
fn log_map(&self, base: &FixedVector, target: &FixedVector) -> Result<FixedVector, OverflowDetected> {
let cos_theta = Self::clamp_unit(base.dot_precise(target));
let theta = cos_theta.try_acos()?;
if theta.is_zero() {
return Ok(FixedVector::new(base.len()));
}
let direction = target - &(base * cos_theta);
let dir_len = direction.length();
if dir_len.is_zero() {
return Ok(FixedVector::new(base.len()));
}
Ok(&direction * (theta / dir_len))
}
fn distance(&self, p: &FixedVector, q: &FixedVector) -> Result<FixedPoint, OverflowDetected> {
let cos_theta = Self::clamp_unit(p.dot_precise(q));
cos_theta.try_acos()
}
fn parallel_transport(&self, base: &FixedVector, target: &FixedVector, tangent: &FixedVector) -> Result<FixedVector, OverflowDetected> {
let cos_theta = Self::clamp_unit(base.dot_precise(target));
let one = FixedPoint::one();
let denom = one + cos_theta;
if denom.is_zero() {
return Err(OverflowDetected::DomainError);
}
let p_plus_q = base + target;
let coeff = tangent.dot_precise(&p_plus_q) / denom;
Ok(tangent - &(&p_plus_q * coeff))
}
}
pub struct HyperbolicSpace {
pub dim: usize, }
impl HyperbolicSpace {
fn minkowski_dot(u: &FixedVector, v: &FixedVector) -> FixedPoint {
assert_eq!(u.len(), v.len());
let n = u.len();
if n == 0 { return FixedPoint::ZERO; }
let spatial = if n > 1 {
let u_raw: Vec<BinaryStorage> = (1..n).map(|i| u[i].raw()).collect();
let v_raw: Vec<BinaryStorage> = (1..n).map(|i| v[i].raw()).collect();
FixedPoint::from_raw(compute_tier_dot_raw(&u_raw, &v_raw))
} else {
FixedPoint::ZERO
};
spatial - u[0] * v[0]
}
fn minkowski_norm(v: &FixedVector) -> Result<FixedPoint, OverflowDetected> {
let dot = Self::minkowski_dot(v, v);
if dot.is_negative() {
return Ok((-dot).sqrt());
}
dot.try_sqrt()
}
}
impl Manifold for HyperbolicSpace {
fn dimension(&self) -> usize { self.dim }
fn inner_product(&self, _base: &FixedVector, u: &FixedVector, v: &FixedVector) -> FixedPoint {
Self::minkowski_dot(u, v)
}
fn norm(&self, _base: &FixedVector, v: &FixedVector) -> FixedPoint {
Self::minkowski_norm(v).unwrap_or(FixedPoint::ZERO)
}
fn exp_map(&self, base: &FixedVector, tangent: &FixedVector) -> Result<FixedVector, OverflowDetected> {
let theta = Self::minkowski_norm(tangent)?;
if theta.is_zero() {
return Ok(base.clone());
}
let (sinh_t, cosh_t) = theta.try_sinhcosh()?;
let direction = tangent * (FixedPoint::one() / theta);
Ok(base * cosh_t + direction * sinh_t)
}
fn log_map(&self, base: &FixedVector, target: &FixedVector) -> Result<FixedVector, OverflowDetected> {
let minus_alpha = Self::minkowski_dot(base, target); let alpha = -minus_alpha; let alpha_clamped = if alpha < FixedPoint::one() { FixedPoint::one() } else { alpha };
let theta = alpha_clamped.try_acosh()?;
if theta.is_zero() {
return Ok(FixedVector::new(base.len()));
}
let direction = target - &(base * (-minus_alpha));
let dir_norm = Self::minkowski_norm(&direction)?;
if dir_norm.is_zero() {
return Ok(FixedVector::new(base.len()));
}
Ok(&direction * (theta / dir_norm))
}
fn distance(&self, p: &FixedVector, q: &FixedVector) -> Result<FixedPoint, OverflowDetected> {
let minus_alpha = Self::minkowski_dot(p, q);
let alpha = -minus_alpha;
let alpha_clamped = if alpha < FixedPoint::one() { FixedPoint::one() } else { alpha };
alpha_clamped.try_acosh()
}
fn parallel_transport(&self, base: &FixedVector, target: &FixedVector, tangent: &FixedVector) -> Result<FixedVector, OverflowDetected> {
let log_pq = self.log_map(base, target)?;
let theta = Self::minkowski_norm(&log_pq)?;
if theta.is_zero() {
return Ok(tangent.clone());
}
let u = &log_pq * (FixedPoint::one() / theta);
let vu = Self::minkowski_dot(tangent, &u);
let (sinh_t, cosh_t) = theta.try_sinhcosh()?;
let correction = &(base * sinh_t) + &(&u * (cosh_t - FixedPoint::one()));
Ok(tangent + &(&correction * vu))
}
}
pub struct SPDManifold {
pub n: usize,
}
fn sym_to_vec(m: &FixedMatrix) -> FixedVector {
let n = m.rows();
let dim = n * (n + 1) / 2;
let mut v = FixedVector::new(dim);
let mut k = 0;
for i in 0..n {
for j in i..n {
v[k] = m.get(i, j);
k += 1;
}
}
v
}
fn vec_to_sym(v: &FixedVector, n: usize) -> FixedMatrix {
let mut m = FixedMatrix::new(n, n);
let mut k = 0;
for i in 0..n {
for j in i..n {
m.set(i, j, v[k]);
m.set(j, i, v[k]); k += 1;
}
}
m
}
impl SPDManifold {
fn sqrt_and_inv_sqrt(p: &FixedMatrix) -> Result<(FixedMatrix, FixedMatrix), OverflowDetected> {
let sqrt_p = matrix_sqrt(p)?;
let inv_sqrt_p = inverse_spd(&sqrt_p)?;
Ok((sqrt_p, inv_sqrt_p))
}
}
impl Manifold for SPDManifold {
fn dimension(&self) -> usize {
self.n * (self.n + 1) / 2
}
fn inner_product(
&self,
base: &FixedVector,
u: &FixedVector,
v: &FixedVector,
) -> FixedPoint {
let p = vec_to_sym(base, self.n);
let u_mat = vec_to_sym(u, self.n);
let v_mat = vec_to_sym(v, self.n);
let p_inv = inverse_spd(&p).unwrap_or_else(|_| FixedMatrix::identity(self.n));
let p_inv_c = ComputeMatrix::from_fixed_matrix(&p_inv);
let u_c = ComputeMatrix::from_fixed_matrix(&u_mat);
let v_c = ComputeMatrix::from_fixed_matrix(&v_mat);
let p_inv_u_c = p_inv_c.mat_mul(&u_c);
let p_inv_v_c = p_inv_c.mat_mul(&v_c);
let product_c = p_inv_u_c.mat_mul(&p_inv_v_c);
product_c.trace_compute()
}
fn exp_map(
&self,
base: &FixedVector,
tangent: &FixedVector,
) -> Result<FixedVector, OverflowDetected> {
let p = vec_to_sym(base, self.n);
let v = vec_to_sym(tangent, self.n);
let (sqrt_p, inv_sqrt_p) = Self::sqrt_and_inv_sqrt(&p)?;
let inner = &(&inv_sqrt_p * &v) * &inv_sqrt_p;
let exp_inner = matrix_exp(&inner)?;
let result = &(&sqrt_p * &exp_inner) * &sqrt_p;
Ok(sym_to_vec(&result))
}
fn log_map(
&self,
base: &FixedVector,
target: &FixedVector,
) -> Result<FixedVector, OverflowDetected> {
let p = vec_to_sym(base, self.n);
let q = vec_to_sym(target, self.n);
let (sqrt_p, inv_sqrt_p) = Self::sqrt_and_inv_sqrt(&p)?;
let inner = &(&inv_sqrt_p * &q) * &inv_sqrt_p;
let log_inner = matrix_log(&inner)?;
let result = &(&sqrt_p * &log_inner) * &sqrt_p;
Ok(sym_to_vec(&result))
}
fn distance(
&self,
p: &FixedVector,
q: &FixedVector,
) -> Result<FixedPoint, OverflowDetected> {
let log_v = self.log_map(p, q)?;
let p_mat = vec_to_sym(p, self.n);
let v_mat = vec_to_sym(&log_v, self.n);
let p_inv = inverse_spd(&p_mat)?;
let p_inv_c = ComputeMatrix::from_fixed_matrix(&p_inv);
let v_c = ComputeMatrix::from_fixed_matrix(&v_mat);
let p_inv_v_c = p_inv_c.mat_mul(&v_c);
let product_c = p_inv_v_c.mat_mul(&p_inv_v_c);
product_c.trace_compute().try_sqrt()
}
fn parallel_transport(
&self,
base: &FixedVector,
target: &FixedVector,
tangent: &FixedVector,
) -> Result<FixedVector, OverflowDetected> {
let p = vec_to_sym(base, self.n);
let q = vec_to_sym(target, self.n);
let v = vec_to_sym(tangent, self.n);
let p_inv = inverse_spd(&p)?;
let qp_inv = &q * &p_inv;
let e = matrix_sqrt(&qp_inv)?;
let e_c = ComputeMatrix::from_fixed_matrix(&e);
let v_c = ComputeMatrix::from_fixed_matrix(&v);
let et_c = e_c.transpose();
let result = e_c.mat_mul(&v_c).mat_mul(&et_c).to_fixed_matrix();
Ok(sym_to_vec(&result))
}
}
pub struct Grassmannian {
pub k: usize, pub n: usize, }
impl Grassmannian {
fn mat_to_vec(m: &FixedMatrix) -> FixedVector {
let len = m.rows() * m.cols();
let mut v = FixedVector::new(len);
let mut idx = 0;
for c in 0..m.cols() {
for r in 0..m.rows() {
v[idx] = m.get(r, c);
idx += 1;
}
}
v
}
fn vec_to_mat(v: &FixedVector, n: usize, k: usize) -> FixedMatrix {
let mut m = FixedMatrix::new(n, k);
let mut idx = 0;
for c in 0..k {
for r in 0..n {
m.set(r, c, v[idx]);
idx += 1;
}
}
m
}
#[allow(dead_code)]
fn project_tangent(q: &FixedMatrix, delta: &FixedMatrix) -> FixedMatrix {
let q_c = ComputeMatrix::from_fixed_matrix(q);
let delta_c = ComputeMatrix::from_fixed_matrix(delta);
let qt_delta_c = q_c.transpose().mat_mul(&delta_c);
delta_c.sub(&q_c.mat_mul(&qt_delta_c)).to_fixed_matrix()
}
}
impl Manifold for Grassmannian {
fn dimension(&self) -> usize {
self.k * (self.n - self.k)
}
fn inner_product(
&self,
_base: &FixedVector,
u: &FixedVector,
v: &FixedVector,
) -> FixedPoint {
let u_mat = Self::vec_to_mat(u, self.n, self.k);
let v_mat = Self::vec_to_mat(v, self.n, self.k);
let u_c = ComputeMatrix::from_fixed_matrix(&u_mat);
let v_c = ComputeMatrix::from_fixed_matrix(&v_mat);
u_c.transpose().mat_mul(&v_c).trace_compute()
}
fn exp_map(
&self,
base: &FixedVector,
tangent: &FixedVector,
) -> Result<FixedVector, OverflowDetected> {
let q = Self::vec_to_mat(base, self.n, self.k);
let delta = Self::vec_to_mat(tangent, self.n, self.k);
let svd = svd_decompose(&delta)?;
let kk = svd.sigma.len().min(self.k);
let mut cos_sigma = FixedMatrix::new(kk, kk);
let mut sin_sigma = FixedMatrix::new(kk, kk);
for i in 0..kk {
cos_sigma.set(i, i, svd.sigma[i].try_cos()?);
sin_sigma.set(i, i, svd.sigma[i].try_sin()?);
}
let u_thin = FixedMatrix::from_fn(self.n, kk, |r, c| {
if r < svd.u.rows() && c < svd.u.cols() { svd.u.get(r, c) } else { FixedPoint::ZERO }
});
let vt_thin = FixedMatrix::from_fn(kk, kk, |r, c| {
if r < svd.vt.rows() && c < svd.vt.cols() { svd.vt.get(r, c) } else { FixedPoint::ZERO }
});
let v_thin = vt_thin.transpose();
let q_c = ComputeMatrix::from_fixed_matrix(&q);
let v_thin_c = ComputeMatrix::from_fixed_matrix(&v_thin);
let cos_sigma_c = ComputeMatrix::from_fixed_matrix(&cos_sigma);
let sin_sigma_c = ComputeMatrix::from_fixed_matrix(&sin_sigma);
let vt_thin_c = ComputeMatrix::from_fixed_matrix(&vt_thin);
let u_thin_c = ComputeMatrix::from_fixed_matrix(&u_thin);
let term1_c = q_c.mat_mul(&v_thin_c).mat_mul(&cos_sigma_c).mat_mul(&vt_thin_c);
let term2_c = u_thin_c.mat_mul(&sin_sigma_c).mat_mul(&vt_thin_c);
let result = term1_c.add(&term2_c).to_fixed_matrix();
Ok(Self::mat_to_vec(&result))
}
fn log_map(
&self,
base: &FixedVector,
target: &FixedVector,
) -> Result<FixedVector, OverflowDetected> {
let q1 = Self::vec_to_mat(base, self.n, self.k);
let q2 = Self::vec_to_mat(target, self.n, self.k);
let q1_c = ComputeMatrix::from_fixed_matrix(&q1);
let q2_c = ComputeMatrix::from_fixed_matrix(&q2);
let qt_q2_c = q1_c.transpose().mat_mul(&q2_c); let proj_c = q1_c.mat_mul(&qt_q2_c); let perp_c = q2_c.sub(&proj_c); let qt_q2 = qt_q2_c.to_fixed_matrix(); let perp = perp_c.to_fixed_matrix();
let svd = svd_decompose(&perp)?;
let kk = svd.sigma.len().min(self.k);
let svd_parallel = svd_decompose(&qt_q2)?;
let kk_par = svd_parallel.sigma.len().min(self.k);
let mut theta_diag = FixedMatrix::new(kk, kk);
for i in 0..kk {
let s_perp = svd.sigma[i];
let s_par = if i < kk_par { svd_parallel.sigma[i] } else { FixedPoint::ZERO };
let theta_i = s_perp.try_atan2(s_par)?;
theta_diag.set(i, i, theta_i);
}
let u_thin = FixedMatrix::from_fn(self.n, kk, |r, c| {
if r < svd.u.rows() && c < svd.u.cols() { svd.u.get(r, c) } else { FixedPoint::ZERO }
});
let vt_thin = FixedMatrix::from_fn(kk, kk, |r, c| {
if r < svd.vt.rows() && c < svd.vt.cols() { svd.vt.get(r, c) } else { FixedPoint::ZERO }
});
let u_thin_c = ComputeMatrix::from_fixed_matrix(&u_thin);
let theta_diag_c = ComputeMatrix::from_fixed_matrix(&theta_diag);
let vt_thin_c = ComputeMatrix::from_fixed_matrix(&vt_thin);
let result = u_thin_c.mat_mul(&theta_diag_c).mat_mul(&vt_thin_c).to_fixed_matrix();
Ok(Self::mat_to_vec(&result))
}
fn distance(
&self,
p: &FixedVector,
q: &FixedVector,
) -> Result<FixedPoint, OverflowDetected> {
let q1 = Self::vec_to_mat(p, self.n, self.k);
let q2 = Self::vec_to_mat(q, self.n, self.k);
let q1_c = ComputeMatrix::from_fixed_matrix(&q1);
let q2_c = ComputeMatrix::from_fixed_matrix(&q2);
let qt_q = q1_c.transpose().mat_mul(&q2_c).to_fixed_matrix();
let svd = svd_decompose(&qt_q)?;
let kk = svd.sigma.len().min(self.k);
let one = FixedPoint::one();
let mut thetas: Vec<BinaryStorage> = Vec::with_capacity(kk);
for i in 0..kk {
let s = if svd.sigma[i] > one { one }
else if svd.sigma[i] < -one { -one }
else { svd.sigma[i] };
let theta = s.try_acos()?;
thetas.push(theta.raw());
}
let dist_sq = FixedPoint::from_raw(compute_tier_dot_raw(&thetas, &thetas));
dist_sq.try_sqrt()
}
fn parallel_transport(
&self,
base: &FixedVector,
target: &FixedVector,
tangent: &FixedVector,
) -> Result<FixedVector, OverflowDetected> {
let log_v = self.log_map(base, target)?;
let q1 = Self::vec_to_mat(base, self.n, self.k);
let delta_mat = Self::vec_to_mat(tangent, self.n, self.k);
let tangent_mat = Self::vec_to_mat(&log_v, self.n, self.k);
let svd = svd_decompose(&tangent_mat)?;
let kk = svd.sigma.len().min(self.k);
let u_thin = FixedMatrix::from_fn(self.n, kk, |r, c| {
if r < svd.u.rows() && c < svd.u.cols() { svd.u.get(r, c) } else { FixedPoint::ZERO }
});
let vt_thin = FixedMatrix::from_fn(kk, kk, |r, c| {
if r < svd.vt.rows() && c < svd.vt.cols() { svd.vt.get(r, c) } else { FixedPoint::ZERO }
});
let mut cos_sigma = FixedMatrix::new(kk, kk);
let mut sin_sigma = FixedMatrix::new(kk, kk);
for i in 0..kk {
cos_sigma.set(i, i, svd.sigma[i].try_cos()?);
sin_sigma.set(i, i, svd.sigma[i].try_sin()?);
}
let q1_c = ComputeMatrix::from_fixed_matrix(&q1);
let u_thin_c = ComputeMatrix::from_fixed_matrix(&u_thin);
let vt_thin_c = ComputeMatrix::from_fixed_matrix(&vt_thin);
let cos_sigma_c = ComputeMatrix::from_fixed_matrix(&cos_sigma);
let sin_sigma_c = ComputeMatrix::from_fixed_matrix(&sin_sigma);
let delta_c = ComputeMatrix::from_fixed_matrix(&delta_mat);
let ut_delta_c = u_thin_c.transpose().mat_mul(&delta_c);
let v_thin_c = vt_thin_c.transpose();
let term1_c = q1_c.mat_mul(&v_thin_c).mat_mul(&sin_sigma_c).mat_mul(&ut_delta_c).neg();
let term2_c = u_thin_c.mat_mul(&cos_sigma_c).mat_mul(&ut_delta_c);
let term3_c = delta_c.sub(&u_thin_c.mat_mul(&ut_delta_c));
let result = term1_c.add(&term2_c).add(&term3_c).to_fixed_matrix();
Ok(Self::mat_to_vec(&result))
}
}
pub struct StiefelManifold {
pub k: usize, pub n: usize, }
fn stiefel_mat_to_vec(m: &FixedMatrix) -> FixedVector {
let len = m.rows() * m.cols();
let mut v = FixedVector::new(len);
let mut idx = 0;
for c in 0..m.cols() {
for r in 0..m.rows() {
v[idx] = m.get(r, c);
idx += 1;
}
}
v
}
fn stiefel_vec_to_mat(v: &FixedVector, n: usize, k: usize) -> FixedMatrix {
let mut m = FixedMatrix::new(n, k);
let mut idx = 0;
for c in 0..k {
for r in 0..n {
m.set(r, c, v[idx]);
idx += 1;
}
}
m
}
impl StiefelManifold {
fn project_tangent(q: &FixedMatrix, delta: &FixedMatrix) -> FixedMatrix {
let q_c = ComputeMatrix::from_fixed_matrix(q);
let delta_c = ComputeMatrix::from_fixed_matrix(delta);
let qt_delta_c = q_c.transpose().mat_mul(&delta_c); let sym_c = qt_delta_c.add(&qt_delta_c.transpose()).halve();
delta_c.sub(&q_c.mat_mul(&sym_c)).to_fixed_matrix()
}
}
impl Manifold for StiefelManifold {
fn dimension(&self) -> usize {
self.n * self.k - self.k * (self.k + 1) / 2
}
fn inner_product(
&self,
_base: &FixedVector,
u: &FixedVector,
v: &FixedVector,
) -> FixedPoint {
let u_mat = stiefel_vec_to_mat(u, self.n, self.k);
let v_mat = stiefel_vec_to_mat(v, self.n, self.k);
let u_c = ComputeMatrix::from_fixed_matrix(&u_mat);
let v_c = ComputeMatrix::from_fixed_matrix(&v_mat);
u_c.transpose().mat_mul(&v_c).trace_compute()
}
fn exp_map(
&self,
base: &FixedVector,
tangent: &FixedVector,
) -> Result<FixedVector, OverflowDetected> {
let q = stiefel_vec_to_mat(base, self.n, self.k);
let delta = stiefel_vec_to_mat(tangent, self.n, self.k);
let q_plus_delta = &q + δ
let qr = qr_decompose(&q_plus_delta)?;
let q_new = FixedMatrix::from_fn(self.n, self.k, |r, c| {
let sign = if qr.r.get(c, c).is_negative() {
FixedPoint::from_int(-1)
} else {
FixedPoint::one()
};
qr.q.get(r, c) * sign
});
Ok(stiefel_mat_to_vec(&q_new))
}
fn log_map(
&self,
base: &FixedVector,
target: &FixedVector,
) -> Result<FixedVector, OverflowDetected> {
let q = stiefel_vec_to_mat(base, self.n, self.k);
let q_target = stiefel_vec_to_mat(target, self.n, self.k);
let diff = &q_target - &q;
let tangent = Self::project_tangent(&q, &diff);
Ok(stiefel_mat_to_vec(&tangent))
}
fn distance(
&self,
p: &FixedVector,
q: &FixedVector,
) -> Result<FixedPoint, OverflowDetected> {
let log_v = self.log_map(p, q)?;
let log_mat = stiefel_vec_to_mat(&log_v, self.n, self.k);
frobenius_norm(&log_mat).try_sqrt()
}
fn parallel_transport(
&self,
_base: &FixedVector,
target: &FixedVector,
tangent: &FixedVector,
) -> Result<FixedVector, OverflowDetected> {
let q_target = stiefel_vec_to_mat(target, self.n, self.k);
let delta = stiefel_vec_to_mat(tangent, self.n, self.k);
let transported = Self::project_tangent(&q_target, &delta);
Ok(stiefel_mat_to_vec(&transported))
}
}
pub struct ProductManifold {
m1: Box<dyn Manifold>,
m2: Box<dyn Manifold>,
dim1_embed: usize,
dim2_embed: usize,
}
impl ProductManifold {
pub fn new(
m1: Box<dyn Manifold>,
dim1_embed: usize,
m2: Box<dyn Manifold>,
dim2_embed: usize,
) -> Self {
Self { m1, m2, dim1_embed, dim2_embed }
}
fn split(&self, v: &FixedVector) -> (FixedVector, FixedVector) {
let mut v1 = FixedVector::new(self.dim1_embed);
let mut v2 = FixedVector::new(self.dim2_embed);
for i in 0..self.dim1_embed { v1[i] = v[i]; }
for i in 0..self.dim2_embed { v2[i] = v[self.dim1_embed + i]; }
(v1, v2)
}
fn join(v1: &FixedVector, v2: &FixedVector) -> FixedVector {
let mut v = FixedVector::new(v1.len() + v2.len());
for i in 0..v1.len() { v[i] = v1[i]; }
for i in 0..v2.len() { v[v1.len() + i] = v2[i]; }
v
}
}
impl Manifold for ProductManifold {
fn dimension(&self) -> usize {
self.m1.dimension() + self.m2.dimension()
}
fn inner_product(
&self,
base: &FixedVector,
u: &FixedVector,
v: &FixedVector,
) -> FixedPoint {
let (b1, b2) = self.split(base);
let (u1, u2) = self.split(u);
let (v1, v2) = self.split(v);
self.m1.inner_product(&b1, &u1, &v1) + self.m2.inner_product(&b2, &u2, &v2)
}
fn exp_map(
&self,
base: &FixedVector,
tangent: &FixedVector,
) -> Result<FixedVector, OverflowDetected> {
let (b1, b2) = self.split(base);
let (t1, t2) = self.split(tangent);
let r1 = self.m1.exp_map(&b1, &t1)?;
let r2 = self.m2.exp_map(&b2, &t2)?;
Ok(Self::join(&r1, &r2))
}
fn log_map(
&self,
base: &FixedVector,
target: &FixedVector,
) -> Result<FixedVector, OverflowDetected> {
let (b1, b2) = self.split(base);
let (t1, t2) = self.split(target);
let l1 = self.m1.log_map(&b1, &t1)?;
let l2 = self.m2.log_map(&b2, &t2)?;
Ok(Self::join(&l1, &l2))
}
fn distance(
&self,
p: &FixedVector,
q: &FixedVector,
) -> Result<FixedPoint, OverflowDetected> {
let (p1, p2) = self.split(p);
let (q1, q2) = self.split(q);
let d1 = self.m1.distance(&p1, &q1)?;
let d2 = self.m2.distance(&p2, &q2)?;
(d1 * d1 + d2 * d2).try_sqrt()
}
fn parallel_transport(
&self,
base: &FixedVector,
target: &FixedVector,
tangent: &FixedVector,
) -> Result<FixedVector, OverflowDetected> {
let (b1, b2) = self.split(base);
let (t1, t2) = self.split(target);
let (v1, v2) = self.split(tangent);
let pt1 = self.m1.parallel_transport(&b1, &t1, &v1)?;
let pt2 = self.m2.parallel_transport(&b2, &t2, &v2)?;
Ok(Self::join(&pt1, &pt2))
}
}