use super::rhs_fn::{check_nan, max_error_norm, Integrator, IntegratorOutput, RhsFn};
use crate::error::Result;
use crate::vector3::Vector3;
pub struct SemiImplicit {
pub max_iterations: usize,
pub tolerance: f64,
}
impl SemiImplicit {
pub fn new(max_iterations: usize, tolerance: f64) -> Self {
Self {
max_iterations,
tolerance,
}
}
}
impl Default for SemiImplicit {
fn default() -> Self {
Self::new(50, 1e-12)
}
}
impl Integrator for SemiImplicit {
fn step(
&mut self,
state: &[Vector3<f64>],
t: f64,
dt: f64,
f: &RhsFn<'_>,
) -> Result<IntegratorOutput> {
let n = state.len();
let t_mid = t + dt * 0.5;
let f0 = f(state, t);
let mut y_new: Vec<Vector3<f64>> = state
.iter()
.zip(f0.iter())
.map(|(&si, &fi)| si + fi * dt)
.collect();
let mut converged = false;
for _ in 0..self.max_iterations {
let y_mid: Vec<Vector3<f64>> = state
.iter()
.zip(y_new.iter())
.map(|(&si, &yi)| (si + yi) * 0.5)
.collect();
let f_mid = f(&y_mid, t_mid);
let y_next: Vec<Vector3<f64>> = state
.iter()
.zip(f_mid.iter())
.map(|(&si, &fi)| si + fi * dt)
.collect();
let diff = max_error_norm(&y_next, &y_new);
y_new = y_next;
if diff < self.tolerance {
converged = true;
break;
}
}
check_nan(&y_new)?;
if !converged {
let residual = {
let y_mid: Vec<Vector3<f64>> = state
.iter()
.zip(y_new.iter())
.map(|(&si, &yi)| (si + yi) * 0.5)
.collect();
let f_mid = f(&y_mid, t_mid);
let y_check: Vec<Vector3<f64>> = state
.iter()
.zip(f_mid.iter())
.map(|(&si, &fi)| si + fi * dt)
.collect();
max_error_norm(&y_check, &y_new)
};
return Ok(IntegratorOutput {
new_state: y_new,
error_estimate: Some(residual),
suggested_dt: Some(dt * 0.5),
});
}
let y_euler: Vec<Vector3<f64>> = (0..n).map(|i| state[i] + f0[i] * dt).collect();
let error = max_error_norm(&y_new, &y_euler);
Ok(IntegratorOutput {
new_state: y_new,
error_estimate: Some(error),
suggested_dt: None,
})
}
}