use nalgebra::SVector;
use crate::Tolerances;
pub trait OdeState: Clone + Sized {
fn zero_like(&self) -> Self;
fn axpy(&self, scale: f64, other: &Self) -> Self;
fn scale(&self, factor: f64) -> Self;
fn is_finite(&self) -> bool;
fn error_norm(&self, y_next: &Self, error: &Self, tol: &Tolerances) -> f64;
fn project(&mut self, _t: f64) {}
}
#[derive(Debug, Clone, PartialEq)]
pub struct State<const DIM: usize, const ORDER: usize> {
pub components: [SVector<f64, DIM>; ORDER],
}
impl<const DIM: usize, const ORDER: usize> OdeState for State<DIM, ORDER> {
fn zero_like(&self) -> Self {
State {
components: [SVector::zeros(); ORDER],
}
}
fn axpy(&self, scale: f64, other: &Self) -> Self {
let mut components = self.components;
for (c, (s, o)) in components
.iter_mut()
.zip(self.components.iter().zip(other.components.iter()))
{
*c = s + scale * o;
}
State { components }
}
fn scale(&self, factor: f64) -> Self {
let mut components = self.components;
for (c, s) in components.iter_mut().zip(self.components.iter()) {
*c = factor * s;
}
State { components }
}
fn is_finite(&self) -> bool {
self.components
.iter()
.flat_map(|c| c.iter())
.all(|v| v.is_finite())
}
fn error_norm(&self, y_next: &Self, error: &Self, tol: &Tolerances) -> f64 {
let mut sum_sq = 0.0;
let n = DIM * ORDER;
for i in 0..ORDER {
for j in 0..DIM {
let sc = tol.atol
+ tol.rtol
* self.components[i][j]
.abs()
.max(y_next.components[i][j].abs());
let e = error.components[i][j] / sc;
sum_sq += e * e;
}
}
(sum_sq / n as f64).sqrt()
}
}
impl<const DIM: usize> State<DIM, 2> {
pub fn new(y: SVector<f64, DIM>, dy: SVector<f64, DIM>) -> Self {
State {
components: [y, dy],
}
}
pub fn y(&self) -> &SVector<f64, DIM> {
&self.components[0]
}
pub fn dy(&self) -> &SVector<f64, DIM> {
&self.components[1]
}
pub fn y_mut(&mut self) -> &mut SVector<f64, DIM> {
&mut self.components[0]
}
pub fn dy_mut(&mut self) -> &mut SVector<f64, DIM> {
&mut self.components[1]
}
pub fn from_derivative(dy: SVector<f64, DIM>, ddy: SVector<f64, DIM>) -> Self {
State {
components: [dy, ddy],
}
}
}
#[cfg(test)]
mod tests {
use nalgebra::SVector;
use proptest::prelude::*;
use super::*;
proptest! {
#[test]
fn axpy_zero_is_identity(
x in -100.0f64..100.0,
v in -100.0f64..100.0,
ox in -100.0f64..100.0,
ov in -100.0f64..100.0,
) {
let state = State::<1, 2>::new(SVector::from([x]), SVector::from([v]));
let other = State::<1, 2>::new(SVector::from([ox]), SVector::from([ov]));
let result = state.axpy(0.0, &other);
prop_assert!((result.y()[0] - x).abs() < 1e-15);
prop_assert!((result.dy()[0] - v).abs() < 1e-15);
}
#[test]
fn scale_one_is_identity(
x in -100.0f64..100.0,
v in -100.0f64..100.0,
) {
let state = State::<1, 2>::new(SVector::from([x]), SVector::from([v]));
let result = state.scale(1.0);
prop_assert!((result.y()[0] - x).abs() < 1e-15);
prop_assert!((result.dy()[0] - v).abs() < 1e-15);
}
#[test]
fn scale_is_multiplicative(
x in -100.0f64..100.0,
v in -100.0f64..100.0,
a in -10.0f64..10.0,
b in -10.0f64..10.0,
) {
let state = State::<1, 2>::new(SVector::from([x]), SVector::from([v]));
let left = state.scale(a).scale(b);
let right = state.scale(a * b);
prop_assert!((left.y()[0] - right.y()[0]).abs() < 1e-10 * (1.0 + right.y()[0].abs()));
prop_assert!((left.dy()[0] - right.dy()[0]).abs() < 1e-10 * (1.0 + right.dy()[0].abs()));
}
#[test]
fn axpy_is_linear(
x in -100.0f64..100.0,
v in -100.0f64..100.0,
ox in -100.0f64..100.0,
ov in -100.0f64..100.0,
s in -10.0f64..10.0,
) {
let state = State::<1, 2>::new(SVector::from([x]), SVector::from([v]));
let other = State::<1, 2>::new(SVector::from([ox]), SVector::from([ov]));
let result = state.axpy(s, &other);
let expected_y = x + s * ox;
let expected_dy = v + s * ov;
prop_assert!((result.y()[0] - expected_y).abs() < 1e-10 * (1.0 + expected_y.abs()));
prop_assert!((result.dy()[0] - expected_dy).abs() < 1e-10 * (1.0 + expected_dy.abs()));
}
#[test]
fn zero_like_is_zero(
x in -100.0f64..100.0,
v in -100.0f64..100.0,
) {
let state = State::<1, 2>::new(SVector::from([x]), SVector::from([v]));
let zero = state.zero_like();
prop_assert_eq!(zero.y()[0], 0.0);
prop_assert_eq!(zero.dy()[0], 0.0);
}
#[test]
fn is_finite_for_finite_values(
x in -1e300f64..1e300,
v in -1e300f64..1e300,
) {
let state = State::<1, 2>::new(SVector::from([x]), SVector::from([v]));
prop_assert!(state.is_finite());
}
}
}
pub trait DynamicalSystem {
type State: OdeState;
fn derivatives(&self, t: f64, state: &Self::State) -> Self::State;
}