gam 0.2.3

Generalized penalized likelihood engine
use ndarray::{Array1, Array2, ArrayView1, ArrayView2, s};

use crate::geometry::manifold::{GeometryResult, RiemannianManifold, check_len};

pub struct ProductManifold {
    components: Vec<Box<dyn RiemannianManifold>>,
}

impl ProductManifold {
    pub fn new(components: Vec<Box<dyn RiemannianManifold>>) -> Self {
        Self { components }
    }

    pub fn components(&self) -> &[Box<dyn RiemannianManifold>] {
        &self.components
    }
}

impl RiemannianManifold for ProductManifold {
    fn dim(&self) -> usize {
        self.components.iter().map(|c| c.dim()).sum()
    }

    fn ambient_dim(&self) -> usize {
        self.components.iter().map(|c| c.ambient_dim()).sum()
    }

    fn tangent_basis(&self, point: ArrayView1<'_, f64>) -> GeometryResult<Array2<f64>> {
        check_len("Product point", point.len(), self.ambient_dim())?;
        let mut out = Array2::<f64>::zeros((self.ambient_dim(), self.dim()));
        let mut row_off = 0usize;
        let mut col_off = 0usize;
        for component in &self.components {
            let m = component.ambient_dim();
            let d = component.dim();
            let q = component.tangent_basis(point.slice(s![row_off..row_off + m]))?;
            for i in 0..m {
                for j in 0..d {
                    out[[row_off + i, col_off + j]] = q[[i, j]];
                }
            }
            row_off += m;
            col_off += d;
        }
        Ok(out)
    }

    fn exp_map(
        &self,
        point: ArrayView1<'_, f64>,
        tangent_vec: ArrayView1<'_, f64>,
    ) -> GeometryResult<Array1<f64>> {
        check_len("Product point", point.len(), self.ambient_dim())?;
        check_len("Product tangent", tangent_vec.len(), self.ambient_dim())?;
        let mut out = Array1::<f64>::zeros(self.ambient_dim());
        let mut off = 0usize;
        for component in &self.components {
            let m = component.ambient_dim();
            let part = component.exp_map(
                point.slice(s![off..off + m]),
                tangent_vec.slice(s![off..off + m]),
            )?;
            for i in 0..m {
                out[off + i] = part[i];
            }
            off += m;
        }
        Ok(out)
    }

    fn log_map(
        &self,
        p_from: ArrayView1<'_, f64>,
        p_to: ArrayView1<'_, f64>,
    ) -> GeometryResult<Array1<f64>> {
        check_len("Product source", p_from.len(), self.ambient_dim())?;
        check_len("Product target", p_to.len(), self.ambient_dim())?;
        let mut out = Array1::<f64>::zeros(self.ambient_dim());
        let mut off = 0usize;
        for component in &self.components {
            let m = component.ambient_dim();
            let part =
                component.log_map(p_from.slice(s![off..off + m]), p_to.slice(s![off..off + m]))?;
            for i in 0..m {
                out[off + i] = part[i];
            }
            off += m;
        }
        Ok(out)
    }

    fn parallel_transport(
        &self,
        point_along: ArrayView2<'_, f64>,
        vec: ArrayView1<'_, f64>,
    ) -> GeometryResult<Array1<f64>> {
        check_len(
            "Product path width",
            point_along.ncols(),
            self.ambient_dim(),
        )?;
        check_len("Product transported vector", vec.len(), self.ambient_dim())?;
        let mut out = Array1::<f64>::zeros(self.ambient_dim());
        let mut off = 0usize;
        for component in &self.components {
            let m = component.ambient_dim();
            let mut path = Array2::<f64>::zeros((point_along.nrows(), m));
            for row in 0..point_along.nrows() {
                for col in 0..m {
                    path[[row, col]] = point_along[[row, off + col]];
                }
            }
            let part = component.parallel_transport(path.view(), vec.slice(s![off..off + m]))?;
            for i in 0..m {
                out[off + i] = part[i];
            }
            off += m;
        }
        Ok(out)
    }

    fn metric_tensor(&self, point: ArrayView1<'_, f64>) -> GeometryResult<Array2<f64>> {
        check_len("Product metric point", point.len(), self.ambient_dim())?;
        let mut out = Array2::<f64>::zeros((self.ambient_dim(), self.ambient_dim()));
        let mut off = 0usize;
        for component in &self.components {
            let m = component.ambient_dim();
            let g = component.metric_tensor(point.slice(s![off..off + m]))?;
            for i in 0..m {
                for j in 0..m {
                    out[[off + i, off + j]] = g[[i, j]];
                }
            }
            off += m;
        }
        Ok(out)
    }

    fn christoffel_symbols(&self, point: ArrayView1<'_, f64>) -> GeometryResult<Vec<Array2<f64>>> {
        check_len("Product Christoffel point", point.len(), self.ambient_dim())?;
        let ambient = self.ambient_dim();
        let mut out = (0..ambient)
            .map(|_| Array2::<f64>::zeros((ambient, ambient)))
            .collect::<Vec<_>>();
        let mut off = 0usize;
        for component in &self.components {
            let m = component.ambient_dim();
            let gamma = component.christoffel_symbols(point.slice(s![off..off + m]))?;
            for k in 0..m {
                for i in 0..m {
                    for j in 0..m {
                        out[off + k][[off + i, off + j]] = gamma[k][[i, j]];
                    }
                }
            }
            off += m;
        }
        Ok(out)
    }

    fn sectional_curvature(
        &self,
        point: ArrayView1<'_, f64>,
        tangent_pair: (ArrayView1<'_, f64>, ArrayView1<'_, f64>),
    ) -> GeometryResult<f64> {
        check_len("Product curvature point", point.len(), self.ambient_dim())?;
        check_len(
            "Product curvature tangent u",
            tangent_pair.0.len(),
            self.ambient_dim(),
        )?;
        check_len(
            "Product curvature tangent v",
            tangent_pair.1.len(),
            self.ambient_dim(),
        )?;
        Ok(0.0)
    }

    fn project_tangent(
        &self,
        point: ArrayView1<'_, f64>,
        vec: ArrayView1<'_, f64>,
    ) -> GeometryResult<Array1<f64>> {
        check_len("Product projection point", point.len(), self.ambient_dim())?;
        check_len("Product projection vector", vec.len(), self.ambient_dim())?;
        let mut out = Array1::<f64>::zeros(self.ambient_dim());
        let mut off = 0usize;
        for component in &self.components {
            let m = component.ambient_dim();
            let part = component
                .project_tangent(point.slice(s![off..off + m]), vec.slice(s![off..off + m]))?;
            for i in 0..m {
                out[off + i] = part[i];
            }
            off += m;
        }
        Ok(out)
    }
}