Skip to main content

numra_ode/
esdirk.rs

1//! ESDIRK: Explicit first Stage, Singly Diagonally Implicit Runge-Kutta methods.
2//!
3//! ESDIRK methods are efficient implicit methods for stiff ODEs.
4//! They have the property that all implicit stages share the same diagonal
5//! coefficient, allowing efficient Jacobian reuse.
6//!
7//! ## Available Methods
8//! - `Esdirk32` - ESDIRK2(1)3L\[2\]SA: 3-stage, 2nd order (A-stable, L-stable)
9//! - `Esdirk43` - ESDIRK3(2)4L\[2\]SA: 4-stage, 3rd order (A-stable, L-stable)
10//! - `Esdirk54` - ESDIRK4(3)6L\[2\]SA: 6-stage, 4th order (L-stable, stiffly-accurate)
11//!
12//! ## Reference
13//! - Kennedy, C.A. & Carpenter, M.H. (2016), "Diagonally Implicit Runge-Kutta
14//!   Methods for Ordinary Differential Equations. A Review", NASA/TM-2016-219173.
15//!
16//! Author: Moussa Leblouba
17//! Date: 5 March 2026
18//! Modified: 2 May 2026
19
20use faer::{ComplexField, Conjugate, SimpleEntity};
21use numra_core::Scalar;
22use numra_linalg::{DenseMatrix, LUFactorization, Matrix};
23
24use crate::error::SolverError;
25use crate::problem::OdeSystem;
26use crate::solver::{Solver, SolverOptions, SolverResult, SolverStats};
27use crate::t_eval::{validate_grid, TEvalEmitter};
28
29// ============================================================================
30// ESDIRK3(2) - 3 stages, 2nd order with embedded 1st order
31// ============================================================================
32
33/// ESDIRK3(2): 3-stage, 2nd order ESDIRK method.
34#[derive(Clone, Debug, Default)]
35pub struct Esdirk32;
36
37impl Esdirk32 {
38    pub fn new() -> Self {
39        Self
40    }
41}
42
43/// ESDIRK3(2) coefficients.
44mod esdirk32_tableau {
45    // Diagonal coefficient (gamma)
46    pub const GAMMA: f64 = 0.2928932188134525; // (2 - sqrt(2)) / 2
47
48    pub const C: [f64; 3] = [0.0, 2.0 * GAMMA, 1.0];
49
50    pub const A: [[f64; 3]; 3] = [
51        [0.0, 0.0, 0.0],
52        [GAMMA, GAMMA, 0.0],
53        [1.0 - 2.0 * GAMMA, GAMMA, GAMMA],
54    ];
55
56    pub const B: [f64; 3] = [1.0 - 2.0 * GAMMA, GAMMA, GAMMA];
57
58    // Error estimation (embedded 1st order)
59    pub const E: [f64; 3] = [1.0 - 2.0 * GAMMA - 0.5, GAMMA - 0.0, GAMMA - 0.5];
60}
61
62// ============================================================================
63// ESDIRK4(3) - 4 stages, 3rd order with embedded 2nd order
64// ============================================================================
65
66/// ESDIRK4(3): 4-stage, 3rd order ESDIRK method.
67#[derive(Clone, Debug, Default)]
68pub struct Esdirk43;
69
70impl Esdirk43 {
71    pub fn new() -> Self {
72        Self
73    }
74}
75
76/// ESDIRK4(3) coefficients (Kvaerno3).
77///
78/// Reference: Kvaerno, "Singly diagonally implicit Runge-Kutta methods
79/// with an explicit first stage", BIT Numerical Mathematics 44, 489-502 (2004).
80///
81/// gamma satisfies 6*gamma^3 - 18*gamma^2 + 9*gamma - 1 = 0 (L-stability).
82mod esdirk43_tableau {
83    pub const GAMMA: f64 = 0.4358665215084590;
84
85    pub const C: [f64; 4] = [0.0, 2.0 * GAMMA, 1.0, 1.0];
86
87    pub const A: [[f64; 4]; 4] = [
88        [0.0, 0.0, 0.0, 0.0],
89        [GAMMA, GAMMA, 0.0, 0.0],
90        [0.4905633884217806, 0.0735700900697604, GAMMA, 0.0],
91        [
92            0.3088099699767466,
93            1.4905633884217800,
94            -1.2352398799069855,
95            GAMMA,
96        ],
97    ];
98
99    pub const B: [f64; 4] = [
100        0.3088099699767466,
101        1.4905633884217800,
102        -1.2352398799069855,
103        GAMMA,
104    ];
105
106    // Error coefficients: E = B - B_hat where B_hat is row 2 of A
107    // (an embedded 2nd-order method).
108    pub const E: [f64; 4] = [
109        0.3088099699767466 - 0.4905633884217806, // -0.1817534184450340
110        1.4905633884217800 - 0.0735700900697604, //  1.4169932983520196
111        -1.2352398799069855 - GAMMA,             // -1.6711064014154445
112        GAMMA,                                   //  0.4358665215084590
113    ];
114}
115
116// ============================================================================
117// ESDIRK5(4) - 6 stages, 4th order with embedded 3rd order
118// ============================================================================
119
120/// ESDIRK5(4): 6-stage, 4th order ESDIRK method (L-stable, stiffly-accurate).
121///
122/// Implements ESDIRK4(3)6L\[2\]SA from Kennedy & Carpenter (2016),
123/// NASA/TM-2016-219173 (Table 7). Both the main (4th order) and embedded
124/// (3rd order) methods are L-stable.
125#[derive(Clone, Debug, Default)]
126pub struct Esdirk54;
127
128impl Esdirk54 {
129    pub fn new() -> Self {
130        Self
131    }
132}
133
134/// ESDIRK4(3)6L[2]SA coefficients from Kennedy & Carpenter (2016).
135///
136/// 6-stage, 4th order main method with 3rd order embedded.
137/// gamma = 1/4, stiffly-accurate (b = last row of A), L-stable.
138///
139/// Reference: C.A. Kennedy, M.H. Carpenter, "Diagonally Implicit Runge-Kutta
140/// Methods for Ordinary Differential Equations. A Review", NASA/TM-2016-219173.
141mod esdirk54_tableau {
142    // Diagonal coefficient (gamma = 1/4)
143    pub const GAMMA: f64 = 0.25;
144
145    pub const C: [f64; 6] = [
146        0.0,
147        0.5,                 // 2 * gamma
148        0.14644660940672624, // (2 - sqrt(2)) / 4
149        0.625,               // 5/8
150        1.04,                // 26/25
151        1.0,
152    ];
153
154    // Coefficients from exact formulas involving sqrt(2):
155    // A[2][0] = A[2][1] = (1 - sqrt(2)) / 8
156    // A[3][0] = A[3][1] = (5 - 7*sqrt(2)) / 64
157    // A[3][2] = 7*(1 + sqrt(2)) / 32
158    // A[4][0] = A[4][1] = (-13796 - 54539*sqrt(2)) / 125000
159    // A[4][2] = (506605 + 132109*sqrt(2)) / 437500
160    // A[4][3] = 166*(-97 + 376*sqrt(2)) / 109375
161    // A[5][0] = A[5][1] = (1181 - 987*sqrt(2)) / 13782
162    // A[5][2] = 47*(-267 + 1783*sqrt(2)) / 273343
163    // A[5][3] = -16*(-22922 + 3525*sqrt(2)) / 571953
164    // A[5][4] = -15625*(97 + 376*sqrt(2)) / 90749876
165    pub const A: [[f64; 6]; 6] = [
166        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
167        [GAMMA, GAMMA, 0.0, 0.0, 0.0, 0.0],
168        [
169            -0.05177669529663689,
170            -0.05177669529663689,
171            GAMMA,
172            0.0,
173            0.0,
174            0.0,
175        ],
176        [
177            -0.07655460838455727,
178            -0.07655460838455727,
179            0.5281092167691145,
180            GAMMA,
181            0.0,
182            0.0,
183        ],
184        [
185            -0.7274063478261299,
186            -0.7274063478261299,
187            1.5849950617406794,
188            0.6598176339115805,
189            GAMMA,
190            0.0,
191        ],
192        [
193            -0.01558763503571651,
194            -0.01558763503571651,
195            0.3876576709132033,
196            0.5017726195721631,
197            -0.10825502041393352,
198            GAMMA,
199        ],
200    ];
201
202    // Stiffly-accurate: b = last row of A
203    pub const B: [f64; 6] = [
204        -0.01558763503571651,
205        -0.01558763503571651,
206        0.3876576709132033,
207        0.5017726195721631,
208        -0.10825502041393352,
209        GAMMA,
210    ];
211
212    // Error estimation: E = Bhat - B, where Bhat is the 3rd order embedded method
213    // Bhat = [-480923228411/4982971448372, -480923228411/4982971448372,
214    //          6709447293961/12833189095359, 3513175791894/6748737351361,
215    //         -498863281070/6042575550617, 2077005547802/8945017530137]
216    pub const E: [f64; 6] = [
217        -0.08092570713246382,
218        -0.08092570713246382,
219        0.13516228008303094,
220        0.01879524505002539,
221        0.0256969660063123,
222        -0.01780307687444085,
223    ];
224}
225
226// ============================================================================
227// Generic ESDIRK solver
228// ============================================================================
229
230fn solve_esdirk<S, Sys, const STAGES: usize>(
231    problem: &Sys,
232    t0: S,
233    tf: S,
234    y0: &[S],
235    options: &SolverOptions<S>,
236    c: &[f64],
237    a: &[[f64; STAGES]; STAGES],
238    b: &[f64],
239    e: &[f64],
240    gamma: f64,
241    order: usize,
242) -> Result<SolverResult<S>, SolverError>
243where
244    S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
245    Sys: OdeSystem<S>,
246{
247    let dim = problem.dim();
248    if y0.len() != dim {
249        return Err(SolverError::DimensionMismatch {
250            expected: dim,
251            actual: y0.len(),
252        });
253    }
254
255    let mut t = t0;
256    let mut y = y0.to_vec();
257
258    let direction_init = if tf > t0 { S::ONE } else { -S::ONE };
259    if let Some(grid) = options.t_eval.as_deref() {
260        validate_grid(grid, t0, tf)?;
261    }
262    let mut grid_emitter = options
263        .t_eval
264        .as_deref()
265        .map(|g| TEvalEmitter::new(g, direction_init));
266    let (mut t_out, mut y_out) = if grid_emitter.is_some() {
267        (Vec::new(), Vec::new())
268    } else {
269        (vec![t0], y0.to_vec())
270    };
271    // Slope at the start of the current step. f0 holds the slope at the
272    // current accepted state and is refreshed inside the accept branch; we
273    // snapshot it here before that overwrite for the Hermite emitter.
274    let mut dy_old_buf = vec![S::ZERO; dim];
275
276    let mut k: Vec<Vec<S>> = (0..STAGES).map(|_| vec![S::ZERO; dim]).collect();
277    let mut y_stage = vec![S::ZERO; dim];
278    let mut y_new = vec![S::ZERO; dim];
279    let mut err = vec![S::ZERO; dim];
280    let mut jac_data = vec![S::ZERO; dim * dim];
281    let mut f0 = vec![S::ZERO; dim];
282
283    let mut stats = SolverStats::default();
284
285    // Initial evaluation
286    problem.rhs(t, &y, &mut k[0]);
287    stats.n_eval += 1;
288    f0.copy_from_slice(&k[0]);
289
290    let mut h = initial_step_size(&y, &k[0], options, dim);
291    let h_min = options.h_min;
292    let h_max = options.h_max.min((tf - t0).abs());
293
294    // Jacobian and LU
295    let mut lu: Option<LUFactorization<S>> = None;
296    let mut need_jac = true;
297    let mut jac_h = h;
298
299    let direction = direction_init;
300    let mut step_count = 0_usize;
301    let mut consecutive_failures = 0_usize;
302
303    while (tf - t) * direction > S::from_f64(1e-10) * (tf - t0).abs() {
304        if step_count >= options.max_steps {
305            return Err(SolverError::MaxIterationsExceeded { t: t.to_f64() });
306        }
307
308        if (t + h - tf) * direction > S::ZERO {
309            h = tf - t;
310        }
311
312        h = h.abs().max(h_min) * direction;
313        if h.abs() > h_max {
314            h = h_max * direction;
315        }
316
317        // Recompute Jacobian only when the state has changed (need_jac).
318        // The Jacobian depends on (t, y), NOT on h, so h changes alone
319        // should only trigger an LU refactorization.
320        if need_jac {
321            compute_jacobian(problem, t, &y, &f0, &mut jac_data, dim);
322            stats.n_jac += 1;
323            need_jac = false;
324        }
325
326        // Form and factorize iteration matrix: I - h*gamma*J
327        if lu.is_none() || (h - jac_h).abs() > S::from_f64(1e-10) * h.abs() {
328            let iter_matrix = form_iteration_matrix(&jac_data, h * S::from_f64(gamma), dim);
329            lu = Some(LUFactorization::new(&iter_matrix)?);
330            stats.n_lu += 1;
331            jac_h = h;
332        }
333
334        // Compute stages
335        let step_ok = compute_esdirk_stages::<S, Sys, STAGES>(
336            problem,
337            t,
338            h,
339            &y,
340            c,
341            a,
342            gamma,
343            lu.as_ref().unwrap(),
344            &mut k,
345            &mut y_stage,
346            &mut stats,
347            dim,
348        )?;
349
350        if !step_ok {
351            stats.n_reject += 1;
352            consecutive_failures += 1;
353            h = h * S::from_f64(0.5);
354            need_jac = true;
355
356            if consecutive_failures >= 5 {
357                return Err(SolverError::Other(format!(
358                    "Too many consecutive failures at t = {}",
359                    t.to_f64()
360                )));
361            }
362            continue;
363        }
364
365        // Compute solution and error
366        for i in 0..dim {
367            let mut sum_b = S::ZERO;
368            let mut sum_e = S::ZERO;
369            for s in 0..STAGES {
370                sum_b = sum_b + S::from_f64(b[s]) * k[s][i];
371                sum_e = sum_e + S::from_f64(e[s]) * k[s][i];
372            }
373            y_new[i] = y[i] + h * sum_b;
374            err[i] = h * sum_e;
375        }
376
377        let err_norm = error_norm(&err, &y, &y_new, options, dim);
378
379        let safety = S::from_f64(0.9);
380        let fac_max = S::from_f64(3.0);
381        let fac_min = S::from_f64(0.2);
382        let order_f = S::from_usize(order + 1);
383
384        if err_norm <= S::ONE {
385            stats.n_accept += 1;
386            consecutive_failures = 0;
387
388            let t_new = t + h;
389            // Save start-of-step slope before f0 is refreshed to the new t.
390            dy_old_buf.copy_from_slice(&f0);
391            problem.rhs(t_new, &y_new, &mut f0);
392            stats.n_eval += 1;
393
394            if let Some(ref mut emitter) = grid_emitter {
395                emitter.emit_step(
396                    t,
397                    &y,
398                    &dy_old_buf,
399                    t_new,
400                    &y_new,
401                    &f0,
402                    &mut t_out,
403                    &mut y_out,
404                );
405            } else {
406                t_out.push(t_new);
407                y_out.extend_from_slice(&y_new);
408            }
409
410            t = t_new;
411            y.copy_from_slice(&y_new);
412            k[0].copy_from_slice(&f0);
413
414            let err_safe = err_norm.max(S::EPSILON * S::from_f64(100.0));
415            let fac = safety * err_safe.powf(-S::ONE / order_f);
416            let fac = fac.min(fac_max).max(fac_min);
417            h = h * fac;
418        } else {
419            stats.n_reject += 1;
420            consecutive_failures += 1;
421
422            let err_safe = err_norm.max(S::EPSILON * S::from_f64(100.0));
423            let fac = safety * err_safe.powf(-S::ONE / order_f);
424            let fac = fac.max(fac_min);
425            h = h * fac;
426
427            if consecutive_failures >= 3 {
428                need_jac = true;
429            }
430        }
431
432        if h.abs() < h_min {
433            return Err(SolverError::StepSizeTooSmall {
434                t: t.to_f64(),
435                h: h.to_f64(),
436                h_min: h_min.to_f64(),
437            });
438        }
439
440        step_count += 1;
441    }
442
443    Ok(SolverResult::new(t_out, y_out, dim, stats))
444}
445
446fn compute_esdirk_stages<S, Sys, const STAGES: usize>(
447    problem: &Sys,
448    t: S,
449    h: S,
450    y: &[S],
451    c: &[f64],
452    a: &[[f64; STAGES]; STAGES],
453    gamma: f64,
454    lu: &LUFactorization<S>,
455    k: &mut [Vec<S>],
456    y_stage: &mut [S],
457    stats: &mut SolverStats,
458    dim: usize,
459) -> Result<bool, SolverError>
460where
461    S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
462    Sys: OdeSystem<S>,
463{
464    // Stage 0 is explicit (already computed as f(t, y))
465
466    for s in 1..STAGES {
467        // Compute initial guess from explicit part
468        for i in 0..dim {
469            let mut sum = S::ZERO;
470            for j in 0..s {
471                sum = sum + S::from_f64(a[s][j]) * k[j][i];
472            }
473            y_stage[i] = y[i] + h * sum;
474        }
475
476        // Newton iteration for implicit stage
477        let t_stage = t + S::from_f64(c[s]) * h;
478        let h_gamma = h * S::from_f64(gamma);
479
480        let mut converged = false;
481        for _iter in 0..10 {
482            let mut f_stage = vec![S::ZERO; dim];
483            problem.rhs(t_stage, y_stage, &mut f_stage);
484            stats.n_eval += 1;
485
486            // Residual: y_stage - y - h * sum(a[s][j] * k[j]) - h*gamma*f_stage
487            let mut residual = vec![S::ZERO; dim];
488            let mut res_norm = S::ZERO;
489            for i in 0..dim {
490                let mut sum = S::ZERO;
491                for j in 0..s {
492                    sum = sum + S::from_f64(a[s][j]) * k[j][i];
493                }
494                residual[i] = y_stage[i] - y[i] - h * sum - h_gamma * f_stage[i];
495                res_norm = res_norm + residual[i] * residual[i];
496            }
497            res_norm = res_norm.sqrt();
498
499            if res_norm < S::from_f64(1e-10) {
500                k[s].copy_from_slice(&f_stage);
501                converged = true;
502                break;
503            }
504
505            // Newton correction
506            let delta = lu.solve(&residual)?;
507            for i in 0..dim {
508                y_stage[i] = y_stage[i] - delta[i];
509            }
510        }
511
512        if !converged {
513            return Ok(false);
514        }
515    }
516
517    Ok(true)
518}
519
520fn compute_jacobian<S, Sys>(problem: &Sys, t: S, y: &[S], f0: &[S], jac: &mut [S], dim: usize)
521where
522    S: Scalar,
523    Sys: OdeSystem<S>,
524{
525    let h_factor = S::EPSILON.sqrt();
526    let mut y_pert = y.to_vec();
527    let mut f_pert = vec![S::ZERO; dim];
528
529    for j in 0..dim {
530        let yj = y[j];
531        let h = h_factor * (S::ONE + yj.abs());
532        y_pert[j] = yj + h;
533        problem.rhs(t, &y_pert, &mut f_pert);
534        y_pert[j] = yj;
535
536        for i in 0..dim {
537            jac[i * dim + j] = (f_pert[i] - f0[i]) / h;
538        }
539    }
540}
541
542fn form_iteration_matrix<S>(jac: &[S], h_gamma: S, dim: usize) -> DenseMatrix<S>
543where
544    S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
545{
546    let mut m = DenseMatrix::zeros(dim, dim);
547    for i in 0..dim {
548        for j in 0..dim {
549            let jij = jac[i * dim + j];
550            if i == j {
551                m.set(i, j, S::ONE - h_gamma * jij);
552            } else {
553                m.set(i, j, -h_gamma * jij);
554            }
555        }
556    }
557    m
558}
559
560fn initial_step_size<S: Scalar>(y0: &[S], f0: &[S], options: &SolverOptions<S>, dim: usize) -> S {
561    if let Some(h0) = options.h0 {
562        return h0;
563    }
564
565    let mut y_norm = S::ZERO;
566    let mut f_norm = S::ZERO;
567    for i in 0..dim {
568        let sc = options.atol + options.rtol * y0[i].abs();
569        y_norm = y_norm + (y0[i] / sc) * (y0[i] / sc);
570        f_norm = f_norm + (f0[i] / sc) * (f0[i] / sc);
571    }
572    y_norm = (y_norm / S::from_usize(dim)).sqrt();
573    f_norm = (f_norm / S::from_usize(dim)).sqrt();
574
575    if y_norm < S::EPSILON.sqrt() || f_norm < S::EPSILON.sqrt() {
576        S::from_f64(1e-6)
577    } else {
578        (S::from_f64(0.01) * y_norm / f_norm).min(options.h_max)
579    }
580}
581
582fn error_norm<S: Scalar>(
583    err: &[S],
584    y: &[S],
585    y_new: &[S],
586    options: &SolverOptions<S>,
587    dim: usize,
588) -> S {
589    let mut err_norm = S::ZERO;
590    for i in 0..dim {
591        let sc = options.atol + options.rtol * y[i].abs().max(y_new[i].abs());
592        let sc = sc.max(S::from_f64(1e-15));
593        let scaled_err = err[i] / sc;
594        err_norm = err_norm + scaled_err * scaled_err;
595    }
596    (err_norm / S::from_usize(dim)).sqrt()
597}
598
599// ============================================================================
600// Solver trait implementations
601// ============================================================================
602
603impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> Solver<S> for Esdirk32 {
604    fn solve<Sys: OdeSystem<S>>(
605        problem: &Sys,
606        t0: S,
607        tf: S,
608        y0: &[S],
609        options: &SolverOptions<S>,
610    ) -> Result<SolverResult<S>, SolverError> {
611        solve_esdirk::<S, Sys, 3>(
612            problem,
613            t0,
614            tf,
615            y0,
616            options,
617            &esdirk32_tableau::C,
618            &esdirk32_tableau::A,
619            &esdirk32_tableau::B,
620            &esdirk32_tableau::E,
621            esdirk32_tableau::GAMMA,
622            2,
623        )
624    }
625}
626
627impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> Solver<S> for Esdirk43 {
628    fn solve<Sys: OdeSystem<S>>(
629        problem: &Sys,
630        t0: S,
631        tf: S,
632        y0: &[S],
633        options: &SolverOptions<S>,
634    ) -> Result<SolverResult<S>, SolverError> {
635        solve_esdirk::<S, Sys, 4>(
636            problem,
637            t0,
638            tf,
639            y0,
640            options,
641            &esdirk43_tableau::C,
642            &esdirk43_tableau::A,
643            &esdirk43_tableau::B,
644            &esdirk43_tableau::E,
645            esdirk43_tableau::GAMMA,
646            3,
647        )
648    }
649}
650
651impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> Solver<S> for Esdirk54 {
652    fn solve<Sys: OdeSystem<S>>(
653        problem: &Sys,
654        t0: S,
655        tf: S,
656        y0: &[S],
657        options: &SolverOptions<S>,
658    ) -> Result<SolverResult<S>, SolverError> {
659        solve_esdirk::<S, Sys, 6>(
660            problem,
661            t0,
662            tf,
663            y0,
664            options,
665            &esdirk54_tableau::C,
666            &esdirk54_tableau::A,
667            &esdirk54_tableau::B,
668            &esdirk54_tableau::E,
669            esdirk54_tableau::GAMMA,
670            4,
671        )
672    }
673}
674
675#[cfg(test)]
676mod tests {
677    use super::*;
678    use crate::problem::OdeProblem;
679
680    #[test]
681    fn test_esdirk32_exponential() {
682        let problem = OdeProblem::new(
683            |_t, y: &[f64], dydt: &mut [f64]| {
684                dydt[0] = -y[0];
685            },
686            0.0,
687            5.0,
688            vec![1.0],
689        );
690        let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
691        let result = Esdirk32::solve(&problem, 0.0, 5.0, &[1.0], &options).unwrap();
692
693        assert!(result.success);
694        let y_final = result.y_final().unwrap();
695        let expected = (-5.0_f64).exp();
696        assert!((y_final[0] - expected).abs() < 1e-3);
697    }
698
699    #[test]
700    fn test_esdirk43_stiff() {
701        let problem = OdeProblem::new(
702            |_t, y: &[f64], dydt: &mut [f64]| {
703                dydt[0] = -50.0 * y[0];
704            },
705            0.0,
706            0.5,
707            vec![1.0],
708        );
709        let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
710        let result = Esdirk43::solve(&problem, 0.0, 0.5, &[1.0], &options).unwrap();
711
712        assert!(result.success);
713        let y_final = result.y_final().unwrap();
714        let expected = (-25.0_f64).exp();
715        assert!((y_final[0] - expected).abs() < 0.01);
716    }
717
718    #[test]
719    fn test_esdirk54_linear_system() {
720        let problem = OdeProblem::new(
721            |_t, y: &[f64], dydt: &mut [f64]| {
722                dydt[0] = -y[0] + y[1];
723                dydt[1] = y[0] - y[1];
724            },
725            0.0,
726            5.0,
727            vec![1.0, 0.0],
728        );
729        let options = SolverOptions::default().rtol(1e-5).atol(1e-7);
730        let result = Esdirk54::solve(&problem, 0.0, 5.0, &[1.0, 0.0], &options).unwrap();
731
732        assert!(result.success);
733        let y_final = result.y_final().unwrap();
734        // Conservation: y1 + y2 = 1
735        assert!((y_final[0] + y_final[1] - 1.0).abs() < 1e-4);
736    }
737
738    #[test]
739    fn test_esdirk_van_der_pol() {
740        let mu = 10.0;
741        let problem = OdeProblem::new(
742            move |_t, y: &[f64], dydt: &mut [f64]| {
743                dydt[0] = y[1];
744                dydt[1] = mu * (1.0 - y[0] * y[0]) * y[1] - y[0];
745            },
746            0.0,
747            10.0,
748            vec![2.0, 0.0],
749        );
750        let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
751        let result = Esdirk54::solve(&problem, 0.0, 10.0, &[2.0, 0.0], &options);
752
753        assert!(result.is_ok());
754    }
755
756    #[test]
757    fn test_esdirk_methods_agree() {
758        let problem = OdeProblem::new(
759            |_t, y: &[f64], dydt: &mut [f64]| {
760                dydt[0] = -y[0];
761            },
762            0.0,
763            2.0,
764            vec![1.0],
765        );
766        let options = SolverOptions::default().rtol(1e-3).atol(1e-5);
767
768        let r32 = Esdirk32::solve(&problem, 0.0, 2.0, &[1.0], &options).unwrap();
769        let r43 = Esdirk43::solve(&problem, 0.0, 2.0, &[1.0], &options).unwrap();
770        let r54 = Esdirk54::solve(&problem, 0.0, 2.0, &[1.0], &options).unwrap();
771
772        let y32 = r32.y_final().unwrap()[0];
773        let y43 = r43.y_final().unwrap()[0];
774        let y54 = r54.y_final().unwrap()[0];
775        let expected = (-2.0_f64).exp();
776
777        // Use looser tolerance to account for different method accuracies
778        assert!(
779            (y32 - expected).abs() < 1e-2,
780            "ESDIRK32: got {}, expected {}",
781            y32,
782            expected
783        );
784        assert!(
785            (y43 - expected).abs() < 1e-2,
786            "ESDIRK43: got {}, expected {}",
787            y43,
788            expected
789        );
790        assert!(
791            (y54 - expected).abs() < 1e-2,
792            "ESDIRK54: got {}, expected {}",
793            y54,
794            expected
795        );
796    }
797}