pathwise_geo/scheme/sri.rs
1use crate::sde::ManifoldSDE;
2use crate::simulate::GeoScheme;
3use cartan_core::{Manifold, ParallelTransport};
4use pathwise_core::state::Increment;
5
6/// Geodesic SRI: strong order 1.5 approximation for manifold SDEs.
7///
8/// Extends GeodesicMilstein with the dZ iterated-integral correction term.
9///
10/// Full step:
11/// v = f(x)*dt + g(x)*dW + 0.5*nabla_g(g)*(dW^2 - dt) + nabla_g(g)*dZ
12/// x_new = exp_x(v)
13///
14/// where nabla_g(g) is approximated by finite-difference parallel transport:
15/// nabla_g g(x) ≈ (1/eps) * [PT_{y->x}(g(y)) - g(x)], y = exp_x(eps*g(x))
16///
17/// # Single-FD note
18///
19/// This uses nabla_g g for both the Milstein correction and the dZ term.
20/// Full SRI1 would require a second PT-based FD for nabla_g(nabla_g g).
21/// This approximation is O(dt^{3/2}) accurate for smooth diffusion fields.
22pub struct GeodesicSRI {
23 /// Finite-difference step size for covariant derivative approximation.
24 pub eps: f64,
25}
26
27impl GeodesicSRI {
28 /// Create with default eps = 1e-4.
29 pub fn new() -> Self {
30 Self { eps: 1e-4 }
31 }
32
33 /// Advance x by one SRI step on the manifold.
34 ///
35 /// Computes the Milstein and dZ corrections via finite-difference parallel transport:
36 /// 1. Walk eps along g(x) to get y = exp_x(eps * g(x)).
37 /// 2. Evaluate g at y.
38 /// 3. Transport g(y) back from y to x via ParallelTransport.
39 /// 4. Approx covariant deriv: nabla_g g ≈ (PT(g(y)) - g(x)) / eps.
40 /// 5. Milstein correction: 0.5 * nabla_g(g) * (dW^2 - dt).
41 /// 6. SRI correction: nabla_g(g) * dZ.
42 /// 7. Apply exp to the full tangent displacement.
43 ///
44 /// If transport fails (cut locus), falls back to Euler step.
45 pub fn step<M, D, G>(
46 &self,
47 sde: &ManifoldSDE<M, D, G>,
48 x: &M::Point,
49 t: f64,
50 dt: f64,
51 inc: &Increment<f64>,
52 ) -> M::Point
53 where
54 M: Manifold + ParallelTransport,
55 D: Fn(&M::Point, f64) -> M::Tangent + Send + Sync,
56 G: Fn(&M::Point, f64) -> M::Tangent + Send + Sync,
57 M::Tangent: std::ops::Add<Output = M::Tangent>
58 + std::ops::Mul<f64, Output = M::Tangent>
59 + std::ops::Sub<Output = M::Tangent>
60 + Clone,
61 {
62 let dw = inc.dw;
63 let dz = inc.dz;
64 let f = (sde.drift)(x, t);
65 let g = (sde.diffusion)(x, t);
66 let eps = self.eps;
67
68 // Walk eps along g(x) to get perturbed point y.
69 let eps_g = g.clone() * eps;
70 let y = sde.manifold.exp(x, &eps_g);
71
72 // Evaluate diffusion at y.
73 let g_at_y = (sde.diffusion)(&y, t);
74
75 // Transport g(y) back from y to T_x(M). Falls back to Euler if it fails.
76 let tangent = match sde.manifold.transport(&y, x, &g_at_y) {
77 Ok(g_transported) => {
78 // Finite-difference covariant derivative.
79 let nabla_g_g = (g_transported - g.clone()) * (1.0 / eps);
80 // Milstein correction: 0.5 * nabla_g(g) * (dW^2 - dt)
81 let milstein_correction = nabla_g_g.clone() * (0.5 * (dw * dw - dt));
82 // SRI dZ correction: nabla_g(g) * dZ
83 let sri_correction = nabla_g_g * dz;
84 f * dt + g * dw + milstein_correction + sri_correction
85 }
86 Err(_) => {
87 // Degenerate geometry (cut locus): fall back to Euler step.
88 f * dt + g * dw
89 }
90 };
91
92 sde.manifold.exp(x, &tangent)
93 }
94}
95
96impl Default for GeodesicSRI {
97 fn default() -> Self {
98 Self::new()
99 }
100}
101
102impl<M, D, G> GeoScheme<M, D, G> for GeodesicSRI
103where
104 M: Manifold + ParallelTransport,
105 D: Fn(&M::Point, f64) -> M::Tangent + Send + Sync,
106 G: Fn(&M::Point, f64) -> M::Tangent + Send + Sync,
107 M::Tangent: std::ops::Add<Output = M::Tangent>
108 + std::ops::Mul<f64, Output = M::Tangent>
109 + std::ops::Sub<Output = M::Tangent>
110 + Clone,
111{
112 fn step_geo(
113 &self,
114 sde: &ManifoldSDE<M, D, G>,
115 x: &M::Point,
116 t: f64,
117 dt: f64,
118 inc: &Increment<f64>,
119 ) -> M::Point {
120 self.step(sde, x, t, dt, inc)
121 }
122}