Skip to main content

numra_ode/
dopri5.rs

1//! Dormand-Prince 5(4) explicit Runge-Kutta solver.
2//!
3//! DoPri5 is one of the most widely used adaptive RK methods for non-stiff ODEs.
4//! It uses a 5th order method for advancing the solution and a 4th order embedded
5//! method for error estimation.
6//!
7//! ## Features
8//! - 5th order accuracy with embedded 4th order error estimator
9//! - 7 stages with FSAL (First Same As Last) property
10//! - Free 4th order interpolant for dense output
11//! - PI step size control for smooth adaptation
12//!
13//! ## Reference
14//! Dormand, J. R.; Prince, P. J. (1980), "A family of embedded Runge-Kutta formulae",
15//! Journal of Computational and Applied Mathematics, 6 (1): 19–26
16//!
17//! Author: Moussa Leblouba
18//! Date: 5 March 2026
19//! Modified: 2 May 2026
20
21use crate::dense::{DenseOutput, DenseSegment, DoPri5Interpolant};
22use crate::error::SolverError;
23use crate::events::{find_event_time, Event, EventAction};
24use crate::problem::OdeSystem;
25use crate::solver::{Solver, SolverOptions, SolverResult, SolverStats};
26use crate::step_control::{PIController, StepController};
27use crate::t_eval::{validate_grid, TEvalEmitter};
28use numra_core::Scalar;
29
30/// Dormand-Prince 5(4) solver.
31///
32/// # Example
33///
34/// ```rust
35/// use numra_ode::{OdeProblem, DoPri5, Solver, SolverOptions};
36///
37/// // Exponential decay: dy/dt = -y
38/// let problem = OdeProblem::new(
39///     |_t, y: &[f64], dydt: &mut [f64]| { dydt[0] = -y[0]; },
40///     0.0,
41///     5.0,
42///     vec![1.0]
43/// );
44///
45/// let options = SolverOptions::default();
46/// let result = DoPri5::solve(&problem, 0.0, 5.0, &[1.0], &options).unwrap();
47///
48/// // y(5) ≈ e^(-5) ≈ 0.00674
49/// let y_final = result.y_final().unwrap();
50/// assert!((y_final[0] - (-5.0_f64).exp()).abs() < 1e-5);
51/// ```
52#[derive(Clone, Debug, Default)]
53pub struct DoPri5;
54
55impl DoPri5 {
56    /// Create a new DoPri5 solver.
57    pub fn new() -> Self {
58        Self
59    }
60}
61
62/// Butcher tableau coefficients for DoPri5.
63///
64/// The tableau is:
65/// ```text
66/// 0    |
67/// 1/5  | 1/5
68/// 3/10 | 3/40      9/40
69/// 4/5  | 44/45     -56/15    32/9
70/// 8/9  | 19372/6561 -25360/2187 64448/6561 -212/729
71/// 1    | 9017/3168  -355/33    46732/5247  49/176   -5103/18656
72/// 1    | 35/384     0          500/1113    125/192  -2187/6784   11/84
73/// -----+-------------------------------------------------------------------
74/// b    | 35/384     0          500/1113    125/192  -2187/6784   11/84    0
75/// b*   | 5179/57600 0          7571/16695  393/640  -92097/339200 187/2100 1/40
76/// ```
77#[allow(dead_code)]
78mod tableau {
79    // Nodes (c coefficients)
80    pub const C2: f64 = 1.0 / 5.0;
81    pub const C3: f64 = 3.0 / 10.0;
82    pub const C4: f64 = 4.0 / 5.0;
83    pub const C5: f64 = 8.0 / 9.0;
84    pub const C6: f64 = 1.0;
85    pub const C7: f64 = 1.0;
86
87    // A matrix coefficients (row by row)
88    pub const A21: f64 = 1.0 / 5.0;
89
90    pub const A31: f64 = 3.0 / 40.0;
91    pub const A32: f64 = 9.0 / 40.0;
92
93    pub const A41: f64 = 44.0 / 45.0;
94    pub const A42: f64 = -56.0 / 15.0;
95    pub const A43: f64 = 32.0 / 9.0;
96
97    pub const A51: f64 = 19372.0 / 6561.0;
98    pub const A52: f64 = -25360.0 / 2187.0;
99    pub const A53: f64 = 64448.0 / 6561.0;
100    pub const A54: f64 = -212.0 / 729.0;
101
102    pub const A61: f64 = 9017.0 / 3168.0;
103    pub const A62: f64 = -355.0 / 33.0;
104    pub const A63: f64 = 46732.0 / 5247.0;
105    pub const A64: f64 = 49.0 / 176.0;
106    pub const A65: f64 = -5103.0 / 18656.0;
107
108    pub const A71: f64 = 35.0 / 384.0;
109    pub const A72: f64 = 0.0;
110    pub const A73: f64 = 500.0 / 1113.0;
111    pub const A74: f64 = 125.0 / 192.0;
112    pub const A75: f64 = -2187.0 / 6784.0;
113    pub const A76: f64 = 11.0 / 84.0;
114
115    // 5th order weights (same as A7*)
116    pub const B1: f64 = 35.0 / 384.0;
117    pub const B2: f64 = 0.0;
118    pub const B3: f64 = 500.0 / 1113.0;
119    pub const B4: f64 = 125.0 / 192.0;
120    pub const B5: f64 = -2187.0 / 6784.0;
121    pub const B6: f64 = 11.0 / 84.0;
122    pub const B7: f64 = 0.0;
123
124    // 4th order weights (for error estimation)
125    pub const B1_HAT: f64 = 5179.0 / 57600.0;
126    pub const B2_HAT: f64 = 0.0;
127    pub const B3_HAT: f64 = 7571.0 / 16695.0;
128    pub const B4_HAT: f64 = 393.0 / 640.0;
129    pub const B5_HAT: f64 = -92097.0 / 339200.0;
130    pub const B6_HAT: f64 = 187.0 / 2100.0;
131    pub const B7_HAT: f64 = 1.0 / 40.0;
132
133    // Error coefficients: e = b - b_hat
134    pub const E1: f64 = B1 - B1_HAT;
135    pub const E2: f64 = B2 - B2_HAT;
136    pub const E3: f64 = B3 - B3_HAT;
137    pub const E4: f64 = B4 - B4_HAT;
138    pub const E5: f64 = B5 - B5_HAT;
139    pub const E6: f64 = B6 - B6_HAT;
140    pub const E7: f64 = B7 - B7_HAT;
141}
142
143impl<S: Scalar> Solver<S> for DoPri5 {
144    fn solve<Sys: OdeSystem<S>>(
145        problem: &Sys,
146        t0: S,
147        tf: S,
148        y0: &[S],
149        options: &SolverOptions<S>,
150    ) -> Result<SolverResult<S>, SolverError> {
151        let dim = problem.dim();
152        if y0.len() != dim {
153            return Err(SolverError::DimensionMismatch {
154                expected: dim,
155                actual: y0.len(),
156            });
157        }
158
159        // Direction of integration
160        let direction = if tf >= t0 { S::ONE } else { -S::ONE };
161
162        // Optional user-requested output grid. When set, the solver emits
163        // (t, y) only at these times, interpolated from each accepted step
164        // via Hermite cubic. Validated up front so misconfigured inputs
165        // surface before any RHS calls.
166        if let Some(grid) = options.t_eval.as_deref() {
167            validate_grid(grid, t0, tf)?;
168        }
169        let mut grid_emitter = options
170            .t_eval
171            .as_deref()
172            .map(|g| TEvalEmitter::new(g, direction));
173
174        // Initialize step size controller
175        let mut controller = PIController::for_order(5);
176
177        // Initialize step size
178        let mut h = match options.h0 {
179            Some(h0) => direction * h0.abs(),
180            None => estimate_initial_step(problem, t0, y0, direction, options),
181        };
182
183        // Clamp step size
184        h = direction * h.abs().min(options.h_max).max(options.h_min);
185
186        // Working arrays
187        let mut t = t0;
188        let mut y = y0.to_vec();
189        let mut y_new = vec![S::ZERO; dim];
190
191        // Stage vectors (FSAL: k7 of one step = k1 of next)
192        let mut k1 = vec![S::ZERO; dim];
193        let mut k2 = vec![S::ZERO; dim];
194        let mut k3 = vec![S::ZERO; dim];
195        let mut k4 = vec![S::ZERO; dim];
196        let mut k5 = vec![S::ZERO; dim];
197        let mut k6 = vec![S::ZERO; dim];
198        let mut k7 = vec![S::ZERO; dim];
199
200        // Temporary vector for stage computation
201        let mut y_stage = vec![S::ZERO; dim];
202
203        // Error vector
204        let mut err = vec![S::ZERO; dim];
205
206        // Pre-allocated workspace for dense output (avoids allocation per step)
207        let mut k_all = if options.dense_output {
208            vec![S::ZERO; 7 * dim]
209        } else {
210            Vec::new()
211        };
212
213        // Output storage. In t_eval mode, the result vectors start empty
214        // and are populated only at the requested times by `grid_emitter`.
215        let (mut t_out, mut y_out) = if grid_emitter.is_some() {
216            (Vec::new(), Vec::new())
217        } else {
218            (vec![t0], y0.to_vec())
219        };
220
221        // Event tracking
222        let has_events = !options.events.is_empty();
223        let mut detected_events: Vec<Event<S>> = Vec::new();
224
225        // Previous event function values (for sign-change detection)
226        let mut g_prev: Vec<S> = options
227            .events
228            .iter()
229            .map(|ef| ef.evaluate(t0, y0))
230            .collect();
231
232        // Statistics
233        let mut stats = SolverStats::new();
234
235        // Dense output
236        let mut dense = if options.dense_output {
237            DenseOutput::new(dim, direction)
238        } else {
239            DenseOutput::new(0, direction)
240        };
241
242        // Compute initial k1
243        problem.rhs(t, &y, &mut k1);
244        stats.n_eval += 1;
245
246        // Pre-allocated tolerance weights buffer (avoids allocation per step)
247        let mut tol_weights = vec![S::ZERO; dim];
248        let update_tol_weights = |weights: &mut [S], y: &[S]| {
249            for (w, &yi) in weights.iter_mut().zip(y.iter()) {
250                *w = options.atol + options.rtol * yi.abs();
251            }
252        };
253
254        // Main integration loop
255        let mut step_count = 0;
256        let mut last_step = false;
257
258        while !last_step {
259            // Check step limit
260            if step_count >= options.max_steps {
261                return Err(SolverError::MaxIterationsExceeded { t: t.to_f64() });
262            }
263
264            // Adjust final step to hit tf exactly
265            if direction * (t + h - tf) > S::ZERO {
266                h = tf - t;
267                last_step = true;
268            }
269
270            // Compute stages
271            // k1 is already computed (FSAL)
272
273            // k2
274            for i in 0..dim {
275                y_stage[i] = y[i] + h * S::from_f64(tableau::A21) * k1[i];
276            }
277            problem.rhs(t + h * S::from_f64(tableau::C2), &y_stage, &mut k2);
278
279            // k3
280            for i in 0..dim {
281                y_stage[i] = y[i]
282                    + h * (S::from_f64(tableau::A31) * k1[i] + S::from_f64(tableau::A32) * k2[i]);
283            }
284            problem.rhs(t + h * S::from_f64(tableau::C3), &y_stage, &mut k3);
285
286            // k4
287            for i in 0..dim {
288                y_stage[i] = y[i]
289                    + h * (S::from_f64(tableau::A41) * k1[i]
290                        + S::from_f64(tableau::A42) * k2[i]
291                        + S::from_f64(tableau::A43) * k3[i]);
292            }
293            problem.rhs(t + h * S::from_f64(tableau::C4), &y_stage, &mut k4);
294
295            // k5
296            for i in 0..dim {
297                y_stage[i] = y[i]
298                    + h * (S::from_f64(tableau::A51) * k1[i]
299                        + S::from_f64(tableau::A52) * k2[i]
300                        + S::from_f64(tableau::A53) * k3[i]
301                        + S::from_f64(tableau::A54) * k4[i]);
302            }
303            problem.rhs(t + h * S::from_f64(tableau::C5), &y_stage, &mut k5);
304
305            // k6
306            for i in 0..dim {
307                y_stage[i] = y[i]
308                    + h * (S::from_f64(tableau::A61) * k1[i]
309                        + S::from_f64(tableau::A62) * k2[i]
310                        + S::from_f64(tableau::A63) * k3[i]
311                        + S::from_f64(tableau::A64) * k4[i]
312                        + S::from_f64(tableau::A65) * k5[i]);
313            }
314            problem.rhs(t + h * S::from_f64(tableau::C6), &y_stage, &mut k6);
315
316            // Compute y_new (5th order)
317            for i in 0..dim {
318                y_new[i] = y[i]
319                    + h * (S::from_f64(tableau::B1) * k1[i]
320                        + S::from_f64(tableau::B3) * k3[i]
321                        + S::from_f64(tableau::B4) * k4[i]
322                        + S::from_f64(tableau::B5) * k5[i]
323                        + S::from_f64(tableau::B6) * k6[i]);
324            }
325
326            // k7 (FSAL: this will be k1 of next step)
327            problem.rhs(t + h, &y_new, &mut k7);
328
329            stats.n_eval += 6; // We computed k2..k7
330
331            // Error estimate
332            for i in 0..dim {
333                err[i] = h
334                    * (S::from_f64(tableau::E1) * k1[i]
335                        + S::from_f64(tableau::E3) * k3[i]
336                        + S::from_f64(tableau::E4) * k4[i]
337                        + S::from_f64(tableau::E5) * k5[i]
338                        + S::from_f64(tableau::E6) * k6[i]
339                        + S::from_f64(tableau::E7) * k7[i]);
340            }
341
342            // Compute scaled error norm using pre-allocated buffer
343            update_tol_weights(&mut tol_weights, &y);
344            let err_norm = weighted_rms_norm(&err, &tol_weights);
345
346            // Detect NaN in error norm (propagated from NaN inputs/RHS)
347            if err_norm.is_nan() {
348                return Err(SolverError::Other(
349                    "NaN detected in error estimate (check inputs and RHS function)".to_string(),
350                ));
351            }
352
353            // Step size control
354            let proposal = controller.propose(h, err_norm, 5);
355
356            if proposal.accept {
357                // Accept step
358                stats.n_accept += 1;
359                controller.accept(h, err_norm);
360
361                // Build interpolation coefficients (needed for dense output)
362                // Uses pre-allocated k_all buffer to avoid per-step allocation.
363                let interp_coeffs = if options.dense_output {
364                    k_all[0..dim].copy_from_slice(&k1);
365                    k_all[dim..2 * dim].copy_from_slice(&k2);
366                    k_all[2 * dim..3 * dim].copy_from_slice(&k3);
367                    k_all[3 * dim..4 * dim].copy_from_slice(&k4);
368                    k_all[4 * dim..5 * dim].copy_from_slice(&k5);
369                    k_all[5 * dim..6 * dim].copy_from_slice(&k6);
370                    k_all[6 * dim..7 * dim].copy_from_slice(&k7);
371                    Some(DoPri5Interpolant::build_coefficients(
372                        &y, &y_new, &k_all, h, dim,
373                    ))
374                } else {
375                    None
376                };
377
378                // Store dense output if enabled
379                if options.dense_output {
380                    if let Some(ref coeffs) = interp_coeffs {
381                        dense.add_segment(DenseSegment::new(t, t + h, coeffs.clone(), dim));
382                    }
383                }
384
385                // Event detection
386                if has_events {
387                    let t_new = t + h;
388
389                    let mut stop_event = false;
390                    let mut earliest_event_t = t_new;
391                    let mut earliest_event_y: Option<Vec<S>> = None;
392
393                    for (idx, event_fn) in options.events.iter().enumerate() {
394                        let g_curr = event_fn.evaluate(t_new, &y_new);
395
396                        // Check for sign change
397                        if g_prev[idx] * g_curr < S::ZERO {
398                            // Build Hermite cubic interpolation between step endpoints.
399                            // Uses y, y_new (values) and k1, k7 (derivatives) for
400                            // accurate cubic interpolation.
401                            let y_ref = &y;
402                            let y_new_ref = &y_new;
403                            let k1_ref = &k1;
404                            let k7_ref = &k7;
405                            let t_start = t;
406                            let h_step = h;
407                            let interpolate = move |t_interp: S| -> Vec<S> {
408                                let theta = (t_interp - t_start) / h_step;
409                                let theta2 = theta * theta;
410                                let theta3 = theta2 * theta;
411                                // Hermite basis functions
412                                let h00 = S::TWO * theta3 - S::from_f64(3.0) * theta2 + S::ONE;
413                                let h10 = theta3 - S::TWO * theta2 + theta;
414                                let h01 = -S::TWO * theta3 + S::from_f64(3.0) * theta2;
415                                let h11 = theta3 - theta2;
416                                let mut y_interp = vec![S::ZERO; dim];
417                                for i in 0..dim {
418                                    y_interp[i] = h00 * y_ref[i]
419                                        + h10 * h_step * k1_ref[i]
420                                        + h01 * y_new_ref[i]
421                                        + h11 * h_step * k7_ref[i];
422                                }
423                                y_interp
424                            };
425
426                            if let Some((t_event, y_event)) = find_event_time(
427                                event_fn.as_ref(),
428                                t,
429                                &y,
430                                t_new,
431                                &y_new,
432                                &interpolate,
433                            ) {
434                                // Track the earliest event in case of multiple
435                                if earliest_event_y.is_none()
436                                    || (direction * (t_event - earliest_event_t) < S::ZERO)
437                                {
438                                    earliest_event_t = t_event;
439                                    earliest_event_y = Some(y_event.clone());
440                                }
441
442                                detected_events.push(Event {
443                                    t: t_event,
444                                    y: y_event,
445                                    event_index: idx,
446                                });
447
448                                if event_fn.action() == EventAction::Stop {
449                                    stop_event = true;
450                                }
451                            }
452                        }
453
454                        g_prev[idx] = g_curr;
455                    }
456
457                    if stop_event {
458                        // Terminate integration at the earliest stop event.
459                        // Filter out events that occur after the earliest stop event.
460                        let ev_t = earliest_event_t;
461                        // Safety: stop_event is only set when an event is found,
462                        // which guarantees earliest_event_y is Some.
463                        let ev_y = match earliest_event_y {
464                            Some(y) => y,
465                            None => {
466                                return Err(SolverError::Other(
467                                    "Internal error: stop event without event data".into(),
468                                ))
469                            }
470                        };
471
472                        detected_events.retain(|e| direction * (e.t - ev_t) <= S::ZERO);
473
474                        t_out.push(ev_t);
475                        y_out.extend_from_slice(&ev_y);
476
477                        let mut result = SolverResult::new(t_out, y_out, dim, stats);
478                        result.events = detected_events;
479                        result.terminated_by_event = true;
480                        if options.dense_output && !dense.is_empty() {
481                            result.dense_output = Some(dense);
482                        }
483                        return Ok(result);
484                    }
485                }
486
487                // Store output. In t_eval mode, emit Hermite-interpolated
488                // values at any requested grid points covered by this step;
489                // otherwise push the natural step endpoint as before. Both
490                // paths are closed-form in (t, y, k1, y_new, k7), which we
491                // already have in scope.
492                let t_new = t + h;
493                if let Some(ref mut emitter) = grid_emitter {
494                    emitter.emit_step(t, &y, &k1, t_new, &y_new, &k7, &mut t_out, &mut y_out);
495                } else {
496                    t_out.push(t_new);
497                    y_out.extend_from_slice(&y_new);
498                }
499
500                // Update state
501                t = t_new;
502                y.copy_from_slice(&y_new);
503
504                // FSAL: k7 becomes k1
505                k1.copy_from_slice(&k7);
506
507                step_count += 1;
508            } else {
509                // Reject step
510                stats.n_reject += 1;
511                controller.reject(h, err_norm);
512                last_step = false; // Need to retry
513            }
514
515            // Update step size
516            h = direction * proposal.h_new.abs().min(options.h_max).max(options.h_min);
517        }
518
519        let mut result = SolverResult::new(t_out, y_out, dim, stats);
520        result.events = detected_events;
521        if options.dense_output && !dense.is_empty() {
522            result.dense_output = Some(dense);
523        }
524        Ok(result)
525    }
526}
527
528/// Compute weighted RMS norm for error control.
529fn weighted_rms_norm<S: Scalar>(err: &[S], weights: &[S]) -> S {
530    let n = S::from_usize(err.len());
531    let mut sum = S::ZERO;
532    for (e, w) in err.iter().zip(weights.iter()) {
533        let scaled = *e / *w;
534        sum = sum + scaled * scaled;
535    }
536    (sum / n).sqrt()
537}
538
539/// Estimate initial step size using the algorithm from Hairer-Wanner.
540fn estimate_initial_step<S: Scalar, Sys: OdeSystem<S>>(
541    problem: &Sys,
542    t0: S,
543    y0: &[S],
544    direction: S,
545    options: &SolverOptions<S>,
546) -> S {
547    let dim = problem.dim();
548
549    // Compute f0 = f(t0, y0)
550    let mut f0 = vec![S::ZERO; dim];
551    problem.rhs(t0, y0, &mut f0);
552
553    // Compute scale = atol + |y0| * rtol
554    let scale: Vec<S> = y0
555        .iter()
556        .map(|&yi| options.atol + options.rtol * yi.abs())
557        .collect();
558
559    // d0 = ||y0 / scale||
560    let d0 = weighted_rms_norm(y0, &scale);
561
562    // d1 = ||f0 / scale||
563    let d1 = weighted_rms_norm(&f0, &scale);
564
565    // First guess: h0 = 0.01 * d0/d1
566    let h0 = if d0 < S::EPSILON.sqrt() || d1 < S::EPSILON.sqrt() {
567        S::from_f64(1e-6)
568    } else {
569        S::from_f64(0.01) * d0 / d1
570    };
571
572    // Perform explicit Euler step
573    let mut y1 = vec![S::ZERO; dim];
574    for i in 0..dim {
575        y1[i] = y0[i] + direction * h0 * f0[i];
576    }
577
578    // Compute f1 = f(t0 + h0, y1)
579    let mut f1 = vec![S::ZERO; dim];
580    problem.rhs(t0 + direction * h0, &y1, &mut f1);
581
582    // d2 = ||f1 - f0|| / h0
583    let mut df = vec![S::ZERO; dim];
584    for i in 0..dim {
585        df[i] = (f1[i] - f0[i]) / h0;
586    }
587    let d2 = weighted_rms_norm(&df, &scale);
588
589    // h1 = (0.01 / max(d1, d2))^(1/5)
590    let max_d = d1.max(d2);
591    let h1 = if max_d <= S::from_f64(1e-15) {
592        (h0 * S::from_f64(1e-3)).max(S::from_f64(1e-6))
593    } else {
594        (S::from_f64(0.01) / max_d).powf(S::from_f64(0.2))
595    };
596
597    // h = min(100 * h0, h1)
598    let h = (S::from_f64(100.0) * h0).min(h1);
599
600    direction * h
601}
602
603#[cfg(test)]
604mod tests {
605    use super::*;
606    use crate::problem::OdeProblem;
607
608    #[test]
609    fn test_exponential_decay() {
610        // dy/dt = -y, y(0) = 1, exact: y(t) = e^(-t)
611        let problem = OdeProblem::new(
612            |_t: f64, y: &[f64], dydt: &mut [f64]| {
613                dydt[0] = -y[0];
614            },
615            0.0,
616            5.0,
617            vec![1.0],
618        );
619
620        let options = SolverOptions::default().rtol(1e-8).atol(1e-10);
621        let result = DoPri5::solve(&problem, 0.0, 5.0, &[1.0], &options).unwrap();
622
623        assert!(result.success);
624        let y_final = result.y_final().unwrap();
625        let exact = (-5.0_f64).exp();
626        let error = (y_final[0] - exact).abs();
627        assert!(error < 1e-7, "Error {} too large", error);
628    }
629
630    #[test]
631    fn test_dense_output_returned_when_requested() {
632        use crate::dense::DenseInterpolant;
633        // Regression: when SolverOptions::dense() is set, the integration
634        // builds a DenseOutput; that DenseOutput must be returned to the
635        // caller via SolverResult.dense_output, not silently dropped.
636        let problem = OdeProblem::new(
637            |_t: f64, y: &[f64], dydt: &mut [f64]| {
638                dydt[0] = -y[0];
639            },
640            0.0,
641            5.0,
642            vec![1.0],
643        );
644
645        let options = SolverOptions::default().rtol(1e-8).atol(1e-10).dense();
646        let result = DoPri5::solve(&problem, 0.0, 5.0, &[1.0], &options).unwrap();
647
648        let dense = result
649            .dense_output
650            .as_ref()
651            .expect("dense() requested; SolverResult.dense_output must be Some");
652        assert!(!dense.is_empty(), "dense output should contain segments");
653
654        let t_mid = 2.5;
655        let segment = dense
656            .find_segment(t_mid)
657            .expect("midpoint should fall inside an integrated segment");
658        let mut y_mid = vec![0.0; 1];
659        DoPri5Interpolant.interpolate(segment, t_mid, &mut y_mid);
660        let exact = (-t_mid).exp();
661        assert!(
662            (y_mid[0] - exact).abs() < 1e-3,
663            "interpolated value {} too far from exact {}",
664            y_mid[0],
665            exact
666        );
667
668        // Symmetric: when dense() is NOT requested, dense_output stays None.
669        let options_no_dense = SolverOptions::default().rtol(1e-8).atol(1e-10);
670        let result_no_dense = DoPri5::solve(&problem, 0.0, 5.0, &[1.0], &options_no_dense).unwrap();
671        assert!(result_no_dense.dense_output.is_none());
672    }
673
674    #[test]
675    fn test_harmonic_oscillator() {
676        // d²x/dt² = -x => y' = [y[1], -y[0]]
677        // x(0) = 1, x'(0) = 0
678        // Exact: x(t) = cos(t)
679        let problem = OdeProblem::new(
680            |_t: f64, y: &[f64], dydt: &mut [f64]| {
681                dydt[0] = y[1];
682                dydt[1] = -y[0];
683            },
684            0.0,
685            10.0,
686            vec![1.0, 0.0],
687        );
688
689        let options = SolverOptions::default().rtol(1e-8).atol(1e-10);
690        let result = DoPri5::solve(&problem, 0.0, 10.0, &[1.0, 0.0], &options).unwrap();
691
692        assert!(result.success);
693        let y_final = result.y_final().unwrap();
694        let exact_x = 10.0_f64.cos();
695        let exact_v = -10.0_f64.sin();
696
697        let error_x = (y_final[0] - exact_x).abs();
698        let error_v = (y_final[1] - exact_v).abs();
699
700        assert!(error_x < 1e-6, "Position error {} too large", error_x);
701        assert!(error_v < 1e-6, "Velocity error {} too large", error_v);
702    }
703
704    #[test]
705    fn test_lorenz_stability() {
706        // Lorenz system - just check that it runs without blowing up
707        let problem = OdeProblem::new(
708            |_t: f64, y: &[f64], dydt: &mut [f64]| {
709                let sigma = 10.0;
710                let rho = 28.0;
711                let beta = 8.0 / 3.0;
712                dydt[0] = sigma * (y[1] - y[0]);
713                dydt[1] = y[0] * (rho - y[2]) - y[1];
714                dydt[2] = y[0] * y[1] - beta * y[2];
715            },
716            0.0,
717            20.0,
718            vec![1.0, 1.0, 1.0],
719        );
720
721        let options = SolverOptions::default();
722        let result = DoPri5::solve(&problem, 0.0, 20.0, &[1.0, 1.0, 1.0], &options).unwrap();
723
724        assert!(result.success);
725        let y_final = result.y_final().unwrap();
726
727        // Solution should remain bounded
728        for &yi in y_final.iter() {
729            assert!(yi.abs() < 100.0, "Solution blew up");
730        }
731    }
732
733    #[test]
734    fn test_backward_integration() {
735        // dy/dt = -y, solve backward from t=5 to t=0
736        // y(5) = e^(-5), exact: y(0) = 1
737        let y5 = (-5.0_f64).exp();
738
739        let problem = OdeProblem::new(
740            |_t: f64, y: &[f64], dydt: &mut [f64]| {
741                dydt[0] = -y[0];
742            },
743            5.0,
744            0.0,
745            vec![y5],
746        );
747
748        let options = SolverOptions::default().rtol(1e-8).atol(1e-10);
749        let result = DoPri5::solve(&problem, 5.0, 0.0, &[y5], &options).unwrap();
750
751        assert!(result.success);
752        let y_final = result.y_final().unwrap();
753        let error = (y_final[0] - 1.0).abs();
754        assert!(error < 1e-6, "Error {} too large", error);
755    }
756
757    #[test]
758    fn test_stats() {
759        let problem = OdeProblem::new(
760            |_t: f64, y: &[f64], dydt: &mut [f64]| {
761                dydt[0] = -y[0];
762            },
763            0.0,
764            1.0,
765            vec![1.0],
766        );
767
768        let options = SolverOptions::default();
769        let result = DoPri5::solve(&problem, 0.0, 1.0, &[1.0], &options).unwrap();
770
771        assert!(result.stats.n_accept > 0);
772        assert!(result.stats.n_eval > 0);
773    }
774
775    // ============================================================================
776    // Edge Case Tests
777    // ============================================================================
778
779    #[test]
780    fn test_zero_interval() {
781        // t0 == t_end, should return immediately
782        let problem = OdeProblem::new(
783            |_t: f64, y: &[f64], dydt: &mut [f64]| {
784                dydt[0] = -y[0];
785            },
786            0.0,
787            0.0,
788            vec![1.0],
789        );
790
791        let options = SolverOptions::default();
792        let result = DoPri5::solve(&problem, 0.0, 0.0, &[1.0], &options).unwrap();
793
794        assert!(result.success);
795        let y_final = result.y_final().unwrap();
796        assert!((y_final[0] - 1.0).abs() < 1e-15);
797    }
798
799    #[test]
800    fn test_very_short_interval() {
801        // Very small integration interval
802        let problem = OdeProblem::new(
803            |_t: f64, y: &[f64], dydt: &mut [f64]| {
804                dydt[0] = -y[0];
805            },
806            0.0,
807            1e-10,
808            vec![1.0],
809        );
810
811        let options = SolverOptions::default();
812        let result = DoPri5::solve(&problem, 0.0, 1e-10, &[1.0], &options).unwrap();
813
814        assert!(result.success);
815        let y_final = result.y_final().unwrap();
816        // Should be very close to initial
817        assert!((y_final[0] - 1.0).abs() < 1e-8);
818    }
819
820    #[test]
821    fn test_constant_zero_rhs() {
822        // dy/dt = 0, solution should stay constant
823        let problem = OdeProblem::new(
824            |_t: f64, _y: &[f64], dydt: &mut [f64]| {
825                dydt[0] = 0.0;
826            },
827            0.0,
828            10.0,
829            vec![42.0],
830        );
831
832        let options = SolverOptions::default();
833        let result = DoPri5::solve(&problem, 0.0, 10.0, &[42.0], &options).unwrap();
834
835        assert!(result.success);
836        let y_final = result.y_final().unwrap();
837        assert!((y_final[0] - 42.0).abs() < 1e-12);
838    }
839
840    #[test]
841    fn test_single_step_only() {
842        // Use max_steps = 1 to force early termination
843        let problem = OdeProblem::new(
844            |_t: f64, y: &[f64], dydt: &mut [f64]| {
845                dydt[0] = -y[0];
846            },
847            0.0,
848            10.0,
849            vec![1.0],
850        );
851
852        let options = SolverOptions::default().max_steps(1);
853        let result = DoPri5::solve(&problem, 0.0, 10.0, &[1.0], &options);
854
855        // Should return an error when max_steps is exceeded
856        assert!(result.is_err());
857        assert!(matches!(
858            result.unwrap_err(),
859            crate::error::SolverError::MaxIterationsExceeded { .. }
860        ));
861    }
862
863    #[test]
864    fn test_tight_tolerance() {
865        // Very tight tolerances
866        let problem = OdeProblem::new(
867            |_t: f64, y: &[f64], dydt: &mut [f64]| {
868                dydt[0] = -y[0];
869            },
870            0.0,
871            1.0,
872            vec![1.0],
873        );
874
875        let options = SolverOptions::default().rtol(1e-12).atol(1e-14);
876        let result = DoPri5::solve(&problem, 0.0, 1.0, &[1.0], &options).unwrap();
877
878        assert!(result.success);
879        let y_final = result.y_final().unwrap();
880        let exact = (-1.0_f64).exp();
881        let error = (y_final[0] - exact).abs();
882        assert!(
883            error < 1e-11,
884            "Error {} too large for tight tolerance",
885            error
886        );
887    }
888
889    #[test]
890    fn test_loose_tolerance() {
891        // Very loose tolerances - should complete quickly
892        let problem = OdeProblem::new(
893            |_t: f64, y: &[f64], dydt: &mut [f64]| {
894                dydt[0] = -y[0];
895            },
896            0.0,
897            1.0,
898            vec![1.0],
899        );
900
901        let options = SolverOptions::default().rtol(1e-2).atol(1e-3);
902        let result = DoPri5::solve(&problem, 0.0, 1.0, &[1.0], &options).unwrap();
903
904        assert!(result.success);
905        // Should have fewer steps than tight tolerance
906        assert!(result.stats.n_accept < 50);
907    }
908
909    #[test]
910    fn test_zero_initial_condition() {
911        // y(0) = 0, dy/dt = 1, exact: y(t) = t
912        let problem = OdeProblem::new(
913            |_t: f64, _y: &[f64], dydt: &mut [f64]| {
914                dydt[0] = 1.0;
915            },
916            0.0,
917            5.0,
918            vec![0.0],
919        );
920
921        let options = SolverOptions::default();
922        let result = DoPri5::solve(&problem, 0.0, 5.0, &[0.0], &options).unwrap();
923
924        assert!(result.success);
925        let y_final = result.y_final().unwrap();
926        assert!((y_final[0] - 5.0).abs() < 1e-8);
927    }
928
929    #[test]
930    fn test_large_initial_condition() {
931        // Large initial value
932        let problem = OdeProblem::new(
933            |_t: f64, y: &[f64], dydt: &mut [f64]| {
934                dydt[0] = -0.1 * y[0];
935            },
936            0.0,
937            1.0,
938            vec![1e10],
939        );
940
941        let options = SolverOptions::default();
942        let result = DoPri5::solve(&problem, 0.0, 1.0, &[1e10], &options).unwrap();
943
944        assert!(result.success);
945        let y_final = result.y_final().unwrap();
946        let exact = 1e10 * (-0.1_f64).exp();
947        let rel_error = (y_final[0] - exact).abs() / exact;
948        assert!(rel_error < 1e-5, "Relative error {} too large", rel_error);
949    }
950
951    #[test]
952    fn test_high_dimension() {
953        // Higher dimensional system (10D linear decay)
954        let problem = OdeProblem::new(
955            |_t: f64, y: &[f64], dydt: &mut [f64]| {
956                for (i, &yi) in y.iter().enumerate() {
957                    dydt[i] = -(i as f64 + 1.0) * 0.1 * yi;
958                }
959            },
960            0.0,
961            1.0,
962            vec![1.0; 10],
963        );
964
965        let options = SolverOptions::default();
966        let result = DoPri5::solve(&problem, 0.0, 1.0, &[1.0; 10], &options).unwrap();
967
968        assert!(result.success);
969        let y_final = result.y_final().unwrap();
970        assert_eq!(y_final.len(), 10);
971
972        // Check each component
973        for (i, &yi) in y_final.iter().enumerate() {
974            let rate = (i as f64 + 1.0) * 0.1;
975            let exact = (-rate).exp();
976            let error = (yi - exact).abs();
977            assert!(error < 1e-5, "Component {} error {} too large", i, error);
978        }
979    }
980
981    // ============================================================================
982    // Event Detection Tests
983    // ============================================================================
984
985    #[test]
986    fn test_event_detection_bouncing_ball() {
987        // Bouncing ball: y'' = -g  =>  y[0]' = y[1], y[1]' = -g
988        // Event: y[0] = 0 (ground contact, falling direction)
989        use crate::events::{EventAction, EventDirection, EventFunction};
990
991        struct GroundContact;
992
993        impl EventFunction<f64> for GroundContact {
994            fn evaluate(&self, _t: f64, y: &[f64]) -> f64 {
995                y[0] // Event when height = 0
996            }
997
998            fn direction(&self) -> EventDirection {
999                EventDirection::Falling // Only when falling (y goes from + to -)
1000            }
1001
1002            fn action(&self) -> EventAction {
1003                EventAction::Stop // Stop at first bounce
1004            }
1005        }
1006
1007        let g = 9.81_f64;
1008        let problem = OdeProblem::new(
1009            |_t, y: &[f64], dydt: &mut [f64]| {
1010                dydt[0] = y[1]; // dy/dt = v
1011                dydt[1] = -g; // dv/dt = -g
1012            },
1013            0.0,
1014            10.0,
1015            vec![10.0, 0.0],
1016        );
1017
1018        let y0 = vec![10.0, 0.0]; // height=10, velocity=0
1019
1020        let options = SolverOptions::default()
1021            .rtol(1e-8)
1022            .atol(1e-10)
1023            .event(Box::new(GroundContact));
1024
1025        let result = DoPri5::solve(&problem, 0.0, 10.0, &y0, &options).unwrap();
1026
1027        // Should stop at first bounce
1028        assert!(
1029            result.terminated_by_event,
1030            "Should have terminated by event"
1031        );
1032        assert!(!result.events.is_empty(), "Should have detected events");
1033
1034        // At event, height should be ~0
1035        let event = &result.events[0];
1036        assert!(
1037            event.y[0].abs() < 1e-4,
1038            "Event should occur at y=0, got y={}",
1039            event.y[0]
1040        );
1041
1042        // Time to fall from height 10: t = sqrt(2h/g) ~ 1.428
1043        let expected_t = (2.0 * 10.0 / g).sqrt();
1044        assert!(
1045            (event.t - expected_t).abs() < 0.01,
1046            "Expected t={:.3}, got t={:.3}",
1047            expected_t,
1048            event.t
1049        );
1050    }
1051
1052    #[test]
1053    fn test_event_continue_action() {
1054        // Event with Continue action should record but not stop
1055        use crate::events::{EventAction, EventDirection, EventFunction};
1056
1057        struct ZeroCrossing;
1058
1059        impl EventFunction<f64> for ZeroCrossing {
1060            fn evaluate(&self, _t: f64, y: &[f64]) -> f64 {
1061                y[0] // Event when y[0] crosses zero
1062            }
1063
1064            fn direction(&self) -> EventDirection {
1065                EventDirection::Both
1066            }
1067
1068            fn action(&self) -> EventAction {
1069                EventAction::Continue // Record but continue
1070            }
1071        }
1072
1073        // Harmonic oscillator: y(t) = cos(t), crosses zero at pi/2, 3pi/2, etc.
1074        let problem = OdeProblem::new(
1075            |_t: f64, y: &[f64], dydt: &mut [f64]| {
1076                dydt[0] = y[1];
1077                dydt[1] = -y[0];
1078            },
1079            0.0,
1080            10.0,
1081            vec![1.0, 0.0],
1082        );
1083
1084        let options = SolverOptions::default()
1085            .rtol(1e-8)
1086            .atol(1e-10)
1087            .event(Box::new(ZeroCrossing));
1088
1089        let result = DoPri5::solve(&problem, 0.0, 10.0, &[1.0, 0.0], &options).unwrap();
1090
1091        // Should NOT terminate by event
1092        assert!(
1093            !result.terminated_by_event,
1094            "Should not have terminated by event"
1095        );
1096
1097        // Should have detected multiple zero crossings of cos(t)
1098        // In [0, 10], cos(t) = 0 at t = pi/2, 3pi/2, 5pi/2, 7pi/2, 9pi/2
1099        // That's roughly at t = 1.571, 4.712, 7.854 (and 3 more approaching 10)
1100        assert!(
1101            result.events.len() >= 3,
1102            "Should have detected at least 3 events, got {}",
1103            result.events.len()
1104        );
1105
1106        // Check first event is near pi/2
1107        let first = &result.events[0];
1108        let expected_t = std::f64::consts::FRAC_PI_2;
1109        assert!(
1110            (first.t - expected_t).abs() < 0.01,
1111            "First event expected at t={:.3}, got t={:.3}",
1112            expected_t,
1113            first.t
1114        );
1115    }
1116
1117    #[test]
1118    fn test_event_rising_only_integration() {
1119        // Harmonic oscillator: y(t) = cos(t), y'(t) = -sin(t)
1120        // y(t) crosses zero at pi/2 (falling) and 3pi/2 (rising).
1121        // With Rising direction, should only detect rising crossings.
1122        use crate::events::{EventAction, EventDirection, EventFunction};
1123
1124        struct RisingZeroCrossing;
1125
1126        impl EventFunction<f64> for RisingZeroCrossing {
1127            fn evaluate(&self, _t: f64, y: &[f64]) -> f64 {
1128                y[0]
1129            }
1130            fn direction(&self) -> EventDirection {
1131                EventDirection::Rising
1132            }
1133            fn action(&self) -> EventAction {
1134                EventAction::Continue
1135            }
1136        }
1137
1138        let problem = OdeProblem::new(
1139            |_t: f64, y: &[f64], dydt: &mut [f64]| {
1140                dydt[0] = y[1];
1141                dydt[1] = -y[0];
1142            },
1143            0.0,
1144            10.0,
1145            vec![1.0, 0.0],
1146        );
1147
1148        let options = SolverOptions::default()
1149            .rtol(1e-8)
1150            .atol(1e-10)
1151            .event(Box::new(RisingZeroCrossing));
1152
1153        let result = DoPri5::solve(&problem, 0.0, 10.0, &[1.0, 0.0], &options).unwrap();
1154
1155        // cos(t) = 0 rising at t = 3pi/2, 7pi/2 (approximately 4.712, 10.996)
1156        // In [0, 10], only 3pi/2 ≈ 4.712 qualifies as rising
1157        // At pi/2, cos goes from + to - (falling), should be filtered out
1158        for event in &result.events {
1159            // At each event, y[0] ~ 0 and velocity y[1] should be positive (rising)
1160            assert!(
1161                event.y[1] > -0.1,
1162                "Rising event should have positive velocity, got y[1]={}",
1163                event.y[1]
1164            );
1165        }
1166        assert!(
1167            !result.events.is_empty(),
1168            "Should detect at least one rising zero crossing"
1169        );
1170    }
1171
1172    #[test]
1173    fn test_event_simultaneous_events() {
1174        // Two event functions that detect the same zero crossing
1175        use crate::events::{EventAction, EventDirection, EventFunction};
1176
1177        struct ZeroCross1;
1178        impl EventFunction<f64> for ZeroCross1 {
1179            fn evaluate(&self, _t: f64, y: &[f64]) -> f64 {
1180                y[0]
1181            }
1182            fn direction(&self) -> EventDirection {
1183                EventDirection::Both
1184            }
1185            fn action(&self) -> EventAction {
1186                EventAction::Continue
1187            }
1188        }
1189
1190        struct ZeroCross2;
1191        impl EventFunction<f64> for ZeroCross2 {
1192            fn evaluate(&self, _t: f64, y: &[f64]) -> f64 {
1193                y[0]
1194            }
1195            fn direction(&self) -> EventDirection {
1196                EventDirection::Both
1197            }
1198            fn action(&self) -> EventAction {
1199                EventAction::Continue
1200            }
1201        }
1202
1203        let problem = OdeProblem::new(
1204            |_t: f64, y: &[f64], dydt: &mut [f64]| {
1205                dydt[0] = y[1];
1206                dydt[1] = -y[0];
1207            },
1208            0.0,
1209            5.0,
1210            vec![1.0, 0.0],
1211        );
1212
1213        let options = SolverOptions::default()
1214            .rtol(1e-8)
1215            .atol(1e-10)
1216            .event(Box::new(ZeroCross1))
1217            .event(Box::new(ZeroCross2));
1218
1219        let result = DoPri5::solve(&problem, 0.0, 5.0, &[1.0, 0.0], &options).unwrap();
1220
1221        // Both event functions detect the same crossings, so we expect
1222        // pairs of events at each crossing time
1223        assert!(
1224            result.events.len() >= 4,
1225            "Should detect events from both functions, got {}",
1226            result.events.len()
1227        );
1228
1229        // Check that events from both indices are present
1230        let has_idx_0 = result.events.iter().any(|e| e.event_index == 0);
1231        let has_idx_1 = result.events.iter().any(|e| e.event_index == 1);
1232        assert!(has_idx_0, "Should have events from function 0");
1233        assert!(has_idx_1, "Should have events from function 1");
1234    }
1235
1236    #[test]
1237    fn test_event_backward_integration() {
1238        // Event detection during backward integration
1239        use crate::events::{EventAction, EventDirection, EventFunction};
1240
1241        struct ZeroCross;
1242        impl EventFunction<f64> for ZeroCross {
1243            fn evaluate(&self, _t: f64, y: &[f64]) -> f64 {
1244                y[0]
1245            }
1246            fn direction(&self) -> EventDirection {
1247                EventDirection::Both
1248            }
1249            fn action(&self) -> EventAction {
1250                EventAction::Stop
1251            }
1252        }
1253
1254        // Harmonic oscillator backward from t=5 to t=0
1255        // cos(t) crosses zero going backward
1256        let y5 = [5.0_f64.cos(), -5.0_f64.sin()];
1257        let problem = OdeProblem::new(
1258            |_t: f64, y: &[f64], dydt: &mut [f64]| {
1259                dydt[0] = y[1];
1260                dydt[1] = -y[0];
1261            },
1262            5.0,
1263            0.0,
1264            y5.to_vec(),
1265        );
1266
1267        let options = SolverOptions::default()
1268            .rtol(1e-8)
1269            .atol(1e-10)
1270            .event(Box::new(ZeroCross));
1271
1272        let result = DoPri5::solve(&problem, 5.0, 0.0, &y5, &options).unwrap();
1273
1274        // Should detect a zero crossing during backward integration
1275        assert!(
1276            result.terminated_by_event,
1277            "Should terminate at event during backward integration"
1278        );
1279        assert!(
1280            !result.events.is_empty(),
1281            "Should detect events during backward integration"
1282        );
1283
1284        // The event time should be between 0 and 5
1285        let event = &result.events[0];
1286        assert!(
1287            event.t > 0.0 && event.t < 5.0,
1288            "Event time {} should be between 0 and 5",
1289            event.t
1290        );
1291        assert!(
1292            event.y[0].abs() < 0.01,
1293            "y at event should be ~0, got {}",
1294            event.y[0]
1295        );
1296    }
1297
1298    #[test]
1299    fn test_no_event_when_no_crossing() {
1300        // Exponential decay never crosses zero, so no events should be detected
1301        use crate::events::{EventAction, EventDirection, EventFunction};
1302
1303        struct ZeroCheck;
1304
1305        impl EventFunction<f64> for ZeroCheck {
1306            fn evaluate(&self, _t: f64, y: &[f64]) -> f64 {
1307                y[0] // y[0] = e^(-t) > 0 always
1308            }
1309
1310            fn direction(&self) -> EventDirection {
1311                EventDirection::Both
1312            }
1313
1314            fn action(&self) -> EventAction {
1315                EventAction::Stop
1316            }
1317        }
1318
1319        let problem = OdeProblem::new(
1320            |_t: f64, y: &[f64], dydt: &mut [f64]| {
1321                dydt[0] = -y[0];
1322            },
1323            0.0,
1324            5.0,
1325            vec![1.0],
1326        );
1327
1328        let options = SolverOptions::default().event(Box::new(ZeroCheck));
1329
1330        let result = DoPri5::solve(&problem, 0.0, 5.0, &[1.0], &options).unwrap();
1331
1332        assert!(!result.terminated_by_event);
1333        assert!(result.events.is_empty());
1334    }
1335
1336    // ============================================================================
1337    // f32 Scalar Tests
1338    // ============================================================================
1339
1340    #[test]
1341    fn test_exponential_decay_f32() {
1342        // Verify DoPri5 works with f32 scalars
1343        let problem = OdeProblem::new(
1344            |_t: f32, y: &[f32], dydt: &mut [f32]| {
1345                dydt[0] = -y[0];
1346            },
1347            0.0f32,
1348            5.0f32,
1349            vec![1.0f32],
1350        );
1351
1352        let options: SolverOptions<f32> = SolverOptions::default().rtol(1e-4).atol(1e-6);
1353        let result = DoPri5::solve(&problem, 0.0f32, 5.0f32, &[1.0f32], &options).unwrap();
1354
1355        assert!(result.success);
1356        let y_final = result.y_final().unwrap();
1357        let exact = (-5.0f32).exp();
1358        let error = (y_final[0] - exact).abs();
1359        assert!(error < 1e-3, "f32 error {} too large", error);
1360    }
1361
1362    #[test]
1363    fn test_harmonic_oscillator_f32() {
1364        let problem = OdeProblem::new(
1365            |_t: f32, y: &[f32], dydt: &mut [f32]| {
1366                dydt[0] = y[1];
1367                dydt[1] = -y[0];
1368            },
1369            0.0f32,
1370            6.0f32,
1371            vec![1.0f32, 0.0f32],
1372        );
1373
1374        let options: SolverOptions<f32> = SolverOptions::default().rtol(1e-4).atol(1e-6);
1375        let result = DoPri5::solve(&problem, 0.0f32, 6.0f32, &[1.0f32, 0.0f32], &options).unwrap();
1376
1377        assert!(result.success);
1378        let y_final = result.y_final().unwrap();
1379        let exact_x = 6.0f32.cos();
1380        let error = (y_final[0] - exact_x).abs();
1381        assert!(error < 1e-3, "f32 harmonic error {} too large", error);
1382    }
1383
1384    // ============================================================================
1385    // NaN / Infinity Input Tests
1386    // ============================================================================
1387
1388    #[test]
1389    fn test_nan_initial_condition() {
1390        // NaN in initial condition — solver should detect and return error
1391        let problem = OdeProblem::new(
1392            |_t: f64, y: &[f64], dydt: &mut [f64]| {
1393                dydt[0] = -y[0];
1394            },
1395            0.0,
1396            1.0,
1397            vec![f64::NAN],
1398        );
1399
1400        let options = SolverOptions::default();
1401        let result = DoPri5::solve(&problem, 0.0, 1.0, &[f64::NAN], &options);
1402        assert!(
1403            result.is_err(),
1404            "NaN initial condition should produce error"
1405        );
1406    }
1407
1408    #[test]
1409    fn test_infinity_initial_condition() {
1410        // Infinity in initial condition — solver should detect NaN from computation
1411        let problem = OdeProblem::new(
1412            |_t: f64, y: &[f64], dydt: &mut [f64]| {
1413                dydt[0] = -y[0];
1414            },
1415            0.0,
1416            1.0,
1417            vec![f64::INFINITY],
1418        );
1419
1420        let options = SolverOptions::default();
1421        let result = DoPri5::solve(&problem, 0.0, 1.0, &[f64::INFINITY], &options);
1422        // Infinity * 0 = NaN in error estimate, so should fail
1423        assert!(
1424            result.is_err(),
1425            "Infinity initial condition should produce error"
1426        );
1427    }
1428
1429    #[test]
1430    fn test_rhs_produces_nan() {
1431        // RHS function produces NaN — solver should detect and return error
1432        let problem = OdeProblem::new(
1433            |_t: f64, _y: &[f64], dydt: &mut [f64]| {
1434                dydt[0] = f64::NAN;
1435            },
1436            0.0,
1437            1.0,
1438            vec![1.0],
1439        );
1440
1441        let options = SolverOptions::default();
1442        let result = DoPri5::solve(&problem, 0.0, 1.0, &[1.0], &options);
1443        assert!(result.is_err(), "NaN in RHS should produce error");
1444    }
1445}