differential_equations/methods/erk/dormandprince/
ordinary.rs1use crate::{
3 error::Error,
4 interpolate::Interpolation,
5 methods::{DormandPrince, ExplicitRungeKutta, Ordinary, h_init::InitialStepSize},
6 ode::{ODE, OrdinaryNumericalMethod},
7 stats::Evals,
8 status::Status,
9 traits::{Real, State},
10 utils::{constrain_step_size, validate_step_size_parameters},
11};
12
13impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
14 OrdinaryNumericalMethod<T, Y> for ExplicitRungeKutta<Ordinary, DormandPrince, T, Y, O, S, I>
15{
16 fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
17 where
18 F: ODE<T, Y> + ?Sized,
19 {
20 let mut evals = Evals::new();
21
22 if self.h0 == T::zero() {
24 self.h0 = InitialStepSize::<Ordinary>::compute(
26 ode, t0, tf, y0, self.order, &self.rtol, &self.atol, self.h_min, self.h_max,
27 &mut evals,
28 );
29 }
30
31 match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
33 Ok(h0) => self.h = (self.filter)(h0),
34 Err(status) => return Err(status),
35 }
36
37 self.stiffness_counter = 0;
39
40 self.t = t0;
42 self.y = y0.clone();
43 self.dydt = y0.zeros_like();
44 self.y_prev = y0.clone();
45 self.dydt_prev = y0.zeros_like();
46 self.k = core::array::from_fn(|_| y0.zeros_like());
47 self.cont = core::array::from_fn(|_| y0.zeros_like());
48 ode.diff(self.t, &self.y, &mut self.k[0]);
49 self.dydt = self.k[0].clone();
50 evals.function += 1;
51
52 self.t_prev = self.t;
54 self.y_prev = self.y.clone();
55 self.dydt_prev = self.dydt.clone();
56
57 self.status = Status::Initialized;
59
60 Ok(evals)
61 }
62
63 fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
64 where
65 F: ODE<T, Y> + ?Sized,
66 {
67 let mut evals = Evals::new();
68
69 if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
71 self.status = Status::Error(Error::StepSize {
72 t: self.t,
73 y: self.y.clone(),
74 });
75 return Err(Error::StepSize {
76 t: self.t,
77 y: self.y.clone(),
78 });
79 }
80
81 if self.steps >= self.max_steps {
83 self.status = Status::Error(Error::MaxSteps {
84 t: self.t,
85 y: self.y.clone(),
86 });
87 return Err(Error::MaxSteps {
88 t: self.t,
89 y: self.y.clone(),
90 });
91 }
92 self.steps += 1;
93
94 let mut y_stage = self.y.zeros_like();
96 for i in 1..self.stages {
97 y_stage = self.y.clone();
98
99 for j in 0..i {
100 y_stage.add_scaled(self.a[i][j] * self.h, &self.k[j]);
101 }
102
103 ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
104 }
105
106 let ysti = y_stage.clone();
108
109 let mut yseg = self.y.zeros_like();
111 for i in 0..self.stages {
112 yseg.add_scaled(self.b[i], &self.k[i]);
113 }
114
115 let y_new = self.y.plus_scaled(self.h, &yseg);
117
118 let t_new = self.t + self.h;
120
121 evals.function += self.stages - 1; let er = self.er.unwrap();
126 let n = self.y.len();
127 let mut err2 = T::zero();
128 let mut err_state = self.y.zeros_like();
129 for (j, coefficient) in er.iter().enumerate().take(self.stages) {
130 err_state.add_scaled(*coefficient, &self.k[j]);
131 }
132 let mut err = self
133 .y
134 .error_norm(&y_new, &err_state, &self.atol, &self.rtol);
135
136 if let Some(bh) = &self.bh {
137 let mut err2_state = yseg.clone();
138 for (j, coefficient) in bh.iter().enumerate().take(self.stages) {
139 err2_state.add_scaled(-*coefficient, &self.k[j]);
140 }
141 err2 = self
142 .y
143 .error_norm(&y_new, &err2_state, &self.atol, &self.rtol);
144 }
145 let mut deno = err + T::from_f64(0.01).unwrap() * err2;
146 if deno <= T::zero() {
147 deno = T::one();
148 }
149 err = self.h.abs() * err * (T::one() / (deno * T::from_usize(n).unwrap())).sqrt();
150
151 let order = T::from_usize(self.order).unwrap();
153 let error_exponent = T::one() / order;
154 let mut scale = self.safety_factor * err.powf(-error_exponent);
155
156 scale = scale.max(self.min_scale).min(self.max_scale);
158
159 if err <= T::one() {
161 ode.diff(t_new, &y_new, &mut self.dydt);
163 evals.function += 1;
164
165 let n_stiff_threshold = 100;
167 if self.steps.is_multiple_of(n_stiff_threshold) {
168 let stdnum = yseg.diff_norm_squared(&self.k[S - 1]);
169 let stden = self.dydt.diff_norm_squared(&ysti);
170
171 if stden > T::zero() {
172 let h_lamb = self.h * (stdnum / stden).sqrt();
173 if h_lamb > T::from_f64(6.1).unwrap() {
174 self.non_stiffness_counter = 0;
175 self.stiffness_counter += 1;
176 if self.stiffness_counter == 15 {
177 self.status = Status::Error(Error::Stiffness {
179 t: self.t,
180 y: self.y.clone(),
181 });
182 return Err(Error::Stiffness {
183 t: self.t,
184 y: self.y.clone(),
185 });
186 }
187 }
188 } else {
189 self.non_stiffness_counter += 1;
190 if self.non_stiffness_counter == 6 {
191 self.stiffness_counter = 0;
192 }
193 }
194 }
195
196 self.cont[0] = self.y.clone();
198 let ydiff = y_new.minus(&self.y);
199 self.cont[1] = ydiff.clone();
200 let mut bspl = ydiff.zeros_like();
201 bspl.add_scaled(self.h, &self.k[0]);
202 bspl.add_scaled(-T::one(), &ydiff);
203 self.cont[2] = bspl.clone();
204 let mut cont3 = ydiff;
205 cont3.add_scaled(-self.h, &self.dydt);
206 cont3.add_scaled(-T::one(), &bspl);
207 self.cont[3] = cont3;
208
209 if let Some(bi) = &self.bi {
211 if I > S {
213 self.k[self.stages] = self.dydt.clone();
215
216 for i in S + 1..I {
217 let mut y_stage = self.y.clone();
218 for j in 0..i {
219 y_stage.add_scaled(self.a[i][j] * self.h, &self.k[j]);
220 }
221
222 ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
223 evals.function += 1;
224 }
225 }
226
227 for i in 4..self.order {
229 self.cont[i].fill(T::zero());
230 for j in 0..self.dense_stages {
231 self.cont[i].add_scaled(bi[i][j], &self.k[j]);
232 }
233 self.cont[i].scale_by(self.h);
234 }
235 }
236
237 self.t_prev = self.t;
239 self.y_prev = self.y.clone();
240 self.dydt_prev = self.k[0].clone();
241 self.h_prev = self.h;
242
243 self.t = t_new;
245 self.y = y_new;
246 self.k[0] = self.dydt.clone();
247
248 if let Status::RejectedStep = self.status {
250 self.status = Status::Solving;
251
252 scale = scale.min(T::one());
254 }
255 } else {
256 self.status = Status::RejectedStep;
258 }
259
260 self.h *= scale;
262
263 self.h = constrain_step_size(self.h, self.h_min, self.h_max);
265
266 self.h = (self.filter)(self.h);
268
269 Ok(evals)
270 }
271
272 fn t(&self) -> T {
273 self.t
274 }
275 fn y(&self) -> &Y {
276 &self.y
277 }
278 fn t_prev(&self) -> T {
279 self.t_prev
280 }
281 fn y_prev(&self) -> &Y {
282 &self.y_prev
283 }
284 fn h(&self) -> T {
285 self.h
286 }
287 fn set_h(&mut self, h: T) {
288 self.h = (self.filter)(h);
289 }
290 fn status(&self) -> &Status<T, Y> {
291 &self.status
292 }
293 fn set_status(&mut self, status: Status<T, Y>) {
294 self.status = status;
295 }
296}
297
298impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
299 for ExplicitRungeKutta<Ordinary, DormandPrince, T, Y, O, S, I>
300{
301 fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
302 if t_interp < self.t_prev || t_interp > self.t {
304 return Err(Error::OutOfBounds {
305 t_interp,
306 t_prev: self.t_prev,
307 t_curr: self.t,
308 });
309 }
310
311 let s = (t_interp - self.t_prev) / self.h_prev;
313 let s1 = T::one() - s;
314
315 let ilast = self.cont.len() - 1;
317 let poly = (1..ilast)
318 .rev()
319 .fold(self.cont[ilast].clone(), |mut acc, i| {
320 let factor = if i >= 4 {
321 if (ilast - i) % 2 == 1 { s1 } else { s }
323 } else {
324 if i % 2 == 1 { s1 } else { s }
326 };
327 acc.scale_by(factor);
328 acc.add_scaled(T::one(), &self.cont[i]);
329 acc
330 });
331
332 let y_interp = self.cont[0].plus_scaled(s, &poly);
334
335 Ok(y_interp)
336 }
337}