Skip to main content

differential_equations/methods/
milstein.rs

1//! Derivative-Free Milstein method for Stochastic Differential Equations
2
3use crate::{
4    error::Error,
5    interpolate::{Interpolation, linear_interpolate},
6    linalg::{component_multiply, component_square},
7    sde::{SDE, StochasticNumericalMethod},
8    stats::Evals,
9    status::Status,
10    traits::{Real, State},
11    utils::validate_step_size_parameters,
12};
13
14/// Derivative-Free Milstein method for solving SDEs.
15///
16/// Provides strong order 1.0 convergence for commutative/diagonal noise,
17/// which is an improvement over the 0.5 strong order of Euler-Maruyama.
18pub struct Milstein<T: Real, Y: State<T>> {
19    pub h0: T,
20    h: T,
21    t: T,
22    y: Y,
23    t_prev: T,
24    y_prev: Y,
25    dydt: Y,
26
27    // Settings
28    pub h_min: T,
29    pub h_max: T,
30    pub max_steps: usize,
31
32    // Statistics
33    steps: usize,
34    status: Status<T, Y>,
35}
36
37impl<T: Real, Y: State<T>> Milstein<T, Y> {
38    /// Creates a new Milstein method solver
39    pub fn new(h0: T) -> Self {
40        Self {
41            h0,
42            h: h0,
43            t: T::zero(),
44            y: Y::zeros(),
45            t_prev: T::zero(),
46            y_prev: Y::zeros(),
47            dydt: Y::zeros(),
48            h_min: T::zero(),
49            h_max: T::infinity(),
50            max_steps: 10_000,
51            steps: 0,
52            status: Status::Uninitialized,
53        }
54    }
55
56    /// Set minimum step size
57    pub fn h_min(mut self, h_min: T) -> Self {
58        self.h_min = h_min;
59        self
60    }
61
62    /// Set maximum step size
63    pub fn h_max(mut self, h_max: T) -> Self {
64        self.h_max = h_max;
65        self
66    }
67
68    /// Set maximum number of steps
69    pub fn max_steps(mut self, max_steps: usize) -> Self {
70        self.max_steps = max_steps;
71        self
72    }
73}
74
75impl<T: Real, Y: State<T>> StochasticNumericalMethod<T, Y> for Milstein<T, Y> {
76    fn init<F>(&mut self, sde: &mut F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
77    where
78        F: SDE<T, Y> + ?Sized,
79    {
80        let mut evals = Evals::new();
81
82        if self.h0 == T::zero() {
83            let duration = (tf - t0).abs();
84            self.h0 = duration / T::from_f64(100.0).unwrap();
85        }
86
87        match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
88            Ok(h0) => self.h = h0,
89            Err(status) => return Err(status),
90        }
91
92        self.steps = 0;
93        self.t = t0;
94        self.y = y0.clone();
95        self.dydt = y0.zeros_like();
96        self.t_prev = t0;
97        self.y_prev = y0.clone();
98
99        sde.drift(self.t, &self.y, &mut self.dydt);
100        evals.function += 1;
101
102        self.status = Status::Initialized;
103
104        Ok(evals)
105    }
106
107    fn step<F>(&mut self, sde: &mut F) -> Result<Evals, Error<T, Y>>
108    where
109        F: SDE<T, Y> + ?Sized,
110    {
111        let mut evals = Evals::new();
112
113        if self.steps >= self.max_steps {
114            self.status = Status::Error(Error::MaxSteps {
115                t: self.t,
116                y: self.y.clone(),
117            });
118            return Err(Error::MaxSteps {
119                t: self.t,
120                y: self.y.clone(),
121            });
122        }
123        self.steps += 1;
124
125        self.t_prev = self.t;
126        self.y_prev = self.y.clone();
127
128        // Calculate diffusion at y_n
129        let mut diffusion = self.y.zeros_like();
130        sde.diffusion(self.t, &self.y, &mut diffusion);
131        evals.function += 1;
132
133        // Generate noise increments
134        let mut dw = self.y.zeros_like();
135        sde.noise(self.h, &mut dw);
136
137        // Derivative-free Milstein correction
138        // y_aux = y_n + b(t_n, y_n) * sqrt(h)
139        let sqrt_h = self.h.sqrt();
140        let mut y_aux = self.y.clone();
141        y_aux.add_scaled(sqrt_h, &diffusion);
142
143        // b_aux = b(t_n, y_aux)
144        let mut diffusion_aux = self.y.zeros_like();
145        sde.diffusion(self.t, &y_aux, &mut diffusion_aux);
146        evals.function += 1;
147
148        // term = (b_aux - b) * (dw^2 - h) / (2 * sqrt(h))
149        let dw_sq = component_square(&dw);
150        let mut milstein_term = self.y.zeros_like();
151        let factor = T::one() / (T::from_f64(2.0).unwrap() * sqrt_h);
152
153        for i in 0..self.y.len() {
154            let diff = diffusion_aux.get_component(i) - diffusion.get_component(i);
155            let dws_minus_h = dw_sq.get_component(i) - self.h;
156            milstein_term.set_component(i, diff * dws_minus_h * factor);
157        }
158
159        // Combine deterministic, Euler-Maruyama, and Milstein parts
160        let mut drift_increment = self.dydt.clone();
161        drift_increment.scale_mut(self.h);
162
163        let diffusion_increment = component_multiply(&diffusion, &dw);
164
165        let y_next = self.y.plus_linear_combination(&[
166            (&drift_increment, T::one()),
167            (&diffusion_increment, T::one()),
168            (&milstein_term, T::one()),
169        ]);
170
171        self.t += self.h;
172        self.y = y_next;
173
174        // Drift for next step
175        sde.drift(self.t, &self.y, &mut self.dydt);
176        evals.function += 1;
177
178        self.status = Status::Solving;
179        Ok(evals)
180    }
181
182    fn t(&self) -> T {
183        self.t
184    }
185    fn y(&self) -> &Y {
186        &self.y
187    }
188    fn t_prev(&self) -> T {
189        self.t_prev
190    }
191    fn y_prev(&self) -> &Y {
192        &self.y_prev
193    }
194    fn h(&self) -> T {
195        self.h
196    }
197    fn set_h(&mut self, h: T) {
198        self.h = h;
199    }
200    fn status(&self) -> &Status<T, Y> {
201        &self.status
202    }
203    fn set_status(&mut self, status: Status<T, Y>) {
204        self.status = status;
205    }
206}
207
208impl<T: Real, Y: State<T>> Interpolation<T, Y> for Milstein<T, Y> {
209    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
210        if t_interp < self.t_prev || t_interp > self.t {
211            return Err(Error::OutOfBounds {
212                t_interp,
213                t_prev: self.t_prev,
214                t_curr: self.t,
215            });
216        }
217        Ok(linear_interpolate(
218            self.t_prev,
219            self.t,
220            &self.y_prev,
221            &self.y,
222            t_interp,
223        ))
224    }
225}