differential_equations/sde/methods/
milstein.rs

1//! Milstein Method for solving stochastic differential equations.
2
3use crate::{
4    Error, Status,
5    interpolate::Interpolation,
6    linalg::{component_multiply, component_square},
7    sde::{SDENumericalMethod, SDE},
8    alias::Evals,
9    traits::{CallBackData, Real, State},
10    utils::validate_step_size_parameters,
11};
12
13/// Milstein Method for solving stochastic differential equations.
14///
15/// The Milstein method is a higher-order method for SDEs that includes
16/// an additional term from the Itô-Taylor expansion to achieve better accuracy.
17///
18/// For an SDE of the form:
19/// dY = a(t,Y)dt + b(t,Y)dW
20///
21/// The Milstein update is:
22/// Y_{n+1} = Y_n + a(t_n, Y_n)Δt + b(t_n, Y_n)ΔW_n +
23///           0.5 * b(t_n, Y_n)^2 * [(ΔW_n)² - Δt]
24///
25/// where ΔW_n is a Wiener process increment.
26///
27/// This implementation is based on the standard form of the Milstein method for
28/// scalar SDEs. For geometric Brownian motion, the b(t,Y) term is σY, so the
29/// correction term becomes 0.5 * (σY)^2 * [(ΔW)² - Δt].
30///
31/// The Milstein method has strong order of convergence 1.0
32/// (compared to 0.5 for Euler-Maruyama), making it more accurate
33/// for SDEs with significant diffusion effects.
34///
35/// # Example
36/// ```
37/// use differential_equations::prelude::*;
38/// use nalgebra::SVector;
39/// use rand::SeedableRng;
40/// use rand_distr::{Distribution, Normal};
41///
42/// struct GBM {
43///     mu: f64,     // Drift rate
44///     sigma: f64,  // Volatility
45///     rng: rand::rngs::StdRng,
46/// }
47///
48/// impl GBM {
49///     fn new(mu: f64, sigma: f64, seed: u64) -> Self {
50///         Self {
51///             mu,
52///             sigma,
53///             rng: rand::rngs::StdRng::seed_from_u64(seed),
54///         }
55///     }
56/// }
57///
58/// impl SDE<f64, SVector<f64, 1>> for GBM {
59///     fn drift(&self, _t: f64, y: &SVector<f64, 1>, dydt: &mut SVector<f64, 1>) {
60///         dydt[0] = self.mu * y[0];
61///     }
62///     
63///     fn diffusion(&self, _t: f64, y: &SVector<f64, 1>, dydw: &mut SVector<f64, 1>) {
64///         dydw[0] = self.sigma * y[0];
65///     }
66///     
67///     fn noise(&self, dt: f64, dw: &mut SVector<f64, 1>) {
68///         let normal = Normal::new(0.0, dt.sqrt()).unwrap();
69///         dw[0] = normal.sample(&mut self.rng.clone());
70///     }
71/// }
72///
73/// let t0 = 0.0;
74/// let tf = 1.0;
75/// let y0 = SVector::<f64, 1>::new(100.0);
76/// let mut solver = Milstein::new(0.01);
77/// let gbm = GBM::new(0.1, 0.2, 42);
78/// let gbm_problem = SDEProblem::new(gbm, t0, tf, y0);
79///
80/// // Solve the SDE
81/// let result = gbm_problem.solve(&mut solver);
82/// ```
83///
84pub struct Milstein<T: Real, V: State<T>, D: CallBackData> {
85    // Step Size
86    pub h: T,
87
88    // Current State
89    t: T,
90    y: V,
91
92    // Previous State
93    t_prev: T,
94    y_prev: V,
95
96    // Temporary storage for derivatives
97    drift: V,
98    diffusion: V,
99
100    // Status
101    status: Status<T, V, D>,
102}
103
104impl<T: Real, V: State<T>, D: CallBackData> Default for Milstein<T, V, D> {
105    fn default() -> Self {
106        Milstein {
107            h: T::from_f64(0.01).unwrap(),
108            t: T::zero(),
109            y: V::zeros(),
110            t_prev: T::zero(),
111            y_prev: V::zeros(),
112            drift: V::zeros(),
113            diffusion: V::zeros(),
114            status: Status::Uninitialized,
115        }
116    }
117}
118
119impl<T: Real, V: State<T>, D: CallBackData> SDENumericalMethod<T, V, D> for Milstein<T, V, D> {
120    fn init<F>(&mut self, sde: &F, t0: T, tf: T, y0: &V) -> Result<Evals, Error<T, V>>
121    where
122        F: SDE<T, V, D>,
123    {
124        let mut evals = Evals::new();
125
126        // Check Bounds
127        match validate_step_size_parameters::<T, V, D>(self.h, T::zero(), T::infinity(), t0, tf) {
128            Ok(_) => {}
129            Err(e) => return Err(e),
130        }
131
132        // Initialize State
133        self.t = t0;
134        self.y = *y0;
135
136        // Initialize previous state
137        self.t_prev = t0;
138        self.y_prev = *y0;
139
140        // Initialize derivatives
141        sde.drift(t0, y0, &mut self.drift);
142        sde.diffusion(t0, y0, &mut self.diffusion);
143        evals.fcn += 2; // 2 function evaluations: drift and diffusion
144
145        // Initialize Status
146        self.status = Status::Initialized;
147
148        Ok(evals) // 2 function evaluations: drift and diffusion
149    }
150
151    fn step<F>(&mut self, sde: &F) -> Result<Evals, Error<T, V>>
152    where
153        F: SDE<T, V, D>,
154    {
155        let mut evals = Evals::new();
156
157        // Log previous state
158        self.t_prev = self.t;
159        self.y_prev = self.y;
160
161        // Compute derivatives at current time and state
162        sde.drift(self.t, &self.y, &mut self.drift);
163        sde.diffusion(self.t, &self.y, &mut self.diffusion);
164
165        evals.fcn += 2; // 2 function evaluations: drift and diffusion
166
167        // Generate noise increments using the SDE's noise method
168        let mut dw = V::zeros();
169        sde.noise(self.h, &mut dw);
170
171        // Compute next state using Milstein method
172
173        // Standard Euler-Maruyama terms
174        let drift_term = self.drift * self.h;
175        let diffusion_term = component_multiply(&self.diffusion, &dw);
176
177        // Additional Milstein correction term: 0.5 * b^2 * [(ΔW)² - Δt]
178        // Calculate dW² - dt
179        let dw_squared = component_square(&dw);
180        let dw_squared_minus_dt = dw_squared - dw * self.h;
181
182        // Calculate b²
183        let diffusion_squared = component_square(&self.diffusion);
184
185        // Calculate 0.5 * b² * (dW² - dt)
186        let half = T::from_f64(0.5).unwrap();
187        let milstein_term = component_multiply(&diffusion_squared, &dw_squared_minus_dt) * half;
188
189        // Combine all terms
190        let y_new = self.y + drift_term + diffusion_term + milstein_term;
191
192        // Update state
193        self.t += self.h;
194        self.y = y_new;
195
196        Ok(evals)
197    }
198
199    fn t(&self) -> T {
200        self.t
201    }
202
203    fn y(&self) -> &V {
204        &self.y
205    }
206
207    fn t_prev(&self) -> T {
208        self.t_prev
209    }
210
211    fn y_prev(&self) -> &V {
212        &self.y_prev
213    }
214
215    fn h(&self) -> T {
216        self.h
217    }
218
219    fn set_h(&mut self, h: T) {
220        self.h = h;
221    }
222
223    fn status(&self) -> &Status<T, V, D> {
224        &self.status
225    }
226
227    fn set_status(&mut self, status: Status<T, V, D>) {
228        self.status = status;
229    }
230}
231
232impl<T: Real, V: State<T>, D: CallBackData> Interpolation<T, V> for Milstein<T, V, D> {
233    fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {
234        // Check if t is within the bounds of the current step
235        if t_interp < self.t_prev || t_interp > self.t {
236            return Err(Error::OutOfBounds {
237                t_interp,
238                t_prev: self.t_prev,
239                t_curr: self.t,
240            });
241        }
242
243        // For stochastic methods, linear interpolation is often used as it's not easy to
244        // determine the precise path between points without knowledge of the entire Wiener path
245        let s = (t_interp - self.t_prev) / (self.t - self.t_prev);
246        let y_interp = self.y_prev + (self.y - self.y_prev) * s;
247
248        Ok(y_interp)
249    }
250}
251
252impl<T: Real, V: State<T>, D: CallBackData> Milstein<T, V, D> {
253    /// Create a new Milstein solver with the specified step size
254    ///
255    /// # Arguments
256    /// * `h` - Step size
257    ///
258    /// # Returns
259    /// * A new solver instance
260    pub fn new(h: T) -> Self {
261        Milstein {
262            h,
263            ..Default::default()
264        }
265    }
266}