1pub mod closure_family;
2pub mod curvature_estimand;
3pub mod integrator;
4pub mod latent_seed;
5pub mod manifold;
6pub mod manifolds;
7pub mod optimizer;
8pub mod response_geometry;
9pub mod sinkhorn_barycenter;
10
11pub use manifolds::{
15 circle, constant_curvature, euclidean, grassmann, lie_so, poincare, product, simplex, spd,
16 sphere, stiefel, torus,
17};
18
19pub use closure_family::{
20 ClosureFamily, ClosureProfileCi, boundary_conductance, conductance_penalty_jet,
21 profile_ci_from_grid,
22};
23pub use curvature_estimand::{
24 CurvatureVerdict, DesignCoordKappaJet, FlatnessTest, KappaProfileCi,
25 design_coord_kappa_derivative, flatness_lr_test, profile_ci_walk, wald_half_width,
26};
27pub use integrator::GeodesicIntegrator;
28pub use latent_seed::laplacian_eigenmap_coords;
29pub use manifold::{GeometryError, GeometryResult, ManifoldSpec, RiemannianManifold};
30pub use manifolds::{
31 CircleManifold, ConstantCurvature, EuclideanManifold, GrassmannManifold, ProductManifold,
32 SpdManifold, SphereManifold, StiefelManifold, TorusManifold, distance_kappa_jet,
33 exp_map_kappa_jet, log_map_kappa_jet, spd_frechet_mean,
34};
35pub use optimizer::{RiemannianLBFGS, RiemannianObjective, RiemannianTrustRegion};
36pub use response_geometry::{
37 ResponseCurvatureFit, ResponseManifold, fit_response_curvature, response_curvature_criterion,
38 response_exp_map, response_frechet_mean, response_log_map, response_projection_residual,
39};
40
41use ndarray::{Array1, ArrayView1};
42
43pub(crate) fn normalize_weights(
49 n: usize,
50 weights: Option<ArrayView1<'_, f64>>,
51) -> Result<Array1<f64>, String> {
52 match weights {
53 None => Ok(Array1::from_elem(n, 1.0 / n as f64)),
54 Some(w) => {
55 if w.len() != n {
56 return Err("weights length must match the number of rows".to_string());
57 }
58 let mut total = 0.0_f64;
59 for value in w.iter() {
60 if !value.is_finite() || *value < 0.0 {
61 return Err(
62 "weights must be finite, non-negative, and have positive total".to_string(),
63 );
64 }
65 total += *value;
66 }
67 if total <= 0.0 {
68 return Err(
69 "weights must be finite, non-negative, and have positive total".to_string(),
70 );
71 }
72 Ok(w.mapv(|v| v / total))
73 }
74 }
75}