gam 0.3.82

Generalized penalized likelihood engine
Documentation
pub mod circle;
pub mod euclidean;
pub mod grassmann;
pub mod integrator;
pub mod latent_seed;
pub mod lie_so;
pub mod manifold;
pub mod optimizer;
pub mod poincare;
pub mod product;
pub mod simplex;
pub mod spd;
pub mod sphere;
pub mod stiefel;
pub mod torus;

pub use circle::CircleManifold;
pub use euclidean::EuclideanManifold;
pub use grassmann::GrassmannManifold;
pub use integrator::GeodesicIntegrator;
pub use latent_seed::laplacian_eigenmap_coords;
pub use manifold::{GeometryError, GeometryResult, ManifoldSpec, RiemannianManifold};
pub use optimizer::{RiemannianLBFGS, RiemannianObjective, RiemannianTrustRegion};
pub use product::ProductManifold;
pub use spd::SpdManifold;
pub use sphere::SphereManifold;
pub use stiefel::StiefelManifold;
pub use torus::TorusManifold;

use ndarray::{Array1, ArrayView1};

/// Validate and normalize per-row weights for a manifold barycenter computation.
///
/// With `None`, returns uniform weights `1/n`. With `Some(w)`, requires `w.len() == n`
/// and every entry finite, non-negative, with positive total, then returns `w` divided
/// by its total so the weights sum to one.
pub(crate) fn normalize_weights(
    n: usize,
    weights: Option<ArrayView1<'_, f64>>,
) -> Result<Array1<f64>, String> {
    match weights {
        None => Ok(Array1::from_elem(n, 1.0 / n as f64)),
        Some(w) => {
            if w.len() != n {
                return Err("weights length must match the number of rows".to_string());
            }
            let mut total = 0.0_f64;
            for value in w.iter() {
                if !value.is_finite() || *value < 0.0 {
                    return Err(
                        "weights must be finite, non-negative, and have positive total".to_string(),
                    );
                }
                total += *value;
            }
            if total <= 0.0 {
                return Err(
                    "weights must be finite, non-negative, and have positive total".to_string(),
                );
            }
            Ok(w.mapv(|v| v / total))
        }
    }
}