differential_equations/methods/erk/adaptive/
ordinary.rs1use crate::{
4 error::Error,
5 interpolate::{Interpolation, cubic_hermite_interpolate},
6 methods::{Adaptive, ExplicitRungeKutta, Ordinary, h_init::InitialStepSize},
7 ode::{ODE, OrdinaryNumericalMethod},
8 stats::Evals,
9 status::Status,
10 traits::{CallBackData, Real, State},
11 utils::{constrain_step_size, validate_step_size_parameters},
12};
13
14impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
15 OrdinaryNumericalMethod<T, Y, D> for ExplicitRungeKutta<Ordinary, Adaptive, T, Y, D, O, S, I>
16{
17 fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
18 where
19 F: ODE<T, Y, D>,
20 {
21 let mut evals = Evals::new();
22
23 if self.h0 == T::zero() {
25 self.h0 = InitialStepSize::<Ordinary>::compute(
27 ode, t0, tf, y0, self.order, self.rtol, self.atol, self.h_min, self.h_max,
28 &mut evals,
29 );
30 evals.function += 2;
31 }
32
33 match validate_step_size_parameters::<T, Y, D>(self.h0, self.h_min, self.h_max, t0, tf) {
35 Ok(h0) => self.h = h0,
36 Err(status) => return Err(status),
37 }
38
39 self.stiffness_counter = 0;
41
42 self.t = t0;
44 self.y = *y0;
45 ode.diff(self.t, &self.y, &mut self.dydt);
46 evals.function += 1;
47
48 self.t_prev = self.t;
50 self.y_prev = self.y;
51 self.dydt_prev = self.dydt;
52
53 self.status = Status::Initialized;
55
56 Ok(evals)
57 }
58
59 fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
60 where
61 F: ODE<T, Y, D>,
62 {
63 let mut evals = Evals::new();
64
65 if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
67 self.status = Status::Error(Error::StepSize {
68 t: self.t,
69 y: self.y,
70 });
71 return Err(Error::StepSize {
72 t: self.t,
73 y: self.y,
74 });
75 }
76
77 if self.steps >= self.max_steps {
79 self.status = Status::Error(Error::MaxSteps {
80 t: self.t,
81 y: self.y,
82 });
83 return Err(Error::MaxSteps {
84 t: self.t,
85 y: self.y,
86 });
87 }
88 self.steps += 1;
89
90 self.k[0] = self.dydt;
92
93 for i in 1..self.stages {
95 let mut y_stage = self.y;
96
97 for j in 0..i {
98 y_stage += self.k[j] * (self.a[i][j] * self.h);
99 }
100
101 ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
102 }
103 evals.function += self.stages - 1; let mut y_high = self.y;
108 for i in 0..self.stages {
109 y_high += self.k[i] * (self.b[i] * self.h);
110 }
111
112 let mut y_low = self.y;
114 let bh = &self.bh.unwrap();
115 for i in 0..self.stages {
116 y_low += self.k[i] * (bh[i] * self.h);
117 }
118
119 let err = y_high - y_low;
121
122 let mut err_norm: T = T::zero();
124
125 for n in 0..self.y.len() {
127 let tol = self.atol + self.rtol * self.y.get(n).abs().max(y_high.get(n).abs());
128 err_norm = err_norm.max((err.get(n) / tol).abs());
129 }
130
131 let order = T::from_usize(self.order).unwrap();
133 let error_exponent = T::one() / order;
134 let mut scale = self.safety_factor * err_norm.powf(-error_exponent);
135
136 scale = scale.max(self.min_scale).min(self.max_scale);
138
139 if err_norm <= T::one() {
141 self.t_prev = self.t;
143 self.y_prev = self.y;
144 self.dydt_prev = self.k[0];
145 self.h_prev = self.h;
146
147 if let Status::RejectedStep = self.status {
148 self.stiffness_counter = 0;
149 self.status = Status::Solving;
150
151 scale = scale.min(T::one());
153 }
154
155 if self.bi.is_some() {
157 for i in 0..(I - S) {
159 let mut y_stage = self.y;
160 for j in 0..self.stages + i {
161 y_stage += self.k[j] * (self.a[self.stages + i][j] * self.h);
162 }
163
164 ode.diff(
165 self.t + self.c[self.stages + i] * self.h,
166 &y_stage,
167 &mut self.k[self.stages + i],
168 );
169 }
170 evals.function += I - S;
171 }
172
173 self.t += self.h;
175 self.y = y_high;
176
177 if self.fsal {
179 self.dydt = self.k[S - 1];
181 } else {
182 ode.diff(self.t, &self.y, &mut self.dydt);
184 evals.function += 1;
185 }
186 } else {
187 self.status = Status::RejectedStep;
189 self.stiffness_counter += 1;
190
191 if self.stiffness_counter >= self.max_rejects {
193 self.status = Status::Error(Error::Stiffness {
194 t: self.t,
195 y: self.y,
196 });
197 return Err(Error::Stiffness {
198 t: self.t,
199 y: self.y,
200 });
201 }
202 }
203
204 self.h *= scale;
206
207 self.h = constrain_step_size(self.h, self.h_min, self.h_max);
209
210 Ok(evals)
211 }
212
213 fn t(&self) -> T {
214 self.t
215 }
216 fn y(&self) -> &Y {
217 &self.y
218 }
219 fn t_prev(&self) -> T {
220 self.t_prev
221 }
222 fn y_prev(&self) -> &Y {
223 &self.y_prev
224 }
225 fn h(&self) -> T {
226 self.h
227 }
228 fn set_h(&mut self, h: T) {
229 self.h = h;
230 }
231 fn status(&self) -> &Status<T, Y, D> {
232 &self.status
233 }
234 fn set_status(&mut self, status: Status<T, Y, D>) {
235 self.status = status;
236 }
237}
238
239impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
240 Interpolation<T, Y> for ExplicitRungeKutta<Ordinary, Adaptive, T, Y, D, O, S, I>
241{
242 fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
243 if t_interp < self.t_prev || t_interp > self.t {
245 return Err(Error::OutOfBounds {
246 t_interp,
247 t_prev: self.t_prev,
248 t_curr: self.t,
249 });
250 }
251
252 if self.bi.is_some() {
254 let s = (t_interp - self.t_prev) / self.h_prev;
256
257 let bi = self.bi.as_ref().unwrap();
259
260 let mut cont = [T::zero(); I];
261 for i in 0..self.dense_stages {
263 cont[i] = bi[i][self.order - 1];
265
266 for j in (0..self.order - 1).rev() {
268 cont[i] = cont[i] * s + bi[i][j];
269 }
270
271 cont[i] *= s;
273 }
274
275 let mut y_interp = self.y_prev;
277 for i in 0..I {
278 y_interp += self.k[i] * cont[i] * self.h_prev;
279 }
280
281 Ok(y_interp)
282 } else {
283 let y_interp = cubic_hermite_interpolate(
285 self.t_prev,
286 self.t,
287 &self.y_prev,
288 &self.y,
289 &self.dydt_prev,
290 &self.dydt,
291 t_interp,
292 );
293
294 Ok(y_interp)
295 }
296 }
297}