1use std::collections::VecDeque;
3
4use crate::{
5 dde::{DDE, DelayNumericalMethod},
6 error::Error,
7 interpolate::{Interpolation, cubic_hermite_interpolate},
8 methods::{Delay, ExplicitRungeKutta, Fixed},
9 stats::Evals,
10 status::Status,
11 traits::{Real, State},
12 utils::validate_step_size_parameters,
13};
14
15impl<
16 const L: usize,
17 T: Real,
18 Y: State<T>,
19 H: Fn(T) -> Y,
20 const O: usize,
21 const S: usize,
22 const I: usize,
23> DelayNumericalMethod<L, T, Y, H> for ExplicitRungeKutta<Delay, Fixed, T, Y, O, S, I>
24{
25 fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &Y, phi: &H) -> Result<Evals, Error<T, Y>>
26 where
27 F: DDE<L, T, Y> + ?Sized,
28 {
29 let mut evals = Evals::new();
31
32 if L == 0 {
34 return Err(Error::NoLags);
35 }
36 self.t0 = t0;
37 self.t = t0;
38 self.y = y0.clone();
39 self.dydt = y0.zeros_like();
40 self.y_prev = y0.clone();
41 self.dydt_prev = y0.zeros_like();
42 self.k = core::array::from_fn(|_| y0.zeros_like());
43 self.cont = core::array::from_fn(|_| y0.zeros_like());
44 self.t_prev = self.t;
45 self.y_prev = self.y.clone();
46 self.status = Status::Initialized;
47 self.steps = 0;
48 self.history = VecDeque::new();
49
50 let mut delays = [T::zero(); L];
52 let mut y_delayed = core::array::from_fn(|_| y0.zeros_like());
53
54 dde.lags(self.t, &self.y, &mut delays);
56 for i in 0..L {
57 let t_delayed = self.t - delays[i];
58 if (t_delayed - t0) * (tf - t0).signum() > T::default_epsilon() {
60 return Err(Error::BadInput {
61 msg: format!(
62 "Initial delayed time {} is out of history range (t <= {}).",
63 t_delayed, t0
64 ),
65 });
66 }
67 y_delayed[i] = phi(t_delayed);
68 }
69
70 dde.diff(self.t, &self.y, &y_delayed, &mut self.dydt);
72 evals.function += 1;
73 self.dydt_prev = self.dydt.clone(); self.history
75 .push_back((self.t, self.y.clone(), self.dydt.clone()));
76
77 if self.h0 == T::zero() {
79 let duration = (tf - t0).abs();
80 let default_steps = T::from_usize(100).unwrap();
81 self.h0 = duration / default_steps;
82 }
83
84 match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
86 Ok(h0) => self.h = h0,
87 Err(status) => return Err(status),
88 }
89 Ok(evals)
90 }
91
92 fn step<F>(&mut self, dde: &F, phi: &H) -> Result<Evals, Error<T, Y>>
93 where
94 F: DDE<L, T, Y> + ?Sized,
95 {
96 let mut evals = Evals::new();
97
98 if self.steps >= self.max_steps {
100 self.status = Status::Error(Error::MaxSteps {
101 t: self.t,
102 y: self.y.clone(),
103 });
104 return Err(Error::MaxSteps {
105 t: self.t,
106 y: self.y.clone(),
107 });
108 }
109 self.steps += 1;
110
111 let mut delays = [T::zero(); L];
113 let mut y_delayed = core::array::from_fn(|_| self.y.zeros_like());
114
115 self.k[0] = self.dydt.clone();
118 let mut min_delay_abs = T::infinity();
119 let y_pred_for_lags = self.y.plus_scaled(self.h, &self.k[0]);
121 dde.lags(self.t + self.h, &y_pred_for_lags, &mut delays);
122 for i in 0..L {
123 min_delay_abs = min_delay_abs.min(delays[i].abs());
124 }
125
126 let max_iter: usize = if min_delay_abs < self.h.abs() && min_delay_abs > T::zero() {
128 5
129 } else {
130 1
131 };
132
133 let mut y_next_candidate_iter = self.y.clone(); let mut dydt_next_candidate_iter = self.y.zeros_like(); let mut y_prev_candidate_iter = self.y.clone(); let mut dde_iteration_failed = false;
137
138 for iter_idx in 0..max_iter {
140 if iter_idx > 0 {
141 y_prev_candidate_iter = y_next_candidate_iter.clone();
142 }
143
144 for i in 1..self.stages {
146 let mut y_stage = self.y.clone();
147 for j in 0..i {
148 y_stage.add_scaled(self.a[i][j] * self.h, &self.k[j]);
149 }
150 dde.lags(self.t + self.c[i] * self.h, &y_stage, &mut delays);
152 if let Err(e) =
153 self.lagvals(self.t + self.c[i] * self.h, &delays, &mut y_delayed, phi)
154 {
155 self.status = Status::Error(e.clone());
156 return Err(e);
157 }
158 dde.diff(
159 self.t + self.c[i] * self.h,
160 &y_stage,
161 &y_delayed,
162 &mut self.k[i],
163 );
164 }
165 evals.function += self.stages - 1;
166
167 let mut y_next = self.y.clone();
169 for i in 0..self.stages {
170 y_next.add_scaled(self.b[i] * self.h, &self.k[i]);
171 }
172
173 if max_iter > 1 && iter_idx > 0 {
175 let n_dim = self.y.len();
176 let mut dde_iteration_error = T::zero();
177 for i_dim in 0..n_dim {
178 let scale = T::from_f64(1e-10).unwrap()
179 + y_prev_candidate_iter
180 .get_component(i_dim)
181 .abs()
182 .max(y_next.get_component(i_dim).abs());
183 if scale > T::zero() {
184 let diff_val = y_next.get_component(i_dim)
185 - y_prev_candidate_iter.get_component(i_dim);
186 let val = diff_val / scale;
187 dde_iteration_error += val * val;
188 }
189 }
190 if n_dim > 0 {
191 dde_iteration_error =
192 (dde_iteration_error / T::from_usize(n_dim).unwrap()).sqrt();
193 }
194
195 if dde_iteration_error <= T::from_f64(1e-6).unwrap() {
196 break;
197 }
198 if iter_idx == max_iter - 1 {
199 dde_iteration_failed = dde_iteration_error > T::from_f64(1e-6).unwrap();
200 }
201 }
202 y_next_candidate_iter = y_next.clone();
203
204 dde.lags(self.t + self.h, &y_next_candidate_iter, &mut delays);
206 if let Err(e) = self.lagvals(self.t + self.h, &delays, &mut y_delayed, phi) {
207 self.status = Status::Error(e.clone());
208 return Err(e);
209 }
210 dde.diff(
211 self.t + self.h,
212 &y_next_candidate_iter,
213 &y_delayed,
214 &mut dydt_next_candidate_iter,
215 );
216 evals.function += 1;
217 }
218
219 if dde_iteration_failed {
221 let sign = self.h.signum();
222 self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * sign;
223 if L > 0
224 && min_delay_abs > T::zero()
225 && self.h.abs() < T::from_f64(2.0).unwrap() * min_delay_abs
226 {
227 self.h = min_delay_abs * sign;
228 }
229 self.status = Status::RejectedStep;
230 return Ok(evals);
231 }
232
233 self.t_prev = self.t;
235 self.y_prev = self.y.clone();
236 self.dydt_prev = self.dydt.clone();
237
238 self.t += self.h;
240 self.y = y_next_candidate_iter;
241
242 if self.fsal {
244 self.dydt = self.k[S - 1].clone();
245 } else {
246 dde.lags(self.t, &self.y, &mut delays);
247 if let Err(e) = self.lagvals(self.t, &delays, &mut y_delayed, phi) {
248 self.status = Status::Error(e.clone());
249 return Err(e);
250 }
251 dde.diff(self.t, &self.y, &y_delayed, &mut self.dydt);
252 evals.function += 1;
253 }
254
255 if self.bi.is_some() {
257 for i in 0..(I - S) {
258 let mut y_stage_dense = self.y_prev.clone();
259 for j in 0..self.stages + i {
260 y_stage_dense.add_scaled(self.a[self.stages + i][j] * self.h, &self.k[j]);
261 }
262 let t_stage = self.t_prev + self.c[self.stages + i] * self.h;
263 dde.lags(t_stage, &y_stage_dense, &mut delays);
264 if let Err(e) = self.lagvals(t_stage, &delays, &mut y_delayed, phi) {
265 self.status = Status::Error(e.clone());
266 return Err(e);
267 }
268 dde.diff(
269 self.t_prev + self.c[self.stages + i] * self.h,
270 &y_stage_dense,
271 &y_delayed,
272 &mut self.k[self.stages + i],
273 );
274 }
275 evals.function += I - S;
276 }
277
278 self.history
280 .push_back((self.t, self.y.clone(), self.dydt.clone()));
281 if let Some(max_delay) = self.max_delay {
282 let cutoff_time = self.t - max_delay;
283 while let Some((t_front, _, _)) = self.history.get(1) {
284 if *t_front < cutoff_time {
285 self.history.pop_front();
286 } else {
287 break;
288 }
289 }
290 }
291
292 self.status = Status::Solving;
293 Ok(evals)
294 }
295
296 fn t(&self) -> T {
297 self.t
298 }
299 fn y(&self) -> &Y {
300 &self.y
301 }
302 fn t_prev(&self) -> T {
303 self.t_prev
304 }
305 fn y_prev(&self) -> &Y {
306 &self.y_prev
307 }
308 fn h(&self) -> T {
309 self.h
310 }
311 fn set_h(&mut self, h: T) {
312 self.h = h;
313 }
314 fn status(&self) -> &Status<T, Y> {
315 &self.status
316 }
317 fn set_status(&mut self, status: Status<T, Y>) {
318 self.status = status;
319 }
320}
321
322impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
323 ExplicitRungeKutta<Delay, Fixed, T, Y, O, S, I>
324{
325 pub fn lagvals<const L: usize, H>(
326 &mut self,
327 t_stage: T,
328 delays: &[T; L],
329 y_delayed: &mut [Y; L],
330 phi: &H,
331 ) -> Result<(), Error<T, Y>>
332 where
333 H: Fn(T) -> Y,
334 {
335 for i in 0..L {
336 let t_delayed = t_stage - delays[i];
337
338 if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
340 y_delayed[i] = phi(t_delayed);
341 } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
343 if let Some(bi_coeffs) = self.bi.as_ref() {
344 let s = (t_delayed - self.t_prev) / self.h_prev;
345
346 let mut cont = [T::zero(); I];
347 for i in 0..I {
348 if i < cont.len() && i < bi_coeffs.len() {
349 cont[i] = bi_coeffs[i][self.dense_stages - 1];
350 for j in (0..self.dense_stages - 1).rev() {
351 cont[i] = cont[i] * s + bi_coeffs[i][j];
352 }
353 cont[i] *= s;
354 }
355 }
356
357 let mut y_interp = self.y_prev.clone();
358 for i in 0..I {
359 if i < self.k.len() && i < cont.len() {
360 y_interp.add_scaled(cont[i] * self.h_prev, &self.k[i]);
361 }
362 }
363 y_delayed[i] = y_interp;
364 } else {
365 y_delayed[i] = cubic_hermite_interpolate(
366 self.t_prev,
367 self.t,
368 &self.y_prev,
369 &self.y,
370 &self.dydt_prev,
371 &self.dydt,
372 t_delayed,
373 );
374 } } else {
376 let mut found_interpolation = false;
378 let buffer = &self.history;
379 let mut buffer_iter = buffer.iter();
381 if let Some(mut prev_entry) = buffer_iter.next() {
382 for curr_entry in buffer_iter {
383 let (t_left, y_left, dydt_left) = prev_entry;
384 let (t_right, y_right, dydt_right) = curr_entry;
385
386 let is_between = if self.h.signum() > T::zero() {
388 *t_left <= t_delayed && t_delayed <= *t_right
390 } else {
391 *t_right <= t_delayed && t_delayed <= *t_left
393 };
394
395 if is_between {
396 y_delayed[i] = cubic_hermite_interpolate(
398 *t_left, *t_right, y_left, y_right, dydt_left, dydt_right,
399 t_delayed,
400 );
401 found_interpolation = true;
402 break;
403 }
404 prev_entry = curr_entry;
405 }
406 } if !found_interpolation {
408 return Err(Error::InsufficientHistory {
409 t_delayed,
410 t_prev: self.t_prev,
411 t_curr: self.t,
412 });
413 }
414 }
415 }
416 Ok(())
417 }
418}
419
420impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
421 for ExplicitRungeKutta<Delay, Fixed, T, Y, O, S, I>
422{
423 fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
425 let dir = self.h.signum();
426 if (t_interp - self.t_prev) * dir < T::zero() || (t_interp - self.t) * dir > T::zero() {
427 return Err(Error::OutOfBounds {
428 t_interp,
429 t_prev: self.t_prev,
430 t_curr: self.t,
431 });
432 }
433
434 if let Some(bi) = self.bi.as_ref() {
436 let s = (t_interp - self.t_prev) / self.h_prev;
437
438 let mut cont = [T::zero(); I];
439 for i in 0..self.dense_stages {
440 cont[i] = bi[i][self.order - 1];
441 for j in (0..self.order - 1).rev() {
442 cont[i] = cont[i] * s + bi[i][j];
443 }
444 cont[i] *= s;
445 }
446
447 let mut y_interp = self.y_prev.clone();
448 for i in 0..I {
449 y_interp.add_scaled(cont[i] * self.h_prev, &self.k[i]);
450 }
451
452 Ok(y_interp)
453 } else {
454 let y_interp = cubic_hermite_interpolate(
456 self.t_prev,
457 self.t,
458 &self.y_prev,
459 &self.y,
460 &self.dydt_prev,
461 &self.dydt,
462 t_interp,
463 );
464
465 Ok(y_interp)
466 }
467 }
468}