1use super::{ExplicitRungeKutta, Delay, Fixed};
4use crate::{
5 Error, Status,
6 alias::Evals,
7 interpolate::{Interpolation, cubic_hermite_interpolate},
8 dde::{DelayNumericalMethod, DDE},
9 traits::{CallBackData, Real, State},
10 utils::validate_step_size_parameters,
11};
12use std::collections::VecDeque;
13
14impl<const L: usize, T: Real, V: State<T>, H: Fn(T) -> V, D: CallBackData, const O: usize, const S: usize, const I: usize> DelayNumericalMethod<L, T, V, H, D> for ExplicitRungeKutta<Delay, Fixed, T, V, D, O, S, I> {
15 fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &V, phi: &H) -> Result<Evals, Error<T, V>>
16 where
17 F: DDE<L, T, V, D>,
18 {
19 let mut evals = Evals::new();
21 self.t0 = t0;
22 self.t = t0;
23 self.y = *y0;
24 self.t_prev = self.t;
25 self.y_prev = self.y;
26 self.status = Status::Initialized;
27 self.steps = 0;
28 self.history = VecDeque::new();
29
30 let mut lags = [T::zero(); L];
32 let mut yd = [V::zeros(); L];
33
34 if L > 0 {
36 dde.lags(self.t, &self.y, &mut lags);
37 for i in 0..L {
38 if lags[i] <= T::zero() {
39 return Err(Error::BadInput {
40 msg: "All lags must be positive.".to_string(),
41 });
42 }
43 let t_delayed = self.t - lags[i];
44 if (t_delayed - t0) * (tf - t0).signum() > T::default_epsilon() {
46 return Err(Error::BadInput {
47 msg: format!(
48 "Initial delayed time {} is out of history range (t <= {}).",
49 t_delayed, t0
50 ),
51 });
52 }
53 yd[i] = phi(t_delayed);
54 }
55 }
56
57 dde.diff(self.t, &self.y, &yd, &mut self.dydt);
59 evals.fcn += 1;
60 self.dydt_prev = self.dydt; self.history.push_back((self.t, self.y, self.dydt));
62
63 if self.h0 == T::zero() {
65 let duration = (tf - t0).abs();
67 let default_steps = T::from_usize(100).unwrap();
68 self.h0 = duration / default_steps;
69 }
70
71 match validate_step_size_parameters::<T, V, D>(self.h0, self.h_min, self.h_max, t0, tf) {
73 Ok(h0) => self.h = h0,
74 Err(status) => return Err(status),
75 }
76 Ok(evals)
77 }
78
79 fn step<F>(&mut self, dde: &F, phi: &H) -> Result<Evals, Error<T, V>>
80 where
81 F: DDE<L, T, V, D>,
82 {
83 let mut evals = Evals::new();
84
85 if self.steps >= self.max_steps {
87 self.status = Status::Error(Error::MaxSteps { t: self.t, y: self.y });
88 return Err(Error::MaxSteps { t: self.t, y: self.y });
89 }
90 self.steps += 1;
91
92 let mut lags = [T::zero(); L];
94 let mut yd = [V::zeros(); L];
95
96 self.k[0] = self.dydt; let mut min_lag_abs = T::infinity();
99 if L > 0 {
100 let y_pred_for_lags = self.y + self.k[0] * self.h;
102 dde.lags(self.t + self.h, &y_pred_for_lags, &mut lags);
103 for i in 0..L {
104 min_lag_abs = min_lag_abs.min(lags[i].abs());
105 }
106 }
107
108 let max_iter: usize = if L > 0 && min_lag_abs < self.h.abs() && min_lag_abs > T::zero() {
110 5
111 } else {
112 1
113 };
114
115 let mut y_next_candidate_iter = self.y; let mut dydt_next_candidate_iter = V::zeros(); let mut y_prev_candidate_iter = self.y; let mut dde_iteration_failed = false;
119
120 for iter_idx in 0..max_iter {
122 if iter_idx > 0 {
123 y_prev_candidate_iter = y_next_candidate_iter;
124 }
125
126 for i in 1..self.stages {
128 let mut y_stage = self.y;
129 for j in 0..i {
130 y_stage += self.k[j] * (self.a[i][j] * self.h);
131 }
132 if L > 0 {
134 dde.lags(self.t + self.c[i] * self.h, &y_stage, &mut lags);
135 self.lagvals(self.t + self.c[i] * self.h, &lags, &mut yd, phi);
136 }
137 dde.diff(self.t + self.c[i] * self.h, &y_stage, &yd, &mut self.k[i]);
138 }
139 evals.fcn += self.stages - 1; let mut y_next = self.y;
143 for i in 0..self.stages {
144 y_next += self.k[i] * (self.b[i] * self.h);
145 }
146
147 if max_iter > 1 && iter_idx > 0 {
149 let mut dde_iteration_error = T::zero();
150 let n_dim = self.y.len();
151 for i_dim in 0..n_dim {
152 let scale = T::from_f64(1e-10).unwrap() + y_prev_candidate_iter.get(i_dim).abs().max(y_next.get(i_dim).abs());
153 if scale > T::zero() {
154 let diff_val = y_next.get(i_dim) - y_prev_candidate_iter.get(i_dim);
155 dde_iteration_error += (diff_val / scale).powi(2);
156 }
157 }
158 if n_dim > 0 {
159 dde_iteration_error = (dde_iteration_error / T::from_usize(n_dim).unwrap()).sqrt();
160 }
161
162 if dde_iteration_error <= T::from_f64(1e-6).unwrap() {
163 break; }
165 if iter_idx == max_iter - 1 { dde_iteration_failed = dde_iteration_error > T::from_f64(1e-6).unwrap();
167 }
168 }
169 y_next_candidate_iter = y_next; if L > 0 {
173 dde.lags(self.t + self.h, &y_next_candidate_iter, &mut lags);
174 self.lagvals(self.t + self.h, &lags, &mut yd, phi);
175 }
176 dde.diff(self.t + self.h, &y_next_candidate_iter, &yd, &mut dydt_next_candidate_iter);
177 evals.fcn += 1;
178 } if dde_iteration_failed {
182 let sign = self.h.signum();
183 self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * sign;
184 if L > 0 && min_lag_abs > T::zero() && self.h.abs() < T::from_f64(2.0).unwrap() * min_lag_abs {
186 self.h = min_lag_abs * sign; }
188 self.status = Status::RejectedStep; return Ok(evals); }
191
192 self.t_prev = self.t;
194 self.y_prev = self.y;
195 self.dydt_prev = self.dydt;
196
197 self.t += self.h;
199 self.y = y_next_candidate_iter;
200
201 if self.fsal {
203 self.dydt = self.k[S - 1];
205 } else {
206 if L > 0 {
208 dde.lags(self.t, &self.y, &mut lags);
209 self.lagvals(self.t, &lags, &mut yd, phi);
210 }
211 dde.diff(self.t, &self.y, &yd, &mut self.dydt);
212 evals.fcn += 1;
213 }
214
215 if self.bi.is_some() {
217 for i in 0..(I - S) { let mut y_stage_dense = self.y_prev; for j in 0..self.stages + i { y_stage_dense += self.k[j] * (self.a[self.stages + i][j] * self.h);
222 }
223 if L > 0 {
225 dde.lags(self.t_prev + self.c[self.stages + i] * self.h, &y_stage_dense, &mut lags);
226 self.lagvals(self.t_prev + self.c[self.stages + i] * self.h, &lags, &mut yd, phi);
227 }
228 dde.diff(self.t_prev + self.c[self.stages + i] * self.h, &y_stage_dense, &yd, &mut self.k[self.stages + i]);
229 }
230 evals.fcn += I - S; }
232
233 self.history.push_back((self.t, self.y, self.dydt));
235 if let Some(max_delay) = self.max_delay {
236 let cutoff_time = self.t - max_delay;
237 while let Some((t_front, _, _)) = self.history.get(1){
238 if *t_front < cutoff_time {
239 self.history.pop_front();
240 } else {
241 break; }
243 }
244 }
245
246 self.status = Status::Solving;
247 Ok(evals)
248 }
249
250 fn t(&self) -> T { self.t }
251 fn y(&self) -> &V { &self.y }
252 fn t_prev(&self) -> T { self.t_prev }
253 fn y_prev(&self) -> &V { &self.y_prev }
254 fn h(&self) -> T { self.h }
255 fn set_h(&mut self, h: T) { self.h = h; }
256 fn status(&self) -> &Status<T, V, D> { &self.status }
257 fn set_status(&mut self, status: Status<T, V, D>) { self.status = status; }
258}
259
260impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> ExplicitRungeKutta<Delay, Fixed, T, V, D, O, S, I> {
261 pub fn lagvals<const L: usize, H>(&mut self, t_stage: T, lags: &[T; L], yd: &mut [V; L], phi: &H)
262 where
263 H: Fn(T) -> V,
264 {
265 for i in 0..L {
266 let t_delayed = t_stage - lags[i];
267
268 if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
270 yd[i] = phi(t_delayed);
271 } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
273 if self.bi.is_some() {
274 let s = (t_delayed - self.t_prev) / self.h;
275
276 let bi_coeffs = self.bi.as_ref().unwrap();
277
278 let mut cont = [T::zero(); I];
279 for i in 0..I {
280 if i < cont.len() && i < bi_coeffs.len() {
281 cont[i] = bi_coeffs[i][self.dense_stages - 1];
282 for j in (0..self.dense_stages - 1).rev() {
283 cont[i] = cont[i] * s + bi_coeffs[i][j];
284 }
285 cont[i] *= s;
286 }
287 }
288
289 let mut y_interp = self.y_prev;
290 for i in 0..I {
291 if i < self.k.len() && i < cont.len() {
292 y_interp += self.k[i] * (cont[i] * self.h);
293 }
294 }
295 yd[i] = y_interp;
296 } else {
297 yd[i] = cubic_hermite_interpolate(
298 self.t_prev,
299 self.t,
300 &self.y_prev,
301 &self.y,
302 &self.dydt_prev,
303 &self.dydt,
304 t_delayed
305 );
306 } } else { let mut found_interpolation = false;
309 let buffer = &self.history;
310 let mut buffer_iter = buffer.iter();
312 if let Some(mut prev_entry) = buffer_iter.next() {
313 for curr_entry in buffer_iter {
314 let (t_left, y_left, dydt_left) = prev_entry;
315 let (t_right, y_right, dydt_right) = curr_entry;
316
317 let is_between = if self.h.signum() > T::zero() {
319 *t_left <= t_delayed && t_delayed <= *t_right
321 } else {
322 *t_right <= t_delayed && t_delayed <= *t_left
324 };
325
326 if is_between {
327 yd[i] = cubic_hermite_interpolate(
329 *t_left,
330 *t_right,
331 y_left,
332 y_right,
333 dydt_left,
334 dydt_right,
335 t_delayed
336 );
337 found_interpolation = true;
338 break;
339 }
340 prev_entry = curr_entry;
341 }
342 }if !found_interpolation {
344 let buffer = &self.history;
346 println!("Buffer contents ({} entries):", buffer.len());
347 for (idx, (t_buf, _, _)) in buffer.iter().enumerate() {
348 if idx < 5 || idx >= buffer.len() - 5 {
349 println!(" [{}] t = {}", idx, t_buf);
350 } else if idx == 5 {
351 println!(" ... ({} more entries) ...", buffer.len() - 10);
352 }
353 }
354 panic!("Insufficient history in history for t_delayed = {} (t_prev = {}, t = {}). Buffer may need to retain more points or there's a logic error in determining interpolation intervals.", t_delayed, self.t_prev, self.t);
355 }
356 }
357 }
358 }
359}
360
361impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> Interpolation<T, V> for ExplicitRungeKutta<Delay, Fixed, T, V, D, O, S, I> {
362 fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {
364 let posneg = self.h.signum();
365 if (t_interp - self.t_prev) * posneg < T::zero() || (t_interp - self.t) * posneg > T::zero() {
366 return Err(Error::OutOfBounds {
367 t_interp,
368 t_prev: self.t_prev,
369 t_curr: self.t,
370 });
371 }
372
373 if self.bi.is_some() {
375 let s = (t_interp - self.t_prev) / self.h_prev;
377
378 let bi = self.bi.as_ref().unwrap();
380
381 let mut cont = [T::zero(); I];
382 for i in 0..self.dense_stages {
384 cont[i] = bi[i][self.order - 1];
386
387 for j in (0..self.order - 1).rev() {
389 cont[i] = cont[i] * s + bi[i][j];
390 }
391
392 cont[i] *= s;
394 }
395
396 let mut y_interp = self.y_prev;
398 for i in 0..I {
399 y_interp += self.k[i] * cont[i] * self.h_prev;
400 }
401
402 Ok(y_interp)
403 } else {
404 let y_interp = cubic_hermite_interpolate(
406 self.t_prev,
407 self.t,
408 &self.y_prev,
409 &self.y,
410 &self.dydt_prev,
411 &self.dydt,
412 t_interp
413 );
414
415 Ok(y_interp)
416 }
417 }
418
419}