differential_equations/methods/erk/dormandprince/
ordinary.rs1use crate::{
4 error::Error,
5 interpolate::Interpolation,
6 methods::{DormandPrince, ExplicitRungeKutta, Ordinary, h_init::InitialStepSize},
7 ode::{ODE, OrdinaryNumericalMethod},
8 stats::Evals,
9 status::Status,
10 traits::{CallBackData, Real, State},
11 utils::{constrain_step_size, validate_step_size_parameters},
12};
13
14impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
15 OrdinaryNumericalMethod<T, Y, D>
16 for ExplicitRungeKutta<Ordinary, DormandPrince, T, Y, D, O, S, I>
17{
18 fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
19 where
20 F: ODE<T, Y, D>,
21 {
22 let mut evals = Evals::new();
23
24 if self.h0 == T::zero() {
26 self.h0 = InitialStepSize::<Ordinary>::compute(
28 ode, t0, tf, y0, self.order, self.rtol, self.atol, self.h_min, self.h_max,
29 &mut evals,
30 );
31 }
32
33 match validate_step_size_parameters::<T, Y, D>(self.h0, self.h_min, self.h_max, t0, tf) {
35 Ok(h0) => self.h = h0,
36 Err(status) => return Err(status),
37 }
38
39 self.stiffness_counter = 0;
41
42 self.t = t0;
44 self.y = *y0;
45 ode.diff(self.t, &self.y, &mut self.k[0]);
46 self.dydt = self.k[0];
47 evals.function += 1;
48
49 self.t_prev = self.t;
51 self.y_prev = self.y;
52 self.dydt_prev = self.dydt;
53
54 self.status = Status::Initialized;
56
57 Ok(evals)
58 }
59
60 fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
61 where
62 F: ODE<T, Y, D>,
63 {
64 let mut evals = Evals::new();
65
66 if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
68 self.status = Status::Error(Error::StepSize {
69 t: self.t,
70 y: self.y,
71 });
72 return Err(Error::StepSize {
73 t: self.t,
74 y: self.y,
75 });
76 }
77
78 if self.steps >= self.max_steps {
80 self.status = Status::Error(Error::MaxSteps {
81 t: self.t,
82 y: self.y,
83 });
84 return Err(Error::MaxSteps {
85 t: self.t,
86 y: self.y,
87 });
88 }
89 self.steps += 1;
90
91 let mut y_stage = Y::zeros();
93 for i in 1..self.stages {
94 y_stage = Y::zeros();
95
96 for j in 0..i {
97 y_stage += self.k[j] * self.a[i][j];
98 }
99 y_stage = self.y + y_stage * self.h;
100
101 ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
102 }
103
104 let ysti = y_stage;
106
107 let mut yseg = Y::zeros();
109 for i in 0..self.stages {
110 yseg += self.k[i] * self.b[i];
111 }
112
113 let y_new = self.y + yseg * self.h;
115
116 let t_new = self.t + self.h;
118
119 evals.function += self.stages - 1; let er = self.er.unwrap();
124 let n = self.y.len();
125 let mut err = T::zero();
126 let mut err2 = T::zero();
127 let mut erri;
128 for i in 0..n {
129 let sk = self.atol + self.rtol * self.y.get(i).abs().max(y_new.get(i).abs());
131
132 erri = T::zero();
134 for j in 0..self.stages {
135 erri += er[j] * self.k[j].get(i);
136 }
137 err += (erri / sk).powi(2);
138
139 if let Some(bh) = &self.bh {
141 erri = yseg.get(i);
142 for j in 0..self.stages {
143 erri -= bh[j] * self.k[j].get(i);
144 }
145 err2 += (erri / sk).powi(2);
146 }
147 }
148 let mut deno = err + T::from_f64(0.01).unwrap() * err2;
149 if deno <= T::zero() {
150 deno = T::one();
151 }
152 err = self.h.abs() * err * (T::one() / (deno * T::from_usize(n).unwrap())).sqrt();
153
154 let order = T::from_usize(self.order).unwrap();
156 let error_exponent = T::one() / order;
157 let mut scale = self.safety_factor * err.powf(-error_exponent);
158
159 scale = scale.max(self.min_scale).min(self.max_scale);
161
162 if err <= T::one() {
164 ode.diff(t_new, &y_new, &mut self.dydt);
166 evals.function += 1;
167
168 let n_stiff_threshold = 100;
170 if self.steps % n_stiff_threshold == 0 {
171 let mut stdnum = T::zero();
172 let mut stden = T::zero();
173 let sqr = yseg - self.k[S - 1];
174 for i in 0..sqr.len() {
175 stdnum += sqr.get(i).powi(2);
176 }
177 let sqr = self.dydt - ysti;
178 for i in 0..sqr.len() {
179 stden += sqr.get(i).powi(2);
180 }
181
182 if stden > T::zero() {
183 let h_lamb = self.h * (stdnum / stden).sqrt();
184 if h_lamb > T::from_f64(6.1).unwrap() {
185 self.non_stiffness_counter = 0;
186 self.stiffness_counter += 1;
187 if self.stiffness_counter == 15 {
188 self.status = Status::Error(Error::Stiffness {
190 t: self.t,
191 y: self.y,
192 });
193 return Err(Error::Stiffness {
194 t: self.t,
195 y: self.y,
196 });
197 }
198 }
199 } else {
200 self.non_stiffness_counter += 1;
201 if self.non_stiffness_counter == 6 {
202 self.stiffness_counter = 0;
203 }
204 }
205 }
206
207 self.cont[0] = self.y;
209 let ydiff = y_new - self.y;
210 self.cont[1] = ydiff;
211 let bspl = self.k[0] * self.h - ydiff;
212 self.cont[2] = bspl;
213 self.cont[3] = ydiff - self.dydt * self.h - bspl;
214
215 if let Some(bi) = &self.bi {
217 if I > S {
219 self.k[self.stages] = self.dydt;
221
222 for i in S + 1..I {
223 let mut y_stage = Y::zeros();
224 for j in 0..i {
225 y_stage += self.k[j] * self.a[i][j];
226 }
227 y_stage = self.y + y_stage * self.h;
228
229 ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
230 evals.function += 1;
231 }
232 }
233
234 for i in 4..self.order {
236 self.cont[i] = Y::zeros();
237 for j in 0..self.dense_stages {
238 self.cont[i] += self.k[j] * bi[i][j];
239 }
240 self.cont[i] = self.cont[i] * self.h;
241 }
242 }
243
244 self.t_prev = self.t;
246 self.y_prev = self.y;
247 self.dydt_prev = self.k[0];
248 self.h_prev = self.h;
249
250 self.t = t_new;
252 self.y = y_new;
253 self.k[0] = self.dydt;
254
255 if let Status::RejectedStep = self.status {
257 self.status = Status::Solving;
258
259 scale = scale.min(T::one());
261 }
262 } else {
263 self.status = Status::RejectedStep;
265 }
266
267 self.h *= scale;
269
270 self.h = constrain_step_size(self.h, self.h_min, self.h_max);
272
273 Ok(evals)
274 }
275
276 fn t(&self) -> T {
277 self.t
278 }
279 fn y(&self) -> &Y {
280 &self.y
281 }
282 fn t_prev(&self) -> T {
283 self.t_prev
284 }
285 fn y_prev(&self) -> &Y {
286 &self.y_prev
287 }
288 fn h(&self) -> T {
289 self.h
290 }
291 fn set_h(&mut self, h: T) {
292 self.h = h;
293 }
294 fn status(&self) -> &Status<T, Y, D> {
295 &self.status
296 }
297 fn set_status(&mut self, status: Status<T, Y, D>) {
298 self.status = status;
299 }
300}
301
302impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
303 Interpolation<T, Y> for ExplicitRungeKutta<Ordinary, DormandPrince, T, Y, D, O, S, I>
304{
305 fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
306 if t_interp < self.t_prev || t_interp > self.t {
308 return Err(Error::OutOfBounds {
309 t_interp,
310 t_prev: self.t_prev,
311 t_curr: self.t,
312 });
313 }
314
315 let s = (t_interp - self.t_prev) / self.h_prev;
317 let s1 = T::one() - s;
318
319 let ilast = self.cont.len() - 1;
321 let poly = (1..ilast).rev().fold(self.cont[ilast], |acc, i| {
322 let factor = if i >= 4 {
323 if (ilast - i) % 2 == 1 { s1 } else { s }
325 } else {
326 if i % 2 == 1 { s1 } else { s }
328 };
329 acc * factor + self.cont[i]
330 });
331
332 let y_interp = self.cont[0] + poly * s;
334
335 Ok(y_interp)
336 }
337}