1use crate::dense::{DenseOutput, DenseSegment, DoPri5Interpolant};
22use crate::error::SolverError;
23use crate::events::{find_event_time, Event, EventAction};
24use crate::problem::OdeSystem;
25use crate::solver::{Solver, SolverOptions, SolverResult, SolverStats};
26use crate::step_control::{PIController, StepController};
27use crate::t_eval::{validate_grid, TEvalEmitter};
28use numra_core::Scalar;
29
30#[derive(Clone, Debug, Default)]
53pub struct DoPri5;
54
55impl DoPri5 {
56 pub fn new() -> Self {
58 Self
59 }
60}
61
62#[allow(dead_code)]
78mod tableau {
79 pub const C2: f64 = 1.0 / 5.0;
81 pub const C3: f64 = 3.0 / 10.0;
82 pub const C4: f64 = 4.0 / 5.0;
83 pub const C5: f64 = 8.0 / 9.0;
84 pub const C6: f64 = 1.0;
85 pub const C7: f64 = 1.0;
86
87 pub const A21: f64 = 1.0 / 5.0;
89
90 pub const A31: f64 = 3.0 / 40.0;
91 pub const A32: f64 = 9.0 / 40.0;
92
93 pub const A41: f64 = 44.0 / 45.0;
94 pub const A42: f64 = -56.0 / 15.0;
95 pub const A43: f64 = 32.0 / 9.0;
96
97 pub const A51: f64 = 19372.0 / 6561.0;
98 pub const A52: f64 = -25360.0 / 2187.0;
99 pub const A53: f64 = 64448.0 / 6561.0;
100 pub const A54: f64 = -212.0 / 729.0;
101
102 pub const A61: f64 = 9017.0 / 3168.0;
103 pub const A62: f64 = -355.0 / 33.0;
104 pub const A63: f64 = 46732.0 / 5247.0;
105 pub const A64: f64 = 49.0 / 176.0;
106 pub const A65: f64 = -5103.0 / 18656.0;
107
108 pub const A71: f64 = 35.0 / 384.0;
109 pub const A72: f64 = 0.0;
110 pub const A73: f64 = 500.0 / 1113.0;
111 pub const A74: f64 = 125.0 / 192.0;
112 pub const A75: f64 = -2187.0 / 6784.0;
113 pub const A76: f64 = 11.0 / 84.0;
114
115 pub const B1: f64 = 35.0 / 384.0;
117 pub const B2: f64 = 0.0;
118 pub const B3: f64 = 500.0 / 1113.0;
119 pub const B4: f64 = 125.0 / 192.0;
120 pub const B5: f64 = -2187.0 / 6784.0;
121 pub const B6: f64 = 11.0 / 84.0;
122 pub const B7: f64 = 0.0;
123
124 pub const B1_HAT: f64 = 5179.0 / 57600.0;
126 pub const B2_HAT: f64 = 0.0;
127 pub const B3_HAT: f64 = 7571.0 / 16695.0;
128 pub const B4_HAT: f64 = 393.0 / 640.0;
129 pub const B5_HAT: f64 = -92097.0 / 339200.0;
130 pub const B6_HAT: f64 = 187.0 / 2100.0;
131 pub const B7_HAT: f64 = 1.0 / 40.0;
132
133 pub const E1: f64 = B1 - B1_HAT;
135 pub const E2: f64 = B2 - B2_HAT;
136 pub const E3: f64 = B3 - B3_HAT;
137 pub const E4: f64 = B4 - B4_HAT;
138 pub const E5: f64 = B5 - B5_HAT;
139 pub const E6: f64 = B6 - B6_HAT;
140 pub const E7: f64 = B7 - B7_HAT;
141}
142
143impl<S: Scalar> Solver<S> for DoPri5 {
144 fn solve<Sys: OdeSystem<S>>(
145 problem: &Sys,
146 t0: S,
147 tf: S,
148 y0: &[S],
149 options: &SolverOptions<S>,
150 ) -> Result<SolverResult<S>, SolverError> {
151 let dim = problem.dim();
152 if y0.len() != dim {
153 return Err(SolverError::DimensionMismatch {
154 expected: dim,
155 actual: y0.len(),
156 });
157 }
158
159 let direction = if tf >= t0 { S::ONE } else { -S::ONE };
161
162 if let Some(grid) = options.t_eval.as_deref() {
167 validate_grid(grid, t0, tf)?;
168 }
169 let mut grid_emitter = options
170 .t_eval
171 .as_deref()
172 .map(|g| TEvalEmitter::new(g, direction));
173
174 let mut controller = PIController::for_order(5);
176
177 let mut h = match options.h0 {
179 Some(h0) => direction * h0.abs(),
180 None => estimate_initial_step(problem, t0, y0, direction, options),
181 };
182
183 h = direction * h.abs().min(options.h_max).max(options.h_min);
185
186 let mut t = t0;
188 let mut y = y0.to_vec();
189 let mut y_new = vec![S::ZERO; dim];
190
191 let mut k1 = vec![S::ZERO; dim];
193 let mut k2 = vec![S::ZERO; dim];
194 let mut k3 = vec![S::ZERO; dim];
195 let mut k4 = vec![S::ZERO; dim];
196 let mut k5 = vec![S::ZERO; dim];
197 let mut k6 = vec![S::ZERO; dim];
198 let mut k7 = vec![S::ZERO; dim];
199
200 let mut y_stage = vec![S::ZERO; dim];
202
203 let mut err = vec![S::ZERO; dim];
205
206 let mut k_all = if options.dense_output {
208 vec![S::ZERO; 7 * dim]
209 } else {
210 Vec::new()
211 };
212
213 let (mut t_out, mut y_out) = if grid_emitter.is_some() {
216 (Vec::new(), Vec::new())
217 } else {
218 (vec![t0], y0.to_vec())
219 };
220
221 let has_events = !options.events.is_empty();
223 let mut detected_events: Vec<Event<S>> = Vec::new();
224
225 let mut g_prev: Vec<S> = options
227 .events
228 .iter()
229 .map(|ef| ef.evaluate(t0, y0))
230 .collect();
231
232 let mut stats = SolverStats::new();
234
235 let mut dense = if options.dense_output {
237 DenseOutput::new(dim, direction)
238 } else {
239 DenseOutput::new(0, direction)
240 };
241
242 problem.rhs(t, &y, &mut k1);
244 stats.n_eval += 1;
245
246 let mut tol_weights = vec![S::ZERO; dim];
248 let update_tol_weights = |weights: &mut [S], y: &[S]| {
249 for (w, &yi) in weights.iter_mut().zip(y.iter()) {
250 *w = options.atol + options.rtol * yi.abs();
251 }
252 };
253
254 let mut step_count = 0;
256 let mut last_step = false;
257
258 while !last_step {
259 if step_count >= options.max_steps {
261 return Err(SolverError::MaxIterationsExceeded { t: t.to_f64() });
262 }
263
264 if direction * (t + h - tf) > S::ZERO {
266 h = tf - t;
267 last_step = true;
268 }
269
270 for i in 0..dim {
275 y_stage[i] = y[i] + h * S::from_f64(tableau::A21) * k1[i];
276 }
277 problem.rhs(t + h * S::from_f64(tableau::C2), &y_stage, &mut k2);
278
279 for i in 0..dim {
281 y_stage[i] = y[i]
282 + h * (S::from_f64(tableau::A31) * k1[i] + S::from_f64(tableau::A32) * k2[i]);
283 }
284 problem.rhs(t + h * S::from_f64(tableau::C3), &y_stage, &mut k3);
285
286 for i in 0..dim {
288 y_stage[i] = y[i]
289 + h * (S::from_f64(tableau::A41) * k1[i]
290 + S::from_f64(tableau::A42) * k2[i]
291 + S::from_f64(tableau::A43) * k3[i]);
292 }
293 problem.rhs(t + h * S::from_f64(tableau::C4), &y_stage, &mut k4);
294
295 for i in 0..dim {
297 y_stage[i] = y[i]
298 + h * (S::from_f64(tableau::A51) * k1[i]
299 + S::from_f64(tableau::A52) * k2[i]
300 + S::from_f64(tableau::A53) * k3[i]
301 + S::from_f64(tableau::A54) * k4[i]);
302 }
303 problem.rhs(t + h * S::from_f64(tableau::C5), &y_stage, &mut k5);
304
305 for i in 0..dim {
307 y_stage[i] = y[i]
308 + h * (S::from_f64(tableau::A61) * k1[i]
309 + S::from_f64(tableau::A62) * k2[i]
310 + S::from_f64(tableau::A63) * k3[i]
311 + S::from_f64(tableau::A64) * k4[i]
312 + S::from_f64(tableau::A65) * k5[i]);
313 }
314 problem.rhs(t + h * S::from_f64(tableau::C6), &y_stage, &mut k6);
315
316 for i in 0..dim {
318 y_new[i] = y[i]
319 + h * (S::from_f64(tableau::B1) * k1[i]
320 + S::from_f64(tableau::B3) * k3[i]
321 + S::from_f64(tableau::B4) * k4[i]
322 + S::from_f64(tableau::B5) * k5[i]
323 + S::from_f64(tableau::B6) * k6[i]);
324 }
325
326 problem.rhs(t + h, &y_new, &mut k7);
328
329 stats.n_eval += 6; for i in 0..dim {
333 err[i] = h
334 * (S::from_f64(tableau::E1) * k1[i]
335 + S::from_f64(tableau::E3) * k3[i]
336 + S::from_f64(tableau::E4) * k4[i]
337 + S::from_f64(tableau::E5) * k5[i]
338 + S::from_f64(tableau::E6) * k6[i]
339 + S::from_f64(tableau::E7) * k7[i]);
340 }
341
342 update_tol_weights(&mut tol_weights, &y);
344 let err_norm = weighted_rms_norm(&err, &tol_weights);
345
346 if err_norm.is_nan() {
348 return Err(SolverError::Other(
349 "NaN detected in error estimate (check inputs and RHS function)".to_string(),
350 ));
351 }
352
353 let proposal = controller.propose(h, err_norm, 5);
355
356 if proposal.accept {
357 stats.n_accept += 1;
359 controller.accept(h, err_norm);
360
361 let interp_coeffs = if options.dense_output {
364 k_all[0..dim].copy_from_slice(&k1);
365 k_all[dim..2 * dim].copy_from_slice(&k2);
366 k_all[2 * dim..3 * dim].copy_from_slice(&k3);
367 k_all[3 * dim..4 * dim].copy_from_slice(&k4);
368 k_all[4 * dim..5 * dim].copy_from_slice(&k5);
369 k_all[5 * dim..6 * dim].copy_from_slice(&k6);
370 k_all[6 * dim..7 * dim].copy_from_slice(&k7);
371 Some(DoPri5Interpolant::build_coefficients(
372 &y, &y_new, &k_all, h, dim,
373 ))
374 } else {
375 None
376 };
377
378 if options.dense_output {
380 if let Some(ref coeffs) = interp_coeffs {
381 dense.add_segment(DenseSegment::new(t, t + h, coeffs.clone(), dim));
382 }
383 }
384
385 if has_events {
387 let t_new = t + h;
388
389 let mut stop_event = false;
390 let mut earliest_event_t = t_new;
391 let mut earliest_event_y: Option<Vec<S>> = None;
392
393 for (idx, event_fn) in options.events.iter().enumerate() {
394 let g_curr = event_fn.evaluate(t_new, &y_new);
395
396 if g_prev[idx] * g_curr < S::ZERO {
398 let y_ref = &y;
402 let y_new_ref = &y_new;
403 let k1_ref = &k1;
404 let k7_ref = &k7;
405 let t_start = t;
406 let h_step = h;
407 let interpolate = move |t_interp: S| -> Vec<S> {
408 let theta = (t_interp - t_start) / h_step;
409 let theta2 = theta * theta;
410 let theta3 = theta2 * theta;
411 let h00 = S::TWO * theta3 - S::from_f64(3.0) * theta2 + S::ONE;
413 let h10 = theta3 - S::TWO * theta2 + theta;
414 let h01 = -S::TWO * theta3 + S::from_f64(3.0) * theta2;
415 let h11 = theta3 - theta2;
416 let mut y_interp = vec![S::ZERO; dim];
417 for i in 0..dim {
418 y_interp[i] = h00 * y_ref[i]
419 + h10 * h_step * k1_ref[i]
420 + h01 * y_new_ref[i]
421 + h11 * h_step * k7_ref[i];
422 }
423 y_interp
424 };
425
426 if let Some((t_event, y_event)) = find_event_time(
427 event_fn.as_ref(),
428 t,
429 &y,
430 t_new,
431 &y_new,
432 &interpolate,
433 ) {
434 if earliest_event_y.is_none()
436 || (direction * (t_event - earliest_event_t) < S::ZERO)
437 {
438 earliest_event_t = t_event;
439 earliest_event_y = Some(y_event.clone());
440 }
441
442 detected_events.push(Event {
443 t: t_event,
444 y: y_event,
445 event_index: idx,
446 });
447
448 if event_fn.action() == EventAction::Stop {
449 stop_event = true;
450 }
451 }
452 }
453
454 g_prev[idx] = g_curr;
455 }
456
457 if stop_event {
458 let ev_t = earliest_event_t;
461 let ev_y = match earliest_event_y {
464 Some(y) => y,
465 None => {
466 return Err(SolverError::Other(
467 "Internal error: stop event without event data".into(),
468 ))
469 }
470 };
471
472 detected_events.retain(|e| direction * (e.t - ev_t) <= S::ZERO);
473
474 t_out.push(ev_t);
475 y_out.extend_from_slice(&ev_y);
476
477 let mut result = SolverResult::new(t_out, y_out, dim, stats);
478 result.events = detected_events;
479 result.terminated_by_event = true;
480 if options.dense_output && !dense.is_empty() {
481 result.dense_output = Some(dense);
482 }
483 return Ok(result);
484 }
485 }
486
487 let t_new = t + h;
493 if let Some(ref mut emitter) = grid_emitter {
494 emitter.emit_step(t, &y, &k1, t_new, &y_new, &k7, &mut t_out, &mut y_out);
495 } else {
496 t_out.push(t_new);
497 y_out.extend_from_slice(&y_new);
498 }
499
500 t = t_new;
502 y.copy_from_slice(&y_new);
503
504 k1.copy_from_slice(&k7);
506
507 step_count += 1;
508 } else {
509 stats.n_reject += 1;
511 controller.reject(h, err_norm);
512 last_step = false; }
514
515 h = direction * proposal.h_new.abs().min(options.h_max).max(options.h_min);
517 }
518
519 let mut result = SolverResult::new(t_out, y_out, dim, stats);
520 result.events = detected_events;
521 if options.dense_output && !dense.is_empty() {
522 result.dense_output = Some(dense);
523 }
524 Ok(result)
525 }
526}
527
528fn weighted_rms_norm<S: Scalar>(err: &[S], weights: &[S]) -> S {
530 let n = S::from_usize(err.len());
531 let mut sum = S::ZERO;
532 for (e, w) in err.iter().zip(weights.iter()) {
533 let scaled = *e / *w;
534 sum = sum + scaled * scaled;
535 }
536 (sum / n).sqrt()
537}
538
539fn estimate_initial_step<S: Scalar, Sys: OdeSystem<S>>(
541 problem: &Sys,
542 t0: S,
543 y0: &[S],
544 direction: S,
545 options: &SolverOptions<S>,
546) -> S {
547 let dim = problem.dim();
548
549 let mut f0 = vec![S::ZERO; dim];
551 problem.rhs(t0, y0, &mut f0);
552
553 let scale: Vec<S> = y0
555 .iter()
556 .map(|&yi| options.atol + options.rtol * yi.abs())
557 .collect();
558
559 let d0 = weighted_rms_norm(y0, &scale);
561
562 let d1 = weighted_rms_norm(&f0, &scale);
564
565 let h0 = if d0 < S::EPSILON.sqrt() || d1 < S::EPSILON.sqrt() {
567 S::from_f64(1e-6)
568 } else {
569 S::from_f64(0.01) * d0 / d1
570 };
571
572 let mut y1 = vec![S::ZERO; dim];
574 for i in 0..dim {
575 y1[i] = y0[i] + direction * h0 * f0[i];
576 }
577
578 let mut f1 = vec![S::ZERO; dim];
580 problem.rhs(t0 + direction * h0, &y1, &mut f1);
581
582 let mut df = vec![S::ZERO; dim];
584 for i in 0..dim {
585 df[i] = (f1[i] - f0[i]) / h0;
586 }
587 let d2 = weighted_rms_norm(&df, &scale);
588
589 let max_d = d1.max(d2);
591 let h1 = if max_d <= S::from_f64(1e-15) {
592 (h0 * S::from_f64(1e-3)).max(S::from_f64(1e-6))
593 } else {
594 (S::from_f64(0.01) / max_d).powf(S::from_f64(0.2))
595 };
596
597 let h = (S::from_f64(100.0) * h0).min(h1);
599
600 direction * h
601}
602
603#[cfg(test)]
604mod tests {
605 use super::*;
606 use crate::problem::OdeProblem;
607
608 #[test]
609 fn test_exponential_decay() {
610 let problem = OdeProblem::new(
612 |_t: f64, y: &[f64], dydt: &mut [f64]| {
613 dydt[0] = -y[0];
614 },
615 0.0,
616 5.0,
617 vec![1.0],
618 );
619
620 let options = SolverOptions::default().rtol(1e-8).atol(1e-10);
621 let result = DoPri5::solve(&problem, 0.0, 5.0, &[1.0], &options).unwrap();
622
623 assert!(result.success);
624 let y_final = result.y_final().unwrap();
625 let exact = (-5.0_f64).exp();
626 let error = (y_final[0] - exact).abs();
627 assert!(error < 1e-7, "Error {} too large", error);
628 }
629
630 #[test]
631 fn test_dense_output_returned_when_requested() {
632 use crate::dense::DenseInterpolant;
633 let problem = OdeProblem::new(
637 |_t: f64, y: &[f64], dydt: &mut [f64]| {
638 dydt[0] = -y[0];
639 },
640 0.0,
641 5.0,
642 vec![1.0],
643 );
644
645 let options = SolverOptions::default().rtol(1e-8).atol(1e-10).dense();
646 let result = DoPri5::solve(&problem, 0.0, 5.0, &[1.0], &options).unwrap();
647
648 let dense = result
649 .dense_output
650 .as_ref()
651 .expect("dense() requested; SolverResult.dense_output must be Some");
652 assert!(!dense.is_empty(), "dense output should contain segments");
653
654 let t_mid = 2.5;
655 let segment = dense
656 .find_segment(t_mid)
657 .expect("midpoint should fall inside an integrated segment");
658 let mut y_mid = vec![0.0; 1];
659 DoPri5Interpolant.interpolate(segment, t_mid, &mut y_mid);
660 let exact = (-t_mid).exp();
661 assert!(
662 (y_mid[0] - exact).abs() < 1e-3,
663 "interpolated value {} too far from exact {}",
664 y_mid[0],
665 exact
666 );
667
668 let options_no_dense = SolverOptions::default().rtol(1e-8).atol(1e-10);
670 let result_no_dense = DoPri5::solve(&problem, 0.0, 5.0, &[1.0], &options_no_dense).unwrap();
671 assert!(result_no_dense.dense_output.is_none());
672 }
673
674 #[test]
675 fn test_harmonic_oscillator() {
676 let problem = OdeProblem::new(
680 |_t: f64, y: &[f64], dydt: &mut [f64]| {
681 dydt[0] = y[1];
682 dydt[1] = -y[0];
683 },
684 0.0,
685 10.0,
686 vec![1.0, 0.0],
687 );
688
689 let options = SolverOptions::default().rtol(1e-8).atol(1e-10);
690 let result = DoPri5::solve(&problem, 0.0, 10.0, &[1.0, 0.0], &options).unwrap();
691
692 assert!(result.success);
693 let y_final = result.y_final().unwrap();
694 let exact_x = 10.0_f64.cos();
695 let exact_v = -10.0_f64.sin();
696
697 let error_x = (y_final[0] - exact_x).abs();
698 let error_v = (y_final[1] - exact_v).abs();
699
700 assert!(error_x < 1e-6, "Position error {} too large", error_x);
701 assert!(error_v < 1e-6, "Velocity error {} too large", error_v);
702 }
703
704 #[test]
705 fn test_lorenz_stability() {
706 let problem = OdeProblem::new(
708 |_t: f64, y: &[f64], dydt: &mut [f64]| {
709 let sigma = 10.0;
710 let rho = 28.0;
711 let beta = 8.0 / 3.0;
712 dydt[0] = sigma * (y[1] - y[0]);
713 dydt[1] = y[0] * (rho - y[2]) - y[1];
714 dydt[2] = y[0] * y[1] - beta * y[2];
715 },
716 0.0,
717 20.0,
718 vec![1.0, 1.0, 1.0],
719 );
720
721 let options = SolverOptions::default();
722 let result = DoPri5::solve(&problem, 0.0, 20.0, &[1.0, 1.0, 1.0], &options).unwrap();
723
724 assert!(result.success);
725 let y_final = result.y_final().unwrap();
726
727 for &yi in y_final.iter() {
729 assert!(yi.abs() < 100.0, "Solution blew up");
730 }
731 }
732
733 #[test]
734 fn test_backward_integration() {
735 let y5 = (-5.0_f64).exp();
738
739 let problem = OdeProblem::new(
740 |_t: f64, y: &[f64], dydt: &mut [f64]| {
741 dydt[0] = -y[0];
742 },
743 5.0,
744 0.0,
745 vec![y5],
746 );
747
748 let options = SolverOptions::default().rtol(1e-8).atol(1e-10);
749 let result = DoPri5::solve(&problem, 5.0, 0.0, &[y5], &options).unwrap();
750
751 assert!(result.success);
752 let y_final = result.y_final().unwrap();
753 let error = (y_final[0] - 1.0).abs();
754 assert!(error < 1e-6, "Error {} too large", error);
755 }
756
757 #[test]
758 fn test_stats() {
759 let problem = OdeProblem::new(
760 |_t: f64, y: &[f64], dydt: &mut [f64]| {
761 dydt[0] = -y[0];
762 },
763 0.0,
764 1.0,
765 vec![1.0],
766 );
767
768 let options = SolverOptions::default();
769 let result = DoPri5::solve(&problem, 0.0, 1.0, &[1.0], &options).unwrap();
770
771 assert!(result.stats.n_accept > 0);
772 assert!(result.stats.n_eval > 0);
773 }
774
775 #[test]
780 fn test_zero_interval() {
781 let problem = OdeProblem::new(
783 |_t: f64, y: &[f64], dydt: &mut [f64]| {
784 dydt[0] = -y[0];
785 },
786 0.0,
787 0.0,
788 vec![1.0],
789 );
790
791 let options = SolverOptions::default();
792 let result = DoPri5::solve(&problem, 0.0, 0.0, &[1.0], &options).unwrap();
793
794 assert!(result.success);
795 let y_final = result.y_final().unwrap();
796 assert!((y_final[0] - 1.0).abs() < 1e-15);
797 }
798
799 #[test]
800 fn test_very_short_interval() {
801 let problem = OdeProblem::new(
803 |_t: f64, y: &[f64], dydt: &mut [f64]| {
804 dydt[0] = -y[0];
805 },
806 0.0,
807 1e-10,
808 vec![1.0],
809 );
810
811 let options = SolverOptions::default();
812 let result = DoPri5::solve(&problem, 0.0, 1e-10, &[1.0], &options).unwrap();
813
814 assert!(result.success);
815 let y_final = result.y_final().unwrap();
816 assert!((y_final[0] - 1.0).abs() < 1e-8);
818 }
819
820 #[test]
821 fn test_constant_zero_rhs() {
822 let problem = OdeProblem::new(
824 |_t: f64, _y: &[f64], dydt: &mut [f64]| {
825 dydt[0] = 0.0;
826 },
827 0.0,
828 10.0,
829 vec![42.0],
830 );
831
832 let options = SolverOptions::default();
833 let result = DoPri5::solve(&problem, 0.0, 10.0, &[42.0], &options).unwrap();
834
835 assert!(result.success);
836 let y_final = result.y_final().unwrap();
837 assert!((y_final[0] - 42.0).abs() < 1e-12);
838 }
839
840 #[test]
841 fn test_single_step_only() {
842 let problem = OdeProblem::new(
844 |_t: f64, y: &[f64], dydt: &mut [f64]| {
845 dydt[0] = -y[0];
846 },
847 0.0,
848 10.0,
849 vec![1.0],
850 );
851
852 let options = SolverOptions::default().max_steps(1);
853 let result = DoPri5::solve(&problem, 0.0, 10.0, &[1.0], &options);
854
855 assert!(result.is_err());
857 assert!(matches!(
858 result.unwrap_err(),
859 crate::error::SolverError::MaxIterationsExceeded { .. }
860 ));
861 }
862
863 #[test]
864 fn test_tight_tolerance() {
865 let problem = OdeProblem::new(
867 |_t: f64, y: &[f64], dydt: &mut [f64]| {
868 dydt[0] = -y[0];
869 },
870 0.0,
871 1.0,
872 vec![1.0],
873 );
874
875 let options = SolverOptions::default().rtol(1e-12).atol(1e-14);
876 let result = DoPri5::solve(&problem, 0.0, 1.0, &[1.0], &options).unwrap();
877
878 assert!(result.success);
879 let y_final = result.y_final().unwrap();
880 let exact = (-1.0_f64).exp();
881 let error = (y_final[0] - exact).abs();
882 assert!(
883 error < 1e-11,
884 "Error {} too large for tight tolerance",
885 error
886 );
887 }
888
889 #[test]
890 fn test_loose_tolerance() {
891 let problem = OdeProblem::new(
893 |_t: f64, y: &[f64], dydt: &mut [f64]| {
894 dydt[0] = -y[0];
895 },
896 0.0,
897 1.0,
898 vec![1.0],
899 );
900
901 let options = SolverOptions::default().rtol(1e-2).atol(1e-3);
902 let result = DoPri5::solve(&problem, 0.0, 1.0, &[1.0], &options).unwrap();
903
904 assert!(result.success);
905 assert!(result.stats.n_accept < 50);
907 }
908
909 #[test]
910 fn test_zero_initial_condition() {
911 let problem = OdeProblem::new(
913 |_t: f64, _y: &[f64], dydt: &mut [f64]| {
914 dydt[0] = 1.0;
915 },
916 0.0,
917 5.0,
918 vec![0.0],
919 );
920
921 let options = SolverOptions::default();
922 let result = DoPri5::solve(&problem, 0.0, 5.0, &[0.0], &options).unwrap();
923
924 assert!(result.success);
925 let y_final = result.y_final().unwrap();
926 assert!((y_final[0] - 5.0).abs() < 1e-8);
927 }
928
929 #[test]
930 fn test_large_initial_condition() {
931 let problem = OdeProblem::new(
933 |_t: f64, y: &[f64], dydt: &mut [f64]| {
934 dydt[0] = -0.1 * y[0];
935 },
936 0.0,
937 1.0,
938 vec![1e10],
939 );
940
941 let options = SolverOptions::default();
942 let result = DoPri5::solve(&problem, 0.0, 1.0, &[1e10], &options).unwrap();
943
944 assert!(result.success);
945 let y_final = result.y_final().unwrap();
946 let exact = 1e10 * (-0.1_f64).exp();
947 let rel_error = (y_final[0] - exact).abs() / exact;
948 assert!(rel_error < 1e-5, "Relative error {} too large", rel_error);
949 }
950
951 #[test]
952 fn test_high_dimension() {
953 let problem = OdeProblem::new(
955 |_t: f64, y: &[f64], dydt: &mut [f64]| {
956 for (i, &yi) in y.iter().enumerate() {
957 dydt[i] = -(i as f64 + 1.0) * 0.1 * yi;
958 }
959 },
960 0.0,
961 1.0,
962 vec![1.0; 10],
963 );
964
965 let options = SolverOptions::default();
966 let result = DoPri5::solve(&problem, 0.0, 1.0, &[1.0; 10], &options).unwrap();
967
968 assert!(result.success);
969 let y_final = result.y_final().unwrap();
970 assert_eq!(y_final.len(), 10);
971
972 for (i, &yi) in y_final.iter().enumerate() {
974 let rate = (i as f64 + 1.0) * 0.1;
975 let exact = (-rate).exp();
976 let error = (yi - exact).abs();
977 assert!(error < 1e-5, "Component {} error {} too large", i, error);
978 }
979 }
980
981 #[test]
986 fn test_event_detection_bouncing_ball() {
987 use crate::events::{EventAction, EventDirection, EventFunction};
990
991 struct GroundContact;
992
993 impl EventFunction<f64> for GroundContact {
994 fn evaluate(&self, _t: f64, y: &[f64]) -> f64 {
995 y[0] }
997
998 fn direction(&self) -> EventDirection {
999 EventDirection::Falling }
1001
1002 fn action(&self) -> EventAction {
1003 EventAction::Stop }
1005 }
1006
1007 let g = 9.81_f64;
1008 let problem = OdeProblem::new(
1009 |_t, y: &[f64], dydt: &mut [f64]| {
1010 dydt[0] = y[1]; dydt[1] = -g; },
1013 0.0,
1014 10.0,
1015 vec![10.0, 0.0],
1016 );
1017
1018 let y0 = vec![10.0, 0.0]; let options = SolverOptions::default()
1021 .rtol(1e-8)
1022 .atol(1e-10)
1023 .event(Box::new(GroundContact));
1024
1025 let result = DoPri5::solve(&problem, 0.0, 10.0, &y0, &options).unwrap();
1026
1027 assert!(
1029 result.terminated_by_event,
1030 "Should have terminated by event"
1031 );
1032 assert!(!result.events.is_empty(), "Should have detected events");
1033
1034 let event = &result.events[0];
1036 assert!(
1037 event.y[0].abs() < 1e-4,
1038 "Event should occur at y=0, got y={}",
1039 event.y[0]
1040 );
1041
1042 let expected_t = (2.0 * 10.0 / g).sqrt();
1044 assert!(
1045 (event.t - expected_t).abs() < 0.01,
1046 "Expected t={:.3}, got t={:.3}",
1047 expected_t,
1048 event.t
1049 );
1050 }
1051
1052 #[test]
1053 fn test_event_continue_action() {
1054 use crate::events::{EventAction, EventDirection, EventFunction};
1056
1057 struct ZeroCrossing;
1058
1059 impl EventFunction<f64> for ZeroCrossing {
1060 fn evaluate(&self, _t: f64, y: &[f64]) -> f64 {
1061 y[0] }
1063
1064 fn direction(&self) -> EventDirection {
1065 EventDirection::Both
1066 }
1067
1068 fn action(&self) -> EventAction {
1069 EventAction::Continue }
1071 }
1072
1073 let problem = OdeProblem::new(
1075 |_t: f64, y: &[f64], dydt: &mut [f64]| {
1076 dydt[0] = y[1];
1077 dydt[1] = -y[0];
1078 },
1079 0.0,
1080 10.0,
1081 vec![1.0, 0.0],
1082 );
1083
1084 let options = SolverOptions::default()
1085 .rtol(1e-8)
1086 .atol(1e-10)
1087 .event(Box::new(ZeroCrossing));
1088
1089 let result = DoPri5::solve(&problem, 0.0, 10.0, &[1.0, 0.0], &options).unwrap();
1090
1091 assert!(
1093 !result.terminated_by_event,
1094 "Should not have terminated by event"
1095 );
1096
1097 assert!(
1101 result.events.len() >= 3,
1102 "Should have detected at least 3 events, got {}",
1103 result.events.len()
1104 );
1105
1106 let first = &result.events[0];
1108 let expected_t = std::f64::consts::FRAC_PI_2;
1109 assert!(
1110 (first.t - expected_t).abs() < 0.01,
1111 "First event expected at t={:.3}, got t={:.3}",
1112 expected_t,
1113 first.t
1114 );
1115 }
1116
1117 #[test]
1118 fn test_event_rising_only_integration() {
1119 use crate::events::{EventAction, EventDirection, EventFunction};
1123
1124 struct RisingZeroCrossing;
1125
1126 impl EventFunction<f64> for RisingZeroCrossing {
1127 fn evaluate(&self, _t: f64, y: &[f64]) -> f64 {
1128 y[0]
1129 }
1130 fn direction(&self) -> EventDirection {
1131 EventDirection::Rising
1132 }
1133 fn action(&self) -> EventAction {
1134 EventAction::Continue
1135 }
1136 }
1137
1138 let problem = OdeProblem::new(
1139 |_t: f64, y: &[f64], dydt: &mut [f64]| {
1140 dydt[0] = y[1];
1141 dydt[1] = -y[0];
1142 },
1143 0.0,
1144 10.0,
1145 vec![1.0, 0.0],
1146 );
1147
1148 let options = SolverOptions::default()
1149 .rtol(1e-8)
1150 .atol(1e-10)
1151 .event(Box::new(RisingZeroCrossing));
1152
1153 let result = DoPri5::solve(&problem, 0.0, 10.0, &[1.0, 0.0], &options).unwrap();
1154
1155 for event in &result.events {
1159 assert!(
1161 event.y[1] > -0.1,
1162 "Rising event should have positive velocity, got y[1]={}",
1163 event.y[1]
1164 );
1165 }
1166 assert!(
1167 !result.events.is_empty(),
1168 "Should detect at least one rising zero crossing"
1169 );
1170 }
1171
1172 #[test]
1173 fn test_event_simultaneous_events() {
1174 use crate::events::{EventAction, EventDirection, EventFunction};
1176
1177 struct ZeroCross1;
1178 impl EventFunction<f64> for ZeroCross1 {
1179 fn evaluate(&self, _t: f64, y: &[f64]) -> f64 {
1180 y[0]
1181 }
1182 fn direction(&self) -> EventDirection {
1183 EventDirection::Both
1184 }
1185 fn action(&self) -> EventAction {
1186 EventAction::Continue
1187 }
1188 }
1189
1190 struct ZeroCross2;
1191 impl EventFunction<f64> for ZeroCross2 {
1192 fn evaluate(&self, _t: f64, y: &[f64]) -> f64 {
1193 y[0]
1194 }
1195 fn direction(&self) -> EventDirection {
1196 EventDirection::Both
1197 }
1198 fn action(&self) -> EventAction {
1199 EventAction::Continue
1200 }
1201 }
1202
1203 let problem = OdeProblem::new(
1204 |_t: f64, y: &[f64], dydt: &mut [f64]| {
1205 dydt[0] = y[1];
1206 dydt[1] = -y[0];
1207 },
1208 0.0,
1209 5.0,
1210 vec![1.0, 0.0],
1211 );
1212
1213 let options = SolverOptions::default()
1214 .rtol(1e-8)
1215 .atol(1e-10)
1216 .event(Box::new(ZeroCross1))
1217 .event(Box::new(ZeroCross2));
1218
1219 let result = DoPri5::solve(&problem, 0.0, 5.0, &[1.0, 0.0], &options).unwrap();
1220
1221 assert!(
1224 result.events.len() >= 4,
1225 "Should detect events from both functions, got {}",
1226 result.events.len()
1227 );
1228
1229 let has_idx_0 = result.events.iter().any(|e| e.event_index == 0);
1231 let has_idx_1 = result.events.iter().any(|e| e.event_index == 1);
1232 assert!(has_idx_0, "Should have events from function 0");
1233 assert!(has_idx_1, "Should have events from function 1");
1234 }
1235
1236 #[test]
1237 fn test_event_backward_integration() {
1238 use crate::events::{EventAction, EventDirection, EventFunction};
1240
1241 struct ZeroCross;
1242 impl EventFunction<f64> for ZeroCross {
1243 fn evaluate(&self, _t: f64, y: &[f64]) -> f64 {
1244 y[0]
1245 }
1246 fn direction(&self) -> EventDirection {
1247 EventDirection::Both
1248 }
1249 fn action(&self) -> EventAction {
1250 EventAction::Stop
1251 }
1252 }
1253
1254 let y5 = [5.0_f64.cos(), -5.0_f64.sin()];
1257 let problem = OdeProblem::new(
1258 |_t: f64, y: &[f64], dydt: &mut [f64]| {
1259 dydt[0] = y[1];
1260 dydt[1] = -y[0];
1261 },
1262 5.0,
1263 0.0,
1264 y5.to_vec(),
1265 );
1266
1267 let options = SolverOptions::default()
1268 .rtol(1e-8)
1269 .atol(1e-10)
1270 .event(Box::new(ZeroCross));
1271
1272 let result = DoPri5::solve(&problem, 5.0, 0.0, &y5, &options).unwrap();
1273
1274 assert!(
1276 result.terminated_by_event,
1277 "Should terminate at event during backward integration"
1278 );
1279 assert!(
1280 !result.events.is_empty(),
1281 "Should detect events during backward integration"
1282 );
1283
1284 let event = &result.events[0];
1286 assert!(
1287 event.t > 0.0 && event.t < 5.0,
1288 "Event time {} should be between 0 and 5",
1289 event.t
1290 );
1291 assert!(
1292 event.y[0].abs() < 0.01,
1293 "y at event should be ~0, got {}",
1294 event.y[0]
1295 );
1296 }
1297
1298 #[test]
1299 fn test_no_event_when_no_crossing() {
1300 use crate::events::{EventAction, EventDirection, EventFunction};
1302
1303 struct ZeroCheck;
1304
1305 impl EventFunction<f64> for ZeroCheck {
1306 fn evaluate(&self, _t: f64, y: &[f64]) -> f64 {
1307 y[0] }
1309
1310 fn direction(&self) -> EventDirection {
1311 EventDirection::Both
1312 }
1313
1314 fn action(&self) -> EventAction {
1315 EventAction::Stop
1316 }
1317 }
1318
1319 let problem = OdeProblem::new(
1320 |_t: f64, y: &[f64], dydt: &mut [f64]| {
1321 dydt[0] = -y[0];
1322 },
1323 0.0,
1324 5.0,
1325 vec![1.0],
1326 );
1327
1328 let options = SolverOptions::default().event(Box::new(ZeroCheck));
1329
1330 let result = DoPri5::solve(&problem, 0.0, 5.0, &[1.0], &options).unwrap();
1331
1332 assert!(!result.terminated_by_event);
1333 assert!(result.events.is_empty());
1334 }
1335
1336 #[test]
1341 fn test_exponential_decay_f32() {
1342 let problem = OdeProblem::new(
1344 |_t: f32, y: &[f32], dydt: &mut [f32]| {
1345 dydt[0] = -y[0];
1346 },
1347 0.0f32,
1348 5.0f32,
1349 vec![1.0f32],
1350 );
1351
1352 let options: SolverOptions<f32> = SolverOptions::default().rtol(1e-4).atol(1e-6);
1353 let result = DoPri5::solve(&problem, 0.0f32, 5.0f32, &[1.0f32], &options).unwrap();
1354
1355 assert!(result.success);
1356 let y_final = result.y_final().unwrap();
1357 let exact = (-5.0f32).exp();
1358 let error = (y_final[0] - exact).abs();
1359 assert!(error < 1e-3, "f32 error {} too large", error);
1360 }
1361
1362 #[test]
1363 fn test_harmonic_oscillator_f32() {
1364 let problem = OdeProblem::new(
1365 |_t: f32, y: &[f32], dydt: &mut [f32]| {
1366 dydt[0] = y[1];
1367 dydt[1] = -y[0];
1368 },
1369 0.0f32,
1370 6.0f32,
1371 vec![1.0f32, 0.0f32],
1372 );
1373
1374 let options: SolverOptions<f32> = SolverOptions::default().rtol(1e-4).atol(1e-6);
1375 let result = DoPri5::solve(&problem, 0.0f32, 6.0f32, &[1.0f32, 0.0f32], &options).unwrap();
1376
1377 assert!(result.success);
1378 let y_final = result.y_final().unwrap();
1379 let exact_x = 6.0f32.cos();
1380 let error = (y_final[0] - exact_x).abs();
1381 assert!(error < 1e-3, "f32 harmonic error {} too large", error);
1382 }
1383
1384 #[test]
1389 fn test_nan_initial_condition() {
1390 let problem = OdeProblem::new(
1392 |_t: f64, y: &[f64], dydt: &mut [f64]| {
1393 dydt[0] = -y[0];
1394 },
1395 0.0,
1396 1.0,
1397 vec![f64::NAN],
1398 );
1399
1400 let options = SolverOptions::default();
1401 let result = DoPri5::solve(&problem, 0.0, 1.0, &[f64::NAN], &options);
1402 assert!(
1403 result.is_err(),
1404 "NaN initial condition should produce error"
1405 );
1406 }
1407
1408 #[test]
1409 fn test_infinity_initial_condition() {
1410 let problem = OdeProblem::new(
1412 |_t: f64, y: &[f64], dydt: &mut [f64]| {
1413 dydt[0] = -y[0];
1414 },
1415 0.0,
1416 1.0,
1417 vec![f64::INFINITY],
1418 );
1419
1420 let options = SolverOptions::default();
1421 let result = DoPri5::solve(&problem, 0.0, 1.0, &[f64::INFINITY], &options);
1422 assert!(
1424 result.is_err(),
1425 "Infinity initial condition should produce error"
1426 );
1427 }
1428
1429 #[test]
1430 fn test_rhs_produces_nan() {
1431 let problem = OdeProblem::new(
1433 |_t: f64, _y: &[f64], dydt: &mut [f64]| {
1434 dydt[0] = f64::NAN;
1435 },
1436 0.0,
1437 1.0,
1438 vec![1.0],
1439 );
1440
1441 let options = SolverOptions::default();
1442 let result = DoPri5::solve(&problem, 0.0, 1.0, &[1.0], &options);
1443 assert!(result.is_err(), "NaN in RHS should produce error");
1444 }
1445}