differential_equations/methods/
milstein.rs1use crate::{
4 error::Error,
5 interpolate::{Interpolation, linear_interpolate},
6 linalg::{component_multiply, component_square},
7 sde::{SDE, StochasticNumericalMethod},
8 stats::Evals,
9 status::Status,
10 traits::{Real, State},
11 utils::validate_step_size_parameters,
12};
13
14pub struct Milstein<T: Real, Y: State<T>> {
19 pub h0: T,
20 h: T,
21 t: T,
22 y: Y,
23 t_prev: T,
24 y_prev: Y,
25 dydt: Y,
26
27 pub h_min: T,
29 pub h_max: T,
30 pub max_steps: usize,
31
32 steps: usize,
34 status: Status<T, Y>,
35}
36
37impl<T: Real, Y: State<T>> Milstein<T, Y> {
38 pub fn new(h0: T) -> Self {
40 Self {
41 h0,
42 h: h0,
43 t: T::zero(),
44 y: Y::zeros(),
45 t_prev: T::zero(),
46 y_prev: Y::zeros(),
47 dydt: Y::zeros(),
48 h_min: T::zero(),
49 h_max: T::infinity(),
50 max_steps: 10_000,
51 steps: 0,
52 status: Status::Uninitialized,
53 }
54 }
55
56 pub fn h_min(mut self, h_min: T) -> Self {
58 self.h_min = h_min;
59 self
60 }
61
62 pub fn h_max(mut self, h_max: T) -> Self {
64 self.h_max = h_max;
65 self
66 }
67
68 pub fn max_steps(mut self, max_steps: usize) -> Self {
70 self.max_steps = max_steps;
71 self
72 }
73}
74
75impl<T: Real, Y: State<T>> StochasticNumericalMethod<T, Y> for Milstein<T, Y> {
76 fn init<F>(&mut self, sde: &mut F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
77 where
78 F: SDE<T, Y> + ?Sized,
79 {
80 let mut evals = Evals::new();
81
82 if self.h0 == T::zero() {
83 let duration = (tf - t0).abs();
84 self.h0 = duration / T::from_f64(100.0).unwrap();
85 }
86
87 match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
88 Ok(h0) => self.h = h0,
89 Err(status) => return Err(status),
90 }
91
92 self.steps = 0;
93 self.t = t0;
94 self.y = y0.clone();
95 self.dydt = y0.zeros_like();
96 self.t_prev = t0;
97 self.y_prev = y0.clone();
98
99 sde.drift(self.t, &self.y, &mut self.dydt);
100 evals.function += 1;
101
102 self.status = Status::Initialized;
103
104 Ok(evals)
105 }
106
107 fn step<F>(&mut self, sde: &mut F) -> Result<Evals, Error<T, Y>>
108 where
109 F: SDE<T, Y> + ?Sized,
110 {
111 let mut evals = Evals::new();
112
113 if self.steps >= self.max_steps {
114 self.status = Status::Error(Error::MaxSteps {
115 t: self.t,
116 y: self.y.clone(),
117 });
118 return Err(Error::MaxSteps {
119 t: self.t,
120 y: self.y.clone(),
121 });
122 }
123 self.steps += 1;
124
125 self.t_prev = self.t;
126 self.y_prev = self.y.clone();
127
128 let mut diffusion = self.y.zeros_like();
130 sde.diffusion(self.t, &self.y, &mut diffusion);
131 evals.function += 1;
132
133 let mut dw = self.y.zeros_like();
135 sde.noise(self.h, &mut dw);
136
137 let sqrt_h = self.h.sqrt();
140 let mut y_aux = self.y.clone();
141 y_aux.add_scaled(sqrt_h, &diffusion);
142
143 let mut diffusion_aux = self.y.zeros_like();
145 sde.diffusion(self.t, &y_aux, &mut diffusion_aux);
146 evals.function += 1;
147
148 let dw_sq = component_square(&dw);
150 let mut milstein_term = self.y.zeros_like();
151 let factor = T::one() / (T::from_f64(2.0).unwrap() * sqrt_h);
152
153 for i in 0..self.y.len() {
154 let diff = diffusion_aux.get_component(i) - diffusion.get_component(i);
155 let dws_minus_h = dw_sq.get_component(i) - self.h;
156 milstein_term.set_component(i, diff * dws_minus_h * factor);
157 }
158
159 let mut drift_increment = self.dydt.clone();
161 drift_increment.scale_mut(self.h);
162
163 let diffusion_increment = component_multiply(&diffusion, &dw);
164
165 let y_next = self.y.plus_linear_combination(&[
166 (&drift_increment, T::one()),
167 (&diffusion_increment, T::one()),
168 (&milstein_term, T::one()),
169 ]);
170
171 self.t += self.h;
172 self.y = y_next;
173
174 sde.drift(self.t, &self.y, &mut self.dydt);
176 evals.function += 1;
177
178 self.status = Status::Solving;
179 Ok(evals)
180 }
181
182 fn t(&self) -> T {
183 self.t
184 }
185 fn y(&self) -> &Y {
186 &self.y
187 }
188 fn t_prev(&self) -> T {
189 self.t_prev
190 }
191 fn y_prev(&self) -> &Y {
192 &self.y_prev
193 }
194 fn h(&self) -> T {
195 self.h
196 }
197 fn set_h(&mut self, h: T) {
198 self.h = h;
199 }
200 fn status(&self) -> &Status<T, Y> {
201 &self.status
202 }
203 fn set_status(&mut self, status: Status<T, Y>) {
204 self.status = status;
205 }
206}
207
208impl<T: Real, Y: State<T>> Interpolation<T, Y> for Milstein<T, Y> {
209 fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
210 if t_interp < self.t_prev || t_interp > self.t {
211 return Err(Error::OutOfBounds {
212 t_interp,
213 t_prev: self.t_prev,
214 t_curr: self.t,
215 });
216 }
217 Ok(linear_interpolate(
218 self.t_prev,
219 self.t,
220 &self.y_prev,
221 &self.y,
222 t_interp,
223 ))
224 }
225}