use ndarray::{Array1, ArrayView1};
use super::constant_curvature::{ConstantCurvature, log_map_kappa_jet};
use super::manifold::GeometryResult;
use super::closure_family::inv_std_normal;
fn std_normal_cdf(x: f64) -> f64 {
0.5 * libm::erfc(-x / std::f64::consts::SQRT_2)
}
fn chi2_1_sf(t: f64) -> f64 {
if t <= 0.0 {
return 1.0;
}
2.0 * (1.0 - std_normal_cdf(t.sqrt()))
}
fn chi2_1_quantile(level: f64) -> f64 {
let z = inv_std_normal(0.5 * (1.0 + level));
z * z
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum CurvatureVerdict {
Spherical,
Hyperbolic,
Flat,
}
#[derive(Clone, Copy, Debug)]
pub struct KappaProfileCi {
pub kappa_hat: f64,
pub ci_lo: f64,
pub ci_hi: f64,
pub lo_at_bound: bool,
pub hi_at_bound: bool,
pub verdict: CurvatureVerdict,
}
#[derive(Clone, Copy, Debug)]
pub struct FlatnessTest {
pub lr_stat: f64,
pub p_value: f64,
pub kappa_hat: f64,
}
pub fn wald_half_width(v_pp: f64, level: f64) -> Option<f64> {
if !(v_pp.is_finite()) || v_pp <= 0.0 {
return None;
}
let z = inv_std_normal(0.5 * (1.0 + level));
Some(z / v_pp.sqrt())
}
pub fn profile_ci_walk<F>(
mut v_p: F,
kappa_hat: f64,
v_pp: f64,
kappa_min: f64,
kappa_max: f64,
level: f64,
tol: f64,
) -> Result<KappaProfileCi, String>
where
F: FnMut(f64) -> Result<f64, String>,
{
if !(level > 0.0 && level < 1.0) {
return Err("profile CI level must lie in (0, 1)".into());
}
if !(kappa_min < kappa_max) {
return Err("kappa bounds must satisfy kappa_min < kappa_max".into());
}
if !(kappa_hat.is_finite()) || kappa_hat < kappa_min || kappa_hat > kappa_max {
return Err("kappa_hat must be finite and inside [kappa_min, kappa_max]".into());
}
let tol = if tol > 0.0 { tol } else { 1e-6 };
let half_thresh = 0.5 * chi2_1_quantile(level);
let v_hat = v_p(kappa_hat)?;
if !v_hat.is_finite() {
return Err("V_p(kappa_hat) is non-finite".into());
}
let init_step = wald_half_width(v_pp, level)
.filter(|h| h.is_finite() && *h > 0.0)
.unwrap_or_else(|| 0.1 * (kappa_max - kappa_min).max(tol));
let drop = |v: f64| v - v_hat;
let cfg = WalkCfg {
kappa_hat,
init_step,
half_thresh,
tol,
};
let (ci_lo, lo_at_bound) = walk_one_side(&mut v_p, &cfg, -1.0, kappa_min, &drop)?;
let (ci_hi, hi_at_bound) = walk_one_side(&mut v_p, &cfg, 1.0, kappa_max, &drop)?;
let verdict = if ci_lo > 0.0 {
CurvatureVerdict::Spherical
} else if ci_hi < 0.0 {
CurvatureVerdict::Hyperbolic
} else {
CurvatureVerdict::Flat
};
Ok(KappaProfileCi {
kappa_hat,
ci_lo,
ci_hi,
lo_at_bound,
hi_at_bound,
verdict,
})
}
struct WalkCfg {
kappa_hat: f64,
init_step: f64,
half_thresh: f64,
tol: f64,
}
fn walk_one_side<F, D>(
v_p: &mut F,
cfg: &WalkCfg,
sign: f64,
bound: f64,
drop: &D,
) -> Result<(f64, bool), String>
where
F: FnMut(f64) -> Result<f64, String>,
D: Fn(f64) -> f64,
{
let WalkCfg {
kappa_hat,
init_step,
half_thresh,
tol,
} = *cfg;
let mut lo = kappa_hat;
let mut step = init_step.max(tol);
let span = (bound - kappa_hat) * sign; if span <= tol {
return Ok((bound, true));
}
let mut probe = step.min(span);
loop {
let kappa = kappa_hat + sign * probe;
let v = v_p(kappa)?;
if !v.is_finite() {
return Err("V_p returned a non-finite value during the CI walk".into());
}
if drop(v) >= half_thresh {
let mut a = lo; let mut b = kappa; while (b - a).abs() > tol {
let m = 0.5 * (a + b);
let vm = v_p(m)?;
if !vm.is_finite() {
return Err("V_p returned a non-finite value during bisection".into());
}
if drop(vm) >= half_thresh {
b = m;
} else {
a = m;
}
}
return Ok((0.5 * (a + b), false));
}
lo = kappa;
if (probe - span).abs() <= tol {
return Ok((bound, true));
}
step *= 2.0;
probe = (probe + step).min(span);
}
}
pub fn flatness_lr_test<F>(mut v_p: F, kappa_hat: f64) -> Result<FlatnessTest, String>
where
F: FnMut(f64) -> Result<f64, String>,
{
let v_hat = v_p(kappa_hat)?;
let v_zero = v_p(0.0)?;
if !v_hat.is_finite() || !v_zero.is_finite() {
return Err("V_p evaluated to a non-finite value in the flatness test".into());
}
let lr_stat = (2.0 * (v_zero - v_hat)).max(0.0);
let p_value = chi2_1_sf(lr_stat);
Ok(FlatnessTest {
lr_stat,
p_value,
kappa_hat,
})
}
#[derive(Clone, Debug)]
pub struct DesignCoordKappaJet {
pub coord: Array1<f64>,
pub d_kappa: Array1<f64>,
pub d_kappa2: Array1<f64>,
}
pub fn design_coord_kappa_derivative(
manifold: &ConstantCurvature,
base: ArrayView1<'_, f64>,
point: ArrayView1<'_, f64>,
) -> GeometryResult<DesignCoordKappaJet> {
let (coord, d_kappa, d_kappa2) = log_map_kappa_jet(manifold, base, point)?;
Ok(DesignCoordKappaJet {
coord,
d_kappa,
d_kappa2,
})
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
fn quad(v0: f64, a: f64, k_star: f64) -> impl Fn(f64) -> Result<f64, String> {
move |k: f64| Ok(v0 + 0.5 * a * (k - k_star) * (k - k_star))
}
#[test]
fn wald_half_width_matches_closed_form() {
let level = 0.95;
let a = 3.0;
let h = wald_half_width(a, level).expect("positive curvature");
let z = inv_std_normal(0.5 * (1.0 + level));
assert!((h - z / a.sqrt()).abs() < 1e-12);
assert!(wald_half_width(0.0, level).is_none());
assert!(wald_half_width(-1.0, level).is_none());
}
#[test]
fn profile_ci_walk_recovers_quadratic_crossing() {
let level = 0.95;
let a = 2.5;
let k_star = -0.7;
let f = quad(0.3, a, k_star);
let ci = profile_ci_walk(
|k| f(k),
k_star,
a, -10.0,
10.0,
level,
1e-9,
)
.expect("CI walk");
let chi2 = chi2_1_quantile(level);
let half = (chi2 / a).sqrt();
assert!((ci.ci_lo - (k_star - half)).abs() < 1e-6, "lo {}", ci.ci_lo);
assert!((ci.ci_hi - (k_star + half)).abs() < 1e-6, "hi {}", ci.ci_hi);
assert!(!ci.lo_at_bound && !ci.hi_at_bound);
let expected = if ci.ci_lo > 0.0 {
CurvatureVerdict::Spherical
} else if ci.ci_hi < 0.0 {
CurvatureVerdict::Hyperbolic
} else {
CurvatureVerdict::Flat
};
assert_eq!(ci.verdict, expected);
}
#[test]
fn profile_ci_walk_verdict_hyperbolic_when_far_negative() {
let level = 0.95;
let a = 50.0; let k_star = -2.0;
let f = quad(0.0, a, k_star);
let ci = profile_ci_walk(|k| f(k), k_star, a, -10.0, 10.0, level, 1e-9).unwrap();
assert!(ci.ci_hi < 0.0, "ci_hi {}", ci.ci_hi);
assert_eq!(ci.verdict, CurvatureVerdict::Hyperbolic);
}
#[test]
fn profile_ci_walk_flags_bound_when_profile_too_flat() {
let level = 0.95;
let a = 1e-6;
let k_star = 0.0;
let f = quad(0.0, a, k_star);
let ci = profile_ci_walk(|k| f(k), k_star, a, -0.01, 0.01, level, 1e-9).unwrap();
assert!(ci.lo_at_bound && ci.hi_at_bound);
assert!((ci.ci_lo + 0.01).abs() < 1e-12 && (ci.ci_hi - 0.01).abs() < 1e-12);
assert_eq!(ci.verdict, CurvatureVerdict::Flat);
}
#[test]
fn flatness_test_zero_when_minimiser_is_flat() {
let f = quad(1.0, 4.0, 0.0);
let t = flatness_lr_test(|k| f(k), 0.0).unwrap();
assert!(t.lr_stat.abs() < 1e-12);
assert!((t.p_value - 1.0).abs() < 1e-12);
}
#[test]
fn flatness_test_lr_and_pvalue_match_chi2_1() {
let a = 3.0;
let k_star = 0.8;
let f = quad(0.5, a, k_star);
let t = flatness_lr_test(|k| f(k), k_star).unwrap();
let expected_lr = a * k_star * k_star;
assert!((t.lr_stat - expected_lr).abs() < 1e-10, "lr {}", t.lr_stat);
let expected_p = chi2_1_sf(expected_lr);
assert!((t.p_value - expected_p).abs() < 1e-12);
let half_chi2_p = 0.5 * expected_p;
assert!((t.p_value - half_chi2_p).abs() > 1e-6);
}
#[test]
fn chi2_1_sf_matches_known_quantiles() {
let q = chi2_1_quantile(0.95);
assert!((q - 3.841_458_820_694_124).abs() < 1e-6, "q {}", q);
assert!((chi2_1_sf(q) - 0.05).abs() < 1e-9);
assert!((chi2_1_sf(chi2_1_quantile(0.99)) - 0.01).abs() < 1e-9);
}
#[test]
fn design_coord_kappa_derivative_matches_jet_and_fd() {
let dim = 3;
let kappa = 0.6;
let manifold = ConstantCurvature::new(dim, kappa);
let base = array![0.05, -0.1, 0.07];
let point = array![0.2, 0.15, -0.05];
let jet = design_coord_kappa_derivative(&manifold, base.view(), point.view()).unwrap();
let (val, dk, dkk) = log_map_kappa_jet(&manifold, base.view(), point.view()).unwrap();
for i in 0..dim {
assert!((jet.coord[i] - val[i]).abs() < 1e-14);
assert!((jet.d_kappa[i] - dk[i]).abs() < 1e-14);
assert!((jet.d_kappa2[i] - dkk[i]).abs() < 1e-14);
}
let h = 1e-5;
let coord_at = |k: f64| -> Array1<f64> {
let m = ConstantCurvature::new(dim, k);
log_map_kappa_jet(&m, base.view(), point.view()).unwrap().0
};
let cp = coord_at(kappa + h);
let cm = coord_at(kappa - h);
let c0 = jet.coord.clone();
for i in 0..dim {
let fd1 = (cp[i] - cm[i]) / (2.0 * h);
let fd2 = (cp[i] - 2.0 * c0[i] + cm[i]) / (h * h);
assert!((jet.d_kappa[i] - fd1).abs() < 1e-6, "d_kappa[{i}] vs FD");
assert!((jet.d_kappa2[i] - fd2).abs() < 1e-4, "d_kappa2[{i}] vs FD");
}
}
#[test]
fn design_coord_kappa_derivative_fd_through_flat() {
let dim = 2;
let kappa = 1e-6;
let manifold = ConstantCurvature::new(dim, kappa);
let base = array![0.1, -0.2];
let point = array![0.25, 0.05];
let jet = design_coord_kappa_derivative(&manifold, base.view(), point.view()).unwrap();
let h = 1e-4;
let coord_at = |k: f64| -> Array1<f64> {
let m = ConstantCurvature::new(dim, k);
log_map_kappa_jet(&m, base.view(), point.view()).unwrap().0
};
let cp = coord_at(kappa + h);
let cm = coord_at(kappa - h);
for i in 0..dim {
let fd1 = (cp[i] - cm[i]) / (2.0 * h);
assert!((jet.d_kappa[i] - fd1).abs() < 1e-5, "flat d_kappa[{i}]");
}
}
}