use super::rhs_fn::{check_nan, Integrator, IntegratorOutput, RhsFn};
use crate::error::{Error, Result};
use crate::vector3::Vector3;
pub struct VelocityVerlet {
_private: (),
}
impl VelocityVerlet {
pub fn new() -> Self {
Self { _private: () }
}
}
impl Default for VelocityVerlet {
fn default() -> Self {
Self::new()
}
}
impl Integrator for VelocityVerlet {
fn step(
&mut self,
state: &[Vector3<f64>],
t: f64,
dt: f64,
f: &RhsFn<'_>,
) -> Result<IntegratorOutput> {
if state.len() % 2 != 0 {
return Err(Error::DimensionMismatch {
expected: "even number of state components (q, p pairs)".to_string(),
actual: format!("{}", state.len()),
});
}
let half = state.len() / 2;
let a_old = f(state, t);
let mut new_state = state.to_vec();
for i in 0..half {
new_state[half + i] = state[half + i] + a_old[half + i] * (dt * 0.5);
new_state[i] = state[i] + new_state[half + i] * dt;
}
let a_new = f(&new_state, t + dt);
for i in 0..half {
new_state[half + i] = new_state[half + i] + a_new[half + i] * (dt * 0.5);
}
check_nan(&new_state)?;
Ok(IntegratorOutput {
new_state,
error_estimate: None,
suggested_dt: None,
})
}
}
pub struct Yoshida4 {
_private: (),
}
impl Yoshida4 {
pub fn new() -> Self {
Self { _private: () }
}
}
impl Default for Yoshida4 {
fn default() -> Self {
Self::new()
}
}
const YOSHIDA_W1: f64 = 1.351_207_191_959_657_6;
const YOSHIDA_W0: f64 = -1.702_414_383_919_315;
impl Integrator for Yoshida4 {
fn step(
&mut self,
state: &[Vector3<f64>],
t: f64,
dt: f64,
f: &RhsFn<'_>,
) -> Result<IntegratorOutput> {
if state.len() % 2 != 0 {
return Err(Error::DimensionMismatch {
expected: "even number of state components (q, p pairs)".to_string(),
actual: format!("{}", state.len()),
});
}
let c1 = YOSHIDA_W1 / 2.0;
let c2 = (YOSHIDA_W0 + YOSHIDA_W1) / 2.0;
let c3 = c2;
let c4 = c1;
let d1 = YOSHIDA_W1;
let d2 = YOSHIDA_W0;
let d3 = YOSHIDA_W1;
let c_coeffs = [c1, c2, c3, c4];
let d_coeffs = [d1, d2, d3];
let half = state.len() / 2;
let mut cur = state.to_vec();
let mut time = t;
for i in 0..half {
cur[i] = cur[i] + cur[half + i] * (c_coeffs[0] * dt);
}
time += c_coeffs[0] * dt;
for step_idx in 0..3 {
let accel = f(&cur, time);
for i in 0..half {
cur[half + i] = cur[half + i] + accel[half + i] * (d_coeffs[step_idx] * dt);
}
for i in 0..half {
cur[i] = cur[i] + cur[half + i] * (c_coeffs[step_idx + 1] * dt);
}
time += c_coeffs[step_idx + 1] * dt;
}
check_nan(&cur)?;
Ok(IntegratorOutput {
new_state: cur,
error_estimate: None,
suggested_dt: None,
})
}
}
pub struct ForestRuth {
_private: (),
}
impl ForestRuth {
pub fn new() -> Self {
Self { _private: () }
}
}
impl Default for ForestRuth {
fn default() -> Self {
Self::new()
}
}
const FR_THETA: f64 = 1.351_207_191_959_657_6;
impl Integrator for ForestRuth {
fn step(
&mut self,
state: &[Vector3<f64>],
t: f64,
dt: f64,
f: &RhsFn<'_>,
) -> Result<IntegratorOutput> {
if state.len() % 2 != 0 {
return Err(Error::DimensionMismatch {
expected: "even number of state components (q, p pairs)".to_string(),
actual: format!("{}", state.len()),
});
}
let half = state.len() / 2;
let theta = FR_THETA;
let mut cur = state.to_vec();
for i in 0..half {
cur[i] = cur[i] + cur[half + i] * (theta * dt * 0.5);
}
let a1 = f(&cur, t + theta * dt * 0.5);
for i in 0..half {
cur[half + i] = cur[half + i] + a1[half + i] * (theta * dt);
}
for i in 0..half {
cur[i] = cur[i] + cur[half + i] * ((1.0 - theta) * dt * 0.5);
}
let a2 = f(&cur, t + dt * 0.5);
for i in 0..half {
cur[half + i] = cur[half + i] + a2[half + i] * ((1.0 - 2.0 * theta) * dt);
}
for i in 0..half {
cur[i] = cur[i] + cur[half + i] * ((1.0 - theta) * dt * 0.5);
}
let a3 = f(&cur, t + (1.0 - theta * 0.5) * dt);
for i in 0..half {
cur[half + i] = cur[half + i] + a3[half + i] * (theta * dt);
}
for i in 0..half {
cur[i] = cur[i] + cur[half + i] * (theta * dt * 0.5);
}
check_nan(&cur)?;
Ok(IntegratorOutput {
new_state: cur,
error_estimate: None,
suggested_dt: None,
})
}
}