differential_equations/methods/erk/dormandprince/
ordinary.rs1use crate::{
4 error::Error,
5 interpolate::Interpolation,
6 methods::{DormandPrince, ExplicitRungeKutta, Ordinary, h_init::InitialStepSize},
7 ode::{ODE, OrdinaryNumericalMethod},
8 stats::Evals,
9 status::Status,
10 traits::{Real, State},
11 utils::{constrain_step_size, validate_step_size_parameters},
12};
13
14impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
15 OrdinaryNumericalMethod<T, Y> for ExplicitRungeKutta<Ordinary, DormandPrince, T, Y, O, S, I>
16{
17 fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
18 where
19 F: ODE<T, Y>,
20 {
21 let mut evals = Evals::new();
22
23 if self.h0 == T::zero() {
25 self.h0 = InitialStepSize::<Ordinary>::compute(
27 ode, t0, tf, y0, self.order, &self.rtol, &self.atol, self.h_min, self.h_max,
28 &mut evals,
29 );
30 }
31
32 match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
34 Ok(h0) => self.h = h0,
35 Err(status) => return Err(status),
36 }
37
38 self.stiffness_counter = 0;
40
41 self.t = t0;
43 self.y = *y0;
44 ode.diff(self.t, &self.y, &mut self.k[0]);
45 self.dydt = self.k[0];
46 evals.function += 1;
47
48 self.t_prev = self.t;
50 self.y_prev = self.y;
51 self.dydt_prev = self.dydt;
52
53 self.status = Status::Initialized;
55
56 Ok(evals)
57 }
58
59 fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
60 where
61 F: ODE<T, Y>,
62 {
63 let mut evals = Evals::new();
64
65 if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
67 self.status = Status::Error(Error::StepSize {
68 t: self.t,
69 y: self.y,
70 });
71 return Err(Error::StepSize {
72 t: self.t,
73 y: self.y,
74 });
75 }
76
77 if self.steps >= self.max_steps {
79 self.status = Status::Error(Error::MaxSteps {
80 t: self.t,
81 y: self.y,
82 });
83 return Err(Error::MaxSteps {
84 t: self.t,
85 y: self.y,
86 });
87 }
88 self.steps += 1;
89
90 let mut y_stage = Y::zeros();
92 for i in 1..self.stages {
93 y_stage = Y::zeros();
94
95 for j in 0..i {
96 y_stage += self.k[j] * self.a[i][j];
97 }
98 y_stage = self.y + y_stage * self.h;
99
100 ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
101 }
102
103 let ysti = y_stage;
105
106 let mut yseg = Y::zeros();
108 for i in 0..self.stages {
109 yseg += self.k[i] * self.b[i];
110 }
111
112 let y_new = self.y + yseg * self.h;
114
115 let t_new = self.t + self.h;
117
118 evals.function += self.stages - 1; let er = self.er.unwrap();
123 let n = self.y.len();
124 let mut err = T::zero();
125 let mut err2 = T::zero();
126 let mut erri;
127 for i in 0..n {
128 let sk = self.atol[i] + self.rtol[i] * self.y.get(i).abs().max(y_new.get(i).abs());
130
131 erri = T::zero();
133 for j in 0..self.stages {
134 erri += er[j] * self.k[j].get(i);
135 }
136 err += (erri / sk).powi(2);
137
138 if let Some(bh) = &self.bh {
140 erri = yseg.get(i);
141 for j in 0..self.stages {
142 erri -= bh[j] * self.k[j].get(i);
143 }
144 err2 += (erri / sk).powi(2);
145 }
146 }
147 let mut deno = err + T::from_f64(0.01).unwrap() * err2;
148 if deno <= T::zero() {
149 deno = T::one();
150 }
151 err = self.h.abs() * err * (T::one() / (deno * T::from_usize(n).unwrap())).sqrt();
152
153 let order = T::from_usize(self.order).unwrap();
155 let error_exponent = T::one() / order;
156 let mut scale = self.safety_factor * err.powf(-error_exponent);
157
158 scale = scale.max(self.min_scale).min(self.max_scale);
160
161 if err <= T::one() {
163 ode.diff(t_new, &y_new, &mut self.dydt);
165 evals.function += 1;
166
167 let n_stiff_threshold = 100;
169 if self.steps % n_stiff_threshold == 0 {
170 let mut stdnum = T::zero();
171 let mut stden = T::zero();
172 let sqr = yseg - self.k[S - 1];
173 for i in 0..sqr.len() {
174 stdnum += sqr.get(i).powi(2);
175 }
176 let sqr = self.dydt - ysti;
177 for i in 0..sqr.len() {
178 stden += sqr.get(i).powi(2);
179 }
180
181 if stden > T::zero() {
182 let h_lamb = self.h * (stdnum / stden).sqrt();
183 if h_lamb > T::from_f64(6.1).unwrap() {
184 self.non_stiffness_counter = 0;
185 self.stiffness_counter += 1;
186 if self.stiffness_counter == 15 {
187 self.status = Status::Error(Error::Stiffness {
189 t: self.t,
190 y: self.y,
191 });
192 return Err(Error::Stiffness {
193 t: self.t,
194 y: self.y,
195 });
196 }
197 }
198 } else {
199 self.non_stiffness_counter += 1;
200 if self.non_stiffness_counter == 6 {
201 self.stiffness_counter = 0;
202 }
203 }
204 }
205
206 self.cont[0] = self.y;
208 let ydiff = y_new - self.y;
209 self.cont[1] = ydiff;
210 let bspl = self.k[0] * self.h - ydiff;
211 self.cont[2] = bspl;
212 self.cont[3] = ydiff - self.dydt * self.h - bspl;
213
214 if let Some(bi) = &self.bi {
216 if I > S {
218 self.k[self.stages] = self.dydt;
220
221 for i in S + 1..I {
222 let mut y_stage = Y::zeros();
223 for j in 0..i {
224 y_stage += self.k[j] * self.a[i][j];
225 }
226 y_stage = self.y + y_stage * self.h;
227
228 ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
229 evals.function += 1;
230 }
231 }
232
233 for i in 4..self.order {
235 self.cont[i] = Y::zeros();
236 for j in 0..self.dense_stages {
237 self.cont[i] += self.k[j] * bi[i][j];
238 }
239 self.cont[i] = self.cont[i] * self.h;
240 }
241 }
242
243 self.t_prev = self.t;
245 self.y_prev = self.y;
246 self.dydt_prev = self.k[0];
247 self.h_prev = self.h;
248
249 self.t = t_new;
251 self.y = y_new;
252 self.k[0] = self.dydt;
253
254 if let Status::RejectedStep = self.status {
256 self.status = Status::Solving;
257
258 scale = scale.min(T::one());
260 }
261 } else {
262 self.status = Status::RejectedStep;
264 }
265
266 self.h *= scale;
268
269 self.h = constrain_step_size(self.h, self.h_min, self.h_max);
271
272 Ok(evals)
273 }
274
275 fn t(&self) -> T {
276 self.t
277 }
278 fn y(&self) -> &Y {
279 &self.y
280 }
281 fn t_prev(&self) -> T {
282 self.t_prev
283 }
284 fn y_prev(&self) -> &Y {
285 &self.y_prev
286 }
287 fn h(&self) -> T {
288 self.h
289 }
290 fn set_h(&mut self, h: T) {
291 self.h = h;
292 }
293 fn status(&self) -> &Status<T, Y> {
294 &self.status
295 }
296 fn set_status(&mut self, status: Status<T, Y>) {
297 self.status = status;
298 }
299}
300
301impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
302 for ExplicitRungeKutta<Ordinary, DormandPrince, T, Y, O, S, I>
303{
304 fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
305 if t_interp < self.t_prev || t_interp > self.t {
307 return Err(Error::OutOfBounds {
308 t_interp,
309 t_prev: self.t_prev,
310 t_curr: self.t,
311 });
312 }
313
314 let s = (t_interp - self.t_prev) / self.h_prev;
316 let s1 = T::one() - s;
317
318 let ilast = self.cont.len() - 1;
320 let poly = (1..ilast).rev().fold(self.cont[ilast], |acc, i| {
321 let factor = if i >= 4 {
322 if (ilast - i) % 2 == 1 { s1 } else { s }
324 } else {
325 if i % 2 == 1 { s1 } else { s }
327 };
328 acc * factor + self.cont[i]
329 });
330
331 let y_interp = self.cont[0] + poly * s;
333
334 Ok(y_interp)
335 }
336}