use ndarray::{Array1, ArrayView1};
use skel::Manifold;
use crate::ode::OdeMethod;
pub fn integrate_fixed_manifold<M>(
method: OdeMethod,
manifold: &M,
x0: &Array1<f64>,
t0: f64,
dt: f64,
steps: usize,
mut f: impl FnMut(&ArrayView1<f64>, f64) -> Array1<f64>,
) -> crate::Result<Array1<f64>>
where
M: Manifold,
{
if steps < 1 {
return Err(crate::Error::Domain("steps must be >= 1"));
}
if !dt.is_finite() {
return Err(crate::Error::Domain("dt must be finite"));
}
let mut x = x0.clone();
let mut t = t0;
match method {
OdeMethod::Euler => {
for _ in 0..steps {
let v = f(&x.view(), t);
let step = v.mapv(|u| u * dt);
x = manifold.exp_map(&x.view(), &step.view());
x = manifold.project(&x.view());
t += dt;
}
}
OdeMethod::Heun => {
for _ in 0..steps {
let v0 = f(&x.view(), t);
let step0 = v0.mapv(|u| u * dt);
let x_pred = manifold.exp_map(&x.view(), &step0.view());
let x_pred = manifold.project(&x_pred.view());
let v1 = f(&x_pred.view(), t + dt);
let v1_at_x = manifold.parallel_transport(&x_pred.view(), &x.view(), &v1.view());
let v_avg = (&v0 + &v1_at_x).mapv(|u| 0.5 * u);
let step = v_avg.mapv(|u| u * dt);
x = manifold.exp_map(&x.view(), &step.view());
x = manifold.project(&x.view());
t += dt;
}
}
}
Ok(x)
}
#[cfg(all(test, feature = "riemannian"))]
mod tests {
use super::*;
use hyperball::PoincareBall;
use proptest::prelude::*;
fn poincare_point() -> impl Strategy<Value = Array1<f64>> {
prop::collection::vec(-0.6f64..0.6f64, 2).prop_map(|v| {
let x = Array1::from_vec(v);
let norm = x.dot(&x).sqrt();
if norm > 0.75 {
x * (0.75 / norm)
} else {
x
}
})
}
fn small_vec2() -> impl Strategy<Value = Array1<f64>> {
prop::collection::vec(-0.2f64..0.2f64, 2).prop_map(Array1::from_vec)
}
#[test]
fn heun_tracks_geodesic_better_than_euler_smoke() {
let m = PoincareBall::<f64>::new(1.0);
let x0 = Array1::from_vec(vec![0.05, -0.02]);
let v0 = Array1::from_vec(vec![0.12, 0.04]);
let exact = m.exp_map(&x0.view(), &v0.view());
let steps = 64usize;
let dt = 1.0f64 / (steps as f64);
let euler = integrate_fixed_manifold(OdeMethod::Euler, &m, &x0, 0.0, dt, steps, |x, _t| {
m.parallel_transport(&x0.view(), x, &v0.view())
})
.unwrap();
let heun = integrate_fixed_manifold(OdeMethod::Heun, &m, &x0, 0.0, dt, steps, |x, _t| {
m.parallel_transport(&x0.view(), x, &v0.view())
})
.unwrap();
let err_e = (&euler - &exact).dot(&(&euler - &exact)).sqrt();
let err_h = (&heun - &exact).dot(&(&heun - &exact)).sqrt();
assert!(
err_h <= err_e + 1e-6,
"expected Heun <= Euler: err_heun={err_h} err_euler={err_e}"
);
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(96))]
#[test]
fn prop_error_decreases_with_more_steps_on_geodesic_field(
x0 in poincare_point(),
v0 in small_vec2(),
steps in 10usize..80,
) {
let m = PoincareBall::<f64>::new(1.0);
let exact = m.exp_map(&x0.view(), &v0.view());
let dt1 = 1.0f64 / (steps as f64);
let dt2 = 1.0f64 / ((2 * steps) as f64);
let e1 = integrate_fixed_manifold(OdeMethod::Euler, &m, &x0, 0.0, dt1, steps, |x, _t| {
m.parallel_transport(&x0.view(), x, &v0.view())
}).unwrap();
let e2 = integrate_fixed_manifold(OdeMethod::Euler, &m, &x0, 0.0, dt2, 2 * steps, |x, _t| {
m.parallel_transport(&x0.view(), x, &v0.view())
}).unwrap();
let h1 = integrate_fixed_manifold(OdeMethod::Heun, &m, &x0, 0.0, dt1, steps, |x, _t| {
m.parallel_transport(&x0.view(), x, &v0.view())
}).unwrap();
let h2 = integrate_fixed_manifold(OdeMethod::Heun, &m, &x0, 0.0, dt2, 2 * steps, |x, _t| {
m.parallel_transport(&x0.view(), x, &v0.view())
}).unwrap();
let err_e1 = (&e1 - &exact).dot(&(&e1 - &exact)).sqrt();
let err_e2 = (&e2 - &exact).dot(&(&e2 - &exact)).sqrt();
let err_h1 = (&h1 - &exact).dot(&(&h1 - &exact)).sqrt();
let err_h2 = (&h2 - &exact).dot(&(&h2 - &exact)).sqrt();
prop_assert!(err_h2 <= err_h1 + 2e-6, "heun error did not decrease: {err_h1} -> {err_h2}");
prop_assert!(err_e2 <= err_e1 + 2e-5, "euler got much worse: {err_e1} -> {err_e2}");
}
}
}