Skip to main content

numra_sde/
milstein.rs

1//! Milstein method for SDEs.
2//!
3//! Higher-order SDE solver with strong order 1.0.
4//!
5//! For an SDE: dX = f(t,X) dt + g(t,X) dW
6//!
7//! The Milstein update is:
8//! X_{n+1} = X_n + f(t_n, X_n) * dt + g(t_n, X_n) * ΔW_n
9//!           + 0.5 * g(t_n, X_n) * g'(t_n, X_n) * (ΔW_n² - dt)
10//!
11//! The extra term accounts for the Itô-Stratonovich correction.
12//!
13//! Author: Moussa Leblouba
14//! Date: 4 February 2026
15//! Modified: 2 May 2026
16
17use crate::system::{NoiseType, SdeOptions, SdeResult, SdeSolver, SdeStats, SdeSystem};
18use crate::wiener::create_wiener;
19use numra_core::Scalar;
20
21/// Milstein SDE solver.
22///
23/// Fixed time step solver with strong order 1.0 and weak order 1.0.
24/// Requires diffusion derivative ∂g/∂x (computed via finite differences by default).
25pub struct Milstein;
26
27impl<S: Scalar> SdeSolver<S> for Milstein {
28    fn solve<Sys: SdeSystem<S>>(
29        system: &Sys,
30        t0: S,
31        tf: S,
32        x0: &[S],
33        options: &SdeOptions<S>,
34        seed: Option<u64>,
35    ) -> Result<SdeResult<S>, String> {
36        let dim = system.dim();
37        if x0.len() != dim {
38            return Err(format!(
39                "Initial state dimension {} doesn't match system dimension {}",
40                x0.len(),
41                dim
42            ));
43        }
44
45        // Milstein currently only supports diagonal noise
46        match system.noise_type() {
47            NoiseType::Diagonal => {}
48            _ => return Err("Milstein currently only supports diagonal noise".to_string()),
49        }
50
51        let dt = options.dt;
52        let n_wiener = system.n_wiener();
53        let actual_seed = seed.or(options.seed);
54        let mut wiener = create_wiener(n_wiener, actual_seed);
55
56        // Allocate storage
57        let mut t = t0;
58        let mut x = x0.to_vec();
59        let mut f = vec![S::ZERO; dim];
60        let mut g = vec![S::ZERO; dim];
61        let mut gdg = vec![S::ZERO; dim]; // g * dg/dx
62
63        let mut t_out = Vec::new();
64        let mut y_out = Vec::new();
65        let mut stats = SdeStats::default();
66
67        // Save initial state
68        if options.save_trajectory {
69            t_out.push(t);
70            y_out.extend_from_slice(&x);
71        }
72
73        let half = S::from_f64(0.5);
74        let mut step = 0;
75        while t < tf && step < options.max_steps {
76            // Adjust final step
77            let h = dt.min(tf - t);
78
79            // Evaluate drift, diffusion, and diffusion derivative
80            system.drift(t, &x, &mut f);
81            system.diffusion(t, &x, &mut g);
82            system.diffusion_derivative(t, &x, &mut gdg);
83            stats.n_drift += 1;
84            stats.n_diffusion += 2; // diffusion_derivative calls diffusion
85
86            // Generate Wiener increment
87            let dw = wiener.increment(h);
88
89            // Milstein update for diagonal noise:
90            // X_{n+1} = X_n + f * dt + g * ΔW + 0.5 * g * g' * (ΔW² - dt)
91            for i in 0..dim {
92                let dw_i = dw.dw[i];
93                let correction = half * gdg[i] * (dw_i * dw_i - h);
94                x[i] += f[i] * h + g[i] * dw_i + correction;
95            }
96
97            t += h;
98            step += 1;
99            stats.n_accept += 1;
100
101            // Save state
102            if options.save_trajectory {
103                t_out.push(t);
104                y_out.extend_from_slice(&x);
105            }
106        }
107
108        if step >= options.max_steps && t < tf {
109            return Err(format!(
110                "Maximum steps ({}) exceeded at t = {}",
111                options.max_steps,
112                t.to_f64()
113            ));
114        }
115
116        // If not saving trajectory, just save final state
117        if !options.save_trajectory {
118            t_out.push(t);
119            y_out.extend_from_slice(&x);
120        }
121
122        Ok(SdeResult::new(t_out, y_out, dim, stats))
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    /// Geometric Brownian Motion: dS = μS dt + σS dW
131    /// For GBM, g(x) = σx, so g'(x) = σ, and g*g' = σ²x
132    #[allow(clippy::upper_case_acronyms)]
133    struct GBM {
134        mu: f64,
135        sigma: f64,
136    }
137
138    impl SdeSystem<f64> for GBM {
139        fn dim(&self) -> usize {
140            1
141        }
142
143        fn drift(&self, _t: f64, x: &[f64], f: &mut [f64]) {
144            f[0] = self.mu * x[0];
145        }
146
147        fn diffusion(&self, _t: f64, x: &[f64], g: &mut [f64]) {
148            g[0] = self.sigma * x[0];
149        }
150
151        fn diffusion_derivative(&self, _t: f64, x: &[f64], gdg: &mut [f64]) {
152            // g = σx, dg/dx = σ, so g * dg/dx = σ²x
153            gdg[0] = self.sigma * self.sigma * x[0];
154        }
155    }
156
157    #[test]
158    fn test_milstein_gbm() {
159        let gbm = GBM {
160            mu: 0.05,
161            sigma: 0.2,
162        };
163        let options = SdeOptions::default().dt(0.001).seed(42);
164
165        let result =
166            Milstein::solve(&gbm, 0.0, 1.0, &[100.0], &options, None).expect("Solve failed");
167
168        assert!(result.success);
169        assert!(!result.t.is_empty());
170        let final_price = result.y_final().unwrap()[0];
171        assert!(final_price > 0.0);
172    }
173
174    #[test]
175    fn test_milstein_vs_euler_maruyama() {
176        // For the same seed, Milstein should give different (more accurate) results
177        let gbm = GBM {
178            mu: 0.05,
179            sigma: 0.3,
180        };
181        let options = SdeOptions::default().dt(0.01).seed(42);
182
183        let em_result = crate::EulerMaruyama::solve(&gbm, 0.0, 1.0, &[100.0], &options, None)
184            .expect("EM solve failed");
185        let mil_result = Milstein::solve(&gbm, 0.0, 1.0, &[100.0], &options, None)
186            .expect("Milstein solve failed");
187
188        let em_final = em_result.y_final().unwrap()[0];
189        let mil_final = mil_result.y_final().unwrap()[0];
190
191        // They should be different due to the correction term
192        // (same random numbers, different algorithm)
193        assert!((em_final - mil_final).abs() > 0.01);
194    }
195
196    #[test]
197    fn test_milstein_additive_noise() {
198        // For additive noise (g constant), Milstein = Euler-Maruyama
199        struct Additive;
200        impl SdeSystem<f64> for Additive {
201            fn dim(&self) -> usize {
202                1
203            }
204            fn drift(&self, _t: f64, x: &[f64], f: &mut [f64]) {
205                f[0] = -x[0];
206            }
207            fn diffusion(&self, _t: f64, _x: &[f64], g: &mut [f64]) {
208                g[0] = 0.5; // Constant diffusion
209            }
210            fn diffusion_derivative(&self, _t: f64, _x: &[f64], gdg: &mut [f64]) {
211                gdg[0] = 0.0; // dg/dx = 0, so g*g' = 0
212            }
213        }
214
215        let options = SdeOptions::default().dt(0.01).seed(42);
216        let em = crate::EulerMaruyama::solve(&Additive, 0.0, 1.0, &[1.0], &options, None)
217            .expect("EM solve failed");
218        let mil = Milstein::solve(&Additive, 0.0, 1.0, &[1.0], &options, None)
219            .expect("Milstein solve failed");
220
221        // For additive noise, results should be identical
222        let em_final = em.y_final().unwrap()[0];
223        let mil_final = mil.y_final().unwrap()[0];
224        assert!((em_final - mil_final).abs() < 1e-10);
225    }
226
227    #[test]
228    fn test_reproducibility() {
229        let gbm = GBM {
230            mu: 0.05,
231            sigma: 0.2,
232        };
233        let options = SdeOptions::default().dt(0.01);
234
235        let r1 =
236            Milstein::solve(&gbm, 0.0, 1.0, &[100.0], &options, Some(42)).expect("Solve failed");
237        let r2 =
238            Milstein::solve(&gbm, 0.0, 1.0, &[100.0], &options, Some(42)).expect("Solve failed");
239
240        let y1 = r1.y_final().unwrap()[0];
241        let y2 = r2.y_final().unwrap()[0];
242        assert!((y1 - y2).abs() < 1e-10);
243    }
244}