use nalgebra::Scalar;
use num_traits::{Float, FromPrimitive, NumCast, One, Zero};
use simba::scalar::{
ClosedAdd, ClosedAddAssign, ClosedDiv, ClosedDivAssign, ClosedMul, ClosedMulAssign, ClosedNeg,
ClosedSub, ClosedSubAssign, SubsetOf,
};
use std::fmt;
use thiserror::Error;
pub trait System<T, V>
where
T: FloatNumber,
{
fn system(&self, x: T, y: &V, dy: &mut V);
fn solout(&mut self, _x: T, _y: &V, _dy: &V) -> bool {
false
}
}
#[derive(Debug, Clone)]
pub struct SolverResult<T, V>(Vec<T>, Vec<V>);
pub trait FloatNumber:
Copy
+ Float
+ NumCast
+ FromPrimitive
+ SubsetOf<f64>
+ Scalar
+ ClosedAdd
+ ClosedMul
+ ClosedDiv
+ ClosedSub
+ ClosedNeg
+ ClosedAddAssign
+ ClosedMulAssign
+ ClosedDivAssign
+ ClosedSubAssign
+ Zero
+ One
{
}
impl FloatNumber for f32 {}
impl FloatNumber for f64 {}
impl<T, V> SolverResult<T, V> {
pub fn new(x: Vec<T>, y: Vec<V>) -> Self {
SolverResult(x, y)
}
pub fn with_capacity(n: usize) -> Self {
SolverResult(Vec::with_capacity(n), Vec::with_capacity(n))
}
pub fn push(&mut self, x: T, y: V) {
self.0.push(x);
self.1.push(y);
}
pub fn append(&mut self, mut other: SolverResult<T, V>) {
self.0.append(&mut other.0);
self.1.append(&mut other.1);
}
pub fn get(&self) -> (&Vec<T>, &Vec<V>) {
(&self.0, &self.1)
}
}
impl<T, V> Default for SolverResult<T, V> {
fn default() -> Self {
Self(Default::default(), Default::default())
}
}
#[derive(PartialEq, Eq)]
pub enum OutputType {
Dense,
Sparse,
Continuous,
}
#[derive(Debug, Error)]
pub enum IntegrationError {
#[error("Stopped at x = {x}. Need more than {n_step} steps.")]
MaxNumStepReached { x: f64, n_step: u32 },
#[error("Stopped at x = {x}. Step size underflow.")]
StepSizeUnderflow { x: f64 },
#[error("The problem seems to become stiff at x = {x}.")]
StiffnessDetected { x: f64 },
}
#[derive(Clone, Copy, Debug)]
pub struct Stats {
pub num_eval: u32,
pub accepted_steps: u32,
pub rejected_steps: u32,
}
impl Stats {
pub(crate) fn new() -> Stats {
Stats {
num_eval: 0,
accepted_steps: 0,
rejected_steps: 0,
}
}
}
impl fmt::Display for Stats {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f, "Number of function evaluations: {}", self.num_eval)?;
writeln!(f, "Number of accepted steps: {}", self.accepted_steps)?;
write!(f, "Number of rejected steps: {}", self.rejected_steps)
}
}
pub(crate) fn sign<T: FloatNumber>(a: T, b: T) -> T {
if b > T::zero() {
a.abs()
} else {
-a.abs()
}
}