1use faer::{ComplexField, Conjugate, SimpleEntity};
21use numra_core::Scalar;
22use numra_linalg::{DenseMatrix, LUFactorization, Matrix};
23
24use crate::error::SolverError;
25use crate::problem::OdeSystem;
26use crate::solver::{Solver, SolverOptions, SolverResult, SolverStats};
27use crate::t_eval::{validate_grid, TEvalEmitter};
28
29#[derive(Clone, Debug, Default)]
35pub struct Esdirk32;
36
37impl Esdirk32 {
38 pub fn new() -> Self {
39 Self
40 }
41}
42
43mod esdirk32_tableau {
45 pub const GAMMA: f64 = 0.2928932188134525; pub const C: [f64; 3] = [0.0, 2.0 * GAMMA, 1.0];
49
50 pub const A: [[f64; 3]; 3] = [
51 [0.0, 0.0, 0.0],
52 [GAMMA, GAMMA, 0.0],
53 [1.0 - 2.0 * GAMMA, GAMMA, GAMMA],
54 ];
55
56 pub const B: [f64; 3] = [1.0 - 2.0 * GAMMA, GAMMA, GAMMA];
57
58 pub const E: [f64; 3] = [1.0 - 2.0 * GAMMA - 0.5, GAMMA - 0.0, GAMMA - 0.5];
60}
61
62#[derive(Clone, Debug, Default)]
68pub struct Esdirk43;
69
70impl Esdirk43 {
71 pub fn new() -> Self {
72 Self
73 }
74}
75
76mod esdirk43_tableau {
83 pub const GAMMA: f64 = 0.4358665215084590;
84
85 pub const C: [f64; 4] = [0.0, 2.0 * GAMMA, 1.0, 1.0];
86
87 pub const A: [[f64; 4]; 4] = [
88 [0.0, 0.0, 0.0, 0.0],
89 [GAMMA, GAMMA, 0.0, 0.0],
90 [0.4905633884217806, 0.0735700900697604, GAMMA, 0.0],
91 [
92 0.3088099699767466,
93 1.4905633884217800,
94 -1.2352398799069855,
95 GAMMA,
96 ],
97 ];
98
99 pub const B: [f64; 4] = [
100 0.3088099699767466,
101 1.4905633884217800,
102 -1.2352398799069855,
103 GAMMA,
104 ];
105
106 pub const E: [f64; 4] = [
109 0.3088099699767466 - 0.4905633884217806, 1.4905633884217800 - 0.0735700900697604, -1.2352398799069855 - GAMMA, GAMMA, ];
114}
115
116#[derive(Clone, Debug, Default)]
126pub struct Esdirk54;
127
128impl Esdirk54 {
129 pub fn new() -> Self {
130 Self
131 }
132}
133
134mod esdirk54_tableau {
142 pub const GAMMA: f64 = 0.25;
144
145 pub const C: [f64; 6] = [
146 0.0,
147 0.5, 0.14644660940672624, 0.625, 1.04, 1.0,
152 ];
153
154 pub const A: [[f64; 6]; 6] = [
166 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
167 [GAMMA, GAMMA, 0.0, 0.0, 0.0, 0.0],
168 [
169 -0.05177669529663689,
170 -0.05177669529663689,
171 GAMMA,
172 0.0,
173 0.0,
174 0.0,
175 ],
176 [
177 -0.07655460838455727,
178 -0.07655460838455727,
179 0.5281092167691145,
180 GAMMA,
181 0.0,
182 0.0,
183 ],
184 [
185 -0.7274063478261299,
186 -0.7274063478261299,
187 1.5849950617406794,
188 0.6598176339115805,
189 GAMMA,
190 0.0,
191 ],
192 [
193 -0.01558763503571651,
194 -0.01558763503571651,
195 0.3876576709132033,
196 0.5017726195721631,
197 -0.10825502041393352,
198 GAMMA,
199 ],
200 ];
201
202 pub const B: [f64; 6] = [
204 -0.01558763503571651,
205 -0.01558763503571651,
206 0.3876576709132033,
207 0.5017726195721631,
208 -0.10825502041393352,
209 GAMMA,
210 ];
211
212 pub const E: [f64; 6] = [
217 -0.08092570713246382,
218 -0.08092570713246382,
219 0.13516228008303094,
220 0.01879524505002539,
221 0.0256969660063123,
222 -0.01780307687444085,
223 ];
224}
225
226fn solve_esdirk<S, Sys, const STAGES: usize>(
231 problem: &Sys,
232 t0: S,
233 tf: S,
234 y0: &[S],
235 options: &SolverOptions<S>,
236 c: &[f64],
237 a: &[[f64; STAGES]; STAGES],
238 b: &[f64],
239 e: &[f64],
240 gamma: f64,
241 order: usize,
242) -> Result<SolverResult<S>, SolverError>
243where
244 S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
245 Sys: OdeSystem<S>,
246{
247 let dim = problem.dim();
248 if y0.len() != dim {
249 return Err(SolverError::DimensionMismatch {
250 expected: dim,
251 actual: y0.len(),
252 });
253 }
254
255 let mut t = t0;
256 let mut y = y0.to_vec();
257
258 let direction_init = if tf > t0 { S::ONE } else { -S::ONE };
259 if let Some(grid) = options.t_eval.as_deref() {
260 validate_grid(grid, t0, tf)?;
261 }
262 let mut grid_emitter = options
263 .t_eval
264 .as_deref()
265 .map(|g| TEvalEmitter::new(g, direction_init));
266 let (mut t_out, mut y_out) = if grid_emitter.is_some() {
267 (Vec::new(), Vec::new())
268 } else {
269 (vec![t0], y0.to_vec())
270 };
271 let mut dy_old_buf = vec![S::ZERO; dim];
275
276 let mut k: Vec<Vec<S>> = (0..STAGES).map(|_| vec![S::ZERO; dim]).collect();
277 let mut y_stage = vec![S::ZERO; dim];
278 let mut y_new = vec![S::ZERO; dim];
279 let mut err = vec![S::ZERO; dim];
280 let mut jac_data = vec![S::ZERO; dim * dim];
281 let mut f0 = vec![S::ZERO; dim];
282
283 let mut stats = SolverStats::default();
284
285 problem.rhs(t, &y, &mut k[0]);
287 stats.n_eval += 1;
288 f0.copy_from_slice(&k[0]);
289
290 let mut h = initial_step_size(&y, &k[0], options, dim);
291 let h_min = options.h_min;
292 let h_max = options.h_max.min((tf - t0).abs());
293
294 let mut lu: Option<LUFactorization<S>> = None;
296 let mut need_jac = true;
297 let mut jac_h = h;
298
299 let direction = direction_init;
300 let mut step_count = 0_usize;
301 let mut consecutive_failures = 0_usize;
302
303 while (tf - t) * direction > S::from_f64(1e-10) * (tf - t0).abs() {
304 if step_count >= options.max_steps {
305 return Err(SolverError::MaxIterationsExceeded { t: t.to_f64() });
306 }
307
308 if (t + h - tf) * direction > S::ZERO {
309 h = tf - t;
310 }
311
312 h = h.abs().max(h_min) * direction;
313 if h.abs() > h_max {
314 h = h_max * direction;
315 }
316
317 if need_jac {
321 compute_jacobian(problem, t, &y, &f0, &mut jac_data, dim);
322 stats.n_jac += 1;
323 need_jac = false;
324 }
325
326 if lu.is_none() || (h - jac_h).abs() > S::from_f64(1e-10) * h.abs() {
328 let iter_matrix = form_iteration_matrix(&jac_data, h * S::from_f64(gamma), dim);
329 lu = Some(LUFactorization::new(&iter_matrix)?);
330 stats.n_lu += 1;
331 jac_h = h;
332 }
333
334 let step_ok = compute_esdirk_stages::<S, Sys, STAGES>(
336 problem,
337 t,
338 h,
339 &y,
340 c,
341 a,
342 gamma,
343 lu.as_ref().unwrap(),
344 &mut k,
345 &mut y_stage,
346 &mut stats,
347 dim,
348 )?;
349
350 if !step_ok {
351 stats.n_reject += 1;
352 consecutive_failures += 1;
353 h = h * S::from_f64(0.5);
354 need_jac = true;
355
356 if consecutive_failures >= 5 {
357 return Err(SolverError::Other(format!(
358 "Too many consecutive failures at t = {}",
359 t.to_f64()
360 )));
361 }
362 continue;
363 }
364
365 for i in 0..dim {
367 let mut sum_b = S::ZERO;
368 let mut sum_e = S::ZERO;
369 for s in 0..STAGES {
370 sum_b = sum_b + S::from_f64(b[s]) * k[s][i];
371 sum_e = sum_e + S::from_f64(e[s]) * k[s][i];
372 }
373 y_new[i] = y[i] + h * sum_b;
374 err[i] = h * sum_e;
375 }
376
377 let err_norm = error_norm(&err, &y, &y_new, options, dim);
378
379 let safety = S::from_f64(0.9);
380 let fac_max = S::from_f64(3.0);
381 let fac_min = S::from_f64(0.2);
382 let order_f = S::from_usize(order + 1);
383
384 if err_norm <= S::ONE {
385 stats.n_accept += 1;
386 consecutive_failures = 0;
387
388 let t_new = t + h;
389 dy_old_buf.copy_from_slice(&f0);
391 problem.rhs(t_new, &y_new, &mut f0);
392 stats.n_eval += 1;
393
394 if let Some(ref mut emitter) = grid_emitter {
395 emitter.emit_step(
396 t,
397 &y,
398 &dy_old_buf,
399 t_new,
400 &y_new,
401 &f0,
402 &mut t_out,
403 &mut y_out,
404 );
405 } else {
406 t_out.push(t_new);
407 y_out.extend_from_slice(&y_new);
408 }
409
410 t = t_new;
411 y.copy_from_slice(&y_new);
412 k[0].copy_from_slice(&f0);
413
414 let err_safe = err_norm.max(S::EPSILON * S::from_f64(100.0));
415 let fac = safety * err_safe.powf(-S::ONE / order_f);
416 let fac = fac.min(fac_max).max(fac_min);
417 h = h * fac;
418 } else {
419 stats.n_reject += 1;
420 consecutive_failures += 1;
421
422 let err_safe = err_norm.max(S::EPSILON * S::from_f64(100.0));
423 let fac = safety * err_safe.powf(-S::ONE / order_f);
424 let fac = fac.max(fac_min);
425 h = h * fac;
426
427 if consecutive_failures >= 3 {
428 need_jac = true;
429 }
430 }
431
432 if h.abs() < h_min {
433 return Err(SolverError::StepSizeTooSmall {
434 t: t.to_f64(),
435 h: h.to_f64(),
436 h_min: h_min.to_f64(),
437 });
438 }
439
440 step_count += 1;
441 }
442
443 Ok(SolverResult::new(t_out, y_out, dim, stats))
444}
445
446fn compute_esdirk_stages<S, Sys, const STAGES: usize>(
447 problem: &Sys,
448 t: S,
449 h: S,
450 y: &[S],
451 c: &[f64],
452 a: &[[f64; STAGES]; STAGES],
453 gamma: f64,
454 lu: &LUFactorization<S>,
455 k: &mut [Vec<S>],
456 y_stage: &mut [S],
457 stats: &mut SolverStats,
458 dim: usize,
459) -> Result<bool, SolverError>
460where
461 S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
462 Sys: OdeSystem<S>,
463{
464 for s in 1..STAGES {
467 for i in 0..dim {
469 let mut sum = S::ZERO;
470 for j in 0..s {
471 sum = sum + S::from_f64(a[s][j]) * k[j][i];
472 }
473 y_stage[i] = y[i] + h * sum;
474 }
475
476 let t_stage = t + S::from_f64(c[s]) * h;
478 let h_gamma = h * S::from_f64(gamma);
479
480 let mut converged = false;
481 for _iter in 0..10 {
482 let mut f_stage = vec![S::ZERO; dim];
483 problem.rhs(t_stage, y_stage, &mut f_stage);
484 stats.n_eval += 1;
485
486 let mut residual = vec![S::ZERO; dim];
488 let mut res_norm = S::ZERO;
489 for i in 0..dim {
490 let mut sum = S::ZERO;
491 for j in 0..s {
492 sum = sum + S::from_f64(a[s][j]) * k[j][i];
493 }
494 residual[i] = y_stage[i] - y[i] - h * sum - h_gamma * f_stage[i];
495 res_norm = res_norm + residual[i] * residual[i];
496 }
497 res_norm = res_norm.sqrt();
498
499 if res_norm < S::from_f64(1e-10) {
500 k[s].copy_from_slice(&f_stage);
501 converged = true;
502 break;
503 }
504
505 let delta = lu.solve(&residual)?;
507 for i in 0..dim {
508 y_stage[i] = y_stage[i] - delta[i];
509 }
510 }
511
512 if !converged {
513 return Ok(false);
514 }
515 }
516
517 Ok(true)
518}
519
520fn compute_jacobian<S, Sys>(problem: &Sys, t: S, y: &[S], f0: &[S], jac: &mut [S], dim: usize)
521where
522 S: Scalar,
523 Sys: OdeSystem<S>,
524{
525 let h_factor = S::EPSILON.sqrt();
526 let mut y_pert = y.to_vec();
527 let mut f_pert = vec![S::ZERO; dim];
528
529 for j in 0..dim {
530 let yj = y[j];
531 let h = h_factor * (S::ONE + yj.abs());
532 y_pert[j] = yj + h;
533 problem.rhs(t, &y_pert, &mut f_pert);
534 y_pert[j] = yj;
535
536 for i in 0..dim {
537 jac[i * dim + j] = (f_pert[i] - f0[i]) / h;
538 }
539 }
540}
541
542fn form_iteration_matrix<S>(jac: &[S], h_gamma: S, dim: usize) -> DenseMatrix<S>
543where
544 S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
545{
546 let mut m = DenseMatrix::zeros(dim, dim);
547 for i in 0..dim {
548 for j in 0..dim {
549 let jij = jac[i * dim + j];
550 if i == j {
551 m.set(i, j, S::ONE - h_gamma * jij);
552 } else {
553 m.set(i, j, -h_gamma * jij);
554 }
555 }
556 }
557 m
558}
559
560fn initial_step_size<S: Scalar>(y0: &[S], f0: &[S], options: &SolverOptions<S>, dim: usize) -> S {
561 if let Some(h0) = options.h0 {
562 return h0;
563 }
564
565 let mut y_norm = S::ZERO;
566 let mut f_norm = S::ZERO;
567 for i in 0..dim {
568 let sc = options.atol + options.rtol * y0[i].abs();
569 y_norm = y_norm + (y0[i] / sc) * (y0[i] / sc);
570 f_norm = f_norm + (f0[i] / sc) * (f0[i] / sc);
571 }
572 y_norm = (y_norm / S::from_usize(dim)).sqrt();
573 f_norm = (f_norm / S::from_usize(dim)).sqrt();
574
575 if y_norm < S::EPSILON.sqrt() || f_norm < S::EPSILON.sqrt() {
576 S::from_f64(1e-6)
577 } else {
578 (S::from_f64(0.01) * y_norm / f_norm).min(options.h_max)
579 }
580}
581
582fn error_norm<S: Scalar>(
583 err: &[S],
584 y: &[S],
585 y_new: &[S],
586 options: &SolverOptions<S>,
587 dim: usize,
588) -> S {
589 let mut err_norm = S::ZERO;
590 for i in 0..dim {
591 let sc = options.atol + options.rtol * y[i].abs().max(y_new[i].abs());
592 let sc = sc.max(S::from_f64(1e-15));
593 let scaled_err = err[i] / sc;
594 err_norm = err_norm + scaled_err * scaled_err;
595 }
596 (err_norm / S::from_usize(dim)).sqrt()
597}
598
599impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> Solver<S> for Esdirk32 {
604 fn solve<Sys: OdeSystem<S>>(
605 problem: &Sys,
606 t0: S,
607 tf: S,
608 y0: &[S],
609 options: &SolverOptions<S>,
610 ) -> Result<SolverResult<S>, SolverError> {
611 solve_esdirk::<S, Sys, 3>(
612 problem,
613 t0,
614 tf,
615 y0,
616 options,
617 &esdirk32_tableau::C,
618 &esdirk32_tableau::A,
619 &esdirk32_tableau::B,
620 &esdirk32_tableau::E,
621 esdirk32_tableau::GAMMA,
622 2,
623 )
624 }
625}
626
627impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> Solver<S> for Esdirk43 {
628 fn solve<Sys: OdeSystem<S>>(
629 problem: &Sys,
630 t0: S,
631 tf: S,
632 y0: &[S],
633 options: &SolverOptions<S>,
634 ) -> Result<SolverResult<S>, SolverError> {
635 solve_esdirk::<S, Sys, 4>(
636 problem,
637 t0,
638 tf,
639 y0,
640 options,
641 &esdirk43_tableau::C,
642 &esdirk43_tableau::A,
643 &esdirk43_tableau::B,
644 &esdirk43_tableau::E,
645 esdirk43_tableau::GAMMA,
646 3,
647 )
648 }
649}
650
651impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> Solver<S> for Esdirk54 {
652 fn solve<Sys: OdeSystem<S>>(
653 problem: &Sys,
654 t0: S,
655 tf: S,
656 y0: &[S],
657 options: &SolverOptions<S>,
658 ) -> Result<SolverResult<S>, SolverError> {
659 solve_esdirk::<S, Sys, 6>(
660 problem,
661 t0,
662 tf,
663 y0,
664 options,
665 &esdirk54_tableau::C,
666 &esdirk54_tableau::A,
667 &esdirk54_tableau::B,
668 &esdirk54_tableau::E,
669 esdirk54_tableau::GAMMA,
670 4,
671 )
672 }
673}
674
675#[cfg(test)]
676mod tests {
677 use super::*;
678 use crate::problem::OdeProblem;
679
680 #[test]
681 fn test_esdirk32_exponential() {
682 let problem = OdeProblem::new(
683 |_t, y: &[f64], dydt: &mut [f64]| {
684 dydt[0] = -y[0];
685 },
686 0.0,
687 5.0,
688 vec![1.0],
689 );
690 let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
691 let result = Esdirk32::solve(&problem, 0.0, 5.0, &[1.0], &options).unwrap();
692
693 assert!(result.success);
694 let y_final = result.y_final().unwrap();
695 let expected = (-5.0_f64).exp();
696 assert!((y_final[0] - expected).abs() < 1e-3);
697 }
698
699 #[test]
700 fn test_esdirk43_stiff() {
701 let problem = OdeProblem::new(
702 |_t, y: &[f64], dydt: &mut [f64]| {
703 dydt[0] = -50.0 * y[0];
704 },
705 0.0,
706 0.5,
707 vec![1.0],
708 );
709 let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
710 let result = Esdirk43::solve(&problem, 0.0, 0.5, &[1.0], &options).unwrap();
711
712 assert!(result.success);
713 let y_final = result.y_final().unwrap();
714 let expected = (-25.0_f64).exp();
715 assert!((y_final[0] - expected).abs() < 0.01);
716 }
717
718 #[test]
719 fn test_esdirk54_linear_system() {
720 let problem = OdeProblem::new(
721 |_t, y: &[f64], dydt: &mut [f64]| {
722 dydt[0] = -y[0] + y[1];
723 dydt[1] = y[0] - y[1];
724 },
725 0.0,
726 5.0,
727 vec![1.0, 0.0],
728 );
729 let options = SolverOptions::default().rtol(1e-5).atol(1e-7);
730 let result = Esdirk54::solve(&problem, 0.0, 5.0, &[1.0, 0.0], &options).unwrap();
731
732 assert!(result.success);
733 let y_final = result.y_final().unwrap();
734 assert!((y_final[0] + y_final[1] - 1.0).abs() < 1e-4);
736 }
737
738 #[test]
739 fn test_esdirk_van_der_pol() {
740 let mu = 10.0;
741 let problem = OdeProblem::new(
742 move |_t, y: &[f64], dydt: &mut [f64]| {
743 dydt[0] = y[1];
744 dydt[1] = mu * (1.0 - y[0] * y[0]) * y[1] - y[0];
745 },
746 0.0,
747 10.0,
748 vec![2.0, 0.0],
749 );
750 let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
751 let result = Esdirk54::solve(&problem, 0.0, 10.0, &[2.0, 0.0], &options);
752
753 assert!(result.is_ok());
754 }
755
756 #[test]
757 fn test_esdirk_methods_agree() {
758 let problem = OdeProblem::new(
759 |_t, y: &[f64], dydt: &mut [f64]| {
760 dydt[0] = -y[0];
761 },
762 0.0,
763 2.0,
764 vec![1.0],
765 );
766 let options = SolverOptions::default().rtol(1e-3).atol(1e-5);
767
768 let r32 = Esdirk32::solve(&problem, 0.0, 2.0, &[1.0], &options).unwrap();
769 let r43 = Esdirk43::solve(&problem, 0.0, 2.0, &[1.0], &options).unwrap();
770 let r54 = Esdirk54::solve(&problem, 0.0, 2.0, &[1.0], &options).unwrap();
771
772 let y32 = r32.y_final().unwrap()[0];
773 let y43 = r43.y_final().unwrap()[0];
774 let y54 = r54.y_final().unwrap()[0];
775 let expected = (-2.0_f64).exp();
776
777 assert!(
779 (y32 - expected).abs() < 1e-2,
780 "ESDIRK32: got {}, expected {}",
781 y32,
782 expected
783 );
784 assert!(
785 (y43 - expected).abs() < 1e-2,
786 "ESDIRK43: got {}, expected {}",
787 y43,
788 expected
789 );
790 assert!(
791 (y54 - expected).abs() < 1e-2,
792 "ESDIRK54: got {}, expected {}",
793 y54,
794 expected
795 );
796 }
797}