Skip to main content

gam_geometry/
lib.rs

1pub mod closure_family;
2pub mod curvature_estimand;
3pub mod integrator;
4pub mod latent_seed;
5pub mod manifold;
6pub mod manifolds;
7pub mod optimizer;
8pub mod response_geometry;
9pub mod sinkhorn_barycenter;
10
11// Re-export each manifold submodule at the crate root so the historical paths
12// (`gam_geometry::sphere::SphereManifold`, …) keep resolving after the
13// `manifolds/` regrouping.
14pub use manifolds::{
15    circle, constant_curvature, euclidean, grassmann, lie_so, poincare, product, simplex, spd,
16    sphere, stiefel, torus,
17};
18
19pub use closure_family::{
20    ClosureFamily, ClosureProfileCi, boundary_conductance, conductance_penalty_jet,
21    profile_ci_from_grid,
22};
23pub use curvature_estimand::{
24    CurvatureVerdict, DesignCoordKappaJet, FlatnessTest, KappaProfileCi,
25    design_coord_kappa_derivative, flatness_lr_test, profile_ci_walk, wald_half_width,
26};
27pub use integrator::GeodesicIntegrator;
28pub use latent_seed::laplacian_eigenmap_coords;
29pub use manifold::{GeometryError, GeometryResult, ManifoldSpec, RiemannianManifold};
30pub use manifolds::{
31    CircleManifold, ConstantCurvature, EuclideanManifold, GrassmannManifold, ProductManifold,
32    SpdManifold, SphereManifold, StiefelManifold, TorusManifold, distance_kappa_jet,
33    exp_map_kappa_jet, log_map_kappa_jet, spd_frechet_mean,
34};
35pub use optimizer::{RiemannianLBFGS, RiemannianObjective, RiemannianTrustRegion};
36pub use response_geometry::{
37    ResponseCurvatureFit, ResponseManifold, fit_response_curvature, response_curvature_criterion,
38    response_exp_map, response_frechet_mean, response_log_map, response_projection_residual,
39};
40
41use ndarray::{Array1, ArrayView1};
42
43/// Validate and normalize per-row weights for a manifold barycenter computation.
44///
45/// With `None`, returns uniform weights `1/n`. With `Some(w)`, requires `w.len() == n`
46/// and every entry finite, non-negative, with positive total, then returns `w` divided
47/// by its total so the weights sum to one.
48pub(crate) fn normalize_weights(
49    n: usize,
50    weights: Option<ArrayView1<'_, f64>>,
51) -> Result<Array1<f64>, String> {
52    match weights {
53        None => Ok(Array1::from_elem(n, 1.0 / n as f64)),
54        Some(w) => {
55            if w.len() != n {
56                return Err("weights length must match the number of rows".to_string());
57            }
58            let mut total = 0.0_f64;
59            for value in w.iter() {
60                if !value.is_finite() || *value < 0.0 {
61                    return Err(
62                        "weights must be finite, non-negative, and have positive total".to_string(),
63                    );
64                }
65                total += *value;
66            }
67            if total <= 0.0 {
68                return Err(
69                    "weights must be finite, non-negative, and have positive total".to_string(),
70                );
71            }
72            Ok(w.mapv(|v| v / total))
73        }
74    }
75}