bacon_sci/ivp/
bdf.rs

1/* This file is part of bacon.
2 * Copyright (c) Wyatt Campbell.
3 *
4 * See repository LICENSE for information.
5 */
6
7use super::{Derivative, IVPError, IVPIterator, IVPSolver, IVPStatus, IVPStepper, Step, UserError};
8use crate::{BMatrix, BSVector, BVector, Dimension};
9use nalgebra::{
10    allocator::Allocator, ComplexField, DefaultAllocator, Dim, DimMin, DimName, RealField, U1,
11};
12use num_traits::{FromPrimitive, One, Zero};
13use std::collections::VecDeque;
14use std::marker::PhantomData;
15
16/// This trait defines an BDF solver
17/// The [`BDF`] struct takes an implemetation of this trait
18/// as a type argument since the algorithm is the same for
19/// all the orders, just the constants are different.
20pub trait BDFCoefficients<const O: usize> {
21    type RealField: RealField;
22
23    /// The polynomial interpolation coefficients for the higher-order
24    /// method. Should start
25    /// with the coefficient for the derivative
26    /// function without h, then n - 1. The
27    /// coefficients for the previous terms
28    /// should have the sign as if they're on the
29    /// same side of the = as the next state.
30    fn higher_coefficients() -> Option<BSVector<Self::RealField, O>>;
31
32    /// The polynomial interpolation coefficients for the lower-order
33    /// method. Must be
34    /// one less in length than higher_coefficients.
35    /// Should start with the coefficient for the
36    /// derivative function without h, then n-1. The
37    /// coefficients for the previous terms
38    /// should have the sign as if they're on the
39    /// same side of the = as the next state.
40    fn lower_coefficients() -> Option<BSVector<Self::RealField, O>>;
41}
42
43/// The nuts and bolts BDF solver
44/// Users won't use this directly if they aren't defining their own BDF
45/// Used as a common struct for the specific implementations
46pub struct BDF<'a, N, D, const O: usize, T, F, B>
47where
48    N: ComplexField + Copy,
49    D: Dimension,
50    T: Clone,
51    F: Derivative<N, D, T> + 'a,
52    B: BDFCoefficients<O, RealField = N::RealField>,
53    D: DimMin<D, Output = D>,
54    DefaultAllocator: Allocator<N, D>,
55    DefaultAllocator: Allocator<N, D, D>,
56{
57    init_dt_max: Option<N::RealField>,
58    init_dt_min: Option<N::RealField>,
59    init_time: Option<N::RealField>,
60    init_end: Option<N::RealField>,
61    init_tolerance: Option<N::RealField>,
62    init_state: Option<BVector<N, D>>,
63    init_derivative: Option<F>,
64    dim: D,
65    _data: PhantomData<&'a (T, B)>,
66}
67
68/// The solver for any BDF predictor-corrector
69/// Users should not use this type directly, and should
70/// instead get it from a specific BDF method struct
71/// (wrapped in an [`IVPIterator`])
72pub struct BDFSolver<'a, N, D, const O: usize, T, F>
73where
74    D: Dimension,
75    N: ComplexField + Copy,
76    T: Clone,
77    F: Derivative<N, D, T> + 'a,
78    D: DimMin<D, Output = D>,
79    DefaultAllocator: Allocator<N, D>,
80    DefaultAllocator: Allocator<N, D, D>,
81{
82    // Parameters set by the user
83    dt_max: N,
84    dt_min: N,
85    time: N,
86    end: N,
87    tolerance: N,
88    derivative: F,
89    dim: D,
90    data: T,
91
92    // Current solution at t = self.time
93    dt: N,
94    state: BVector<N, D>,
95
96    // Per-order constants set by an BDFCoefficients
97    higher_coefficients: BSVector<N, O>,
98    lower_coefficients: BSVector<N, O>,
99
100    // Previous steps to interpolate with
101    prev_values: VecDeque<(N::RealField, BVector<N, D>)>,
102
103    // A scratch vector to use during the algorithm (to avoid allocating & de-allocating every step)
104    scratch_pad: BVector<N, D>,
105    // A place to store solver state while taking speculative steps trying to find a good timestep
106    save_state: BVector<N, D>,
107
108    // Constants for the particular field
109    one_tenth: N,
110    one_sixth: N,
111    half: N,
112    two: N,
113
114    // generic parameter O in the type N
115    order: N,
116
117    // The number of items in prev_values that need to be yielded to the iterator
118    // due to a previous runge-kutta step
119    yield_memory: usize,
120
121    _lifetime: PhantomData<&'a ()>,
122}
123
124impl<'a, N, D, const O: usize, T, F, B> IVPSolver<'a, D> for BDF<'a, N, D, O, T, F, B>
125where
126    N: ComplexField + Copy,
127    D: Dimension,
128    T: Clone,
129    F: Derivative<N, D, T> + 'a,
130    B: BDFCoefficients<O, RealField = N::RealField>,
131    D: DimMin<D, Output = D>,
132    DefaultAllocator: Allocator<N, D>,
133    DefaultAllocator: Allocator<N, U1, D>,
134    DefaultAllocator: Allocator<N, D, D>,
135    DefaultAllocator: Allocator<(usize, usize), D>,
136{
137    type Error = IVPError;
138    type Field = N;
139    type RealField = N::RealField;
140    type Derivative = F;
141    type UserData = T;
142    type Solver = BDFSolver<'a, N, D, O, T, F>;
143
144    fn new() -> Result<Self, Self::Error> {
145        Ok(Self {
146            init_dt_max: None,
147            init_dt_min: None,
148            init_time: None,
149            init_end: None,
150            init_tolerance: None,
151            init_state: None,
152            init_derivative: None,
153            dim: D::dim()?,
154            _data: PhantomData,
155        })
156    }
157
158    fn new_dyn(size: usize) -> Result<Self, Self::Error> {
159        Ok(Self {
160            init_dt_max: None,
161            init_dt_min: None,
162            init_time: None,
163            init_end: None,
164            init_tolerance: None,
165            init_state: None,
166            init_derivative: None,
167            dim: D::dim_dyn(size)?,
168            _data: PhantomData,
169        })
170    }
171
172    fn dim(&self) -> D {
173        self.dim
174    }
175
176    fn with_tolerance(mut self, tol: Self::RealField) -> Result<Self, Self::Error> {
177        if tol <= <Self::RealField as Zero>::zero() {
178            return Err(IVPError::ToleranceOOB);
179        }
180        self.init_tolerance = Some(tol);
181        Ok(self)
182    }
183
184    /// Will overwrite any previously set value
185    /// If the provided maximum is less than a previously set minimum, then the minimum
186    /// is set to this value as well.
187    fn with_maximum_dt(mut self, max: Self::RealField) -> Result<Self, Self::Error> {
188        if max <= <Self::RealField as Zero>::zero() {
189            return Err(IVPError::TimeDeltaOOB);
190        }
191
192        self.init_dt_max = Some(max.clone());
193        if let Some(dt_min) = self.init_dt_min.as_mut() {
194            if *dt_min > max {
195                *dt_min = max;
196            }
197        }
198
199        Ok(self)
200    }
201
202    /// Will overwrite any previously set value
203    /// If the provided minimum is greatear than a previously set maximum, then the maximum
204    /// is set to this value as well.
205    fn with_minimum_dt(mut self, min: Self::RealField) -> Result<Self, Self::Error> {
206        if min <= <Self::RealField as Zero>::zero() {
207            return Err(IVPError::TimeDeltaOOB);
208        }
209
210        self.init_dt_min = Some(min.clone());
211        if let Some(dt_max) = self.init_dt_max.as_mut() {
212            if *dt_max < min {
213                *dt_max = min;
214            }
215        }
216
217        Ok(self)
218    }
219
220    fn with_initial_time(mut self, initial: Self::RealField) -> Result<Self, Self::Error> {
221        self.init_time = Some(initial.clone());
222
223        if let Some(end) = self.init_end.as_ref() {
224            if *end <= initial {
225                return Err(IVPError::TimeStartOOB);
226            }
227        }
228
229        Ok(self)
230    }
231
232    fn with_ending_time(mut self, ending: Self::RealField) -> Result<Self, Self::Error> {
233        self.init_end = Some(ending.clone());
234
235        if let Some(initial) = self.init_time.as_ref() {
236            if *initial >= ending {
237                return Err(IVPError::TimeEndOOB);
238            }
239        }
240
241        Ok(self)
242    }
243
244    fn with_initial_conditions(
245        mut self,
246        start: BVector<Self::Field, D>,
247    ) -> Result<Self, Self::Error> {
248        self.init_state = Some(start);
249        Ok(self)
250    }
251
252    fn with_derivative(mut self, derivative: Self::Derivative) -> Self {
253        self.init_derivative = Some(derivative);
254        self
255    }
256
257    fn solve(self, data: Self::UserData) -> Result<IVPIterator<D, Self::Solver>, Self::Error> {
258        let dt_max = self.init_dt_max.ok_or(IVPError::MissingParameters)?;
259        let dt_min = self.init_dt_min.ok_or(IVPError::MissingParameters)?;
260        let tolerance = self.init_tolerance.ok_or(IVPError::MissingParameters)?;
261        let time = self.init_time.ok_or(IVPError::MissingParameters)?;
262        let end = self.init_end.ok_or(IVPError::MissingParameters)?;
263        let state = self.init_state.ok_or(IVPError::MissingParameters)?;
264        let derivative = self.init_derivative.ok_or(IVPError::MissingParameters)?;
265
266        let two = Self::Field::from_u8(2).ok_or(IVPError::FromPrimitiveFailure)?;
267        let half = two.recip();
268        let one_sixth = Self::Field::from_u8(6)
269            .ok_or(IVPError::FromPrimitiveFailure)?
270            .recip();
271        let one_tenth = Self::Field::from_u8(10)
272            .ok_or(IVPError::FromPrimitiveFailure)?
273            .recip();
274
275        let higher_coefficients = BSVector::from_iterator(
276            B::higher_coefficients()
277                .ok_or(IVPError::FromPrimitiveFailure)?
278                .as_slice()
279                .iter()
280                .cloned()
281                .map(Self::Field::from_real),
282        );
283
284        let lower_coefficients = BSVector::from_iterator(
285            B::lower_coefficients()
286                .ok_or(IVPError::FromPrimitiveFailure)?
287                .as_slice()
288                .iter()
289                .cloned()
290                .map(Self::Field::from_real),
291        );
292
293        let order = Self::Field::from_usize(O).ok_or(IVPError::FromPrimitiveFailure)?;
294
295        Ok(IVPIterator {
296            solver: BDFSolver {
297                dt_max: Self::Field::from_real(dt_max.clone()),
298                dt_min: Self::Field::from_real(dt_min.clone()),
299                time: Self::Field::from_real(time),
300                end: Self::Field::from_real(end),
301                tolerance: Self::Field::from_real(tolerance),
302                dt: Self::Field::from_real(dt_max + dt_min) * half,
303                state,
304                derivative,
305                dim: self.dim,
306                data,
307                higher_coefficients,
308                lower_coefficients,
309                prev_values: VecDeque::new(),
310                scratch_pad: BVector::from_element_generic(
311                    self.dim,
312                    U1::name(),
313                    Self::Field::zero(),
314                ),
315                save_state: BVector::from_element_generic(
316                    self.dim,
317                    U1::from_usize(1),
318                    Self::Field::zero(),
319                ),
320                one_tenth,
321                one_sixth,
322                half,
323                two,
324                order,
325                yield_memory: 0,
326                _lifetime: PhantomData,
327            },
328            finished: false,
329            _dim: PhantomData,
330        })
331    }
332}
333
334impl<'a, N, D, const O: usize, T, F> BDFSolver<'a, N, D, O, T, F>
335where
336    N: ComplexField + Copy,
337    D: Dimension,
338    T: Clone,
339    F: Derivative<N, D, T> + 'a,
340    D: DimMin<D, Output = D>,
341    DefaultAllocator: Allocator<N, D>,
342    DefaultAllocator: Allocator<N, U1, D>,
343    DefaultAllocator: Allocator<N, D, D>,
344    DefaultAllocator: Allocator<(usize, usize), D>,
345{
346    fn runge_kutta(&mut self, iterations: usize) -> Result<(), IVPError> {
347        for i in 0..iterations {
348            let k1 = (self.derivative)(
349                self.time.real(),
350                self.state.as_slice(),
351                &mut self.data.clone(),
352            )? * self.dt;
353            let intermediate = &self.state + &k1 * self.half;
354
355            let k2 = (self.derivative)(
356                (self.time + self.half * self.dt).real(),
357                intermediate.as_slice(),
358                &mut self.data.clone(),
359            )? * self.dt;
360            let intermediate = &self.state + &k2 * self.half;
361
362            let k3 = (self.derivative)(
363                (self.time + self.half * self.dt).real(),
364                intermediate.as_slice(),
365                &mut self.data.clone(),
366            )? * self.dt;
367            let intermediate = &self.state + &k3;
368
369            let k4 = (self.derivative)(
370                (self.time + self.dt).real(),
371                intermediate.as_slice(),
372                &mut self.data.clone(),
373            )? * self.dt;
374
375            if i != 0 {
376                self.prev_values
377                    .push_back((self.time.real(), self.state.clone()));
378            }
379
380            self.state += (k1 + k2 * self.two + k3 * self.two + k4) * self.one_sixth;
381            self.time += self.dt;
382        }
383        self.prev_values
384            .push_back((self.time.real(), self.state.clone()));
385
386        Ok(())
387    }
388
389    // Used for secant method
390    fn jac_finite_diff<G>(
391        &mut self,
392        x: &mut BVector<N, D>,
393        g: &mut G,
394    ) -> Result<BMatrix<N, D, D>, IVPError>
395    where
396        G: FnMut(&mut Self, N::RealField, &[N], &mut T) -> Result<BVector<N, D>, UserError>,
397    {
398        let mut mat = BMatrix::from_element_generic(self.dim, self.dim, N::zero());
399        let denom = (self.two * self.dt).recip();
400
401        for (ind, mut col) in mat.column_iter_mut().enumerate() {
402            x[ind] += self.dt;
403            let above = g(self, self.time.real(), x.as_slice(), &mut self.data.clone())?;
404            x[ind] -= self.two * self.dt;
405            let below = g(self, self.time.real(), x.as_slice(), &mut self.data.clone())?;
406            x[ind] += self.dt;
407            col.set_column(0, &((above + below) * denom));
408        }
409
410        Ok(mat)
411    }
412
413    // Secant method for performing the BDF
414    fn secant<G>(&mut self, g: &mut G) -> Result<BVector<N, D>, IVPError>
415    where
416        G: FnMut(&mut Self, N::RealField, &[N], &mut T) -> Result<BVector<N, D>, UserError>,
417    {
418        let mut n = 2;
419
420        let mut guess = self.state.clone();
421        let mut derivative = g(
422            self,
423            self.time.real(),
424            guess.as_slice(),
425            &mut self.data.clone(),
426        )?;
427
428        let jac = self.jac_finite_diff(&mut guess, g)?;
429        let lu = jac.clone().lu();
430        let mut jac_inv = if let Some(inv) = lu.try_inverse() {
431            inv
432        } else {
433            let lu = jac.clone().full_piv_lu();
434            if let Some(inv) = lu.try_inverse() {
435                inv
436            } else {
437                let qr = jac.qr();
438                if let Some(inv) = qr.try_inverse() {
439                    inv
440                } else {
441                    return Err(IVPError::SingularMatrix);
442                }
443            }
444        };
445
446        let mut shift = -&jac_inv * &derivative;
447        guess += &shift;
448
449        while n < 1000 {
450            let derivative_last = derivative;
451            derivative = g(
452                self,
453                self.time.real(),
454                guess.as_slice(),
455                &mut self.data.clone(),
456            )?;
457
458            let difference = &derivative - &derivative_last;
459            let adjustment = -&jac_inv * difference;
460            let s_transpose = shift.clone().transpose();
461            let p = (-&s_transpose * &adjustment)[(0, 0)];
462            let u = s_transpose * &jac_inv;
463
464            jac_inv += (shift + adjustment) * u / p;
465            shift = -&jac_inv * &derivative;
466            guess += &shift;
467
468            if shift.norm() <= self.tolerance.real() {
469                return Ok(guess);
470            }
471            n += 1;
472        }
473
474        Err(IVPError::MaximumIterationsExceeded)
475    }
476}
477
478impl<'a, N, D, const O: usize, T, F> IVPStepper<D> for BDFSolver<'a, N, D, O, T, F>
479where
480    N: ComplexField + Copy,
481    D: Dimension,
482    T: Clone,
483    F: Derivative<N, D, T> + 'a,
484    D: DimMin<D, Output = D>,
485    DefaultAllocator: Allocator<N, D>,
486    DefaultAllocator: Allocator<N, U1, D>,
487    DefaultAllocator: Allocator<N, D, D>,
488    DefaultAllocator: Allocator<(usize, usize), D>,
489{
490    type Error = IVPError;
491    type Field = N;
492    type RealField = N::RealField;
493    type UserData = T;
494
495    fn step(&mut self) -> Step<Self::RealField, Self::Field, D, Self::Error> {
496        // If yield_memory is in [1, Order] then we have taken a runge-kutta step
497        // and committed to it (i.e. determined that we are within error bounds)
498        // If yield_memory is Order+1 then we have taken a runge-kutta step but haven't
499        // checked if it is correct, so we don't want to yield the steps to the Iterator yet
500        if self.yield_memory > 0 && self.yield_memory <= O {
501            let get_item = O - self.yield_memory;
502            self.yield_memory -= 1;
503
504            // If this is the last runge-kutta step to be yielded,
505            // set yield_memory to the sentinel value O+2 so that the next step() call
506            // will yield the value in self.state (the bdf step that was within
507            // tolerance after these runge-kutta steps)
508            if self.yield_memory == 0 {
509                self.yield_memory = O + 2;
510            }
511            return Ok(self.prev_values[get_item].clone());
512        }
513
514        // Sentinel value to signify that the runge-kutta steps are yielded
515        // and the solver can yield the bdf step and continue as normal.
516        // The current state needs to be returned and pushed onto the memory deque.
517        // The derivatives memory deque already has the derivatives for this step,
518        // since the derivatives deque is unused while yielding runge-kutta steps
519        if self.yield_memory == O + 2 {
520            self.yield_memory = 0;
521            self.prev_values
522                .push_back((self.time.real(), self.state.clone()));
523            self.prev_values.pop_front();
524            return Ok((self.time.real(), self.state.clone()));
525        }
526
527        if self.time.real() >= self.end.real() {
528            return Err(IVPStatus::Done);
529        }
530
531        if self.time.real() + self.dt.real() >= self.end.real() {
532            self.dt = self.end - self.time;
533            self.runge_kutta(1)?;
534            return Ok((self.time.real(), self.prev_values.back().unwrap().1.clone()));
535        }
536
537        if self.prev_values.is_empty() {
538            self.save_state = self.state.clone();
539            if self.time.real() + self.dt.real() * self.order.real() >= self.end.real() {
540                self.dt = (self.end - self.time) / self.order;
541            }
542            self.runge_kutta(O)?;
543            self.yield_memory = O + 1;
544
545            return Err(IVPStatus::Redo);
546        }
547
548        let mut higher_func = |bdf: &mut Self,
549                               t: Self::RealField,
550                               y: &[N],
551                               data: &mut T|
552         -> Result<BVector<N, D>, UserError> {
553            bdf.scratch_pad = -(bdf.derivative)(t, y, data)? * bdf.dt * bdf.higher_coefficients[0];
554            for (ind, &coeff) in bdf.higher_coefficients.column(0).iter().enumerate().skip(1) {
555                bdf.scratch_pad += &bdf.prev_values[O - ind].1 * coeff;
556            }
557            Ok(
558                bdf.scratch_pad.clone()
559                    + BVector::from_column_slice_generic(bdf.dim, U1::name(), y),
560            )
561        };
562        let mut lower_func = |bdf: &mut Self,
563                              t: Self::RealField,
564                              y: &[N],
565                              data: &mut T|
566         -> Result<BVector<N, D>, UserError> {
567            bdf.scratch_pad = -(bdf.derivative)(t, y, data)? * bdf.dt * bdf.lower_coefficients[0];
568            for (ind, &coeff) in bdf.higher_coefficients.column(0).iter().enumerate().skip(1) {
569                bdf.scratch_pad += &bdf.prev_values[O - ind].1 * coeff;
570            }
571            Ok(
572                bdf.scratch_pad.clone()
573                    + BVector::from_column_slice_generic(bdf.dim, U1::name(), y),
574            )
575        };
576
577        let higher_step = self.secant(&mut higher_func)?;
578        let lower_step = self.secant(&mut lower_func)?;
579
580        let difference = &higher_step - &lower_step;
581        let error = difference.norm();
582
583        if error <= self.tolerance.real() {
584            self.state = higher_step;
585            self.time += self.dt;
586
587            // We have determined that this step passes the tolerance bounds.
588            // If yield_memory is non-zero, then we still need to yield the runge-kutta
589            // steps to the Iterator. We store the successful bdf step in self.state,
590            // and self.time, decrement yield memory, and return (we never want to adjust the dt
591            // the step after adjusting it down). We return IVPStatus::Redo so IVPIterator
592            // calls again, yielding the runge-kutta steps.
593            if self.yield_memory == O + 1 {
594                self.yield_memory -= 1;
595                return Err(IVPStatus::Redo);
596            }
597
598            self.prev_values
599                .push_back((self.time.real(), self.state.clone()));
600            self.prev_values.pop_front();
601
602            if error < self.one_tenth.real() * self.tolerance.real() {
603                self.dt *= self.two;
604                if self.dt.real() > self.dt_max.real() {
605                    self.dt = self.dt_max;
606                }
607
608                // Clear the saved steps since we have changed the timestep
609                // so we can no longer use linear interpolation.
610                self.prev_values.clear();
611            }
612
613            return Ok((self.time.real(), self.state.clone()));
614        }
615
616        // yield_memory can be Order+1 here, meaning we speculatively tried a timestep and the lower timestep
617        // still didn't pass the tolerances.
618        // In this case, we need to return the state to what it was previously, before the runge-kutta steps,
619        // and reset the time to what it was previously.
620        if self.yield_memory == O + 1 {
621            // We took Order - 1 runge kutta steps at this dt
622            self.time -= self.dt - self.order;
623            self.state = self.save_state.clone();
624        }
625
626        self.dt *= self.half;
627
628        if self.dt.real() < self.dt_min.real() {
629            return Err(IVPStatus::Failure(IVPError::MinimumTimeDeltaExceeded));
630        }
631
632        self.prev_values.clear();
633        Err(IVPStatus::Redo)
634    }
635
636    fn time(&self) -> Self::RealField {
637        self.time.real()
638    }
639}
640
641pub struct BDF6Coefficients<N: ComplexField>(PhantomData<N>);
642
643impl<N: ComplexField> BDFCoefficients<7> for BDF6Coefficients<N> {
644    type RealField = N::RealField;
645
646    fn higher_coefficients() -> Option<BSVector<Self::RealField, 7>> {
647        let one_hundred_forty_seven = Self::RealField::from_u8(147)?;
648
649        Some(BSVector::from_column_slice(&[
650            Self::RealField::from_u8(60)? / one_hundred_forty_seven.clone(),
651            -Self::RealField::from_u16(360)? / one_hundred_forty_seven.clone(),
652            Self::RealField::from_u16(450)? / one_hundred_forty_seven.clone(),
653            -Self::RealField::from_u16(400)? / one_hundred_forty_seven.clone(),
654            Self::RealField::from_u8(225)? / one_hundred_forty_seven.clone(),
655            -Self::RealField::from_u8(72)? / one_hundred_forty_seven.clone(),
656            Self::RealField::from_u8(10)? / one_hundred_forty_seven,
657        ]))
658    }
659
660    fn lower_coefficients() -> Option<BSVector<Self::RealField, 7>> {
661        let one_hundred_thirty_seven = Self::RealField::from_u8(137)?;
662
663        Some(BSVector::from_column_slice(&[
664            Self::RealField::from_u8(60)? / one_hundred_thirty_seven.clone(),
665            -Self::RealField::from_u16(300)? / one_hundred_thirty_seven.clone(),
666            Self::RealField::from_u16(300)? / one_hundred_thirty_seven.clone(),
667            -Self::RealField::from_u8(200)? / one_hundred_thirty_seven.clone(),
668            Self::RealField::from_u8(75)? / one_hundred_thirty_seven.clone(),
669            -Self::RealField::from_u8(12)? / one_hundred_thirty_seven,
670            Self::RealField::zero(),
671        ]))
672    }
673}
674
675/// 6th order backwards differentiation formula method for
676/// solving an initial value problem.
677///
678/// Defines the higher and lower order coefficients. Uses
679/// [`BDFInfo`] for the actual solving.
680///
681/// # Examples
682/// ```
683/// use std::error::Error;
684/// use bacon_sci::{BSVector, ivp::{IVPSolver, IVPError, bdf::BDF6}};
685///
686/// fn derivatives(_t: f64, state: &[f64], _p: &mut ()) -> Result<BSVector<f64, 1>, Box<dyn Error>> {
687///     Ok(-BSVector::from_column_slice(state))
688/// }
689///
690/// fn example() -> Result<(), IVPError> {
691///     let bdf = BDF6::new()?
692///         .with_maximum_dt(0.1)?
693///         .with_minimum_dt(0.00001)?
694///         .with_tolerance(0.00001)?
695///         .with_initial_time(0.0)?
696///         .with_ending_time(10.0)?
697///         .with_initial_conditions_slice(&[1.0])?
698///         .with_derivative(derivatives)
699///         .solve(())?;
700///     let path = bdf.collect_vec()?;
701///     for (time, state) in &path {
702///         assert!(((-time).exp() - state.column(0)[0]).abs() < 0.001);
703///     }
704///     Ok(())
705/// }
706pub type BDF6<'a, N, D, T, F> = BDF<'a, N, D, 7, T, F, BDF6Coefficients<N>>;
707
708pub struct BDF2Coefficients<N: ComplexField>(PhantomData<N>);
709
710impl<N: ComplexField> BDFCoefficients<3> for BDF2Coefficients<N> {
711    type RealField = N::RealField;
712
713    fn higher_coefficients() -> Option<BSVector<Self::RealField, 3>> {
714        let three = Self::RealField::from_u8(3)?;
715
716        Some(BSVector::from_column_slice(&[
717            Self::RealField::from_u8(2)? / three.clone(),
718            -Self::RealField::from_u8(4)? / three.clone(),
719            three.recip(),
720        ]))
721    }
722
723    fn lower_coefficients() -> Option<BSVector<Self::RealField, 3>> {
724        Some(BSVector::from_column_slice(&[
725            Self::RealField::one(),
726            -Self::RealField::one(),
727            Self::RealField::zero(),
728        ]))
729    }
730}
731
732/// 2nd order backwards differentiation formula method for
733/// solving an initial value problem.
734///
735/// Defines the higher and lower order coefficients. Uses
736/// [`BDFInfo`] for the actual solving.
737///
738/// # Examples
739/// ```
740/// use std::error::Error;
741/// use bacon_sci::{BSVector, ivp::{IVPSolver, IVPError, bdf::BDF2}};
742/// fn derivatives(_t: f64, state: &[f64], _p: &mut ()) -> Result<BSVector<f64, 1>, Box<dyn Error>> {
743///     Ok(-BSVector::from_column_slice(state))
744/// }
745///
746/// fn example() -> Result<(), IVPError> {
747///     let bdf = BDF2::new()?
748///         .with_maximum_dt(0.1)?
749///         .with_minimum_dt(0.00001)?
750///         .with_tolerance(0.00001)?
751///         .with_initial_time(0.0)?
752///         .with_ending_time(10.0)?
753///         .with_initial_conditions_slice(&[1.0])?
754///         .with_derivative(derivatives)
755///         .solve(())?;
756///     let path = bdf.collect_vec()?;
757///     for (time, state) in &path {
758///         assert!(((-time).exp() - state.column(0)[0]).abs() < 0.001);
759///     }
760///     Ok(())
761/// }
762pub type BDF2<'a, N, D, T, F> = BDF<'a, N, D, 3, T, F, BDF2Coefficients<N>>;
763
764#[cfg(test)]
765mod test {
766    use super::*;
767    use crate::BSVector;
768
769    fn exp_deriv(_: f64, y: &[f64], _: &mut ()) -> Result<BSVector<f64, 1>, UserError> {
770        Ok(BSVector::from_column_slice(y))
771    }
772
773    fn quadratic_deriv(t: f64, _y: &[f64], _: &mut ()) -> Result<BSVector<f64, 1>, UserError> {
774        Ok(BSVector::from_column_slice(&[-2.0 * t]))
775    }
776
777    fn sine_deriv(t: f64, y: &[f64], _: &mut ()) -> Result<BSVector<f64, 1>, UserError> {
778        Ok(BSVector::from_iterator(y.iter().map(|_| t.cos())))
779    }
780
781    fn unstable_deriv(_: f64, y: &[f64], _: &mut ()) -> Result<BSVector<f64, 1>, UserError> {
782        Ok(-BSVector::from_column_slice(y))
783    }
784
785    #[test]
786    fn bdf6_exp() {
787        let t_initial = 0.0;
788        let t_final = 7.0;
789
790        let solver = BDF6::new()
791            .unwrap()
792            .with_minimum_dt(1e-5)
793            .unwrap()
794            .with_maximum_dt(0.1)
795            .unwrap()
796            .with_tolerance(0.00001)
797            .unwrap()
798            .with_initial_time(t_initial)
799            .unwrap()
800            .with_ending_time(t_final)
801            .unwrap()
802            .with_initial_conditions_slice(&[1.0])
803            .unwrap()
804            .with_derivative(exp_deriv)
805            .solve(())
806            .unwrap();
807
808        let path = solver.collect_vec().unwrap();
809
810        for step in &path {
811            assert!(approx_eq!(
812                f64,
813                step.1.column(0)[0],
814                step.0.exp(),
815                epsilon = 0.01
816            ));
817        }
818    }
819
820    #[test]
821    fn bdf6_unstable() {
822        let t_initial = 0.0;
823        let t_final = 10.0;
824
825        let solver = BDF6::new()
826            .unwrap()
827            .with_minimum_dt(1e-5)
828            .unwrap()
829            .with_maximum_dt(0.1)
830            .unwrap()
831            .with_tolerance(0.00001)
832            .unwrap()
833            .with_initial_time(t_initial)
834            .unwrap()
835            .with_ending_time(t_final)
836            .unwrap()
837            .with_initial_conditions_slice(&[1.0])
838            .unwrap()
839            .with_derivative(unstable_deriv)
840            .solve(())
841            .unwrap();
842
843        let path = solver.collect_vec().unwrap();
844
845        for step in &path {
846            assert!(approx_eq!(
847                f64,
848                step.1.column(0)[0],
849                (-step.0).exp(),
850                epsilon = 0.01
851            ));
852        }
853    }
854
855    #[test]
856    fn bdf6_quadratic() {
857        let t_initial = 0.0;
858        let t_final = 2.0;
859
860        let solver = BDF6::new()
861            .unwrap()
862            .with_minimum_dt(1e-5)
863            .unwrap()
864            .with_maximum_dt(0.1)
865            .unwrap()
866            .with_tolerance(0.00001)
867            .unwrap()
868            .with_initial_time(t_initial)
869            .unwrap()
870            .with_ending_time(t_final)
871            .unwrap()
872            .with_initial_conditions_slice(&[1.0])
873            .unwrap()
874            .with_derivative(quadratic_deriv)
875            .solve(())
876            .unwrap();
877
878        let path = solver.collect_vec().unwrap();
879
880        for step in &path {
881            assert!(approx_eq!(
882                f64,
883                step.1.column(0)[0],
884                1.0 - step.0.powi(2),
885                epsilon = 0.01
886            ));
887        }
888    }
889
890    #[test]
891    fn bdf6_sin() {
892        let t_initial = 0.0;
893        let t_final = 6.0;
894
895        let solver = BDF6::new()
896            .unwrap()
897            .with_minimum_dt(1e-5)
898            .unwrap()
899            .with_maximum_dt(0.1)
900            .unwrap()
901            .with_tolerance(0.00001)
902            .unwrap()
903            .with_initial_time(t_initial)
904            .unwrap()
905            .with_ending_time(t_final)
906            .unwrap()
907            .with_initial_conditions_slice(&[0.0])
908            .unwrap()
909            .with_derivative(sine_deriv)
910            .solve(())
911            .unwrap();
912
913        let path = solver.collect_vec().unwrap();
914
915        for step in &path {
916            assert!(approx_eq!(
917                f64,
918                step.1.column(0)[0],
919                step.0.sin(),
920                epsilon = 0.01
921            ));
922        }
923    }
924
925    #[test]
926    fn bdf2_exp() {
927        let t_initial = 0.0;
928        let t_final = 7.0;
929
930        let solver = BDF2::new()
931            .unwrap()
932            .with_minimum_dt(1e-5)
933            .unwrap()
934            .with_maximum_dt(0.1)
935            .unwrap()
936            .with_tolerance(0.00001)
937            .unwrap()
938            .with_initial_time(t_initial)
939            .unwrap()
940            .with_ending_time(t_final)
941            .unwrap()
942            .with_initial_conditions_slice(&[1.0])
943            .unwrap()
944            .with_derivative(exp_deriv)
945            .solve(())
946            .unwrap();
947
948        let path = solver.collect_vec().unwrap();
949
950        for step in &path {
951            assert!(approx_eq!(
952                f64,
953                step.1.column(0)[0],
954                step.0.exp(),
955                epsilon = 0.01
956            ));
957        }
958    }
959
960    #[test]
961    fn bdf2_unstable() {
962        let t_initial = 0.0;
963        let t_final = 10.0;
964
965        let solver = BDF2::new()
966            .unwrap()
967            .with_minimum_dt(1e-5)
968            .unwrap()
969            .with_maximum_dt(0.1)
970            .unwrap()
971            .with_tolerance(0.00001)
972            .unwrap()
973            .with_initial_time(t_initial)
974            .unwrap()
975            .with_ending_time(t_final)
976            .unwrap()
977            .with_initial_conditions_slice(&[1.0])
978            .unwrap()
979            .with_derivative(unstable_deriv)
980            .solve(())
981            .unwrap();
982
983        let path = solver.collect_vec().unwrap();
984
985        for step in &path {
986            assert!(approx_eq!(
987                f64,
988                step.1.column(0)[0],
989                (-step.0).exp(),
990                epsilon = 0.01
991            ));
992        }
993    }
994
995    #[test]
996    fn bdf2_quadratic() {
997        let t_initial = 0.0;
998        let t_final = 1.0;
999
1000        let solver = BDF2::new()
1001            .unwrap()
1002            .with_minimum_dt(1e-5)
1003            .unwrap()
1004            .with_maximum_dt(0.1)
1005            .unwrap()
1006            .with_tolerance(0.00001)
1007            .unwrap()
1008            .with_initial_time(t_initial)
1009            .unwrap()
1010            .with_ending_time(t_final)
1011            .unwrap()
1012            .with_initial_conditions_slice(&[1.0])
1013            .unwrap()
1014            .with_derivative(quadratic_deriv)
1015            .solve(())
1016            .unwrap();
1017
1018        let path = solver.collect_vec().unwrap();
1019
1020        for step in &path {
1021            assert!(approx_eq!(
1022                f64,
1023                step.1.column(0)[0],
1024                1.0 - step.0.powi(2),
1025                epsilon = 0.01
1026            ));
1027        }
1028    }
1029
1030    #[test]
1031    fn bdf2_sin() {
1032        let t_initial = 0.0;
1033        let t_final = 6.0;
1034
1035        let solver = BDF2::new()
1036            .unwrap()
1037            .with_minimum_dt(1e-5)
1038            .unwrap()
1039            .with_maximum_dt(0.1)
1040            .unwrap()
1041            .with_tolerance(0.00001)
1042            .unwrap()
1043            .with_initial_time(t_initial)
1044            .unwrap()
1045            .with_ending_time(t_final)
1046            .unwrap()
1047            .with_initial_conditions_slice(&[0.0])
1048            .unwrap()
1049            .with_derivative(sine_deriv)
1050            .solve(())
1051            .unwrap();
1052
1053        let path = solver.collect_vec().unwrap();
1054
1055        for step in &path {
1056            assert!(approx_eq!(
1057                f64,
1058                step.1.column(0)[0],
1059                step.0.sin(),
1060                epsilon = 0.01
1061            ));
1062        }
1063    }
1064}