differential_equations/methods/erk/adaptive/
ordinary.rs1use super::{ExplicitRungeKutta, Ordinary, Adaptive};
4use crate::{
5 Error, Status,
6 alias::Evals,
7 methods::h_init::InitialStepSize,
8 interpolate::{Interpolation, cubic_hermite_interpolate},
9 ode::{OrdinaryNumericalMethod, ODE},
10 traits::{CallBackData, Real, State},
11 utils::{constrain_step_size, validate_step_size_parameters},
12};
13
14impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> OrdinaryNumericalMethod<T, V, D> for ExplicitRungeKutta<Ordinary, Adaptive, T, V, D, O, S, I> {
15 fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &V) -> Result<Evals, Error<T, V>>
16 where
17 F: ODE<T, V, D>,
18 {
19 let mut evals = Evals::new();
20
21 if self.h0 == T::zero() {
23 self.h0 = InitialStepSize::<Ordinary>::compute(ode, t0, tf, y0, self.order, self.rtol, self.atol, self.h_min, self.h_max, &mut evals);
25 evals.fcn += 2;
26
27 }
28
29 match validate_step_size_parameters::<T, V, D>(self.h0, self.h_min, self.h_max, t0, tf) {
31 Ok(h0) => self.h = h0,
32 Err(status) => return Err(status),
33 }
34
35 self.stiffness_counter = 0;
37
38 self.t = t0;
40 self.y = *y0;
41 ode.diff(self.t, &self.y, &mut self.dydt);
42 evals.fcn += 1;
43
44 self.t_prev = self.t;
46 self.y_prev = self.y;
47 self.dydt_prev = self.dydt;
48
49 self.status = Status::Initialized;
51
52 Ok(evals)
53 }
54
55 fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, V>>
56 where
57 F: ODE<T, V, D>,
58 {
59 let mut evals = Evals::new();
60
61 if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
63 self.status = Status::Error(Error::StepSize {
64 t: self.t, y: self.y
65 });
66 return Err(Error::StepSize {
67 t: self.t, y: self.y
68 });
69 }
70
71 if self.steps >= self.max_steps {
73 self.status = Status::Error(Error::MaxSteps {
74 t: self.t, y: self.y
75 });
76 return Err(Error::MaxSteps {
77 t: self.t, y: self.y
78 });
79 }
80 self.steps += 1;
81
82 self.k[0] = self.dydt;
84
85 for i in 1..self.stages {
87 let mut y_stage = self.y;
88
89 for j in 0..i {
90 y_stage += self.k[j] * (self.a[i][j] * self.h);
91 }
92
93 ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
94 }
95 evals.fcn += self.stages - 1; let mut y_high = self.y;
100 for i in 0..self.stages {
101 y_high += self.k[i] * (self.b[i] * self.h);
102 }
103
104 let mut y_low = self.y;
106 if let Some(bh) = &self.bh {
107 for i in 0..self.stages {
108 y_low += self.k[i] * (bh[i] * self.h);
109 }
110 }
111
112 let err = y_high - y_low;
114
115 let mut err_norm: T = T::zero();
117
118 for n in 0..self.y.len() {
120 let tol = self.atol + self.rtol * self.y.get(n).abs().max(y_high.get(n).abs());
121 err_norm = err_norm.max((err.get(n) / tol).abs());
122 };
123
124 if err_norm <= T::one() {
126 self.t_prev = self.t;
128 self.y_prev = self.y;
129 self.dydt_prev = self.k[0];
130 self.h_prev = self.h;
131
132 if let Status::RejectedStep = self.status {
133 self.stiffness_counter = 0;
134 self.status = Status::Solving;
135 }
136
137 if self.bi.is_some() {
139 for i in 0..(I - S) {
141 let mut y_stage = self.y;
142 for j in 0..self.stages + i {
143 y_stage += self.k[j] * (self.a[self.stages + i][j] * self.h);
144 }
145
146 ode.diff(self.t + self.c[self.stages + i] * self.h, &y_stage, &mut self.k[self.stages + i]);
147 }
148 evals.fcn += I - S;
149 }
150
151 self.t += self.h;
153 self.y = y_high;
154
155 if self.fsal {
157 self.dydt = self.k[S - 1];
159 } else {
160 ode.diff(self.t, &self.y, &mut self.dydt);
162 evals.fcn += 1;
163 }
164 } else {
165 self.status = Status::RejectedStep;
167 self.stiffness_counter += 1;
168
169 if self.stiffness_counter >= self.max_rejects {
171 self.status = Status::Error(Error::Stiffness {
172 t: self.t, y: self.y
173 });
174 return Err(Error::Stiffness {
175 t: self.t, y: self.y
176 });
177 }
178 }
179
180 let order = T::from_usize(self.order).unwrap();
182 let err_order = T::one() / order;
183
184 let scale = self.safety_factor * err_norm.powf(-err_order);
186 let scale = scale.max(self.min_scale).min(self.max_scale);
187 self.h *= scale;
188
189 self.h = constrain_step_size(self.h, self.h_min, self.h_max);
191
192 Ok(evals)
193 }
194
195 fn t(&self) -> T { self.t }
196 fn y(&self) -> &V { &self.y }
197 fn t_prev(&self) -> T { self.t_prev }
198 fn y_prev(&self) -> &V { &self.y_prev }
199 fn h(&self) -> T { self.h }
200 fn set_h(&mut self, h: T) { self.h = h; }
201 fn status(&self) -> &Status<T, V, D> { &self.status }
202 fn set_status(&mut self, status: Status<T, V, D>) { self.status = status; }
203}
204
205impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> Interpolation<T, V> for ExplicitRungeKutta<Ordinary, Adaptive, T, V, D, O, S, I> {
206 fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {
207 if t_interp < self.t_prev || t_interp > self.t {
209 return Err(Error::OutOfBounds {
210 t_interp,
211 t_prev: self.t_prev,
212 t_curr: self.t
213 });
214 }
215
216 if self.bi.is_some() {
218 let s = (t_interp - self.t_prev) / self.h_prev;
220
221 let bi = self.bi.as_ref().unwrap();
223
224 let mut cont = [T::zero(); I];
225 for i in 0..self.dense_stages {
227 cont[i] = bi[i][self.order - 1];
229
230 for j in (0..self.order - 1).rev() {
232 cont[i] = cont[i] * s + bi[i][j];
233 }
234
235 cont[i] *= s;
237 }
238
239 let mut y_interp = self.y_prev;
241 for i in 0..I {
242 y_interp += self.k[i] * cont[i] * self.h_prev;
243 }
244
245 Ok(y_interp)
246 } else {
247 let y_interp = cubic_hermite_interpolate(
249 self.t_prev,
250 self.t,
251 &self.y_prev,
252 &self.y,
253 &self.dydt_prev,
254 &self.dydt,
255 t_interp
256 );
257
258 Ok(y_interp)
259 }
260 }
261}