use ndarray::{Array1, Array2, ArrayView1, ArrayView2, s};
use crate::manifold::{
GEOMETRY_EPS, GeometryError, GeometryResult, RiemannianManifold, check_len, quad_form,
};
pub struct ProductManifold {
components: Vec<Box<dyn RiemannianManifold>>,
}
impl ProductManifold {
pub fn new(components: Vec<Box<dyn RiemannianManifold>>) -> Self {
Self { components }
}
pub fn components(&self) -> &[Box<dyn RiemannianManifold>] {
&self.components
}
}
impl RiemannianManifold for ProductManifold {
fn dim(&self) -> usize {
self.components.iter().map(|c| c.dim()).sum()
}
fn ambient_dim(&self) -> usize {
self.components.iter().map(|c| c.ambient_dim()).sum()
}
fn tangent_basis(&self, point: ArrayView1<'_, f64>) -> GeometryResult<Array2<f64>> {
check_len("Product point", point.len(), self.ambient_dim())?;
let mut out = Array2::<f64>::zeros((self.ambient_dim(), self.dim()));
let mut row_off = 0usize;
let mut col_off = 0usize;
for component in &self.components {
let m = component.ambient_dim();
let d = component.dim();
let q = component.tangent_basis(point.slice(s![row_off..row_off + m]))?;
for i in 0..m {
for j in 0..d {
out[[row_off + i, col_off + j]] = q[[i, j]];
}
}
row_off += m;
col_off += d;
}
Ok(out)
}
fn exp_map(
&self,
point: ArrayView1<'_, f64>,
tangent_vec: ArrayView1<'_, f64>,
) -> GeometryResult<Array1<f64>> {
check_len("Product point", point.len(), self.ambient_dim())?;
check_len("Product tangent", tangent_vec.len(), self.ambient_dim())?;
let mut out = Array1::<f64>::zeros(self.ambient_dim());
let mut off = 0usize;
for component in &self.components {
let m = component.ambient_dim();
let part = component.exp_map(
point.slice(s![off..off + m]),
tangent_vec.slice(s![off..off + m]),
)?;
for i in 0..m {
out[off + i] = part[i];
}
off += m;
}
Ok(out)
}
fn exp_map_vjp(
&self,
point: ArrayView1<'_, f64>,
tangent_vec: ArrayView1<'_, f64>,
grad_output: ArrayView1<'_, f64>,
) -> GeometryResult<(Array1<f64>, Array1<f64>)> {
let ambient = self.ambient_dim();
check_len("Product exp_map_vjp point", point.len(), ambient)?;
check_len("Product exp_map_vjp tangent", tangent_vec.len(), ambient)?;
check_len("Product exp_map_vjp grad", grad_output.len(), ambient)?;
let mut grad_point = Array1::<f64>::zeros(ambient);
let mut grad_tangent = Array1::<f64>::zeros(ambient);
let mut off = 0usize;
for component in &self.components {
let m = component.ambient_dim();
let (gp, gt) = component.exp_map_vjp(
point.slice(s![off..off + m]),
tangent_vec.slice(s![off..off + m]),
grad_output.slice(s![off..off + m]),
)?;
for i in 0..m {
grad_point[off + i] = gp[i];
grad_tangent[off + i] = gt[i];
}
off += m;
}
Ok((grad_point, grad_tangent))
}
fn log_map(
&self,
p_from: ArrayView1<'_, f64>,
p_to: ArrayView1<'_, f64>,
) -> GeometryResult<Array1<f64>> {
check_len("Product source", p_from.len(), self.ambient_dim())?;
check_len("Product target", p_to.len(), self.ambient_dim())?;
let mut out = Array1::<f64>::zeros(self.ambient_dim());
let mut off = 0usize;
for component in &self.components {
let m = component.ambient_dim();
let part =
component.log_map(p_from.slice(s![off..off + m]), p_to.slice(s![off..off + m]))?;
for i in 0..m {
out[off + i] = part[i];
}
off += m;
}
Ok(out)
}
fn parallel_transport(
&self,
point_along: ArrayView2<'_, f64>,
vec: ArrayView1<'_, f64>,
) -> GeometryResult<Array1<f64>> {
check_len(
"Product path width",
point_along.ncols(),
self.ambient_dim(),
)?;
check_len("Product transported vector", vec.len(), self.ambient_dim())?;
let mut out = Array1::<f64>::zeros(self.ambient_dim());
let mut off = 0usize;
for component in &self.components {
let m = component.ambient_dim();
let mut path = Array2::<f64>::zeros((point_along.nrows(), m));
for row in 0..point_along.nrows() {
for col in 0..m {
path[[row, col]] = point_along[[row, off + col]];
}
}
let part = component.parallel_transport(path.view(), vec.slice(s![off..off + m]))?;
for i in 0..m {
out[off + i] = part[i];
}
off += m;
}
Ok(out)
}
fn metric_tensor(&self, point: ArrayView1<'_, f64>) -> GeometryResult<Array2<f64>> {
check_len("Product metric point", point.len(), self.ambient_dim())?;
let mut out = Array2::<f64>::zeros((self.ambient_dim(), self.ambient_dim()));
let mut off = 0usize;
for component in &self.components {
let m = component.ambient_dim();
let g = component.metric_tensor(point.slice(s![off..off + m]))?;
for i in 0..m {
for j in 0..m {
out[[off + i, off + j]] = g[[i, j]];
}
}
off += m;
}
Ok(out)
}
fn christoffel_symbols(&self, point: ArrayView1<'_, f64>) -> GeometryResult<Vec<Array2<f64>>> {
check_len("Product Christoffel point", point.len(), self.ambient_dim())?;
let ambient = self.ambient_dim();
let mut out = (0..ambient)
.map(|_| Array2::<f64>::zeros((ambient, ambient)))
.collect::<Vec<_>>();
let mut off = 0usize;
for component in &self.components {
let m = component.ambient_dim();
let gamma = component.christoffel_symbols(point.slice(s![off..off + m]))?;
for k in 0..m {
for i in 0..m {
for j in 0..m {
out[off + k][[off + i, off + j]] = gamma[k][[i, j]];
}
}
}
off += m;
}
Ok(out)
}
fn sectional_curvature(
&self,
point: ArrayView1<'_, f64>,
tangent_pair: (ArrayView1<'_, f64>, ArrayView1<'_, f64>),
) -> GeometryResult<f64> {
check_len("Product curvature point", point.len(), self.ambient_dim())?;
check_len(
"Product curvature tangent u",
tangent_pair.0.len(),
self.ambient_dim(),
)?;
check_len(
"Product curvature tangent v",
tangent_pair.1.len(),
self.ambient_dim(),
)?;
let (u, v) = tangent_pair;
let mut numerator = 0.0;
let mut uu_total = 0.0;
let mut vv_total = 0.0;
let mut uv_total = 0.0;
let mut off = 0usize;
for component in &self.components {
let m = component.ambient_dim();
let u_r = u.slice(s![off..off + m]);
let v_r = v.slice(s![off..off + m]);
let g_r = component.metric_tensor(point.slice(s![off..off + m]))?;
let uu_r = quad_form(g_r.view(), u_r, u_r);
let vv_r = quad_form(g_r.view(), v_r, v_r);
let uv_r = quad_form(g_r.view(), u_r, v_r);
let gram_r = uu_r * vv_r - uv_r * uv_r;
if gram_r > GEOMETRY_EPS {
let k_r =
component.sectional_curvature(point.slice(s![off..off + m]), (u_r, v_r))?;
numerator += k_r * gram_r;
}
uu_total += uu_r;
vv_total += vv_r;
uv_total += uv_r;
off += m;
}
let denom = uu_total * vv_total - uv_total * uv_total;
if denom <= GEOMETRY_EPS {
return Err(GeometryError::Singular(
"Product sectional curvature plane is degenerate",
));
}
Ok(numerator / denom)
}
fn project_tangent(
&self,
point: ArrayView1<'_, f64>,
vec: ArrayView1<'_, f64>,
) -> GeometryResult<Array1<f64>> {
check_len("Product projection point", point.len(), self.ambient_dim())?;
check_len("Product projection vector", vec.len(), self.ambient_dim())?;
let mut out = Array1::<f64>::zeros(self.ambient_dim());
let mut off = 0usize;
for component in &self.components {
let m = component.ambient_dim();
let part = component
.project_tangent(point.slice(s![off..off + m]), vec.slice(s![off..off + m]))?;
for i in 0..m {
out[off + i] = part[i];
}
off += m;
}
Ok(out)
}
fn riemannian_gradient(
&self,
point: ArrayView1<'_, f64>,
euclidean_grad: ArrayView1<'_, f64>,
) -> GeometryResult<Array1<f64>> {
check_len("Product gradient point", point.len(), self.ambient_dim())?;
check_len(
"Product gradient vector",
euclidean_grad.len(),
self.ambient_dim(),
)?;
let mut out = Array1::<f64>::zeros(self.ambient_dim());
let mut off = 0usize;
for component in &self.components {
let m = component.ambient_dim();
let part = component.riemannian_gradient(
point.slice(s![off..off + m]),
euclidean_grad.slice(s![off..off + m]),
)?;
for i in 0..m {
out[off + i] = part[i];
}
off += m;
}
Ok(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::manifold::RiemannianManifold;
use crate::manifolds::euclidean::EuclideanManifold;
use ndarray::array;
fn two_euclidean() -> ProductManifold {
ProductManifold::new(vec![
Box::new(EuclideanManifold::new(2)),
Box::new(EuclideanManifold::new(3)),
])
}
#[test]
fn dim_is_sum_of_component_dims() {
assert_eq!(two_euclidean().dim(), 5);
}
#[test]
fn ambient_dim_equals_dim_for_euclidean_factors() {
assert_eq!(two_euclidean().ambient_dim(), 5);
}
#[test]
fn exp_map_euclidean_product_is_componentwise_add() {
let m = two_euclidean();
let p = array![1.0_f64, 2.0, 3.0, 4.0, 5.0];
let v = array![10.0_f64, 20.0, 30.0, 40.0, 50.0];
let q = m.exp_map(p.view(), v.view()).unwrap();
assert_eq!(q.len(), 5);
for i in 0..5 {
assert!((q[i] - (p[i] + v[i])).abs() < 1e-12, "index {i}: {}", q[i]);
}
}
#[test]
fn log_map_euclidean_product_is_componentwise_sub() {
let m = two_euclidean();
let p = array![1.0_f64, 2.0, 3.0, 4.0, 5.0];
let q = array![4.0_f64, 2.0, 1.0, 9.0, 5.0];
let v = m.log_map(p.view(), q.view()).unwrap();
let expected = array![3.0_f64, 0.0, -2.0, 5.0, 0.0];
for i in 0..5 {
assert!((v[i] - expected[i]).abs() < 1e-12, "index {i}: {}", v[i]);
}
}
#[test]
fn metric_tensor_is_block_identity_for_euclidean_factors() {
let m = two_euclidean();
let p = Array1::<f64>::zeros(5);
let g = m.metric_tensor(p.view()).unwrap();
assert_eq!(g.dim(), (5, 5));
for i in 0..5 {
for j in 0..5 {
let expected = if i == j { 1.0 } else { 0.0 };
assert!((g[[i, j]] - expected).abs() < 1e-14);
}
}
}
#[test]
fn dimension_mismatch_returns_error() {
let m = two_euclidean();
let p = array![1.0_f64, 2.0]; let v = array![0.0_f64, 0.0];
assert!(m.exp_map(p.view(), v.view()).is_err());
}
#[test]
fn single_factor_product_behaves_like_that_manifold() {
let m = ProductManifold::new(vec![Box::new(EuclideanManifold::new(3))]);
assert_eq!(m.dim(), 3);
let p = array![1.0_f64, 0.0, -1.0];
let v = array![2.0_f64, 3.0, 4.0];
let q = m.exp_map(p.view(), v.view()).unwrap();
assert!((q[0] - 3.0).abs() < 1e-12);
assert!((q[1] - 3.0).abs() < 1e-12);
assert!((q[2] - 3.0).abs() < 1e-12);
}
}