Skip to main content

pathwise_geo/scheme/
sri.rs

1use crate::sde::ManifoldSDE;
2use crate::simulate::GeoScheme;
3use cartan_core::{Manifold, ParallelTransport};
4use pathwise_core::state::Increment;
5
6/// Geodesic SRI: strong order 1.5 approximation for manifold SDEs.
7///
8/// Extends GeodesicMilstein with the dZ iterated-integral correction term.
9///
10/// Full step:
11///   v = f(x)*dt + g(x)*dW + 0.5*nabla_g(g)*(dW^2 - dt) + nabla_g(g)*dZ
12///   x_new = exp_x(v)
13///
14/// where nabla_g(g) is approximated by finite-difference parallel transport:
15///   nabla_g g(x) ≈ (1/eps) * [PT_{y->x}(g(y)) - g(x)],  y = exp_x(eps*g(x))
16///
17/// # Single-FD note
18///
19/// This uses nabla_g g for both the Milstein correction and the dZ term.
20/// Full SRI1 would require a second PT-based FD for nabla_g(nabla_g g).
21/// This approximation is O(dt^{3/2}) accurate for smooth diffusion fields.
22pub struct GeodesicSRI {
23    /// Finite-difference step size for covariant derivative approximation.
24    pub eps: f64,
25}
26
27impl GeodesicSRI {
28    /// Create with default eps = 1e-4.
29    pub fn new() -> Self {
30        Self { eps: 1e-4 }
31    }
32
33    /// Advance x by one SRI step on the manifold.
34    ///
35    /// Computes the Milstein and dZ corrections via finite-difference parallel transport:
36    ///   1. Walk eps along g(x) to get y = exp_x(eps * g(x)).
37    ///   2. Evaluate g at y.
38    ///   3. Transport g(y) back from y to x via ParallelTransport.
39    ///   4. Approx covariant deriv: nabla_g g ≈ (PT(g(y)) - g(x)) / eps.
40    ///   5. Milstein correction: 0.5 * nabla_g(g) * (dW^2 - dt).
41    ///   6. SRI correction: nabla_g(g) * dZ.
42    ///   7. Apply exp to the full tangent displacement.
43    ///
44    /// If transport fails (cut locus), falls back to Euler step.
45    pub fn step<M, D, G>(
46        &self,
47        sde: &ManifoldSDE<M, D, G>,
48        x: &M::Point,
49        t: f64,
50        dt: f64,
51        inc: &Increment<f64>,
52    ) -> M::Point
53    where
54        M: Manifold + ParallelTransport,
55        D: Fn(&M::Point, f64) -> M::Tangent + Send + Sync,
56        G: Fn(&M::Point, f64) -> M::Tangent + Send + Sync,
57        M::Tangent: std::ops::Add<Output = M::Tangent>
58            + std::ops::Mul<f64, Output = M::Tangent>
59            + std::ops::Sub<Output = M::Tangent>
60            + Clone,
61    {
62        let dw = inc.dw;
63        let dz = inc.dz;
64        let f = (sde.drift)(x, t);
65        let g = (sde.diffusion)(x, t);
66        let eps = self.eps;
67
68        // Walk eps along g(x) to get perturbed point y.
69        let eps_g = g.clone() * eps;
70        let y = sde.manifold.exp(x, &eps_g);
71
72        // Evaluate diffusion at y.
73        let g_at_y = (sde.diffusion)(&y, t);
74
75        // Transport g(y) back from y to T_x(M). Falls back to Euler if it fails.
76        let tangent = match sde.manifold.transport(&y, x, &g_at_y) {
77            Ok(g_transported) => {
78                // Finite-difference covariant derivative.
79                let nabla_g_g = (g_transported - g.clone()) * (1.0 / eps);
80                // Milstein correction: 0.5 * nabla_g(g) * (dW^2 - dt)
81                let milstein_correction = nabla_g_g.clone() * (0.5 * (dw * dw - dt));
82                // SRI dZ correction: nabla_g(g) * dZ
83                let sri_correction = nabla_g_g * dz;
84                f * dt + g * dw + milstein_correction + sri_correction
85            }
86            Err(_) => {
87                // Degenerate geometry (cut locus): fall back to Euler step.
88                f * dt + g * dw
89            }
90        };
91
92        sde.manifold.exp(x, &tangent)
93    }
94}
95
96impl Default for GeodesicSRI {
97    fn default() -> Self {
98        Self::new()
99    }
100}
101
102impl<M, D, G> GeoScheme<M, D, G> for GeodesicSRI
103where
104    M: Manifold + ParallelTransport,
105    D: Fn(&M::Point, f64) -> M::Tangent + Send + Sync,
106    G: Fn(&M::Point, f64) -> M::Tangent + Send + Sync,
107    M::Tangent: std::ops::Add<Output = M::Tangent>
108        + std::ops::Mul<f64, Output = M::Tangent>
109        + std::ops::Sub<Output = M::Tangent>
110        + Clone,
111{
112    fn step_geo(
113        &self,
114        sde: &ManifoldSDE<M, D, G>,
115        x: &M::Point,
116        t: f64,
117        dt: f64,
118        inc: &Increment<f64>,
119    ) -> M::Point {
120        self.step(sde, x, t, dt, inc)
121    }
122}