differential_equations/methods/erk/adaptive/
ordinary.rs1use super::{ExplicitRungeKutta, Ordinary, Adaptive};
4use crate::{
5 Error, Status,
6 alias::Evals,
7 methods::h_init::InitialStepSize,
8 interpolate::{Interpolation, cubic_hermite_interpolate},
9 ode::{OrdinaryNumericalMethod, ODE},
10 traits::{CallBackData, Real, State},
11 utils::{constrain_step_size, validate_step_size_parameters},
12};
13
14impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> OrdinaryNumericalMethod<T, V, D> for ExplicitRungeKutta<Ordinary, Adaptive, T, V, D, O, S, I> {
15 fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &V) -> Result<Evals, Error<T, V>>
16 where
17 F: ODE<T, V, D>,
18 {
19 let mut evals = Evals::new();
20
21 if self.h0 == T::zero() {
23 self.h0 = InitialStepSize::<Ordinary>::compute(ode, t0, tf, y0, self.order, self.rtol, self.atol, self.h_min, self.h_max, &mut evals);
25 evals.fcn += 2;
26
27 }
28
29 match validate_step_size_parameters::<T, V, D>(self.h0, self.h_min, self.h_max, t0, tf) {
31 Ok(h0) => self.h = h0,
32 Err(status) => return Err(status),
33 }
34
35 self.stiffness_counter = 0;
37
38 self.t = t0;
40 self.y = *y0;
41 ode.diff(self.t, &self.y, &mut self.dydt);
42 evals.fcn += 1;
43
44 self.t_prev = self.t;
46 self.y_prev = self.y;
47 self.dydt_prev = self.dydt;
48
49 self.status = Status::Initialized;
51
52 Ok(evals)
53 }
54
55 fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, V>>
56 where
57 F: ODE<T, V, D>,
58 {
59 let mut evals = Evals::new();
60
61 if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
63 self.status = Status::Error(Error::StepSize {
64 t: self.t, y: self.y
65 });
66 return Err(Error::StepSize {
67 t: self.t, y: self.y
68 });
69 }
70
71 if self.steps >= self.max_steps {
73 self.status = Status::Error(Error::MaxSteps {
74 t: self.t, y: self.y
75 });
76 return Err(Error::MaxSteps {
77 t: self.t, y: self.y
78 });
79 }
80 self.steps += 1;
81
82 self.k[0] = self.dydt;
84
85 for i in 1..self.stages {
87 let mut y_stage = self.y;
88
89 for j in 0..i {
90 y_stage += self.k[j] * (self.a[i][j] * self.h);
91 }
92
93 ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
94 }
95 evals.fcn += self.stages - 1; let mut y_high = self.y;
100 for i in 0..self.stages {
101 y_high += self.k[i] * (self.b[i] * self.h);
102 }
103
104 let mut y_low = self.y;
106 if let Some(bh) = &self.bh {
107 for i in 0..self.stages {
108 y_low += self.k[i] * (bh[i] * self.h);
109 }
110 }
111
112 let err = y_high - y_low;
114
115 let mut err_norm: T = T::zero();
117
118 for n in 0..self.y.len() {
120 let tol = self.atol + self.rtol * self.y.get(n).abs().max(y_high.get(n).abs());
121 err_norm = err_norm.max((err.get(n) / tol).abs());
122 };
123
124 let order = T::from_usize(self.order).unwrap();
126 let error_exponent = T::one() / order;
127 let mut scale = self.safety_factor * err_norm.powf(-error_exponent);
128
129 scale = scale.max(self.min_scale).min(self.max_scale);
131
132 if err_norm <= T::one() {
134 self.t_prev = self.t;
136 self.y_prev = self.y;
137 self.dydt_prev = self.k[0];
138 self.h_prev = self.h;
139
140 if let Status::RejectedStep = self.status {
141 self.stiffness_counter = 0;
142 self.status = Status::Solving;
143
144 scale = scale.min(T::one());
146 }
147
148 if self.bi.is_some() {
150 for i in 0..(I - S) {
152 let mut y_stage = self.y;
153 for j in 0..self.stages + i {
154 y_stage += self.k[j] * (self.a[self.stages + i][j] * self.h);
155 }
156
157 ode.diff(self.t + self.c[self.stages + i] * self.h, &y_stage, &mut self.k[self.stages + i]);
158 }
159 evals.fcn += I - S;
160 }
161
162 self.t += self.h;
164 self.y = y_high;
165
166 if self.fsal {
168 self.dydt = self.k[S - 1];
170 } else {
171 ode.diff(self.t, &self.y, &mut self.dydt);
173 evals.fcn += 1;
174 }
175 } else {
176 self.status = Status::RejectedStep;
178 self.stiffness_counter += 1;
179
180 if self.stiffness_counter >= self.max_rejects {
182 self.status = Status::Error(Error::Stiffness {
183 t: self.t, y: self.y
184 });
185 return Err(Error::Stiffness {
186 t: self.t, y: self.y
187 });
188 }
189 }
190
191 self.h *= scale;
193
194 self.h = constrain_step_size(self.h, self.h_min, self.h_max);
196
197 Ok(evals)
198 }
199
200 fn t(&self) -> T { self.t }
201 fn y(&self) -> &V { &self.y }
202 fn t_prev(&self) -> T { self.t_prev }
203 fn y_prev(&self) -> &V { &self.y_prev }
204 fn h(&self) -> T { self.h }
205 fn set_h(&mut self, h: T) { self.h = h; }
206 fn status(&self) -> &Status<T, V, D> { &self.status }
207 fn set_status(&mut self, status: Status<T, V, D>) { self.status = status; }
208}
209
210impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> Interpolation<T, V> for ExplicitRungeKutta<Ordinary, Adaptive, T, V, D, O, S, I> {
211 fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {
212 if t_interp < self.t_prev || t_interp > self.t {
214 return Err(Error::OutOfBounds {
215 t_interp,
216 t_prev: self.t_prev,
217 t_curr: self.t
218 });
219 }
220
221 if self.bi.is_some() {
223 let s = (t_interp - self.t_prev) / self.h_prev;
225
226 let bi = self.bi.as_ref().unwrap();
228
229 let mut cont = [T::zero(); I];
230 for i in 0..self.dense_stages {
232 cont[i] = bi[i][self.order - 1];
234
235 for j in (0..self.order - 1).rev() {
237 cont[i] = cont[i] * s + bi[i][j];
238 }
239
240 cont[i] *= s;
242 }
243
244 let mut y_interp = self.y_prev;
246 for i in 0..I {
247 y_interp += self.k[i] * cont[i] * self.h_prev;
248 }
249
250 Ok(y_interp)
251 } else {
252 let y_interp = cubic_hermite_interpolate(
254 self.t_prev,
255 self.t,
256 &self.y_prev,
257 &self.y,
258 &self.dydt_prev,
259 &self.dydt,
260 t_interp
261 );
262
263 Ok(y_interp)
264 }
265 }
266}