pathwise_geo/scheme/milstein.rs
1use crate::sde::ManifoldSDE;
2use cartan_core::{Manifold, ParallelTransport};
3use pathwise_core::state::Increment;
4
5/// Geodesic Milstein: strong order 1.0 scheme via finite-difference covariant derivative.
6///
7/// Correction via finite-difference approximation of nabla_g(g):
8/// nabla_g g(x) ≈ (1/eps) * [PT_{y->x}(g(y)) - g(x)]
9/// where y = exp_x(eps * g(x))
10///
11/// Full step:
12/// v = f(x)*dt + g(x)*dW + 0.5 * nabla_g(g) * (dW^2 - dt)
13/// x_new = exp_x(v)
14///
15/// Requires M: ParallelTransport to compute the covariant derivative via
16/// transporting g(y) back to T_x(M) along the geodesic from y to x.
17///
18/// References:
19/// - Milstein (1974), Platen & Wagner (1982) for the scalar correction.
20/// - Said & Manton (2012) for geodesic extension to Lie groups.
21pub struct GeodesicMilstein {
22 /// Finite-difference step size for covariant derivative approximation.
23 pub eps: f64,
24}
25
26impl GeodesicMilstein {
27 /// Create with default eps = 1e-4.
28 pub fn new() -> Self {
29 Self { eps: 1e-4 }
30 }
31
32 /// Advance x by one Milstein step on the manifold.
33 ///
34 /// Computes the Milstein correction via finite-difference parallel transport:
35 /// 1. Walk eps along g(x) to get y = exp_x(eps * g(x)).
36 /// 2. Evaluate g at y.
37 /// 3. Transport g(y) back from y to x via ParallelTransport.
38 /// 4. Approx covariant deriv: nabla_g g ≈ (PT(g(y)) - g(x)) / eps.
39 /// 5. Add Milstein correction: 0.5 * nabla_g(g) * (dW^2 - dt).
40 /// 6. Apply exp to the full tangent displacement.
41 ///
42 /// If transport fails (cut locus), falls back to Euler step (no correction).
43 pub fn step<M, D, G>(
44 &self,
45 sde: &ManifoldSDE<M, D, G>,
46 x: &M::Point,
47 t: f64,
48 dt: f64,
49 inc: &Increment<f64>,
50 ) -> M::Point
51 where
52 M: Manifold + ParallelTransport,
53 D: Fn(&M::Point, f64) -> M::Tangent + Send + Sync,
54 G: Fn(&M::Point, f64) -> M::Tangent + Send + Sync,
55 {
56 let dw = inc.dw;
57 let f = (sde.drift)(x, t);
58 let g = (sde.diffusion)(x, t);
59 let eps = self.eps;
60
61 // Walk eps along g(x) to get perturbed point y.
62 let eps_g = g.clone() * eps;
63 let y = sde.manifold.exp(x, &eps_g);
64
65 // Evaluate diffusion at y.
66 let g_at_y = (sde.diffusion)(&y, t);
67
68 // Transport g(y) from y back to T_x(M). Falls back to Euler if it fails.
69 let tangent = match sde.manifold.transport(&y, x, &g_at_y) {
70 Ok(g_transported) => {
71 // Finite-difference covariant derivative: nabla_g g ≈ (PT(g(y)) - g(x)) / eps.
72 let nabla_g_g = (g_transported - g.clone()) * (1.0 / eps);
73 let correction = nabla_g_g * (0.5 * (dw * dw - dt));
74 f * dt + g * dw + correction
75 }
76 Err(_) => {
77 // Degenerate geometry (cut locus): fall back to Euler step.
78 f * dt + g * dw
79 }
80 };
81
82 sde.manifold.exp(x, &tangent)
83 }
84}
85
86impl Default for GeodesicMilstein {
87 fn default() -> Self {
88 Self::new()
89 }
90}