1use crate::{
4 Error, Status,
5 alias::Evals,
6 dde::{DDE, DDENumericalMethod, methods::h_init::h_init},
7 interpolate::Interpolation,
8 traits::{CallBackData, Real, State},
9 utils::{constrain_step_size, validate_step_size_parameters},
10};
11use std::collections::VecDeque;
12
13pub struct BS23<const L: usize, T: Real, V: State<T>, H: Fn(T) -> V, D: CallBackData> {
68 pub h0: T,
69 t: T,
70 y: V,
71 h: T,
72 pub rtol: T,
73 pub atol: T,
74 pub h_max: T,
75 pub h_min: T,
76 pub max_steps: usize,
77 pub safe: T,
79 pub fac1: T,
80 pub fac2: T,
81 pub beta: T,
82 pub max_delay: Option<T>,
83 expo1: T,
84 facc1: T,
85 facc2: T,
86 facold: T,
87 fac11: T,
88 fac: T,
89 status: Status<T, V, D>,
90 steps: usize,
91 n_accepted: usize,
92 a: [[T; 4]; 3], b: [T; 3], c: [T; 3], er: [T; 4], k: [V; 4], y_old: V,
98 t_old: T,
99 h_old: T,
100 cont: [V; 4], cont_buffer: VecDeque<(T, T, T, [V; 4])>,
102 phi: Option<H>,
103 t0: T,
104 tf: T,
105 posneg: T,
106 lags: [T; L], yd: [V; L], }
109
110impl<const L: usize, T: Real, V: State<T>, H: Fn(T) -> V, D: CallBackData>
111 DDENumericalMethod<L, T, V, H, D> for BS23<L, T, V, H, D>
112{
113 fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &V, phi: H) -> Result<Evals, Error<T, V>>
114 where
115 F: DDE<L, T, V, D>,
116 {
117 let mut evals = Evals::new();
118
119 self.t = t0;
120 self.y = *y0;
121 self.t0 = t0;
122 self.tf = tf;
123 self.posneg = (tf - t0).signum();
124 self.phi = Some(phi);
125
126 if L > 0 {
127 dde.lags(self.t, &self.y, &mut self.lags);
128 for i in 0..L {
129 if self.lags[i] <= T::zero() {
130 return Err(Error::BadInput {
131 msg: "All lags must be positive.".to_string(),
132 });
133 }
134 let t_delayed = self.t - self.lags[i];
135 if (t_delayed - self.t0) * self.posneg > T::default_epsilon() {
139 return Err(Error::BadInput {
141 msg: format!(
142 "Initial delayed time {} is out of history range (t <= {}).",
143 t_delayed, self.t0
144 ),
145 });
146 }
147 self.yd[i] = (self.phi.as_ref().unwrap())(t_delayed);
148 }
149 }
150 dde.diff(self.t, &self.y, &self.yd, &mut self.k[0]); evals.fcn += 1;
152
153 if self.h0 == T::zero() {
154 let h_est = h_init(
155 dde,
156 self.t,
157 self.tf,
158 &self.y,
159 3,
160 self.rtol,
161 self.atol,
162 self.h_min,
163 self.h_max,
164 self.phi.as_ref().unwrap(),
165 &self.k[0],
166 &mut evals,
167 );
168 self.h0 = h_est;
169 }
170
171 match validate_step_size_parameters::<T, V, D>(
172 self.h0, self.h_min, self.h_max, self.t, self.tf,
173 ) {
174 Ok(h0_validated) => self.h = h0_validated,
175 Err(status) => return Err(status),
176 }
177
178 self.t_old = self.t;
179 self.y_old = self.y;
180 self.h_old = self.h; self.steps = 0;
183 self.n_accepted = 0;
184 self.status = Status::Initialized;
185 Ok(evals)
186 }
187
188 fn step<F>(&mut self, dde: &F) -> Result<Evals, Error<T, V>>
189 where
190 F: DDE<L, T, V, D>,
191 {
192 let mut evals = Evals::new();
193
194 if self.steps >= self.max_steps {
195 self.status = Status::Error(Error::MaxSteps {
196 t: self.t,
197 y: self.y,
198 });
199 return Err(Error::MaxSteps {
200 t: self.t,
201 y: self.y,
202 });
203 }
204
205 let t_current_step_start = self.t;
206 let y_current_step_start = self.y;
207 let k0_current_step_start = self.k[0]; let mut min_lag_abs = T::infinity();
210 if L > 0 {
211 let temp_y_for_lags = y_current_step_start + k0_current_step_start * self.h; dde.lags(
213 t_current_step_start + self.h,
214 &temp_y_for_lags,
215 &mut self.lags,
216 );
217 for i in 0..L {
218 min_lag_abs = min_lag_abs.min(self.lags[i].abs());
219 }
220 }
221
222 let max_iter: usize = if L > 0 && min_lag_abs < self.h.abs() && min_lag_abs > T::zero() {
223 5
224 } else {
225 1
226 };
227
228 let mut y_new_from_iter = y_current_step_start;
229 let mut k_from_iter = self.k; let mut last_y_for_errit_calc = V::zeros();
232 let mut iteration_failed_to_converge = false;
233
234 for iter_idx in 0..max_iter {
235 if iter_idx > 0 {
236 last_y_for_errit_calc = y_new_from_iter;
237 }
238 self.k[0] = k0_current_step_start; let mut ti = t_current_step_start + self.c[0] * self.h; let mut yi = y_current_step_start + self.k[0] * (self.a[0][0] * self.h);
243 if L > 0 {
244 dde.lags(ti, &yi, &mut self.lags);
245 self.lagvals(ti, &yi);
246 }
247 dde.diff(ti, &yi, &self.yd, &mut self.k[1]); ti = t_current_step_start + self.c[1] * self.h; yi = y_current_step_start + self.k[1] * (self.a[1][1] * self.h); if L > 0 {
253 dde.lags(ti, &yi, &mut self.lags);
254 self.lagvals(ti, &yi);
255 }
256 dde.diff(ti, &yi, &self.yd, &mut self.k[2]); y_new_from_iter = y_current_step_start
260 + (self.k[0] * self.b[0] + self.k[1] * self.b[1] + self.k[2] * self.b[2]) * self.h;
261
262 let t_new_val = t_current_step_start + self.h;
264 if L > 0 {
267 dde.lags(t_new_val, &y_new_from_iter, &mut self.lags);
268 self.lagvals(t_new_val, &y_new_from_iter);
269 }
270 dde.diff(t_new_val, &y_new_from_iter, &self.yd, &mut self.k[3]); evals.fcn += 3; k_from_iter.copy_from_slice(&self.k);
274
275 if max_iter > 1 && iter_idx > 0 {
276 let mut errit_val = T::zero();
277 let n_dim = y_current_step_start.len();
278 for i_dim in 0..n_dim {
279 let scale = self.atol
280 + self.rtol
281 * last_y_for_errit_calc
282 .get(i_dim)
283 .abs()
284 .max(y_new_from_iter.get(i_dim).abs());
285 if scale > T::zero() {
286 let diff_val =
287 y_new_from_iter.get(i_dim) - last_y_for_errit_calc.get(i_dim);
288 errit_val += (diff_val / scale).powi(2);
289 }
290 }
291 if n_dim > 0 {
292 errit_val = (errit_val / T::from_usize(n_dim).unwrap()).sqrt();
293 }
294
295 if errit_val <= self.rtol * T::from_f64(0.1).unwrap() {
296 break;
297 }
298 }
299 if iter_idx == max_iter - 1 && max_iter > 1 {
300 iteration_failed_to_converge = true;
301 }
302 }
303
304 if iteration_failed_to_converge {
305 self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * self.posneg;
306 if L > 0
307 && min_lag_abs > T::zero()
308 && self.h.abs() < T::from_f64(2.0).unwrap() * min_lag_abs
309 {
310 self.h = min_lag_abs * self.posneg;
311 }
312 self.h = constrain_step_size(self.h, self.h_min, self.h_max);
313 self.status = Status::RejectedStep;
314 return Ok(evals);
315 }
316
317 let mut err_final = T::zero();
318 let n = y_current_step_start.len();
319 for i in 0..n {
320 let sk = self.atol
321 + self.rtol
322 * y_current_step_start
323 .get(i)
324 .abs()
325 .max(y_new_from_iter.get(i).abs());
326 let err_comp = k_from_iter[0].get(i) * self.er[0]
328 + k_from_iter[1].get(i) * self.er[1]
329 + k_from_iter[2].get(i) * self.er[2]
330 + k_from_iter[3].get(i) * self.er[3];
331 let erri = self.h * err_comp;
332 if sk > T::zero() {
333 err_final += (erri / sk).powi(2);
334 }
335 }
336 if n > 0 {
337 err_final = (err_final / T::from_usize(n).unwrap()).sqrt();
338 }
339
340 self.fac11 = err_final.powf(self.expo1);
341 let fac_beta = if self.beta > T::zero() && self.facold > T::zero() {
342 self.facold.powf(self.beta)
343 } else {
344 T::one()
345 };
346 self.fac = self.fac11 / fac_beta;
347 self.fac = self.facc2.max(self.facc1.min(self.fac / self.safe));
348 let mut h_new_final = self.h / self.fac;
349
350 let t_new_val = t_current_step_start + self.h;
351
352 if err_final <= T::one() {
353 self.facold = err_final.max(T::from_f64(1.0e-4).unwrap());
354 self.n_accepted += 1;
355
356 let k_old_for_cont = k_from_iter[0]; let k_new_for_cont = k_from_iter[3]; let y_diff_cont = y_new_from_iter - y_current_step_start;
365
366 self.cont[0] = y_current_step_start;
367 self.cont[1] = k_old_for_cont * self.h;
368 self.cont[2] = y_diff_cont * T::from_f64(3.0).unwrap()
369 - (k_old_for_cont * T::from_f64(2.0).unwrap() + k_new_for_cont) * self.h;
370 self.cont[3] = y_diff_cont * T::from_f64(-2.0).unwrap()
371 + (k_old_for_cont + k_new_for_cont) * self.h;
372
373 self.cont_buffer
374 .push_back((t_current_step_start, t_new_val, self.h, self.cont));
375
376 if let Some(max_delay_val) = self.max_delay {
377 let prune_time = if self.posneg > T::zero() {
378 t_new_val - max_delay_val
379 } else {
380 t_new_val + max_delay_val
381 };
382 while let Some((buf_t_start, buf_t_end, _, _)) = self.cont_buffer.front() {
383 if (self.posneg > T::zero() && *buf_t_end < prune_time)
384 || (self.posneg < T::zero() && *buf_t_start > prune_time)
385 {
386 self.cont_buffer.pop_front();
387 } else {
388 break;
389 }
390 }
391 }
392
393 self.y_old = y_current_step_start;
394 self.t_old = t_current_step_start;
395 self.h_old = self.h;
396
397 self.k[0] = k_from_iter[3]; self.y = y_new_from_iter;
399 self.t = t_new_val;
400
401 if let Status::RejectedStep = self.status {
402 h_new_final = self.h_old.min(h_new_final);
403 self.status = Status::Solving;
404 }
405 } else {
406 h_new_final = self.h / self.facc1.min(self.fac11 / self.safe);
407 self.status = Status::RejectedStep;
408 }
409
410 self.steps += 1;
411 self.h = constrain_step_size(h_new_final, self.h_min, self.h_max);
412 Ok(evals)
413 }
414
415 fn t(&self) -> T {
416 self.t
417 }
418 fn y(&self) -> &V {
419 &self.y
420 }
421 fn t_prev(&self) -> T {
422 self.t_old
423 }
424 fn y_prev(&self) -> &V {
425 &self.y_old
426 }
427 fn h(&self) -> T {
428 self.h
429 }
430 fn set_h(&mut self, h: T) {
431 self.h = h;
432 }
433 fn status(&self) -> &Status<T, V, D> {
434 &self.status
435 }
436 fn set_status(&mut self, status: Status<T, V, D>) {
437 self.status = status;
438 }
439}
440
441impl<const L: usize, T: Real, V: State<T>, H: Fn(T) -> V, D: CallBackData> Interpolation<T, V>
442 for BS23<L, T, V, H, D>
443{
444 fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {
445 if (t_interp - self.t_old) * self.posneg < T::zero()
447 || (t_interp - self.t) * self.posneg > T::zero()
448 {
449 if (t_interp - self.t_old).abs() > T::default_epsilon()
451 && (t_interp - self.t).abs() > T::default_epsilon()
452 {
453 return Err(Error::OutOfBounds {
454 t_interp,
455 t_prev: self.t_old,
456 t_curr: self.t,
457 });
458 }
459 }
460
461 let s = if self.h_old == T::zero() {
462 if (t_interp - self.t_old).abs() < T::default_epsilon() {
463 T::zero()
464 } else {
465 T::one()
466 }
467 } else {
468 (t_interp - self.t_old) / self.h_old
469 };
470
471 let y_interp = self.cont[0] + (self.cont[1] + (self.cont[2] + self.cont[3] * s) * s) * s;
473 Ok(y_interp)
474 }
475}
476
477impl<const L: usize, T: Real, V: State<T>, H: Fn(T) -> V, D: CallBackData> BS23<L, T, V, H, D> {
478 pub fn new() -> Self {
479 Self::default()
480 }
481
482 pub fn rtol(mut self, rtol: T) -> Self {
484 self.rtol = rtol;
485 self
486 }
487 pub fn atol(mut self, atol: T) -> Self {
488 self.atol = atol;
489 self
490 }
491 pub fn h0(mut self, h0: T) -> Self {
492 self.h0 = h0;
493 self
494 }
495 pub fn h_max(mut self, h_max: T) -> Self {
496 self.h_max = h_max;
497 self
498 }
499 pub fn h_min(mut self, h_min: T) -> Self {
500 self.h_min = h_min;
501 self
502 }
503 pub fn max_steps(mut self, max_steps: usize) -> Self {
504 self.max_steps = max_steps;
505 self
506 }
507 pub fn safe(mut self, safe: T) -> Self {
508 self.safe = safe;
509 self
510 }
511 pub fn fac1(mut self, fac1: T) -> Self {
512 self.fac1 = fac1;
513 self.facc1 = T::one() / fac1;
514 self
515 }
516 pub fn fac2(mut self, fac2: T) -> Self {
517 self.fac2 = fac2;
518 self.facc2 = T::one() / fac2;
519 self
520 }
521 pub fn beta(mut self, beta: T) -> Self {
522 self.beta = beta;
523 self
524 }
525 pub fn max_delay(mut self, max_delay: T) -> Self {
526 self.max_delay = Some(max_delay.abs());
527 self
528 }
529
530 fn lagvals(&mut self, t_stage: T, _y_stage: &V) {
531 for i in 0..L {
533 let t_delayed = t_stage - self.lags[i];
534 if (t_delayed - self.t0) * self.posneg <= T::default_epsilon() {
535 self.yd[i] = (self.phi.as_ref().unwrap())(t_delayed);
537 } else {
538 let mut found_in_buffer = false;
540 for (buf_t_start, buf_t_end, buf_h, buf_cont) in self.cont_buffer.iter().rev() {
541 if (t_delayed - *buf_t_start) * self.posneg >= -T::default_epsilon()
543 && (t_delayed - *buf_t_end) * self.posneg <= T::default_epsilon()
544 {
545 let s = if *buf_h == T::zero() {
546 if (t_delayed - *buf_t_start).abs() < T::default_epsilon() {
547 T::zero()
548 } else {
549 T::one()
550 }
551 } else {
552 (t_delayed - *buf_t_start) / *buf_h
553 };
554 self.yd[i] =
555 buf_cont[0] + (buf_cont[1] + (buf_cont[2] + buf_cont[3] * s) * s) * s;
556 found_in_buffer = true;
557 break;
558 }
559 }
560 if !found_in_buffer {
561 if let Some((buf_t_start, _buf_t_end, buf_h, buf_cont)) =
565 self.cont_buffer.back()
566 {
567 let s = if *buf_h == T::zero() {
568 T::one()
569 } else {
570 (t_delayed - *buf_t_start) / *buf_h
571 }; self.yd[i] =
573 buf_cont[0] + (buf_cont[1] + (buf_cont[2] + buf_cont[3] * s) * s) * s;
574 } else {
575 self.yd[i] = (self.phi.as_ref().unwrap())(t_delayed);
580 }
582 }
583 }
584 }
585 }
586}
587
588const BS23_C: [f64; 3] = [1.0 / 2.0, 3.0 / 4.0, 1.0]; const BS23_A: [[f64; 4]; 3] = [
600 [1.0 / 2.0, 0.0, 0.0, 0.0], [0.0, 3.0 / 4.0, 0.0, 0.0], [2.0 / 9.0, 1.0 / 3.0, 4.0 / 9.0, 0.0], ];
605const BS23_B_SOL: [f64; 3] = [2.0 / 9.0, 1.0 / 3.0, 4.0 / 9.0]; const BS23_E: [f64; 4] = [-5.0 / 72.0, 1.0 / 12.0, 1.0 / 9.0, -1.0 / 8.0]; impl<const L: usize, T: Real, V: State<T>, H: Fn(T) -> V, D: CallBackData> Default
609 for BS23<L, T, V, H, D>
610{
611 fn default() -> Self {
612 let c_conv = BS23_C.map(|x| T::from_f64(x).unwrap());
613 let a_conv = BS23_A.map(|row| row.map(|x| T::from_f64(x).unwrap()));
614 let b_sol_conv = BS23_B_SOL.map(|x| T::from_f64(x).unwrap());
615 let er_conv = BS23_E.map(|x| T::from_f64(x).unwrap());
616
617 let expo1_final = T::one() / T::from_f64(3.0).unwrap();
618
619 let fac1_default = T::from_f64(0.2).unwrap(); let fac2_default = T::from_f64(10.0).unwrap();
621
622 BS23 {
623 t: T::zero(),
624 y: V::zeros(),
625 h: T::zero(),
626 h0: T::zero(),
627 rtol: T::from_f64(1e-3).unwrap(),
628 atol: T::from_f64(1e-6).unwrap(),
629 h_max: T::infinity(),
630 h_min: T::zero(),
631 max_steps: 100_000,
632 safe: T::from_f64(0.9).unwrap(),
633 fac1: fac1_default,
634 fac2: fac2_default,
635 beta: T::zero(), max_delay: None,
637 expo1: expo1_final,
638 facc1: T::one() / fac1_default,
639 facc2: T::one() / fac2_default,
640 facold: T::from_f64(1.0e-4).unwrap(),
641 fac11: T::zero(),
642 fac: T::zero(),
643 status: Status::Uninitialized,
644 steps: 0,
645 n_accepted: 0,
646 a: a_conv,
647 b: b_sol_conv,
648 c: c_conv,
649 er: er_conv,
650 k: [V::zeros(); 4],
651 y_old: V::zeros(),
652 t_old: T::zero(),
653 h_old: T::zero(),
654 cont: [V::zeros(); 4],
655 cont_buffer: VecDeque::new(),
656 phi: None,
657 t0: T::zero(),
658 tf: T::zero(),
659 posneg: T::zero(),
660 lags: [T::zero(); L],
661 yd: [V::zeros(); L],
662 }
663 }
664}