use std::collections::VecDeque;
use crate::{InitialCondition, collections::stable_index_deque::StableIndexVecDeque};
pub trait State<const N: usize> {
fn method_order(&self) -> usize;
fn interpolation_order(&self) -> usize;
fn t(&self) -> f64;
fn t_mut(&mut self) -> &mut f64;
fn x(&self) -> [f64; N];
fn x_mut(&mut self) -> &mut [f64; N];
fn tx_mut(&mut self) -> (&mut f64, &mut [f64; N]);
fn t_prev(&self) -> f64;
fn x_prev(&self) -> [f64; N];
fn d_prev(&self) -> [f64; N];
fn t_init(&self) -> f64;
fn t_span(&self) -> f64;
fn t_seq(&self) -> &VecDeque<f64>;
fn t_seq_mut(&mut self) -> &mut VecDeque<f64>;
fn x_seq(&self) -> &VecDeque<[f64; N]>;
fn x_seq_mut(&mut self) -> &mut VecDeque<[f64; N]>;
fn disco_seq(&self) -> &StableIndexVecDeque<(f64, usize)>;
fn disco_seq_mut(&mut self) -> &mut StableIndexVecDeque<(f64, usize)>;
fn eval_all(&self, t: f64) -> [f64; N];
fn eval(&self, t: f64, coordinate: usize) -> f64;
fn eval_derivative(&self, t: f64, coordinate: usize) -> f64;
fn coord_fns<'b>(&'b self) -> [StateCoordFn<'b, N, Self>; N];
fn make_zero_step(&mut self);
fn make_step(&mut self, rhs: &mut impl StateFnMut<N, Output = [f64; N]>, t_step: f64);
fn undo_step(&mut self);
fn push_current(&mut self);
}
pub struct RKState<'a, const N: usize, const S: usize, IC: InitialCondition<N>>
where
[(); S * (S - 1) / 2]:,
{
pub t: f64,
pub t_prev: f64,
pub t_init: f64,
pub t_span: f64,
pub t_seq: std::collections::VecDeque<f64>,
pub x: [f64; N],
pub x_prev: [f64; N],
pub x_init: IC,
pub x_seq: std::collections::VecDeque<[f64; N]>,
pub disco_seq: StableIndexVecDeque<(f64, usize)>,
pub rk: &'a crate::rk::RungeKuttaTable<S>,
pub k: [[f64; N]; S],
pub k_seq: std::collections::VecDeque<[[f64; N]; S]>,
}
impl<'a, const N: usize, const S: usize, IC: InitialCondition<N>> RKState<'a, N, S, IC>
where
[(); S * (S - 1) / 2]:,
{
pub fn new(
t_init: f64,
x_init: IC,
t_span: f64,
rk: &'a crate::rk::RungeKuttaTable<S>,
) -> Self {
let x = x_init.eval::<0>(t_init);
Self {
t_init,
t: t_init,
t_prev: t_init,
t_span,
t_seq: std::collections::VecDeque::from([t_init]),
x_init,
x,
x_prev: x.clone(),
x_seq: std::collections::VecDeque::from([x.clone()]),
disco_seq: StableIndexVecDeque::from([(t_init, 0)]),
k: [[0.; N]; S],
k_seq: std::collections::VecDeque::new(),
rk,
}
}
}
impl<'a, const N: usize, const S: usize, IC: InitialCondition<N>> State<N> for RKState<'a, N, S, IC>
where
[(); S * (S - 1) / 2]:,
{
fn method_order(&self) -> usize {
self.rk.order
}
fn interpolation_order(&self) -> usize {
self.rk.order_interpolant
}
fn t(&self) -> f64 {
self.t
}
fn t_mut(&mut self) -> &mut f64 {
&mut self.t
}
fn x(&self) -> [f64; N] {
self.x
}
fn x_mut(&mut self) -> &mut [f64; N] {
&mut self.x
}
fn tx_mut(&mut self) -> (&mut f64, &mut [f64; N]) {
(&mut self.t, &mut self.x)
}
fn t_prev(&self) -> f64 {
self.t_prev
}
fn x_prev(&self) -> [f64; N] {
self.x_prev
}
fn d_prev(&self) -> [f64; N] {
self.k[0]
}
fn t_init(&self) -> f64 {
self.t_init
}
fn t_span(&self) -> f64 {
self.t_span
}
fn t_seq(&self) -> &VecDeque<f64> {
&self.t_seq
}
fn t_seq_mut(&mut self) -> &mut VecDeque<f64> {
&mut self.t_seq
}
fn x_seq(&self) -> &VecDeque<[f64; N]> {
&self.x_seq
}
fn x_seq_mut(&mut self) -> &mut VecDeque<[f64; N]> {
&mut self.x_seq
}
fn disco_seq(&self) -> &StableIndexVecDeque<(f64, usize)> {
&self.disco_seq
}
fn disco_seq_mut(&mut self) -> &mut StableIndexVecDeque<(f64, usize)> {
&mut self.disco_seq
}
fn eval_all(&self, t: f64) -> [f64; N] {
if t <= self.t_init {
self.x_init.eval::<0>(t)
} else if self.t_prev <= t && t <= self.t {
let x_prev = self.x_prev;
let k = self.k;
let t_prev = self.t_prev;
let t_next = self.t;
let t_step = t_next - t_prev;
if t_step == 0. {
return x_prev;
}
let theta = (t - t_prev) / t_step;
return std::array::from_fn(|i| {
x_prev[i] + t_step * (0..S).fold(0., |acc, j| acc + self.rk.bi[j](theta) * k[j][i])
});
} else {
let i = self.t_seq.partition_point(|t_i| t_i < &t); if i == 0 {
panic!(
"Evaluation of state in deleted time range. Try adding .with_delay({}) to your equation.",
self.t - t
);
} else if i == self.t_seq.len() {
panic!(
"Evaluation of state in a not yet computed time range at {t} while state.t is {}.",
self.t
);
}
let x_prev = &self.x_seq[i - 1];
let k = &self.k_seq[i - 1];
let t_prev = self.t_seq[i - 1];
let t_next = self.t_seq[i];
let t_step = t_next - t_prev;
if t_step == 0. {
return *x_prev;
}
let theta = (t - t_prev) / t_step;
return std::array::from_fn(|i| {
x_prev[i] + t_step * (0..S).fold(0., |acc, j| acc + self.rk.bi[j](theta) * k[j][i])
});
}
}
fn eval(&self, t: f64, coordinate: usize) -> f64 {
if t <= self.t_init {
self.x_init.eval::<0>(t)[coordinate]
}
else if self.t_prev <= t && t <= self.t {
let x_prev = self.x_prev[coordinate];
let k = self.k;
let t_prev = self.t_prev;
let t_next = self.t;
let t_step = t_next - t_prev;
if t_step == 0. {
return x_prev;
}
let theta = (t - t_prev) / t_step;
return x_prev
+ t_step * (0..S).fold(0., |acc, j| acc + self.rk.bi[j](theta) * k[j][coordinate]);
} else {
let i = self.t_seq.partition_point(|t_i| t_i < &t); if i == 0 {
panic!(
"Evaluation of state in deleted time range. Try adding .with_delay({}) to your equation.",
self.t - t
);
} else if i == self.t_seq.len() {
panic!(
"Evaluation of state in a not yet computed time range at {t} while state.t is {}.",
self.t
);
}
let x_prev = &self.x_seq[i - 1][coordinate];
let k = &self.k_seq[i - 1];
let t_prev = self.t_seq[i - 1];
let t_next = self.t_seq[i];
let t_step = t_next - t_prev;
if t_step == 0. {
return *x_prev;
}
let theta = (t - t_prev) / t_step;
return x_prev
+ t_step * (0..S).fold(0., |acc, j| acc + self.rk.bi[j](theta) * k[j][coordinate]);
}
}
fn eval_derivative(&self, t: f64, coordinate: usize) -> f64 {
if t <= self.t_init {
self.x_init.eval::<1>(t)[coordinate]
}
else if self.t_prev <= t && t <= self.t && self.t != self.t_prev {
let k = self.k;
let t_prev = self.t_prev;
let t_next = self.t;
let t_step = t_next - t_prev;
let theta = (t - t_prev) / t_step;
return (0..S).fold(0., |acc, j| acc + self.rk.bi[j].d(theta) * k[j][coordinate]);
} else {
let i = self.t_seq.partition_point(|t_i| t_i < &t); if i == 0 {
panic!(
"Evaluation of state in deleted time range. Try adding .with_delay({}) to your equation.",
self.t - t
);
} else if i == self.t_seq.len() {
panic!(
"Evaluation of state in a not yet computed time range at {t} while state.t is {}.",
self.t
);
}
let k = &self.k_seq[i - 1];
let t_prev = self.t_seq[i - 1];
let t_next = self.t_seq[i];
let t_step = t_next - t_prev;
let theta = (t - t_prev) / t_step;
return (0..S).fold(0., |acc, j| acc + self.rk.bi[j].d(theta) * k[j][coordinate]);
}
}
fn coord_fns<'b>(&'b self) -> [StateCoordFn<'b, N, Self>; N] {
std::array::from_fn(|i| StateCoordFn::<'b, N, Self> {
state: self,
coord: i,
})
}
fn push_current(&mut self) {
self.t_seq.push_back(self.t);
self.x_seq.push_back(self.x);
self.k_seq.push_back(self.k);
let t_tail = self.t_prev - self.t_span - (self.t - self.t_prev);
while &t_tail
> self
.t_seq
.front()
.expect("Last element won't pop for non-negative t_span")
{
self.t_seq.pop_front();
self.x_seq.pop_front();
self.k_seq.pop_front();
}
while let Some((t, _order)) = self.disco_seq.front()
&& &t_tail > t
{
self.disco_seq.pop_front();
}
}
fn make_step(&mut self, rhs: &mut impl StateFnMut<N, Output = [f64; N]>, t_step: f64) {
self.t_prev = self.t;
self.x_prev = self.x;
let mut a_i = 0;
for i in 0..S {
self.t = self.t_prev + self.rk.c[i] * t_step;
self.x = std::array::from_fn(|k| {
self.x_prev[k]
+ t_step * (0..i).fold(0., |acc, j| acc + self.rk.a[a_i + j] * self.k[j][k])
});
a_i += i;
self.k[i] = rhs.eval(self);
}
self.x = std::array::from_fn(|k| {
self.x_prev[k] + t_step * (0..S).fold(0., |acc, j| acc + self.rk.b[j] * self.k[j][k])
});
self.t = self.t_prev + t_step;
}
fn make_zero_step(&mut self) {
self.t_prev = self.t;
self.x_prev = self.x;
self.k = [[0.; N]; S];
}
fn undo_step(&mut self) {
self.t = self.t_prev;
self.x = self.x_prev;
}
}
pub trait StateFnMut<const N: usize> {
type Output;
fn eval(&mut self, state: &impl State<N>) -> Self::Output;
fn eval_prev(&mut self, state: &impl State<N>) -> Self::Output;
fn eval_at(&mut self, state: &impl State<N>, t: f64) -> Self::Output;
}
pub trait MutStateFnMut<const N: usize> {
type Output;
fn eval_mut(&mut self, state: &mut impl State<N>) -> Self::Output;
}
#[derive(Clone, Copy)]
pub struct ConstantStateFnMut<F: FnMut<(), Output = Ret>, Ret>(pub F);
impl<F: FnMut<(), Output = Ret>, Ret, const N: usize> StateFnMut<N> for ConstantStateFnMut<F, Ret> {
type Output = Ret;
fn eval(&mut self, _: &impl State<N>) -> Ret {
(self.0)()
}
fn eval_prev(&mut self, _: &impl State<N>) -> Ret {
(self.0)()
}
fn eval_at(&mut self, _: &impl State<N>, _: f64) -> Ret {
(self.0)()
}
}
impl<F: FnMut<(), Output = Ret>, Ret, const N: usize> MutStateFnMut<N>
for ConstantStateFnMut<F, Ret>
{
type Output = Ret;
fn eval_mut(&mut self, _: &mut impl State<N>) -> Ret {
(self.0)()
}
}
#[derive(Clone, Copy)]
pub struct TimeStateFnMut<F: FnMut<(f64,), Output = Ret>, Ret>(pub F);
impl<F: FnMut<(f64,), Output = Ret>, Ret, const N: usize> StateFnMut<N> for TimeStateFnMut<F, Ret> {
type Output = Ret;
fn eval(&mut self, state: &impl State<N>) -> Ret {
(self.0)(state.t())
}
fn eval_prev(&mut self, state: &impl State<N>) -> Ret {
(self.0)(state.t_prev())
}
fn eval_at(&mut self, _: &impl State<N>, t: f64) -> Ret {
(self.0)(t)
}
}
impl<F: FnMut<(f64,), Output = Ret>, Ret, const N: usize> MutStateFnMut<N>
for TimeStateFnMut<F, Ret>
{
type Output = Ret;
fn eval_mut(&mut self, state: &mut impl State<N>) -> Ret {
(self.0)(state.t())
}
}
#[derive(Clone, Copy)]
pub struct TimeMutStateFnMut<F: for<'a> FnMut<(&'a mut f64,), Output = Ret>, Ret>(pub F);
impl<F: for<'a> FnMut<(&'a mut f64,), Output = Ret>, Ret, const N: usize> MutStateFnMut<N>
for TimeMutStateFnMut<F, Ret>
{
type Output = Ret;
fn eval_mut(&mut self, state: &mut impl State<N>) -> Ret {
(self.0)(state.t_mut())
}
}
#[derive(Clone, Copy)]
pub struct ODEStateFnMut<const N: usize, F: FnMut<([f64; N],), Output = Ret>, Ret>(pub F);
impl<F: FnMut<([f64; N],), Output = Ret>, Ret, const N: usize> StateFnMut<N>
for ODEStateFnMut<N, F, Ret>
{
type Output = Ret;
fn eval(&mut self, state: &impl State<N>) -> Ret {
(self.0)(state.x())
}
fn eval_prev(&mut self, state: &impl State<N>) -> Ret {
(self.0)(state.x_prev())
}
fn eval_at(&mut self, state: &impl State<N>, t: f64) -> Ret {
(self.0)(state.eval_all(t))
}
}
impl<F: for<'a> FnMut<([f64; N],), Output = Ret>, Ret, const N: usize> MutStateFnMut<N>
for ODEStateFnMut<N, F, Ret>
{
type Output = Ret;
fn eval_mut(&mut self, state: &mut impl State<N>) -> Ret {
(self.0)(state.x())
}
}
#[derive(Clone, Copy)]
pub struct ODEMutStateFnMut<
const N: usize,
F: for<'a> FnMut<(&'a mut [f64; N],), Output = Ret>,
Ret,
>(pub F);
impl<F: for<'a> FnMut<(&'a mut [f64; N],), Output = Ret>, Ret, const N: usize> MutStateFnMut<N>
for ODEMutStateFnMut<N, F, Ret>
{
type Output = Ret;
fn eval_mut(&mut self, state: &mut impl State<N>) -> Ret {
(self.0)(state.x_mut())
}
}
#[derive(Clone, Copy)]
pub struct ODE2StateFnMut<const N: usize, F: FnMut<(f64, [f64; N]), Output = Ret>, Ret>(pub F);
impl<F: FnMut<(f64, [f64; N]), Output = Ret>, Ret, const N: usize> StateFnMut<N>
for ODE2StateFnMut<N, F, Ret>
{
type Output = Ret;
fn eval(&mut self, state: &impl State<N>) -> Ret {
(self.0)(state.t(), state.x())
}
fn eval_prev(&mut self, state: &impl State<N>) -> Ret {
(self.0)(state.t_prev(), state.x_prev())
}
fn eval_at(&mut self, state: &impl State<N>, t: f64) -> Ret {
(self.0)(t, state.eval_all(t))
}
}
impl<F: for<'a> FnMut<(f64, [f64; N]), Output = Ret>, Ret, const N: usize> MutStateFnMut<N>
for ODE2StateFnMut<N, F, Ret>
{
type Output = Ret;
fn eval_mut(&mut self, state: &mut impl State<N>) -> Ret {
(self.0)(state.t(), state.x())
}
}
#[derive(Clone, Copy)]
pub struct ODE2MutStateFnMut<
const N: usize,
F: for<'a> FnMut<(&'a mut f64, &'a mut [f64; N]), Output = Ret>,
Ret,
>(pub F);
impl<F: for<'a> FnMut<(&'a mut f64, &'a mut [f64; N]), Output = Ret>, Ret, const N: usize>
MutStateFnMut<N> for ODE2MutStateFnMut<N, F, Ret>
{
type Output = Ret;
fn eval_mut(&mut self, state: &mut impl State<N>) -> Ret {
let (t, x) = state.tx_mut();
(self.0)(t, x)
}
}
#[derive(Clone, Copy)]
pub struct DDEStateFnMut<
const N: usize,
F: for<'a> FnMut<(f64, [f64; N], [&'a dyn StateCoordFnTrait; N]), Output = Ret>,
Ret,
>(pub F);
impl<
F: for<'a> FnMut<(f64, [f64; N], [&'a dyn StateCoordFnTrait; N]), Output = Ret>,
Ret,
const N: usize,
> StateFnMut<N> for DDEStateFnMut<N, F, Ret>
{
type Output = Ret;
fn eval(&mut self, state: &impl State<N>) -> Ret {
let coord_fns: [StateCoordFn<'_, N, _>; N] = state.coord_fns();
let coord_fns = std::array::from_fn(|i| &coord_fns[i] as &dyn StateCoordFnTrait);
(self.0)(state.t(), state.x(), coord_fns)
}
fn eval_prev(&mut self, state: &impl State<N>) -> Ret {
let coord_fns: [StateCoordFn<'_, N, _>; N] = state.coord_fns();
let coord_fns = std::array::from_fn(|i| &coord_fns[i] as &dyn StateCoordFnTrait);
(self.0)(state.t_prev(), state.x_prev(), coord_fns)
}
fn eval_at(&mut self, state: &impl State<N>, t: f64) -> Ret {
let coord_fns: [StateCoordFn<'_, N, _>; N] = state.coord_fns();
let coord_fns = std::array::from_fn(|i| &coord_fns[i] as &dyn StateCoordFnTrait);
(self.0)(t, state.eval_all(t), coord_fns)
}
}
impl<
F: for<'a> FnMut<(f64, [f64; N], [&'a dyn StateCoordFnTrait; N]), Output = Ret>,
Ret,
const N: usize,
> MutStateFnMut<N> for DDEStateFnMut<N, F, Ret>
{
type Output = Ret;
fn eval_mut(&mut self, state: &mut impl State<N>) -> Ret {
let coord_fns: [StateCoordFn<'_, N, _>; N] = state.coord_fns();
let coord_fns = std::array::from_fn(|i| &coord_fns[i] as &dyn StateCoordFnTrait);
(self.0)(state.t(), state.x(), coord_fns)
}
}
pub struct DDEMutStateFnMut<
const N: usize,
F: for<'a> FnMut<
(
&'a mut f64,
&'a mut [f64; N],
[&'a dyn StateCoordFnTrait; N],
),
Output = Ret,
>,
Ret,
>(pub F);
impl<
F: for<'a> FnMut<
(
&'a mut f64,
&'a mut [f64; N],
[&'a dyn StateCoordFnTrait; N],
),
Output = Ret,
>,
Ret,
const N: usize,
> MutStateFnMut<N> for DDEMutStateFnMut<N, F, Ret>
{
type Output = Ret;
fn eval_mut(&mut self, state: &mut impl State<N>) -> Ret {
let coord_fns: [StateCoordFn<'_, N, _>; N] = state.coord_fns();
let coord_fns = std::array::from_fn(|i| &coord_fns[i] as &dyn StateCoordFnTrait);
let mut t = state.t();
let mut x = state.x();
let ret = (self.0)(&mut t, &mut x, coord_fns);
*state.t_mut() = t;
*state.x_mut() = x;
ret
}
}
pub struct StateFnMutComposition<F, SF>(pub F, pub SF);
impl<'a, 'b, Ret1, Ret2, SF: StateFnMut<N, Output = Ret1>, F: FnMut(Ret1) -> Ret2, const N: usize>
StateFnMut<N> for StateFnMutComposition<&'a mut F, &'b mut SF>
{
type Output = Ret2;
fn eval(&mut self, state: &impl State<N>) -> Self::Output {
self.0(self.1.eval(state))
}
fn eval_prev(&mut self, state: &impl State<N>) -> Self::Output {
self.0(self.1.eval_prev(state))
}
fn eval_at(&mut self, state: &impl State<N>, t: f64) -> Self::Output {
self.0(self.1.eval_at(state, t))
}
}
pub struct StateCoordFn<'a, const N: usize, S: State<N> + ?Sized> {
pub state: &'a S,
pub coord: usize,
}
pub trait StateCoordFnTrait: Fn(f64) -> f64 {
fn d(&self, t: f64) -> f64;
fn prev(&self) -> f64;
fn d_prev(&self) -> f64;
}
impl<'a, const N: usize, S: State<N>> FnOnce<(f64,)> for StateCoordFn<'a, N, S> {
type Output = f64;
#[inline]
extern "rust-call" fn call_once(self, arg: (f64,)) -> Self::Output {
self.state.eval(arg.0, self.coord)
}
}
impl<'a, const N: usize, S: State<N>> FnMut<(f64,)> for StateCoordFn<'a, N, S> {
#[inline]
extern "rust-call" fn call_mut(&mut self, arg: (f64,)) -> Self::Output {
self.state.eval(arg.0, self.coord)
}
}
impl<'a, const N: usize, S: State<N>> Fn<(f64,)> for StateCoordFn<'a, N, S> {
extern "rust-call" fn call(&self, arg: (f64,)) -> Self::Output {
self.state.eval(arg.0, self.coord)
}
}
impl<'a, const N: usize, S: State<N>> StateCoordFnTrait for StateCoordFn<'a, N, S> {
fn d(&self, t: f64) -> f64 {
self.state.eval_derivative(t, self.coord)
}
fn prev(&self) -> f64 {
self.state.x_prev()[self.coord]
}
fn d_prev(&self) -> f64 {
self.state.d_prev()[self.coord]
}
}
#[macro_export]
macro_rules! state_fn {
() => {
$crate::state::ConstantStateFnMut(|| {})
};
($($move:ident)? || $expr:expr) => {
$crate::state::ConstantStateFnMut($($move)? || {$expr})
};
($($move:ident)? |$t:ident| $expr:expr) => {
$crate::state::TimeStateFnMut($($move)? |$t| $expr)
};
($($move:ident)? |[$($x:pat),+]| $expr:expr) => {
$crate::state::ODEStateFnMut($($move)? |[$($x),+]| $expr)
};
($($move:ident)? |$t:pat, [$($x:pat),+]| $expr:expr) => {
$crate::state::ODE2StateFnMut($($move)? |$t, [$($x),+]| $expr)
};
($($move:ident)? |$t:pat, [$($x:pat),+], [$($x_:pat),+]| $expr:expr) => {
$crate::state::DDEStateFnMut($($move)? |$t, [$($x),+], [$($x_),+]| $expr)
};
}
#[macro_export]
macro_rules! mut_state_fn {
() => {
$crate::state::ConstantStateFnMut(|| {})
};
($($move:ident)? |$t:ident| $expr:expr) => {
$crate::state::TimeMutStateFnMut($($move)? |$t| $expr)
};
($($move:ident)? |[$($x:pat),+]| $expr:expr) => {
$crate::state::ODEMutStateFnMut($($move)? |[$($x),+]| $expr)
};
($($move:ident)? |$t:pat, [$($x:pat),+]| $expr:expr) => {
$crate::state::ODE2MutStateFnMut($($move)? |$t, [$($x),+]| $expr)
};
(|$t:ident, [$($x:ident),+], [$($x_:ident),+]| $expr:expr) => {
$crate::state::DDEMutStateFnMut(|$t, [$($x),+], [$($x_),+]| $expr)
};
}