use alga::general::ComplexField;
use nalgebra::{allocator::Allocator, DefaultAllocator, DimName, VectorN, U6};
use num_traits::Zero;
mod adams;
mod bdf;
mod rk;
pub use adams::*;
pub use bdf::*;
pub use rk::*;
pub enum IVPStatus<N: ComplexField, S: DimName>
where
DefaultAllocator: Allocator<N, S>,
{
Redo,
Ok(Vec<(N::RealField, VectorN<N, S>)>),
Done,
}
type Path<Complex, Real, S> = Result<Vec<(Real, VectorN<Complex, S>)>, String>;
pub trait IVPSolver<N: ComplexField, S: DimName>: Sized
where
DefaultAllocator: Allocator<N, S>,
{
fn step<T: Clone, F: FnMut(N::RealField, &[N], &mut T) -> Result<VectorN<N, S>, String>>(
&mut self,
f: F,
params: &mut T,
) -> Result<IVPStatus<N, S>, String>;
fn with_tolerance(self, tol: N::RealField) -> Result<Self, String>;
fn with_dt_max(self, max: N::RealField) -> Result<Self, String>;
fn with_dt_min(self, min: N::RealField) -> Result<Self, String>;
fn with_start(self, t_initial: N::RealField) -> Result<Self, String>;
fn with_end(self, t_final: N::RealField) -> Result<Self, String>;
fn with_initial_conditions(self, start: &[N]) -> Result<Self, String>;
fn build(self) -> Self;
fn get_initial_conditions(&self) -> Option<VectorN<N, S>>;
fn get_time(&self) -> Option<N::RealField>;
fn check_start(&self) -> Result<(), String>;
fn solve_ivp<
T: Clone,
F: FnMut(N::RealField, &[N], &mut T) -> Result<VectorN<N, S>, String>,
>(
mut self,
mut f: F,
params: &mut T,
) -> Path<N, N::RealField, S> {
self.check_start()?;
let mut path = vec![];
let init_conditions = self.get_initial_conditions();
let time = self.get_time();
path.push((time.unwrap(), init_conditions.unwrap()));
'out: loop {
let step = self.step(&mut f, params)?;
match step {
IVPStatus::Done => break 'out,
IVPStatus::Redo => {}
IVPStatus::Ok(mut state) => path.append(&mut state),
}
}
Ok(path)
}
}
#[derive(Debug, Clone, Default)]
#[cfg_attr(serialize, derive(Serialize, Deserialize))]
pub struct Euler<N: ComplexField, S: DimName>
where
DefaultAllocator: Allocator<N, S>,
{
dt: Option<N::RealField>,
time: Option<N::RealField>,
end: Option<N::RealField>,
state: Option<VectorN<N, S>>,
}
impl<N: ComplexField, S: DimName> Euler<N, S>
where
DefaultAllocator: Allocator<N, S>,
{
pub fn new() -> Self {
Euler {
dt: None,
time: None,
end: None,
state: None,
}
}
}
impl<N: ComplexField, S: DimName> IVPSolver<N, S> for Euler<N, S>
where
DefaultAllocator: Allocator<N, S>,
{
fn step<T: Clone, F: FnMut(N::RealField, &[N], &mut T) -> Result<VectorN<N, S>, String>>(
&mut self,
mut f: F,
params: &mut T,
) -> Result<IVPStatus<N, S>, String> {
if self.time >= self.end {
return Ok(IVPStatus::Done);
}
if self.time.unwrap() + self.dt.unwrap() >= self.end.unwrap() {
self.dt = Some(self.end.unwrap() - self.time.unwrap());
}
let deriv = f(
self.time.unwrap(),
self.state.as_ref().unwrap().as_slice(),
params,
)?;
*self.state.get_or_insert(VectorN::from_iterator(
[N::zero()].repeat(self.state.as_ref().unwrap().as_slice().len()),
)) += deriv * N::from_real(self.dt.unwrap());
*self.time.get_or_insert(N::RealField::zero()) += self.dt.unwrap();
Ok(IVPStatus::Ok(vec![(
self.time.unwrap(),
self.state.clone().unwrap(),
)]))
}
fn with_tolerance(self, _tol: N::RealField) -> Result<Self, String> {
Ok(self)
}
fn with_dt_max(mut self, max: N::RealField) -> Result<Self, String> {
self.dt = Some(max);
Ok(self)
}
fn with_dt_min(self, _min: N::RealField) -> Result<Self, String> {
Ok(self)
}
fn with_start(mut self, t_initial: N::RealField) -> Result<Self, String> {
if let Some(end) = self.end {
if end <= t_initial {
return Err("Euler with_end: Start must be after end".to_owned());
}
}
self.time = Some(t_initial);
Ok(self)
}
fn with_end(mut self, t_final: N::RealField) -> Result<Self, String> {
if let Some(start) = self.time {
if start >= t_final {
return Err("Euler with_end: Start must be after end".to_owned());
}
}
self.end = Some(t_final);
Ok(self)
}
fn with_initial_conditions(mut self, start: &[N]) -> Result<Self, String> {
self.state = Some(VectorN::from_column_slice(start));
Ok(self)
}
fn build(self) -> Self {
self
}
fn get_initial_conditions(&self) -> Option<VectorN<N, S>> {
if let Some(state) = &self.state {
Some(state.clone())
} else {
None
}
}
fn get_time(&self) -> Option<N::RealField> {
self.time
}
fn check_start(&self) -> Result<(), String> {
if self.time == None {
Err("Euler check_start: No initial time".to_owned())
} else if self.end == None {
Err("Euler check_start: No end time".to_owned())
} else if self.state == None {
Err("Euler check_start: No initial conditions".to_owned())
} else if self.dt == None {
Err("Euler check_start: No dt".to_owned())
} else {
Ok(())
}
}
}
pub fn solve_ivp<
N: ComplexField,
S: DimName,
T: Clone,
F: FnMut(N::RealField, &[N], &mut T) -> Result<VectorN<N, S>, String>,
>(
(start, end): (N::RealField, N::RealField),
(dt_max, dt_min): (N::RealField, N::RealField),
y_0: &[N],
mut f: F,
tol: N::RealField,
params: &mut T,
) -> Path<N, N::RealField, S>
where
DefaultAllocator: Allocator<N, S>,
DefaultAllocator: Allocator<N, U6>,
DefaultAllocator: Allocator<N, S, U6>,
DefaultAllocator: Allocator<N, U6, U6>,
DefaultAllocator: Allocator<N::RealField, U6>,
DefaultAllocator: Allocator<N::RealField, U6, U6>,
{
let solver = Adams::new()
.with_start(start)?
.with_end(end)?
.with_dt_max(dt_max)?
.with_dt_min(dt_min)?
.with_tolerance(tol)?
.with_initial_conditions(y_0)?
.build();
let path = solver.solve_ivp(&mut f, &mut params.clone());
if let Ok(path) = path {
return Ok(path);
}
let solver: RK45<N, S> = RK45::new()
.with_initial_conditions(y_0)?
.with_start(start)?
.with_end(end)?
.with_dt_max(dt_max)?
.with_dt_min(dt_min)?
.with_tolerance(tol)?
.build();
let path = solver.solve_ivp(&mut f, &mut params.clone());
if let Ok(path) = path {
return Ok(path);
}
let solver = BDF6::new()
.with_start(start)?
.with_end(end)?
.with_dt_max(dt_max)?
.with_dt_min(dt_min)?
.with_tolerance(tol)?
.with_initial_conditions(y_0)?
.build();
solver.solve_ivp(&mut f, params)
}