differential_equations/methods/erk/fixed/
stochastic.rs1use crate::{
4 error::Error,
5 interpolate::{Interpolation, linear_interpolate},
6 linalg::component_multiply,
7 methods::{ExplicitRungeKutta, Fixed, Stochastic},
8 sde::{SDE, StochasticNumericalMethod},
9 stats::Evals,
10 status::Status,
11 traits::{CallBackData, Real, State},
12 utils::validate_step_size_parameters,
13};
14
15impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
16 StochasticNumericalMethod<T, Y, D> for ExplicitRungeKutta<Stochastic, Fixed, T, Y, D, O, S, I>
17{
18 fn init<F>(&mut self, sde: &mut F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
19 where
20 F: SDE<T, Y, D>,
21 {
22 let mut evals = Evals::new();
23
24 if self.h0 == T::zero() {
26 let duration = (tf - t0).abs();
28 let default_steps = T::from_usize(100).unwrap();
29 self.h0 = duration / default_steps;
30 }
31
32 match validate_step_size_parameters::<T, Y, D>(self.h0, self.h_min, self.h_max, t0, tf) {
34 Ok(h0) => self.h = h0,
35 Err(status) => return Err(status),
36 }
37
38 self.steps = 0;
40
41 self.t = t0;
43 self.y = *y0;
44
45 sde.drift(self.t, &self.y, &mut self.dydt);
47 let mut diffusion = Y::zeros();
48 sde.diffusion(self.t, &self.y, &mut diffusion);
49 evals.function += 2; self.t_prev = self.t;
53 self.y_prev = self.y;
54 self.dydt_prev = self.dydt;
55
56 self.status = Status::Initialized;
58
59 Ok(evals)
60 }
61
62 fn step<F>(&mut self, sde: &mut F) -> Result<Evals, Error<T, Y>>
63 where
64 F: SDE<T, Y, D>,
65 {
66 let mut evals = Evals::new();
67
68 if self.steps >= self.max_steps {
70 self.status = Status::Error(Error::MaxSteps {
71 t: self.t,
72 y: self.y,
73 });
74 return Err(Error::MaxSteps {
75 t: self.t,
76 y: self.y,
77 });
78 }
79 self.steps += 1;
80
81 self.t_prev = self.t;
83 self.y_prev = self.y;
84 self.dydt_prev = self.dydt;
85
86 self.k[0] = self.dydt;
88
89 for i in 1..self.stages {
91 let mut y_stage = self.y;
92
93 for j in 0..i {
94 y_stage += self.k[j] * (self.a[i][j] * self.h);
95 }
96
97 sde.drift(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
98 }
99 evals.function += self.stages - 1; let mut drift_increment = Y::zeros();
103 for i in 0..self.stages {
104 drift_increment += self.k[i] * (self.b[i] * self.h);
105 }
106
107 let mut diffusion = Y::zeros();
109 sde.diffusion(self.t, &self.y, &mut diffusion);
110 evals.function += 1;
111
112 let mut dw = Y::zeros();
114 sde.noise(self.h, &mut dw);
115
116 let diffusion_increment = component_multiply(&diffusion, &dw);
118
119 let y_next = self.y + drift_increment + diffusion_increment;
121
122 self.t += self.h;
124 self.y = y_next;
125
126 if self.fsal {
128 self.dydt = self.k[S - 1];
130 } else {
131 sde.drift(self.t, &self.y, &mut self.dydt);
133 evals.function += 1;
134 }
135
136 self.status = Status::Solving;
137 Ok(evals)
138 }
139
140 fn t(&self) -> T {
141 self.t
142 }
143 fn y(&self) -> &Y {
144 &self.y
145 }
146 fn t_prev(&self) -> T {
147 self.t_prev
148 }
149 fn y_prev(&self) -> &Y {
150 &self.y_prev
151 }
152 fn h(&self) -> T {
153 self.h
154 }
155 fn set_h(&mut self, h: T) {
156 self.h = h;
157 }
158 fn status(&self) -> &Status<T, Y, D> {
159 &self.status
160 }
161 fn set_status(&mut self, status: Status<T, Y, D>) {
162 self.status = status;
163 }
164}
165
166impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
167 Interpolation<T, Y> for ExplicitRungeKutta<Stochastic, Fixed, T, Y, D, O, S, I>
168{
169 fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
170 if t_interp < self.t_prev || t_interp > self.t {
172 return Err(Error::OutOfBounds {
173 t_interp,
174 t_prev: self.t_prev,
175 t_curr: self.t,
176 });
177 }
178
179 let y_interp = linear_interpolate(self.t_prev, self.t, &self.y_prev, &self.y, t_interp);
182
183 Ok(y_interp)
184 }
185}