1use std::collections::VecDeque;
4
5use crate::{
6 dde::{DDE, DelayNumericalMethod},
7 error::Error,
8 interpolate::{Interpolation, cubic_hermite_interpolate},
9 methods::{Delay, ExplicitRungeKutta, Fixed},
10 stats::Evals,
11 status::Status,
12 traits::{Real, State},
13 utils::validate_step_size_parameters,
14};
15
16impl<
17 const L: usize,
18 T: Real,
19 Y: State<T>,
20 H: Fn(T) -> Y,
21 const O: usize,
22 const S: usize,
23 const I: usize,
24> DelayNumericalMethod<L, T, Y, H> for ExplicitRungeKutta<Delay, Fixed, T, Y, O, S, I>
25{
26 fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &Y, phi: &H) -> Result<Evals, Error<T, Y>>
27 where
28 F: DDE<L, T, Y>,
29 {
30 let mut evals = Evals::new();
32
33 if L <= 0 {
35 return Err(Error::NoLags);
36 }
37 self.t0 = t0;
38 self.t = t0;
39 self.y = *y0;
40 self.t_prev = self.t;
41 self.y_prev = self.y;
42 self.status = Status::Initialized;
43 self.steps = 0;
44 self.history = VecDeque::new();
45
46 let mut delays = [T::zero(); L];
48 let mut y_delayed = [Y::zeros(); L];
49
50 dde.lags(self.t, &self.y, &mut delays);
52 for i in 0..L {
53 let t_delayed = self.t - delays[i];
54 if (t_delayed - t0) * (tf - t0).signum() > T::default_epsilon() {
56 return Err(Error::BadInput {
57 msg: format!(
58 "Initial delayed time {} is out of history range (t <= {}).",
59 t_delayed, t0
60 ),
61 });
62 }
63 y_delayed[i] = phi(t_delayed);
64 }
65
66 dde.diff(self.t, &self.y, &y_delayed, &mut self.dydt);
68 evals.function += 1;
69 self.dydt_prev = self.dydt; self.history.push_back((self.t, self.y, self.dydt));
71
72 if self.h0 == T::zero() {
74 let duration = (tf - t0).abs();
75 let default_steps = T::from_usize(100).unwrap();
76 self.h0 = duration / default_steps;
77 }
78
79 match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
81 Ok(h0) => self.h = h0,
82 Err(status) => return Err(status),
83 }
84 Ok(evals)
85 }
86
87 fn step<F>(&mut self, dde: &F, phi: &H) -> Result<Evals, Error<T, Y>>
88 where
89 F: DDE<L, T, Y>,
90 {
91 let mut evals = Evals::new();
92
93 if self.steps >= self.max_steps {
95 self.status = Status::Error(Error::MaxSteps {
96 t: self.t,
97 y: self.y,
98 });
99 return Err(Error::MaxSteps {
100 t: self.t,
101 y: self.y,
102 });
103 }
104 self.steps += 1;
105
106 let mut delays = [T::zero(); L];
108 let mut y_delayed = [Y::zeros(); L];
109
110 self.k[0] = self.dydt;
113 let mut min_delay_abs = T::infinity();
114 let y_pred_for_lags = self.y + self.k[0] * self.h;
116 dde.lags(self.t + self.h, &y_pred_for_lags, &mut delays);
117 for i in 0..L {
118 min_delay_abs = min_delay_abs.min(delays[i].abs());
119 }
120
121 let max_iter: usize = if min_delay_abs < self.h.abs() && min_delay_abs > T::zero() {
123 5
124 } else {
125 1
126 };
127
128 let mut y_next_candidate_iter = self.y; let mut dydt_next_candidate_iter = Y::zeros(); let mut y_prev_candidate_iter = self.y; let mut dde_iteration_failed = false;
132
133 for iter_idx in 0..max_iter {
135 if iter_idx > 0 {
136 y_prev_candidate_iter = y_next_candidate_iter;
137 }
138
139 for i in 1..self.stages {
141 let mut y_stage = self.y;
142 for j in 0..i {
143 y_stage += self.k[j] * (self.a[i][j] * self.h);
144 }
145 dde.lags(self.t + self.c[i] * self.h, &y_stage, &mut delays);
147 if let Err(e) =
148 self.lagvals(self.t + self.c[i] * self.h, &delays, &mut y_delayed, phi)
149 {
150 self.status = Status::Error(e.clone());
151 return Err(e);
152 }
153 dde.diff(
154 self.t + self.c[i] * self.h,
155 &y_stage,
156 &y_delayed,
157 &mut self.k[i],
158 );
159 }
160 evals.function += self.stages - 1;
161
162 let mut y_next = self.y;
164 for i in 0..self.stages {
165 y_next += self.k[i] * (self.b[i] * self.h);
166 }
167
168 if max_iter > 1 && iter_idx > 0 {
170 let mut dde_iteration_error = T::zero();
171 let n_dim = self.y.len();
172 for i_dim in 0..n_dim {
173 let scale = T::from_f64(1e-10).unwrap()
174 + y_prev_candidate_iter
175 .get(i_dim)
176 .abs()
177 .max(y_next.get(i_dim).abs());
178 if scale > T::zero() {
179 let diff_val = y_next.get(i_dim) - y_prev_candidate_iter.get(i_dim);
180 dde_iteration_error += (diff_val / scale).powi(2);
181 }
182 }
183 if n_dim > 0 {
184 dde_iteration_error =
185 (dde_iteration_error / T::from_usize(n_dim).unwrap()).sqrt();
186 }
187
188 if dde_iteration_error <= T::from_f64(1e-6).unwrap() {
189 break;
190 }
191 if iter_idx == max_iter - 1 {
192 dde_iteration_failed = dde_iteration_error > T::from_f64(1e-6).unwrap();
193 }
194 }
195 y_next_candidate_iter = y_next;
196
197 dde.lags(self.t + self.h, &y_next_candidate_iter, &mut delays);
199 if let Err(e) = self.lagvals(self.t + self.h, &delays, &mut y_delayed, phi) {
200 self.status = Status::Error(e.clone());
201 return Err(e);
202 }
203 dde.diff(
204 self.t + self.h,
205 &y_next_candidate_iter,
206 &y_delayed,
207 &mut dydt_next_candidate_iter,
208 );
209 evals.function += 1;
210 }
211
212 if dde_iteration_failed {
214 let sign = self.h.signum();
215 self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * sign;
216 if L > 0
217 && min_delay_abs > T::zero()
218 && self.h.abs() < T::from_f64(2.0).unwrap() * min_delay_abs
219 {
220 self.h = min_delay_abs * sign;
221 }
222 self.status = Status::RejectedStep;
223 return Ok(evals);
224 }
225
226 self.t_prev = self.t;
228 self.y_prev = self.y;
229 self.dydt_prev = self.dydt;
230
231 self.t += self.h;
233 self.y = y_next_candidate_iter;
234
235 if self.fsal {
237 self.dydt = self.k[S - 1];
238 } else {
239 dde.lags(self.t, &self.y, &mut delays);
240 if let Err(e) = self.lagvals(self.t, &delays, &mut y_delayed, phi) {
241 self.status = Status::Error(e.clone());
242 return Err(e);
243 }
244 dde.diff(self.t, &self.y, &y_delayed, &mut self.dydt);
245 evals.function += 1;
246 }
247
248 if self.bi.is_some() {
250 for i in 0..(I - S) {
251 let mut y_stage_dense = self.y_prev;
252 for j in 0..self.stages + i {
253 y_stage_dense += self.k[j] * (self.a[self.stages + i][j] * self.h);
254 }
255 let t_stage = self.t_prev + self.c[self.stages + i] * self.h;
256 dde.lags(t_stage, &y_stage_dense, &mut delays);
257 if let Err(e) = self.lagvals(t_stage, &delays, &mut y_delayed, phi) {
258 self.status = Status::Error(e.clone());
259 return Err(e);
260 }
261 dde.diff(
262 self.t_prev + self.c[self.stages + i] * self.h,
263 &y_stage_dense,
264 &y_delayed,
265 &mut self.k[self.stages + i],
266 );
267 }
268 evals.function += I - S;
269 }
270
271 self.history.push_back((self.t, self.y, self.dydt));
273 if let Some(max_delay) = self.max_delay {
274 let cutoff_time = self.t - max_delay;
275 while let Some((t_front, _, _)) = self.history.get(1) {
276 if *t_front < cutoff_time {
277 self.history.pop_front();
278 } else {
279 break;
280 }
281 }
282 }
283
284 self.status = Status::Solving;
285 Ok(evals)
286 }
287
288 fn t(&self) -> T {
289 self.t
290 }
291 fn y(&self) -> &Y {
292 &self.y
293 }
294 fn t_prev(&self) -> T {
295 self.t_prev
296 }
297 fn y_prev(&self) -> &Y {
298 &self.y_prev
299 }
300 fn h(&self) -> T {
301 self.h
302 }
303 fn set_h(&mut self, h: T) {
304 self.h = h;
305 }
306 fn status(&self) -> &Status<T, Y> {
307 &self.status
308 }
309 fn set_status(&mut self, status: Status<T, Y>) {
310 self.status = status;
311 }
312}
313
314impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
315 ExplicitRungeKutta<Delay, Fixed, T, Y, O, S, I>
316{
317 pub fn lagvals<const L: usize, H>(
318 &mut self,
319 t_stage: T,
320 delays: &[T; L],
321 y_delayed: &mut [Y; L],
322 phi: &H,
323 ) -> Result<(), Error<T, Y>>
324 where
325 H: Fn(T) -> Y,
326 {
327 for i in 0..L {
328 let t_delayed = t_stage - delays[i];
329
330 if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
332 y_delayed[i] = phi(t_delayed);
333 } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
335 if self.bi.is_some() {
336 let s = (t_delayed - self.t_prev) / self.h_prev;
337
338 let bi_coeffs = self.bi.as_ref().unwrap();
339
340 let mut cont = [T::zero(); I];
341 for i in 0..I {
342 if i < cont.len() && i < bi_coeffs.len() {
343 cont[i] = bi_coeffs[i][self.dense_stages - 1];
344 for j in (0..self.dense_stages - 1).rev() {
345 cont[i] = cont[i] * s + bi_coeffs[i][j];
346 }
347 cont[i] *= s;
348 }
349 }
350
351 let mut y_interp = self.y_prev;
352 for i in 0..I {
353 if i < self.k.len() && i < cont.len() {
354 y_interp += self.k[i] * (cont[i] * self.h_prev);
355 }
356 }
357 y_delayed[i] = y_interp;
358 } else {
359 y_delayed[i] = cubic_hermite_interpolate(
360 self.t_prev,
361 self.t,
362 &self.y_prev,
363 &self.y,
364 &self.dydt_prev,
365 &self.dydt,
366 t_delayed,
367 );
368 } } else {
370 let mut found_interpolation = false;
372 let buffer = &self.history;
373 let mut buffer_iter = buffer.iter();
375 if let Some(mut prev_entry) = buffer_iter.next() {
376 for curr_entry in buffer_iter {
377 let (t_left, y_left, dydt_left) = prev_entry;
378 let (t_right, y_right, dydt_right) = curr_entry;
379
380 let is_between = if self.h.signum() > T::zero() {
382 *t_left <= t_delayed && t_delayed <= *t_right
384 } else {
385 *t_right <= t_delayed && t_delayed <= *t_left
387 };
388
389 if is_between {
390 y_delayed[i] = cubic_hermite_interpolate(
392 *t_left, *t_right, y_left, y_right, dydt_left, dydt_right,
393 t_delayed,
394 );
395 found_interpolation = true;
396 break;
397 }
398 prev_entry = curr_entry;
399 }
400 } if !found_interpolation {
402 return Err(Error::InsufficientHistory {
403 t_delayed,
404 t_prev: self.t_prev,
405 t_curr: self.t,
406 });
407 }
408 }
409 }
410 Ok(())
411 }
412}
413
414impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
415 for ExplicitRungeKutta<Delay, Fixed, T, Y, O, S, I>
416{
417 fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
419 let dir = self.h.signum();
420 if (t_interp - self.t_prev) * dir < T::zero() || (t_interp - self.t) * dir > T::zero() {
421 return Err(Error::OutOfBounds {
422 t_interp,
423 t_prev: self.t_prev,
424 t_curr: self.t,
425 });
426 }
427
428 if self.bi.is_some() {
430 let s = (t_interp - self.t_prev) / self.h_prev;
431
432 let bi = self.bi.as_ref().unwrap();
433
434 let mut cont = [T::zero(); I];
435 for i in 0..self.dense_stages {
436 cont[i] = bi[i][self.order - 1];
437 for j in (0..self.order - 1).rev() {
438 cont[i] = cont[i] * s + bi[i][j];
439 }
440 cont[i] *= s;
441 }
442
443 let mut y_interp = self.y_prev;
444 for i in 0..I {
445 y_interp += self.k[i] * cont[i] * self.h_prev;
446 }
447
448 Ok(y_interp)
449 } else {
450 let y_interp = cubic_hermite_interpolate(
452 self.t_prev,
453 self.t,
454 &self.y_prev,
455 &self.y,
456 &self.dydt_prev,
457 &self.dydt,
458 t_interp,
459 );
460
461 Ok(y_interp)
462 }
463 }
464}