1use std::collections::VecDeque;
4
5use crate::{
6 dde::{DDE, DelayNumericalMethod},
7 error::Error,
8 interpolate::{Interpolation, cubic_hermite_interpolate},
9 methods::{Delay, ExplicitRungeKutta, Fixed},
10 stats::Evals,
11 status::Status,
12 traits::{CallBackData, Real, State},
13 utils::validate_step_size_parameters,
14};
15
16impl<
17 const L: usize,
18 T: Real,
19 Y: State<T>,
20 H: Fn(T) -> Y,
21 D: CallBackData,
22 const O: usize,
23 const S: usize,
24 const I: usize,
25> DelayNumericalMethod<L, T, Y, H, D> for ExplicitRungeKutta<Delay, Fixed, T, Y, D, O, S, I>
26{
27 fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &Y, phi: &H) -> Result<Evals, Error<T, Y>>
28 where
29 F: DDE<L, T, Y, D>,
30 {
31 let mut evals = Evals::new();
33
34 if L <= 0 {
36 return Err(Error::NoLags);
37 }
38 self.t0 = t0;
39 self.t = t0;
40 self.y = *y0;
41 self.t_prev = self.t;
42 self.y_prev = self.y;
43 self.status = Status::Initialized;
44 self.steps = 0;
45 self.history = VecDeque::new();
46
47 let mut delays = [T::zero(); L];
49 let mut y_delayed = [Y::zeros(); L];
50
51 dde.lags(self.t, &self.y, &mut delays);
53 for i in 0..L {
54 let t_delayed = self.t - delays[i];
55 if (t_delayed - t0) * (tf - t0).signum() > T::default_epsilon() {
57 return Err(Error::BadInput {
58 msg: format!(
59 "Initial delayed time {} is out of history range (t <= {}).",
60 t_delayed, t0
61 ),
62 });
63 }
64 y_delayed[i] = phi(t_delayed);
65 }
66
67 dde.diff(self.t, &self.y, &y_delayed, &mut self.dydt);
69 evals.function += 1;
70 self.dydt_prev = self.dydt; self.history.push_back((self.t, self.y, self.dydt));
72
73 if self.h0 == T::zero() {
75 let duration = (tf - t0).abs();
76 let default_steps = T::from_usize(100).unwrap();
77 self.h0 = duration / default_steps;
78 }
79
80 match validate_step_size_parameters::<T, Y, D>(self.h0, self.h_min, self.h_max, t0, tf) {
82 Ok(h0) => self.h = h0,
83 Err(status) => return Err(status),
84 }
85 Ok(evals)
86 }
87
88 fn step<F>(&mut self, dde: &F, phi: &H) -> Result<Evals, Error<T, Y>>
89 where
90 F: DDE<L, T, Y, D>,
91 {
92 let mut evals = Evals::new();
93
94 if self.steps >= self.max_steps {
96 self.status = Status::Error(Error::MaxSteps {
97 t: self.t,
98 y: self.y,
99 });
100 return Err(Error::MaxSteps {
101 t: self.t,
102 y: self.y,
103 });
104 }
105 self.steps += 1;
106
107 let mut delays = [T::zero(); L];
109 let mut y_delayed = [Y::zeros(); L];
110
111 self.k[0] = self.dydt;
114 let mut min_delay_abs = T::infinity();
115 let y_pred_for_lags = self.y + self.k[0] * self.h;
117 dde.lags(self.t + self.h, &y_pred_for_lags, &mut delays);
118 for i in 0..L {
119 min_delay_abs = min_delay_abs.min(delays[i].abs());
120 }
121
122 let max_iter: usize = if min_delay_abs < self.h.abs() && min_delay_abs > T::zero() {
124 5
125 } else {
126 1
127 };
128
129 let mut y_next_candidate_iter = self.y; let mut dydt_next_candidate_iter = Y::zeros(); let mut y_prev_candidate_iter = self.y; let mut dde_iteration_failed = false;
133
134 for iter_idx in 0..max_iter {
136 if iter_idx > 0 {
137 y_prev_candidate_iter = y_next_candidate_iter;
138 }
139
140 for i in 1..self.stages {
142 let mut y_stage = self.y;
143 for j in 0..i {
144 y_stage += self.k[j] * (self.a[i][j] * self.h);
145 }
146 dde.lags(self.t + self.c[i] * self.h, &y_stage, &mut delays);
148 if let Err(e) =
149 self.lagvals(self.t + self.c[i] * self.h, &delays, &mut y_delayed, phi)
150 {
151 self.status = Status::Error(e.clone());
152 return Err(e);
153 }
154 dde.diff(
155 self.t + self.c[i] * self.h,
156 &y_stage,
157 &y_delayed,
158 &mut self.k[i],
159 );
160 }
161 evals.function += self.stages - 1;
162
163 let mut y_next = self.y;
165 for i in 0..self.stages {
166 y_next += self.k[i] * (self.b[i] * self.h);
167 }
168
169 if max_iter > 1 && iter_idx > 0 {
171 let mut dde_iteration_error = T::zero();
172 let n_dim = self.y.len();
173 for i_dim in 0..n_dim {
174 let scale = T::from_f64(1e-10).unwrap()
175 + y_prev_candidate_iter
176 .get(i_dim)
177 .abs()
178 .max(y_next.get(i_dim).abs());
179 if scale > T::zero() {
180 let diff_val = y_next.get(i_dim) - y_prev_candidate_iter.get(i_dim);
181 dde_iteration_error += (diff_val / scale).powi(2);
182 }
183 }
184 if n_dim > 0 {
185 dde_iteration_error =
186 (dde_iteration_error / T::from_usize(n_dim).unwrap()).sqrt();
187 }
188
189 if dde_iteration_error <= T::from_f64(1e-6).unwrap() {
190 break;
191 }
192 if iter_idx == max_iter - 1 {
193 dde_iteration_failed = dde_iteration_error > T::from_f64(1e-6).unwrap();
194 }
195 }
196 y_next_candidate_iter = y_next;
197
198 dde.lags(self.t + self.h, &y_next_candidate_iter, &mut delays);
200 if let Err(e) = self.lagvals(self.t + self.h, &delays, &mut y_delayed, phi) {
201 self.status = Status::Error(e.clone());
202 return Err(e);
203 }
204 dde.diff(
205 self.t + self.h,
206 &y_next_candidate_iter,
207 &y_delayed,
208 &mut dydt_next_candidate_iter,
209 );
210 evals.function += 1;
211 }
212
213 if dde_iteration_failed {
215 let sign = self.h.signum();
216 self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * sign;
217 if L > 0
218 && min_delay_abs > T::zero()
219 && self.h.abs() < T::from_f64(2.0).unwrap() * min_delay_abs
220 {
221 self.h = min_delay_abs * sign;
222 }
223 self.status = Status::RejectedStep;
224 return Ok(evals);
225 }
226
227 self.t_prev = self.t;
229 self.y_prev = self.y;
230 self.dydt_prev = self.dydt;
231
232 self.t += self.h;
234 self.y = y_next_candidate_iter;
235
236 if self.fsal {
238 self.dydt = self.k[S - 1];
239 } else {
240 dde.lags(self.t, &self.y, &mut delays);
241 if let Err(e) = self.lagvals(self.t, &delays, &mut y_delayed, phi) {
242 self.status = Status::Error(e.clone());
243 return Err(e);
244 }
245 dde.diff(self.t, &self.y, &y_delayed, &mut self.dydt);
246 evals.function += 1;
247 }
248
249 if self.bi.is_some() {
251 for i in 0..(I - S) {
252 let mut y_stage_dense = self.y_prev;
253 for j in 0..self.stages + i {
254 y_stage_dense += self.k[j] * (self.a[self.stages + i][j] * self.h);
255 }
256 let t_stage = self.t_prev + self.c[self.stages + i] * self.h;
257 dde.lags(t_stage, &y_stage_dense, &mut delays);
258 if let Err(e) = self.lagvals(t_stage, &delays, &mut y_delayed, phi) {
259 self.status = Status::Error(e.clone());
260 return Err(e);
261 }
262 dde.diff(
263 self.t_prev + self.c[self.stages + i] * self.h,
264 &y_stage_dense,
265 &y_delayed,
266 &mut self.k[self.stages + i],
267 );
268 }
269 evals.function += I - S;
270 }
271
272 self.history.push_back((self.t, self.y, self.dydt));
274 if let Some(max_delay) = self.max_delay {
275 let cutoff_time = self.t - max_delay;
276 while let Some((t_front, _, _)) = self.history.get(1) {
277 if *t_front < cutoff_time {
278 self.history.pop_front();
279 } else {
280 break;
281 }
282 }
283 }
284
285 self.status = Status::Solving;
286 Ok(evals)
287 }
288
289 fn t(&self) -> T {
290 self.t
291 }
292 fn y(&self) -> &Y {
293 &self.y
294 }
295 fn t_prev(&self) -> T {
296 self.t_prev
297 }
298 fn y_prev(&self) -> &Y {
299 &self.y_prev
300 }
301 fn h(&self) -> T {
302 self.h
303 }
304 fn set_h(&mut self, h: T) {
305 self.h = h;
306 }
307 fn status(&self) -> &Status<T, Y, D> {
308 &self.status
309 }
310 fn set_status(&mut self, status: Status<T, Y, D>) {
311 self.status = status;
312 }
313}
314
315impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
316 ExplicitRungeKutta<Delay, Fixed, T, Y, D, O, S, I>
317{
318 pub fn lagvals<const L: usize, H>(
319 &mut self,
320 t_stage: T,
321 delays: &[T; L],
322 y_delayed: &mut [Y; L],
323 phi: &H,
324 ) -> Result<(), Error<T, Y>>
325 where
326 H: Fn(T) -> Y,
327 {
328 for i in 0..L {
329 let t_delayed = t_stage - delays[i];
330
331 if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
333 y_delayed[i] = phi(t_delayed);
334 } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
336 if self.bi.is_some() {
337 let s = (t_delayed - self.t_prev) / self.h_prev;
338
339 let bi_coeffs = self.bi.as_ref().unwrap();
340
341 let mut cont = [T::zero(); I];
342 for i in 0..I {
343 if i < cont.len() && i < bi_coeffs.len() {
344 cont[i] = bi_coeffs[i][self.dense_stages - 1];
345 for j in (0..self.dense_stages - 1).rev() {
346 cont[i] = cont[i] * s + bi_coeffs[i][j];
347 }
348 cont[i] *= s;
349 }
350 }
351
352 let mut y_interp = self.y_prev;
353 for i in 0..I {
354 if i < self.k.len() && i < cont.len() {
355 y_interp += self.k[i] * (cont[i] * self.h_prev);
356 }
357 }
358 y_delayed[i] = y_interp;
359 } else {
360 y_delayed[i] = cubic_hermite_interpolate(
361 self.t_prev,
362 self.t,
363 &self.y_prev,
364 &self.y,
365 &self.dydt_prev,
366 &self.dydt,
367 t_delayed,
368 );
369 } } else {
371 let mut found_interpolation = false;
373 let buffer = &self.history;
374 let mut buffer_iter = buffer.iter();
376 if let Some(mut prev_entry) = buffer_iter.next() {
377 for curr_entry in buffer_iter {
378 let (t_left, y_left, dydt_left) = prev_entry;
379 let (t_right, y_right, dydt_right) = curr_entry;
380
381 let is_between = if self.h.signum() > T::zero() {
383 *t_left <= t_delayed && t_delayed <= *t_right
385 } else {
386 *t_right <= t_delayed && t_delayed <= *t_left
388 };
389
390 if is_between {
391 y_delayed[i] = cubic_hermite_interpolate(
393 *t_left, *t_right, y_left, y_right, dydt_left, dydt_right,
394 t_delayed,
395 );
396 found_interpolation = true;
397 break;
398 }
399 prev_entry = curr_entry;
400 }
401 } if !found_interpolation {
403 return Err(Error::InsufficientHistory {
404 t_delayed,
405 t_prev: self.t_prev,
406 t_curr: self.t,
407 });
408 }
409 }
410 }
411 Ok(())
412 }
413}
414
415impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
416 Interpolation<T, Y> for ExplicitRungeKutta<Delay, Fixed, T, Y, D, O, S, I>
417{
418 fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
420 let dir = self.h.signum();
421 if (t_interp - self.t_prev) * dir < T::zero() || (t_interp - self.t) * dir > T::zero() {
422 return Err(Error::OutOfBounds {
423 t_interp,
424 t_prev: self.t_prev,
425 t_curr: self.t,
426 });
427 }
428
429 if self.bi.is_some() {
431 let s = (t_interp - self.t_prev) / self.h_prev;
432
433 let bi = self.bi.as_ref().unwrap();
434
435 let mut cont = [T::zero(); I];
436 for i in 0..self.dense_stages {
437 cont[i] = bi[i][self.order - 1];
438 for j in (0..self.order - 1).rev() {
439 cont[i] = cont[i] * s + bi[i][j];
440 }
441 cont[i] *= s;
442 }
443
444 let mut y_interp = self.y_prev;
445 for i in 0..I {
446 y_interp += self.k[i] * cont[i] * self.h_prev;
447 }
448
449 Ok(y_interp)
450 } else {
451 let y_interp = cubic_hermite_interpolate(
453 self.t_prev,
454 self.t,
455 &self.y_prev,
456 &self.y,
457 &self.dydt_prev,
458 &self.dydt,
459 t_interp,
460 );
461
462 Ok(y_interp)
463 }
464 }
465}