bacon_sci/ivp/
rk.rs

1/* This file is part of bacon.
2 * Copyright (c) Wyatt Campbell.
3 *
4 * See repository LICENSE for information.
5 */
6
7use super::{Derivative, IVPError, IVPIterator, IVPSolver, IVPStatus, IVPStepper, Step};
8use crate::{BMatrix, BSMatrix, BSVector, BVector, Dimension};
9use nalgebra::{
10    allocator::Allocator, ComplexField, Const, DefaultAllocator, Dim, DimName, RealField, U1,
11};
12use num_traits::{FromPrimitive, One, Zero};
13use std::marker::PhantomData;
14
15/// This trait defines a Runge-Kutta solver
16/// The [`RungeKutta`] struct takes an implementation of this trait
17/// as a type argument since the algorithm is the same for
18/// all the methods, just the order and these functions
19/// need to be different.
20pub trait RungeKuttaCoefficients<const O: usize> {
21    /// The real field associated with the solver's Field.
22    type RealField: RealField;
23
24    /// Returns a vec of coefficients to multiply the time step by when getting
25    /// intermediate results. Upper-left portion of Butcher Tableaux
26    fn t_coefficients() -> Option<BSVector<Self::RealField, O>>;
27
28    /// Returns the coefficients to use on the k_i's when finding another
29    /// k_i. Upper-right portion of the Butcher Tableaux. Should be
30    /// an NxN-1 matrix, where N is the order of the Runge-Kutta Method (Or order+1 for
31    /// adaptive methods)
32    fn k_coefficients() -> Option<BSMatrix<Self::RealField, O, O>>;
33
34    /// Coefficients to use when calculating the final step to take.
35    /// These are the weights of the weighted average of k_i's. Bottom
36    /// portion of the Butcher Tableaux. For adaptive methods, this is the first
37    /// row of the bottom portion.
38    fn avg_coefficients() -> Option<BSVector<Self::RealField, O>>;
39
40    /// Coefficients to use on
41    /// the k_i's to find the error between the two orders
42    /// of Runge-Kutta methods. In the Butcher Tableaux, this is
43    /// the first row of the bottom portion minus the second row.
44    fn error_coefficients() -> Option<BSVector<Self::RealField, O>>;
45}
46
47/// The nuts and bolts Runge-Kutta solver
48/// Users won't use this directly if they aren't defining their own Runge-Kutta solver
49/// Used as a common struct for the specific implementations
50pub struct RungeKutta<'a, N, D, const O: usize, T, F, R>
51where
52    D: Dimension,
53    N: ComplexField + Copy,
54    T: Clone,
55    F: Derivative<N, D, T> + 'a,
56    R: RungeKuttaCoefficients<O, RealField = N::RealField>,
57    DefaultAllocator: Allocator<N, D>,
58    DefaultAllocator: Allocator<N, Const<O>>,
59{
60    init_dt_max: Option<N::RealField>,
61    init_dt_min: Option<N::RealField>,
62    init_time: Option<N::RealField>,
63    init_end: Option<N::RealField>,
64    init_tolerance: Option<N::RealField>,
65    init_state: Option<BVector<N, D>>,
66    init_derivative: Option<F>,
67    dim: D,
68    _data: PhantomData<&'a (T, R)>,
69}
70
71/// The solver for any Runge-Kutta method
72/// Users should not use this type directly, and should
73/// instead get it from a specific [`RungeKutta`] struct
74/// (wrapped in an [`IVPIterator`])
75pub struct RungeKuttaSolver<'a, N, D, const O: usize, T, F>
76where
77    D: Dimension,
78    N: ComplexField + Copy,
79    T: Clone,
80    F: Derivative<N, D, T> + 'a,
81    DefaultAllocator: Allocator<N, D>,
82    DefaultAllocator: Allocator<N, Const<O>>,
83    DefaultAllocator: Allocator<N, D, Const<O>>,
84{
85    // Parameters set by the user
86    dt_max: N,
87    dt_min: N,
88    time: N,
89    end: N,
90    tolerance: N,
91    derivative: F,
92    data: T,
93
94    // The current state of the solver
95    dt: N,
96    state: BVector<N, D>,
97
98    // Per-order constants set by [`RungeKuttaCoefficients`]
99    t_coefficients: BSVector<N, O>,
100    k_coefficients: BSMatrix<N, O, O>,
101    avg_coefficients: BSVector<N, O>,
102    error_coefficients: BSVector<N, O>,
103
104    // Scratch space to store the partial steps needed for the algorithm
105    half_steps: BMatrix<N, D, Const<O>>,
106    step: BVector<N, D>,
107    scratch_pad: BVector<N, D>,
108
109    // Constants needed for algorithm
110    one_tenth: N,
111    one_fourth: N,
112    point_eighty_four: N,
113    four: N,
114
115    _lifetime: PhantomData<&'a ()>,
116}
117
118impl<'a, N, D, const O: usize, T, F, R> IVPSolver<'a, D> for RungeKutta<'a, N, D, O, T, F, R>
119where
120    D: Dimension,
121    N: ComplexField + Copy,
122    T: Clone,
123    F: Derivative<N, D, T> + 'a,
124    R: RungeKuttaCoefficients<O, RealField = N::RealField>,
125    DefaultAllocator: Allocator<N, D>,
126    DefaultAllocator: Allocator<N, Const<O>>,
127    DefaultAllocator: Allocator<N, D, Const<O>>,
128{
129    type Error = IVPError;
130    type Field = N;
131    type RealField = N::RealField;
132    type Derivative = F;
133    type UserData = T;
134    type Solver = RungeKuttaSolver<'a, N, D, O, T, F>;
135
136    fn new() -> Result<Self, IVPError> {
137        Ok(Self {
138            init_dt_max: None,
139            init_dt_min: None,
140            init_time: None,
141            init_end: None,
142            init_tolerance: None,
143            init_state: None,
144            init_derivative: None,
145            dim: D::dim()?,
146            _data: PhantomData,
147        })
148    }
149
150    fn new_dyn(size: usize) -> Result<Self, IVPError> {
151        Ok(Self {
152            init_dt_max: None,
153            init_dt_min: None,
154            init_time: None,
155            init_end: None,
156            init_tolerance: None,
157            init_state: None,
158            init_derivative: None,
159            dim: D::dim_dyn(size)?,
160            _data: PhantomData,
161        })
162    }
163
164    fn dim(&self) -> D {
165        self.dim
166    }
167
168    fn with_tolerance(mut self, tol: Self::RealField) -> Result<Self, Self::Error> {
169        if tol <= <Self::RealField as Zero>::zero() {
170            return Err(IVPError::ToleranceOOB);
171        }
172        self.init_tolerance = Some(tol);
173        Ok(self)
174    }
175
176    /// Will overwrite any previously set value
177    /// If the provided maximum is less than a previously set minimum, then the minimum
178    /// is set to this value as well.
179    fn with_maximum_dt(mut self, max: Self::RealField) -> Result<Self, Self::Error> {
180        if max <= <Self::RealField as Zero>::zero() {
181            return Err(IVPError::TimeDeltaOOB);
182        }
183
184        self.init_dt_max = Some(max.clone());
185        if let Some(dt_min) = self.init_dt_min.as_mut() {
186            if *dt_min > max {
187                *dt_min = max;
188            }
189        }
190
191        Ok(self)
192    }
193
194    /// Will overwrite any previously set value
195    /// If the provided minimum is greater than a previously set maximum, then the maximum
196    /// is set to this value as well.
197    fn with_minimum_dt(mut self, min: Self::RealField) -> Result<Self, Self::Error> {
198        if min <= <Self::RealField as Zero>::zero() {
199            return Err(IVPError::TimeDeltaOOB);
200        }
201
202        self.init_dt_min = Some(min.clone());
203        if let Some(dt_max) = self.init_dt_max.as_mut() {
204            if *dt_max < min {
205                *dt_max = min;
206            }
207        }
208
209        Ok(self)
210    }
211
212    fn with_initial_time(mut self, initial: Self::RealField) -> Result<Self, Self::Error> {
213        self.init_time = Some(initial.clone());
214
215        if let Some(end) = self.init_end.as_ref() {
216            if *end <= initial {
217                return Err(IVPError::TimeStartOOB);
218            }
219        }
220
221        Ok(self)
222    }
223
224    fn with_ending_time(mut self, ending: Self::RealField) -> Result<Self, Self::Error> {
225        self.init_end = Some(ending.clone());
226
227        if let Some(initial) = self.init_time.as_ref() {
228            if *initial >= ending {
229                return Err(IVPError::TimeEndOOB);
230            }
231        }
232
233        Ok(self)
234    }
235
236    fn with_initial_conditions(
237        mut self,
238        start: BVector<Self::Field, D>,
239    ) -> Result<Self, Self::Error> {
240        self.init_state = Some(start);
241        Ok(self)
242    }
243
244    fn with_derivative(mut self, derivative: Self::Derivative) -> Self {
245        self.init_derivative = Some(derivative);
246        self
247    }
248
249    fn solve(self, data: Self::UserData) -> Result<IVPIterator<D, Self::Solver>, Self::Error> {
250        let dt_max = self.init_dt_max.ok_or(IVPError::MissingParameters)?;
251        let dt_min = self.init_dt_min.ok_or(IVPError::MissingParameters)?;
252        let tolerance = self.init_tolerance.ok_or(IVPError::MissingParameters)?;
253        let time = self.init_time.ok_or(IVPError::MissingParameters)?;
254        let end = self.init_end.ok_or(IVPError::MissingParameters)?;
255        let state = self.init_state.ok_or(IVPError::MissingParameters)?;
256        let derivative = self.init_derivative.ok_or(IVPError::MissingParameters)?;
257
258        let two = Self::Field::from_u8(2).ok_or(IVPError::FromPrimitiveFailure)?;
259        let half = Self::Field::one() / two;
260
261        let one_tenth =
262            Self::Field::one() / Self::Field::from_u8(10).ok_or(IVPError::FromPrimitiveFailure)?;
263        let four = Self::Field::from_u8(4).ok_or(IVPError::FromPrimitiveFailure)?;
264        let one_fourth = Self::Field::one() / four;
265
266        let one_hundred = Self::Field::from_u8(100).ok_or(IVPError::FromPrimitiveFailure)?;
267        let eighty_four = Self::Field::from_u8(100).ok_or(IVPError::FromPrimitiveFailure)?;
268        let point_eighty_four = eighty_four / one_hundred;
269
270        let t_coefficients = BSVector::from_iterator(
271            R::t_coefficients()
272                .ok_or(IVPError::FromPrimitiveFailure)?
273                .as_slice()
274                .iter()
275                .cloned()
276                .map(Self::Field::from_real),
277        );
278
279        let k_coefficients = BSMatrix::<N, O, O>::from_iterator_generic(
280            <Const<O> as Dim>::from_usize(O),
281            <Const<O> as Dim>::from_usize(O),
282            R::k_coefficients()
283                .ok_or(IVPError::FromPrimitiveFailure)?
284                .as_slice()
285                .iter()
286                .cloned()
287                .map(Self::Field::from_real),
288        );
289
290        let avg_coefficients = BSVector::from_iterator(
291            R::avg_coefficients()
292                .ok_or(IVPError::FromPrimitiveFailure)?
293                .as_slice()
294                .iter()
295                .cloned()
296                .map(Self::Field::from_real),
297        );
298
299        let error_coefficients = BSVector::from_iterator(
300            R::error_coefficients()
301                .ok_or(IVPError::FromPrimitiveFailure)?
302                .as_slice()
303                .iter()
304                .cloned()
305                .map(Self::Field::from_real),
306        );
307
308        Ok(IVPIterator {
309            solver: RungeKuttaSolver {
310                dt_max: Self::Field::from_real(dt_max.clone()),
311                dt_min: Self::Field::from_real(dt_min.clone()),
312                time: Self::Field::from_real(time),
313                end: Self::Field::from_real(end),
314                tolerance: Self::Field::from_real(tolerance),
315                dt: Self::Field::from_real(dt_max + dt_min) * half,
316                state,
317                derivative,
318                data,
319                t_coefficients,
320                k_coefficients,
321                avg_coefficients,
322                error_coefficients,
323                half_steps: BMatrix::from_element_generic(
324                    self.dim,
325                    <Const<O> as DimName>::name(),
326                    Self::Field::zero(),
327                ),
328                scratch_pad: BVector::from_element_generic(
329                    self.dim,
330                    U1::name(),
331                    Self::Field::zero(),
332                ),
333                step: BVector::from_element_generic(self.dim, U1::name(), Self::Field::zero()),
334                one_tenth,
335                one_fourth,
336                point_eighty_four,
337                four,
338                _lifetime: PhantomData,
339            },
340            finished: false,
341            _dim: PhantomData,
342        })
343    }
344}
345
346impl<'a, N, D, const O: usize, T, F> IVPStepper<D> for RungeKuttaSolver<'a, N, D, O, T, F>
347where
348    D: Dimension,
349    N: ComplexField + Copy,
350    T: Clone,
351    F: Derivative<N, D, T> + 'a,
352    DefaultAllocator: Allocator<N, D>,
353    DefaultAllocator: Allocator<N, Const<O>>,
354    DefaultAllocator: Allocator<N, D, Const<O>>,
355{
356    type Error = IVPError;
357    type Field = N;
358    type RealField = N::RealField;
359    type UserData = T;
360
361    fn step(&mut self) -> Step<Self::RealField, Self::Field, D, Self::Error> {
362        if self.time.real() >= self.end.real() {
363            return Err(IVPStatus::Done);
364        }
365
366        if self.time.real() + self.dt.real() >= self.end.real() {
367            self.dt = self.end - self.time;
368        }
369
370        for (i, k_row) in self.k_coefficients.row_iter().enumerate() {
371            self.scratch_pad = self.state.clone();
372            for (j, &k_coeff) in k_row.iter().enumerate() {
373                self.scratch_pad += self.half_steps.column(j) * k_coeff;
374            }
375
376            let step_time = self.time + self.t_coefficients[i] * self.dt;
377            self.step = (self.derivative)(
378                step_time.real(),
379                self.scratch_pad.as_slice(),
380                &mut self.data.clone(),
381            )? * self.dt;
382
383            self.half_steps.set_column(i, &self.step);
384        }
385
386        self.scratch_pad = self.half_steps.column(0) * self.error_coefficients[0];
387        for (ind, &e_coeff) in self.error_coefficients.iter().enumerate().skip(1) {
388            self.scratch_pad += self.half_steps.column(ind) * e_coeff;
389        }
390        let error = self.scratch_pad.norm() / self.dt.real();
391
392        if error <= self.tolerance.real() {
393            self.time += self.dt;
394
395            for (ind, &avg_coeff) in self.avg_coefficients.iter().enumerate() {
396                self.state += self.half_steps.column(ind) * avg_coeff;
397            }
398        }
399
400        let delta = self.point_eighty_four.real()
401            * (self.tolerance.real() / error.clone()).powf(self.one_fourth.real());
402        if delta <= self.one_tenth.real() {
403            self.dt *= self.one_tenth;
404        } else if delta >= self.four.real() {
405            self.dt *= self.four;
406        } else {
407            self.dt *= Self::Field::from_real(delta);
408        }
409
410        if self.dt.real() > self.dt_max.real() {
411            self.dt = self.dt_max;
412        }
413
414        if self.dt.real() < self.dt_min.real() && self.time.real() < self.end.real() {
415            return Err(IVPStatus::Failure(IVPError::MinimumTimeDeltaExceeded));
416        }
417
418        if error <= self.tolerance.real() {
419            Ok((self.time.real(), self.state.clone()))
420        } else {
421            Err(IVPStatus::Redo)
422        }
423    }
424
425    fn time(&self) -> Self::RealField {
426        self.time.real()
427    }
428}
429
430pub struct RKCoefficients45<N: ComplexField>(PhantomData<N>);
431
432impl<N: ComplexField> RungeKuttaCoefficients<6> for RKCoefficients45<N> {
433    type RealField = N::RealField;
434
435    fn t_coefficients() -> Option<BSVector<Self::RealField, 6>> {
436        let one_fourth = Self::RealField::from_u8(4)?.recip();
437        let one_half = Self::RealField::from_u8(2)?.recip();
438        let three = Self::RealField::from_u8(3)?;
439        let eight = Self::RealField::from_u8(8)?;
440        let twelve = Self::RealField::from_u8(12)?;
441        let thirteen = Self::RealField::from_u8(13)?;
442
443        Some(BSVector::from_column_slice(&[
444            Self::RealField::zero(),
445            one_fourth,
446            three / eight,
447            twelve / thirteen,
448            Self::RealField::one(),
449            one_half,
450        ]))
451    }
452
453    fn k_coefficients() -> Option<BSMatrix<Self::RealField, 6, 6>> {
454        let zero = Self::RealField::zero();
455        let one_fourth = Self::RealField::from_u8(4)?.recip();
456        let thirty_two = Self::RealField::from_u8(32)?;
457        let two_one_nine_seven = Self::RealField::from_u16(2197)?;
458
459        Some(BSMatrix::from_vec(vec![
460            // Row 0
461            zero.clone(),
462            zero.clone(),
463            zero.clone(),
464            zero.clone(),
465            zero.clone(),
466            zero.clone(),
467            // Row 1
468            one_fourth,
469            zero.clone(),
470            zero.clone(),
471            zero.clone(),
472            zero.clone(),
473            zero.clone(),
474            // Row 2
475            Self::RealField::from_u8(3)? / thirty_two.clone(),
476            Self::RealField::from_u8(9)? / thirty_two.clone(),
477            zero.clone(),
478            zero.clone(),
479            zero.clone(),
480            zero.clone(),
481            // Row 3
482            Self::RealField::from_u16(1932)? / two_one_nine_seven.clone(),
483            -Self::RealField::from_u16(7200)? / two_one_nine_seven.clone(),
484            Self::RealField::from_u16(7296)? / two_one_nine_seven,
485            zero.clone(),
486            zero.clone(),
487            zero.clone(),
488            // Row 4
489            Self::RealField::from_u16(439)? / Self::RealField::from_u8(216)?,
490            -Self::RealField::from_u8(8)?,
491            Self::RealField::from_u16(3680)? / Self::RealField::from_u16(513)?,
492            -Self::RealField::from_u16(845)? / Self::RealField::from_u16(4104)?,
493            zero.clone(),
494            zero.clone(),
495            // Row 5
496            -Self::RealField::from_u8(8)? / Self::RealField::from_u8(27)?,
497            Self::RealField::from_u8(2)?,
498            -Self::RealField::from_u16(3544)? / Self::RealField::from_u16(2565)?,
499            Self::RealField::from_u16(1859)? / Self::RealField::from_u16(4014)?,
500            -Self::RealField::from_u8(11)? / Self::RealField::from_u8(40)?,
501            zero,
502        ]))
503    }
504
505    fn avg_coefficients() -> Option<BSVector<Self::RealField, 6>> {
506        Some(BSVector::from_column_slice(&[
507            Self::RealField::from_u8(25)? / Self::RealField::from_u8(216)?,
508            Self::RealField::zero(),
509            Self::RealField::from_u16(1408)? / Self::RealField::from_u16(2565)?,
510            Self::RealField::from_u16(2197)? / Self::RealField::from_u16(4104)?,
511            -Self::RealField::from_u8(5)?.recip(),
512            Self::RealField::zero(),
513        ]))
514    }
515
516    fn error_coefficients() -> Option<BSVector<Self::RealField, 6>> {
517        Some(BSVector::from_column_slice(&[
518            Self::RealField::from_u16(360)?.recip(),
519            Self::RealField::from_f64(0.0).unwrap(),
520            Self::RealField::from_f64(-128.0 / 4275.0).unwrap(),
521            Self::RealField::from_f64(-2197.0 / 75240.0).unwrap(),
522            Self::RealField::from_f64(1.0 / 50.0).unwrap(),
523            Self::RealField::from_f64(2.0 / 55.0).unwrap(),
524        ]))
525    }
526}
527
528/// Runge-Kutta-Fehlberg method for solving an IVP.
529///
530/// Defines the Butcher Tableaux for a 5(4) order adaptive
531/// Runge-Kutta method. Uses [`RungeKutta`] to do the actual solving.
532/// Provides an implementation of the [`IVPSolver`] trait.
533///
534/// # Examples
535/// ```
536/// use std::error::Error;
537/// use bacon_sci::{BSVector, ivp::{IVPSolver, IVPError, rk::RungeKutta45}};
538///
539/// fn derivatives(_t: f64, state: &[f64], _p: &mut ()) -> Result<BSVector<f64, 1>, Box<dyn Error>> {
540///     Ok(BSVector::from_column_slice(state))
541/// }
542///
543/// fn example() -> Result<(), IVPError> {
544///     let rk45 = RungeKutta45::new()?
545///         .with_maximum_dt(0.1)?
546///         .with_minimum_dt(0.001)?
547///         .with_initial_time(0.0)?
548///         .with_ending_time(10.0)?
549///         .with_tolerance(0.0001)?
550///         .with_initial_conditions_slice(&[1.0])?
551///         .with_derivative(derivatives)
552///         .solve(())?;
553///
554///     let path = rk45.collect_vec()?;
555///     for (time, state) in &path {
556///         assert!((time.exp() - state.column(0)[0]).abs() < 0.001);
557///     }
558///     Ok(())
559/// }
560/// ```
561pub type RungeKutta45<'a, N, D, T, F> = RungeKutta<'a, N, D, 6, T, F, RKCoefficients45<N>>;
562
563pub struct RK23Coefficients<N: ComplexField>(PhantomData<N>);
564
565impl<N: ComplexField> RungeKuttaCoefficients<4> for RK23Coefficients<N> {
566    type RealField = N::RealField;
567
568    fn t_coefficients() -> Option<BSVector<Self::RealField, 4>> {
569        Some(BSVector::from_column_slice(&[
570            Self::RealField::zero(),
571            Self::RealField::from_u8(2)?.recip(),
572            Self::RealField::from_u8(3)? / Self::RealField::from_u8(4)?,
573            Self::RealField::one(),
574        ]))
575    }
576
577    fn k_coefficients() -> Option<BSMatrix<Self::RealField, 4, 4>> {
578        let zero = Self::RealField::zero();
579
580        Some(BSMatrix::from_vec(vec![
581            // Row 0
582            zero.clone(),
583            zero.clone(),
584            zero.clone(),
585            zero.clone(),
586            // Row 1
587            Self::RealField::from_u8(2)?.recip(),
588            zero.clone(),
589            zero.clone(),
590            zero.clone(),
591            // Row 2
592            zero.clone(),
593            Self::RealField::from_u8(3)? / Self::RealField::from_u8(4)?,
594            zero.clone(),
595            zero.clone(),
596            // Row 3
597            Self::RealField::from_u8(2)? / Self::RealField::from_u8(9)?,
598            Self::RealField::from_u8(3)?.recip(),
599            Self::RealField::from_u8(4)? / Self::RealField::from_u8(9)?,
600            zero,
601        ]))
602    }
603
604    fn avg_coefficients() -> Option<BSVector<Self::RealField, 4>> {
605        Some(BSVector::from_column_slice(&[
606            Self::RealField::from_u8(2)? / Self::RealField::from_u8(9)?,
607            Self::RealField::from_u8(3)?.recip(),
608            Self::RealField::from_u8(4)? / Self::RealField::from_u8(9)?,
609            Self::RealField::zero(),
610        ]))
611    }
612
613    fn error_coefficients() -> Option<BSVector<Self::RealField, 4>> {
614        Some(BSVector::from_column_slice(&[
615            -Self::RealField::from_u8(5)? / Self::RealField::from_u8(72)?,
616            Self::RealField::from_u8(12)?.recip(),
617            Self::RealField::from_u8(9)?.recip(),
618            -Self::RealField::from_u8(8)?.recip(),
619        ]))
620    }
621}
622
623/// Bogacki-Shampine method for solving an IVP.
624///
625/// Defines the Butcher Tableaux for a 3(2) order adaptive
626/// Runge-Kutta method. Uses [`RungeKutta`] to do the actual solving.
627/// Provides an implementation of the [`IVPSolver`] trait.
628///
629/// # Examples
630/// ```
631/// use std::error::Error;
632/// use bacon_sci::{BSVector, ivp::{IVPSolver, IVPError, rk::RungeKutta23}};
633///
634/// fn derivatives(_t: f64, state: &[f64], _p: &mut ()) -> Result<BSVector<f64, 1>, Box<dyn Error>> {
635///     Ok(BSVector::from_column_slice(state))
636/// }
637///
638/// fn example() -> Result<(), IVPError> {
639///     let rk23 = RungeKutta23::new()?
640///         .with_maximum_dt(0.1)?
641///         .with_minimum_dt(0.001)?
642///         .with_initial_time(0.0)?
643///         .with_ending_time(10.0)?
644///         .with_tolerance(0.0001)?
645///         .with_initial_conditions_slice(&[1.0])?
646///         .with_derivative(derivatives)
647///         .solve(())?;
648///
649///     let path = rk23.collect_vec()?;
650///     for (time, state) in &path {
651///         assert!((time.exp() - state.column(0)[0]).abs() < 0.001);
652///     }
653///     Ok(())
654/// }
655/// ```
656pub type RungeKutta23<'a, N, D, T, F> = RungeKutta<'a, N, D, 4, T, F, RK23Coefficients<N>>;
657
658#[cfg(test)]
659mod test {
660    use super::*;
661    use crate::{ivp::UserError, BSVector};
662    use rstest::rstest;
663
664    fn quadratic_deriv(t: f64, _y: &[f64], _: &mut ()) -> Result<BSVector<f64, 1>, UserError> {
665        Ok(BSVector::from_column_slice(&[-2.0 * t]))
666    }
667
668    fn sine_deriv(t: f64, _y: &[f64], _: &mut ()) -> Result<BSVector<f64, 1>, UserError> {
669        Ok(BSVector::from_column_slice(&[t.cos()]))
670    }
671
672    type TestRK<'a, const O: usize, R> = RungeKutta<
673        'a,
674        f64,
675        Const<1>,
676        O,
677        (),
678        fn(f64, &[f64], &mut ()) -> Result<BVector<f64, Const<1>>, UserError>,
679        R,
680    >;
681
682    #[rstest]
683    #[case::rk23(RungeKutta23::new().unwrap())]
684    #[case::rk45(RungeKutta45::new().unwrap())]
685    fn rungekutta_quadratic<'a, const O: usize, R>(#[case] rk: TestRK<'a, O, R>)
686    where
687        R: RungeKuttaCoefficients<O, RealField = f64>,
688    {
689        let t_initial = 0.0;
690        let t_final = 10.0;
691
692        let solver = rk
693            .with_minimum_dt(0.0001)
694            .unwrap()
695            .with_maximum_dt(0.1)
696            .unwrap()
697            .with_initial_time(t_initial)
698            .unwrap()
699            .with_ending_time(t_final)
700            .unwrap()
701            .with_tolerance(1e-5)
702            .unwrap()
703            .with_initial_conditions_slice(&[1.0])
704            .unwrap()
705            .with_derivative(quadratic_deriv)
706            .solve(())
707            .unwrap();
708
709        let path = solver.collect_vec().unwrap();
710
711        for step in &path {
712            assert!(approx_eq!(
713                f64,
714                step.1.column(0)[0],
715                1.0 - step.0.powi(2),
716                epsilon = 0.0001
717            ));
718        }
719    }
720
721    #[rstest]
722    #[case::rk23(RungeKutta23::new().unwrap())]
723    #[case::rk45(RungeKutta45::new().unwrap())]
724    fn rungekutta_sine<'a, const O: usize, R>(#[case] rk: TestRK<'a, O, R>)
725    where
726        R: RungeKuttaCoefficients<O, RealField = f64>,
727    {
728        let t_initial = 0.0;
729        let t_final = 10.0;
730
731        let solver = rk
732            .with_minimum_dt(0.001)
733            .unwrap()
734            .with_maximum_dt(0.01)
735            .unwrap()
736            .with_tolerance(0.0001)
737            .unwrap()
738            .with_initial_time(t_initial)
739            .unwrap()
740            .with_ending_time(t_final)
741            .unwrap()
742            .with_initial_conditions_slice(&[0.0])
743            .unwrap()
744            .with_derivative(sine_deriv)
745            .solve(())
746            .unwrap();
747
748        let path = solver.collect_vec().unwrap();
749
750        for step in &path {
751            assert!(approx_eq!(
752                f64,
753                step.1.column(0)[0],
754                step.0.sin(),
755                epsilon = 0.01
756            ));
757        }
758    }
759}