differential_equations/methods/erk/fixed/
stochastic.rs

1//! Fixed Runge-Kutta methods for SDEs
2
3use super::{ExplicitRungeKutta, Stochastic, Fixed};
4use crate::{
5    Error, Status,
6    alias::Evals,
7    interpolate::Interpolation,
8    linalg::component_multiply,
9    sde::{StochasticNumericalMethod, SDE},
10    traits::{CallBackData, Real, State},
11    utils::validate_step_size_parameters,
12};
13
14impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> StochasticNumericalMethod<T, V, D> for ExplicitRungeKutta<Stochastic, Fixed, T, V, D, O, S, I> {    
15    fn init<F>(&mut self, sde: &mut F, t0: T, tf: T, y0: &V) -> Result<Evals, Error<T, V>>
16    where
17        F: SDE<T, V, D>,
18    {
19        let mut evals = Evals::new();
20
21        // If h0 is zero, calculate initial step size for fixed-step methods
22        if self.h0 == T::zero() {
23            // Simple default step size for fixed-step methods
24            let duration = (tf - t0).abs();
25            let default_steps = T::from_usize(100).unwrap();
26            self.h0 = duration / default_steps;
27        }
28
29        // Check bounds
30        match validate_step_size_parameters::<T, V, D>(self.h0, self.h_min, self.h_max, t0, tf) {
31            Ok(h0) => self.h = h0,
32            Err(status) => return Err(status),
33        }
34
35        // Initialize Statistics
36        self.steps = 0;
37
38        // Initialize State
39        self.t = t0;
40        self.y = *y0;
41        
42        // Calculate initial drift and diffusion
43        sde.drift(self.t, &self.y, &mut self.dydt);
44        let mut diffusion = V::zeros();
45        sde.diffusion(self.t, &self.y, &mut diffusion);
46        evals.fcn += 2; // 1 for drift + 1 for diffusion
47
48        // Initialize previous state
49        self.t_prev = self.t;
50        self.y_prev = self.y;
51        self.dydt_prev = self.dydt;
52
53        // Initialize Status
54        self.status = Status::Initialized;
55
56        Ok(evals)
57    }
58
59    fn step<F>(&mut self, sde: &mut F) -> Result<Evals, Error<T, V>>
60    where
61        F: SDE<T, V, D>,
62    {
63        let mut evals = Evals::new();
64
65        // Check max steps
66        if self.steps >= self.max_steps {
67            self.status = Status::Error(Error::MaxSteps {
68                t: self.t, y: self.y
69            });
70            return Err(Error::MaxSteps {
71                t: self.t, y: self.y
72            });
73        }
74        self.steps += 1;
75
76        // Store current state before update for interpolation
77        self.t_prev = self.t;
78        self.y_prev = self.y;
79        self.dydt_prev = self.dydt;
80
81        // Save k[0] as the current drift
82        self.k[0] = self.dydt;
83
84        // Compute Runge-Kutta stages for the drift term
85        for i in 1..self.stages {
86            let mut y_stage = self.y;
87
88            for j in 0..i {
89                y_stage += self.k[j] * (self.a[i][j] * self.h);
90            }
91
92            sde.drift(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
93        }
94        evals.fcn += self.stages - 1; // We already have k[0]
95
96        // Compute deterministic part using RK weights
97        let mut drift_increment = V::zeros();
98        for i in 0..self.stages {
99            drift_increment += self.k[i] * (self.b[i] * self.h);
100        }
101
102        // Compute diffusion term at current state
103        let mut diffusion = V::zeros();
104        sde.diffusion(self.t, &self.y, &mut diffusion);
105        evals.fcn += 1;
106
107        // Generate noise increments
108        let mut dw = V::zeros();
109        sde.noise(self.h, &mut dw);
110
111        // Compute stochastic increment (Euler-Maruyama style)
112        let diffusion_increment = component_multiply(&diffusion, &dw);
113
114        // Combine deterministic and stochastic parts
115        let y_next = self.y + drift_increment + diffusion_increment;
116
117        // Update state
118        self.t += self.h;
119        self.y = y_next;
120        
121        // Calculate new drift for next step
122        if self.fsal {
123            // If FSAL (First Same As Last) is enabled, we can reuse the last derivative
124            self.dydt = self.k[S - 1];
125        } else {
126            // Otherwise, compute the new derivative
127            sde.drift(self.t, &self.y, &mut self.dydt);
128            evals.fcn += 1;
129        }
130        
131        self.status = Status::Solving;        
132        Ok(evals)
133    }
134
135    fn t(&self) -> T { self.t }
136    fn y(&self) -> &V { &self.y }
137    fn t_prev(&self) -> T { self.t_prev }
138    fn y_prev(&self) -> &V { &self.y_prev }
139    fn h(&self) -> T { self.h }
140    fn set_h(&mut self, h: T) { self.h = h; }
141    fn status(&self) -> &Status<T, V, D> { &self.status }
142    fn set_status(&mut self, status: Status<T, V, D>) { self.status = status; }
143}
144
145impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> Interpolation<T, V> for ExplicitRungeKutta<Stochastic, Fixed, T, V, D, O, S, I> {
146    fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {
147        // Check if t is within bounds
148        if t_interp < self.t_prev || t_interp > self.t {
149            return Err(Error::OutOfBounds {
150                t_interp,
151                t_prev: self.t_prev,
152                t_curr: self.t
153            });
154        }       
155        
156        // For stochastic methods, we typically use linear interpolation
157        // since the exact path between points involves the Wiener process
158        // which is not deterministic
159        let s = (t_interp - self.t_prev) / (self.t - self.t_prev);
160        let y_interp = self.y_prev + (self.y - self.y_prev) * s;
161
162        Ok(y_interp)
163    }
164}