differential_equations/methods/erk/fixed/
stochastic.rs1use super::{ExplicitRungeKutta, Stochastic, Fixed};
4use crate::{
5 Error, Status,
6 alias::Evals,
7 interpolate::Interpolation,
8 linalg::component_multiply,
9 sde::{StochasticNumericalMethod, SDE},
10 traits::{CallBackData, Real, State},
11 utils::validate_step_size_parameters,
12};
13
14impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> StochasticNumericalMethod<T, V, D> for ExplicitRungeKutta<Stochastic, Fixed, T, V, D, O, S, I> {
15 fn init<F>(&mut self, sde: &mut F, t0: T, tf: T, y0: &V) -> Result<Evals, Error<T, V>>
16 where
17 F: SDE<T, V, D>,
18 {
19 let mut evals = Evals::new();
20
21 if self.h0 == T::zero() {
23 let duration = (tf - t0).abs();
25 let default_steps = T::from_usize(100).unwrap();
26 self.h0 = duration / default_steps;
27 }
28
29 match validate_step_size_parameters::<T, V, D>(self.h0, self.h_min, self.h_max, t0, tf) {
31 Ok(h0) => self.h = h0,
32 Err(status) => return Err(status),
33 }
34
35 self.steps = 0;
37
38 self.t = t0;
40 self.y = *y0;
41
42 sde.drift(self.t, &self.y, &mut self.dydt);
44 let mut diffusion = V::zeros();
45 sde.diffusion(self.t, &self.y, &mut diffusion);
46 evals.fcn += 2; self.t_prev = self.t;
50 self.y_prev = self.y;
51 self.dydt_prev = self.dydt;
52
53 self.status = Status::Initialized;
55
56 Ok(evals)
57 }
58
59 fn step<F>(&mut self, sde: &mut F) -> Result<Evals, Error<T, V>>
60 where
61 F: SDE<T, V, D>,
62 {
63 let mut evals = Evals::new();
64
65 if self.steps >= self.max_steps {
67 self.status = Status::Error(Error::MaxSteps {
68 t: self.t, y: self.y
69 });
70 return Err(Error::MaxSteps {
71 t: self.t, y: self.y
72 });
73 }
74 self.steps += 1;
75
76 self.t_prev = self.t;
78 self.y_prev = self.y;
79 self.dydt_prev = self.dydt;
80
81 self.k[0] = self.dydt;
83
84 for i in 1..self.stages {
86 let mut y_stage = self.y;
87
88 for j in 0..i {
89 y_stage += self.k[j] * (self.a[i][j] * self.h);
90 }
91
92 sde.drift(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
93 }
94 evals.fcn += self.stages - 1; let mut drift_increment = V::zeros();
98 for i in 0..self.stages {
99 drift_increment += self.k[i] * (self.b[i] * self.h);
100 }
101
102 let mut diffusion = V::zeros();
104 sde.diffusion(self.t, &self.y, &mut diffusion);
105 evals.fcn += 1;
106
107 let mut dw = V::zeros();
109 sde.noise(self.h, &mut dw);
110
111 let diffusion_increment = component_multiply(&diffusion, &dw);
113
114 let y_next = self.y + drift_increment + diffusion_increment;
116
117 self.t += self.h;
119 self.y = y_next;
120
121 if self.fsal {
123 self.dydt = self.k[S - 1];
125 } else {
126 sde.drift(self.t, &self.y, &mut self.dydt);
128 evals.fcn += 1;
129 }
130
131 self.status = Status::Solving;
132 Ok(evals)
133 }
134
135 fn t(&self) -> T { self.t }
136 fn y(&self) -> &V { &self.y }
137 fn t_prev(&self) -> T { self.t_prev }
138 fn y_prev(&self) -> &V { &self.y_prev }
139 fn h(&self) -> T { self.h }
140 fn set_h(&mut self, h: T) { self.h = h; }
141 fn status(&self) -> &Status<T, V, D> { &self.status }
142 fn set_status(&mut self, status: Status<T, V, D>) { self.status = status; }
143}
144
145impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> Interpolation<T, V> for ExplicitRungeKutta<Stochastic, Fixed, T, V, D, O, S, I> {
146 fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {
147 if t_interp < self.t_prev || t_interp > self.t {
149 return Err(Error::OutOfBounds {
150 t_interp,
151 t_prev: self.t_prev,
152 t_curr: self.t
153 });
154 }
155
156 let s = (t_interp - self.t_prev) / (self.t - self.t_prev);
160 let y_interp = self.y_prev + (self.y - self.y_prev) * s;
161
162 Ok(y_interp)
163 }
164}