bacon_sci_1/ivp/
mod.rs

1/* This file is part of bacon.
2 * Copyright (c) Wyatt Campbell.
3 *
4 * See repository LICENSE for information.
5 */
6
7use nalgebra::{
8    allocator::Allocator, dimension::DimMin, ComplexField, DefaultAllocator, DimName, VectorN, U1,
9    U6, U7,
10};
11use num_traits::Zero;
12
13mod adams;
14mod bdf;
15mod rk;
16pub use adams::*;
17pub use bdf::*;
18pub use rk::*;
19
20/// Status of an IVP Solver after a step
21pub enum IVPStatus<N: ComplexField, S: DimName>
22where
23    DefaultAllocator: Allocator<N, S>,
24{
25    Redo,
26    Ok(Vec<(N::RealField, VectorN<N, S>)>),
27    Done,
28}
29
30type Path<Complex, Real, S> = Result<Vec<(Real, VectorN<Complex, S>)>, String>;
31
32/// Trait defining what it means to be an IVP solver.
33/// solve_ivp is automatically implemented based on your step implementation.
34pub trait IVPSolver<N: ComplexField, S: DimName>: Sized
35where
36    DefaultAllocator: Allocator<N, S>,
37{
38    /// Step forward in the solver. Returns if the solver is finished, produced
39    /// an acceptable state, or needs to be redone.
40    fn step<T: Clone, F: FnMut(N::RealField, &[N], &mut T) -> Result<VectorN<N, S>, String>>(
41        &mut self,
42        f: F,
43        params: &mut T,
44    ) -> Result<IVPStatus<N, S>, String>;
45    /// Set the error tolerance for this solver.
46    fn with_tolerance(self, tol: N::RealField) -> Result<Self, String>;
47    /// Set the maximum time step for this solver.
48    fn with_dt_max(self, max: N::RealField) -> Result<Self, String>;
49    /// Set the minimum time step for this solver.
50    fn with_dt_min(self, min: N::RealField) -> Result<Self, String>;
51    /// Set the initial time for this solver.
52    fn with_start(self, t_initial: N::RealField) -> Result<Self, String>;
53    /// Set the end time for this solver.
54    fn with_end(self, t_final: N::RealField) -> Result<Self, String>;
55    /// Set the initial conditions for this solver.
56    fn with_initial_conditions(self, start: &[N]) -> Result<Self, String>;
57    /// Build this solver.
58    fn build(self) -> Self;
59
60    /// Return the initial conditions. Called once at the very start
61    /// of solving.
62    fn get_initial_conditions(&self) -> Option<VectorN<N, S>>;
63    /// Get the current time of the solver.
64    fn get_time(&self) -> Option<N::RealField>;
65    /// Make sure that every value that needs to be set
66    /// is set before the solver starts
67    fn check_start(&self) -> Result<(), String>;
68
69    /// Solve an initial value problem, consuming the solver
70    fn solve_ivp<
71        T: Clone,
72        F: FnMut(N::RealField, &[N], &mut T) -> Result<VectorN<N, S>, String>,
73    >(
74        mut self,
75        mut f: F,
76        params: &mut T,
77    ) -> Path<N, N::RealField, S> {
78        self.check_start()?;
79        let mut path = vec![];
80        let init_conditions = self.get_initial_conditions();
81        let time = self.get_time();
82        path.push((time.unwrap(), init_conditions.unwrap()));
83
84        'out: loop {
85            let step = self.step(&mut f, params)?;
86            match step {
87                IVPStatus::Done => break 'out,
88                IVPStatus::Redo => {}
89                IVPStatus::Ok(mut state) => path.append(&mut state),
90            }
91        }
92
93        Ok(path)
94    }
95}
96
97/// Euler solver for an IVP.
98///
99/// Solves an initial value problem using Euler's method.
100///
101/// # Examples
102/// ```
103/// use nalgebra::{VectorN, U1};
104/// use bacon_sci::ivp::{Euler, IVPSolver};
105/// fn derivative(_t: f64, state: &[f64], _p: &mut ()) -> Result<VectorN<f64, U1>, String> {
106///     Ok(VectorN::<f64, U1>::from_column_slice(state))
107/// }
108///
109/// fn example() -> Result<(), String> {
110///     let solver = Euler::new()
111///         .with_dt_max(0.001)?
112///         .with_initial_conditions(&[1.0])?
113///         .with_start(0.0)?
114///         .with_end(1.0)?
115///         .build();
116///     let path = solver.solve_ivp(derivative, &mut ())?;
117///
118///     for (time, state) in &path {
119///         assert!((time.exp() - state.column(0)[0]).abs() <= 0.001);
120///     }
121///     Ok(())
122/// }
123/// ```
124#[derive(Debug, Clone, Default)]
125#[cfg_attr(serialize, derive(Serialize, Deserialize))]
126pub struct Euler<N: ComplexField, S: DimName>
127where
128    DefaultAllocator: Allocator<N, S>,
129{
130    dt: Option<N::RealField>,
131    time: Option<N::RealField>,
132    end: Option<N::RealField>,
133    state: Option<VectorN<N, S>>,
134}
135
136impl<N: ComplexField, S: DimName> Euler<N, S>
137where
138    DefaultAllocator: Allocator<N, S>,
139{
140    pub fn new() -> Self {
141        Euler {
142            dt: None,
143            time: None,
144            end: None,
145            state: None,
146        }
147    }
148}
149
150impl<N: ComplexField, S: DimName> IVPSolver<N, S> for Euler<N, S>
151where
152    DefaultAllocator: Allocator<N, S>,
153{
154    fn step<T: Clone, F: FnMut(N::RealField, &[N], &mut T) -> Result<VectorN<N, S>, String>>(
155        &mut self,
156        mut f: F,
157        params: &mut T,
158    ) -> Result<IVPStatus<N, S>, String> {
159        if self.time >= self.end {
160            return Ok(IVPStatus::Done);
161        }
162        if self.time.unwrap() + self.dt.unwrap() >= self.end.unwrap() {
163            self.dt = Some(self.end.unwrap() - self.time.unwrap());
164        }
165
166        let deriv = f(
167            self.time.unwrap(),
168            self.state.as_ref().unwrap().as_slice(),
169            params,
170        )?;
171
172        *self.state.get_or_insert(VectorN::from_iterator(
173            [N::zero()].repeat(self.state.as_ref().unwrap().as_slice().len()),
174        )) += deriv * N::from_real(self.dt.unwrap());
175        *self.time.get_or_insert(N::RealField::zero()) += self.dt.unwrap();
176        Ok(IVPStatus::Ok(vec![(
177            self.time.unwrap(),
178            self.state.clone().unwrap(),
179        )]))
180    }
181
182    fn with_tolerance(self, _tol: N::RealField) -> Result<Self, String> {
183        Ok(self)
184    }
185
186    fn with_dt_max(mut self, max: N::RealField) -> Result<Self, String> {
187        self.dt = Some(max);
188        Ok(self)
189    }
190
191    fn with_dt_min(self, _min: N::RealField) -> Result<Self, String> {
192        Ok(self)
193    }
194
195    fn with_start(mut self, t_initial: N::RealField) -> Result<Self, String> {
196        if let Some(end) = self.end {
197            if end <= t_initial {
198                return Err("Euler with_end: Start must be after end".to_owned());
199            }
200        }
201        self.time = Some(t_initial);
202        Ok(self)
203    }
204
205    fn with_end(mut self, t_final: N::RealField) -> Result<Self, String> {
206        if let Some(start) = self.time {
207            if start >= t_final {
208                return Err("Euler with_end: Start must be after end".to_owned());
209            }
210        }
211        self.end = Some(t_final);
212        Ok(self)
213    }
214
215    fn with_initial_conditions(mut self, start: &[N]) -> Result<Self, String> {
216        self.state = Some(VectorN::from_column_slice(start));
217        Ok(self)
218    }
219
220    fn build(self) -> Self {
221        self
222    }
223
224    fn get_initial_conditions(&self) -> Option<VectorN<N, S>> {
225        if let Some(state) = &self.state {
226            Some(state.clone())
227        } else {
228            None
229        }
230    }
231
232    fn get_time(&self) -> Option<N::RealField> {
233        self.time
234    }
235
236    fn check_start(&self) -> Result<(), String> {
237        if self.time == None {
238            Err("Euler check_start: No initial time".to_owned())
239        } else if self.end == None {
240            Err("Euler check_start: No end time".to_owned())
241        } else if self.state == None {
242            Err("Euler check_start: No initial conditions".to_owned())
243        } else if self.dt == None {
244            Err("Euler check_start: No dt".to_owned())
245        } else {
246            Ok(())
247        }
248    }
249}
250
251/// Solve an initial value problem of y'(t) = f(t, y) numerically.
252///
253/// Tries to solve an initial value problem with an Adams predictor-corrector,
254/// the Runge-Kutta-Fehlberg method, and finally a backwards differentiation formula.
255/// This is probably what you want to use.
256///
257/// # Params
258/// `(start, end)` The start and end times for the IVP
259///
260/// `(dt_max, dt_min)` The maximum and minimum time step for solving
261///
262/// `y_0` The initial conditions at `start`
263///
264/// `f` the derivative function
265///
266/// `tol` acceptable error between steps.
267///
268/// `params` parameters to pass to the derivative function
269///
270/// # Examples
271/// ```
272/// use nalgebra::{VectorN, U1};
273/// use bacon_sci::ivp::solve_ivp;
274/// fn derivatives(_: f64, y: &[f64], _: &mut ()) -> Result<VectorN<f64, U1>, String> {
275///     Ok(-VectorN::<f64, U1>::from_column_slice(y))
276/// }
277///
278/// fn example() -> Result<(), String> {
279///     let path = solve_ivp((0.0, 10.0), (0.1, 0.001), &[1.0], derivatives, 0.00001, &mut ())?;
280///
281///     for step in path {
282///         assert!(((-step.0).exp() - step.1.column(0)[0]).abs() < 0.001);
283///     }
284///
285///     Ok(())
286/// }
287/// ```
288pub fn solve_ivp<N, S, T, F>(
289    (start, end): (N::RealField, N::RealField),
290    (dt_max, dt_min): (N::RealField, N::RealField),
291    y_0: &[N],
292    mut f: F,
293    tol: N::RealField,
294    params: &mut T,
295) -> Path<N, N::RealField, S>
296where
297    N: ComplexField,
298    S: DimName + DimMin<S, Output = S>,
299    T: Clone,
300    F: FnMut(N::RealField, &[N], &mut T) -> Result<VectorN<N, S>, String>,
301    DefaultAllocator: Allocator<N, S>
302        + Allocator<N, U6>
303        + Allocator<N, S, U6>
304        + Allocator<N, U6, U6>
305        + Allocator<N::RealField, U6>
306        + Allocator<N::RealField, U6, U6>
307        + Allocator<N, U7>
308        + Allocator<N, S, S>
309        + Allocator<N, U1, S>
310        + Allocator<(usize, usize), S>,
311{
312    let solver = Adams::new()
313        .with_start(start)?
314        .with_end(end)?
315        .with_dt_max(dt_max)?
316        .with_dt_min(dt_min)?
317        .with_tolerance(tol)?
318        .with_initial_conditions(y_0)?
319        .build();
320
321    let path = solver.solve_ivp(&mut f, &mut params.clone());
322
323    if let Ok(path) = path {
324        return Ok(path);
325    }
326
327    let solver: RK45<N, S> = RK45::new()
328        .with_initial_conditions(y_0)?
329        .with_start(start)?
330        .with_end(end)?
331        .with_dt_max(dt_max)?
332        .with_dt_min(dt_min)?
333        .with_tolerance(tol)?
334        .build();
335
336    let path = solver.solve_ivp(&mut f, &mut params.clone());
337
338    if let Ok(path) = path {
339        return Ok(path);
340    }
341
342    let solver = BDF6::new()
343        .with_start(start)?
344        .with_end(end)?
345        .with_dt_max(dt_max)?
346        .with_dt_min(dt_min)?
347        .with_tolerance(tol)?
348        .with_initial_conditions(y_0)?
349        .build();
350
351    solver.solve_ivp(&mut f, params)
352}