use crate::error::{Error, Result};
use crate::vector3::Vector3;
#[derive(Debug, Clone)]
pub struct IntegratorOutput {
pub new_state: Vec<Vector3<f64>>,
pub error_estimate: Option<f64>,
pub suggested_dt: Option<f64>,
}
pub type RhsFn<'a> = dyn Fn(&[Vector3<f64>], f64) -> Vec<Vector3<f64>> + 'a;
pub trait Integrator {
fn step(
&mut self,
state: &[Vector3<f64>],
t: f64,
dt: f64,
f: &RhsFn<'_>,
) -> Result<IntegratorOutput>;
}
#[inline]
pub(super) fn vec_add_scaled(a: &[Vector3<f64>], b: &[Vector3<f64>], s: f64) -> Vec<Vector3<f64>> {
a.iter()
.zip(b.iter())
.map(|(&ai, &bi)| ai + bi * s)
.collect()
}
#[inline]
pub(super) fn vec_combine(
base: &[Vector3<f64>],
ks: &[&[Vector3<f64>]],
coeffs: &[f64],
) -> Vec<Vector3<f64>> {
let n = base.len();
let mut out = base.to_vec();
for i in 0..n {
for (k, &c) in ks.iter().zip(coeffs.iter()) {
out[i] = out[i] + k[i] * c;
}
}
out
}
#[inline]
pub(super) fn max_error_norm(a: &[Vector3<f64>], b: &[Vector3<f64>]) -> f64 {
a.iter()
.zip(b.iter())
.map(|(&ai, &bi)| (ai - bi).magnitude())
.fold(0.0_f64, f64::max)
}
pub(super) fn check_nan(state: &[Vector3<f64>]) -> Result<()> {
for v in state {
if v.x.is_nan() || v.y.is_nan() || v.z.is_nan() {
return Err(Error::NumericalError {
description: "NaN detected in integrator output".to_string(),
});
}
}
Ok(())
}