1use super::{ExplicitRungeKutta, Delay, DormandPrince};
4use crate::{
5 Error, Status,
6 methods::h_init::InitialStepSize,
7 alias::Evals,
8 interpolate::{Interpolation, cubic_hermite_interpolate},
9 dde::{DelayNumericalMethod, DDE},
10 traits::{CallBackData, Real, State},
11 utils::{constrain_step_size, validate_step_size_parameters},
12};
13use std::collections::VecDeque;
14
15impl<const L: usize, T: Real, V: State<T>, H: Fn(T) -> V, D: CallBackData, const O: usize, const S: usize, const I: usize> DelayNumericalMethod<L, T, V, H, D> for ExplicitRungeKutta<Delay, DormandPrince, T, V, D, O, S, I> {
16 fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &V, phi: &H) -> Result<Evals, Error<T, V>>
17 where
18 F: DDE<L, T, V, D>,
19 {
20 let mut evals = Evals::new();
21
22 self.t0 = t0;
24 self.t = t0;
25 self.y = *y0;
26 self.t_prev = self.t;
27 self.y_prev = self.y;
28 self.status = Status::Initialized; self.steps = 0;
29 self.stiffness_counter = 0;
30 self.non_stiffness_counter = 0;
31 self.history = VecDeque::new();
32
33 let mut lags = [T::zero(); L];
35 let mut yd = [V::zeros(); L];
36
37 if L > 0 {
39 dde.lags(self.t, &self.y, &mut lags);
40 for i in 0..L {
41 if lags[i] <= T::zero() {
42 return Err(Error::BadInput {
43 msg: "All lags must be positive.".to_string(),
44 });
45 }
46 let t_delayed = self.t - lags[i];
47 if (t_delayed - t0) * (tf - t0).signum() > T::default_epsilon() {
49 return Err(Error::BadInput {
50 msg: format!(
51 "Delayed time {} is beyond initial time {}",
52 t_delayed, t0
53 ),
54 });
55 }
56 yd[i] = phi(t_delayed);
57 }
58 }
59
60 dde.diff(self.t, &self.y, &yd, &mut self.k[0]);
62 self.dydt = self.k[0];
63 evals.fcn += 1;
64 self.dydt_prev = self.dydt;
65
66 self.history.push_back((self.t, self.y, self.dydt));
68
69 if self.h0 == T::zero() {
71 self.h0 = InitialStepSize::<Delay>::compute(dde, t0, tf, y0, self.order, self.rtol, self.atol, self.h_min, self.h_max, phi, &self.k[0], &mut evals);
73 }
74
75 match validate_step_size_parameters::<T, V, D>(self.h0, self.h_min, self.h_max, t0, tf) {
77 Ok(h0) => self.h = h0,
78 Err(status) => return Err(status),
79 }
80 Ok(evals)
81 }
82
83 fn step<F>(&mut self, dde: &F, phi: &H) -> Result<Evals, Error<T, V>>
84 where
85 F: DDE<L, T, V, D>,
86 {
87 let mut evals = Evals::new();
88
89 if self.h.abs() < T::default_epsilon() {
91 self.status = Status::Error(Error::StepSize { t: self.t, y: self.y });
92 return Err(Error::StepSize { t: self.t, y: self.y });
93 }
94
95 if self.steps >= self.max_steps {
97 self.status = Status::Error(Error::MaxSteps { t: self.t, y: self.y });
98 return Err(Error::MaxSteps { t: self.t, y: self.y });
99 }
100 self.steps += 1;
101
102 let mut lags = [T::zero(); L];
104 let mut yd = [V::zeros(); L];
105
106 let mut min_lag_abs = T::infinity();
108 if L > 0 {
109 let y_pred_for_lags = self.y + self.k[0] * self.h;
111 dde.lags(self.t + self.h, &y_pred_for_lags, &mut lags);
112 for i in 0..L {
113 min_lag_abs = min_lag_abs.min(lags[i].abs());
114 }
115 }
116
117 let max_iter: usize = if L > 0 && min_lag_abs < self.h.abs() && min_lag_abs > T::zero() {
119 5
120 } else {
121 1
122 }; let mut y_next_candidate_iter = self.y; let mut y_prev_candidate_iter = self.y; let mut dde_iteration_failed = false;
125 let mut err: T = T::zero(); let mut ysti = V::zeros(); for iter_idx in 0..max_iter {
130 if iter_idx > 0 {
131 y_prev_candidate_iter = y_next_candidate_iter;
132 }
133
134 let mut y_stage = V::zeros();
136 for i in 1..self.stages {
137 y_stage = V::zeros();
138 for j in 0..i {
139 y_stage += self.k[j] * self.a[i][j];
140 }
141 y_stage = self.y + y_stage * self.h;
142
143 if L > 0 {
145 dde.lags(self.t + self.c[i] * self.h, &y_stage, &mut lags);
146 self.lagvals(self.t + self.c[i] * self.h, &lags, &mut yd, phi);
147 }
148 dde.diff(self.t + self.c[i] * self.h, &y_stage, &yd, &mut self.k[i]);
149 }
150 evals.fcn += self.stages - 1; ysti = y_stage;
154
155 let mut yseg = V::zeros();
157 for i in 0..self.stages {
158 yseg += self.k[i] * self.b[i];
159 }
160
161 let y_new = self.y + yseg * self.h;
163
164 let er = self.er.unwrap();
166 let n = self.y.len();
167 let mut err_val = T::zero();
168 let mut err2 = T::zero();
169 let mut erri;
170 for i in 0..n {
171 let sk = self.atol + self.rtol * self.y.get(i).abs().max(y_new.get(i).abs());
173
174 erri = T::zero();
176 for j in 0..self.stages {
177 erri += er[j] * self.k[j].get(i);
178 }
179 err_val += (erri / sk).powi(2);
180
181 if let Some(bh) = &self.bh {
183 erri = yseg.get(i);
184 for j in 0..self.stages {
185 erri -= bh[j] * self.k[j].get(i);
186 }
187 err2 += (erri / sk).powi(2);
188 }
189 }
190 let mut deno = err_val + T::from_f64(0.01).unwrap() * err2;
191 if deno <= T::zero() {
192 deno = T::one();
193 }
194 err = self.h.abs() * err_val * (T::one() / (deno * T::from_usize(n).unwrap())).sqrt();
195
196 if max_iter > 1 && iter_idx > 0 {
198 let mut dde_iteration_error = T::zero();
199 let n_dim = self.y.len();
200 for i_dim in 0..n_dim {
201 let scale = self.atol + self.rtol * y_prev_candidate_iter.get(i_dim).abs().max(y_new.get(i_dim).abs());
202 if scale > T::zero() {
203 let diff_val = y_new.get(i_dim) - y_prev_candidate_iter.get(i_dim);
204 dde_iteration_error += (diff_val / scale).powi(2);
205 }
206 }
207 if n_dim > 0 {
208 dde_iteration_error = (dde_iteration_error / T::from_usize(n_dim).unwrap()).sqrt();
209 }
210
211 if dde_iteration_error <= self.rtol * T::from_f64(0.1).unwrap() {
212 break; }
214 if iter_idx == max_iter - 1 { dde_iteration_failed = dde_iteration_error > self.rtol * T::from_f64(0.1).unwrap();
216 }
217 }
218 y_next_candidate_iter = y_new; if iter_idx == max_iter - 1 || max_iter == 1 {
222 }
224 } if dde_iteration_failed {
228 let sign = self.h.signum();
229 self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * sign;
230 if L > 0 && min_lag_abs > T::zero() && self.h.abs() < T::from_f64(2.0).unwrap() * min_lag_abs {
232 self.h = min_lag_abs * sign;
233 }
234 self.h = constrain_step_size(self.h, self.h_min, self.h_max);
235 self.status = Status::RejectedStep;
236 return Ok(evals); }
238
239 if err <= T::one() { let y_new = y_next_candidate_iter;
242 let t_new = self.t + self.h;
243
244 if L > 0 {
246 dde.lags(t_new, &y_new, &mut lags);
247 self.lagvals(t_new, &lags, &mut yd, phi);
248 }
249 dde.diff(t_new, &y_new, &yd, &mut self.dydt);
250 evals.fcn += 1; let n_stiff_threshold = 100;
252 if self.steps % n_stiff_threshold == 0 {
253 let mut stdnum = T::zero();
254 let mut stden = T::zero();
255 let sqr = {
256 let mut yseg = V::zeros();
257 for i in 0..self.stages {
258 yseg += self.k[i] * self.b[i];
259 }
260 yseg - self.k[S-1]
261 };
262 for i in 0..sqr.len() {
263 stdnum += sqr.get(i).powi(2);
264 }
265 let sqr = self.dydt - ysti;
266 for i in 0..sqr.len() {
267 stden += sqr.get(i).powi(2);
268 }
269
270 if stden > T::zero() {
271 let h_lamb = self.h * (stdnum / stden).sqrt();
272 if h_lamb > T::from_f64(6.1).unwrap() {
273 self.non_stiffness_counter = 0;
274 self.stiffness_counter += 1;
275 if self.stiffness_counter == 15 {
276 self.status = Status::Error(Error::Stiffness {
277 t: self.t,
278 y: self.y,
279 });
280 return Err(Error::Stiffness {
281 t: self.t,
282 y: self.y,
283 });
284 }
285 }
286 } else {
287 self.non_stiffness_counter += 1;
288 if self.non_stiffness_counter == 6 {
289 self.stiffness_counter = 0;
290 }
291 }
292 }
293
294 self.cont[0] = self.y;
296 let ydiff = y_new - self.y;
297 self.cont[1] = ydiff;
298 let bspl = self.k[0] * self.h - ydiff;
299 self.cont[2] = bspl;
300 self.cont[3] = ydiff - self.dydt * self.h - bspl;
301
302 if let Some(bi) = &self.bi {
304 if I > S {
306 self.k[self.stages] = self.dydt; for i in S+1..I {
308 let mut y_stage = V::zeros();
309 for j in 0..i {
310 y_stage += self.k[j] * self.a[i][j];
311 }
312 y_stage = self.y + y_stage * self.h;
313
314 if L > 0 {
315 dde.lags(self.t + self.c[i] * self.h, &y_stage, &mut lags);
316 for lag_idx in 0..L {
318 let t_delayed = (self.t + self.c[i] * self.h) - lags[lag_idx];
319
320 if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
322 yd[lag_idx] = phi(t_delayed);
323 } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
325 if self.bi.is_some() {
326 let s = (t_delayed - self.t_prev) / self.h_prev;
327 let s1 = T::one() - s;
328 let ilast = self.cont.len() - 1;
329 let poly = (1..ilast).rev().fold(self.cont[ilast], |acc, cont_i| {
330 let factor = if cont_i >= 4 {
331 if (ilast - cont_i) % 2 == 1 { s1 } else { s }
332 } else {
333 if cont_i % 2 == 1 { s1 } else { s }
334 };
335 acc * factor + self.cont[cont_i]
336 });
337 yd[lag_idx] = self.cont[0] + poly * s;
338 } else {
339 yd[lag_idx] = cubic_hermite_interpolate(
340 self.t_prev,
341 self.t,
342 &self.y_prev,
343 &self.y,
344 &self.dydt_prev,
345 &self.dydt,
346 t_delayed
347 );
348 }
349 } else {
350 let mut found_interpolation = false;
352 let buffer = &self.history;
353 let mut buffer_iter = buffer.iter();
354 if let Some(mut prev_entry) = buffer_iter.next() {
355 for curr_entry in buffer_iter {
356 let (t_left, y_left, dydt_left) = prev_entry;
357 let (t_right, y_right, dydt_right) = curr_entry;
358
359 let is_between = if self.h.signum() > T::zero() {
360 *t_left <= t_delayed && t_delayed <= *t_right
361 } else {
362 *t_right <= t_delayed && t_delayed <= *t_left
363 };
364
365 if is_between {
366 yd[lag_idx] = cubic_hermite_interpolate(
367 *t_left,
368 *t_right,
369 y_left,
370 y_right,
371 dydt_left,
372 dydt_right,
373 t_delayed
374 );
375 found_interpolation = true;
376 break;
377 }
378 prev_entry = curr_entry;
379 }
380 }
381 if !found_interpolation {
382 panic!("Insufficient history for t_delayed = {} (t_prev = {}, t = {})", t_delayed, self.t_prev, self.t);
383 }
384 }
385 }
386 }
387 dde.diff(self.t + self.c[i] * self.h, &y_stage, &yd, &mut self.k[i]);
388 evals.fcn += 1;
389 }
390 }
391
392 for i in 4..self.order {
394 self.cont[i] = V::zeros();
395 for j in 0..self.dense_stages {
396 self.cont[i] += self.k[j] * bi[i][j];
397 }
398 self.cont[i] = self.cont[i] * self.h;
399 }
400 }
401
402 self.t_prev = self.t;
404 self.y_prev = self.y;
405 self.dydt_prev = self.k[0];
406 self.h_prev = self.h;
407
408 self.t = t_new;
410 self.y = y_new;
411 self.k[0] = self.dydt;
412
413 self.history.push_back((self.t, self.y, self.dydt));
415 if let Some(max_delay) = self.max_delay {
416 let cutoff_time = self.t - max_delay;
417 while let Some((t_front, _, _)) = self.history.get(1){
418 if *t_front < cutoff_time {
419 self.history.pop_front();
420 } else {
421 break;
422 }
423 }
424 } if let Status::RejectedStep = self.status {
426 self.status = Status::Solving;
427 }
428 } else {
429 self.status = Status::RejectedStep;
431 }
432
433 let order = T::from_usize(self.order).unwrap();
435 let err_order = T::one() / order;
436
437 let scale = self.safety_factor * err.powf(-err_order);
439 let scale = scale.max(self.min_scale).min(self.max_scale);
440 self.h *= scale;
441
442 self.h = constrain_step_size(self.h, self.h_min, self.h_max);
444
445 Ok(evals)
446 }
447
448 fn t(&self) -> T { self.t }
449 fn y(&self) -> &V { &self.y }
450 fn t_prev(&self) -> T { self.t_prev }
451 fn y_prev(&self) -> &V { &self.y_prev }
452 fn h(&self) -> T { self.h }
453 fn set_h(&mut self, h: T) { self.h = h; }
454 fn status(&self) -> &Status<T, V, D> { &self.status }
455 fn set_status(&mut self, status: Status<T, V, D>) { self.status = status; }
456}
457
458impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> ExplicitRungeKutta<Delay, DormandPrince, T, V, D, O, S, I> {
459 fn lagvals<const L: usize, H>(&mut self, t_stage: T, lags: &[T; L], yd: &mut [V; L], phi: &H)
460 where
461 H: Fn(T) -> V,
462 {
463 for i in 0..L {
464 let t_delayed = t_stage - lags[i];
465
466 if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
468 yd[i] = phi(t_delayed);
469 } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
471 if self.bi.is_some() {
472 let s = (t_delayed - self.t_prev) / self.h_prev;
473
474 let s1 = T::one() - s;
476
477 let ilast = self.cont.len() - 1;
479 let poly = (1..ilast).rev().fold(self.cont[ilast], |acc, i| {
480 let factor = if i >= 4 {
481 if (ilast - i) % 2 == 1 { s1 } else { s }
483 } else {
484 if i % 2 == 1 { s1 } else { s }
486 };
487 acc * factor + self.cont[i]
488 });
489
490 let y_interp = self.cont[0] + poly * s;
492 yd[i] = y_interp;
493 } else {
494 yd[i] = cubic_hermite_interpolate(
495 self.t_prev,
496 self.t,
497 &self.y_prev,
498 &self.y,
499 &self.dydt_prev,
500 &self.dydt,
501 t_delayed
502 );
503 }
504 } else {
506 let mut found_interpolation = false;
508 let buffer = &self.history;
509 let mut buffer_iter = buffer.iter();
511 if let Some(mut prev_entry) = buffer_iter.next() {
512 for curr_entry in buffer_iter {
513 let (t_left, y_left, dydt_left) = prev_entry;
514 let (t_right, y_right, dydt_right) = curr_entry;
515
516 let is_between = if self.h.signum() > T::zero() {
518 *t_left <= t_delayed && t_delayed <= *t_right
519 } else {
520 *t_right <= t_delayed && t_delayed <= *t_left
521 };
522
523 if is_between {
524 yd[i] = cubic_hermite_interpolate(
525 *t_left,
526 *t_right,
527 y_left,
528 y_right,
529 dydt_left,
530 dydt_right,
531 t_delayed
532 );
533 found_interpolation = true;
534 break;
535 }
536 prev_entry = curr_entry;
537 }
538 }
539 if !found_interpolation {
541 let buffer = &self.history;
543 println!("Buffer contents ({} entries):", buffer.len());
544 for (idx, (t_buf, _, _)) in buffer.iter().enumerate() {
545 println!(" [{}]: t = {}", idx, t_buf);
546 }
547 panic!("Insufficient history in history for t_delayed = {} (t_prev = {}, t = {}). Buffer may need to retain more points or there's a logic error in determining interpolation intervals.", t_delayed, self.t_prev, self.t);
548 }
549 }
550 }
551 }
552}
553
554impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> Interpolation<T, V> for ExplicitRungeKutta<Delay, DormandPrince, T, V, D, O, S, I> {
555 fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {
556 let posneg = (self.t - self.t_prev).signum();
558 if (t_interp - self.t_prev) * posneg < T::zero() || (t_interp - self.t) * posneg > T::zero() {
559 return Err(Error::OutOfBounds {
560 t_interp,
561 t_prev: self.t_prev,
562 t_curr: self.t,
563 });
564 }
565
566 let s = (t_interp - self.t_prev) / self.h_prev;
568 let s1 = T::one() - s;
569
570 let ilast = self.cont.len() - 1;
572 let poly = (1..ilast).rev().fold(self.cont[ilast], |acc, i| {
573 let factor = if i >= 4 {
574 if (ilast - i) % 2 == 1 { s1 } else { s }
576 } else {
577 if i % 2 == 1 { s1 } else { s }
579 };
580 acc * factor + self.cont[i]
581 });
582
583 let y_interp = self.cont[0] + poly * s;
585
586 Ok(y_interp)
587 }
588}