Skip to main content

numra_ode/
radau5.rs

1//! Radau5: 3-stage Radau IIA implicit Runge-Kutta method (5th order, L-stable).
2//!
3//! This is a corrected implementation following Hairer & Wanner's algorithm
4//! ("Solving Ordinary Differential Equations II", §IV.8) and aligned with the
5//! reference implementations in radau5.f (Hairer, Geneva) and SciPy's port
6//! (`scipy.integrate.Radau`).
7//!
8//! ## Mathematical Formulation
9//!
10//! Solves systems of the form:
11//! ```text
12//! M * y'(t) = f(t, y)
13//! ```
14//! where M is a (possibly singular) mass matrix. When M is the identity, this
15//! reduces to the standard ODE form y' = f(t, y).
16//!
17//! **DAE support (index-1):** When M is singular, the system is a
18//! differential-algebraic equation. Rows of M with zero diagonal correspond to
19//! algebraic constraints. Radau5 handles index-1 DAEs natively because it is
20//! L-stable and stiffly accurate (the last stage coincides with the step endpoint).
21//!
22//! ## Algorithm
23//!
24//! At each step, the 3-stage system is transformed via real-Schur decomposition
25//! of the Radau IIA coefficient matrix A. This reduces the 3×3 block system to
26//! one real and one complex linear solve per Newton iteration, cutting the cost
27//! from O(3n)³ to O(n)³ per factorization.
28//!
29//! Step size control uses Hairer's ESTRAD error estimator with refinement on
30//! the first step / rejected steps, combined with Gustafsson's predictive
31//! controller (Hairer's `IWORK(8) = 1`, the default in radau5.f).
32//!
33//! ## Corrections vs. previous revision (5 May 2026)
34//!
35//! 1. (CRITICAL) `error_estimate`: forcing term is `f(t, y)` (the RHS), not
36//!    `y` itself. Same bug existed in the mass-matrix branch (was `M*y`).
37//!    Also: scale by max(|y|, |y_new|) per Hairer's ESTRAD.
38//! 2. Newton initial guess: extrapolated collocation polynomial from the
39//!    previous step (Hairer's `STARTN = 0`, the default).
40//! 3. LU re-factorization is only triggered when the step ratio leaves
41//!    [1.0, 1.2] (Hairer's `QUOT1`/`QUOT2` heuristic).
42//! 4. Off-by-one in Newton convergence-rate check (`newt > 1` → `newt >= 1`).
43//! 5. Step-controller `nit` is the max Newton iteration count (= 7), not 0.
44//! 6. Gustafsson predictive step-size controller (uses prior step's err norm).
45//! 7. `facl` is 8.0 (Hairer's default), not 5.0.
46//! 8. `f0` is tracked across steps and updated after each accepted step
47//!    (required for FIX 1; SciPy does this).
48//!
49//! ## Known Limitations
50//!
51//! - Only supports index-1 DAEs (algebraic variables appear linearly).
52//! - The error estimator falls back to step rejection when the LU solve fails.
53//!
54//! ## Jacobian
55//!
56//! Radau5 calls `OdeSystem::jacobian` for each rebuild. Systems that
57//! override the trait method get an analytical Jacobian for free; systems
58//! that don't fall through to the canonical forward-FD default in
59//! `crate::problem` (`h = sqrt(S::EPSILON) * (1 + |y_j|)`, row-major
60//! dense output).
61//!
62//! ## References
63//! - Hairer, E. & Wanner, G. (1996), "Solving Ordinary Differential Equations II:
64//!   Stiff and Differential-Algebraic Problems", Springer (2nd ed.), §IV.8.
65//! - radau5.f source (E. Hairer), available at <https://www.unige.ch/~hairer/>.
66//! - SciPy `scipy/integrate/_ivp/radau.py` (Apache-2.0 / BSD-3 implementation).
67//!
68//! Author: Moussa Leblouba
69//! Date: 10 February 2026
70//! Modified: 5 May 2026 (corrections per Hairer-Wanner ODE II §IV.8 and SciPy port)
71
72use faer::{ComplexField, Conjugate, SimpleEntity};
73use numra_core::Scalar;
74use numra_linalg::{DenseMatrix, LUFactorization, Matrix};
75
76use crate::error::SolverError;
77use crate::problem::OdeSystem;
78use crate::solver::{Solver, SolverOptions, SolverResult, SolverStats};
79use crate::t_eval::{validate_grid, TEvalEmitter};
80
81/// Radau5 solver for stiff ODEs.
82#[derive(Clone, Debug, Default)]
83pub struct Radau5;
84
85impl Radau5 {
86    /// Create a new Radau5 solver.
87    pub fn new() -> Self {
88        Self
89    }
90}
91
92/// Radau IIA coefficients and Hairer transformation matrices.
93mod coefficients {
94    pub const SQRT6: f64 = 2.449489742783178;
95
96    // Stage points (c values)
97    pub const C1: f64 = (4.0 - SQRT6) / 10.0; // ≈ 0.1551
98    pub const C2: f64 = (4.0 + SQRT6) / 10.0; // ≈ 0.6449
99    #[allow(dead_code)]
100    pub const C3: f64 = 1.0; // Third Radau node (always 1 for Radau IIA)
101
102    // Error estimation coefficients (DD values from Hairer's ESTRAD)
103    pub const DD1: f64 = -(13.0 + 7.0 * SQRT6) / 3.0;
104    pub const DD2: f64 = (-13.0 + 7.0 * SQRT6) / 3.0;
105    pub const DD3: f64 = -1.0 / 3.0;
106
107    // Eigenvalue-related constants from Hairer-Wanner.
108    // 81^(1/3) ≈ 4.3267, 9^(1/3) ≈ 2.0801
109    const CUBERT81: f64 = 4.3267487109222245;
110    const CUBERT9: f64 = 2.080083823051904;
111
112    // U1 = inverse of real eigenvalue
113    const U1_RAW: f64 = (6.0 + CUBERT81 - CUBERT9) / 30.0;
114    pub const U1: f64 = 1.0 / U1_RAW; // ≈ 3.6378342527444957
115
116    // Complex eigenvalue: α ± iβ (after normalization)
117    const ALPH_RAW: f64 = (12.0 - CUBERT81 + CUBERT9) / 60.0;
118    const BETA_RAW: f64 = (CUBERT81 + CUBERT9) * 1.7320508075688772 / 60.0; // sqrt(3)
119    const CNO: f64 = ALPH_RAW * ALPH_RAW + BETA_RAW * BETA_RAW;
120    pub const ALPH: f64 = ALPH_RAW / CNO; // ≈ 2.6812
121    pub const BETA: f64 = BETA_RAW / CNO; // ≈ 3.0504
122
123    // Transformation matrix T (transforms from decoupled to original space).
124    // Z = T * F where F is in transformed space. From radau5.f.
125    pub const T11: f64 = 9.1232394870892942792e-02;
126    pub const T12: f64 = -0.14125529502095420843;
127    pub const T13: f64 = -3.0029194105147424492e-02;
128    pub const T21: f64 = 0.24171793270710701896;
129    pub const T22: f64 = 0.20412935229379993199;
130    pub const T23: f64 = 0.38294211275726193779;
131    pub const T31: f64 = 0.96604818261509293619;
132    pub const T32: f64 = 1.0;
133    #[allow(dead_code)]
134    pub const T33: f64 = 0.0;
135
136    // Inverse transformation matrix TI = T^{-1}. From radau5.f.
137    pub const TI11: f64 = 4.3255798900631553510;
138    pub const TI12: f64 = 0.33919925181580986954;
139    pub const TI13: f64 = 0.54177053993587487119;
140    pub const TI21: f64 = -4.1787185915519047273;
141    pub const TI22: f64 = -0.32768282076106238708;
142    pub const TI23: f64 = 0.47662355450055045196;
143    pub const TI31: f64 = -0.50287263494578687595;
144    pub const TI32: f64 = 2.5719269498556054292;
145    pub const TI33: f64 = -0.59603920482822492497;
146
147    // ---- Dense-output (continuous extension) coefficients ---------------
148    //
149    // The cubic collocation polynomial through {(0, y_old), (C1, y+Z1),
150    // (C2, y+Z2), (1, y+Z3)} can be written as:
151    //   y(t_old + θ*h) = y_old + Σ_k Q[i,k] * θ^(k+1)
152    // where Q[i,k] = Σ_j P[j,k] * Z[j,i].
153    //
154    // We use this to extrapolate the previous step's stages to the current
155    // step's abscissae as a Newton initial guess (Hairer's STARTN = 0).
156    //
157    // Reference: SciPy radau.py (which cites Hairer-Wanner §IV.8).
158    pub const P11: f64 = 13.0 / 3.0 + 7.0 * SQRT6 / 3.0;
159    pub const P12: f64 = -23.0 / 3.0 - 22.0 * SQRT6 / 3.0;
160    pub const P13: f64 = 10.0 / 3.0 + 5.0 * SQRT6;
161    pub const P21: f64 = 13.0 / 3.0 - 7.0 * SQRT6 / 3.0;
162    pub const P22: f64 = -23.0 / 3.0 + 22.0 * SQRT6 / 3.0;
163    pub const P23: f64 = 10.0 / 3.0 - 5.0 * SQRT6;
164    pub const P31: f64 = 1.0 / 3.0;
165    pub const P32: f64 = -8.0 / 3.0;
166    pub const P33: f64 = 10.0 / 3.0;
167}
168
169/// Maximum Newton iterations per step (Hairer's NIT default).
170const MAX_NEWTON_ITER: usize = 7;
171
172impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> Solver<S> for Radau5 {
173    fn solve<Sys: OdeSystem<S>>(
174        problem: &Sys,
175        t0: S,
176        tf: S,
177        y0: &[S],
178        options: &SolverOptions<S>,
179    ) -> Result<SolverResult<S>, SolverError> {
180        let dim = problem.dim();
181        if y0.len() != dim {
182            return Err(SolverError::DimensionMismatch {
183                expected: dim,
184                actual: y0.len(),
185            });
186        }
187
188        let mut t = t0;
189        let mut y = y0.to_vec();
190
191        let direction_init = if tf > t0 { S::ONE } else { -S::ONE };
192        if let Some(grid) = options.t_eval.as_deref() {
193            validate_grid(grid, t0, tf)?;
194        }
195        let mut grid_emitter = options
196            .t_eval
197            .as_deref()
198            .map(|g| TEvalEmitter::new(g, direction_init));
199        let (mut t_out, mut y_out) = if grid_emitter.is_some() {
200            (Vec::new(), Vec::new())
201        } else {
202            (vec![t0], y0.to_vec())
203        };
204        // Buffer holding f(t_old, y_old) right before output emission.
205        let mut dy_old_buf = vec![S::ZERO; dim];
206
207        // Working arrays
208        let mut f0 = vec![S::ZERO; dim];
209        let mut z1 = vec![S::ZERO; dim];
210        let mut z2 = vec![S::ZERO; dim];
211        let mut z3 = vec![S::ZERO; dim];
212        let mut w1 = vec![S::ZERO; dim];
213        let mut w2 = vec![S::ZERO; dim];
214        let mut w3 = vec![S::ZERO; dim];
215        let mut cont = vec![S::ZERO; dim];
216        let mut scal = vec![S::ZERO; dim];
217        let mut y_new = vec![S::ZERO; dim];
218        let mut err = vec![S::ZERO; dim];
219        let mut jac_data = vec![S::ZERO; dim * dim];
220
221        // FIX 2 state: previous-step stages, used to extrapolate the
222        // collocation polynomial as Newton's initial guess.
223        let mut z1_prev = vec![S::ZERO; dim];
224        let mut z2_prev = vec![S::ZERO; dim];
225        let mut z3_prev = vec![S::ZERO; dim];
226        let mut h_prev: S = S::ONE; // dummy until first accepted step
227        let mut have_prev = false;
228
229        // FIX 6 state: Gustafsson predictive controller history.
230        let mut h_abs_old: Option<S> = None;
231        let mut err_norm_old: Option<S> = None;
232
233        // Mass matrix support for DAEs
234        let has_mass = problem.has_mass_matrix();
235        let mass_data = if has_mass {
236            let mut m = vec![S::ZERO; dim * dim];
237            problem.mass_matrix(&mut m);
238            Some(m)
239        } else {
240            None
241        };
242        let mass_ref = mass_data.as_deref();
243
244        let mut stats = SolverStats::default();
245
246        // Initial scaling
247        for i in 0..dim {
248            scal[i] = options.atol + options.rtol * y[i].abs();
249        }
250
251        // FIX 8: f0 is tracked across steps; initialize once and refresh after
252        // every accepted step so that the error estimator (FIX 1) always sees
253        // f(t, y) at the start of the current step.
254        problem.rhs(t, &y, &mut f0);
255        stats.n_eval += 1;
256
257        // Initial step size
258        let mut h = Self::initial_step_size(&y, &f0, options, dim);
259        let h_min = options.h_min;
260        let h_max = (tf - t0).abs() * S::from_f64(0.5);
261
262        // LU factorizations for the decoupled systems.
263        let mut lu_real: Option<LUFactorization<S>> = None;
264        let mut lu_complex: Option<LUFactorization<S>> = None;
265        let mut need_jac = true;
266
267        let mut first = true;
268        let mut reject = false;
269        let mut step_count = 0usize;
270        let direction = if tf > t0 { S::ONE } else { -S::ONE };
271
272        while (tf - t) * direction > S::ZERO {
273            if step_count >= options.max_steps {
274                return Err(SolverError::MaxIterationsExceeded { t: t.to_f64() });
275            }
276
277            // Don't overshoot tf
278            if (t + h - tf) * direction > S::ZERO {
279                h = tf - t;
280            }
281
282            // Recompute Jacobian if needed (set on Newton failure or first step).
283            // The Jacobian depends on (t, y), not on h, so an h change alone does
284            // NOT trigger a Jacobian recompute -- only an LU refactor.
285            //
286            // Delegates to OdeSystem::jacobian, which lets a system override
287            // with an analytical Jacobian (e.g. MOLSystem*) and otherwise
288            // falls through to the canonical FD default in problem.rs.
289            // Costs one extra rhs eval per Jacobian rebuild and one allocation
290            // for the trait default's internal scratch buffers, vs the previous
291            // inlined zero-alloc path; benched within the ≤5% regression
292            // bound on Van der Pol stiff workloads (see commit message).
293            if need_jac {
294                problem.jacobian(t, &y, &mut jac_data);
295                stats.n_jac += 1;
296                need_jac = false;
297                // New Jacobian => existing LU is stale.
298                lu_real = None;
299                lu_complex = None;
300            }
301
302            // FIX 3: refactor LU only when explicitly invalidated. Hairer's
303            // QUOT1 = 1.0 / QUOT2 = 1.2 heuristic skips refactorization when
304            // h has changed by less than ~20% (we set this in the accept branch).
305            if lu_real.is_none() {
306                let (e1, e2) = Self::form_transformed_matrices(&jac_data, h, dim, mass_ref);
307                lu_real = Some(LUFactorization::new(&e1)?);
308                lu_complex = Some(LUFactorization::new(&e2)?);
309                stats.n_lu += 2;
310            }
311
312            // Update scaling for this step
313            for i in 0..dim {
314                scal[i] = options.atol + options.rtol * y[i].abs();
315            }
316
317            // FIX 2: Newton initial guess.
318            //   - On the first step or after a rejection, use zero.
319            //   - Otherwise, extrapolate from the previous step's collocation
320            //     polynomial. This typically reduces Newton iterations 5-7 -> 1-3.
321            let use_extrapolation = !first && !reject && have_prev;
322            if !use_extrapolation {
323                for i in 0..dim {
324                    z1[i] = S::ZERO;
325                    z2[i] = S::ZERO;
326                    z3[i] = S::ZERO;
327                    w1[i] = S::ZERO;
328                    w2[i] = S::ZERO;
329                    w3[i] = S::ZERO;
330                }
331            } else {
332                // Q[i,k] = Σ_j P[j,k] * Z_prev[j,i].
333                // Then Z_init[k,i] = q0*r[k] + q1*r[k]^2 + q2*r[k]^3
334                // where r[k] = h * C[k] / h_prev (relative position in
335                // the previous step's parameterization).
336                let p11 = S::from_f64(coefficients::P11);
337                let p12 = S::from_f64(coefficients::P12);
338                let p13 = S::from_f64(coefficients::P13);
339                let p21 = S::from_f64(coefficients::P21);
340                let p22 = S::from_f64(coefficients::P22);
341                let p23 = S::from_f64(coefficients::P23);
342                let p31 = S::from_f64(coefficients::P31);
343                let p32 = S::from_f64(coefficients::P32);
344                let p33 = S::from_f64(coefficients::P33);
345
346                let c1 = S::from_f64(coefficients::C1);
347                let c2 = S::from_f64(coefficients::C2);
348                let c3 = S::ONE;
349
350                let r1 = h * c1 / h_prev;
351                let r2 = h * c2 / h_prev;
352                let r3 = h * c3 / h_prev;
353
354                for i in 0..dim {
355                    let q0 = z1_prev[i] * p11 + z2_prev[i] * p21 + z3_prev[i] * p31;
356                    let q1 = z1_prev[i] * p12 + z2_prev[i] * p22 + z3_prev[i] * p32;
357                    let q2 = z1_prev[i] * p13 + z2_prev[i] * p23 + z3_prev[i] * p33;
358
359                    z1[i] = q0 * r1 + q1 * r1 * r1 + q2 * r1 * r1 * r1;
360                    z2[i] = q0 * r2 + q1 * r2 * r2 + q2 * r2 * r2 * r2;
361                    z3[i] = q0 * r3 + q1 * r3 * r3 + q2 * r3 * r3 * r3;
362                }
363
364                // W = TI * Z (must be consistent with Z so the Newton inner
365                // loop's back-transform Z = T*W stays stable).
366                let ti11 = S::from_f64(coefficients::TI11);
367                let ti12 = S::from_f64(coefficients::TI12);
368                let ti13 = S::from_f64(coefficients::TI13);
369                let ti21 = S::from_f64(coefficients::TI21);
370                let ti22 = S::from_f64(coefficients::TI22);
371                let ti23 = S::from_f64(coefficients::TI23);
372                let ti31 = S::from_f64(coefficients::TI31);
373                let ti32 = S::from_f64(coefficients::TI32);
374                let ti33 = S::from_f64(coefficients::TI33);
375
376                for i in 0..dim {
377                    w1[i] = ti11 * z1[i] + ti12 * z2[i] + ti13 * z3[i];
378                    w2[i] = ti21 * z1[i] + ti22 * z2[i] + ti23 * z3[i];
379                    w3[i] = ti31 * z1[i] + ti32 * z2[i] + ti33 * z3[i];
380                }
381            }
382
383            // Simplified Newton iteration (in transformed space).
384            let newton_result = Self::newton_iteration(
385                problem,
386                t,
387                h,
388                &y,
389                &scal,
390                &mut z1,
391                &mut z2,
392                &mut z3,
393                &mut w1,
394                &mut w2,
395                &mut w3,
396                &mut cont,
397                lu_real.as_ref().unwrap(),
398                lu_complex.as_ref().unwrap(),
399                mass_ref,
400                &mut stats,
401                dim,
402                options,
403            );
404
405            let (newton_converged, newt_iter) = match newton_result {
406                Ok((converged, iter)) => (converged, iter),
407                Err(_) => (false, MAX_NEWTON_ITER),
408            };
409
410            if !newton_converged {
411                // Newton failed -- reduce step size, force fresh Jacobian.
412                h = h * S::from_f64(0.5);
413                stats.n_reject += 1;
414                reject = true;
415                need_jac = true;
416
417                if h.abs() < h_min {
418                    return Err(SolverError::StepSizeTooSmall {
419                        t: t.to_f64(),
420                        h: h.to_f64(),
421                        h_min: h_min.to_f64(),
422                    });
423                }
424                continue;
425            }
426
427            // Compute new solution candidate: y_new = y + Z3 (since c3 = 1).
428            for i in 0..dim {
429                y_new[i] = y[i] + z3[i];
430            }
431
432            // FIX 1 + FIX 1b: error estimation now takes f0 explicitly and
433            // scales by max(|y|, |y_new|).
434            let err_norm = Self::error_estimate(
435                problem,
436                t,
437                &f0,
438                &z1,
439                &z2,
440                &z3,
441                &y,
442                &y_new,
443                h,
444                options,
445                lu_real.as_ref().unwrap(),
446                &mut err,
447                dim,
448                first,
449                reject,
450                &mut stats,
451                mass_ref,
452            );
453
454            // FIX 5 + FIX 6: Gustafsson predictive controller with safety
455            // factor 0.9 * (2*NIT+1) / (2*NIT + newt_iter), classical bound
456            // [0.2, 8.0] (FIX 7).
457            let safety = Self::safety_factor::<S>(newt_iter, MAX_NEWTON_ITER);
458            let pred = Self::predict_factor(h.abs(), h_abs_old, err_norm, err_norm_old);
459            let factor = (safety * pred).max(S::from_f64(0.2)).min(S::from_f64(8.0));
460
461            if err_norm < S::ONE {
462                // ----- Step accepted -----
463                stats.n_accept += 1;
464
465                // Save stages and h for next step's extrapolation (FIX 2).
466                z1_prev.copy_from_slice(&z1);
467                z2_prev.copy_from_slice(&z2);
468                z3_prev.copy_from_slice(&z3);
469                h_prev = h;
470                have_prev = true;
471
472                // Update Gustafsson state (FIX 6).
473                h_abs_old = Some(h.abs());
474                err_norm_old = Some(err_norm);
475
476                let t_new = t + h;
477                // Save the slope at the start of the step before we
478                // overwrite f0 with the slope at the end (used both for the
479                // next-step error estimator and Hermite interpolation).
480                dy_old_buf.copy_from_slice(&f0);
481                problem.rhs(t_new, &y_new, &mut f0);
482                stats.n_eval += 1;
483
484                if let Some(ref mut emitter) = grid_emitter {
485                    emitter.emit_step(
486                        t,
487                        &y,
488                        &dy_old_buf,
489                        t_new,
490                        &y_new,
491                        &f0,
492                        &mut t_out,
493                        &mut y_out,
494                    );
495                } else {
496                    t_out.push(t_new);
497                    y_out.extend_from_slice(&y_new);
498                }
499
500                t = t_new;
501                y.copy_from_slice(&y_new);
502
503                first = false;
504                reject = false;
505
506                // FIX 3: only invalidate LU when factor >= 1.2 (Hairer's
507                // QUOT2 = 1.2 heuristic). Small step changes don't justify
508                // an LU refactor; keep h and LU in that case.
509                if factor < S::from_f64(1.2) {
510                    // Don't change h, don't refactor LU.
511                } else {
512                    let h_proposed = h * factor;
513                    let h_capped = if h_proposed.abs() > h_max {
514                        if h_proposed > S::ZERO {
515                            h_max
516                        } else {
517                            -h_max
518                        }
519                    } else {
520                        h_proposed
521                    };
522                    h = h_capped;
523                    lu_real = None;
524                    lu_complex = None;
525                }
526            } else {
527                // ----- Step rejected -----
528                stats.n_reject += 1;
529                reject = true;
530
531                // Always shrink h and refactor LU on rejection.
532                h = h * factor;
533                lu_real = None;
534                lu_complex = None;
535
536                if h.abs() < h_min {
537                    return Err(SolverError::StepSizeTooSmall {
538                        t: t.to_f64(),
539                        h: h.to_f64(),
540                        h_min: h_min.to_f64(),
541                    });
542                }
543            }
544
545            step_count += 1;
546        }
547
548        Ok(SolverResult::new(t_out, y_out, dim, stats))
549    }
550}
551
552impl Radau5 {
553    /// Initial step size estimation (simple heuristic, not Hairer's full HINIT).
554    fn initial_step_size<S: Scalar>(y: &[S], f: &[S], options: &SolverOptions<S>, dim: usize) -> S {
555        let mut d0 = S::ZERO;
556        let mut d1 = S::ZERO;
557
558        for i in 0..dim {
559            let sc = options.atol + options.rtol * y[i].abs();
560            d0 = d0 + (y[i] / sc) * (y[i] / sc);
561            d1 = d1 + (f[i] / sc) * (f[i] / sc);
562        }
563
564        let d0 = (d0 / S::from_usize(dim)).sqrt();
565        let d1 = (d1 / S::from_usize(dim)).sqrt();
566
567        let h0 = if d0 < S::from_f64(1e-5) || d1 < S::from_f64(1e-5) {
568            S::from_f64(1e-6)
569        } else {
570            S::from_f64(0.01) * d0 / d1
571        };
572
573        h0.min(options.h_max).max(options.h_min)
574    }
575
576    /// Compute Jacobian by finite differences.
577    /// Form the transformed iteration matrices E1 (n×n real) and E2 (2n×2n
578    /// real form of the complex system).
579    ///
580    /// For ODEs (identity mass):
581    ///   E1 = (U1/h)*I - J
582    ///   E2 = real form of ((α + iβ)/h)*I - J
583    ///
584    /// For DAEs (general mass M):
585    ///   E1 = (U1/h)*M - J
586    ///   E2 uses M instead of I in the diagonal blocks.
587    fn form_transformed_matrices<S>(
588        jac: &[S],
589        h: S,
590        dim: usize,
591        mass: Option<&[S]>,
592    ) -> (DenseMatrix<S>, DenseMatrix<S>)
593    where
594        S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
595    {
596        let fac1 = S::from_f64(coefficients::U1) / h;
597        let mut e1 = DenseMatrix::zeros(dim, dim);
598        for i in 0..dim {
599            for j in 0..dim {
600                let jij = jac[i * dim + j];
601                let mij = match mass {
602                    Some(m) => m[i * dim + j],
603                    None => {
604                        if i == j {
605                            S::ONE
606                        } else {
607                            S::ZERO
608                        }
609                    }
610                };
611                e1.set(i, j, fac1 * mij - jij);
612            }
613        }
614
615        // E2 is the 2n×2n real matrix for the complex system:
616        //   | alphn*M - J   -betan*M       |
617        //   | betan*M        alphn*M - J   |
618        // representing (alphn + i*betan)*M - J acting on (W2 + i*W3).
619        let alphn = S::from_f64(coefficients::ALPH) / h;
620        let betan = S::from_f64(coefficients::BETA) / h;
621        let mut e2 = DenseMatrix::zeros(2 * dim, 2 * dim);
622
623        for i in 0..dim {
624            for j in 0..dim {
625                let jij = jac[i * dim + j];
626                let mij = match mass {
627                    Some(m) => m[i * dim + j],
628                    None => {
629                        if i == j {
630                            S::ONE
631                        } else {
632                            S::ZERO
633                        }
634                    }
635                };
636                e2.set(i, j, alphn * mij - jij);
637                e2.set(i, dim + j, -betan * mij);
638                e2.set(dim + i, j, betan * mij);
639                e2.set(dim + i, dim + j, alphn * mij - jij);
640            }
641        }
642
643        (e1, e2)
644    }
645
646    /// Safety factor for step-size selection (FIX 5).
647    ///
648    /// SciPy's formulation: `safety = 0.9 * (2*NIT + 1) / (2*NIT + n_iter)`.
649    /// This is equivalent in spirit to Hairer's `min(SAFE, CFAC/(NEWT+2*NIT))`
650    /// — the more Newton iterations a step needed, the more conservatively we
651    /// scale h.
652    fn safety_factor<S: Scalar>(n_iter: usize, max_iter: usize) -> S {
653        let num = 0.9 * (2.0 * max_iter as f64 + 1.0);
654        let den = 2.0 * max_iter as f64 + n_iter as f64;
655        S::from_f64(num / den)
656    }
657
658    /// Gustafsson predictive step-size factor (FIX 6).
659    ///
660    /// Returns the multiplier such that `h_new = h * factor`. With history,
661    /// the predictor multiplies the classical err^{-1/4} factor by
662    /// `(h_old/h)^{-1} * (err_old/err)^{1/4}` (capped at 1) to brake when
663    /// errors are trending up.
664    ///
665    /// References:
666    /// - Gustafsson (1991), "Control theoretic techniques for stepsize selection
667    ///   in explicit Runge-Kutta methods".
668    /// - Hairer-Wanner ODE II, §IV.8 (PI controller variant).
669    /// - SciPy `radau.py::predict_factor`.
670    fn predict_factor<S: Scalar>(
671        h_abs: S,
672        h_abs_old: Option<S>,
673        err_norm: S,
674        err_norm_old: Option<S>,
675    ) -> S {
676        let multiplier = match (h_abs_old, err_norm_old) {
677            (Some(h_old), Some(err_old)) if err_norm > S::ZERO && h_old > S::ZERO => {
678                (h_abs / h_old) * (err_old / err_norm).powf(S::from_f64(0.25))
679            }
680            _ => S::ONE,
681        };
682        multiplier.min(S::ONE) * err_norm.powf(S::from_f64(-0.25))
683    }
684
685    /// Newton iteration in transformed space.
686    ///
687    /// Solves the implicit Radau IIA stage equations using a simplified Newton
688    /// iteration on the transformed variables W = TI * Z. The 3×3 block system
689    /// decouples into one real solve (E1) and one complex solve (E2).
690    ///
691    /// For DAEs with mass matrix M, the residual involves M*Z; the algorithm
692    /// computes M*Z in the original space and then transforms.
693    ///
694    /// FIX 4: convergence-rate (theta) check now uses `newt >= 1` (matching
695    /// Hairer's NEWT > 1 in 1-based Fortran), not the original `newt > 1`.
696    #[allow(clippy::too_many_arguments)]
697    fn newton_iteration<S, Sys>(
698        problem: &Sys,
699        t: S,
700        h: S,
701        y: &[S],
702        scal: &[S],
703        z1: &mut [S],
704        z2: &mut [S],
705        z3: &mut [S],
706        w1: &mut [S],
707        w2: &mut [S],
708        w3: &mut [S],
709        cont: &mut [S],
710        lu_real: &LUFactorization<S>,
711        lu_complex: &LUFactorization<S>,
712        mass: Option<&[S]>,
713        stats: &mut SolverStats,
714        dim: usize,
715        options: &SolverOptions<S>,
716    ) -> Result<(bool, usize), SolverError>
717    where
718        S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
719        Sys: OdeSystem<S>,
720    {
721        let c1 = S::from_f64(coefficients::C1);
722        let c2 = S::from_f64(coefficients::C2);
723
724        // FNEWT: Newton tolerance, per Hairer's formula
725        //   FNEWT = max(10*UROUND/RTOL, min(0.03, sqrt(RTOL))).
726        let uround = S::from_f64(1e-16);
727        let fnewt = (S::from_f64(10.0) * uround / options.rtol)
728            .max(S::from_f64(0.03).min(options.rtol.sqrt()));
729
730        // Eigenvalue scalings
731        let fac1 = S::from_f64(coefficients::U1) / h;
732        let alphn = S::from_f64(coefficients::ALPH) / h;
733        let betan = S::from_f64(coefficients::BETA) / h;
734
735        // T and TI (used inside the loop)
736        let ti11 = S::from_f64(coefficients::TI11);
737        let ti12 = S::from_f64(coefficients::TI12);
738        let ti13 = S::from_f64(coefficients::TI13);
739        let ti21 = S::from_f64(coefficients::TI21);
740        let ti22 = S::from_f64(coefficients::TI22);
741        let ti23 = S::from_f64(coefficients::TI23);
742        let ti31 = S::from_f64(coefficients::TI31);
743        let ti32 = S::from_f64(coefficients::TI32);
744        let ti33 = S::from_f64(coefficients::TI33);
745
746        let t11 = S::from_f64(coefficients::T11);
747        let t12 = S::from_f64(coefficients::T12);
748        let t13 = S::from_f64(coefficients::T13);
749        let t21 = S::from_f64(coefficients::T21);
750        let t22 = S::from_f64(coefficients::T22);
751        let t23 = S::from_f64(coefficients::T23);
752        let t31 = S::from_f64(coefficients::T31);
753        let t32 = S::from_f64(coefficients::T32);
754        // T33 = 0
755
756        // State for convergence diagnostics
757        let mut dynold: S = uround;
758        let mut thqold: S = S::ONE;
759        let mut faccon: S = S::ONE;
760
761        let n3 = S::from_usize(3 * dim);
762
763        // Pre-allocated buffers (avoid per-iteration allocation)
764        let mut f2_temp = vec![S::ZERO; dim];
765        let mut f3_temp = vec![S::ZERO; dim];
766        let mut z1_orig = vec![S::ZERO; dim];
767        let mut z2_orig = vec![S::ZERO; dim];
768        let mut z3_orig = vec![S::ZERO; dim];
769        let mut mz1_buf = vec![S::ZERO; dim];
770        let mut mz2_buf = vec![S::ZERO; dim];
771        let mut mz3_buf = vec![S::ZERO; dim];
772        let mut rhs1 = vec![S::ZERO; dim];
773        let mut rhs2 = vec![S::ZERO; dim];
774        let mut rhs3 = vec![S::ZERO; dim];
775        let mut rhs_complex = vec![S::ZERO; 2 * dim];
776
777        for newt in 0..MAX_NEWTON_ITER {
778            // Stage RHS evaluations: Y_i = y + Z_i, F_i = f(t + C_i*h, Y_i)
779            for i in 0..dim {
780                cont[i] = y[i] + z1[i];
781            }
782            problem.rhs(t + c1 * h, cont, z1); // store F1 in z1 temporarily
783
784            for i in 0..dim {
785                cont[i] = y[i] + z2[i];
786            }
787            problem.rhs(t + c2 * h, cont, &mut f2_temp);
788
789            for i in 0..dim {
790                cont[i] = y[i] + z3[i];
791            }
792            problem.rhs(t + h, cont, &mut f3_temp);
793            stats.n_eval += 3;
794
795            // Recompute Z = T*W from W (we just clobbered z1 with F1).
796            for i in 0..dim {
797                z1_orig[i] = t11 * w1[i] + t12 * w2[i] + t13 * w3[i];
798                z2_orig[i] = t21 * w1[i] + t22 * w2[i] + t23 * w3[i];
799                z3_orig[i] = t31 * w1[i] + t32 * w2[i]; // T33 = 0
800            }
801
802            // M*Z for each stage. For identity mass M*Z = Z.
803            if let Some(m) = mass {
804                for i in 0..dim {
805                    mz1_buf[i] = S::ZERO;
806                    mz2_buf[i] = S::ZERO;
807                    mz3_buf[i] = S::ZERO;
808                }
809                for i in 0..dim {
810                    for j in 0..dim {
811                        let mij = m[i * dim + j];
812                        mz1_buf[i] = mz1_buf[i] + mij * z1_orig[j];
813                        mz2_buf[i] = mz2_buf[i] + mij * z2_orig[j];
814                        mz3_buf[i] = mz3_buf[i] + mij * z3_orig[j];
815                    }
816                }
817            } else {
818                mz1_buf.copy_from_slice(&z1_orig);
819                mz2_buf.copy_from_slice(&z2_orig);
820                mz3_buf.copy_from_slice(&z3_orig);
821            }
822
823            // Build transformed RHS.
824            //   RHS_real    = (TI*F)[0]  -  fac1  *  (TI*M*Z)[0]
825            //   RHS_complex = (TI*F)[1]  +  i*(TI*F)[2]
826            //               - (alphn+i*betan) * ((TI*M*Z)[1] + i*(TI*M*Z)[2])
827            for i in 0..dim {
828                let a1 = z1[i]; // F1 still here
829                let a2 = f2_temp[i];
830                let a3 = f3_temp[i];
831                let tf1 = ti11 * a1 + ti12 * a2 + ti13 * a3;
832                let tf2 = ti21 * a1 + ti22 * a2 + ti23 * a3;
833                let tf3 = ti31 * a1 + ti32 * a2 + ti33 * a3;
834
835                let tmz1 = ti11 * mz1_buf[i] + ti12 * mz2_buf[i] + ti13 * mz3_buf[i];
836                let tmz2 = ti21 * mz1_buf[i] + ti22 * mz2_buf[i] + ti23 * mz3_buf[i];
837                let tmz3 = ti31 * mz1_buf[i] + ti32 * mz2_buf[i] + ti33 * mz3_buf[i];
838
839                rhs1[i] = tf1 - fac1 * tmz1;
840                rhs2[i] = tf2 - alphn * tmz2 + betan * tmz3;
841                rhs3[i] = tf3 - alphn * tmz3 - betan * tmz2;
842            }
843
844            // Decoupled linear solves
845            let dw1 = lu_real.solve(&rhs1)?;
846
847            for i in 0..dim {
848                rhs_complex[i] = rhs2[i];
849                rhs_complex[dim + i] = rhs3[i];
850            }
851            let dw_complex = lu_complex.solve(&rhs_complex)?;
852
853            // DYNO: scaled RMS norm of correction (ΔW)
854            let mut dyno = S::ZERO;
855            for i in 0..dim {
856                let denom = scal[i];
857                dyno = dyno
858                    + (dw1[i] / denom) * (dw1[i] / denom)
859                    + (dw_complex[i] / denom) * (dw_complex[i] / denom)
860                    + (dw_complex[dim + i] / denom) * (dw_complex[dim + i] / denom);
861            }
862            dyno = (dyno / n3).sqrt();
863
864            // FIX 4: Convergence-rate check matches Hairer's NEWT > 1 in
865            // 1-based Fortran; in our 0-based loop, `newt >= 1` means
866            // "from the second iteration onwards".
867            if (1..MAX_NEWTON_ITER - 1).contains(&newt) {
868                let thq = dyno / dynold;
869                let theta = if newt == 1 {
870                    thq
871                } else {
872                    (thq * thqold).sqrt()
873                };
874                thqold = thq;
875
876                if theta < S::from_f64(0.99) {
877                    faccon = theta / (S::ONE - theta);
878                    let dyth =
879                        faccon * dyno * theta.powf(S::from_usize(MAX_NEWTON_ITER - 1 - newt))
880                            / fnewt;
881                    if dyth >= S::ONE {
882                        return Ok((false, newt + 1));
883                    }
884                } else {
885                    return Ok((false, newt + 1));
886                }
887            }
888            dynold = dyno.max(uround);
889
890            // Accumulate: W += dW (transformed space)
891            for i in 0..dim {
892                w1[i] = w1[i] + dw1[i];
893                w2[i] = w2[i] + dw_complex[i];
894                w3[i] = w3[i] + dw_complex[dim + i];
895            }
896
897            // Back-transform: Z = T * W
898            for i in 0..dim {
899                z1[i] = t11 * w1[i] + t12 * w2[i] + t13 * w3[i];
900                z2[i] = t21 * w1[i] + t22 * w2[i] + t23 * w3[i];
901                z3[i] = t31 * w1[i] + t32 * w2[i]; // T33 = 0
902            }
903
904            // Convergence test
905            if faccon * dyno <= fnewt {
906                return Ok((true, newt + 1));
907            }
908        }
909
910        Ok((false, MAX_NEWTON_ITER))
911    }
912
913    /// Hairer's ESTRAD error estimate with refinement on first/rejected steps.
914    ///
915    /// The error estimator solves
916    ///   (U1/h * M - J) * ê  =  f(t, y)  +  M (DD1*Z1 + DD2*Z2 + DD3*Z3) / h
917    /// and returns the scaled RMS norm of ê.
918    ///
919    /// FIX 1 (CRITICAL): the forcing term is f(t, y), the RHS, not y itself.
920    ///   The previous version used `y[i]` (and `M*y` in the mass-matrix branch),
921    ///   which is dimensionally wrong and made the estimator return ~|y|
922    ///   regardless of the actual local truncation error.
923    ///
924    /// FIX 1b: scale uses max(|y|, |y_new|), per Hairer's ESTRAD.
925    #[allow(clippy::too_many_arguments)]
926    fn error_estimate<S, Sys>(
927        problem: &Sys,
928        t: S,
929        f0: &[S],
930        z1: &[S],
931        z2: &[S],
932        z3: &[S],
933        y: &[S],
934        y_new: &[S],
935        h: S,
936        options: &SolverOptions<S>,
937        lu_real: &LUFactorization<S>,
938        err: &mut [S],
939        dim: usize,
940        first: bool,
941        reject: bool,
942        stats: &mut SolverStats,
943        mass: Option<&[S]>,
944    ) -> S
945    where
946        S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
947        Sys: OdeSystem<S>,
948    {
949        let dd1 = S::from_f64(coefficients::DD1);
950        let dd2 = S::from_f64(coefficients::DD2);
951        let dd3 = S::from_f64(coefficients::DD3);
952
953        // f2 = DD1*Z1 + DD2*Z2 + DD3*Z3 (the integrated stage residual)
954        let mut f2 = vec![S::ZERO; dim];
955        for i in 0..dim {
956            f2[i] = dd1 * z1[i] + dd2 * z2[i] + dd3 * z3[i];
957        }
958
959        let mut cont = vec![S::ZERO; dim];
960        if let Some(m) = mass {
961            // Mass-matrix case: cont = (M*f2)/h + f0
962            let mut mf2 = vec![S::ZERO; dim];
963            for i in 0..dim {
964                for j in 0..dim {
965                    mf2[i] = mf2[i] + m[i * dim + j] * f2[j];
966                }
967            }
968            for i in 0..dim {
969                cont[i] = mf2[i] / h + f0[i]; // FIX 1: f0, not M*y
970            }
971            // Save (M*f2)/h in f2 for the refinement branch below
972            for i in 0..dim {
973                f2[i] = mf2[i] / h;
974            }
975        } else {
976            // Identity mass: cont = f2/h + f0
977            for i in 0..dim {
978                f2[i] = f2[i] / h;
979                cont[i] = f2[i] + f0[i]; // FIX 1: f0, not y
980            }
981        }
982
983        // Solve E1 * ê = cont (E1 = U1/h*M - J already factored)
984        let solved = match lu_real.solve(&cont) {
985            Ok(s) => s,
986            Err(_) => return S::from_f64(1e6),
987        };
988
989        // FIX 1b: scale by max(|y|, |y_new|).
990        let mut err_norm = S::ZERO;
991        for i in 0..dim {
992            err[i] = solved[i];
993            let y_max = y[i].abs().max(y_new[i].abs());
994            let scale = options.atol + options.rtol * y_max;
995            let r = solved[i] / scale;
996            err_norm = err_norm + r * r;
997        }
998        let err_norm = (err_norm / S::from_usize(dim)).sqrt();
999        let err_norm = err_norm.max(S::from_f64(1e-10));
1000
1001        // Refinement: on first/rejected steps, if err >= 1, redo the solve
1002        // using f(t, y + ê) on the right-hand side. This typically halves the
1003        // estimate when transients are dominating.
1004        if err_norm >= S::ONE && (first || reject) {
1005            for i in 0..dim {
1006                cont[i] = y[i] + solved[i];
1007            }
1008            let mut f1 = vec![S::ZERO; dim];
1009            problem.rhs(t, &cont, &mut f1);
1010            stats.n_eval += 1;
1011
1012            for i in 0..dim {
1013                cont[i] = f1[i] + f2[i];
1014            }
1015            let solved2 = match lu_real.solve(&cont) {
1016                Ok(s) => s,
1017                Err(_) => return S::from_f64(1e6),
1018            };
1019
1020            let mut err_norm2 = S::ZERO;
1021            for i in 0..dim {
1022                err[i] = solved2[i];
1023                let y_max = y[i].abs().max(y_new[i].abs());
1024                let scale = options.atol + options.rtol * y_max;
1025                let r = solved2[i] / scale;
1026                err_norm2 = err_norm2 + r * r;
1027            }
1028            let err_norm2 = (err_norm2 / S::from_usize(dim)).sqrt();
1029            return err_norm2.max(S::from_f64(1e-10));
1030        }
1031
1032        err_norm
1033    }
1034}
1035
1036#[cfg(test)]
1037mod tests {
1038    use super::*;
1039    use crate::problem::{DaeProblem, OdeProblem};
1040
1041    #[test]
1042    fn test_radau5_stiff_decay() {
1043        let problem = OdeProblem::new(
1044            |_t, y: &[f64], dydt: &mut [f64]| {
1045                dydt[0] = -100.0 * y[0];
1046            },
1047            0.0,
1048            0.1,
1049            vec![1.0],
1050        );
1051        let options = SolverOptions::default().rtol(1e-2).atol(1e-4);
1052        let result = Radau5::solve(&problem, 0.0, 0.1, &[1.0], &options).unwrap();
1053        assert!(result.success);
1054        let y_final = result.y_final().unwrap();
1055        let exact = (-10.0_f64).exp();
1056        assert!(
1057            (y_final[0] - exact).abs() < 1e-4,
1058            "Error: {}",
1059            (y_final[0] - exact).abs()
1060        );
1061    }
1062
1063    #[test]
1064    fn test_radau5_exponential() {
1065        let problem = OdeProblem::new(
1066            |_t, y: &[f64], dydt: &mut [f64]| {
1067                dydt[0] = y[0];
1068            },
1069            0.0,
1070            1.0,
1071            vec![1.0],
1072        );
1073        let options = SolverOptions::default().rtol(1e-6).atol(1e-8);
1074        let result = Radau5::solve(&problem, 0.0, 1.0, &[1.0], &options).unwrap();
1075        assert!(result.success);
1076        let y_final = result.y_final().unwrap();
1077        let exact = 1.0_f64.exp();
1078        assert!((y_final[0] - exact).abs() < 1e-5);
1079    }
1080
1081    #[test]
1082    fn test_radau5_linear_2d() {
1083        let problem = OdeProblem::new(
1084            |_t, y: &[f64], dydt: &mut [f64]| {
1085                dydt[0] = -y[0] + y[1];
1086                dydt[1] = -y[0] - y[1];
1087            },
1088            0.0,
1089            1.0,
1090            vec![1.0, 0.0],
1091        );
1092        let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
1093        let result = Radau5::solve(&problem, 0.0, 1.0, &[1.0, 0.0], &options).unwrap();
1094        assert!(result.success);
1095    }
1096
1097    #[test]
1098    fn test_radau5_van_der_pol_mild() {
1099        let mu = 10.0;
1100        let problem = OdeProblem::new(
1101            move |_t, y: &[f64], dydt: &mut [f64]| {
1102                dydt[0] = y[1];
1103                dydt[1] = mu * (1.0 - y[0] * y[0]) * y[1] - y[0];
1104            },
1105            0.0,
1106            2.0,
1107            vec![2.0, 0.0],
1108        );
1109        let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
1110        let result = Radau5::solve(&problem, 0.0, 2.0, &[2.0, 0.0], &options);
1111        assert!(result.is_ok());
1112    }
1113
1114    #[test]
1115    fn test_radau5_van_der_pol_stiff() {
1116        let mu = 100.0;
1117        let problem = OdeProblem::new(
1118            move |_t, y: &[f64], dydt: &mut [f64]| {
1119                dydt[0] = y[1];
1120                dydt[1] = mu * (1.0 - y[0] * y[0]) * y[1] - y[0];
1121            },
1122            0.0,
1123            20.0,
1124            vec![2.0, 0.0],
1125        );
1126        let options = SolverOptions::default().rtol(1e-3).atol(1e-5);
1127        let result = Radau5::solve(&problem, 0.0, 20.0, &[2.0, 0.0], &options);
1128        assert!(
1129            result.is_ok(),
1130            "Van der Pol μ=100 failed: {:?}",
1131            result.err()
1132        );
1133    }
1134
1135    #[test]
1136    fn test_radau5_step_efficiency() {
1137        // After the corrections (Hairer's algorithm with Gustafsson controller,
1138        // extrapolated Newton initial guess, and the FIX 1 error-estimator
1139        // bugfix), this problem should accept ~15 steps — comparable to SciPy
1140        // (~11) and Hairer's reference. Before the fix, it took thousands.
1141        let mu = 100.0;
1142        let problem = OdeProblem::new(
1143            move |_t, y: &[f64], dydt: &mut [f64]| {
1144                dydt[0] = y[1];
1145                dydt[1] = mu * (1.0 - y[0] * y[0]) * y[1] - y[0];
1146            },
1147            0.0,
1148            20.0,
1149            vec![2.0, 0.0],
1150        );
1151        let options = SolverOptions::default().rtol(1e-3).atol(1e-5);
1152        let result = Radau5::solve(&problem, 0.0, 20.0, &[2.0, 0.0], &options).unwrap();
1153
1154        // Loose upper bound (would fail dramatically if any FIX regressed).
1155        assert!(
1156            result.stats.n_accept < 200,
1157            "Too many accepted steps: {} (expected < 200, ~15 typical)",
1158            result.stats.n_accept
1159        );
1160        assert!(result.success);
1161    }
1162
1163    #[test]
1164    fn test_radau5_simple_dae() {
1165        // y1' = -y1 + y2       (differential equation)
1166        // 0   =  y1 - y2       (algebraic constraint: y2 = y1)
1167        // Mass matrix: diag(1, 0). Analytical solution: y1 = y2 = 1 (constant).
1168        let dae = DaeProblem::new(
1169            |_t, y: &[f64], dydt: &mut [f64]| {
1170                dydt[0] = -y[0] + y[1];
1171                dydt[1] = y[0] - y[1];
1172            },
1173            |mass: &mut [f64]| {
1174                for i in 0..4 {
1175                    mass[i] = 0.0;
1176                }
1177                mass[0] = 1.0;
1178            },
1179            0.0,
1180            1.0,
1181            vec![1.0, 1.0],
1182            vec![1],
1183        );
1184
1185        let options = SolverOptions::default()
1186            .rtol(1e-4)
1187            .atol(1e-6)
1188            .max_steps(500_000);
1189        let result = Radau5::solve(&dae, 0.0, 1.0, &[1.0, 1.0], &options);
1190
1191        assert!(result.is_ok(), "DAE solve failed: {:?}", result.err());
1192        let sol = result.unwrap();
1193
1194        let yf = sol.y_final().unwrap();
1195        assert!(
1196            (yf[0] - 1.0).abs() < 1e-4,
1197            "y1 deviated: {} (expected 1.0)",
1198            yf[0]
1199        );
1200        assert!(
1201            (yf[1] - 1.0).abs() < 1e-4,
1202            "y2 deviated: {} (expected 1.0)",
1203            yf[1]
1204        );
1205        let constraint = yf[0] - yf[1];
1206        assert!(
1207            constraint.abs() < 1e-4,
1208            "Constraint violated: {} (y1={}, y2={})",
1209            constraint,
1210            yf[0],
1211            yf[1]
1212        );
1213    }
1214
1215    #[test]
1216    fn test_radau5_dae_with_mass_identity() {
1217        // dy/dt = -y, y(0)=1 ⇒ y(t)=exp(-t). Identity mass via DaeProblem.
1218        let dae = DaeProblem::new(
1219            |_t, y: &[f64], dydt: &mut [f64]| {
1220                dydt[0] = -y[0];
1221            },
1222            |mass: &mut [f64]| {
1223                mass[0] = 1.0;
1224            },
1225            0.0,
1226            1.0,
1227            vec![1.0],
1228            vec![],
1229        );
1230
1231        let options = SolverOptions::default().rtol(1e-6).atol(1e-8);
1232        let result = Radau5::solve(&dae, 0.0, 1.0, &[1.0], &options);
1233
1234        assert!(
1235            result.is_ok(),
1236            "DAE with identity mass failed: {:?}",
1237            result.err()
1238        );
1239        let sol = result.unwrap();
1240        let yf = sol.y_final().unwrap();
1241        let exact = (-1.0_f64).exp();
1242        assert!(
1243            (yf[0] - exact).abs() < 1e-5,
1244            "Error: {} (expected {}, got {})",
1245            (yf[0] - exact).abs(),
1246            exact,
1247            yf[0]
1248        );
1249    }
1250
1251    #[test]
1252    fn test_radau5_dae_scaled_mass() {
1253        // 2*y' = -y ⇒ y(t) = exp(-t/2).
1254        let dae = DaeProblem::new(
1255            |_t, y: &[f64], dydt: &mut [f64]| {
1256                dydt[0] = -y[0];
1257            },
1258            |mass: &mut [f64]| {
1259                mass[0] = 2.0;
1260            },
1261            0.0,
1262            1.0,
1263            vec![1.0],
1264            vec![],
1265        );
1266
1267        let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
1268        let result = Radau5::solve(&dae, 0.0, 1.0, &[1.0], &options);
1269
1270        assert!(
1271            result.is_ok(),
1272            "DAE with scaled mass failed: {:?}",
1273            result.err()
1274        );
1275        let sol = result.unwrap();
1276        let yf = sol.y_final().unwrap();
1277        let exact = (-0.5_f64).exp();
1278        assert!(
1279            (yf[0] - exact).abs() < 1e-3,
1280            "Error: {} (expected {}, got {})",
1281            (yf[0] - exact).abs(),
1282            exact,
1283            yf[0]
1284        );
1285    }
1286}