pub mod circle;
pub mod euclidean;
pub mod grassmann;
pub mod integrator;
pub mod manifold;
pub mod optimizer;
pub mod poincare;
pub mod product;
pub mod simplex;
pub mod spd;
pub mod sphere;
pub mod stiefel;
pub mod torus;
pub use circle::CircleManifold;
pub use euclidean::EuclideanManifold;
pub use grassmann::GrassmannManifold;
pub use integrator::GeodesicIntegrator;
pub use manifold::{GeometryError, GeometryResult, ManifoldSpec, RiemannianManifold};
pub use optimizer::{RiemannianLBFGS, RiemannianObjective, RiemannianTrustRegion};
pub use product::ProductManifold;
pub use spd::SpdManifold;
pub use sphere::SphereManifold;
pub use stiefel::StiefelManifold;
pub use torus::TorusManifold;
use ndarray::{Array1, ArrayView1};
pub(crate) fn normalize_weights(
n: usize,
weights: Option<ArrayView1<'_, f64>>,
) -> Result<Array1<f64>, String> {
match weights {
None => Ok(Array1::from_elem(n, 1.0 / n as f64)),
Some(w) => {
if w.len() != n {
return Err("weights length must match the number of rows".to_string());
}
let mut total = 0.0_f64;
for value in w.iter() {
if !value.is_finite() || *value < 0.0 {
return Err(
"weights must be finite, non-negative, and have positive total".to_string(),
);
}
total += *value;
}
if total <= 0.0 {
return Err(
"weights must be finite, non-negative, and have positive total".to_string(),
);
}
Ok(w.mapv(|v| v / total))
}
}
}