1use super::{ExplicitRungeKutta, Delay, Adaptive};
4use crate::{
5 Error, Status,
6 methods::h_init::InitialStepSize,
7 alias::Evals,
8 interpolate::{Interpolation, cubic_hermite_interpolate},
9 dde::{DelayNumericalMethod, DDE},
10 traits::{CallBackData, Real, State},
11 utils::{constrain_step_size, validate_step_size_parameters},
12};
13use std::collections::VecDeque;
14
15impl<const L: usize, T: Real, V: State<T>, H: Fn(T) -> V, D: CallBackData, const O: usize, const S: usize, const I: usize> DelayNumericalMethod<L, T, V, H, D> for ExplicitRungeKutta<Delay, Adaptive, T, V, D, O, S, I> {
16 fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &V, phi: &H) -> Result<Evals, Error<T, V>>
17 where
18 F: DDE<L, T, V, D>,
19 {
20 let mut evals = Evals::new();
21
22 self.t0 = t0;
24 self.t = t0;
25 self.y = *y0;
26 self.t_prev = self.t;
27 self.y_prev = self.y;
28 self.status = Status::Initialized;
29 self.steps = 0;
30 self.stiffness_counter = 0;
31 self.history = VecDeque::new();
32
33 let mut lags = [T::zero(); L];
35 let mut yd = [V::zeros(); L];
36
37 if L > 0 {
39 dde.lags(self.t, &self.y, &mut lags);
40 for i in 0..L {
41 if lags[i] <= T::zero() {
42 return Err(Error::BadInput {
43 msg: "All lags must be positive.".to_string(),
44 });
45 }
46 let t_delayed = self.t - lags[i];
47 if (t_delayed - t0) * (tf - t0).signum() > T::default_epsilon() {
49 return Err(Error::BadInput {
50 msg: format!(
51 "Initial delayed time {} is out of history range (t <= {}).",
52 t_delayed, t0
53 ),
54 });
55 }
56 yd[i] = phi(t_delayed);
57 }
58 }
59
60 dde.diff(self.t, &self.y, &yd, &mut self.dydt);
62 evals.fcn += 1;
63 self.dydt_prev = self.dydt; self.history.push_back((self.t, self.y, self.dydt));
65
66 if self.h0 == T::zero() {
68 self.h0 = InitialStepSize::<Delay>::compute(dde, t0, tf, y0, self.order, self.rtol, self.atol, self.h_min, self.h_max, phi, &self.k[0], &mut evals);
70 evals.fcn += 2; }
72
73 match validate_step_size_parameters::<T, V, D>(self.h0, self.h_min, self.h_max, t0, tf) {
75 Ok(h0) => self.h = h0,
76 Err(status) => return Err(status),
77 }
78 Ok(evals)
79 }
80
81 fn step<F>(&mut self, dde: &F, phi: &H) -> Result<Evals, Error<T, V>>
82 where
83 F: DDE<L, T, V, D>,
84 {
85 let mut evals = Evals::new();
86
87 if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
89 self.status = Status::Error(Error::StepSize { t: self.t, y: self.y });
90 return Err(Error::StepSize { t: self.t, y: self.y });
91 }
92
93 if self.steps >= self.max_steps {
95 self.status = Status::Error(Error::MaxSteps { t: self.t, y: self.y });
96 return Err(Error::MaxSteps { t: self.t, y: self.y });
97 }
98 self.steps += 1;
99
100 let mut lags = [T::zero(); L];
102 let mut yd = [V::zeros(); L];
103
104 self.k[0] = self.dydt;
106
107 let mut min_lag_abs = T::infinity();
109 if L > 0 {
110 let y_pred_for_lags = self.y + self.k[0] * self.h;
112 dde.lags(self.t + self.h, &y_pred_for_lags, &mut lags);
113 for i in 0..L {
114 min_lag_abs = min_lag_abs.min(lags[i].abs());
115 }
116 }
117
118 let max_iter: usize = if L > 0 && min_lag_abs < self.h.abs() && min_lag_abs > T::zero() {
120 5
121 } else {
122 1
123 };
124
125 let mut y_next_candidate_iter = self.y; let mut dydt_next_candidate_iter = V::zeros(); let mut y_prev_candidate_iter = self.y; let mut dde_iteration_failed = false;
129 let mut err_norm: T = T::zero(); for iter_idx in 0..max_iter {
133 if iter_idx > 0 {
134 y_prev_candidate_iter = y_next_candidate_iter;
135 }
136
137 for i in 1..self.stages {
139 let mut y_stage = self.y;
140 for j in 0..i {
141 y_stage += self.k[j] * (self.a[i][j] * self.h);
142 }
143 if L > 0 {
145 dde.lags(self.t + self.c[i] * self.h, &y_stage, &mut lags);
146 self.lagvals(self.t + self.c[i] * self.h, &lags, &mut yd, phi);
147 }
148 dde.diff(self.t + self.c[i] * self.h, &y_stage, &yd, &mut self.k[i]);
149 }
150 evals.fcn += self.stages - 1; let mut y_high = self.y; for i in 0..self.stages {
155 y_high += self.k[i] * (self.b[i] * self.h);
156 }
157 let mut y_low = self.y; if let Some(bh_coeffs) = &self.bh {
159 for i in 0..self.stages {
160 y_low += self.k[i] * (bh_coeffs[i] * self.h);
161 }
162 }
163 let err_vec: V = y_high - y_low; err_norm = T::zero();
167 for n in 0..self.y.len() {
168 let tol = self.atol + self.rtol * self.y.get(n).abs().max(y_high.get(n).abs());
169 err_norm = err_norm.max((err_vec.get(n) / tol).abs());
170 }
171
172 if max_iter > 1 && iter_idx > 0 {
174 let mut dde_iteration_error = T::zero();
175 let n_dim = self.y.len();
176 for i_dim in 0..n_dim {
177 let scale = self.atol + self.rtol * y_prev_candidate_iter.get(i_dim).abs().max(y_high.get(i_dim).abs());
178 if scale > T::zero() {
179 let diff_val = y_high.get(i_dim) - y_prev_candidate_iter.get(i_dim);
180 dde_iteration_error += (diff_val / scale).powi(2);
181 }
182 }
183 if n_dim > 0 {
184 dde_iteration_error = (dde_iteration_error / T::from_usize(n_dim).unwrap()).sqrt();
185 }
186
187 if dde_iteration_error <= self.rtol * T::from_f64(0.1).unwrap() {
188 break; }
190 if iter_idx == max_iter - 1 { dde_iteration_failed = dde_iteration_error > self.rtol * T::from_f64(0.1).unwrap();
192 }
193 }
194 y_next_candidate_iter = y_high; if L > 0 {
198 dde.lags(self.t + self.h, &y_next_candidate_iter, &mut lags);
199 self.lagvals(self.t + self.h, &lags, &mut yd, phi);
200 }
201 dde.diff(self.t + self.h, &y_next_candidate_iter, &yd, &mut dydt_next_candidate_iter);
202 evals.fcn += 1;
203 } if dde_iteration_failed {
207 let sign = self.h.signum();
208 self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * sign;
209 if L > 0 && min_lag_abs > T::zero() && self.h.abs() < T::from_f64(2.0).unwrap() * min_lag_abs {
211 self.h = min_lag_abs * sign; }
213 self.h = constrain_step_size(self.h, self.h_min, self.h_max);
214 self.status = Status::RejectedStep; return Ok(evals); }
218
219 let order_t = T::from_usize(self.order).unwrap();
221 let err_order_inv = T::one() / order_t;
222 let mut scale_factor = self.safety_factor * err_norm.powf(-err_order_inv);
223 scale_factor = scale_factor.max(self.min_scale).min(self.max_scale);
224
225 let h_new = self.h * scale_factor;
226
227 if err_norm <= T::one() { self.t_prev = self.t;
230 self.y_prev = self.y;
231 self.dydt_prev = self.dydt; self.h_prev = self.h; if let Status::RejectedStep = self.status { self.stiffness_counter = 0;
236 }
237 self.status = Status::Solving;
238
239 if self.bi.is_some() {
241 for i in 0..(I - S) { let mut y_stage_dense = self.y; for j in 0..self.stages + i { y_stage_dense += self.k[j] * (self.a[self.stages + i][j] * self.h);
246 }
247 if L > 0 {
249 dde.lags(self.t + self.c[self.stages + i] * self.h, &y_stage_dense, &mut lags);
250 self.lagvals(self.t + self.c[self.stages + i] * self.h, &lags, &mut yd, phi);
251 }
252 dde.diff(self.t + self.c[self.stages + i] * self.h, &y_stage_dense, &yd, &mut self.k[self.stages + i]);
253 }
254 evals.fcn += I - S; }
256
257 self.t += self.h;
259 self.y = y_next_candidate_iter;
260 if self.fsal {
263 self.dydt = self.k[S - 1];
265 } else {
266 if L > 0 {
267 dde.lags(self.t, &self.y, &mut lags);
268 self.lagvals(self.t, &lags, &mut yd, phi);
269 }
270 dde.diff(self.t, &self.y, &yd, &mut self.dydt);
272 evals.fcn += 1;
273 }
274
275 self.history.push_back((self.t, self.y, self.dydt));
277 if let Some(max_delay) = self.max_delay {
278 let cutoff_time = self.t - max_delay;
279 while let Some((t_front, _, _)) = self.history.get(1){
280 if *t_front < cutoff_time {
281 self.history.pop_front();
282 } else {
283 break; }
285 }
286 }
287
288 self.h = constrain_step_size(h_new, self.h_min, self.h_max); } else { self.status = Status::RejectedStep;
291 self.stiffness_counter += 1;
292
293 if self.stiffness_counter >= self.max_rejects {
295 self.status = Status::Error(Error::Stiffness { t: self.t, y: self.y });
296 return Err(Error::Stiffness { t: self.t, y: self.y });
297 }
298 self.h = constrain_step_size(h_new, self.h_min, self.h_max);
300 }
301 Ok(evals)
302 }
303
304 fn t(&self) -> T { self.t }
305 fn y(&self) -> &V { &self.y }
306 fn t_prev(&self) -> T { self.t_prev }
307 fn y_prev(&self) -> &V { &self.y_prev }
308 fn h(&self) -> T { self.h }
309 fn set_h(&mut self, h: T) { self.h = h; }
310 fn status(&self) -> &Status<T, V, D> { &self.status }
311 fn set_status(&mut self, status: Status<T, V, D>) { self.status = status; }
312}
313
314impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> ExplicitRungeKutta<Delay, Adaptive, T, V, D, O, S, I> {
315 fn lagvals<const L: usize, H>(&mut self, t_stage: T, lags: &[T; L], yd: &mut [V; L], phi: &H)
316 where
317 H: Fn(T) -> V,
318 {
319 for i in 0..L {
320 let t_delayed = t_stage - lags[i];
321
322 if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
324 yd[i] = phi(t_delayed);
325 } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
327 if self.bi.is_some() {
328 let s = (t_delayed - self.t_prev) / self.h_prev;
329
330 let bi_coeffs = self.bi.as_ref().unwrap();
331
332 let mut cont = [T::zero(); I];
333 for i in 0..I {
334 if i < self.cont.len() && i < bi_coeffs.len() {
335 cont[i] = bi_coeffs[i][self.dense_stages - 1];
336 for j in (0..self.dense_stages - 1).rev() {
337 cont[i] = cont[i] * s + bi_coeffs[i][j];
338 }
339 cont[i] *= s;
340 }
341 }
342
343 let mut y_interp = self.y_prev;
344 for i in 0..I {
345 if i < self.k.len() && i < self.cont.len() {
346 y_interp += self.k[i] * (cont[i] * self.h_prev);
347 }
348 }
349 yd[i] = y_interp;
350 } else {
351 yd[i] = cubic_hermite_interpolate(
352 self.t_prev,
353 self.t,
354 &self.y_prev,
355 &self.y,
356 &self.dydt_prev,
357 &self.dydt,
358 t_delayed
359 );
360 } } else { let mut found_interpolation = false;
363 let buffer = &self.history;
364 let mut buffer_iter = buffer.iter();
366 if let Some(mut prev_entry) = buffer_iter.next() {
367 for curr_entry in buffer_iter {
368 let (t_left, y_left, dydt_left) = prev_entry;
369 let (t_right, y_right, dydt_right) = curr_entry;
370
371 let is_between = if self.h.signum() > T::zero() {
373 *t_left <= t_delayed && t_delayed <= *t_right
375 } else {
376 *t_right <= t_delayed && t_delayed <= *t_left
378 };
379
380 if is_between {
381 yd[i] = cubic_hermite_interpolate(
383 *t_left,
384 *t_right,
385 y_left,
386 y_right,
387 dydt_left,
388 dydt_right,
389 t_delayed
390 );
391 found_interpolation = true;
392 break;
393 }
394 prev_entry = curr_entry;
395 }
396 }if !found_interpolation {
398 let buffer = &self.history;
400 println!("Buffer contents ({} entries):", buffer.len());
401 for (idx, (t_buf, _, _)) in buffer.iter().enumerate() {
402 if idx < 5 || idx >= buffer.len() - 5 {
403 println!(" [{}] t = {}", idx, t_buf);
404 } else if idx == 5 {
405 println!(" ... ({} more entries) ...", buffer.len() - 10);
406 }
407 }
408 panic!("Insufficient history in history for t_delayed = {} (t_prev = {}, t = {}). Buffer may need to retain more points or there's a logic error in determining interpolation intervals.", t_delayed, self.t_prev, self.t);
409 }
410 }
411 }
412 }
413}
414
415impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> Interpolation<T, V> for ExplicitRungeKutta<Delay, Adaptive, T, V, D, O, S, I> {
416 fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {
418 let posneg = (self.t - self.t_prev).signum();
419 if (t_interp - self.t_prev) * posneg < T::zero() || (t_interp - self.t) * posneg > T::zero() {
420 return Err(Error::OutOfBounds {
421 t_interp,
422 t_prev: self.t_prev,
423 t_curr: self.t,
424 });
425 }
426
427 if self.bi.is_some() {
429 let s = (t_interp - self.t_prev) / self.h_prev;
431
432 let bi = self.bi.as_ref().unwrap();
434
435 let mut cont = [T::zero(); I];
436 for i in 0..self.dense_stages {
438 cont[i] = bi[i][self.order - 1];
440
441 for j in (0..self.order - 1).rev() {
443 cont[i] = cont[i] * s + bi[i][j];
444 }
445
446 cont[i] *= s;
448 }
449
450 let mut y_interp = self.y_prev;
452 for i in 0..I {
453 y_interp += self.k[i] * cont[i] * self.h_prev;
454 }
455
456 Ok(y_interp)
457 } else {
458 let y_interp = cubic_hermite_interpolate(
460 self.t_prev,
461 self.t,
462 &self.y_prev,
463 &self.y,
464 &self.dydt_prev,
465 &self.dydt,
466 t_interp
467 );
468
469 Ok(y_interp)
470 }
471 }
472}