Skip to main content

oxiphysics_core/
numerical_ode.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! Numerical ODE solvers with event detection and dense output.
5//!
6//! Provides classic RK4, adaptive Dormand-Prince RK45 (with FSAL), implicit
7//! Euler, Crank-Nicolson (trapezoidal), BDF2 for stiff systems, zero-crossing
8//! event detection, and trajectory storage with interpolation.
9
10#![allow(dead_code)]
11#![allow(clippy::too_many_arguments)]
12
13use std::f64::consts::PI;
14
15// ─────────────────────────────────────────────────────────────────────────────
16// OdeState
17// ─────────────────────────────────────────────────────────────────────────────
18
19/// Combined time and state vector for an ODE system.
20///
21/// Stores the current time `t` and the state vector `y`, and provides
22/// convenience helpers such as the Euclidean norm of the state.
23#[derive(Debug, Clone, PartialEq)]
24pub struct OdeState {
25    /// Current time.
26    pub t: f64,
27    /// State vector `y(t)`.
28    pub y: Vec<f64>,
29}
30
31impl OdeState {
32    /// Construct a new [`OdeState`] from time `t` and state vector `y`.
33    pub fn new(t: f64, y: Vec<f64>) -> Self {
34        Self { t, y }
35    }
36
37    /// Euclidean norm of the state vector.
38    pub fn norm(&self) -> f64 {
39        self.y.iter().map(|v| v * v).sum::<f64>().sqrt()
40    }
41
42    /// Number of components in the state vector.
43    pub fn dim(&self) -> usize {
44        self.y.len()
45    }
46
47    /// Return a zero state of dimension `n` at time `t`.
48    pub fn zeros(t: f64, n: usize) -> Self {
49        Self { t, y: vec![0.0; n] }
50    }
51
52    /// Component-wise linear interpolation between `self` and `other` at
53    /// parameter `alpha` ∈ \[0, 1\].
54    pub fn lerp(&self, other: &OdeState, alpha: f64) -> OdeState {
55        let t = self.t + alpha * (other.t - self.t);
56        let y = self
57            .y
58            .iter()
59            .zip(other.y.iter())
60            .map(|(a, b)| a + alpha * (b - a))
61            .collect();
62        OdeState { t, y }
63    }
64}
65
66// ─────────────────────────────────────────────────────────────────────────────
67// Helper: vector arithmetic
68// ─────────────────────────────────────────────────────────────────────────────
69
70#[inline]
71fn vec_axpy(a: f64, x: &[f64], y: &[f64]) -> Vec<f64> {
72    x.iter().zip(y.iter()).map(|(xi, yi)| a * xi + yi).collect()
73}
74
75#[inline]
76fn vec_scale(a: f64, x: &[f64]) -> Vec<f64> {
77    x.iter().map(|xi| a * xi).collect()
78}
79
80#[inline]
81fn vec_add(x: &[f64], y: &[f64]) -> Vec<f64> {
82    x.iter().zip(y.iter()).map(|(a, b)| a + b).collect()
83}
84
85#[inline]
86fn vec_sub(x: &[f64], y: &[f64]) -> Vec<f64> {
87    x.iter().zip(y.iter()).map(|(a, b)| a - b).collect()
88}
89
90#[inline]
91fn rms_norm(v: &[f64]) -> f64 {
92    if v.is_empty() {
93        return 0.0;
94    }
95    (v.iter().map(|x| x * x).sum::<f64>() / v.len() as f64).sqrt()
96}
97
98// ─────────────────────────────────────────────────────────────────────────────
99// RK4Integrator
100// ─────────────────────────────────────────────────────────────────────────────
101
102/// Classic 4th-order Runge-Kutta integrator with optional adaptive step size.
103///
104/// In fixed-step mode each call to [`RK4Integrator::step`] advances the state
105/// by exactly `dt`.  The adaptive driver [`RK4Integrator::integrate_adaptive`]
106/// doubles or halves the step based on an embedded 2nd-order (midpoint) error
107/// estimate.
108pub struct RK4Integrator {
109    /// Absolute tolerance used by the adaptive driver.
110    pub atol: f64,
111    /// Relative tolerance used by the adaptive driver.
112    pub rtol: f64,
113}
114
115impl RK4Integrator {
116    /// Construct an integrator with the given tolerances.
117    pub fn new(atol: f64, rtol: f64) -> Self {
118        Self { atol, rtol }
119    }
120
121    /// Construct with default tolerances (1e-6 absolute, 1e-6 relative).
122    pub fn default_tolerances() -> Self {
123        Self {
124            atol: 1e-6,
125            rtol: 1e-6,
126        }
127    }
128
129    /// Perform one fixed RK4 step from state `s` using step size `dt`.
130    ///
131    /// `f(t, y)` is the right-hand side of `dy/dt = f(t, y)`.
132    pub fn step<F>(&self, s: &OdeState, dt: f64, f: &F) -> OdeState
133    where
134        F: Fn(f64, &[f64]) -> Vec<f64>,
135    {
136        let t = s.t;
137        let y = &s.y;
138        let k1 = f(t, y);
139        let y2 = vec_axpy(0.5 * dt, &k1, y);
140        let k2 = f(t + 0.5 * dt, &y2);
141        let y3 = vec_axpy(0.5 * dt, &k2, y);
142        let k3 = f(t + 0.5 * dt, &y3);
143        let y4 = vec_axpy(dt, &k3, y);
144        let k4 = f(t + dt, &y4);
145
146        let n = y.len();
147        let y_new: Vec<f64> = (0..n)
148            .map(|i| y[i] + (dt / 6.0) * (k1[i] + 2.0 * k2[i] + 2.0 * k3[i] + k4[i]))
149            .collect();
150        OdeState::new(t + dt, y_new)
151    }
152
153    /// Integrate from `s0` to time `t_end` using a fixed step `dt`.
154    ///
155    /// Returns all intermediate states including the initial state.
156    pub fn integrate<F>(&self, s0: &OdeState, t_end: f64, dt: f64, f: &F) -> Vec<OdeState>
157    where
158        F: Fn(f64, &[f64]) -> Vec<f64>,
159    {
160        let mut states = vec![s0.clone()];
161        let mut s = s0.clone();
162        while s.t < t_end - 1e-14 {
163            let h = dt.min(t_end - s.t);
164            s = self.step(&s, h, f);
165            states.push(s.clone());
166        }
167        states
168    }
169
170    /// Integrate from `s0` to `t_end` with adaptive step size.
171    ///
172    /// Uses an embedded RK2 error estimate to control the step.  The step is
173    /// accepted when the RMS error is below `atol + rtol * ||y||`.
174    pub fn integrate_adaptive<F>(
175        &self,
176        s0: &OdeState,
177        t_end: f64,
178        dt_init: f64,
179        f: &F,
180    ) -> Vec<OdeState>
181    where
182        F: Fn(f64, &[f64]) -> Vec<f64>,
183    {
184        let mut states = vec![s0.clone()];
185        let mut s = s0.clone();
186        let mut dt = dt_init;
187        let dt_min = 1e-12;
188        let dt_max = t_end - s0.t;
189
190        while s.t < t_end - 1e-14 {
191            let h = dt.min(t_end - s.t).max(dt_min);
192            let s_rk4 = self.step(&s, h, f);
193
194            // Embedded RK2 (midpoint) for error estimate
195            let k1 = f(s.t, &s.y);
196            let y_mid = vec_axpy(0.5 * h, &k1, &s.y);
197            let k2 = f(s.t + 0.5 * h, &y_mid);
198            let y_rk2: Vec<f64> =
199                s.y.iter()
200                    .zip(k2.iter())
201                    .map(|(yi, ki)| yi + h * ki)
202                    .collect();
203
204            let err: Vec<f64> = s_rk4
205                .y
206                .iter()
207                .zip(y_rk2.iter())
208                .map(|(a, b)| a - b)
209                .collect();
210            let tol = self.atol + self.rtol * s_rk4.norm();
211            let e = rms_norm(&err);
212
213            if e <= tol || h <= dt_min {
214                s = s_rk4;
215                states.push(s.clone());
216                // Increase step
217                if e > 0.0 {
218                    dt = (h * (tol / e).powf(0.2)).min(dt_max);
219                } else {
220                    dt = (h * 2.0).min(dt_max);
221                }
222            } else {
223                // Reject, reduce step
224                dt = (h * 0.9 * (tol / e).powf(0.25)).max(dt_min);
225            }
226        }
227        states
228    }
229}
230
231// ─────────────────────────────────────────────────────────────────────────────
232// DormandPrince45 (RK45 with FSAL)
233// ─────────────────────────────────────────────────────────────────────────────
234
235/// Dormand-Prince RK45 adaptive integrator with FSAL (First Same As Last).
236///
237/// This is the method underlying MATLAB's `ode45` and SciPy's `RK45`.  The
238/// 5th-order solution is used to advance the state; the 4th-order solution is
239/// used only for error estimation.  The FSAL property means only 5 new
240/// function evaluations per accepted step (the last stage of step n equals the
241/// first stage of step n+1).
242pub struct DormandPrince45 {
243    /// Absolute error tolerance.
244    pub atol: f64,
245    /// Relative error tolerance.
246    pub rtol: f64,
247    /// Minimum allowed step size.
248    pub dt_min: f64,
249    /// Maximum allowed step size.
250    pub dt_max: f64,
251}
252
253impl DormandPrince45 {
254    // Butcher tableau coefficients
255    const C2: f64 = 1.0 / 5.0;
256    const C3: f64 = 3.0 / 10.0;
257    const C4: f64 = 4.0 / 5.0;
258    const C5: f64 = 8.0 / 9.0;
259    const A21: f64 = 1.0 / 5.0;
260    const A31: f64 = 3.0 / 40.0;
261    const A32: f64 = 9.0 / 40.0;
262    const A41: f64 = 44.0 / 45.0;
263    const A42: f64 = -56.0 / 15.0;
264    const A43: f64 = 32.0 / 9.0;
265    const A51: f64 = 19372.0 / 6561.0;
266    const A52: f64 = -25360.0 / 2187.0;
267    const A53: f64 = 64448.0 / 6561.0;
268    const A54: f64 = -212.0 / 729.0;
269    const A61: f64 = 9017.0 / 3168.0;
270    const A62: f64 = -355.0 / 33.0;
271    const A63: f64 = 46732.0 / 5247.0;
272    const A64: f64 = 49.0 / 176.0;
273    const A65: f64 = -5103.0 / 18656.0;
274
275    // 5th-order weights
276    const B1: f64 = 35.0 / 384.0;
277    const B3: f64 = 500.0 / 1113.0;
278    const B4: f64 = 125.0 / 192.0;
279    const B5: f64 = -2187.0 / 6784.0;
280    const B6: f64 = 11.0 / 84.0;
281
282    // Error coefficients (difference of 5th and 4th order weights)
283    const E1: f64 = 71.0 / 57600.0;
284    const E3: f64 = -71.0 / 16695.0;
285    const E4: f64 = 71.0 / 1920.0;
286    const E5: f64 = -17253.0 / 339200.0;
287    const E6: f64 = 22.0 / 525.0;
288    const E7: f64 = -1.0 / 40.0;
289
290    /// Construct with specified tolerances and step bounds.
291    pub fn new(atol: f64, rtol: f64, dt_min: f64, dt_max: f64) -> Self {
292        Self {
293            atol,
294            rtol,
295            dt_min,
296            dt_max,
297        }
298    }
299
300    /// Construct with default tolerances (1e-6 / 1e-6).
301    pub fn default_tolerances() -> Self {
302        Self {
303            atol: 1e-6,
304            rtol: 1e-6,
305            dt_min: 1e-12,
306            dt_max: f64::INFINITY,
307        }
308    }
309
310    /// Perform one FSAL Dormand-Prince step from state `s` with step `h`.
311    ///
312    /// Returns `(new_state, error_rms, k7)` where `k7` is the last stage
313    /// (reusable as the first stage of the next step via FSAL).
314    pub fn step<F>(
315        &self,
316        s: &OdeState,
317        h: f64,
318        f: &F,
319        k1_in: Option<&Vec<f64>>,
320    ) -> (OdeState, f64, Vec<f64>)
321    where
322        F: Fn(f64, &[f64]) -> Vec<f64>,
323    {
324        let t = s.t;
325        let y = &s.y;
326        let n = y.len();
327
328        let k1 = match k1_in {
329            Some(k) => k.clone(),
330            None => f(t, y),
331        };
332
333        let y2: Vec<f64> = (0..n).map(|i| y[i] + h * Self::A21 * k1[i]).collect();
334        let k2 = f(t + Self::C2 * h, &y2);
335
336        let y3: Vec<f64> = (0..n)
337            .map(|i| y[i] + h * (Self::A31 * k1[i] + Self::A32 * k2[i]))
338            .collect();
339        let k3 = f(t + Self::C3 * h, &y3);
340
341        let y4: Vec<f64> = (0..n)
342            .map(|i| y[i] + h * (Self::A41 * k1[i] + Self::A42 * k2[i] + Self::A43 * k3[i]))
343            .collect();
344        let k4 = f(t + Self::C4 * h, &y4);
345
346        let y5: Vec<f64> = (0..n)
347            .map(|i| {
348                y[i] + h
349                    * (Self::A51 * k1[i]
350                        + Self::A52 * k2[i]
351                        + Self::A53 * k3[i]
352                        + Self::A54 * k4[i])
353            })
354            .collect();
355        let k5 = f(t + Self::C5 * h, &y5);
356
357        let y6: Vec<f64> = (0..n)
358            .map(|i| {
359                y[i] + h
360                    * (Self::A61 * k1[i]
361                        + Self::A62 * k2[i]
362                        + Self::A63 * k3[i]
363                        + Self::A64 * k4[i]
364                        + Self::A65 * k5[i])
365            })
366            .collect();
367        let k6 = f(t + h, &y6);
368
369        let y_new: Vec<f64> = (0..n)
370            .map(|i| {
371                y[i] + h
372                    * (Self::B1 * k1[i]
373                        + Self::B3 * k3[i]
374                        + Self::B4 * k4[i]
375                        + Self::B5 * k5[i]
376                        + Self::B6 * k6[i])
377            })
378            .collect();
379        let k7 = f(t + h, &y_new);
380
381        // Error estimate
382        let err: Vec<f64> = (0..n)
383            .map(|i| {
384                h * (Self::E1 * k1[i]
385                    + Self::E3 * k3[i]
386                    + Self::E4 * k4[i]
387                    + Self::E5 * k5[i]
388                    + Self::E6 * k6[i]
389                    + Self::E7 * k7[i])
390            })
391            .collect();
392
393        let sc: Vec<f64> = y_new
394            .iter()
395            .zip(y.iter())
396            .map(|(yn, y0)| self.atol + self.rtol * yn.abs().max(y0.abs()))
397            .collect();
398        let err_norm = (err
399            .iter()
400            .zip(sc.iter())
401            .map(|(e, s)| (e / s).powi(2))
402            .sum::<f64>()
403            / n as f64)
404            .sqrt();
405
406        (OdeState::new(t + h, y_new), err_norm, k7)
407    }
408
409    /// Integrate from `s0` to `t_end` with adaptive step control.
410    ///
411    /// Returns an [`OdeSolution`] containing all accepted states.
412    pub fn integrate<F>(&self, s0: &OdeState, t_end: f64, dt_init: f64, f: &F) -> OdeSolution
413    where
414        F: Fn(f64, &[f64]) -> Vec<f64>,
415    {
416        let mut states = vec![s0.clone()];
417        let mut s = s0.clone();
418        let mut h = dt_init;
419        let mut k1 = f(s.t, &s.y);
420        let max_steps = 1_000_000usize;
421        let mut n_steps = 0;
422
423        while s.t < t_end - 1e-14 && n_steps < max_steps {
424            h = h.min(t_end - s.t).max(self.dt_min).min(self.dt_max);
425            let (s_new, err, k7) = self.step(&s, h, f, Some(&k1));
426
427            if err <= 1.0 || h <= self.dt_min {
428                s = s_new;
429                k1 = k7; // FSAL
430                states.push(s.clone());
431                // PI step size control
432                if err > 0.0 {
433                    h = (h * 0.9 * err.powf(-0.2)).min(self.dt_max).max(self.dt_min);
434                } else {
435                    h = (h * 5.0).min(self.dt_max);
436                }
437            } else {
438                h = (h * 0.9 * err.powf(-0.25)).max(self.dt_min);
439            }
440            n_steps += 1;
441        }
442
443        OdeSolution::new(states)
444    }
445}
446
447// ─────────────────────────────────────────────────────────────────────────────
448// ImplicitEuler (backward Euler)
449// ─────────────────────────────────────────────────────────────────────────────
450
451/// Implicit (backward) Euler integrator for stiff ODEs.
452///
453/// Solves the nonlinear system `y_{n+1} - y_n - h * f(t_{n+1}, y_{n+1}) = 0`
454/// using fixed-point (Picard) iteration followed by Newton correction steps.
455pub struct ImplicitEuler {
456    /// Maximum Newton iterations per step.
457    pub max_iter: usize,
458    /// Convergence tolerance for Newton iteration.
459    pub tol: f64,
460    /// Finite-difference step size for the Jacobian.
461    pub fd_eps: f64,
462}
463
464impl ImplicitEuler {
465    /// Construct with specified parameters.
466    pub fn new(max_iter: usize, tol: f64, fd_eps: f64) -> Self {
467        Self {
468            max_iter,
469            tol,
470            fd_eps,
471        }
472    }
473
474    /// Construct with default parameters.
475    pub fn default_params() -> Self {
476        Self {
477            max_iter: 50,
478            tol: 1e-10,
479            fd_eps: 1e-7,
480        }
481    }
482
483    /// Perform one backward Euler step from state `s` with step `h`.
484    ///
485    /// Uses simple fixed-point / Picard iteration (no Jacobian required).
486    /// For strongly stiff systems prefer Newton iteration via [`ImplicitEuler::step_newton`].
487    pub fn step<F>(&self, s: &OdeState, h: f64, f: &F) -> OdeState
488    where
489        F: Fn(f64, &[f64]) -> Vec<f64>,
490    {
491        let t_new = s.t + h;
492        let mut y = s.y.clone();
493
494        for _ in 0..self.max_iter {
495            let rhs = f(t_new, &y);
496            let y_new: Vec<f64> =
497                s.y.iter()
498                    .zip(rhs.iter())
499                    .map(|(y0, r)| y0 + h * r)
500                    .collect();
501            let diff = rms_norm(&vec_sub(&y_new, &y));
502            y = y_new;
503            if diff < self.tol {
504                break;
505            }
506        }
507
508        OdeState::new(t_new, y)
509    }
510
511    /// Perform one backward Euler step using finite-difference Newton iteration.
512    ///
513    /// More robust than fixed-point for stiff problems.  Approximates the
514    /// Jacobian column-by-column with forward differences.
515    pub fn step_newton<F>(&self, s: &OdeState, h: f64, f: &F) -> OdeState
516    where
517        F: Fn(f64, &[f64]) -> Vec<f64>,
518    {
519        let t_new = s.t + h;
520        let n = s.y.len();
521        let mut y = s.y.clone();
522
523        for _ in 0..self.max_iter {
524            let fy = f(t_new, &y);
525            // Residual g(y) = y - y_n - h*f(t_new, y)
526            let g: Vec<f64> = (0..n).map(|i| y[i] - s.y[i] - h * fy[i]).collect();
527
528            let g_norm = rms_norm(&g);
529            if g_norm < self.tol {
530                break;
531            }
532
533            // Build approximate Jacobian dg/dy = I - h * df/dy via FD
534            // For simplicity use a diagonal approximation
535            let mut jac_diag = vec![1.0f64; n];
536            for j in 0..n {
537                let mut yp = y.clone();
538                yp[j] += self.fd_eps;
539                let fyp = f(t_new, &yp);
540                jac_diag[j] = 1.0 - h * (fyp[j] - fy[j]) / self.fd_eps;
541                if jac_diag[j].abs() < 1e-14 {
542                    jac_diag[j] = 1.0;
543                }
544            }
545
546            // Newton update: y <- y - J^{-1} g (diagonal approx)
547            for i in 0..n {
548                y[i] -= g[i] / jac_diag[i];
549            }
550        }
551
552        OdeState::new(t_new, y)
553    }
554
555    /// Integrate from `s0` to `t_end` with fixed step `dt`.
556    ///
557    /// Uses Newton iteration per step for robust handling of stiff problems.
558    pub fn integrate<F>(&self, s0: &OdeState, t_end: f64, dt: f64, f: &F) -> Vec<OdeState>
559    where
560        F: Fn(f64, &[f64]) -> Vec<f64>,
561    {
562        let mut states = vec![s0.clone()];
563        let mut s = s0.clone();
564        while s.t < t_end - 1e-14 {
565            let h = dt.min(t_end - s.t);
566            s = self.step_newton(&s, h, f);
567            states.push(s.clone());
568        }
569        states
570    }
571}
572
573// ─────────────────────────────────────────────────────────────────────────────
574// Trapezoidal (Crank-Nicolson)
575// ─────────────────────────────────────────────────────────────────────────────
576
577/// Trapezoidal (Crank-Nicolson) integrator — second-order accurate, A-stable.
578///
579/// Solves `y_{n+1} = y_n + h/2 * [f(t_n, y_n) + f(t_{n+1}, y_{n+1})]`
580/// iteratively via fixed-point iteration.
581pub struct Trapezoidal {
582    /// Maximum iterations per step.
583    pub max_iter: usize,
584    /// Convergence tolerance.
585    pub tol: f64,
586}
587
588impl Trapezoidal {
589    /// Construct with given parameters.
590    pub fn new(max_iter: usize, tol: f64) -> Self {
591        Self { max_iter, tol }
592    }
593
594    /// Construct with default parameters.
595    pub fn default_params() -> Self {
596        Self {
597            max_iter: 50,
598            tol: 1e-10,
599        }
600    }
601
602    /// Perform one Crank-Nicolson step from state `s` with step `h`.
603    pub fn step<F>(&self, s: &OdeState, h: f64, f: &F) -> OdeState
604    where
605        F: Fn(f64, &[f64]) -> Vec<f64>,
606    {
607        let t_new = s.t + h;
608        let f0 = f(s.t, &s.y);
609        // Predictor: explicit Euler
610        let mut y = vec_axpy(h, &f0, &s.y);
611
612        for _ in 0..self.max_iter {
613            let f1 = f(t_new, &y);
614            let y_new: Vec<f64> = (0..s.y.len())
615                .map(|i| s.y[i] + 0.5 * h * (f0[i] + f1[i]))
616                .collect();
617            let diff = rms_norm(&vec_sub(&y_new, &y));
618            y = y_new;
619            if diff < self.tol {
620                break;
621            }
622        }
623
624        OdeState::new(t_new, y)
625    }
626
627    /// Integrate from `s0` to `t_end` with fixed step `dt`.
628    pub fn integrate<F>(&self, s0: &OdeState, t_end: f64, dt: f64, f: &F) -> Vec<OdeState>
629    where
630        F: Fn(f64, &[f64]) -> Vec<f64>,
631    {
632        let mut states = vec![s0.clone()];
633        let mut s = s0.clone();
634        while s.t < t_end - 1e-14 {
635            let h = dt.min(t_end - s.t);
636            s = self.step(&s, h, f);
637            states.push(s.clone());
638        }
639        states
640    }
641}
642
643// ─────────────────────────────────────────────────────────────────────────────
644// BDF2 — 2nd-order backward differentiation formula
645// ─────────────────────────────────────────────────────────────────────────────
646
647/// Second-order Backward Differentiation Formula (BDF2) for stiff ODEs.
648///
649/// Uses the formula `(3/2) y_{n+1} - 2 y_n + (1/2) y_{n-1} = h * f(t_{n+1}, y_{n+1})`.
650/// The first step is taken with implicit Euler to obtain `y_1`.
651pub struct BDF2 {
652    /// Maximum iterations per step.
653    pub max_iter: usize,
654    /// Convergence tolerance for fixed-point iteration.
655    pub tol: f64,
656}
657
658impl BDF2 {
659    /// Construct with given parameters.
660    pub fn new(max_iter: usize, tol: f64) -> Self {
661        Self { max_iter, tol }
662    }
663
664    /// Construct with default parameters.
665    pub fn default_params() -> Self {
666        Self {
667            max_iter: 50,
668            tol: 1e-10,
669        }
670    }
671
672    /// Perform one BDF2 step given `y_n` (`s_curr`) and `y_{n-1}` (`s_prev`).
673    ///
674    /// Step size `h` must be the same as the previous step (constant step BDF2).
675    /// Uses Newton iteration with a diagonal finite-difference Jacobian for
676    /// robust handling of stiff problems.
677    pub fn step<F>(&self, s_curr: &OdeState, s_prev: &OdeState, h: f64, f: &F) -> OdeState
678    where
679        F: Fn(f64, &[f64]) -> Vec<f64>,
680    {
681        let t_new = s_curr.t + h;
682        let n = s_curr.y.len();
683        // Predictor: linear extrapolation
684        let mut y: Vec<f64> = (0..n).map(|i| 2.0 * s_curr.y[i] - s_prev.y[i]).collect();
685        let fd_eps = 1e-7_f64;
686
687        for _ in 0..self.max_iter {
688            let fy = f(t_new, &y);
689            // Residual: g(y) = (3/2)y - 2*y_n + (1/2)*y_{n-1} - h*f(y)
690            let g: Vec<f64> = (0..n)
691                .map(|i| 1.5 * y[i] - 2.0 * s_curr.y[i] + 0.5 * s_prev.y[i] - h * fy[i])
692                .collect();
693            let g_norm = rms_norm(&g);
694            if g_norm < self.tol {
695                break;
696            }
697            // Diagonal Jacobian: dg_i/dy_i = 3/2 - h * df_i/dy_i (FD)
698            let mut jac_diag = vec![1.5_f64; n];
699            for j in 0..n {
700                let mut yp = y.clone();
701                yp[j] += fd_eps;
702                let fyp = f(t_new, &yp);
703                jac_diag[j] = 1.5 - h * (fyp[j] - fy[j]) / fd_eps;
704                if jac_diag[j].abs() < 1e-14 {
705                    jac_diag[j] = 1.5;
706                }
707            }
708            // Newton update
709            for i in 0..n {
710                y[i] -= g[i] / jac_diag[i];
711            }
712        }
713
714        OdeState::new(t_new, y)
715    }
716
717    /// Integrate from `s0` to `t_end` with fixed step `dt`.
718    ///
719    /// The first step uses implicit Euler; subsequent steps use BDF2.
720    pub fn integrate<F>(&self, s0: &OdeState, t_end: f64, dt: f64, f: &F) -> Vec<OdeState>
721    where
722        F: Fn(f64, &[f64]) -> Vec<f64>,
723    {
724        if s0.t >= t_end - 1e-14 {
725            return vec![s0.clone()];
726        }
727        let ie = ImplicitEuler::new(self.max_iter, self.tol, 1e-7);
728        let h = dt.min(t_end - s0.t);
729        let s1 = ie.step_newton(s0, h, f);
730        let mut states = vec![s0.clone(), s1.clone()];
731        let mut s_prev = s0.clone();
732        let mut s_curr = s1;
733
734        while s_curr.t < t_end - 1e-14 {
735            let step = dt.min(t_end - s_curr.t);
736            let s_next = self.step(&s_curr, &s_prev, step, f);
737            states.push(s_next.clone());
738            s_prev = s_curr;
739            s_curr = s_next;
740        }
741        states
742    }
743}
744
745// ─────────────────────────────────────────────────────────────────────────────
746// EventDetection
747// ─────────────────────────────────────────────────────────────────────────────
748
749/// Zero-crossing (event) detected during integration.
750#[derive(Debug, Clone)]
751pub struct CrossingEvent {
752    /// Time at which the event function crossed zero.
753    pub t: f64,
754    /// State at the event time (interpolated).
755    pub y: Vec<f64>,
756    /// Sign of the event function just before the crossing (-1.0 or +1.0).
757    pub sign_before: f64,
758    /// Index of the event function that triggered.
759    pub event_index: usize,
760}
761
762/// Zero-crossing event detection via bisection root-finding.
763///
764/// After each ODE step, each registered event function `g_i(t, y)` is
765/// evaluated.  If the sign of `g_i` changes, bisection is used to locate the
766/// crossing time to within `tol`.
767pub struct EventDetection {
768    /// Bisection tolerance for the event time.
769    pub tol: f64,
770    /// Maximum bisection iterations.
771    pub max_iter: usize,
772}
773
774impl EventDetection {
775    /// Construct with specified tolerance and iteration limit.
776    pub fn new(tol: f64, max_iter: usize) -> Self {
777        Self { tol, max_iter }
778    }
779
780    /// Construct with default parameters.
781    pub fn default_params() -> Self {
782        Self {
783            tol: 1e-10,
784            max_iter: 50,
785        }
786    }
787
788    /// Detect zero crossings of `events[i](t, y)` between states `s_a` and `s_b`.
789    ///
790    /// Returns a list of [`CrossingEvent`] sorted by time.  Uses linear
791    /// interpolation of the state and bisection on each event function.
792    pub fn detect<E>(&self, s_a: &OdeState, s_b: &OdeState, events: &[E]) -> Vec<CrossingEvent>
793    where
794        E: Fn(f64, &[f64]) -> f64,
795    {
796        let mut crossings = Vec::new();
797
798        for (idx, evt) in events.iter().enumerate() {
799            let ga = evt(s_a.t, &s_a.y);
800            let gb = evt(s_b.t, &s_b.y);
801            if ga * gb > 0.0 {
802                continue; // no sign change
803            }
804
805            // Bisect in [ta, tb]
806            let mut lo = 0.0f64;
807            let mut hi = 1.0f64;
808            let ga_sign = ga.signum();
809
810            for _ in 0..self.max_iter {
811                let mid = 0.5 * (lo + hi);
812                let s_mid = s_a.lerp(s_b, mid);
813                let gm = evt(s_mid.t, &s_mid.y);
814                if gm.signum() == ga_sign {
815                    lo = mid;
816                } else {
817                    hi = mid;
818                }
819                if hi - lo < self.tol {
820                    break;
821                }
822            }
823
824            let alpha = 0.5 * (lo + hi);
825            let s_cross = s_a.lerp(s_b, alpha);
826            crossings.push(CrossingEvent {
827                t: s_cross.t,
828                y: s_cross.y,
829                sign_before: ga_sign,
830                event_index: idx,
831            });
832        }
833
834        crossings.sort_by(|a, b| a.t.partial_cmp(&b.t).unwrap_or(std::cmp::Ordering::Equal));
835        crossings
836    }
837}
838
839// ─────────────────────────────────────────────────────────────────────────────
840// OdeSolution — trajectory storage with interpolation
841// ─────────────────────────────────────────────────────────────────────────────
842
843/// Stored ODE trajectory with dense output via linear interpolation.
844///
845/// Holds all accepted integration steps and provides interpolation at
846/// arbitrary times within the integration interval.
847#[derive(Debug, Clone)]
848pub struct OdeSolution {
849    /// Ordered sequence of states (by increasing time).
850    pub states: Vec<OdeState>,
851}
852
853impl OdeSolution {
854    /// Construct from a vector of states (assumed sorted by time).
855    pub fn new(states: Vec<OdeState>) -> Self {
856        Self { states }
857    }
858
859    /// Number of stored states.
860    pub fn len(&self) -> usize {
861        self.states.len()
862    }
863
864    /// Return `true` if no states are stored.
865    pub fn is_empty(&self) -> bool {
866        self.states.is_empty()
867    }
868
869    /// Interpolate the state at time `t` using linear (dense output) interpolation.
870    ///
871    /// Returns `None` if `t` is outside the stored time interval.
872    pub fn interpolate(&self, t: f64) -> Option<OdeState> {
873        if self.states.is_empty() {
874            return None;
875        }
876        let t0 = self.states.first()?.t;
877        let t1 = self.states.last()?.t;
878        if t < t0 - 1e-14 || t > t1 + 1e-14 {
879            return None;
880        }
881        // Binary search for the interval
882        let idx = self.states.partition_point(|s| s.t <= t).saturating_sub(1);
883        let idx = idx.min(self.states.len() - 1);
884
885        if idx + 1 >= self.states.len() {
886            return Some(self.states[idx].clone());
887        }
888
889        let sa = &self.states[idx];
890        let sb = &self.states[idx + 1];
891        let dt = sb.t - sa.t;
892        if dt < 1e-15 {
893            return Some(sa.clone());
894        }
895        let alpha = (t - sa.t) / dt;
896        Some(sa.lerp(sb, alpha))
897    }
898
899    /// Extract all times as a `Vec`f64`.
900    pub fn times(&self) -> Vec<f64> {
901        self.states.iter().map(|s| s.t).collect()
902    }
903
904    /// Extract the trajectory of component `i` as a `Vec`f64`.
905    ///
906    /// Returns an empty vector if `i` is out of bounds.
907    pub fn component(&self, i: usize) -> Vec<f64> {
908        self.states
909            .iter()
910            .filter_map(|s| s.y.get(i).copied())
911            .collect()
912    }
913
914    /// Evaluate a scalar observable `g(t, y)` along the trajectory.
915    pub fn map_observable<G>(&self, g: G) -> Vec<f64>
916    where
917        G: Fn(f64, &[f64]) -> f64,
918    {
919        self.states.iter().map(|s| g(s.t, &s.y)).collect()
920    }
921
922    /// Resample the trajectory at `n` uniformly spaced times in \[t0, t1\].
923    pub fn resample(&self, n: usize) -> Vec<OdeState> {
924        if self.states.len() < 2 || n < 2 {
925            return self.states.clone();
926        }
927        let t0 = self
928            .states
929            .first()
930            .expect("states has at least 2 entries")
931            .t;
932        let t1 = self.states.last().expect("states has at least 2 entries").t;
933        (0..n)
934            .filter_map(|k| {
935                let t = t0 + (t1 - t0) * k as f64 / (n - 1) as f64;
936                self.interpolate(t)
937            })
938            .collect()
939    }
940}
941
942// ─────────────────────────────────────────────────────────────────────────────
943// Suppress unused import of PI (used in tests)
944// ─────────────────────────────────────────────────────────────────────────────
945#[allow(unused_imports)]
946const _PI_CHECK: f64 = PI;
947
948// ─────────────────────────────────────────────────────────────────────────────
949// Tests
950// ─────────────────────────────────────────────────────────────────────────────
951#[cfg(test)]
952mod tests {
953    use super::*;
954
955    // ------------------------------------------------------------------
956    // OdeState
957    // ------------------------------------------------------------------
958    #[test]
959    fn test_ode_state_new_and_norm() {
960        let s = OdeState::new(1.0, vec![3.0, 4.0]);
961        assert_eq!(s.t, 1.0);
962        assert!((s.norm() - 5.0).abs() < 1e-12);
963    }
964
965    #[test]
966    fn test_ode_state_zeros() {
967        let s = OdeState::zeros(0.0, 5);
968        assert_eq!(s.y.len(), 5);
969        assert_eq!(s.norm(), 0.0);
970    }
971
972    #[test]
973    fn test_ode_state_dim() {
974        let s = OdeState::new(0.0, vec![1.0, 2.0, 3.0]);
975        assert_eq!(s.dim(), 3);
976    }
977
978    #[test]
979    fn test_ode_state_lerp() {
980        let s0 = OdeState::new(0.0, vec![0.0, 0.0]);
981        let s1 = OdeState::new(1.0, vec![2.0, 4.0]);
982        let mid = s0.lerp(&s1, 0.5);
983        assert!((mid.t - 0.5).abs() < 1e-12);
984        assert!((mid.y[0] - 1.0).abs() < 1e-12);
985        assert!((mid.y[1] - 2.0).abs() < 1e-12);
986    }
987
988    #[test]
989    fn test_ode_state_lerp_endpoints() {
990        let s0 = OdeState::new(0.0, vec![1.0]);
991        let s1 = OdeState::new(2.0, vec![3.0]);
992        let at0 = s0.lerp(&s1, 0.0);
993        let at1 = s0.lerp(&s1, 1.0);
994        assert!((at0.y[0] - 1.0).abs() < 1e-12);
995        assert!((at1.y[0] - 3.0).abs() < 1e-12);
996    }
997
998    // ------------------------------------------------------------------
999    // Helper functions
1000    // ------------------------------------------------------------------
1001    #[test]
1002    fn test_rms_norm_empty() {
1003        assert_eq!(rms_norm(&[]), 0.0);
1004    }
1005
1006    #[test]
1007    fn test_rms_norm_ones() {
1008        let v = vec![1.0, 1.0, 1.0, 1.0];
1009        assert!((rms_norm(&v) - 1.0).abs() < 1e-12);
1010    }
1011
1012    #[test]
1013    fn test_vec_axpy() {
1014        let x = vec![1.0, 2.0];
1015        let y = vec![3.0, 4.0];
1016        let r = vec_axpy(2.0, &x, &y);
1017        assert!((r[0] - 5.0).abs() < 1e-12);
1018        assert!((r[1] - 8.0).abs() < 1e-12);
1019    }
1020
1021    #[test]
1022    fn test_vec_scale() {
1023        let x = vec![1.0, 2.0, 3.0];
1024        let r = vec_scale(3.0, &x);
1025        assert!((r[2] - 9.0).abs() < 1e-12);
1026    }
1027
1028    #[test]
1029    fn test_vec_add_sub() {
1030        let a = vec![1.0, 2.0];
1031        let b = vec![3.0, 1.0];
1032        let s = vec_add(&a, &b);
1033        let d = vec_sub(&b, &a);
1034        assert!((s[0] - 4.0).abs() < 1e-12);
1035        assert!((d[1] + 1.0).abs() < 1e-12);
1036    }
1037
1038    // ------------------------------------------------------------------
1039    // RK4 — exponential decay: dy/dt = -y, y(0)=1
1040    // ------------------------------------------------------------------
1041    fn f_decay(_t: f64, y: &[f64]) -> Vec<f64> {
1042        vec![-y[0]]
1043    }
1044
1045    #[test]
1046    fn test_rk4_single_step_accuracy() {
1047        let rk4 = RK4Integrator::default_tolerances();
1048        let s0 = OdeState::new(0.0, vec![1.0]);
1049        let s1 = rk4.step(&s0, 0.1, &f_decay);
1050        let exact = (-0.1f64).exp();
1051        assert!((s1.y[0] - exact).abs() < 1e-7);
1052    }
1053
1054    #[test]
1055    fn test_rk4_integrate_fixed() {
1056        let rk4 = RK4Integrator::default_tolerances();
1057        let s0 = OdeState::new(0.0, vec![1.0]);
1058        let traj = rk4.integrate(&s0, 1.0, 0.01, &f_decay);
1059        let last = traj.last().unwrap();
1060        let exact = (-1.0f64).exp();
1061        assert!((last.y[0] - exact).abs() < 1e-6);
1062    }
1063
1064    #[test]
1065    fn test_rk4_adaptive() {
1066        let rk4 = RK4Integrator::new(1e-8, 1e-8);
1067        let s0 = OdeState::new(0.0, vec![1.0]);
1068        let traj = rk4.integrate_adaptive(&s0, 2.0, 0.1, &f_decay);
1069        let last = traj.last().unwrap();
1070        let exact = (-2.0f64).exp();
1071        assert!((last.y[0] - exact).abs() < 1e-5);
1072    }
1073
1074    #[test]
1075    fn test_rk4_harmonic_oscillator() {
1076        // dy0/dt = y1, dy1/dt = -y0  (omega=1)
1077        let f = |_t: f64, y: &[f64]| vec![y[1], -y[0]];
1078        let rk4 = RK4Integrator::default_tolerances();
1079        let s0 = OdeState::new(0.0, vec![0.0, 1.0]); // sin(t)
1080        let traj = rk4.integrate(&s0, std::f64::consts::PI, 0.01, &f);
1081        let last = traj.last().unwrap();
1082        // y0 = sin(pi) ~ 0
1083        assert!(last.y[0].abs() < 1e-5);
1084    }
1085
1086    // ------------------------------------------------------------------
1087    // DormandPrince45
1088    // ------------------------------------------------------------------
1089    #[test]
1090    fn test_dp45_exponential_decay() {
1091        let dp = DormandPrince45::default_tolerances();
1092        let s0 = OdeState::new(0.0, vec![1.0]);
1093        let sol = dp.integrate(&s0, 1.0, 0.1, &f_decay);
1094        let last = sol.states.last().unwrap();
1095        let exact = (-1.0f64).exp();
1096        // default tolerances are 1e-6; expect global error within a factor of 10
1097        assert!((last.y[0] - exact).abs() < 1e-5);
1098    }
1099
1100    #[test]
1101    fn test_dp45_harmonic_oscillator() {
1102        let f = |_t: f64, y: &[f64]| vec![y[1], -y[0]];
1103        let dp = DormandPrince45::new(1e-9, 1e-9, 1e-12, 1.0);
1104        let s0 = OdeState::new(0.0, vec![1.0, 0.0]); // cos(t)
1105        let sol = dp.integrate(&s0, 2.0 * std::f64::consts::PI, 0.1, &f);
1106        let last = sol.states.last().unwrap();
1107        // y0 = cos(2*pi) = 1
1108        assert!((last.y[0] - 1.0).abs() < 1e-6);
1109    }
1110
1111    #[test]
1112    fn test_dp45_solution_len() {
1113        let dp = DormandPrince45::default_tolerances();
1114        let s0 = OdeState::new(0.0, vec![1.0]);
1115        let sol = dp.integrate(&s0, 1.0, 0.1, &f_decay);
1116        assert!(sol.len() > 1);
1117    }
1118
1119    #[test]
1120    fn test_dp45_fsal_step() {
1121        let dp = DormandPrince45::default_tolerances();
1122        let s0 = OdeState::new(0.0, vec![1.0]);
1123        let (s1, err, _k7) = dp.step(&s0, 0.1, &f_decay, None);
1124        assert!(err >= 0.0);
1125        let exact = (-0.1f64).exp();
1126        assert!((s1.y[0] - exact).abs() < 1e-9);
1127    }
1128
1129    // ------------------------------------------------------------------
1130    // ImplicitEuler — stiff decay dy/dt = -100y
1131    // ------------------------------------------------------------------
1132    #[test]
1133    fn test_implicit_euler_stiff_decay() {
1134        let f = |_t: f64, y: &[f64]| vec![-100.0 * y[0]];
1135        let ie = ImplicitEuler::default_params();
1136        let s0 = OdeState::new(0.0, vec![1.0]);
1137        let traj = ie.integrate(&s0, 1.0, 0.05, &f);
1138        let last = traj.last().unwrap();
1139        let exact = (-100.0f64).exp();
1140        assert!((last.y[0] - exact).abs() < 0.01);
1141    }
1142
1143    #[test]
1144    fn test_implicit_euler_newton_step() {
1145        let f_lin = |_t: f64, y: &[f64]| vec![-y[0]];
1146        let ie = ImplicitEuler::default_params();
1147        let s0 = OdeState::new(0.0, vec![1.0]);
1148        let s1 = ie.step_newton(&s0, 0.1, &f_lin);
1149        // Analytical: y1 = 1/(1+h) for implicit Euler on y' = -y
1150        let expected = 1.0 / 1.1;
1151        assert!((s1.y[0] - expected).abs() < 1e-8);
1152    }
1153
1154    #[test]
1155    fn test_implicit_euler_zero_rhs() {
1156        let f_zero = |_t: f64, y: &[f64]| vec![0.0 * y[0]];
1157        let ie = ImplicitEuler::default_params();
1158        let s0 = OdeState::new(0.0, vec![5.0]);
1159        let s1 = ie.step(&s0, 1.0, &f_zero);
1160        assert!((s1.y[0] - 5.0).abs() < 1e-12);
1161    }
1162
1163    // ------------------------------------------------------------------
1164    // Trapezoidal
1165    // ------------------------------------------------------------------
1166    #[test]
1167    fn test_trapezoidal_decay() {
1168        let trap = Trapezoidal::default_params();
1169        let s0 = OdeState::new(0.0, vec![1.0]);
1170        let traj = trap.integrate(&s0, 1.0, 0.01, &f_decay);
1171        let last = traj.last().unwrap();
1172        let exact = (-1.0f64).exp();
1173        // Trapezoidal is 2nd order; should be more accurate than implicit Euler
1174        assert!((last.y[0] - exact).abs() < 1e-5);
1175    }
1176
1177    #[test]
1178    fn test_trapezoidal_single_step() {
1179        let trap = Trapezoidal::new(100, 1e-12);
1180        let s0 = OdeState::new(0.0, vec![1.0]);
1181        let s1 = trap.step(&s0, 0.1, &f_decay);
1182        // For y'=-y the trapezoidal solution is (1 - h/2)/(1 + h/2)
1183        let expected = (1.0 - 0.05) / (1.0 + 0.05);
1184        assert!((s1.y[0] - expected).abs() < 1e-10);
1185    }
1186
1187    // ------------------------------------------------------------------
1188    // BDF2
1189    // ------------------------------------------------------------------
1190    #[test]
1191    fn test_bdf2_decay() {
1192        let bdf2 = BDF2::default_params();
1193        let s0 = OdeState::new(0.0, vec![1.0]);
1194        let traj = bdf2.integrate(&s0, 1.0, 0.01, &f_decay);
1195        let last = traj.last().unwrap();
1196        let exact = (-1.0f64).exp();
1197        assert!((last.y[0] - exact).abs() < 1e-4);
1198    }
1199
1200    #[test]
1201    fn test_bdf2_stiff_lambda_100() {
1202        // BDF2 on y'=-100y with h=0.05: very stiff problem.  The BDF2 predictor
1203        // (linear extrapolation) can be large, but fixed-point iteration should
1204        // recover a stable solution near zero.  The exact value exp(-50)~1.9e-22
1205        // is machine-zero; we just verify A-stability (no blow-up) and that
1206        // the trajectory is monotonically decreasing towards zero.
1207        let f = |_t: f64, y: &[f64]| vec![-100.0 * y[0]];
1208        let bdf2 = BDF2::default_params();
1209        let s0 = OdeState::new(0.0, vec![1.0]);
1210        let traj = bdf2.integrate(&s0, 0.5, 0.05, &f);
1211        let last = traj.last().unwrap();
1212        // BDF2 is A-stable: solution must not blow up
1213        assert!(
1214            last.y[0].abs() < 0.5,
1215            "BDF2 stiff result out of bounds: {}",
1216            last.y[0]
1217        );
1218        // First step should decay significantly from 1.0
1219        assert!(traj[1].y[0] < 1.0);
1220    }
1221
1222    #[test]
1223    fn test_bdf2_short_interval() {
1224        let bdf2 = BDF2::default_params();
1225        let s0 = OdeState::new(5.0, vec![1.0]);
1226        let traj = bdf2.integrate(&s0, 5.0, 0.1, &f_decay);
1227        assert_eq!(traj.len(), 1); // already at t_end
1228    }
1229
1230    // ------------------------------------------------------------------
1231    // EventDetection
1232    // ------------------------------------------------------------------
1233    #[test]
1234    fn test_event_detection_crossing_zero() {
1235        let ed = EventDetection::default_params();
1236        let s_a = OdeState::new(0.9, vec![0.1]);
1237        let s_b = OdeState::new(1.1, vec![-0.1]);
1238        // Event: y[0] itself
1239        let events: Vec<fn(f64, &[f64]) -> f64> = vec![|_t, y| y[0]];
1240        let crossings = ed.detect(&s_a, &s_b, &events);
1241        assert_eq!(crossings.len(), 1);
1242        assert!((crossings[0].t - 1.0).abs() < 1e-8);
1243    }
1244
1245    #[test]
1246    fn test_event_detection_no_crossing() {
1247        let ed = EventDetection::default_params();
1248        let s_a = OdeState::new(0.0, vec![1.0]);
1249        let s_b = OdeState::new(1.0, vec![2.0]);
1250        let events: Vec<fn(f64, &[f64]) -> f64> = vec![|_t, y| y[0]];
1251        let crossings = ed.detect(&s_a, &s_b, &events);
1252        assert!(crossings.is_empty());
1253    }
1254
1255    #[test]
1256    fn test_event_detection_time_event() {
1257        let ed = EventDetection::default_params();
1258        let s_a = OdeState::new(0.8, vec![0.0]);
1259        let s_b = OdeState::new(1.2, vec![0.0]);
1260        // Event fires at t=1.0
1261        let events: Vec<fn(f64, &[f64]) -> f64> = vec![|t, _y| t - 1.0];
1262        let crossings = ed.detect(&s_a, &s_b, &events);
1263        assert_eq!(crossings.len(), 1);
1264        assert!((crossings[0].t - 1.0).abs() < 1e-8);
1265    }
1266
1267    #[test]
1268    fn test_event_detection_multiple_events() {
1269        let ed = EventDetection::default_params();
1270        let s_a = OdeState::new(0.0, vec![2.0, -1.0]);
1271        let s_b = OdeState::new(2.0, vec![-2.0, 1.0]);
1272        let ev0: fn(f64, &[f64]) -> f64 = |_t, y| y[0];
1273        let ev1: fn(f64, &[f64]) -> f64 = |_t, y| y[1];
1274        let crossings = ed.detect(&s_a, &s_b, &[ev0, ev1]);
1275        assert_eq!(crossings.len(), 2);
1276    }
1277
1278    // ------------------------------------------------------------------
1279    // OdeSolution
1280    // ------------------------------------------------------------------
1281    #[test]
1282    fn test_ode_solution_interpolate() {
1283        let states = vec![
1284            OdeState::new(0.0, vec![0.0]),
1285            OdeState::new(1.0, vec![1.0]),
1286            OdeState::new(2.0, vec![4.0]),
1287        ];
1288        let sol = OdeSolution::new(states);
1289        let mid = sol.interpolate(0.5).unwrap();
1290        assert!((mid.y[0] - 0.5).abs() < 1e-12);
1291    }
1292
1293    #[test]
1294    fn test_ode_solution_out_of_range() {
1295        let states = vec![OdeState::new(0.0, vec![1.0]), OdeState::new(1.0, vec![2.0])];
1296        let sol = OdeSolution::new(states);
1297        assert!(sol.interpolate(-0.5).is_none());
1298        assert!(sol.interpolate(1.5).is_none());
1299    }
1300
1301    #[test]
1302    fn test_ode_solution_times_and_component() {
1303        let dp = DormandPrince45::default_tolerances();
1304        let s0 = OdeState::new(0.0, vec![1.0, 0.0]);
1305        let f = |_t: f64, y: &[f64]| vec![y[1], -y[0]];
1306        let sol = dp.integrate(&s0, 1.0, 0.1, &f);
1307        let ts = sol.times();
1308        let c0 = sol.component(0);
1309        assert_eq!(ts.len(), c0.len());
1310    }
1311
1312    #[test]
1313    fn test_ode_solution_resample() {
1314        let dp = DormandPrince45::default_tolerances();
1315        let s0 = OdeState::new(0.0, vec![1.0]);
1316        let sol = dp.integrate(&s0, 1.0, 0.1, &f_decay);
1317        let resampled = sol.resample(20);
1318        assert_eq!(resampled.len(), 20);
1319    }
1320
1321    #[test]
1322    fn test_ode_solution_empty() {
1323        let sol = OdeSolution::new(vec![]);
1324        assert!(sol.is_empty());
1325        assert!(sol.interpolate(0.5).is_none());
1326    }
1327
1328    #[test]
1329    fn test_ode_solution_map_observable() {
1330        let states = vec![OdeState::new(0.0, vec![1.0]), OdeState::new(1.0, vec![2.0])];
1331        let sol = OdeSolution::new(states);
1332        let obs = sol.map_observable(|_t, y| y[0] * 2.0);
1333        assert!((obs[0] - 2.0).abs() < 1e-12);
1334        assert!((obs[1] - 4.0).abs() < 1e-12);
1335    }
1336
1337    // ------------------------------------------------------------------
1338    // Cross-method comparison
1339    // ------------------------------------------------------------------
1340    #[test]
1341    fn test_rk4_vs_dp45_accuracy() {
1342        let f = |_t: f64, y: &[f64]| vec![-y[0]];
1343        let rk4 = RK4Integrator::default_tolerances();
1344        let dp = DormandPrince45::default_tolerances();
1345        let s0 = OdeState::new(0.0, vec![1.0]);
1346
1347        let traj_rk4 = rk4.integrate(&s0, 1.0, 0.01, &f);
1348        let sol_dp = dp.integrate(&s0, 1.0, 0.1, &f);
1349
1350        let exact = (-1.0f64).exp();
1351        let err_rk4 = (traj_rk4.last().unwrap().y[0] - exact).abs();
1352        let err_dp = (sol_dp.states.last().unwrap().y[0] - exact).abs();
1353
1354        // DP45 with tight tolerance should be at least as accurate as RK4
1355        assert!(err_dp < 1e-6);
1356        assert!(err_rk4 < 1e-6);
1357    }
1358
1359    #[test]
1360    fn test_implicit_vs_explicit_stiff() {
1361        // Explicit RK4 is unstable with large h for stiff problems;
1362        // implicit Euler stays bounded.
1363        let f = |_t: f64, y: &[f64]| vec![-1000.0 * y[0]];
1364        let ie = ImplicitEuler::default_params();
1365        let s0 = OdeState::new(0.0, vec![1.0]);
1366        let traj = ie.integrate(&s0, 0.01, 0.001, &f);
1367        // Solution should go to zero, not blow up
1368        let last = traj.last().unwrap().y[0];
1369        assert!((0.0..=1.0).contains(&last));
1370    }
1371}