differential_equations/methods/erk/fixed/
stochastic.rs1use crate::{
4 error::Error,
5 interpolate::{Interpolation, linear_interpolate},
6 linalg::component_multiply,
7 methods::{ExplicitRungeKutta, Fixed, Stochastic},
8 sde::{SDE, StochasticNumericalMethod},
9 stats::Evals,
10 status::Status,
11 traits::{Real, State},
12 utils::validate_step_size_parameters,
13};
14
15impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
16 StochasticNumericalMethod<T, Y> for ExplicitRungeKutta<Stochastic, Fixed, T, Y, O, S, I>
17{
18 fn init<F>(&mut self, sde: &mut F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
19 where
20 F: SDE<T, Y> + ?Sized,
21 {
22 let mut evals = Evals::new();
23
24 if self.h0 == T::zero() {
26 let duration = (tf - t0).abs();
28 let default_steps = T::from_usize(100).unwrap();
29 self.h0 = duration / default_steps;
30 }
31
32 match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
34 Ok(h0) => self.h = h0,
35 Err(status) => return Err(status),
36 }
37
38 self.steps = 0;
40
41 self.t = t0;
43 self.y = y0.clone();
44 self.dydt = y0.zeros_like();
45 self.y_prev = y0.clone();
46 self.dydt_prev = y0.zeros_like();
47 self.k = core::array::from_fn(|_| y0.zeros_like());
48 self.cont = core::array::from_fn(|_| y0.zeros_like());
49
50 sde.drift(self.t, &self.y, &mut self.dydt);
52 let mut diffusion = y0.zeros_like();
53 sde.diffusion(self.t, &self.y, &mut diffusion);
54 evals.function += 2; self.t_prev = self.t;
58 self.y_prev = self.y.clone();
59 self.dydt_prev = self.dydt.clone();
60
61 self.status = Status::Initialized;
63
64 Ok(evals)
65 }
66
67 fn step<F>(&mut self, sde: &mut F) -> Result<Evals, Error<T, Y>>
68 where
69 F: SDE<T, Y> + ?Sized,
70 {
71 let mut evals = Evals::new();
72
73 if self.steps >= self.max_steps {
75 self.status = Status::Error(Error::MaxSteps {
76 t: self.t,
77 y: self.y.clone(),
78 });
79 return Err(Error::MaxSteps {
80 t: self.t,
81 y: self.y.clone(),
82 });
83 }
84 self.steps += 1;
85
86 self.t_prev = self.t;
88 self.y_prev = self.y.clone();
89 self.dydt_prev = self.dydt.clone();
90
91 self.k[0] = self.dydt.clone();
93
94 for i in 1..self.stages {
96 let mut y_stage = self.y.clone();
97
98 for j in 0..i {
99 y_stage.add_scaled(self.a[i][j] * self.h, &self.k[j]);
100 }
101
102 sde.drift(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
103 }
104 evals.function += self.stages - 1; let mut drift_increment = self.y.zeros_like();
108 for i in 0..self.stages {
109 drift_increment.add_scaled(self.b[i] * self.h, &self.k[i]);
110 }
111
112 let mut diffusion = self.y.zeros_like();
114 sde.diffusion(self.t, &self.y, &mut diffusion);
115 evals.function += 1;
116
117 let mut dw = self.y.zeros_like();
119 sde.noise(self.h, &mut dw);
120
121 let diffusion_increment = component_multiply(&diffusion, &dw);
123
124 let y_next = self.y.plus_linear_combination(&[
126 (&drift_increment, T::one()),
127 (&diffusion_increment, T::one()),
128 ]);
129
130 self.t += self.h;
132 self.y = y_next;
133
134 if self.fsal {
136 self.dydt = self.k[S - 1].clone();
138 } else {
139 sde.drift(self.t, &self.y, &mut self.dydt);
141 evals.function += 1;
142 }
143
144 self.status = Status::Solving;
145 Ok(evals)
146 }
147
148 fn t(&self) -> T {
149 self.t
150 }
151 fn y(&self) -> &Y {
152 &self.y
153 }
154 fn t_prev(&self) -> T {
155 self.t_prev
156 }
157 fn y_prev(&self) -> &Y {
158 &self.y_prev
159 }
160 fn h(&self) -> T {
161 self.h
162 }
163 fn set_h(&mut self, h: T) {
164 self.h = h;
165 }
166 fn status(&self) -> &Status<T, Y> {
167 &self.status
168 }
169 fn set_status(&mut self, status: Status<T, Y>) {
170 self.status = status;
171 }
172}
173
174impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
175 for ExplicitRungeKutta<Stochastic, Fixed, T, Y, O, S, I>
176{
177 fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
178 if t_interp < self.t_prev || t_interp > self.t {
180 return Err(Error::OutOfBounds {
181 t_interp,
182 t_prev: self.t_prev,
183 t_curr: self.t,
184 });
185 }
186
187 let y_interp = linear_interpolate(self.t_prev, self.t, &self.y_prev, &self.y, t_interp);
190
191 Ok(y_interp)
192 }
193}