bacon_sci/
ivp.rs

1/* This file is part of bacon.
2 * Copyright (c) Wyatt Campbell.
3 *
4 * See repository LICENSE for information.
5 */
6
7use crate::{BVector, Dimension, DimensionError};
8use nalgebra::{allocator::Allocator, ComplexField, DefaultAllocator, Dim, RealField, U1};
9use num_traits::{FromPrimitive, Zero};
10use std::{error::Error, marker::PhantomData};
11use thiserror::Error;
12
13pub mod adams;
14pub mod bdf;
15pub mod rk;
16
17/// Status returned from the [`IVPStepper`]
18/// Used by the [`IVPIterator`] struct to correctly step through
19/// the IVP solution.
20#[derive(Error, Clone, Debug)]
21pub enum IVPStatus<T: Error> {
22    #[error("the solver needs the step to be re-done")]
23    Redo,
24    #[error("the solver is complete")]
25    Done,
26    #[error("unspecified solver error: {0}")]
27    Failure(#[from] T),
28}
29
30/// An error generated in a derivative function
31pub type UserError = Box<dyn Error>;
32
33/// A function that can be used as a derivative for the solver
34pub trait Derivative<N: ComplexField + Copy, D: Dim, T: Clone>:
35    FnMut(N::RealField, &[N], &mut T) -> Result<BVector<N, D>, UserError>
36where
37    DefaultAllocator: Allocator<N, D>,
38{
39}
40
41impl<N, D: Dim, T, F> Derivative<N, D, T> for F
42where
43    N: ComplexField + Copy,
44    T: Clone,
45    F: FnMut(N::RealField, &[N], &mut T) -> Result<BVector<N, D>, UserError>,
46    DefaultAllocator: Allocator<N, D>,
47{
48}
49
50#[derive(Error, Debug)]
51pub enum IVPError {
52    #[error("the solver does not have all required parameters set")]
53    MissingParameters,
54    #[error("the solver hit an error from the user-provided derivative function: {0}")]
55    UserError(#[from] UserError),
56    #[error("the provided tolerance was out-of-bounds")]
57    ToleranceOOB,
58    #[error("the provided time delta was out-of-bounds")]
59    TimeDeltaOOB,
60    #[error("the provided ending time was before the provided starting time")]
61    TimeEndOOB,
62    #[error("the provided starting time was after the provided ending time")]
63    TimeStartOOB,
64    #[error("a conversion from a necessary primitive failed")]
65    FromPrimitiveFailure,
66    #[error("the time step fell below the paramater minimum allowed value")]
67    MinimumTimeDeltaExceeded,
68    #[error("the number of iterations exceeded the maximum allowable")]
69    MaximumIterationsExceeded,
70    #[error("a matrix was unable to be inverted")]
71    SingularMatrix,
72    #[error("attempted to build a dynamic solver with static dimension")]
73    DynamicOnStatic,
74    #[error("attempted to build a static solver with dynamic dimension")]
75    StaticOnDynamic,
76}
77
78impl From<UserError> for IVPStatus<IVPError> {
79    fn from(value: UserError) -> Self {
80        Self::Failure(IVPError::UserError(value))
81    }
82}
83
84impl From<DimensionError> for IVPError {
85    fn from(value: DimensionError) -> Self {
86        match value {
87            DimensionError::DynamicOnStatic => Self::DynamicOnStatic,
88            DimensionError::StaticOnDynamic => Self::StaticOnDynamic,
89        }
90    }
91}
92
93/// A type alias for a Result of a [`IVPStepper`] step
94/// Ok is a tuple of the time and solution at that time
95/// Err is an IVPError
96pub type Step<R, C, D, E> = Result<(R, BVector<C, D>), IVPStatus<E>>;
97
98/// Implementing this trait is providing the main functionality of
99/// an initial value problem solver. This should be used only when
100/// implementing an [`IVPSolver`], users should use the solver via the [`IVPSolver`]
101/// trait's interface.
102pub trait IVPStepper<D: Dimension>: Sized
103where
104    DefaultAllocator: Allocator<Self::Field, D>,
105{
106    /// Error type. IVPError must be able to convert to the error type.
107    type Error: Error + From<IVPError>;
108    /// The field, complex or real, that the solver is operating on.
109    type Field: ComplexField + Copy;
110    /// The real field associated with the solver's Field.
111    type RealField: RealField;
112    /// Arbitrary data provided by the user for the derivative function
113    /// It must be clone because for any intermediate time steps (e.g. in runge-kutta)
114    /// gives the derivative function a clone of the params: only normal time steps get to update
115    /// the internal UserData state
116    type UserData: Clone;
117
118    /// Step forward in the solver.
119    /// The solver may request a step be redone, signal that the
120    /// solution is finished, or give the value of the next solution value.
121    fn step(&mut self) -> Step<Self::RealField, Self::Field, D, Self::Error>;
122
123    /// Get the current time of the solver.
124    fn time(&self) -> Self::RealField;
125}
126
127/// Trait covering all initial value problem solvers.
128/// Build up the solver using the parameter builder functions and then use solve.
129///
130/// This is used as a builder pattern, setting parameters of the solver.
131/// [`IVPSolver`] implementations should implement a step function that
132/// returns an IVPStatus, then a blanket impl will allow it to be used as an
133/// IntoIterator for the user to iterate over the results.
134pub trait IVPSolver<'a, D: Dimension>: Sized
135where
136    DefaultAllocator: Allocator<Self::Field, D>,
137{
138    /// Error type. IVPError must be able to convert to the error type.
139    type Error: Error + From<IVPError>;
140    /// The field, complex or real, that the solver is operating on.
141    type Field: ComplexField + Copy;
142    /// The real field associated with the solver's Field.
143    type RealField: RealField;
144    /// Arbitrary data provided by the user for the derivative function
145    type UserData: Clone;
146    /// The type signature of the derivative function to use
147    type Derivative: Derivative<Self::Field, D, Self::UserData> + 'a;
148    /// The type that actually does the solving.
149    type Solver: IVPStepper<
150        D,
151        Error = Self::Error,
152        Field = Self::Field,
153        RealField = Self::RealField,
154        UserData = Self::UserData,
155    >;
156
157    /// Create the solver.
158    /// Will fail for dynamically sized solvers
159    fn new() -> Result<Self, Self::Error>;
160
161    /// Create the solver with a run-time dimension.
162    /// Will fail for statically sized solvers
163    fn new_dyn(size: usize) -> Result<Self, Self::Error>;
164
165    /// Gets the dimension of the solver
166    fn dim(&self) -> D;
167
168    /// Set the error tolerance for any condition needing needing a float epsilon
169    fn with_tolerance(self, tol: Self::RealField) -> Result<Self, Self::Error>;
170
171    fn with_maximum_dt(self, max: Self::RealField) -> Result<Self, Self::Error>;
172    fn with_minimum_dt(self, min: Self::RealField) -> Result<Self, Self::Error>;
173    fn with_initial_time(self, initial: Self::RealField) -> Result<Self, Self::Error>;
174    fn with_ending_time(self, ending: Self::RealField) -> Result<Self, Self::Error>;
175
176    /// The initial conditions of the problem, should reset any previous values.
177    fn with_initial_conditions_slice(self, start: &[Self::Field]) -> Result<Self, Self::Error> {
178        let svector = BVector::from_column_slice_generic(self.dim(), U1::from_usize(1), start);
179        self.with_initial_conditions(svector)
180    }
181
182    /// The initial conditions of the problem, in a BVector. Should reset any previous values.
183    fn with_initial_conditions(self, start: BVector<Self::Field, D>) -> Result<Self, Self::Error>;
184
185    /// Sets the derivative function to use during the solve
186    fn with_derivative(self, derivative: Self::Derivative) -> Self;
187
188    /// Turns the solver into an iterator over the solution, using IVPStep::step
189    fn solve(self, data: Self::UserData) -> Result<IVPIterator<D, Self::Solver>, Self::Error>;
190}
191
192pub struct IVPIterator<D: Dimension, T: IVPStepper<D>>
193where
194    DefaultAllocator: Allocator<T::Field, D>,
195{
196    solver: T,
197    finished: bool,
198    _dim: PhantomData<D>,
199}
200
201/// A type alias for collecting all Steps into a Result
202/// of a Vec of the solution ((time, system state))
203pub type Path<R, C, D, E> = Result<Vec<(R, BVector<C, D>)>, E>;
204
205impl<D: Dimension, T: IVPStepper<D>> IVPIterator<D, T>
206where
207    DefaultAllocator: Allocator<T::Field, D>,
208{
209    pub fn collect_vec(self) -> Path<T::RealField, T::Field, D, T::Error> {
210        self.collect::<Result<Vec<_>, _>>()
211    }
212}
213
214impl<D: Dimension, T: IVPStepper<D>> Iterator for IVPIterator<D, T>
215where
216    DefaultAllocator: Allocator<T::Field, D>,
217{
218    type Item = Result<(T::RealField, BVector<T::Field, D>), T::Error>;
219
220    fn next(&mut self) -> Option<Self::Item> {
221        use IVPStatus as IE;
222
223        if self.finished {
224            return None;
225        }
226
227        loop {
228            match self.solver.step() {
229                Ok(vec) => break Some(Ok(vec)),
230                Err(IE::Done) => break None,
231                Err(IE::Redo) => continue,
232                Err(IE::Failure(e)) => {
233                    self.finished = true;
234                    break Some(Err(e));
235                }
236            }
237        }
238    }
239}
240
241/// Euler solver for an IVP.
242///
243/// Solves an initial value problem using Euler's method.
244///
245/// # Examples
246/// ```
247/// use std::error::Error;
248/// use bacon_sci::{BSVector, ivp::{Euler, IVPSolver, IVPError}};
249/// fn derivative(_t: f64, state: &[f64], _p: &mut ()) -> Result<BSVector<f64, 1>, Box<dyn Error>> {
250///     Ok(BSVector::<f64, 1>::from_column_slice(state))
251/// }
252///
253/// fn example() -> Result<(), IVPError> {
254///     let solver = Euler::new()?
255///         .with_maximum_dt(0.001)?
256///         .with_initial_conditions_slice(&[1.0])?
257///         .with_initial_time(0.0)?
258///         .with_ending_time(1.0)?
259///         .with_derivative(derivative)
260///         .solve(())?;
261///     let path = solver.collect_vec()?;
262///
263///     for (time, state) in &path {
264///         assert!((time.exp() - state.column(0)[0]).abs() <= 0.001);
265///     }
266///     Ok(())
267/// }
268/// ```
269pub struct Euler<'a, N, D, T, F>
270where
271    N: ComplexField + Copy,
272    D: Dimension,
273    T: Clone,
274    F: Derivative<N, D, T> + 'a,
275    DefaultAllocator: Allocator<N, D>,
276{
277    init_dt: Option<N::RealField>,
278    init_time: Option<N::RealField>,
279    init_end: Option<N::RealField>,
280    init_state: Option<BVector<N, D>>,
281    init_derivative: Option<F>,
282    dim: D,
283    _data: PhantomData<&'a T>,
284}
285
286/// The struct that actually solves an IVP with Euler's method
287/// Is the associated [`IVPStepper`] for Euler (the IVPSolver)
288/// You should use Euler and not this type directly
289pub struct EulerSolver<'a, N, D, T, F>
290where
291    N: ComplexField + Copy,
292    D: Dimension,
293    T: Clone,
294    F: Derivative<N, D, T> + 'a,
295    DefaultAllocator: Allocator<N, D>,
296{
297    dt: N,
298    time: N,
299    end: N,
300    state: BVector<N, D>,
301    derivative: F,
302    data: T,
303    _lifetime: PhantomData<&'a ()>,
304}
305
306impl<'a, N, D, T, F> IVPStepper<D> for EulerSolver<'a, N, D, T, F>
307where
308    N: ComplexField + Copy,
309    D: Dimension,
310    T: Clone,
311    F: Derivative<N, D, T> + 'a,
312    DefaultAllocator: Allocator<N, D>,
313{
314    type Error = IVPError;
315    type Field = N;
316    type RealField = N::RealField;
317    type UserData = T;
318
319    fn step(
320        &mut self,
321    ) -> Result<(Self::RealField, BVector<Self::Field, D>), IVPStatus<Self::Error>> {
322        if self.time.real() >= self.end.real() {
323            return Err(IVPStatus::Done);
324        }
325        if (self.time + self.dt).real() >= self.end.real() {
326            self.dt = self.end - self.time;
327        }
328
329        let derivative = (self.derivative)(self.time.real(), self.state.as_slice(), &mut self.data)
330            .map_err(IVPError::UserError)?;
331
332        let old_time = self.time.real();
333        let old_state = self.state.clone();
334
335        self.state += derivative * self.dt;
336        self.time += self.dt;
337
338        Ok((old_time, old_state))
339    }
340
341    fn time(&self) -> Self::RealField {
342        self.time.real()
343    }
344}
345
346impl<'a, N, D, T, F> IVPSolver<'a, D> for Euler<'a, N, D, T, F>
347where
348    N: ComplexField + Copy,
349    D: Dimension,
350    T: Clone,
351    F: Derivative<N, D, T> + 'a,
352    DefaultAllocator: Allocator<N, D>,
353{
354    type Error = IVPError;
355    type Field = N;
356    type RealField = N::RealField;
357    type Derivative = F;
358    type UserData = T;
359    type Solver = EulerSolver<'a, N, D, T, F>;
360
361    fn new() -> Result<Self, Self::Error> {
362        Ok(Self {
363            init_dt: None,
364            init_time: None,
365            init_end: None,
366            init_state: None,
367            init_derivative: None,
368            dim: D::dim()?,
369            _data: PhantomData,
370        })
371    }
372
373    fn new_dyn(size: usize) -> Result<Self, Self::Error> {
374        Ok(Self {
375            init_dt: None,
376            init_time: None,
377            init_end: None,
378            init_state: None,
379            init_derivative: None,
380            dim: D::dim_dyn(size)?,
381            _data: PhantomData,
382        })
383    }
384
385    fn dim(&self) -> D {
386        self.dim
387    }
388
389    /// Unused for Euler, call is a no-op
390    fn with_tolerance(self, _tol: Self::RealField) -> Result<Self, Self::Error> {
391        Ok(self)
392    }
393
394    /// If there is not time step already, set, then set the time step.
395    /// If there is, set the time step to the average of that and the max passed in.
396    fn with_maximum_dt(mut self, max: Self::RealField) -> Result<Self, Self::Error> {
397        if max <= <Self::RealField as Zero>::zero() {
398            return Err(IVPError::TimeDeltaOOB);
399        }
400
401        self.init_dt = if let Some(dt) = self.init_dt {
402            Some((dt + max) / Self::RealField::from_u8(2).ok_or(IVPError::FromPrimitiveFailure)?)
403        } else {
404            Some(max)
405        };
406        Ok(self)
407    }
408
409    /// If there is not time step already, set, then set the time step.
410    /// If there is, set the time step to the average of that and the max passed in.
411    fn with_minimum_dt(mut self, min: Self::RealField) -> Result<Self, Self::Error> {
412        if min <= <Self::RealField as Zero>::zero() {
413            return Err(IVPError::TimeDeltaOOB);
414        }
415
416        self.init_dt = if let Some(dt) = self.init_dt {
417            Some((dt + min) / Self::RealField::from_u8(2).ok_or(IVPError::FromPrimitiveFailure)?)
418        } else {
419            Some(min)
420        };
421        Ok(self)
422    }
423
424    fn with_initial_time(mut self, initial: Self::RealField) -> Result<Self, Self::Error> {
425        self.init_time = Some(initial.clone());
426
427        if let Some(end) = self.init_end.as_ref() {
428            if *end <= initial {
429                return Err(IVPError::TimeStartOOB);
430            }
431        }
432
433        Ok(self)
434    }
435
436    fn with_ending_time(mut self, ending: Self::RealField) -> Result<Self, Self::Error> {
437        self.init_end = Some(ending.clone());
438
439        if let Some(initial) = self.init_time.as_ref() {
440            if *initial >= ending {
441                return Err(IVPError::TimeEndOOB);
442            }
443        }
444
445        Ok(self)
446    }
447
448    fn with_initial_conditions(
449        mut self,
450        start: BVector<Self::Field, D>,
451    ) -> Result<Self, Self::Error> {
452        self.init_state = Some(start);
453        Ok(self)
454    }
455
456    fn with_derivative(mut self, derivative: Self::Derivative) -> Self {
457        self.init_derivative = Some(derivative);
458        self
459    }
460
461    fn solve(mut self, data: Self::UserData) -> Result<IVPIterator<D, Self::Solver>, Self::Error> {
462        let dt = self.init_dt.ok_or(IVPError::MissingParameters)?;
463        let time = self.init_time.ok_or(IVPError::MissingParameters)?;
464        let end = self.init_end.ok_or(IVPError::MissingParameters)?;
465        let state = self.init_state.take().ok_or(IVPError::MissingParameters)?;
466        let derivative = self
467            .init_derivative
468            .take()
469            .ok_or(IVPError::MissingParameters)?;
470
471        Ok(IVPIterator {
472            solver: EulerSolver {
473                dt: N::from_real(dt),
474                time: N::from_real(time),
475                end: N::from_real(end),
476                state,
477                derivative,
478                data,
479                _lifetime: PhantomData,
480            },
481            finished: false,
482            _dim: PhantomData,
483        })
484    }
485}
486
487#[cfg(test)]
488mod test {
489    use super::*;
490    use crate::BSVector;
491    use nalgebra::{DimName, Dyn};
492
493    type Path<D> = Vec<(f64, BVector<f64, D>)>;
494
495    fn solve_ivp<'a, D, F>(
496        (initial, end): (f64, f64),
497        dt: f64,
498        initial_conds: &[f64],
499        derivative: F,
500    ) -> Result<Path<D>, IVPError>
501    where
502        D: Dimension,
503        F: Derivative<f64, D, ()> + 'a,
504        DefaultAllocator: Allocator<f64, D>,
505    {
506        let ivp = Euler::new()?
507            .with_initial_time(initial)?
508            .with_ending_time(end)?
509            .with_maximum_dt(dt)?
510            .with_initial_conditions_slice(initial_conds)?
511            .with_derivative(derivative);
512        ivp.solve(())?.collect()
513    }
514
515    fn exp_deriv(_: f64, y: &[f64], _: &mut ()) -> Result<BSVector<f64, 1>, UserError> {
516        Ok(BSVector::from_column_slice(y))
517    }
518
519    fn quadratic_deriv(t: f64, _y: &[f64], _: &mut ()) -> Result<BSVector<f64, 1>, UserError> {
520        Ok(BSVector::from_column_slice(&[-2.0 * t]))
521    }
522
523    fn sine_deriv(t: f64, y: &[f64], _: &mut ()) -> Result<BSVector<f64, 1>, UserError> {
524        Ok(BSVector::from_iterator(y.iter().map(|_| t.cos())))
525    }
526
527    fn cos_deriv(_t: f64, y: &[f64], _: &mut ()) -> Result<BSVector<f64, 2>, UserError> {
528        Ok(BSVector::from_column_slice(&[y[1], -y[0]]))
529    }
530
531    fn dynamic_cos_deriv(_t: f64, y: &[f64], _: &mut ()) -> Result<BVector<f64, Dyn>, UserError> {
532        Ok(BVector::from_column_slice_generic(
533            Dyn::from_usize(y.len()),
534            U1::name(),
535            &[y[1], -y[0]],
536        ))
537    }
538
539    #[test]
540    #[should_panic]
541    fn euler_dynamic_cos_panics() {
542        let t_initial = 0.0;
543        let t_final = 1.0;
544
545        let path = solve_ivp((t_initial, t_final), 0.01, &[1.0, 0.0], dynamic_cos_deriv).unwrap();
546
547        for step in path {
548            assert!(approx_eq!(
549                f64,
550                step.1.column(0)[0],
551                step.0.cos(),
552                epsilon = 0.01
553            ));
554        }
555    }
556
557    #[test]
558    fn euler_dynamic_cos() {
559        let t_initial = 0.0;
560        let t_final = 1.0;
561
562        let ivp = Euler::new_dyn(2)
563            .unwrap()
564            .with_initial_time(t_initial)
565            .unwrap()
566            .with_ending_time(t_final)
567            .unwrap()
568            .with_maximum_dt(0.01)
569            .unwrap()
570            .with_initial_conditions_slice(&[1.0, 0.0])
571            .unwrap()
572            .with_derivative(dynamic_cos_deriv)
573            .solve(())
574            .unwrap();
575        let path = ivp.collect_vec().unwrap();
576
577        for step in path {
578            assert!(approx_eq!(
579                f64,
580                step.1.column(0)[0],
581                step.0.cos(),
582                epsilon = 0.01
583            ));
584        }
585    }
586
587    #[test]
588    fn euler_cos() {
589        let t_initial = 0.0;
590        let t_final = 1.0;
591
592        let path = solve_ivp((t_initial, t_final), 0.01, &[1.0, 0.0], cos_deriv).unwrap();
593
594        for step in path {
595            assert!(approx_eq!(
596                f64,
597                step.1.column(0)[0],
598                step.0.cos(),
599                epsilon = 0.01
600            ));
601        }
602    }
603
604    #[test]
605    fn euler_exp() {
606        let t_initial = 0.0;
607        let t_final = 1.0;
608
609        let path = solve_ivp((t_initial, t_final), 0.005, &[1.0], exp_deriv).unwrap();
610
611        for step in path {
612            assert!(approx_eq!(
613                f64,
614                step.1.column(0)[0],
615                step.0.exp(),
616                epsilon = 0.01
617            ));
618        }
619    }
620
621    #[test]
622    fn euler_quadratic() {
623        let t_initial = 0.0;
624        let t_final = 1.0;
625
626        let path = solve_ivp((t_initial, t_final), 0.01, &[1.0], quadratic_deriv).unwrap();
627
628        for step in path {
629            assert!(approx_eq!(
630                f64,
631                step.1.column(0)[0],
632                1.0 - step.0.powi(2),
633                epsilon = 0.01
634            ));
635        }
636    }
637
638    #[test]
639    fn euler_sin() {
640        let t_initial = 0.0;
641        let t_final = 1.0;
642
643        let path = solve_ivp((t_initial, t_final), 0.01, &[0.0], sine_deriv).unwrap();
644
645        for step in path {
646            assert!(approx_eq!(
647                f64,
648                step.1.column(0)[0],
649                step.0.sin(),
650                epsilon = 0.01
651            ));
652        }
653    }
654}