differential_equations/methods/erk/fixed/
stochastic.rs

1//! Fixed Runge-Kutta methods for SDEs
2
3use crate::{
4    error::Error,
5    interpolate::{Interpolation, linear_interpolate},
6    linalg::component_multiply,
7    methods::{ExplicitRungeKutta, Fixed, Stochastic},
8    sde::{SDE, StochasticNumericalMethod},
9    stats::Evals,
10    status::Status,
11    traits::{CallBackData, Real, State},
12    utils::validate_step_size_parameters,
13};
14
15impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
16    StochasticNumericalMethod<T, Y, D> for ExplicitRungeKutta<Stochastic, Fixed, T, Y, D, O, S, I>
17{
18    fn init<F>(&mut self, sde: &mut F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
19    where
20        F: SDE<T, Y, D>,
21    {
22        let mut evals = Evals::new();
23
24        // If h0 is zero, calculate initial step size for fixed-step methods
25        if self.h0 == T::zero() {
26            // Simple default step size for fixed-step methods
27            let duration = (tf - t0).abs();
28            let default_steps = T::from_usize(100).unwrap();
29            self.h0 = duration / default_steps;
30        }
31
32        // Check bounds
33        match validate_step_size_parameters::<T, Y, D>(self.h0, self.h_min, self.h_max, t0, tf) {
34            Ok(h0) => self.h = h0,
35            Err(status) => return Err(status),
36        }
37
38        // Initialize Statistics
39        self.steps = 0;
40
41        // Initialize State
42        self.t = t0;
43        self.y = *y0;
44
45        // Calculate initial drift and diffusion
46        sde.drift(self.t, &self.y, &mut self.dydt);
47        let mut diffusion = Y::zeros();
48        sde.diffusion(self.t, &self.y, &mut diffusion);
49        evals.function += 2; // 1 for drift + 1 for diffusion
50
51        // Initialize previous state
52        self.t_prev = self.t;
53        self.y_prev = self.y;
54        self.dydt_prev = self.dydt;
55
56        // Initialize Status
57        self.status = Status::Initialized;
58
59        Ok(evals)
60    }
61
62    fn step<F>(&mut self, sde: &mut F) -> Result<Evals, Error<T, Y>>
63    where
64        F: SDE<T, Y, D>,
65    {
66        let mut evals = Evals::new();
67
68        // Check max steps
69        if self.steps >= self.max_steps {
70            self.status = Status::Error(Error::MaxSteps {
71                t: self.t,
72                y: self.y,
73            });
74            return Err(Error::MaxSteps {
75                t: self.t,
76                y: self.y,
77            });
78        }
79        self.steps += 1;
80
81        // Store current state before update for interpolation
82        self.t_prev = self.t;
83        self.y_prev = self.y;
84        self.dydt_prev = self.dydt;
85
86        // Save k[0] as the current drift
87        self.k[0] = self.dydt;
88
89        // Compute Runge-Kutta stages for the drift term
90        for i in 1..self.stages {
91            let mut y_stage = self.y;
92
93            for j in 0..i {
94                y_stage += self.k[j] * (self.a[i][j] * self.h);
95            }
96
97            sde.drift(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
98        }
99        evals.function += self.stages - 1; // We already have k[0]
100
101        // Compute deterministic part using RK weights
102        let mut drift_increment = Y::zeros();
103        for i in 0..self.stages {
104            drift_increment += self.k[i] * (self.b[i] * self.h);
105        }
106
107        // Compute diffusion term at current state
108        let mut diffusion = Y::zeros();
109        sde.diffusion(self.t, &self.y, &mut diffusion);
110        evals.function += 1;
111
112        // Generate noise increments
113        let mut dw = Y::zeros();
114        sde.noise(self.h, &mut dw);
115
116        // Compute stochastic increment (Euler-Maruyama style)
117        let diffusion_increment = component_multiply(&diffusion, &dw);
118
119        // Combine deterministic and stochastic parts
120        let y_next = self.y + drift_increment + diffusion_increment;
121
122        // Update state
123        self.t += self.h;
124        self.y = y_next;
125
126        // Calculate new drift for next step
127        if self.fsal {
128            // If FSAL (First Same As Last) is enabled, we can reuse the last derivative
129            self.dydt = self.k[S - 1];
130        } else {
131            // Otherwise, compute the new derivative
132            sde.drift(self.t, &self.y, &mut self.dydt);
133            evals.function += 1;
134        }
135
136        self.status = Status::Solving;
137        Ok(evals)
138    }
139
140    fn t(&self) -> T {
141        self.t
142    }
143    fn y(&self) -> &Y {
144        &self.y
145    }
146    fn t_prev(&self) -> T {
147        self.t_prev
148    }
149    fn y_prev(&self) -> &Y {
150        &self.y_prev
151    }
152    fn h(&self) -> T {
153        self.h
154    }
155    fn set_h(&mut self, h: T) {
156        self.h = h;
157    }
158    fn status(&self) -> &Status<T, Y, D> {
159        &self.status
160    }
161    fn set_status(&mut self, status: Status<T, Y, D>) {
162        self.status = status;
163    }
164}
165
166impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
167    Interpolation<T, Y> for ExplicitRungeKutta<Stochastic, Fixed, T, Y, D, O, S, I>
168{
169    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
170        // Check if t is within bounds
171        if t_interp < self.t_prev || t_interp > self.t {
172            return Err(Error::OutOfBounds {
173                t_interp,
174                t_prev: self.t_prev,
175                t_curr: self.t,
176            });
177        }
178
179        // For stochastic methods, we typically use linear interpolation
180        // since the exact path between points involves the Wiener process
181        let y_interp = linear_interpolate(self.t_prev, self.t, &self.y_prev, &self.y, t_interp);
182
183        Ok(y_interp)
184    }
185}