bacon_sci/ivp/
adams.rs

1/* This file is part of bacon.
2 * Copyright (c) Wyatt Campbell.
3 *
4 * See repository LICENSE for information.
5 */
6
7use super::{Derivative, IVPError, IVPIterator, IVPSolver, IVPStatus, IVPStepper, Step};
8use crate::{BSVector, BVector, Dimension};
9use nalgebra::{
10    allocator::Allocator, ComplexField, Const, DefaultAllocator, DimName, RealField, U1,
11};
12use num_traits::{FromPrimitive, One, Zero};
13use std::collections::VecDeque;
14use std::marker::PhantomData;
15
16/// This trait defines an Adams predictor-corrector solver
17/// The [`Adams`] struct takes an implemetation of this trait
18/// as a type argument since the algorithm is the same for
19/// all the predictor correctors, just the order and these functions
20/// need to be different.
21pub trait AdamsCoefficients<const O: usize> {
22    /// The real field associated with the solver's Field.
23    type RealField: RealField;
24
25    /// The polynomial interpolation coefficients for the predictor. Should start
26    /// with the coefficient for n - 1
27    fn predictor_coefficients() -> Option<BSVector<Self::RealField, O>>;
28
29    /// The polynomial interpolation coefficients for the corrector. Must be
30    /// the same length as predictor_coefficients. Should start with the
31    /// implicit coefficient.
32    fn corrector_coefficients() -> Option<BSVector<Self::RealField, O>>;
33
34    /// Coefficient for multiplying error by.
35    fn error_coefficient() -> Option<Self::RealField>;
36}
37
38/// The nuts and bolts Adams solver
39/// Users won't use this directly if they aren't defining their own Adams predictor-corrector
40/// Used as a common struct for the specific implementations
41pub struct Adams<'a, N, D, const O: usize, T, F, A>
42where
43    D: Dimension,
44    N: ComplexField + Copy,
45    T: Clone,
46    F: Derivative<N, D, T> + 'a,
47    A: AdamsCoefficients<O, RealField = N::RealField>,
48    DefaultAllocator: Allocator<N, D>,
49{
50    init_dt_max: Option<N::RealField>,
51    init_dt_min: Option<N::RealField>,
52    init_time: Option<N::RealField>,
53    init_end: Option<N::RealField>,
54    init_tolerance: Option<N::RealField>,
55    init_state: Option<BVector<N, D>>,
56    init_derivative: Option<F>,
57    dim: D,
58    _data: PhantomData<&'a (T, A)>,
59}
60
61/// The solver for any Adams predictor-corrector
62/// Users should not use this type directly, and should
63/// instead get it from a specific Adams method struct
64/// (wrapped in an IVPIterator)
65pub struct AdamsSolver<'a, N, D, const O: usize, T, F>
66where
67    D: Dimension,
68    N: ComplexField + Copy,
69    T: Clone,
70    F: Derivative<N, D, T> + 'a,
71    DefaultAllocator: Allocator<N, D>,
72{
73    // Parameters set by the user
74    dt_max: N,
75    dt_min: N,
76    time: N,
77    end: N,
78    tolerance: N,
79    derivative: F,
80    data: T,
81
82    // Current solution at t = self.time
83    dt: N,
84    state: BVector<N, D>,
85
86    // Per-order constants set by an AdamsCoefficients
87    predictor_coefficients: BSVector<N, O>,
88    corrector_coefficients: BSVector<N, O>,
89    error_coefficient: N,
90
91    // Previous steps to interpolate with
92    prev_values: VecDeque<(N::RealField, BVector<N, D>)>,
93    prev_derivatives: VecDeque<BVector<N, D>>,
94
95    // A scratch vector to use during the algorithm (to avoid allocating & de-allocating every step)
96    scratch_pad: BVector<N, D>,
97    // Another scratch vector, used to store values for the implicit step
98    implicit_derivs: BVector<N, D>,
99    // A place to store solver state while taking speculative steps trying to find a good timestep
100    save_state: BVector<N, D>,
101
102    // Constants for the particular field
103    one_tenth: N,
104    one_sixth: N,
105    half: N,
106    two: N,
107    four: N,
108
109    // generic parameter O in the type N
110    order: N,
111
112    // The number of items in prev_values that need to be yielded to the iterator
113    // due to a previous runge-kutta step
114    yield_memory: usize,
115
116    _lifetime: PhantomData<&'a ()>,
117}
118
119impl<'a, N, D, const O: usize, T, F, A> IVPSolver<'a, D> for Adams<'a, N, D, O, T, F, A>
120where
121    D: Dimension,
122    N: ComplexField + Copy,
123    T: Clone,
124    F: Derivative<N, D, T> + 'a,
125    A: AdamsCoefficients<O, RealField = N::RealField>,
126    DefaultAllocator: Allocator<N, D>,
127    DefaultAllocator: Allocator<N, Const<O>>,
128{
129    type Error = IVPError;
130    type Field = N;
131    type RealField = N::RealField;
132    type Derivative = F;
133    type UserData = T;
134    type Solver = AdamsSolver<'a, N, D, O, T, F>;
135
136    fn new() -> Result<Self, IVPError> {
137        Ok(Self {
138            init_dt_max: None,
139            init_dt_min: None,
140            init_time: None,
141            init_end: None,
142            init_tolerance: None,
143            init_state: None,
144            init_derivative: None,
145            dim: D::dim()?,
146            _data: PhantomData,
147        })
148    }
149
150    fn new_dyn(size: usize) -> Result<Self, Self::Error> {
151        Ok(Self {
152            init_dt_max: None,
153            init_dt_min: None,
154            init_time: None,
155            init_end: None,
156            init_tolerance: None,
157            init_state: None,
158            init_derivative: None,
159            dim: D::dim_dyn(size)?,
160            _data: PhantomData,
161        })
162    }
163
164    fn dim(&self) -> D {
165        self.dim
166    }
167
168    fn with_tolerance(mut self, tol: Self::RealField) -> Result<Self, Self::Error> {
169        if tol <= <Self::RealField as Zero>::zero() {
170            return Err(IVPError::ToleranceOOB);
171        }
172        self.init_tolerance = Some(tol);
173        Ok(self)
174    }
175
176    /// Will overwrite any previously set value
177    /// If the provided maximum is less than a previously set minimum, then the minimum
178    /// is set to this value as well.
179    fn with_maximum_dt(mut self, max: Self::RealField) -> Result<Self, Self::Error> {
180        if max <= <Self::RealField as Zero>::zero() {
181            return Err(IVPError::TimeDeltaOOB);
182        }
183
184        self.init_dt_max = Some(max.clone());
185        if let Some(dt_min) = self.init_dt_min.as_mut() {
186            if *dt_min > max {
187                *dt_min = max;
188            }
189        }
190
191        Ok(self)
192    }
193
194    /// Will overwrite any previously set value
195    /// If the provided minimum is greatear than a previously set maximum, then the maximum
196    /// is set to this value as well.
197    fn with_minimum_dt(mut self, min: Self::RealField) -> Result<Self, Self::Error> {
198        if min <= <Self::RealField as Zero>::zero() {
199            return Err(IVPError::TimeDeltaOOB);
200        }
201
202        self.init_dt_min = Some(min.clone());
203        if let Some(dt_max) = self.init_dt_max.as_mut() {
204            if *dt_max < min {
205                *dt_max = min;
206            }
207        }
208
209        Ok(self)
210    }
211
212    fn with_initial_time(mut self, initial: Self::RealField) -> Result<Self, Self::Error> {
213        self.init_time = Some(initial.clone());
214
215        if let Some(end) = self.init_end.as_ref() {
216            if *end <= initial {
217                return Err(IVPError::TimeStartOOB);
218            }
219        }
220
221        Ok(self)
222    }
223
224    fn with_ending_time(mut self, ending: Self::RealField) -> Result<Self, Self::Error> {
225        self.init_end = Some(ending.clone());
226
227        if let Some(initial) = self.init_time.as_ref() {
228            if *initial >= ending {
229                return Err(IVPError::TimeEndOOB);
230            }
231        }
232
233        Ok(self)
234    }
235
236    fn with_initial_conditions(
237        mut self,
238        start: BVector<Self::Field, D>,
239    ) -> Result<Self, Self::Error> {
240        self.init_state = Some(start);
241        Ok(self)
242    }
243
244    fn with_derivative(mut self, derivative: Self::Derivative) -> Self {
245        self.init_derivative = Some(derivative);
246        self
247    }
248
249    fn solve(self, data: Self::UserData) -> Result<IVPIterator<D, Self::Solver>, Self::Error> {
250        let dt_max = self.init_dt_max.ok_or(IVPError::MissingParameters)?;
251        let dt_min = self.init_dt_min.ok_or(IVPError::MissingParameters)?;
252        let tolerance = self.init_tolerance.ok_or(IVPError::MissingParameters)?;
253        let time = self.init_time.ok_or(IVPError::MissingParameters)?;
254        let end = self.init_end.ok_or(IVPError::MissingParameters)?;
255        let state = self.init_state.ok_or(IVPError::MissingParameters)?;
256        let derivative = self.init_derivative.ok_or(IVPError::MissingParameters)?;
257
258        let two = Self::Field::from_u8(2).ok_or(IVPError::FromPrimitiveFailure)?;
259        let half = two.recip();
260        let one_sixth = Self::Field::from_u8(6)
261            .ok_or(IVPError::FromPrimitiveFailure)?
262            .recip();
263        let one_tenth = Self::Field::from_u8(10)
264            .ok_or(IVPError::FromPrimitiveFailure)?
265            .recip();
266        let four = two * two;
267
268        let predictor_coefficients = BSVector::from_iterator(
269            A::predictor_coefficients()
270                .ok_or(IVPError::FromPrimitiveFailure)?
271                .as_slice()
272                .iter()
273                .cloned()
274                .map(Self::Field::from_real),
275        );
276
277        let corrector_coefficients = BSVector::from_iterator(
278            A::corrector_coefficients()
279                .ok_or(IVPError::FromPrimitiveFailure)?
280                .as_slice()
281                .iter()
282                .cloned()
283                .map(Self::Field::from_real),
284        );
285
286        let order = Self::Field::from_usize(O).ok_or(IVPError::FromPrimitiveFailure)?;
287
288        Ok(IVPIterator {
289            solver: AdamsSolver {
290                dt_max: Self::Field::from_real(dt_max.clone()),
291                dt_min: Self::Field::from_real(dt_min.clone()),
292                time: Self::Field::from_real(time),
293                end: Self::Field::from_real(end),
294                tolerance: Self::Field::from_real(tolerance),
295                dt: Self::Field::from_real(dt_max + dt_min) * half,
296                state,
297                derivative,
298                data,
299                predictor_coefficients,
300                corrector_coefficients,
301                error_coefficient: Self::Field::from_real(
302                    A::error_coefficient().ok_or(IVPError::FromPrimitiveFailure)?,
303                ),
304                prev_values: VecDeque::new(),
305                prev_derivatives: VecDeque::new(),
306                scratch_pad: BVector::from_element_generic(
307                    self.dim,
308                    U1::name(),
309                    Self::Field::zero(),
310                ),
311                implicit_derivs: BVector::from_element_generic(
312                    self.dim,
313                    U1::name(),
314                    Self::Field::zero(),
315                ),
316                save_state: BVector::from_element_generic(
317                    self.dim,
318                    U1::name(),
319                    Self::Field::zero(),
320                ),
321                one_tenth,
322                one_sixth,
323                half,
324                two,
325                four,
326                order,
327                yield_memory: 0,
328                _lifetime: PhantomData,
329            },
330            finished: false,
331            _dim: PhantomData,
332        })
333    }
334}
335
336impl<'a, N, D, const O: usize, T, F> AdamsSolver<'a, N, D, O, T, F>
337where
338    D: Dimension,
339    N: ComplexField + Copy,
340    T: Clone,
341    F: Derivative<N, D, T> + 'a,
342    DefaultAllocator: Allocator<N, D>,
343{
344    fn runge_kutta(&mut self, iterations: usize) -> Result<(), IVPError> {
345        for i in 0..iterations {
346            let k1 = (self.derivative)(
347                self.time.real(),
348                self.state.as_slice(),
349                &mut self.data.clone(),
350            )? * self.dt;
351            let intermediate = &self.state + &k1 * self.half;
352
353            let k2 = (self.derivative)(
354                (self.time + self.half * self.dt).real(),
355                intermediate.as_slice(),
356                &mut self.data.clone(),
357            )? * self.dt;
358            let intermediate = &self.state + &k2 * self.half;
359
360            let k3 = (self.derivative)(
361                (self.time + self.half * self.dt).real(),
362                intermediate.as_slice(),
363                &mut self.data.clone(),
364            )? * self.dt;
365            let intermediate = &self.state + &k3;
366
367            let k4 = (self.derivative)(
368                (self.time + self.dt).real(),
369                intermediate.as_slice(),
370                &mut self.data.clone(),
371            )? * self.dt;
372
373            if i != 0 {
374                self.prev_derivatives.push_back((self.derivative)(
375                    self.time.real(),
376                    self.state.as_slice(),
377                    &mut self.data,
378                )?);
379                self.prev_values
380                    .push_back((self.time.real(), self.state.clone()));
381            }
382
383            self.state += (k1 + k2 * self.two + k3 * self.two + k4) * self.one_sixth;
384            self.time += self.dt;
385        }
386        self.prev_derivatives.push_back((self.derivative)(
387            self.time.real(),
388            self.state.as_slice(),
389            &mut self.data,
390        )?);
391        self.prev_values
392            .push_back((self.time.real(), self.state.clone()));
393
394        Ok(())
395    }
396}
397
398impl<'a, N, D, const O: usize, T, F> IVPStepper<D> for AdamsSolver<'a, N, D, O, T, F>
399where
400    D: Dimension,
401    N: ComplexField + Copy,
402    T: Clone,
403    F: Derivative<N, D, T> + 'a,
404    DefaultAllocator: Allocator<N, D>,
405{
406    type Error = IVPError;
407    type Field = N;
408    type RealField = N::RealField;
409    type UserData = T;
410
411    fn step(&mut self) -> Step<Self::RealField, Self::Field, D, Self::Error> {
412        // If yield_memory is in [1, Order) then we have taken a runge-kutta step
413        // and committed to it (i.e. determined that we are within error bounds)
414        // If yield_memory is Order then we have taken a runge-kutta step but haven't
415        // checked if it is correct, so we don't want to yield the steps to the Iterator yet
416        if self.yield_memory > 0 && self.yield_memory < O {
417            let get_item = O - self.yield_memory - 1;
418            self.yield_memory -= 1;
419
420            // If this is the last runge-kutta step to be yielded,
421            // set yield_memory to the sentinel value O+1 so that the next step() call
422            // will yield the value in self.state (the adams step that was within
423            // tolerance after these runge-kutta steps)
424            if self.yield_memory == 0 {
425                self.yield_memory = O + 1;
426            }
427            return Ok(self.prev_values[get_item].clone());
428        }
429
430        // Sentinel value to signify that the runge-kutta steps are yielded
431        // and the solver can yield the adams step and continue as normal.
432        // The current state needs to be returned and pushed onto the memory deque.
433        // The derivatives memory deque already has the derivatives for this step,
434        // since the derivatives deque is unused while yielding runge-kutta steps
435        if self.yield_memory == O + 1 {
436            self.yield_memory = 0;
437            self.prev_values
438                .push_back((self.time.real(), self.state.clone()));
439            self.prev_values.pop_front();
440            return Ok((self.time.real(), self.state.clone()));
441        }
442
443        if self.time.real() >= self.end.real() {
444            return Err(IVPStatus::Done);
445        }
446
447        if self.time.real() + self.dt.real() >= self.end.real() {
448            self.dt = self.end - self.time;
449            self.runge_kutta(1)?;
450            return Ok((self.time.real(), self.prev_values.back().unwrap().1.clone()));
451        }
452
453        if self.prev_values.is_empty() {
454            self.save_state = self.state.clone();
455            if self.time.real() + self.dt.real() * (self.order - Self::Field::one()).real()
456                >= self.end.real()
457            {
458                self.dt = (self.end - self.time) / (self.order - Self::Field::one());
459            }
460            self.runge_kutta(O - 1)?;
461            self.yield_memory = O;
462
463            return Err(IVPStatus::Redo);
464        }
465
466        self.scratch_pad = &self.prev_derivatives[0] * self.predictor_coefficients[O - 2];
467        for i in 1..O - 1 {
468            let coefficient = self.predictor_coefficients[O - i - 2];
469            self.scratch_pad += &self.prev_derivatives[i] * coefficient;
470        }
471        let predictor = &self.state + &self.scratch_pad * self.dt;
472
473        self.implicit_derivs = (self.derivative)(
474            self.time.real() + self.dt.real(),
475            predictor.as_slice(),
476            &mut self.data.clone(),
477        )?;
478        self.scratch_pad = &self.implicit_derivs * self.corrector_coefficients[0];
479
480        for i in 0..O - 1 {
481            let coefficient = self.corrector_coefficients[O - i - 1];
482            self.scratch_pad += &self.prev_derivatives[i] * coefficient;
483        }
484        let corrector = &self.state + &self.scratch_pad * self.dt;
485
486        let difference = &corrector - &predictor;
487        let error = self.error_coefficient.real() / self.dt.real() * difference.norm();
488
489        if error <= self.tolerance.real() {
490            self.state = corrector;
491            self.time += self.dt;
492
493            // We have determined that this step passes the tolerance bounds.
494            // If yield_memory is non-zero, then we still need to yield the runge-kutta
495            // steps to the Iterator. We store the successful adams step in self.state,
496            // and self.time, decrement yield memory, and return (we never want to adjust the dt
497            // the step after adjusting it down). We return IVPStatus::Redo so IVPIterator
498            // calls again, yielding the runge-kutta steps.
499            if self.yield_memory == O {
500                self.yield_memory -= 1;
501                return Err(IVPStatus::Redo);
502            }
503
504            self.prev_derivatives
505                .push_back(self.implicit_derivs.clone());
506            self.prev_values
507                .push_back((self.time.real(), self.state.clone()));
508
509            self.prev_values.pop_front();
510            self.prev_derivatives.pop_front();
511
512            if error < self.one_tenth.real() * self.tolerance.real() {
513                let q = (self.tolerance.real() / (self.two.real() * error))
514                    .powf(self.order.recip().real());
515
516                if q > self.four.real() {
517                    self.dt *= self.four;
518                } else {
519                    self.dt *= Self::Field::from_real(q);
520                }
521
522                if self.dt.real() > self.dt_max.real() {
523                    self.dt = self.dt_max;
524                }
525
526                // Clear the saved steps since we have changed the timestep
527                // so we can no longer use linear interpolation.
528                self.prev_values.clear();
529                self.prev_derivatives.clear();
530            }
531
532            return Ok((self.time.real(), self.state.clone()));
533        }
534
535        // yield_memory can be Order here, meaning we speculatively tried a timestep and the lower timestep
536        // still didn't pass the tolerances.
537        // In this case, we need to return the state to what it was previously, before the runge-kutta steps,
538        // and reset the time to what it was previously.
539        if self.yield_memory == O {
540            // We took Order - 1 runge kutta steps at this dt
541            self.time -= self.dt * (self.order - Self::Field::one());
542            self.state = self.save_state.clone();
543        }
544
545        let q = (self.tolerance.real() / (self.two.real() * error.real()))
546            .powf(self.order.recip().real());
547
548        if q < self.one_tenth.real() {
549            self.dt *= self.one_tenth;
550        } else {
551            self.dt *= Self::Field::from_real(q);
552        }
553
554        if self.dt.real() < self.dt_min.real() {
555            return Err(IVPStatus::Failure(IVPError::MinimumTimeDeltaExceeded));
556        }
557
558        self.prev_values.clear();
559        self.prev_derivatives.clear();
560        Err(IVPStatus::Redo)
561    }
562
563    fn time(&self) -> Self::RealField {
564        self.time.real()
565    }
566}
567
568pub struct AdamsCoefficients5<N: ComplexField>(PhantomData<N>);
569
570impl<N: ComplexField> AdamsCoefficients<5> for AdamsCoefficients5<N> {
571    type RealField = N::RealField;
572
573    fn predictor_coefficients() -> Option<BSVector<Self::RealField, 5>> {
574        let twenty_four = Self::RealField::from_u8(24)?;
575
576        Some(BSVector::from_column_slice(&[
577            Self::RealField::from_u8(55)? / twenty_four.clone(),
578            -Self::RealField::from_u8(59)? / twenty_four.clone(),
579            Self::RealField::from_u8(37)? / twenty_four.clone(),
580            -Self::RealField::from_u8(9)? / twenty_four,
581            Self::RealField::zero(),
582        ]))
583    }
584
585    fn corrector_coefficients() -> Option<BSVector<Self::RealField, 5>> {
586        let seven_hundred_twenty = Self::RealField::from_u16(720)?;
587
588        Some(BSVector::from_column_slice(&[
589            Self::RealField::from_u8(251)? / seven_hundred_twenty.clone(),
590            Self::RealField::from_u16(646)? / seven_hundred_twenty.clone(),
591            -Self::RealField::from_u16(264)? / seven_hundred_twenty.clone(),
592            Self::RealField::from_u8(106)? / seven_hundred_twenty.clone(),
593            -Self::RealField::from_u8(19)? / seven_hundred_twenty,
594        ]))
595    }
596
597    fn error_coefficient() -> Option<Self::RealField> {
598        Some(Self::RealField::from_u8(19)? / Self::RealField::from_u16(270)?)
599    }
600}
601
602/// 5th order Adams predictor-corrector method for solving an IVP.
603///
604/// Defines the predictor and corrector coefficients, as well as
605/// the error coefficient. Uses [`Adams`] for the actual solving.
606///
607/// # Examples
608/// ```
609/// use std::error::Error;
610/// use bacon_sci::{BSVector, ivp::{IVPSolver, IVPError, adams::Adams5}};
611///
612/// fn derivatives(_t: f64, state: &[f64], _p: &mut ()) -> Result<BSVector<f64, 1>, Box<dyn Error>> {
613///     Ok(BSVector::from_column_slice(state))
614/// }
615///
616/// fn example() -> Result<(), IVPError> {
617///     let adams = Adams5::new()?
618///         .with_maximum_dt(0.1)?
619///         .with_minimum_dt(0.00001)?
620///         .with_tolerance(0.00001)?
621///         .with_initial_time(0.0)?
622///         .with_ending_time(1.0)?
623///         .with_initial_conditions_slice(&[1.0])?
624///         .with_derivative(derivatives)
625///         .solve(())?;
626///     let path = adams.collect_vec()?;
627///     for (time, state) in &path {
628///         assert!((time.exp() - state.column(0)[0]).abs() < 0.001);
629///     }
630///     Ok(())
631/// }
632/// ```
633pub type Adams5<'a, N, D, T, F> = Adams<'a, N, D, 5, T, F, AdamsCoefficients5<N>>;
634
635pub struct AdamsCoefficients3<N: ComplexField>(PhantomData<N>);
636
637impl<N: ComplexField + Copy> AdamsCoefficients<3> for AdamsCoefficients3<N> {
638    type RealField = N::RealField;
639
640    fn predictor_coefficients() -> Option<BSVector<Self::RealField, 3>> {
641        Some(BSVector::from_column_slice(&[
642            Self::RealField::one() + Self::RealField::from_u8(2)?.recip(),
643            -Self::RealField::from_u8(2)?.recip(),
644            Self::RealField::zero(),
645        ]))
646    }
647
648    fn corrector_coefficients() -> Option<BSVector<Self::RealField, 3>> {
649        Some(BSVector::from_column_slice(&[
650            Self::RealField::from_u8(5)? / Self::RealField::from_u8(12)?,
651            Self::RealField::from_u8(2)? / Self::RealField::from_u8(3)?,
652            -Self::RealField::from_u8(12)?.recip(),
653        ]))
654    }
655
656    fn error_coefficient() -> Option<Self::RealField> {
657        Some(Self::RealField::from_u8(19)? / Self::RealField::from_u16(270)?)
658    }
659}
660
661/// 3rd order Adams predictor-corrector method for solving an IVP.
662///
663/// Defines the predictor and corrector coefficients, as well as
664/// the error coefficient. Uses [`Adams`] for the actual solving.
665///
666/// # Examples
667/// ```
668/// use std::error::Error;
669/// use bacon_sci::{BSVector, ivp::{IVPSolver, IVPError, adams::Adams3}};
670///
671/// fn derivatives(_t: f64, state: &[f64], _p: &mut ()) -> Result<BSVector<f64, 1>, Box<dyn Error>> {
672///     Ok(BSVector::from_column_slice(state))
673/// }
674///
675///
676/// fn example() -> Result<(), IVPError> {
677///     let adams = Adams3::new()?
678///         .with_maximum_dt(0.1)?
679///         .with_minimum_dt(0.00001)?
680///         .with_tolerance(0.00001)?
681///         .with_initial_time(0.0)?
682///         .with_ending_time(1.0)?
683///         .with_initial_conditions_slice(&[1.0])?
684///         .with_derivative(derivatives)
685///         .solve(())?;
686///     let path = adams.collect_vec()?;
687///     for (time, state) in &path {
688///         assert!((time.exp() - state.column(0)[0]).abs() < 0.001);
689///     }
690///     Ok(())
691/// }
692/// ```
693pub type Adams3<'a, N, D, T, F> = Adams<'a, N, D, 3, T, F, AdamsCoefficients3<N>>;
694
695#[cfg(test)]
696mod test {
697    use super::*;
698    use crate::{ivp::IVPSolver, BSVector};
699    use std::error::Error;
700
701    fn exp_deriv(_: f64, y: &[f64], _: &mut ()) -> Result<BSVector<f64, 1>, Box<dyn Error>> {
702        Ok(BSVector::from_column_slice(y))
703    }
704
705    fn quadratic_deriv(t: f64, _y: &[f64], _: &mut ()) -> Result<BSVector<f64, 1>, Box<dyn Error>> {
706        Ok(BSVector::from_column_slice(&[-2.0 * t]))
707    }
708
709    fn sine_deriv(t: f64, y: &[f64], _: &mut ()) -> Result<BSVector<f64, 1>, Box<dyn Error>> {
710        Ok(BSVector::from_iterator(y.iter().map(|_| t.cos())))
711    }
712
713    // Test predictor-corrector for y=exp(t)
714    #[test]
715    fn adams5_exp() {
716        let t_initial = 0.0;
717        let t_final = 2.0;
718
719        let solver = Adams5::new()
720            .unwrap()
721            .with_minimum_dt(1e-5)
722            .unwrap()
723            .with_maximum_dt(0.1)
724            .unwrap()
725            .with_tolerance(0.0005)
726            .unwrap()
727            .with_initial_time(t_initial)
728            .unwrap()
729            .with_ending_time(t_final)
730            .unwrap()
731            .with_initial_conditions_slice(&[1.0])
732            .unwrap()
733            .with_derivative(exp_deriv)
734            .solve(())
735            .unwrap();
736
737        let path = solver.collect_vec().unwrap();
738
739        for step in &path {
740            assert!(approx_eq!(
741                f64,
742                step.1.column(0)[0],
743                step.0.exp(),
744                epsilon = 0.01
745            ));
746        }
747    }
748
749    #[test]
750    fn adams5_quadratic() {
751        let t_initial = 0.0;
752        let t_final = 5.0;
753
754        let solver = Adams5::new()
755            .unwrap()
756            .with_minimum_dt(1e-7)
757            .unwrap()
758            .with_maximum_dt(0.001)
759            .unwrap()
760            .with_tolerance(0.01)
761            .unwrap()
762            .with_initial_time(t_initial)
763            .unwrap()
764            .with_ending_time(t_final)
765            .unwrap()
766            .with_initial_conditions_slice(&[1.0])
767            .unwrap()
768            .with_derivative(quadratic_deriv)
769            .solve(())
770            .unwrap();
771
772        let path = solver.collect_vec().unwrap();
773
774        for step in &path {
775            assert!(approx_eq!(
776                f64,
777                step.1.column(0)[0],
778                1.0 - step.0.powi(2),
779                epsilon = 0.01
780            ));
781        }
782    }
783
784    #[test]
785    fn adams5_sine() {
786        let t_initial = 0.0;
787        let t_final = std::f64::consts::TAU;
788
789        let solver = Adams5::new()
790            .unwrap()
791            .with_minimum_dt(1e-5)
792            .unwrap()
793            .with_maximum_dt(0.001)
794            .unwrap()
795            .with_tolerance(0.01)
796            .unwrap()
797            .with_initial_time(t_initial)
798            .unwrap()
799            .with_ending_time(t_final)
800            .unwrap()
801            .with_initial_conditions_slice(&[0.0])
802            .unwrap()
803            .with_derivative(sine_deriv)
804            .solve(())
805            .unwrap();
806
807        let path = solver.collect_vec().unwrap();
808
809        for step in &path {
810            assert!(approx_eq!(
811                f64,
812                step.1.column(0)[0],
813                step.0.sin(),
814                epsilon = 0.01
815            ));
816        }
817    }
818
819    #[test]
820    fn adams3_exp() {
821        let t_initial = 0.0;
822        let t_final = 2.0;
823
824        let solver = Adams3::new()
825            .unwrap()
826            .with_minimum_dt(1e-5)
827            .unwrap()
828            .with_maximum_dt(0.1)
829            .unwrap()
830            .with_tolerance(0.001)
831            .unwrap()
832            .with_initial_time(t_initial)
833            .unwrap()
834            .with_ending_time(t_final)
835            .unwrap()
836            .with_initial_conditions_slice(&[1.0])
837            .unwrap()
838            .with_derivative(exp_deriv)
839            .solve(())
840            .unwrap();
841
842        let path = solver.collect_vec().unwrap();
843
844        for step in &path {
845            assert!(approx_eq!(
846                f64,
847                step.1.column(0)[0],
848                step.0.exp(),
849                epsilon = 0.01
850            ));
851        }
852    }
853
854    #[test]
855    fn adams3_quadratic() {
856        let t_initial = 0.0;
857        let t_final = 5.0;
858
859        let solver = Adams3::new()
860            .unwrap()
861            .with_minimum_dt(1e-5)
862            .unwrap()
863            .with_maximum_dt(0.001)
864            .unwrap()
865            .with_tolerance(0.1)
866            .unwrap()
867            .with_initial_time(t_initial)
868            .unwrap()
869            .with_ending_time(t_final)
870            .unwrap()
871            .with_initial_conditions_slice(&[1.0])
872            .unwrap()
873            .with_derivative(quadratic_deriv)
874            .solve(())
875            .unwrap();
876
877        let path = solver.collect_vec().unwrap();
878
879        for step in &path {
880            assert!(approx_eq!(
881                f64,
882                step.1.column(0)[0],
883                1.0 - step.0.powi(2),
884                epsilon = 0.01
885            ));
886        }
887    }
888
889    #[test]
890    fn adams3_sine() {
891        let t_initial = 0.0;
892        let t_final = std::f64::consts::TAU;
893
894        let solver = Adams3::new()
895            .unwrap()
896            .with_minimum_dt(1e-5)
897            .unwrap()
898            .with_maximum_dt(0.001)
899            .unwrap()
900            .with_tolerance(0.01)
901            .unwrap()
902            .with_initial_time(t_initial)
903            .unwrap()
904            .with_ending_time(t_final)
905            .unwrap()
906            .with_initial_conditions_slice(&[0.0])
907            .unwrap()
908            .with_derivative(sine_deriv)
909            .solve(())
910            .unwrap();
911
912        let path = solver.collect_vec().unwrap();
913
914        for step in &path {
915            assert!(approx_eq!(
916                f64,
917                step.1.column(0)[0],
918                step.0.sin(),
919                epsilon = 0.01
920            ));
921        }
922    }
923}