use crate::solver::latent_cache::LatentRetractionRegistry;
use crate::terms::basis::{BasisError, RadialScalarKind};
use ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, ArrayView3};
use std::sync::atomic::{AtomicU64, Ordering};
const SPHERE_NORMAL_PIN: f64 = 1.0;
static NEXT_LATENT_COORD_ID: AtomicU64 = AtomicU64::new(1);
fn next_latent_coord_id() -> u64 {
NEXT_LATENT_COORD_ID.fetch_add(1, Ordering::Relaxed)
}
#[derive(Debug, Clone, Copy)]
pub enum AuxPriorFamily {
Ridge,
Linear,
}
#[derive(Debug, Clone, Copy)]
pub enum AuxPriorStrength {
Auto,
Fixed(f64),
}
#[derive(Debug, Clone)]
pub enum LatentIdMode {
AuxPrior {
u: Array2<f64>,
family: AuxPriorFamily,
strength: AuxPriorStrength,
},
AuxPriorDimSelection {
u: Array2<f64>,
family: AuxPriorFamily,
strength: AuxPriorStrength,
init_log_precision: Option<Array1<f64>>,
},
DimSelection {
init_log_precision: Option<Array1<f64>>,
},
None,
}
#[derive(Debug, Clone, PartialEq, Default)]
pub enum LatentManifold {
#[default]
Euclidean,
Circle { period: f64 },
Sphere { dim: usize },
Interval { lo: f64, hi: f64 },
Product(Vec<LatentManifold>),
ProductWithMetric {
manifolds: Vec<LatentManifold>,
weights: Vec<f64>,
},
}
impl LatentManifold {
pub fn is_euclidean(&self) -> bool {
matches!(self, Self::Euclidean)
}
pub fn ambient_dim(&self, fallback_dim: usize) -> usize {
match self {
Self::Euclidean => fallback_dim,
Self::Circle { .. } | Self::Interval { .. } => 1,
Self::Sphere { dim } => *dim,
Self::Product(parts)
| Self::ProductWithMetric {
manifolds: parts, ..
} => parts.iter().map(|part| part.ambient_dim(1)).sum(),
}
}
pub fn metric_weights(&self) -> Vec<f64> {
match self {
Self::Euclidean => vec![1.0],
Self::Circle { period } => {
assert!(
period.is_finite() && *period > 0.0,
"LatentManifold::Circle requires a finite positive period; got {period}"
);
vec![1.0 / (period * period)]
}
Self::Sphere { dim } => {
let w = 1.0 / (std::f64::consts::PI * std::f64::consts::PI);
vec![w; *dim]
}
Self::Interval { lo, hi } => {
let scale = hi - lo;
vec![1.0 / (scale * scale)]
}
Self::Product(parts) => {
let mut out = Vec::with_capacity(self.ambient_dim(1));
for part in parts {
out.extend(part.metric_weights());
}
out
}
Self::ProductWithMetric { manifolds, weights } => {
let expected: usize = manifolds.iter().map(|part| part.ambient_dim(1)).sum();
assert_eq!(
weights.len(),
expected,
"LatentManifold::ProductWithMetric weights length must match ambient dimension"
);
weights.clone()
}
}
}
pub fn project_point(&self, t: ArrayView1<'_, f64>) -> Array1<f64> {
match self {
Self::Euclidean => t.to_owned(),
Self::Circle { period } => {
let mut out = Array1::<f64>::zeros(1);
out[0] = wrap_to_period(t[0], *period);
out
}
Self::Sphere { dim } => {
assert_eq!(t.len(), *dim);
normalize_or_axis(t, *dim)
}
Self::Interval { lo, hi } => {
let mut out = Array1::<f64>::zeros(1);
out[0] = t[0].clamp(*lo, *hi);
out
}
Self::Product(parts)
| Self::ProductWithMetric {
manifolds: parts, ..
} => {
let mut out = Array1::<f64>::zeros(t.len());
let mut offset = 0_usize;
for part in parts {
let dim = part.ambient_dim(1);
let projected = part.project_point(t.slice(ndarray::s![offset..offset + dim]));
for a in 0..dim {
out[offset + a] = projected[a];
}
offset += dim;
}
assert_eq!(offset, t.len());
out
}
}
}
pub fn retract(&self, t: ArrayView1<'_, f64>, xi: ArrayView1<'_, f64>) -> Array1<f64> {
assert_eq!(t.len(), xi.len());
match self {
Self::Euclidean => {
let mut out = t.to_owned();
for a in 0..out.len() {
out[a] += xi[a];
}
out
}
Self::Circle { period } => {
let mut out = Array1::<f64>::zeros(1);
out[0] = wrap_to_period(t[0] + xi[0], *period);
out
}
Self::Sphere { dim } => {
assert_eq!(t.len(), *dim);
let mut y = Array1::<f64>::zeros(*dim);
for a in 0..*dim {
y[a] = t[a] + xi[a];
}
normalize_or_axis(y.view(), *dim)
}
Self::Interval { lo, hi } => {
let mut out = Array1::<f64>::zeros(1);
out[0] = (t[0] + xi[0]).clamp(*lo, *hi);
out
}
Self::Product(parts)
| Self::ProductWithMetric {
manifolds: parts, ..
} => {
let mut out = Array1::<f64>::zeros(t.len());
let mut offset = 0_usize;
for part in parts {
let dim = part.ambient_dim(1);
let next = part.retract(
t.slice(ndarray::s![offset..offset + dim]),
xi.slice(ndarray::s![offset..offset + dim]),
);
for a in 0..dim {
out[offset + a] = next[a];
}
offset += dim;
}
assert_eq!(offset, t.len());
out
}
}
}
pub fn project_to_tangent(
&self,
t: ArrayView1<'_, f64>,
v: ArrayView1<'_, f64>,
) -> Array1<f64> {
assert_eq!(t.len(), v.len());
match self {
Self::Euclidean | Self::Circle { .. } => v.to_owned(),
Self::Sphere { dim } => {
assert_eq!(t.len(), *dim);
let tv = dot_views(t, v);
let mut out = v.to_owned();
for a in 0..*dim {
out[a] -= tv * t[a];
}
out
}
Self::Interval { lo, hi } => {
let mut out = Array1::<f64>::zeros(1);
let at_lo = t[0] <= *lo && v[0] < 0.0;
let at_hi = t[0] >= *hi && v[0] > 0.0;
out[0] = if at_lo || at_hi { 0.0 } else { v[0] };
out
}
Self::Product(parts)
| Self::ProductWithMetric {
manifolds: parts, ..
} => {
let mut out = Array1::<f64>::zeros(v.len());
let mut offset = 0_usize;
for part in parts {
let dim = part.ambient_dim(1);
let projected = part.project_to_tangent(
t.slice(ndarray::s![offset..offset + dim]),
v.slice(ndarray::s![offset..offset + dim]),
);
for a in 0..dim {
out[offset + a] = projected[a];
}
offset += dim;
}
assert_eq!(offset, v.len());
out
}
}
}
pub fn euclidean_to_riemannian_hessian(
&self,
t: ArrayView1<'_, f64>,
eg: ArrayView1<'_, f64>,
eh: ArrayView2<'_, f64>,
xi: ArrayView1<'_, f64>,
) -> Array1<f64> {
assert_eq!(t.len(), eg.len());
assert_eq!(t.len(), xi.len());
assert_eq!(eh.nrows(), t.len());
assert_eq!(eh.ncols(), t.len());
let eh_xi = matvec(eh, xi);
self.euclidean_hessian_action_to_riemannian(t, eg, xi, eh_xi.view())
}
fn euclidean_hessian_action_to_riemannian(
&self,
t: ArrayView1<'_, f64>,
eg: ArrayView1<'_, f64>,
xi: ArrayView1<'_, f64>,
eh_xi: ArrayView1<'_, f64>,
) -> Array1<f64> {
assert_eq!(t.len(), eg.len());
assert_eq!(t.len(), xi.len());
assert_eq!(t.len(), eh_xi.len());
match self {
Self::Euclidean | Self::Circle { .. } | Self::Interval { .. } => {
self.project_to_tangent(t, eh_xi)
}
Self::Sphere { dim } => {
assert_eq!(t.len(), *dim);
let grad_r = self.project_to_tangent(t, eg);
let mut ambient = self.project_to_tangent(t, eh_xi);
let eg_normal = dot_views(eg, t);
let normal_curve = dot_views(grad_r.view(), xi);
for a in 0..*dim {
ambient[a] -= eg_normal * xi[a];
ambient[a] -= normal_curve * t[a];
}
self.project_to_tangent(t, ambient.view())
}
Self::Product(parts)
| Self::ProductWithMetric {
manifolds: parts, ..
} => {
let mut out = Array1::<f64>::zeros(t.len());
let mut offset = 0_usize;
for part in parts {
let dim = part.ambient_dim(1);
let converted = part.euclidean_hessian_action_to_riemannian(
t.slice(ndarray::s![offset..offset + dim]),
eg.slice(ndarray::s![offset..offset + dim]),
xi.slice(ndarray::s![offset..offset + dim]),
eh_xi.slice(ndarray::s![offset..offset + dim]),
);
for a in 0..dim {
out[offset + a] = converted[a];
}
offset += dim;
}
assert_eq!(offset, t.len());
out
}
}
}
pub fn riemannian_hessian_matrix(
&self,
t: ArrayView1<'_, f64>,
eg: ArrayView1<'_, f64>,
eh: ArrayView2<'_, f64>,
) -> Array2<f64> {
let d = t.len();
let mut out = Array2::<f64>::zeros((d, d));
let mut xi = Array1::<f64>::zeros(d);
for a in 0..d {
xi.fill(0.0);
xi[a] = 1.0;
let tangent_xi = self.project_to_tangent(t, xi.view());
let col = self.euclidean_to_riemannian_hessian(t, eg, eh, tangent_xi.view());
for b in 0..d {
out[[b, a]] = col[b];
}
}
self.add_normal_pinning(t, &mut out);
symmetrize(&mut out);
out
}
pub fn project_matrix_columns_to_tangent(
&self,
t: ArrayView1<'_, f64>,
matrix: ArrayView2<'_, f64>,
) -> Array2<f64> {
let mut out = Array2::<f64>::zeros(matrix.dim());
for col_idx in 0..matrix.ncols() {
let col = self.project_to_tangent(t, matrix.column(col_idx));
for row_idx in 0..matrix.nrows() {
out[[row_idx, col_idx]] = col[row_idx];
}
}
out
}
fn add_normal_pinning(&self, t: ArrayView1<'_, f64>, matrix: &mut Array2<f64>) {
match self {
Self::Sphere { dim } => {
assert_eq!(t.len(), *dim);
for a in 0..*dim {
for b in 0..*dim {
matrix[[a, b]] += SPHERE_NORMAL_PIN * t[a] * t[b];
}
}
}
Self::Product(parts)
| Self::ProductWithMetric {
manifolds: parts, ..
} => {
let mut offset = 0_usize;
for part in parts {
let dim = part.ambient_dim(1);
let mut block =
matrix.slice_mut(ndarray::s![offset..offset + dim, offset..offset + dim]);
let mut owned = block.to_owned();
part.add_normal_pinning(t.slice(ndarray::s![offset..offset + dim]), &mut owned);
block.assign(&owned);
offset += dim;
}
}
Self::Euclidean | Self::Circle { .. } | Self::Interval { .. } => {}
}
}
}
impl LatentIdMode {
pub fn is_identifiable(&self) -> bool {
matches!(
self,
Self::AuxPrior { .. } | Self::AuxPriorDimSelection { .. }
)
}
fn reject_dim_selection_alone(&self) {
if matches!(self, Self::DimSelection { .. }) {
panic!(
"LatentIdMode::DimSelection is not a standalone gauge fix; pair ARD with AuxPrior or Isometry"
);
}
}
}
pub enum InputLocationDerivative<'a> {
Radial {
centers: ArrayView2<'a, f64>,
radial_kind: &'a RadialScalarKind,
},
Jet(ArrayView3<'a, f64>),
}
#[derive(Debug, Clone)]
pub struct LatentCoordValues {
id: u64,
values: Array1<f64>,
n_obs: usize,
latent_dim: usize,
id_mode: LatentIdMode,
manifold: LatentManifold,
retraction_registry: LatentRetractionRegistry,
}
impl LatentCoordValues {
pub fn from_matrix(matrix: ArrayView2<'_, f64>, id_mode: LatentIdMode) -> Self {
Self::from_matrix_with_manifold(matrix, id_mode, LatentManifold::Euclidean)
}
pub fn from_matrix_with_manifold(
matrix: ArrayView2<'_, f64>,
id_mode: LatentIdMode,
manifold: LatentManifold,
) -> Self {
Self::from_matrix_with_manifold_and_retraction(
matrix,
id_mode,
manifold,
LatentRetractionRegistry::all_euclidean(),
)
}
pub(crate) fn from_matrix_with_manifold_and_retraction(
matrix: ArrayView2<'_, f64>,
id_mode: LatentIdMode,
manifold: LatentManifold,
retraction_registry: LatentRetractionRegistry,
) -> Self {
id_mode.reject_dim_selection_alone();
let n_obs = matrix.nrows();
let latent_dim = matrix.ncols();
retraction_registry
.validate_dim(latent_dim, "LatentCoordValues::from_matrix_with_manifold")
.expect("invalid latent retraction dimension");
let mut values = Array1::<f64>::zeros(n_obs * latent_dim);
for n in 0..n_obs {
for k in 0..latent_dim {
values[n * latent_dim + k] = matrix[[n, k]];
}
}
let mut out = Self {
id: next_latent_coord_id(),
values,
n_obs,
latent_dim,
id_mode,
manifold,
retraction_registry,
};
out.project_all_rows_to_manifold();
out
}
pub fn from_flat(
values: Array1<f64>,
n_obs: usize,
latent_dim: usize,
id_mode: LatentIdMode,
) -> Self {
Self::from_flat_with_manifold(
values,
n_obs,
latent_dim,
id_mode,
LatentManifold::Euclidean,
)
}
pub fn from_flat_with_manifold(
values: Array1<f64>,
n_obs: usize,
latent_dim: usize,
id_mode: LatentIdMode,
manifold: LatentManifold,
) -> Self {
Self::from_flat_with_manifold_and_retraction_and_id(
values,
n_obs,
latent_dim,
id_mode,
manifold,
LatentRetractionRegistry::all_euclidean(),
next_latent_coord_id(),
)
}
pub(crate) fn from_flat_with_manifold_and_retraction_and_id(
values: Array1<f64>,
n_obs: usize,
latent_dim: usize,
id_mode: LatentIdMode,
manifold: LatentManifold,
retraction_registry: LatentRetractionRegistry,
id: u64,
) -> Self {
id_mode.reject_dim_selection_alone();
assert_eq!(
values.len(),
n_obs * latent_dim,
"LatentCoordValues::from_flat: length {} != n_obs * latent_dim = {}",
values.len(),
n_obs * latent_dim
);
retraction_registry
.validate_dim(latent_dim, "LatentCoordValues::from_flat_with_manifold")
.expect("invalid latent retraction dimension");
let mut out = Self {
id,
values,
n_obs,
latent_dim,
id_mode,
manifold,
retraction_registry,
};
out.project_all_rows_to_manifold();
out
}
pub fn latent_id(&self) -> u64 {
self.id
}
pub fn n_obs(&self) -> usize {
self.n_obs
}
pub fn latent_dim(&self) -> usize {
self.latent_dim
}
pub fn len(&self) -> usize {
self.values.len()
}
pub fn is_empty(&self) -> bool {
self.values.is_empty()
}
pub fn id_mode(&self) -> &LatentIdMode {
&self.id_mode
}
pub fn manifold(&self) -> &LatentManifold {
&self.manifold
}
pub(crate) fn retraction_registry(&self) -> &LatentRetractionRegistry {
&self.retraction_registry
}
pub(crate) fn effective_is_all_euclidean(&self) -> bool {
self.manifold.is_euclidean() && self.retraction_registry.is_all_euclidean()
}
pub(crate) fn effective_metric_weights(&self) -> Vec<f64> {
if self.manifold.is_euclidean() {
self.retraction_registry.metric_weights(self.latent_dim)
} else {
self.manifold.metric_weights()
}
}
pub fn with_manifold(&self, manifold: LatentManifold) -> Self {
Self::from_flat_with_manifold_and_retraction_and_id(
self.values.clone(),
self.n_obs,
self.latent_dim,
self.id_mode.clone(),
manifold,
self.retraction_registry.clone(),
self.id,
)
}
pub fn as_flat(&self) -> &Array1<f64> {
&self.values
}
pub fn row(&self, n: usize) -> &[f64] {
let start = n * self.latent_dim;
let end = start + self.latent_dim;
&self.values.as_slice().expect("contiguous")[start..end]
}
pub fn as_matrix(&self) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((self.n_obs, self.latent_dim));
for n in 0..self.n_obs {
for k in 0..self.latent_dim {
out[[n, k]] = self.values[n * self.latent_dim + k];
}
}
out
}
pub fn set_flat(&mut self, flat: ArrayView1<'_, f64>) {
assert_eq!(flat.len(), self.values.len());
self.values.assign(&flat);
self.project_all_rows_to_manifold();
}
pub fn retract_flat_delta(&mut self, delta: ArrayView1<'_, f64>) {
assert_eq!(delta.len(), self.values.len());
if self.retraction_registry.is_all_euclidean() {
if self.manifold.is_euclidean() {
for (t, dt) in self.values.iter_mut().zip(delta.iter()) {
*t += *dt;
}
return;
}
assert_eq!(
self.manifold.ambient_dim(self.latent_dim),
self.latent_dim,
"LatentCoordValues::retract_flat_delta: manifold ambient dim does not match latent_dim",
);
for n in 0..self.n_obs {
let start = n * self.latent_dim;
let end = start + self.latent_dim;
let next = self.manifold.retract(
self.values.slice(ndarray::s![start..end]),
delta.slice(ndarray::s![start..end]),
);
for a in 0..self.latent_dim {
self.values[start + a] = next[a];
}
}
return;
}
for n in 0..self.n_obs {
let start = n * self.latent_dim;
let end = start + self.latent_dim;
let mut current = self.values.slice_mut(ndarray::s![start..end]);
let xi = delta.slice(ndarray::s![start..end]);
self.retraction_registry.retract(&mut current, xi);
}
}
fn project_all_rows_to_manifold(&mut self) {
if self.manifold.is_euclidean() {
return;
}
assert_eq!(self.manifold.ambient_dim(self.latent_dim), self.latent_dim);
for n in 0..self.n_obs {
let start = n * self.latent_dim;
let end = start + self.latent_dim;
let projected = self
.manifold
.project_point(self.values.slice(ndarray::s![start..end]));
for a in 0..self.latent_dim {
self.values[start + a] = projected[a];
}
}
}
pub fn apply_tospec(&self) -> Array2<f64> {
self.as_matrix()
}
pub(crate) fn design_gradient_wrt_t(
&self,
centers: ArrayView2<'_, f64>,
radial_kind: &RadialScalarKind,
) -> Result<Array3<f64>, BasisError> {
let n_obs = self.n_obs;
let d = self.latent_dim;
let n_centers = centers.nrows();
if centers.ncols() != d {
return Err(BasisError::DimensionMismatch(format!(
"LatentCoordValues::design_gradient_wrt_t center dimension mismatch: centers have {} cols but latent_dim is {}",
centers.ncols(),
d
)));
}
let mut jet = Array3::<f64>::zeros((n_obs, n_centers, d));
for n in 0..n_obs {
let t_n = self.row(n);
for k in 0..n_centers {
let mut r2 = 0.0_f64;
for a in 0..d {
let delta = t_n[a] - centers[[k, a]];
r2 += delta * delta;
}
let r = r2.sqrt();
let (_, q, _) = radial_kind.eval_design_triplet(r)?;
if q == 0.0 {
continue;
}
for a in 0..d {
jet[[n, k, a]] = q * (t_n[a] - centers[[k, a]]);
}
}
}
Ok(jet)
}
pub(crate) fn design_gradient_wrt_t_dispatch(
&self,
input: InputLocationDerivative<'_>,
) -> Result<Array3<f64>, BasisError> {
match input {
InputLocationDerivative::Radial {
centers,
radial_kind,
} => self.design_gradient_wrt_t(centers, radial_kind),
InputLocationDerivative::Jet(jet) => {
if jet.shape() != [self.n_obs, jet.shape()[1], self.latent_dim] {
return Err(BasisError::DimensionMismatch(format!(
"LatentCoordValues::design_gradient_wrt_t_dispatch jet shape {:?} does not match latent shape ({}, {}, {})",
jet.shape(),
self.n_obs,
jet.shape()[1],
self.latent_dim
)));
}
Ok(jet.to_owned())
}
}
}
}
fn wrap_to_period(x: f64, period: f64) -> f64 {
assert!(
period.is_finite() && period > 0.0,
"wrap_to_period requires a finite positive period; got {period}"
);
let y = x.rem_euclid(period);
if y == period { 0.0 } else { y }
}
fn normalize_or_axis(v: ArrayView1<'_, f64>, dim: usize) -> Array1<f64> {
let mut norm_sq = 0.0_f64;
for a in 0..dim {
norm_sq += v[a] * v[a];
}
if norm_sq <= 0.0 || !norm_sq.is_finite() {
panic!("LatentManifold::Sphere cannot normalize a zero or non-finite ambient vector");
}
let inv = 1.0 / norm_sq.sqrt();
let mut out = Array1::<f64>::zeros(dim);
for a in 0..dim {
out[a] = v[a] * inv;
}
out
}
fn dot_views(a: ArrayView1<'_, f64>, b: ArrayView1<'_, f64>) -> f64 {
assert_eq!(a.len(), b.len());
let mut acc = 0.0_f64;
for i in 0..a.len() {
acc += a[i] * b[i];
}
acc
}
fn matvec(a: ArrayView2<'_, f64>, x: ArrayView1<'_, f64>) -> Array1<f64> {
assert_eq!(a.ncols(), x.len());
let mut out = Array1::<f64>::zeros(a.nrows());
for i in 0..a.nrows() {
let mut acc = 0.0_f64;
for j in 0..a.ncols() {
acc += a[[i, j]] * x[j];
}
out[i] = acc;
}
out
}
#[inline]
fn symmetrize(a: &mut Array2<f64>) {
crate::linalg::utils::enforce_symmetry(a)
}
pub fn aux_prior_targets(
t: ArrayView2<'_, f64>,
u: ArrayView2<'_, f64>,
family: AuxPriorFamily,
) -> Result<Array2<f64>, String> {
let n_obs = t.nrows();
let d = t.ncols();
if u.nrows() != n_obs {
return Err(format!(
"aux_prior_targets: u has {} rows but t has {}",
u.nrows(),
n_obs
));
}
let p = u.ncols();
if p == 0 {
return Err("aux_prior_targets: auxiliary u must have at least one column".into());
}
let mut gram = Array2::<f64>::zeros((p, p));
for n in 0..n_obs {
for i in 0..p {
for j in 0..p {
gram[[i, j]] += u[[n, i]] * u[[n, j]];
}
}
}
let ridge_eps = match family {
AuxPriorFamily::Ridge => {
let trace: f64 = (0..p).map(|i| gram[[i, i]]).sum();
(1e-6 * trace / p as f64).max(1e-12)
}
AuxPriorFamily::Linear => 0.0,
};
for i in 0..p {
gram[[i, i]] += ridge_eps;
}
let mut rhs = Array2::<f64>::zeros((p, d));
for n in 0..n_obs {
for i in 0..p {
for k in 0..d {
rhs[[i, k]] += u[[n, i]] * t[[n, k]];
}
}
}
let coeffs = solve_spd(gram.view(), rhs.view())?;
let mut targets = Array2::<f64>::zeros((n_obs, d));
for n in 0..n_obs {
for k in 0..d {
let mut acc = 0.0_f64;
for i in 0..p {
acc += u[[n, i]] * coeffs[[i, k]];
}
targets[[n, k]] = acc;
}
}
Ok(targets)
}
fn solve_spd(a: ArrayView2<'_, f64>, b: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
let n = a.nrows();
if a.ncols() != n {
return Err("solve_spd: A must be square".into());
}
if b.nrows() != n {
return Err("solve_spd: RHS row count mismatch".into());
}
let mut l = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in 0..=i {
let mut sum = a[[i, j]];
for k in 0..j {
sum -= l[[i, k]] * l[[j, k]];
}
if i == j {
if sum <= 0.0 {
return Err(format!(
"solve_spd: non-positive pivot {sum} at index {i} \
(matrix is not positive definite)"
));
}
l[[i, j]] = sum.sqrt();
} else {
l[[i, j]] = sum / l[[j, j]];
}
}
}
let d = b.ncols();
let mut out = Array2::<f64>::zeros((n, d));
for col in 0..d {
let mut y = Array1::<f64>::zeros(n);
for i in 0..n {
let mut sum = b[[i, col]];
for k in 0..i {
sum -= l[[i, k]] * y[k];
}
y[i] = sum / l[[i, i]];
}
for i in (0..n).rev() {
let mut sum = y[i];
for k in (i + 1)..n {
sum -= l[[k, i]] * out[[k, col]];
}
out[[i, col]] = sum / l[[i, i]];
}
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn from_matrix_roundtrip() {
let m = array![[1.0_f64, 2.0], [3.0, 4.0], [5.0, 6.0]];
let lc = LatentCoordValues::from_matrix(m.view(), LatentIdMode::None);
assert_eq!(lc.n_obs(), 3);
assert_eq!(lc.latent_dim(), 2);
let back = lc.as_matrix();
assert_eq!(back, m);
}
#[test]
fn row_access() {
let m = array![[1.0_f64, 2.0], [3.0, 4.0]];
let lc = LatentCoordValues::from_matrix(m.view(), LatentIdMode::None);
assert_eq!(lc.row(0), &[1.0, 2.0]);
assert_eq!(lc.row(1), &[3.0, 4.0]);
}
#[test]
fn circle_manifold_update_wraps_into_canonical_interval() {
let two_pi = std::f64::consts::TAU;
let near_top = 6.2_f64;
let m = array![[near_top]];
let mut lc = LatentCoordValues::from_matrix_with_manifold(
m.view(),
LatentIdMode::None,
LatentManifold::Circle { period: two_pi },
);
let delta = Array1::from(vec![1.5_f64]);
lc.retract_flat_delta(delta.view());
let updated = lc.row(0)[0];
let expected = (near_top + 1.5).rem_euclid(two_pi);
assert!(
(0.0..two_pi).contains(&updated),
"Circle retraction did not wrap into [0, 2π): got {updated}",
);
assert!(
(updated - expected).abs() < 1e-12,
"Circle retraction value mismatch: got {updated}, expected {expected}",
);
let large_delta = Array1::from(vec![10.0 * two_pi + 0.25_f64]);
lc.retract_flat_delta(large_delta.view());
let after_big = lc.row(0)[0];
assert!(
(0.0..two_pi).contains(&after_big),
"Circle retraction did not wrap a large delta: got {after_big}",
);
}
#[test]
fn sphere_manifold_update_preserves_unit_norm() {
let m = array![[1.0_f64, 0.0, 0.0]];
let mut lc = LatentCoordValues::from_matrix_with_manifold(
m.view(),
LatentIdMode::None,
LatentManifold::Sphere { dim: 3 },
);
let delta = Array1::from(vec![0.3_f64, 0.7, -0.2]);
lc.retract_flat_delta(delta.view());
let row = lc.row(0);
let norm_sq: f64 = row.iter().map(|x| x * x).sum();
assert!(
(norm_sq.sqrt() - 1.0).abs() < 1e-12,
"Sphere retraction did not preserve unit norm: ||t|| = {}",
norm_sq.sqrt(),
);
let big_delta = Array1::from(vec![50.0_f64, -25.0, 13.0]);
lc.retract_flat_delta(big_delta.view());
let row2 = lc.row(0);
let norm_sq2: f64 = row2.iter().map(|x| x * x).sum();
assert!(
(norm_sq2.sqrt() - 1.0).abs() < 1e-12,
"Sphere retraction failed to renormalize after large delta: ||t|| = {}",
norm_sq2.sqrt(),
);
}
}