1use super::{ExplicitRungeKutta, Delay, Adaptive};
4use crate::{
5 Error, Status,
6 methods::h_init::InitialStepSize,
7 alias::Evals,
8 interpolate::{Interpolation, cubic_hermite_interpolate},
9 dde::{DelayNumericalMethod, DDE},
10 traits::{CallBackData, Real, State},
11 utils::{constrain_step_size, validate_step_size_parameters},
12};
13use std::collections::VecDeque;
14
15impl<const L: usize, T: Real, V: State<T>, H: Fn(T) -> V, D: CallBackData, const O: usize, const S: usize, const I: usize> DelayNumericalMethod<L, T, V, H, D> for ExplicitRungeKutta<Delay, Adaptive, T, V, D, O, S, I> {
16 fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &V, phi: &H) -> Result<Evals, Error<T, V>>
17 where
18 F: DDE<L, T, V, D>,
19 {
20 let mut evals = Evals::new();
21
22 self.t0 = t0;
24 self.t = t0;
25 self.y = *y0;
26 self.t_prev = self.t;
27 self.y_prev = self.y;
28 self.status = Status::Initialized;
29 self.steps = 0;
30 self.stiffness_counter = 0;
31 self.history = VecDeque::new();
32
33 let mut lags = [T::zero(); L];
35 let mut yd = [V::zeros(); L];
36
37 if L > 0 {
39 dde.lags(self.t, &self.y, &mut lags);
40 for i in 0..L {
41 if lags[i] <= T::zero() {
42 return Err(Error::BadInput {
43 msg: "All lags must be positive.".to_string(),
44 });
45 }
46 let t_delayed = self.t - lags[i];
47 if (t_delayed - t0) * (tf - t0).signum() > T::default_epsilon() {
49 return Err(Error::BadInput {
50 msg: format!(
51 "Initial delayed time {} is out of history range (t <= {}).",
52 t_delayed, t0
53 ),
54 });
55 }
56 yd[i] = phi(t_delayed);
57 }
58 }
59
60 dde.diff(self.t, &self.y, &yd, &mut self.dydt);
62 evals.fcn += 1;
63 self.dydt_prev = self.dydt; self.history.push_back((self.t, self.y, self.dydt));
65
66 if self.h0 == T::zero() {
68 self.h0 = InitialStepSize::<Delay>::compute(dde, t0, tf, y0, self.order, self.rtol, self.atol, self.h_min, self.h_max, phi, &self.k[0], &mut evals);
70 evals.fcn += 2; }
72
73 match validate_step_size_parameters::<T, V, D>(self.h0, self.h_min, self.h_max, t0, tf) {
75 Ok(h0) => self.h = h0,
76 Err(status) => return Err(status),
77 }
78 Ok(evals)
79 }
80
81 fn step<F>(&mut self, dde: &F, phi: &H) -> Result<Evals, Error<T, V>>
82 where
83 F: DDE<L, T, V, D>,
84 {
85 let mut evals = Evals::new();
86
87 if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
89 self.status = Status::Error(Error::StepSize { t: self.t, y: self.y });
90 return Err(Error::StepSize { t: self.t, y: self.y });
91 }
92
93 if self.steps >= self.max_steps {
95 self.status = Status::Error(Error::MaxSteps { t: self.t, y: self.y });
96 return Err(Error::MaxSteps { t: self.t, y: self.y });
97 }
98 self.steps += 1;
99
100 let mut lags = [T::zero(); L];
102 let mut yd = [V::zeros(); L];
103
104 self.k[0] = self.dydt;
106
107 let mut min_lag_abs = T::infinity();
109 if L > 0 {
110 let y_pred_for_lags = self.y + self.k[0] * self.h;
112 dde.lags(self.t + self.h, &y_pred_for_lags, &mut lags);
113 for i in 0..L {
114 min_lag_abs = min_lag_abs.min(lags[i].abs());
115 }
116 }
117
118 let max_iter: usize = if L > 0 && min_lag_abs < self.h.abs() && min_lag_abs > T::zero() {
120 5
121 } else {
122 1
123 };
124
125 let mut y_next_candidate_iter = self.y; let mut dydt_next_candidate_iter = V::zeros(); let mut y_prev_candidate_iter = self.y; let mut dde_iteration_failed = false;
129 let mut err_norm: T = T::zero(); for iter_idx in 0..max_iter {
133 if iter_idx > 0 {
134 y_prev_candidate_iter = y_next_candidate_iter;
135 }
136
137 for i in 1..self.stages {
139 let mut y_stage = self.y;
140 for j in 0..i {
141 y_stage += self.k[j] * (self.a[i][j] * self.h);
142 }
143 if L > 0 {
145 dde.lags(self.t + self.c[i] * self.h, &y_stage, &mut lags);
146 self.lagvals(self.t + self.c[i] * self.h, &lags, &mut yd, phi);
147 }
148 dde.diff(self.t + self.c[i] * self.h, &y_stage, &yd, &mut self.k[i]);
149 }
150 evals.fcn += self.stages - 1; let mut y_high = self.y; for i in 0..self.stages {
155 y_high += self.k[i] * (self.b[i] * self.h);
156 }
157 let mut y_low = self.y; if let Some(bh_coeffs) = &self.bh {
159 for i in 0..self.stages {
160 y_low += self.k[i] * (bh_coeffs[i] * self.h);
161 }
162 }
163 let err_vec: V = y_high - y_low; err_norm = T::zero();
167 for n in 0..self.y.len() {
168 let tol = self.atol + self.rtol * self.y.get(n).abs().max(y_high.get(n).abs());
169 err_norm = err_norm.max((err_vec.get(n) / tol).abs());
170 }
171
172 if max_iter > 1 && iter_idx > 0 {
174 let mut dde_iteration_error = T::zero();
175 let n_dim = self.y.len();
176 for i_dim in 0..n_dim {
177 let scale = self.atol + self.rtol * y_prev_candidate_iter.get(i_dim).abs().max(y_high.get(i_dim).abs());
178 if scale > T::zero() {
179 let diff_val = y_high.get(i_dim) - y_prev_candidate_iter.get(i_dim);
180 dde_iteration_error += (diff_val / scale).powi(2);
181 }
182 }
183 if n_dim > 0 {
184 dde_iteration_error = (dde_iteration_error / T::from_usize(n_dim).unwrap()).sqrt();
185 }
186
187 if dde_iteration_error <= self.rtol * T::from_f64(0.1).unwrap() {
188 break; }
190 if iter_idx == max_iter - 1 { dde_iteration_failed = dde_iteration_error > self.rtol * T::from_f64(0.1).unwrap();
192 }
193 }
194 y_next_candidate_iter = y_high; if L > 0 {
198 dde.lags(self.t + self.h, &y_next_candidate_iter, &mut lags);
199 self.lagvals(self.t + self.h, &lags, &mut yd, phi);
200 }
201 dde.diff(self.t + self.h, &y_next_candidate_iter, &yd, &mut dydt_next_candidate_iter);
202 evals.fcn += 1;
203 } if dde_iteration_failed {
207 let sign = self.h.signum();
208 self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * sign;
209 if L > 0 && min_lag_abs > T::zero() && self.h.abs() < T::from_f64(2.0).unwrap() * min_lag_abs {
211 self.h = min_lag_abs * sign; }
213 self.h = constrain_step_size(self.h, self.h_min, self.h_max);
214 self.status = Status::RejectedStep; return Ok(evals); }
218
219 let order = T::from_usize(self.order).unwrap();
221 let error_exponent = T::one() / order;
222 let mut scale = self.safety_factor * err_norm.powf(-error_exponent);
223
224 scale = scale.max(self.min_scale).min(self.max_scale);
226
227 if err_norm <= T::one() { self.t_prev = self.t;
230 self.y_prev = self.y;
231 self.dydt_prev = self.dydt; self.h_prev = self.h; if let Status::RejectedStep = self.status { self.stiffness_counter = 0;
236
237 scale = scale.min(T::one());
239 }
240 self.status = Status::Solving;
241
242 if self.bi.is_some() {
244 for i in 0..(I - S) { let mut y_stage_dense = self.y; for j in 0..self.stages + i { y_stage_dense += self.k[j] * (self.a[self.stages + i][j] * self.h);
249 }
250 if L > 0 {
252 dde.lags(self.t + self.c[self.stages + i] * self.h, &y_stage_dense, &mut lags);
253 self.lagvals(self.t + self.c[self.stages + i] * self.h, &lags, &mut yd, phi);
254 }
255 dde.diff(self.t + self.c[self.stages + i] * self.h, &y_stage_dense, &yd, &mut self.k[self.stages + i]);
256 }
257 evals.fcn += I - S; }
259
260 self.t += self.h;
262 self.y = y_next_candidate_iter;
263 if self.fsal {
266 self.dydt = self.k[S - 1];
268 } else {
269 if L > 0 {
270 dde.lags(self.t, &self.y, &mut lags);
271 self.lagvals(self.t, &lags, &mut yd, phi);
272 }
273 dde.diff(self.t, &self.y, &yd, &mut self.dydt);
275 evals.fcn += 1;
276 }
277
278 self.history.push_back((self.t, self.y, self.dydt));
280 if let Some(max_delay) = self.max_delay {
281 let cutoff_time = self.t - max_delay;
282 while let Some((t_front, _, _)) = self.history.get(1){
283 if *t_front < cutoff_time {
284 self.history.pop_front();
285 } else {
286 break; }
288 }
289 }
290 } else { self.status = Status::RejectedStep;
292 self.stiffness_counter += 1;
293
294 if self.stiffness_counter >= self.max_rejects {
296 self.status = Status::Error(Error::Stiffness { t: self.t, y: self.y });
297 return Err(Error::Stiffness { t: self.t, y: self.y });
298 }
299 }
300
301 self.h *= scale;
303
304 self.h = constrain_step_size(self.h, self.h_min, self.h_max);
306
307 Ok(evals)
308 }
309
310 fn t(&self) -> T { self.t }
311 fn y(&self) -> &V { &self.y }
312 fn t_prev(&self) -> T { self.t_prev }
313 fn y_prev(&self) -> &V { &self.y_prev }
314 fn h(&self) -> T { self.h }
315 fn set_h(&mut self, h: T) { self.h = h; }
316 fn status(&self) -> &Status<T, V, D> { &self.status }
317 fn set_status(&mut self, status: Status<T, V, D>) { self.status = status; }
318}
319
320impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> ExplicitRungeKutta<Delay, Adaptive, T, V, D, O, S, I> {
321 fn lagvals<const L: usize, H>(&mut self, t_stage: T, lags: &[T; L], yd: &mut [V; L], phi: &H)
322 where
323 H: Fn(T) -> V,
324 {
325 for i in 0..L {
326 let t_delayed = t_stage - lags[i];
327
328 if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
330 yd[i] = phi(t_delayed);
331 } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
333 if self.bi.is_some() {
334 let s = (t_delayed - self.t_prev) / self.h_prev;
335
336 let bi_coeffs = self.bi.as_ref().unwrap();
337
338 let mut cont = [T::zero(); I];
339 for i in 0..I {
340 if i < self.cont.len() && i < bi_coeffs.len() {
341 cont[i] = bi_coeffs[i][self.dense_stages - 1];
342 for j in (0..self.dense_stages - 1).rev() {
343 cont[i] = cont[i] * s + bi_coeffs[i][j];
344 }
345 cont[i] *= s;
346 }
347 }
348
349 let mut y_interp = self.y_prev;
350 for i in 0..I {
351 if i < self.k.len() && i < self.cont.len() {
352 y_interp += self.k[i] * (cont[i] * self.h_prev);
353 }
354 }
355 yd[i] = y_interp;
356 } else {
357 yd[i] = cubic_hermite_interpolate(
358 self.t_prev,
359 self.t,
360 &self.y_prev,
361 &self.y,
362 &self.dydt_prev,
363 &self.dydt,
364 t_delayed
365 );
366 } } else { let mut found_interpolation = false;
369 let buffer = &self.history;
370 let mut buffer_iter = buffer.iter();
372 if let Some(mut prev_entry) = buffer_iter.next() {
373 for curr_entry in buffer_iter {
374 let (t_left, y_left, dydt_left) = prev_entry;
375 let (t_right, y_right, dydt_right) = curr_entry;
376
377 let is_between = if self.h.signum() > T::zero() {
379 *t_left <= t_delayed && t_delayed <= *t_right
381 } else {
382 *t_right <= t_delayed && t_delayed <= *t_left
384 };
385
386 if is_between {
387 yd[i] = cubic_hermite_interpolate(
389 *t_left,
390 *t_right,
391 y_left,
392 y_right,
393 dydt_left,
394 dydt_right,
395 t_delayed
396 );
397 found_interpolation = true;
398 break;
399 }
400 prev_entry = curr_entry;
401 }
402 }if !found_interpolation {
404 let buffer = &self.history;
406 println!("Buffer contents ({} entries):", buffer.len());
407 for (idx, (t_buf, _, _)) in buffer.iter().enumerate() {
408 if idx < 5 || idx >= buffer.len() - 5 {
409 println!(" [{}] t = {}", idx, t_buf);
410 } else if idx == 5 {
411 println!(" ... ({} more entries) ...", buffer.len() - 10);
412 }
413 }
414 panic!("Insufficient history in history for t_delayed = {} (t_prev = {}, t = {}). Buffer may need to retain more points or there's a logic error in determining interpolation intervals.", t_delayed, self.t_prev, self.t);
415 }
416 }
417 }
418 }
419}
420
421impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> Interpolation<T, V> for ExplicitRungeKutta<Delay, Adaptive, T, V, D, O, S, I> {
422 fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {
424 let posneg = (self.t - self.t_prev).signum();
425 if (t_interp - self.t_prev) * posneg < T::zero() || (t_interp - self.t) * posneg > T::zero() {
426 return Err(Error::OutOfBounds {
427 t_interp,
428 t_prev: self.t_prev,
429 t_curr: self.t,
430 });
431 }
432
433 if self.bi.is_some() {
435 let s = (t_interp - self.t_prev) / self.h_prev;
437
438 let bi = self.bi.as_ref().unwrap();
440
441 let mut cont = [T::zero(); I];
442 for i in 0..self.dense_stages {
444 cont[i] = bi[i][self.order - 1];
446
447 for j in (0..self.order - 1).rev() {
449 cont[i] = cont[i] * s + bi[i][j];
450 }
451
452 cont[i] *= s;
454 }
455
456 let mut y_interp = self.y_prev;
458 for i in 0..I {
459 y_interp += self.k[i] * cont[i] * self.h_prev;
460 }
461
462 Ok(y_interp)
463 } else {
464 let y_interp = cubic_hermite_interpolate(
466 self.t_prev,
467 self.t,
468 &self.y_prev,
469 &self.y,
470 &self.dydt_prev,
471 &self.dydt,
472 t_interp
473 );
474
475 Ok(y_interp)
476 }
477 }
478}