differential_equations/sde/methods/
euler_maruyama.rs

1//! Euler-Maruyama Method for solving stochastic differential equations.
2
3use crate::{
4    Error, Status,
5    interpolate::Interpolation,
6    linalg::component_multiply,
7    sde::{SDENumericalMethod, SDE},
8    alias::Evals,
9    traits::{CallBackData, Real, State},
10    utils::validate_step_size_parameters,
11};
12
13/// Euler-Maruyama Method for solving stochastic differential equations.
14///
15/// The Euler-Maruyama method is the simplest numerical method for SDEs,
16/// essentially extending Euler's method to stochastic equations.
17///
18/// For an SDE of the form:
19/// dY = a(t,Y)dt + b(t,Y)dW
20///
21/// The Euler-Maruyama update is:
22/// Y_{n+1} = Y_n + a(t_n, Y_n)Δt + b(t_n, Y_n)ΔW_n
23///
24/// where ΔW_n is a Wiener process increment, produced by the SDE's noise method.
25///
26/// # Example
27/// ```
28/// use differential_equations::prelude::*;
29/// use nalgebra::SVector;
30/// use rand::SeedableRng;
31/// use rand_distr::{Distribution, Normal};
32///
33/// struct GBM {
34///     rng: rand::rngs::StdRng,
35/// }
36///
37/// impl GBM {
38///     fn new(seed: u64) -> Self {
39///         Self {
40///             rng: rand::rngs::StdRng::seed_from_u64(seed),
41///         }
42///     }
43/// }
44///
45/// impl SDE<f64, SVector<f64, 1>> for GBM {
46///     fn drift(&self, _t: f64, y: &SVector<f64, 1>, dydt: &mut SVector<f64, 1>) {
47///         dydt[0] = 0.1 * y[0]; // μS
48///     }
49///     
50///     fn diffusion(&self, _t: f64, y: &SVector<f64, 1>, dydw: &mut SVector<f64, 1>) {
51///         dydw[0] = 0.2 * y[0]; // σS
52///     }
53///     
54///     fn noise(&self, dt: f64, dw: &mut SVector<f64, 1>) {
55///         let normal = Normal::new(0.0, dt.sqrt()).unwrap();
56///         dw[0] = normal.sample(&mut self.rng.clone());
57///     }
58/// }
59///
60/// let t0 = 0.0;
61/// let tf = 1.0;
62/// let y0 = SVector::<f64, 1>::new(100.0);
63/// let mut solver = EM::new(0.01);
64/// let gbm = GBM::new(42);
65/// let gbm_problem = SDEProblem::new(gbm, t0, tf, y0);
66///
67/// // Solve the SDE
68/// let result = gbm_problem.solve(&mut solver);
69/// ```
70///
71pub struct EM<T: Real, V: State<T>, D: CallBackData> {
72    // Step Size
73    pub h: T,
74
75    // Current State
76    t: T,
77    y: V,
78
79    // Previous State
80    t_prev: T,
81    y_prev: V,
82
83    // Temporary storage for derivatives
84    drift: V,
85    diffusion: V,
86
87    // Status
88    status: Status<T, V, D>,
89}
90
91impl<T: Real, V: State<T>, D: CallBackData> Default for EM<T, V, D> {
92    fn default() -> Self {
93        EM {
94            h: T::from_f64(0.01).unwrap(),
95            t: T::zero(),
96            y: V::zeros(),
97            t_prev: T::zero(),
98            y_prev: V::zeros(),
99            drift: V::zeros(),
100            diffusion: V::zeros(),
101            status: Status::Uninitialized,
102        }
103    }
104}
105
106impl<T: Real, V: State<T>, D: CallBackData> SDENumericalMethod<T, V, D> for EM<T, V, D> {
107    fn init<F>(&mut self, sde: &F, t0: T, tf: T, y0: &V) -> Result<Evals, Error<T, V>>
108    where
109        F: SDE<T, V, D>,
110    {
111        let mut evals = Evals::new();
112
113        // Check Bounds
114        match validate_step_size_parameters::<T, V, D>(self.h, T::zero(), T::infinity(), t0, tf) {
115            Ok(_) => {}
116            Err(e) => return Err(e),
117        }
118
119        // Initialize State
120        self.t = t0;
121        self.y = *y0;
122
123        // Initialize previous state
124        self.t_prev = t0;
125        self.y_prev = *y0;
126
127        // Initialize derivatives
128        sde.drift(t0, y0, &mut self.drift);
129        sde.diffusion(t0, y0, &mut self.diffusion);
130
131        evals.fcn += 2; // 2 function evaluations: drift and diffusion
132
133        // Initialize Status
134        self.status = Status::Initialized;
135
136        Ok(evals)
137    }
138
139    fn step<F>(&mut self, sde: &F) -> Result<Evals, Error<T, V>>
140    where
141        F: SDE<T, V, D>,
142    {
143        let mut evals = Evals::new();
144
145        // Log previous state
146        self.t_prev = self.t;
147        self.y_prev = self.y;
148
149        // Compute derivatives at current time and state
150        sde.drift(self.t, &self.y, &mut self.drift);
151        sde.diffusion(self.t, &self.y, &mut self.diffusion);
152
153        evals.fcn += 2; // 2 function evaluations: drift and diffusion
154
155        // Generate noise increments using the SDE's noise method
156        let mut dw = V::zeros();
157        sde.noise(self.h, &mut dw);
158
159        // Compute next state using Euler-Maruyama method
160        let drift_term = self.drift * self.h;
161        let diffusion_term = component_multiply(&self.diffusion, &dw);
162        let y_new = self.y + drift_term + diffusion_term;
163
164        // Update state
165        self.t += self.h;
166        self.y = y_new;
167
168        Ok(evals) // 2 function evaluations: drift and diffusion
169    }
170
171    fn t(&self) -> T {
172        self.t
173    }
174
175    fn y(&self) -> &V {
176        &self.y
177    }
178
179    fn t_prev(&self) -> T {
180        self.t_prev
181    }
182
183    fn y_prev(&self) -> &V {
184        &self.y_prev
185    }
186
187    fn h(&self) -> T {
188        self.h
189    }
190
191    fn set_h(&mut self, h: T) {
192        self.h = h;
193    }
194
195    fn status(&self) -> &Status<T, V, D> {
196        &self.status
197    }
198
199    fn set_status(&mut self, status: Status<T, V, D>) {
200        self.status = status;
201    }
202}
203
204impl<T: Real, V: State<T>, D: CallBackData> Interpolation<T, V> for EM<T, V, D> {
205    fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {
206        // Check if t is within the bounds of the current step
207        if t_interp < self.t_prev || t_interp > self.t {
208            return Err(Error::OutOfBounds {
209                t_interp,
210                t_prev: self.t_prev,
211                t_curr: self.t,
212            });
213        }
214
215        // For stochastic methods, linear interpolation is often used as it's not easy to
216        // determine the precise path between points without knowledge of the entire Wiener path
217        let s = (t_interp - self.t_prev) / (self.t - self.t_prev);
218        let y_interp = self.y_prev + (self.y - self.y_prev) * s;
219
220        Ok(y_interp)
221    }
222}
223
224impl<T: Real, V: State<T>, D: CallBackData> EM<T, V, D> {
225    /// Create a new Euler-Maruyama solver with the specified step size
226    ///
227    /// # Arguments
228    /// * `h` - Step size
229    ///
230    /// # Returns
231    /// * A new solver instance
232    pub fn new(h: T) -> Self {
233        EM {
234            h,
235            ..Default::default()
236        }
237    }
238}