1#![allow(dead_code)]
11#![allow(clippy::too_many_arguments)]
12
13use std::f64::consts::PI;
14
15#[derive(Debug, Clone, PartialEq)]
24pub struct OdeState {
25 pub t: f64,
27 pub y: Vec<f64>,
29}
30
31impl OdeState {
32 pub fn new(t: f64, y: Vec<f64>) -> Self {
34 Self { t, y }
35 }
36
37 pub fn norm(&self) -> f64 {
39 self.y.iter().map(|v| v * v).sum::<f64>().sqrt()
40 }
41
42 pub fn dim(&self) -> usize {
44 self.y.len()
45 }
46
47 pub fn zeros(t: f64, n: usize) -> Self {
49 Self { t, y: vec![0.0; n] }
50 }
51
52 pub fn lerp(&self, other: &OdeState, alpha: f64) -> OdeState {
55 let t = self.t + alpha * (other.t - self.t);
56 let y = self
57 .y
58 .iter()
59 .zip(other.y.iter())
60 .map(|(a, b)| a + alpha * (b - a))
61 .collect();
62 OdeState { t, y }
63 }
64}
65
66#[inline]
71fn vec_axpy(a: f64, x: &[f64], y: &[f64]) -> Vec<f64> {
72 x.iter().zip(y.iter()).map(|(xi, yi)| a * xi + yi).collect()
73}
74
75#[inline]
76fn vec_scale(a: f64, x: &[f64]) -> Vec<f64> {
77 x.iter().map(|xi| a * xi).collect()
78}
79
80#[inline]
81fn vec_add(x: &[f64], y: &[f64]) -> Vec<f64> {
82 x.iter().zip(y.iter()).map(|(a, b)| a + b).collect()
83}
84
85#[inline]
86fn vec_sub(x: &[f64], y: &[f64]) -> Vec<f64> {
87 x.iter().zip(y.iter()).map(|(a, b)| a - b).collect()
88}
89
90#[inline]
91fn rms_norm(v: &[f64]) -> f64 {
92 if v.is_empty() {
93 return 0.0;
94 }
95 (v.iter().map(|x| x * x).sum::<f64>() / v.len() as f64).sqrt()
96}
97
98pub struct RK4Integrator {
109 pub atol: f64,
111 pub rtol: f64,
113}
114
115impl RK4Integrator {
116 pub fn new(atol: f64, rtol: f64) -> Self {
118 Self { atol, rtol }
119 }
120
121 pub fn default_tolerances() -> Self {
123 Self {
124 atol: 1e-6,
125 rtol: 1e-6,
126 }
127 }
128
129 pub fn step<F>(&self, s: &OdeState, dt: f64, f: &F) -> OdeState
133 where
134 F: Fn(f64, &[f64]) -> Vec<f64>,
135 {
136 let t = s.t;
137 let y = &s.y;
138 let k1 = f(t, y);
139 let y2 = vec_axpy(0.5 * dt, &k1, y);
140 let k2 = f(t + 0.5 * dt, &y2);
141 let y3 = vec_axpy(0.5 * dt, &k2, y);
142 let k3 = f(t + 0.5 * dt, &y3);
143 let y4 = vec_axpy(dt, &k3, y);
144 let k4 = f(t + dt, &y4);
145
146 let n = y.len();
147 let y_new: Vec<f64> = (0..n)
148 .map(|i| y[i] + (dt / 6.0) * (k1[i] + 2.0 * k2[i] + 2.0 * k3[i] + k4[i]))
149 .collect();
150 OdeState::new(t + dt, y_new)
151 }
152
153 pub fn integrate<F>(&self, s0: &OdeState, t_end: f64, dt: f64, f: &F) -> Vec<OdeState>
157 where
158 F: Fn(f64, &[f64]) -> Vec<f64>,
159 {
160 let mut states = vec![s0.clone()];
161 let mut s = s0.clone();
162 while s.t < t_end - 1e-14 {
163 let h = dt.min(t_end - s.t);
164 s = self.step(&s, h, f);
165 states.push(s.clone());
166 }
167 states
168 }
169
170 pub fn integrate_adaptive<F>(
175 &self,
176 s0: &OdeState,
177 t_end: f64,
178 dt_init: f64,
179 f: &F,
180 ) -> Vec<OdeState>
181 where
182 F: Fn(f64, &[f64]) -> Vec<f64>,
183 {
184 let mut states = vec![s0.clone()];
185 let mut s = s0.clone();
186 let mut dt = dt_init;
187 let dt_min = 1e-12;
188 let dt_max = t_end - s0.t;
189
190 while s.t < t_end - 1e-14 {
191 let h = dt.min(t_end - s.t).max(dt_min);
192 let s_rk4 = self.step(&s, h, f);
193
194 let k1 = f(s.t, &s.y);
196 let y_mid = vec_axpy(0.5 * h, &k1, &s.y);
197 let k2 = f(s.t + 0.5 * h, &y_mid);
198 let y_rk2: Vec<f64> =
199 s.y.iter()
200 .zip(k2.iter())
201 .map(|(yi, ki)| yi + h * ki)
202 .collect();
203
204 let err: Vec<f64> = s_rk4
205 .y
206 .iter()
207 .zip(y_rk2.iter())
208 .map(|(a, b)| a - b)
209 .collect();
210 let tol = self.atol + self.rtol * s_rk4.norm();
211 let e = rms_norm(&err);
212
213 if e <= tol || h <= dt_min {
214 s = s_rk4;
215 states.push(s.clone());
216 if e > 0.0 {
218 dt = (h * (tol / e).powf(0.2)).min(dt_max);
219 } else {
220 dt = (h * 2.0).min(dt_max);
221 }
222 } else {
223 dt = (h * 0.9 * (tol / e).powf(0.25)).max(dt_min);
225 }
226 }
227 states
228 }
229}
230
231pub struct DormandPrince45 {
243 pub atol: f64,
245 pub rtol: f64,
247 pub dt_min: f64,
249 pub dt_max: f64,
251}
252
253impl DormandPrince45 {
254 const C2: f64 = 1.0 / 5.0;
256 const C3: f64 = 3.0 / 10.0;
257 const C4: f64 = 4.0 / 5.0;
258 const C5: f64 = 8.0 / 9.0;
259 const A21: f64 = 1.0 / 5.0;
260 const A31: f64 = 3.0 / 40.0;
261 const A32: f64 = 9.0 / 40.0;
262 const A41: f64 = 44.0 / 45.0;
263 const A42: f64 = -56.0 / 15.0;
264 const A43: f64 = 32.0 / 9.0;
265 const A51: f64 = 19372.0 / 6561.0;
266 const A52: f64 = -25360.0 / 2187.0;
267 const A53: f64 = 64448.0 / 6561.0;
268 const A54: f64 = -212.0 / 729.0;
269 const A61: f64 = 9017.0 / 3168.0;
270 const A62: f64 = -355.0 / 33.0;
271 const A63: f64 = 46732.0 / 5247.0;
272 const A64: f64 = 49.0 / 176.0;
273 const A65: f64 = -5103.0 / 18656.0;
274
275 const B1: f64 = 35.0 / 384.0;
277 const B3: f64 = 500.0 / 1113.0;
278 const B4: f64 = 125.0 / 192.0;
279 const B5: f64 = -2187.0 / 6784.0;
280 const B6: f64 = 11.0 / 84.0;
281
282 const E1: f64 = 71.0 / 57600.0;
284 const E3: f64 = -71.0 / 16695.0;
285 const E4: f64 = 71.0 / 1920.0;
286 const E5: f64 = -17253.0 / 339200.0;
287 const E6: f64 = 22.0 / 525.0;
288 const E7: f64 = -1.0 / 40.0;
289
290 pub fn new(atol: f64, rtol: f64, dt_min: f64, dt_max: f64) -> Self {
292 Self {
293 atol,
294 rtol,
295 dt_min,
296 dt_max,
297 }
298 }
299
300 pub fn default_tolerances() -> Self {
302 Self {
303 atol: 1e-6,
304 rtol: 1e-6,
305 dt_min: 1e-12,
306 dt_max: f64::INFINITY,
307 }
308 }
309
310 pub fn step<F>(
315 &self,
316 s: &OdeState,
317 h: f64,
318 f: &F,
319 k1_in: Option<&Vec<f64>>,
320 ) -> (OdeState, f64, Vec<f64>)
321 where
322 F: Fn(f64, &[f64]) -> Vec<f64>,
323 {
324 let t = s.t;
325 let y = &s.y;
326 let n = y.len();
327
328 let k1 = match k1_in {
329 Some(k) => k.clone(),
330 None => f(t, y),
331 };
332
333 let y2: Vec<f64> = (0..n).map(|i| y[i] + h * Self::A21 * k1[i]).collect();
334 let k2 = f(t + Self::C2 * h, &y2);
335
336 let y3: Vec<f64> = (0..n)
337 .map(|i| y[i] + h * (Self::A31 * k1[i] + Self::A32 * k2[i]))
338 .collect();
339 let k3 = f(t + Self::C3 * h, &y3);
340
341 let y4: Vec<f64> = (0..n)
342 .map(|i| y[i] + h * (Self::A41 * k1[i] + Self::A42 * k2[i] + Self::A43 * k3[i]))
343 .collect();
344 let k4 = f(t + Self::C4 * h, &y4);
345
346 let y5: Vec<f64> = (0..n)
347 .map(|i| {
348 y[i] + h
349 * (Self::A51 * k1[i]
350 + Self::A52 * k2[i]
351 + Self::A53 * k3[i]
352 + Self::A54 * k4[i])
353 })
354 .collect();
355 let k5 = f(t + Self::C5 * h, &y5);
356
357 let y6: Vec<f64> = (0..n)
358 .map(|i| {
359 y[i] + h
360 * (Self::A61 * k1[i]
361 + Self::A62 * k2[i]
362 + Self::A63 * k3[i]
363 + Self::A64 * k4[i]
364 + Self::A65 * k5[i])
365 })
366 .collect();
367 let k6 = f(t + h, &y6);
368
369 let y_new: Vec<f64> = (0..n)
370 .map(|i| {
371 y[i] + h
372 * (Self::B1 * k1[i]
373 + Self::B3 * k3[i]
374 + Self::B4 * k4[i]
375 + Self::B5 * k5[i]
376 + Self::B6 * k6[i])
377 })
378 .collect();
379 let k7 = f(t + h, &y_new);
380
381 let err: Vec<f64> = (0..n)
383 .map(|i| {
384 h * (Self::E1 * k1[i]
385 + Self::E3 * k3[i]
386 + Self::E4 * k4[i]
387 + Self::E5 * k5[i]
388 + Self::E6 * k6[i]
389 + Self::E7 * k7[i])
390 })
391 .collect();
392
393 let sc: Vec<f64> = y_new
394 .iter()
395 .zip(y.iter())
396 .map(|(yn, y0)| self.atol + self.rtol * yn.abs().max(y0.abs()))
397 .collect();
398 let err_norm = (err
399 .iter()
400 .zip(sc.iter())
401 .map(|(e, s)| (e / s).powi(2))
402 .sum::<f64>()
403 / n as f64)
404 .sqrt();
405
406 (OdeState::new(t + h, y_new), err_norm, k7)
407 }
408
409 pub fn integrate<F>(&self, s0: &OdeState, t_end: f64, dt_init: f64, f: &F) -> OdeSolution
413 where
414 F: Fn(f64, &[f64]) -> Vec<f64>,
415 {
416 let mut states = vec![s0.clone()];
417 let mut s = s0.clone();
418 let mut h = dt_init;
419 let mut k1 = f(s.t, &s.y);
420 let max_steps = 1_000_000usize;
421 let mut n_steps = 0;
422
423 while s.t < t_end - 1e-14 && n_steps < max_steps {
424 h = h.min(t_end - s.t).max(self.dt_min).min(self.dt_max);
425 let (s_new, err, k7) = self.step(&s, h, f, Some(&k1));
426
427 if err <= 1.0 || h <= self.dt_min {
428 s = s_new;
429 k1 = k7; states.push(s.clone());
431 if err > 0.0 {
433 h = (h * 0.9 * err.powf(-0.2)).min(self.dt_max).max(self.dt_min);
434 } else {
435 h = (h * 5.0).min(self.dt_max);
436 }
437 } else {
438 h = (h * 0.9 * err.powf(-0.25)).max(self.dt_min);
439 }
440 n_steps += 1;
441 }
442
443 OdeSolution::new(states)
444 }
445}
446
447pub struct ImplicitEuler {
456 pub max_iter: usize,
458 pub tol: f64,
460 pub fd_eps: f64,
462}
463
464impl ImplicitEuler {
465 pub fn new(max_iter: usize, tol: f64, fd_eps: f64) -> Self {
467 Self {
468 max_iter,
469 tol,
470 fd_eps,
471 }
472 }
473
474 pub fn default_params() -> Self {
476 Self {
477 max_iter: 50,
478 tol: 1e-10,
479 fd_eps: 1e-7,
480 }
481 }
482
483 pub fn step<F>(&self, s: &OdeState, h: f64, f: &F) -> OdeState
488 where
489 F: Fn(f64, &[f64]) -> Vec<f64>,
490 {
491 let t_new = s.t + h;
492 let mut y = s.y.clone();
493
494 for _ in 0..self.max_iter {
495 let rhs = f(t_new, &y);
496 let y_new: Vec<f64> =
497 s.y.iter()
498 .zip(rhs.iter())
499 .map(|(y0, r)| y0 + h * r)
500 .collect();
501 let diff = rms_norm(&vec_sub(&y_new, &y));
502 y = y_new;
503 if diff < self.tol {
504 break;
505 }
506 }
507
508 OdeState::new(t_new, y)
509 }
510
511 pub fn step_newton<F>(&self, s: &OdeState, h: f64, f: &F) -> OdeState
516 where
517 F: Fn(f64, &[f64]) -> Vec<f64>,
518 {
519 let t_new = s.t + h;
520 let n = s.y.len();
521 let mut y = s.y.clone();
522
523 for _ in 0..self.max_iter {
524 let fy = f(t_new, &y);
525 let g: Vec<f64> = (0..n).map(|i| y[i] - s.y[i] - h * fy[i]).collect();
527
528 let g_norm = rms_norm(&g);
529 if g_norm < self.tol {
530 break;
531 }
532
533 let mut jac_diag = vec![1.0f64; n];
536 for j in 0..n {
537 let mut yp = y.clone();
538 yp[j] += self.fd_eps;
539 let fyp = f(t_new, &yp);
540 jac_diag[j] = 1.0 - h * (fyp[j] - fy[j]) / self.fd_eps;
541 if jac_diag[j].abs() < 1e-14 {
542 jac_diag[j] = 1.0;
543 }
544 }
545
546 for i in 0..n {
548 y[i] -= g[i] / jac_diag[i];
549 }
550 }
551
552 OdeState::new(t_new, y)
553 }
554
555 pub fn integrate<F>(&self, s0: &OdeState, t_end: f64, dt: f64, f: &F) -> Vec<OdeState>
559 where
560 F: Fn(f64, &[f64]) -> Vec<f64>,
561 {
562 let mut states = vec![s0.clone()];
563 let mut s = s0.clone();
564 while s.t < t_end - 1e-14 {
565 let h = dt.min(t_end - s.t);
566 s = self.step_newton(&s, h, f);
567 states.push(s.clone());
568 }
569 states
570 }
571}
572
573pub struct Trapezoidal {
582 pub max_iter: usize,
584 pub tol: f64,
586}
587
588impl Trapezoidal {
589 pub fn new(max_iter: usize, tol: f64) -> Self {
591 Self { max_iter, tol }
592 }
593
594 pub fn default_params() -> Self {
596 Self {
597 max_iter: 50,
598 tol: 1e-10,
599 }
600 }
601
602 pub fn step<F>(&self, s: &OdeState, h: f64, f: &F) -> OdeState
604 where
605 F: Fn(f64, &[f64]) -> Vec<f64>,
606 {
607 let t_new = s.t + h;
608 let f0 = f(s.t, &s.y);
609 let mut y = vec_axpy(h, &f0, &s.y);
611
612 for _ in 0..self.max_iter {
613 let f1 = f(t_new, &y);
614 let y_new: Vec<f64> = (0..s.y.len())
615 .map(|i| s.y[i] + 0.5 * h * (f0[i] + f1[i]))
616 .collect();
617 let diff = rms_norm(&vec_sub(&y_new, &y));
618 y = y_new;
619 if diff < self.tol {
620 break;
621 }
622 }
623
624 OdeState::new(t_new, y)
625 }
626
627 pub fn integrate<F>(&self, s0: &OdeState, t_end: f64, dt: f64, f: &F) -> Vec<OdeState>
629 where
630 F: Fn(f64, &[f64]) -> Vec<f64>,
631 {
632 let mut states = vec![s0.clone()];
633 let mut s = s0.clone();
634 while s.t < t_end - 1e-14 {
635 let h = dt.min(t_end - s.t);
636 s = self.step(&s, h, f);
637 states.push(s.clone());
638 }
639 states
640 }
641}
642
643pub struct BDF2 {
652 pub max_iter: usize,
654 pub tol: f64,
656}
657
658impl BDF2 {
659 pub fn new(max_iter: usize, tol: f64) -> Self {
661 Self { max_iter, tol }
662 }
663
664 pub fn default_params() -> Self {
666 Self {
667 max_iter: 50,
668 tol: 1e-10,
669 }
670 }
671
672 pub fn step<F>(&self, s_curr: &OdeState, s_prev: &OdeState, h: f64, f: &F) -> OdeState
678 where
679 F: Fn(f64, &[f64]) -> Vec<f64>,
680 {
681 let t_new = s_curr.t + h;
682 let n = s_curr.y.len();
683 let mut y: Vec<f64> = (0..n).map(|i| 2.0 * s_curr.y[i] - s_prev.y[i]).collect();
685 let fd_eps = 1e-7_f64;
686
687 for _ in 0..self.max_iter {
688 let fy = f(t_new, &y);
689 let g: Vec<f64> = (0..n)
691 .map(|i| 1.5 * y[i] - 2.0 * s_curr.y[i] + 0.5 * s_prev.y[i] - h * fy[i])
692 .collect();
693 let g_norm = rms_norm(&g);
694 if g_norm < self.tol {
695 break;
696 }
697 let mut jac_diag = vec![1.5_f64; n];
699 for j in 0..n {
700 let mut yp = y.clone();
701 yp[j] += fd_eps;
702 let fyp = f(t_new, &yp);
703 jac_diag[j] = 1.5 - h * (fyp[j] - fy[j]) / fd_eps;
704 if jac_diag[j].abs() < 1e-14 {
705 jac_diag[j] = 1.5;
706 }
707 }
708 for i in 0..n {
710 y[i] -= g[i] / jac_diag[i];
711 }
712 }
713
714 OdeState::new(t_new, y)
715 }
716
717 pub fn integrate<F>(&self, s0: &OdeState, t_end: f64, dt: f64, f: &F) -> Vec<OdeState>
721 where
722 F: Fn(f64, &[f64]) -> Vec<f64>,
723 {
724 if s0.t >= t_end - 1e-14 {
725 return vec![s0.clone()];
726 }
727 let ie = ImplicitEuler::new(self.max_iter, self.tol, 1e-7);
728 let h = dt.min(t_end - s0.t);
729 let s1 = ie.step_newton(s0, h, f);
730 let mut states = vec![s0.clone(), s1.clone()];
731 let mut s_prev = s0.clone();
732 let mut s_curr = s1;
733
734 while s_curr.t < t_end - 1e-14 {
735 let step = dt.min(t_end - s_curr.t);
736 let s_next = self.step(&s_curr, &s_prev, step, f);
737 states.push(s_next.clone());
738 s_prev = s_curr;
739 s_curr = s_next;
740 }
741 states
742 }
743}
744
745#[derive(Debug, Clone)]
751pub struct CrossingEvent {
752 pub t: f64,
754 pub y: Vec<f64>,
756 pub sign_before: f64,
758 pub event_index: usize,
760}
761
762pub struct EventDetection {
768 pub tol: f64,
770 pub max_iter: usize,
772}
773
774impl EventDetection {
775 pub fn new(tol: f64, max_iter: usize) -> Self {
777 Self { tol, max_iter }
778 }
779
780 pub fn default_params() -> Self {
782 Self {
783 tol: 1e-10,
784 max_iter: 50,
785 }
786 }
787
788 pub fn detect<E>(&self, s_a: &OdeState, s_b: &OdeState, events: &[E]) -> Vec<CrossingEvent>
793 where
794 E: Fn(f64, &[f64]) -> f64,
795 {
796 let mut crossings = Vec::new();
797
798 for (idx, evt) in events.iter().enumerate() {
799 let ga = evt(s_a.t, &s_a.y);
800 let gb = evt(s_b.t, &s_b.y);
801 if ga * gb > 0.0 {
802 continue; }
804
805 let mut lo = 0.0f64;
807 let mut hi = 1.0f64;
808 let ga_sign = ga.signum();
809
810 for _ in 0..self.max_iter {
811 let mid = 0.5 * (lo + hi);
812 let s_mid = s_a.lerp(s_b, mid);
813 let gm = evt(s_mid.t, &s_mid.y);
814 if gm.signum() == ga_sign {
815 lo = mid;
816 } else {
817 hi = mid;
818 }
819 if hi - lo < self.tol {
820 break;
821 }
822 }
823
824 let alpha = 0.5 * (lo + hi);
825 let s_cross = s_a.lerp(s_b, alpha);
826 crossings.push(CrossingEvent {
827 t: s_cross.t,
828 y: s_cross.y,
829 sign_before: ga_sign,
830 event_index: idx,
831 });
832 }
833
834 crossings.sort_by(|a, b| a.t.partial_cmp(&b.t).unwrap_or(std::cmp::Ordering::Equal));
835 crossings
836 }
837}
838
839#[derive(Debug, Clone)]
848pub struct OdeSolution {
849 pub states: Vec<OdeState>,
851}
852
853impl OdeSolution {
854 pub fn new(states: Vec<OdeState>) -> Self {
856 Self { states }
857 }
858
859 pub fn len(&self) -> usize {
861 self.states.len()
862 }
863
864 pub fn is_empty(&self) -> bool {
866 self.states.is_empty()
867 }
868
869 pub fn interpolate(&self, t: f64) -> Option<OdeState> {
873 if self.states.is_empty() {
874 return None;
875 }
876 let t0 = self.states.first()?.t;
877 let t1 = self.states.last()?.t;
878 if t < t0 - 1e-14 || t > t1 + 1e-14 {
879 return None;
880 }
881 let idx = self.states.partition_point(|s| s.t <= t).saturating_sub(1);
883 let idx = idx.min(self.states.len() - 1);
884
885 if idx + 1 >= self.states.len() {
886 return Some(self.states[idx].clone());
887 }
888
889 let sa = &self.states[idx];
890 let sb = &self.states[idx + 1];
891 let dt = sb.t - sa.t;
892 if dt < 1e-15 {
893 return Some(sa.clone());
894 }
895 let alpha = (t - sa.t) / dt;
896 Some(sa.lerp(sb, alpha))
897 }
898
899 pub fn times(&self) -> Vec<f64> {
901 self.states.iter().map(|s| s.t).collect()
902 }
903
904 pub fn component(&self, i: usize) -> Vec<f64> {
908 self.states
909 .iter()
910 .filter_map(|s| s.y.get(i).copied())
911 .collect()
912 }
913
914 pub fn map_observable<G>(&self, g: G) -> Vec<f64>
916 where
917 G: Fn(f64, &[f64]) -> f64,
918 {
919 self.states.iter().map(|s| g(s.t, &s.y)).collect()
920 }
921
922 pub fn resample(&self, n: usize) -> Vec<OdeState> {
924 if self.states.len() < 2 || n < 2 {
925 return self.states.clone();
926 }
927 let t0 = self
928 .states
929 .first()
930 .expect("states has at least 2 entries")
931 .t;
932 let t1 = self.states.last().expect("states has at least 2 entries").t;
933 (0..n)
934 .filter_map(|k| {
935 let t = t0 + (t1 - t0) * k as f64 / (n - 1) as f64;
936 self.interpolate(t)
937 })
938 .collect()
939 }
940}
941
942#[allow(unused_imports)]
946const _PI_CHECK: f64 = PI;
947
948#[cfg(test)]
952mod tests {
953 use super::*;
954
955 #[test]
959 fn test_ode_state_new_and_norm() {
960 let s = OdeState::new(1.0, vec![3.0, 4.0]);
961 assert_eq!(s.t, 1.0);
962 assert!((s.norm() - 5.0).abs() < 1e-12);
963 }
964
965 #[test]
966 fn test_ode_state_zeros() {
967 let s = OdeState::zeros(0.0, 5);
968 assert_eq!(s.y.len(), 5);
969 assert_eq!(s.norm(), 0.0);
970 }
971
972 #[test]
973 fn test_ode_state_dim() {
974 let s = OdeState::new(0.0, vec![1.0, 2.0, 3.0]);
975 assert_eq!(s.dim(), 3);
976 }
977
978 #[test]
979 fn test_ode_state_lerp() {
980 let s0 = OdeState::new(0.0, vec![0.0, 0.0]);
981 let s1 = OdeState::new(1.0, vec![2.0, 4.0]);
982 let mid = s0.lerp(&s1, 0.5);
983 assert!((mid.t - 0.5).abs() < 1e-12);
984 assert!((mid.y[0] - 1.0).abs() < 1e-12);
985 assert!((mid.y[1] - 2.0).abs() < 1e-12);
986 }
987
988 #[test]
989 fn test_ode_state_lerp_endpoints() {
990 let s0 = OdeState::new(0.0, vec![1.0]);
991 let s1 = OdeState::new(2.0, vec![3.0]);
992 let at0 = s0.lerp(&s1, 0.0);
993 let at1 = s0.lerp(&s1, 1.0);
994 assert!((at0.y[0] - 1.0).abs() < 1e-12);
995 assert!((at1.y[0] - 3.0).abs() < 1e-12);
996 }
997
998 #[test]
1002 fn test_rms_norm_empty() {
1003 assert_eq!(rms_norm(&[]), 0.0);
1004 }
1005
1006 #[test]
1007 fn test_rms_norm_ones() {
1008 let v = vec![1.0, 1.0, 1.0, 1.0];
1009 assert!((rms_norm(&v) - 1.0).abs() < 1e-12);
1010 }
1011
1012 #[test]
1013 fn test_vec_axpy() {
1014 let x = vec![1.0, 2.0];
1015 let y = vec![3.0, 4.0];
1016 let r = vec_axpy(2.0, &x, &y);
1017 assert!((r[0] - 5.0).abs() < 1e-12);
1018 assert!((r[1] - 8.0).abs() < 1e-12);
1019 }
1020
1021 #[test]
1022 fn test_vec_scale() {
1023 let x = vec![1.0, 2.0, 3.0];
1024 let r = vec_scale(3.0, &x);
1025 assert!((r[2] - 9.0).abs() < 1e-12);
1026 }
1027
1028 #[test]
1029 fn test_vec_add_sub() {
1030 let a = vec![1.0, 2.0];
1031 let b = vec![3.0, 1.0];
1032 let s = vec_add(&a, &b);
1033 let d = vec_sub(&b, &a);
1034 assert!((s[0] - 4.0).abs() < 1e-12);
1035 assert!((d[1] + 1.0).abs() < 1e-12);
1036 }
1037
1038 fn f_decay(_t: f64, y: &[f64]) -> Vec<f64> {
1042 vec![-y[0]]
1043 }
1044
1045 #[test]
1046 fn test_rk4_single_step_accuracy() {
1047 let rk4 = RK4Integrator::default_tolerances();
1048 let s0 = OdeState::new(0.0, vec![1.0]);
1049 let s1 = rk4.step(&s0, 0.1, &f_decay);
1050 let exact = (-0.1f64).exp();
1051 assert!((s1.y[0] - exact).abs() < 1e-7);
1052 }
1053
1054 #[test]
1055 fn test_rk4_integrate_fixed() {
1056 let rk4 = RK4Integrator::default_tolerances();
1057 let s0 = OdeState::new(0.0, vec![1.0]);
1058 let traj = rk4.integrate(&s0, 1.0, 0.01, &f_decay);
1059 let last = traj.last().unwrap();
1060 let exact = (-1.0f64).exp();
1061 assert!((last.y[0] - exact).abs() < 1e-6);
1062 }
1063
1064 #[test]
1065 fn test_rk4_adaptive() {
1066 let rk4 = RK4Integrator::new(1e-8, 1e-8);
1067 let s0 = OdeState::new(0.0, vec![1.0]);
1068 let traj = rk4.integrate_adaptive(&s0, 2.0, 0.1, &f_decay);
1069 let last = traj.last().unwrap();
1070 let exact = (-2.0f64).exp();
1071 assert!((last.y[0] - exact).abs() < 1e-5);
1072 }
1073
1074 #[test]
1075 fn test_rk4_harmonic_oscillator() {
1076 let f = |_t: f64, y: &[f64]| vec![y[1], -y[0]];
1078 let rk4 = RK4Integrator::default_tolerances();
1079 let s0 = OdeState::new(0.0, vec![0.0, 1.0]); let traj = rk4.integrate(&s0, std::f64::consts::PI, 0.01, &f);
1081 let last = traj.last().unwrap();
1082 assert!(last.y[0].abs() < 1e-5);
1084 }
1085
1086 #[test]
1090 fn test_dp45_exponential_decay() {
1091 let dp = DormandPrince45::default_tolerances();
1092 let s0 = OdeState::new(0.0, vec![1.0]);
1093 let sol = dp.integrate(&s0, 1.0, 0.1, &f_decay);
1094 let last = sol.states.last().unwrap();
1095 let exact = (-1.0f64).exp();
1096 assert!((last.y[0] - exact).abs() < 1e-5);
1098 }
1099
1100 #[test]
1101 fn test_dp45_harmonic_oscillator() {
1102 let f = |_t: f64, y: &[f64]| vec![y[1], -y[0]];
1103 let dp = DormandPrince45::new(1e-9, 1e-9, 1e-12, 1.0);
1104 let s0 = OdeState::new(0.0, vec![1.0, 0.0]); let sol = dp.integrate(&s0, 2.0 * std::f64::consts::PI, 0.1, &f);
1106 let last = sol.states.last().unwrap();
1107 assert!((last.y[0] - 1.0).abs() < 1e-6);
1109 }
1110
1111 #[test]
1112 fn test_dp45_solution_len() {
1113 let dp = DormandPrince45::default_tolerances();
1114 let s0 = OdeState::new(0.0, vec![1.0]);
1115 let sol = dp.integrate(&s0, 1.0, 0.1, &f_decay);
1116 assert!(sol.len() > 1);
1117 }
1118
1119 #[test]
1120 fn test_dp45_fsal_step() {
1121 let dp = DormandPrince45::default_tolerances();
1122 let s0 = OdeState::new(0.0, vec![1.0]);
1123 let (s1, err, _k7) = dp.step(&s0, 0.1, &f_decay, None);
1124 assert!(err >= 0.0);
1125 let exact = (-0.1f64).exp();
1126 assert!((s1.y[0] - exact).abs() < 1e-9);
1127 }
1128
1129 #[test]
1133 fn test_implicit_euler_stiff_decay() {
1134 let f = |_t: f64, y: &[f64]| vec![-100.0 * y[0]];
1135 let ie = ImplicitEuler::default_params();
1136 let s0 = OdeState::new(0.0, vec![1.0]);
1137 let traj = ie.integrate(&s0, 1.0, 0.05, &f);
1138 let last = traj.last().unwrap();
1139 let exact = (-100.0f64).exp();
1140 assert!((last.y[0] - exact).abs() < 0.01);
1141 }
1142
1143 #[test]
1144 fn test_implicit_euler_newton_step() {
1145 let f_lin = |_t: f64, y: &[f64]| vec![-y[0]];
1146 let ie = ImplicitEuler::default_params();
1147 let s0 = OdeState::new(0.0, vec![1.0]);
1148 let s1 = ie.step_newton(&s0, 0.1, &f_lin);
1149 let expected = 1.0 / 1.1;
1151 assert!((s1.y[0] - expected).abs() < 1e-8);
1152 }
1153
1154 #[test]
1155 fn test_implicit_euler_zero_rhs() {
1156 let f_zero = |_t: f64, y: &[f64]| vec![0.0 * y[0]];
1157 let ie = ImplicitEuler::default_params();
1158 let s0 = OdeState::new(0.0, vec![5.0]);
1159 let s1 = ie.step(&s0, 1.0, &f_zero);
1160 assert!((s1.y[0] - 5.0).abs() < 1e-12);
1161 }
1162
1163 #[test]
1167 fn test_trapezoidal_decay() {
1168 let trap = Trapezoidal::default_params();
1169 let s0 = OdeState::new(0.0, vec![1.0]);
1170 let traj = trap.integrate(&s0, 1.0, 0.01, &f_decay);
1171 let last = traj.last().unwrap();
1172 let exact = (-1.0f64).exp();
1173 assert!((last.y[0] - exact).abs() < 1e-5);
1175 }
1176
1177 #[test]
1178 fn test_trapezoidal_single_step() {
1179 let trap = Trapezoidal::new(100, 1e-12);
1180 let s0 = OdeState::new(0.0, vec![1.0]);
1181 let s1 = trap.step(&s0, 0.1, &f_decay);
1182 let expected = (1.0 - 0.05) / (1.0 + 0.05);
1184 assert!((s1.y[0] - expected).abs() < 1e-10);
1185 }
1186
1187 #[test]
1191 fn test_bdf2_decay() {
1192 let bdf2 = BDF2::default_params();
1193 let s0 = OdeState::new(0.0, vec![1.0]);
1194 let traj = bdf2.integrate(&s0, 1.0, 0.01, &f_decay);
1195 let last = traj.last().unwrap();
1196 let exact = (-1.0f64).exp();
1197 assert!((last.y[0] - exact).abs() < 1e-4);
1198 }
1199
1200 #[test]
1201 fn test_bdf2_stiff_lambda_100() {
1202 let f = |_t: f64, y: &[f64]| vec![-100.0 * y[0]];
1208 let bdf2 = BDF2::default_params();
1209 let s0 = OdeState::new(0.0, vec![1.0]);
1210 let traj = bdf2.integrate(&s0, 0.5, 0.05, &f);
1211 let last = traj.last().unwrap();
1212 assert!(
1214 last.y[0].abs() < 0.5,
1215 "BDF2 stiff result out of bounds: {}",
1216 last.y[0]
1217 );
1218 assert!(traj[1].y[0] < 1.0);
1220 }
1221
1222 #[test]
1223 fn test_bdf2_short_interval() {
1224 let bdf2 = BDF2::default_params();
1225 let s0 = OdeState::new(5.0, vec![1.0]);
1226 let traj = bdf2.integrate(&s0, 5.0, 0.1, &f_decay);
1227 assert_eq!(traj.len(), 1); }
1229
1230 #[test]
1234 fn test_event_detection_crossing_zero() {
1235 let ed = EventDetection::default_params();
1236 let s_a = OdeState::new(0.9, vec![0.1]);
1237 let s_b = OdeState::new(1.1, vec![-0.1]);
1238 let events: Vec<fn(f64, &[f64]) -> f64> = vec![|_t, y| y[0]];
1240 let crossings = ed.detect(&s_a, &s_b, &events);
1241 assert_eq!(crossings.len(), 1);
1242 assert!((crossings[0].t - 1.0).abs() < 1e-8);
1243 }
1244
1245 #[test]
1246 fn test_event_detection_no_crossing() {
1247 let ed = EventDetection::default_params();
1248 let s_a = OdeState::new(0.0, vec![1.0]);
1249 let s_b = OdeState::new(1.0, vec![2.0]);
1250 let events: Vec<fn(f64, &[f64]) -> f64> = vec![|_t, y| y[0]];
1251 let crossings = ed.detect(&s_a, &s_b, &events);
1252 assert!(crossings.is_empty());
1253 }
1254
1255 #[test]
1256 fn test_event_detection_time_event() {
1257 let ed = EventDetection::default_params();
1258 let s_a = OdeState::new(0.8, vec![0.0]);
1259 let s_b = OdeState::new(1.2, vec![0.0]);
1260 let events: Vec<fn(f64, &[f64]) -> f64> = vec![|t, _y| t - 1.0];
1262 let crossings = ed.detect(&s_a, &s_b, &events);
1263 assert_eq!(crossings.len(), 1);
1264 assert!((crossings[0].t - 1.0).abs() < 1e-8);
1265 }
1266
1267 #[test]
1268 fn test_event_detection_multiple_events() {
1269 let ed = EventDetection::default_params();
1270 let s_a = OdeState::new(0.0, vec![2.0, -1.0]);
1271 let s_b = OdeState::new(2.0, vec![-2.0, 1.0]);
1272 let ev0: fn(f64, &[f64]) -> f64 = |_t, y| y[0];
1273 let ev1: fn(f64, &[f64]) -> f64 = |_t, y| y[1];
1274 let crossings = ed.detect(&s_a, &s_b, &[ev0, ev1]);
1275 assert_eq!(crossings.len(), 2);
1276 }
1277
1278 #[test]
1282 fn test_ode_solution_interpolate() {
1283 let states = vec![
1284 OdeState::new(0.0, vec![0.0]),
1285 OdeState::new(1.0, vec![1.0]),
1286 OdeState::new(2.0, vec![4.0]),
1287 ];
1288 let sol = OdeSolution::new(states);
1289 let mid = sol.interpolate(0.5).unwrap();
1290 assert!((mid.y[0] - 0.5).abs() < 1e-12);
1291 }
1292
1293 #[test]
1294 fn test_ode_solution_out_of_range() {
1295 let states = vec![OdeState::new(0.0, vec![1.0]), OdeState::new(1.0, vec![2.0])];
1296 let sol = OdeSolution::new(states);
1297 assert!(sol.interpolate(-0.5).is_none());
1298 assert!(sol.interpolate(1.5).is_none());
1299 }
1300
1301 #[test]
1302 fn test_ode_solution_times_and_component() {
1303 let dp = DormandPrince45::default_tolerances();
1304 let s0 = OdeState::new(0.0, vec![1.0, 0.0]);
1305 let f = |_t: f64, y: &[f64]| vec![y[1], -y[0]];
1306 let sol = dp.integrate(&s0, 1.0, 0.1, &f);
1307 let ts = sol.times();
1308 let c0 = sol.component(0);
1309 assert_eq!(ts.len(), c0.len());
1310 }
1311
1312 #[test]
1313 fn test_ode_solution_resample() {
1314 let dp = DormandPrince45::default_tolerances();
1315 let s0 = OdeState::new(0.0, vec![1.0]);
1316 let sol = dp.integrate(&s0, 1.0, 0.1, &f_decay);
1317 let resampled = sol.resample(20);
1318 assert_eq!(resampled.len(), 20);
1319 }
1320
1321 #[test]
1322 fn test_ode_solution_empty() {
1323 let sol = OdeSolution::new(vec![]);
1324 assert!(sol.is_empty());
1325 assert!(sol.interpolate(0.5).is_none());
1326 }
1327
1328 #[test]
1329 fn test_ode_solution_map_observable() {
1330 let states = vec![OdeState::new(0.0, vec![1.0]), OdeState::new(1.0, vec![2.0])];
1331 let sol = OdeSolution::new(states);
1332 let obs = sol.map_observable(|_t, y| y[0] * 2.0);
1333 assert!((obs[0] - 2.0).abs() < 1e-12);
1334 assert!((obs[1] - 4.0).abs() < 1e-12);
1335 }
1336
1337 #[test]
1341 fn test_rk4_vs_dp45_accuracy() {
1342 let f = |_t: f64, y: &[f64]| vec![-y[0]];
1343 let rk4 = RK4Integrator::default_tolerances();
1344 let dp = DormandPrince45::default_tolerances();
1345 let s0 = OdeState::new(0.0, vec![1.0]);
1346
1347 let traj_rk4 = rk4.integrate(&s0, 1.0, 0.01, &f);
1348 let sol_dp = dp.integrate(&s0, 1.0, 0.1, &f);
1349
1350 let exact = (-1.0f64).exp();
1351 let err_rk4 = (traj_rk4.last().unwrap().y[0] - exact).abs();
1352 let err_dp = (sol_dp.states.last().unwrap().y[0] - exact).abs();
1353
1354 assert!(err_dp < 1e-6);
1356 assert!(err_rk4 < 1e-6);
1357 }
1358
1359 #[test]
1360 fn test_implicit_vs_explicit_stiff() {
1361 let f = |_t: f64, y: &[f64]| vec![-1000.0 * y[0]];
1364 let ie = ImplicitEuler::default_params();
1365 let s0 = OdeState::new(0.0, vec![1.0]);
1366 let traj = ie.integrate(&s0, 0.01, 0.001, &f);
1367 let last = traj.last().unwrap().y[0];
1369 assert!((0.0..=1.0).contains(&last));
1370 }
1371}