Skip to main content

differential_equations/methods/erk/fixed/
stochastic.rs

1//! Fixed Runge-Kutta methods for SDEs
2
3use crate::{
4    error::Error,
5    interpolate::{Interpolation, linear_interpolate},
6    linalg::component_multiply,
7    methods::{ExplicitRungeKutta, Fixed, Stochastic},
8    sde::{SDE, StochasticNumericalMethod},
9    stats::Evals,
10    status::Status,
11    traits::{Real, State},
12    utils::validate_step_size_parameters,
13};
14
15impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
16    StochasticNumericalMethod<T, Y> for ExplicitRungeKutta<Stochastic, Fixed, T, Y, O, S, I>
17{
18    fn init<F>(&mut self, sde: &mut F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
19    where
20        F: SDE<T, Y> + ?Sized,
21    {
22        let mut evals = Evals::new();
23
24        // If h0 is zero, calculate initial step size for fixed-step methods
25        if self.h0 == T::zero() {
26            // Simple default step size for fixed-step methods
27            let duration = (tf - t0).abs();
28            let default_steps = T::from_usize(100).unwrap();
29            self.h0 = duration / default_steps;
30        }
31
32        // Check bounds
33        match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
34            Ok(h0) => self.h = h0,
35            Err(status) => return Err(status),
36        }
37
38        // Initialize Statistics
39        self.steps = 0;
40
41        // Initialize State
42        self.t = t0;
43        self.y = y0.clone();
44        self.dydt = y0.zeros_like();
45        self.y_prev = y0.clone();
46        self.dydt_prev = y0.zeros_like();
47        self.k = core::array::from_fn(|_| y0.zeros_like());
48        self.cont = core::array::from_fn(|_| y0.zeros_like());
49
50        // Calculate initial drift and diffusion
51        sde.drift(self.t, &self.y, &mut self.dydt);
52        let mut diffusion = y0.zeros_like();
53        sde.diffusion(self.t, &self.y, &mut diffusion);
54        evals.function += 2; // 1 for drift + 1 for diffusion
55
56        // Initialize previous state
57        self.t_prev = self.t;
58        self.y_prev = self.y.clone();
59        self.dydt_prev = self.dydt.clone();
60
61        // Initialize Status
62        self.status = Status::Initialized;
63
64        Ok(evals)
65    }
66
67    fn step<F>(&mut self, sde: &mut F) -> Result<Evals, Error<T, Y>>
68    where
69        F: SDE<T, Y> + ?Sized,
70    {
71        let mut evals = Evals::new();
72
73        // Check max steps
74        if self.steps >= self.max_steps {
75            self.status = Status::Error(Error::MaxSteps {
76                t: self.t,
77                y: self.y.clone(),
78            });
79            return Err(Error::MaxSteps {
80                t: self.t,
81                y: self.y.clone(),
82            });
83        }
84        self.steps += 1;
85
86        // Store current state before update for interpolation
87        self.t_prev = self.t;
88        self.y_prev = self.y.clone();
89        self.dydt_prev = self.dydt.clone();
90
91        // Save k[0] as the current drift
92        self.k[0] = self.dydt.clone();
93
94        // Compute Runge-Kutta stages for the drift term
95        for i in 1..self.stages {
96            let mut y_stage = self.y.clone();
97
98            for j in 0..i {
99                y_stage.add_scaled(self.a[i][j] * self.h, &self.k[j]);
100            }
101
102            sde.drift(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
103        }
104        evals.function += self.stages - 1; // We already have k[0]
105
106        // Compute deterministic part using RK weights
107        let mut drift_increment = self.y.zeros_like();
108        for i in 0..self.stages {
109            drift_increment.add_scaled(self.b[i] * self.h, &self.k[i]);
110        }
111
112        // Compute diffusion term at current state
113        let mut diffusion = self.y.zeros_like();
114        sde.diffusion(self.t, &self.y, &mut diffusion);
115        evals.function += 1;
116
117        // Generate noise increments
118        let mut dw = self.y.zeros_like();
119        sde.noise(self.h, &mut dw);
120
121        // Compute stochastic increment (Euler-Maruyama style)
122        let diffusion_increment = component_multiply(&diffusion, &dw);
123
124        // Combine deterministic and stochastic parts
125        let y_next = self.y.plus_linear_combination(&[
126            (&drift_increment, T::one()),
127            (&diffusion_increment, T::one()),
128        ]);
129
130        // Update state
131        self.t += self.h;
132        self.y = y_next;
133
134        // Calculate new drift for next step
135        if self.fsal {
136            // If FSAL (First Same As Last) is enabled, we can reuse the last derivative
137            self.dydt = self.k[S - 1].clone();
138        } else {
139            // Otherwise, compute the new derivative
140            sde.drift(self.t, &self.y, &mut self.dydt);
141            evals.function += 1;
142        }
143
144        self.status = Status::Solving;
145        Ok(evals)
146    }
147
148    fn t(&self) -> T {
149        self.t
150    }
151    fn y(&self) -> &Y {
152        &self.y
153    }
154    fn t_prev(&self) -> T {
155        self.t_prev
156    }
157    fn y_prev(&self) -> &Y {
158        &self.y_prev
159    }
160    fn h(&self) -> T {
161        self.h
162    }
163    fn set_h(&mut self, h: T) {
164        self.h = h;
165    }
166    fn status(&self) -> &Status<T, Y> {
167        &self.status
168    }
169    fn set_status(&mut self, status: Status<T, Y>) {
170        self.status = status;
171    }
172}
173
174impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
175    for ExplicitRungeKutta<Stochastic, Fixed, T, Y, O, S, I>
176{
177    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
178        // Check if t is within bounds
179        if t_interp < self.t_prev || t_interp > self.t {
180            return Err(Error::OutOfBounds {
181                t_interp,
182                t_prev: self.t_prev,
183                t_curr: self.t,
184            });
185        }
186
187        // For stochastic methods, we typically use linear interpolation
188        // since the exact path between points involves the Wiener process
189        let y_interp = linear_interpolate(self.t_prev, self.t, &self.y_prev, &self.y, t_interp);
190
191        Ok(y_interp)
192    }
193}