1use super::{ExplicitRungeKutta, Delay, DormandPrince};
4use crate::{
5 Error, Status,
6 methods::h_init::InitialStepSize,
7 alias::Evals,
8 interpolate::{Interpolation, cubic_hermite_interpolate},
9 dde::{DelayNumericalMethod, DDE},
10 traits::{CallBackData, Real, State},
11 utils::{constrain_step_size, validate_step_size_parameters},
12};
13use std::collections::VecDeque;
14
15impl<const L: usize, T: Real, V: State<T>, H: Fn(T) -> V, D: CallBackData, const O: usize, const S: usize, const I: usize> DelayNumericalMethod<L, T, V, H, D> for ExplicitRungeKutta<Delay, DormandPrince, T, V, D, O, S, I> {
16 fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &V, phi: &H) -> Result<Evals, Error<T, V>>
17 where
18 F: DDE<L, T, V, D>,
19 {
20 let mut evals = Evals::new();
21
22 self.t0 = t0;
24 self.t = t0;
25 self.y = *y0;
26 self.t_prev = self.t;
27 self.y_prev = self.y;
28 self.status = Status::Initialized;
29 self.steps = 0;
30 self.stiffness_counter = 0;
31 self.non_stiffness_counter = 0;
32 self.history = VecDeque::new();
33
34 let mut lags = [T::zero(); L];
36 let mut yd = [V::zeros(); L];
37
38 if L > 0 {
40 dde.lags(self.t, &self.y, &mut lags);
41 for i in 0..L {
42 if lags[i] <= T::zero() {
43 return Err(Error::BadInput {
44 msg: "All lags must be positive.".to_string(),
45 });
46 }
47 let t_delayed = self.t - lags[i];
48 if (t_delayed - t0) * (tf - t0).signum() > T::default_epsilon() {
50 return Err(Error::BadInput {
51 msg: format!(
52 "Delayed time {} is beyond initial time {}",
53 t_delayed, t0
54 ),
55 });
56 }
57 yd[i] = phi(t_delayed);
58 }
59 }
60
61 dde.diff(self.t, &self.y, &yd, &mut self.k[0]);
63 self.dydt = self.k[0];
64 evals.fcn += 1;
65 self.dydt_prev = self.dydt;
66
67 self.history.push_back((self.t, self.y, self.dydt));
69
70 if self.h0 == T::zero() {
72 self.h0 = InitialStepSize::<Delay>::compute(dde, t0, tf, y0, self.order, self.rtol, self.atol, self.h_min, self.h_max, phi, &self.k[0], &mut evals);
74 }
75
76 match validate_step_size_parameters::<T, V, D>(self.h0, self.h_min, self.h_max, t0, tf) {
78 Ok(h0) => self.h = h0,
79 Err(status) => return Err(status),
80 }
81 Ok(evals)
82 }
83
84 fn step<F>(&mut self, dde: &F, phi: &H) -> Result<Evals, Error<T, V>>
85 where
86 F: DDE<L, T, V, D>,
87 {
88 let mut evals = Evals::new();
89
90 if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
92 self.status = Status::Error(Error::StepSize { t: self.t, y: self.y });
93 return Err(Error::StepSize { t: self.t, y: self.y });
94 }
95
96 if self.steps >= self.max_steps {
98 self.status = Status::Error(Error::MaxSteps { t: self.t, y: self.y });
99 return Err(Error::MaxSteps { t: self.t, y: self.y });
100 }
101 self.steps += 1;
102
103 let mut lags = [T::zero(); L];
105 let mut yd = [V::zeros(); L];
106
107 let mut min_lag_abs = T::infinity();
109 if L > 0 {
110 let y_pred_for_lags = self.y + self.k[0] * self.h;
112 dde.lags(self.t + self.h, &y_pred_for_lags, &mut lags);
113 for i in 0..L {
114 min_lag_abs = min_lag_abs.min(lags[i].abs());
115 }
116 }
117
118 let max_iter: usize = if L > 0 && min_lag_abs < self.h.abs() && min_lag_abs > T::zero() {
120 5
121 } else {
122 1
123 }; let mut y_next_candidate_iter = self.y; let mut y_prev_candidate_iter = self.y; let mut dde_iteration_failed = false;
126 let mut err: T = T::zero(); let mut ysti = V::zeros(); for iter_idx in 0..max_iter {
131 if iter_idx > 0 {
132 y_prev_candidate_iter = y_next_candidate_iter;
133 }
134
135 let mut y_stage = V::zeros();
137 for i in 1..self.stages {
138 y_stage = V::zeros();
139 for j in 0..i {
140 y_stage += self.k[j] * self.a[i][j];
141 }
142 y_stage = self.y + y_stage * self.h;
143
144 if L > 0 {
146 dde.lags(self.t + self.c[i] * self.h, &y_stage, &mut lags);
147 self.lagvals(self.t + self.c[i] * self.h, &lags, &mut yd, phi);
148 }
149 dde.diff(self.t + self.c[i] * self.h, &y_stage, &yd, &mut self.k[i]);
150 }
151 evals.fcn += self.stages - 1; ysti = y_stage;
155
156 let mut yseg = V::zeros();
158 for i in 0..self.stages {
159 yseg += self.k[i] * self.b[i];
160 }
161
162 let y_new = self.y + yseg * self.h;
164
165 let er = self.er.unwrap();
167 let n = self.y.len();
168 let mut err_val = T::zero();
169 let mut err2 = T::zero();
170 let mut erri;
171 for i in 0..n {
172 let sk = self.atol + self.rtol * self.y.get(i).abs().max(y_new.get(i).abs());
174
175 erri = T::zero();
177 for j in 0..self.stages {
178 erri += er[j] * self.k[j].get(i);
179 }
180 err_val += (erri / sk).powi(2);
181
182 if let Some(bh) = &self.bh {
184 erri = yseg.get(i);
185 for j in 0..self.stages {
186 erri -= bh[j] * self.k[j].get(i);
187 }
188 err2 += (erri / sk).powi(2);
189 }
190 }
191 let mut deno = err_val + T::from_f64(0.01).unwrap() * err2;
192 if deno <= T::zero() {
193 deno = T::one();
194 }
195 err = self.h.abs() * err_val * (T::one() / (deno * T::from_usize(n).unwrap())).sqrt();
196
197 if max_iter > 1 && iter_idx > 0 {
199 let mut dde_iteration_error = T::zero();
200 let n_dim = self.y.len();
201 for i_dim in 0..n_dim {
202 let scale = self.atol + self.rtol * y_prev_candidate_iter.get(i_dim).abs().max(y_new.get(i_dim).abs());
203 if scale > T::zero() {
204 let diff_val = y_new.get(i_dim) - y_prev_candidate_iter.get(i_dim);
205 dde_iteration_error += (diff_val / scale).powi(2);
206 }
207 }
208 if n_dim > 0 {
209 dde_iteration_error = (dde_iteration_error / T::from_usize(n_dim).unwrap()).sqrt();
210 }
211
212 if dde_iteration_error <= self.rtol * T::from_f64(0.1).unwrap() {
213 break; }
215 if iter_idx == max_iter - 1 { dde_iteration_failed = dde_iteration_error > self.rtol * T::from_f64(0.1).unwrap();
217 }
218 }
219 y_next_candidate_iter = y_new; if iter_idx == max_iter - 1 || max_iter == 1 {
223 }
225 } if dde_iteration_failed {
229 let sign = self.h.signum();
230 self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * sign;
231 if L > 0 && min_lag_abs > T::zero() && self.h.abs() < T::from_f64(2.0).unwrap() * min_lag_abs {
233 self.h = min_lag_abs * sign;
234 }
235 self.h = constrain_step_size(self.h, self.h_min, self.h_max);
236 self.status = Status::RejectedStep;
237 return Ok(evals); }
239
240 let order = T::from_usize(self.order).unwrap();
242 let error_exponent = T::one() / order;
243 let mut scale = self.safety_factor * err.powf(-error_exponent);
244
245 scale = scale.max(self.min_scale).min(self.max_scale);
247
248 if err <= T::one() { let y_new = y_next_candidate_iter;
251 let t_new = self.t + self.h;
252
253 if L > 0 {
255 dde.lags(t_new, &y_new, &mut lags);
256 self.lagvals(t_new, &lags, &mut yd, phi);
257 }
258 dde.diff(t_new, &y_new, &yd, &mut self.dydt);
259 evals.fcn += 1; let n_stiff_threshold = 100;
261 if self.steps % n_stiff_threshold == 0 {
262 let mut stdnum = T::zero();
263 let mut stden = T::zero();
264 let sqr = {
265 let mut yseg = V::zeros();
266 for i in 0..self.stages {
267 yseg += self.k[i] * self.b[i];
268 }
269 yseg - self.k[S-1]
270 };
271 for i in 0..sqr.len() {
272 stdnum += sqr.get(i).powi(2);
273 }
274 let sqr = self.dydt - ysti;
275 for i in 0..sqr.len() {
276 stden += sqr.get(i).powi(2);
277 }
278
279 if stden > T::zero() {
280 let h_lamb = self.h * (stdnum / stden).sqrt();
281 if h_lamb > T::from_f64(6.1).unwrap() {
282 self.non_stiffness_counter = 0;
283 self.stiffness_counter += 1;
284 if self.stiffness_counter == 15 {
285 self.status = Status::Error(Error::Stiffness {
286 t: self.t,
287 y: self.y,
288 });
289 return Err(Error::Stiffness {
290 t: self.t,
291 y: self.y,
292 });
293 }
294 }
295 } else {
296 self.non_stiffness_counter += 1;
297 if self.non_stiffness_counter == 6 {
298 self.stiffness_counter = 0;
299 }
300 }
301 }
302
303 self.cont[0] = self.y;
305 let ydiff = y_new - self.y;
306 self.cont[1] = ydiff;
307 let bspl = self.k[0] * self.h - ydiff;
308 self.cont[2] = bspl;
309 self.cont[3] = ydiff - self.dydt * self.h - bspl;
310
311 if let Some(bi) = &self.bi {
313 if I > S {
315 self.k[self.stages] = self.dydt; for i in S+1..I {
317 let mut y_stage = V::zeros();
318 for j in 0..i {
319 y_stage += self.k[j] * self.a[i][j];
320 }
321 y_stage = self.y + y_stage * self.h;
322
323 if L > 0 {
324 dde.lags(self.t + self.c[i] * self.h, &y_stage, &mut lags);
325 for lag_idx in 0..L {
327 let t_delayed = (self.t + self.c[i] * self.h) - lags[lag_idx];
328
329 if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
331 yd[lag_idx] = phi(t_delayed);
332 } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
334 if self.bi.is_some() {
335 let s = (t_delayed - self.t_prev) / self.h_prev;
336 let s1 = T::one() - s;
337 let ilast = self.cont.len() - 1;
338 let poly = (1..ilast).rev().fold(self.cont[ilast], |acc, cont_i| {
339 let factor = if cont_i >= 4 {
340 if (ilast - cont_i) % 2 == 1 { s1 } else { s }
341 } else {
342 if cont_i % 2 == 1 { s1 } else { s }
343 };
344 acc * factor + self.cont[cont_i]
345 });
346 yd[lag_idx] = self.cont[0] + poly * s;
347 } else {
348 yd[lag_idx] = cubic_hermite_interpolate(
349 self.t_prev,
350 self.t,
351 &self.y_prev,
352 &self.y,
353 &self.dydt_prev,
354 &self.dydt,
355 t_delayed
356 );
357 }
358 } else {
359 let mut found_interpolation = false;
361 let buffer = &self.history;
362 let mut buffer_iter = buffer.iter();
363 if let Some(mut prev_entry) = buffer_iter.next() {
364 for curr_entry in buffer_iter {
365 let (t_left, y_left, dydt_left) = prev_entry;
366 let (t_right, y_right, dydt_right) = curr_entry;
367
368 let is_between = if self.h.signum() > T::zero() {
369 *t_left <= t_delayed && t_delayed <= *t_right
370 } else {
371 *t_right <= t_delayed && t_delayed <= *t_left
372 };
373
374 if is_between {
375 yd[lag_idx] = cubic_hermite_interpolate(
376 *t_left,
377 *t_right,
378 y_left,
379 y_right,
380 dydt_left,
381 dydt_right,
382 t_delayed
383 );
384 found_interpolation = true;
385 break;
386 }
387 prev_entry = curr_entry;
388 }
389 }
390 if !found_interpolation {
391 panic!("Insufficient history for t_delayed = {} (t_prev = {}, t = {})", t_delayed, self.t_prev, self.t);
392 }
393 }
394 }
395 }
396 dde.diff(self.t + self.c[i] * self.h, &y_stage, &yd, &mut self.k[i]);
397 evals.fcn += 1;
398 }
399 }
400
401 for i in 4..self.order {
403 self.cont[i] = V::zeros();
404 for j in 0..self.dense_stages {
405 self.cont[i] += self.k[j] * bi[i][j];
406 }
407 self.cont[i] = self.cont[i] * self.h;
408 }
409 }
410
411 self.t_prev = self.t;
413 self.y_prev = self.y;
414 self.dydt_prev = self.k[0];
415 self.h_prev = self.h;
416
417 self.t = t_new;
419 self.y = y_new;
420 self.k[0] = self.dydt;
421
422 self.history.push_back((self.t, self.y, self.dydt));
424 if let Some(max_delay) = self.max_delay {
425 let cutoff_time = self.t - max_delay;
426 while let Some((t_front, _, _)) = self.history.get(1){
427 if *t_front < cutoff_time {
428 self.history.pop_front();
429 } else {
430 break;
431 }
432 }
433 } if let Status::RejectedStep = self.status {
435 self.status = Status::Solving;
436
437 scale = scale.min(T::one());
439 }
440 } else {
441 self.status = Status::RejectedStep;
443 }
444
445 self.h *= scale;
447
448 self.h = constrain_step_size(self.h, self.h_min, self.h_max);
450
451 Ok(evals)
452 }
453
454 fn t(&self) -> T { self.t }
455 fn y(&self) -> &V { &self.y }
456 fn t_prev(&self) -> T { self.t_prev }
457 fn y_prev(&self) -> &V { &self.y_prev }
458 fn h(&self) -> T { self.h }
459 fn set_h(&mut self, h: T) { self.h = h; }
460 fn status(&self) -> &Status<T, V, D> { &self.status }
461 fn set_status(&mut self, status: Status<T, V, D>) { self.status = status; }
462}
463
464impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> ExplicitRungeKutta<Delay, DormandPrince, T, V, D, O, S, I> {
465 fn lagvals<const L: usize, H>(&mut self, t_stage: T, lags: &[T; L], yd: &mut [V; L], phi: &H)
466 where
467 H: Fn(T) -> V,
468 {
469 for i in 0..L {
470 let t_delayed = t_stage - lags[i];
471
472 if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
474 yd[i] = phi(t_delayed);
475 } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
477 if self.bi.is_some() {
478 let s = (t_delayed - self.t_prev) / self.h_prev;
479
480 let s1 = T::one() - s;
482
483 let ilast = self.cont.len() - 1;
485 let poly = (1..ilast).rev().fold(self.cont[ilast], |acc, i| {
486 let factor = if i >= 4 {
487 if (ilast - i) % 2 == 1 { s1 } else { s }
489 } else {
490 if i % 2 == 1 { s1 } else { s }
492 };
493 acc * factor + self.cont[i]
494 });
495
496 let y_interp = self.cont[0] + poly * s;
498 yd[i] = y_interp;
499 } else {
500 yd[i] = cubic_hermite_interpolate(
501 self.t_prev,
502 self.t,
503 &self.y_prev,
504 &self.y,
505 &self.dydt_prev,
506 &self.dydt,
507 t_delayed
508 );
509 }
510 } else {
512 let mut found_interpolation = false;
514 let buffer = &self.history;
515 let mut buffer_iter = buffer.iter();
517 if let Some(mut prev_entry) = buffer_iter.next() {
518 for curr_entry in buffer_iter {
519 let (t_left, y_left, dydt_left) = prev_entry;
520 let (t_right, y_right, dydt_right) = curr_entry;
521
522 let is_between = if self.h.signum() > T::zero() {
524 *t_left <= t_delayed && t_delayed <= *t_right
525 } else {
526 *t_right <= t_delayed && t_delayed <= *t_left
527 };
528
529 if is_between {
530 yd[i] = cubic_hermite_interpolate(
531 *t_left,
532 *t_right,
533 y_left,
534 y_right,
535 dydt_left,
536 dydt_right,
537 t_delayed
538 );
539 found_interpolation = true;
540 break;
541 }
542 prev_entry = curr_entry;
543 }
544 }
545 if !found_interpolation {
547 let buffer = &self.history;
549 println!("Buffer contents ({} entries):", buffer.len());
550 for (idx, (t_buf, _, _)) in buffer.iter().enumerate() {
551 println!(" [{}]: t = {}", idx, t_buf);
552 }
553 panic!("Insufficient history in history for t_delayed = {} (t_prev = {}, t = {}). Buffer may need to retain more points or there's a logic error in determining interpolation intervals.", t_delayed, self.t_prev, self.t);
554 }
555 }
556 }
557 }
558}
559
560impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> Interpolation<T, V> for ExplicitRungeKutta<Delay, DormandPrince, T, V, D, O, S, I> {
561 fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {
562 let posneg = (self.t - self.t_prev).signum();
564 if (t_interp - self.t_prev) * posneg < T::zero() || (t_interp - self.t) * posneg > T::zero() {
565 return Err(Error::OutOfBounds {
566 t_interp,
567 t_prev: self.t_prev,
568 t_curr: self.t,
569 });
570 }
571
572 let s = (t_interp - self.t_prev) / self.h_prev;
574 let s1 = T::one() - s;
575
576 let ilast = self.cont.len() - 1;
578 let poly = (1..ilast).rev().fold(self.cont[ilast], |acc, i| {
579 let factor = if i >= 4 {
580 if (ilast - i) % 2 == 1 { s1 } else { s }
582 } else {
583 if i % 2 == 1 { s1 } else { s }
585 };
586 acc * factor + self.cont[i]
587 });
588
589 let y_interp = self.cont[0] + poly * s;
591
592 Ok(y_interp)
593 }
594}