Skip to main content

pathwise_geo/scheme/
milstein.rs

1use crate::sde::ManifoldSDE;
2use cartan_core::{Manifold, ParallelTransport};
3use pathwise_core::state::Increment;
4
5/// Geodesic Milstein: strong order 1.0 scheme via finite-difference covariant derivative.
6///
7/// Correction via finite-difference approximation of nabla_g(g):
8///   nabla_g g(x) ≈ (1/eps) * [PT_{y->x}(g(y)) - g(x)]
9///   where y = exp_x(eps * g(x))
10///
11/// Full step:
12///   v = f(x)*dt + g(x)*dW + 0.5 * nabla_g(g) * (dW^2 - dt)
13///   x_new = exp_x(v)
14///
15/// Requires M: ParallelTransport to compute the covariant derivative via
16/// transporting g(y) back to T_x(M) along the geodesic from y to x.
17///
18/// References:
19///   - Milstein (1974), Platen & Wagner (1982) for the scalar correction.
20///   - Said & Manton (2012) for geodesic extension to Lie groups.
21pub struct GeodesicMilstein {
22    /// Finite-difference step size for covariant derivative approximation.
23    pub eps: f64,
24}
25
26impl GeodesicMilstein {
27    /// Create with default eps = 1e-4.
28    pub fn new() -> Self {
29        Self { eps: 1e-4 }
30    }
31
32    /// Advance x by one Milstein step on the manifold.
33    ///
34    /// Computes the Milstein correction via finite-difference parallel transport:
35    ///   1. Walk eps along g(x) to get y = exp_x(eps * g(x)).
36    ///   2. Evaluate g at y.
37    ///   3. Transport g(y) back from y to x via ParallelTransport.
38    ///   4. Approx covariant deriv: nabla_g g ≈ (PT(g(y)) - g(x)) / eps.
39    ///   5. Add Milstein correction: 0.5 * nabla_g(g) * (dW^2 - dt).
40    ///   6. Apply exp to the full tangent displacement.
41    ///
42    /// If transport fails (cut locus), falls back to Euler step (no correction).
43    pub fn step<M, D, G>(
44        &self,
45        sde: &ManifoldSDE<M, D, G>,
46        x: &M::Point,
47        t: f64,
48        dt: f64,
49        inc: &Increment<f64>,
50    ) -> M::Point
51    where
52        M: Manifold + ParallelTransport,
53        D: Fn(&M::Point, f64) -> M::Tangent + Send + Sync,
54        G: Fn(&M::Point, f64) -> M::Tangent + Send + Sync,
55    {
56        let dw = inc.dw;
57        let f = (sde.drift)(x, t);
58        let g = (sde.diffusion)(x, t);
59        let eps = self.eps;
60
61        // Walk eps along g(x) to get perturbed point y.
62        let eps_g = g.clone() * eps;
63        let y = sde.manifold.exp(x, &eps_g);
64
65        // Evaluate diffusion at y.
66        let g_at_y = (sde.diffusion)(&y, t);
67
68        // Transport g(y) from y back to T_x(M). Falls back to Euler if it fails.
69        let tangent = match sde.manifold.transport(&y, x, &g_at_y) {
70            Ok(g_transported) => {
71                // Finite-difference covariant derivative: nabla_g g ≈ (PT(g(y)) - g(x)) / eps.
72                let nabla_g_g = (g_transported - g.clone()) * (1.0 / eps);
73                let correction = nabla_g_g * (0.5 * (dw * dw - dt));
74                f * dt + g * dw + correction
75            }
76            Err(_) => {
77                // Degenerate geometry (cut locus): fall back to Euler step.
78                f * dt + g * dw
79            }
80        };
81
82        sde.manifold.exp(x, &tangent)
83    }
84}
85
86impl Default for GeodesicMilstein {
87    fn default() -> Self {
88        Self::new()
89    }
90}