use crate::error::{SolverError, SolverResult};
use super::types::{OdeConfig, OdeSolution, OdeSystem};
use super::utils::validate_ode_inputs;
pub struct EulerSolver;
impl EulerSolver {
pub fn solve(
system: &dyn OdeSystem,
y0: &[f64],
config: &OdeConfig,
) -> SolverResult<OdeSolution> {
let n = system.dim();
validate_ode_inputs(n, y0, config)?;
let mut t = config.t_start;
let dt = config.dt;
let mut y = y0.to_vec();
let mut k = vec![0.0; n];
let mut times = vec![t];
let mut states = vec![y.clone()];
let mut num_steps = 0_usize;
let mut num_rhs = 0_usize;
while t < config.t_end - dt * 1e-10 && num_steps < config.max_steps {
let h = dt.min(config.t_end - t);
system.rhs(t, &y, &mut k)?;
num_rhs += 1;
for i in 0..n {
y[i] += h * k[i];
}
t += h;
num_steps += 1;
times.push(t);
states.push(y.clone());
}
Ok(OdeSolution {
times,
states,
num_steps,
num_rejected: 0,
num_rhs_evals: num_rhs,
})
}
}
pub struct Rk4Solver;
impl Rk4Solver {
pub fn solve(
system: &dyn OdeSystem,
y0: &[f64],
config: &OdeConfig,
) -> SolverResult<OdeSolution> {
let n = system.dim();
validate_ode_inputs(n, y0, config)?;
let mut t = config.t_start;
let dt = config.dt;
let mut y = y0.to_vec();
let mut k1 = vec![0.0; n];
let mut k2 = vec![0.0; n];
let mut k3 = vec![0.0; n];
let mut k4 = vec![0.0; n];
let mut tmp = vec![0.0; n];
let mut times = vec![t];
let mut states = vec![y.clone()];
let mut num_steps = 0_usize;
let mut num_rhs = 0_usize;
while t < config.t_end - dt * 1e-10 && num_steps < config.max_steps {
let h = dt.min(config.t_end - t);
system.rhs(t, &y, &mut k1)?;
num_rhs += 1;
for i in 0..n {
tmp[i] = y[i] + 0.5 * h * k1[i];
}
system.rhs(t + 0.5 * h, &tmp, &mut k2)?;
num_rhs += 1;
for i in 0..n {
tmp[i] = y[i] + 0.5 * h * k2[i];
}
system.rhs(t + 0.5 * h, &tmp, &mut k3)?;
num_rhs += 1;
for i in 0..n {
tmp[i] = y[i] + h * k3[i];
}
system.rhs(t + h, &tmp, &mut k4)?;
num_rhs += 1;
for i in 0..n {
y[i] += h / 6.0 * (k1[i] + 2.0 * k2[i] + 2.0 * k3[i] + k4[i]);
}
t += h;
num_steps += 1;
times.push(t);
states.push(y.clone());
}
Ok(OdeSolution {
times,
states,
num_steps,
num_rejected: 0,
num_rhs_evals: num_rhs,
})
}
}
pub struct Rk45Solver;
impl Rk45Solver {
const A21: f64 = 1.0 / 5.0;
const A31: f64 = 3.0 / 40.0;
const A32: f64 = 9.0 / 40.0;
const A41: f64 = 44.0 / 45.0;
const A42: f64 = -56.0 / 15.0;
const A43: f64 = 32.0 / 9.0;
const A51: f64 = 19372.0 / 6561.0;
const A52: f64 = -25360.0 / 2187.0;
const A53: f64 = 64448.0 / 6561.0;
const A54: f64 = -212.0 / 729.0;
const A61: f64 = 9017.0 / 3168.0;
const A62: f64 = -355.0 / 33.0;
const A63: f64 = 46732.0 / 5247.0;
const A64: f64 = 49.0 / 176.0;
const A65: f64 = -5103.0 / 18656.0;
const B1: f64 = 35.0 / 384.0;
const B3: f64 = 500.0 / 1113.0;
const B4: f64 = 125.0 / 192.0;
const B5: f64 = -2187.0 / 6784.0;
const B6: f64 = 11.0 / 84.0;
const E1: f64 = 71.0 / 57600.0;
const E3: f64 = -71.0 / 16695.0;
const E4: f64 = 71.0 / 1920.0;
const E5: f64 = -17253.0 / 339200.0;
const E6: f64 = 22.0 / 525.0;
const E7: f64 = -1.0 / 40.0;
pub fn solve(
system: &dyn OdeSystem,
y0: &[f64],
config: &OdeConfig,
) -> SolverResult<OdeSolution> {
let n = system.dim();
validate_ode_inputs(n, y0, config)?;
let mut t = config.t_start;
let mut h = config.dt;
let mut y = y0.to_vec();
let mut k1 = vec![0.0; n];
let mut k2 = vec![0.0; n];
let mut k3 = vec![0.0; n];
let mut k4 = vec![0.0; n];
let mut k5 = vec![0.0; n];
let mut k6 = vec![0.0; n];
let mut k7 = vec![0.0; n];
let mut tmp = vec![0.0; n];
let mut y_new = vec![0.0; n];
let mut times = vec![t];
let mut states = vec![y.clone()];
let mut num_steps = 0_usize;
let mut num_rejected = 0_usize;
let mut num_rhs = 0_usize;
let safety = 0.9;
let min_factor = 0.2;
let max_factor = 5.0;
system.rhs(t, &y, &mut k1)?;
num_rhs += 1;
while t < config.t_end - 1e-14 * config.t_end.abs().max(1.0)
&& num_steps + num_rejected < config.max_steps
{
h = h.min(config.t_end - t);
for i in 0..n {
tmp[i] = y[i] + h * Self::A21 * k1[i];
}
system.rhs(t + h / 5.0, &tmp, &mut k2)?;
for i in 0..n {
tmp[i] = y[i] + h * (Self::A31 * k1[i] + Self::A32 * k2[i]);
}
system.rhs(t + 3.0 / 10.0 * h, &tmp, &mut k3)?;
for i in 0..n {
tmp[i] = y[i] + h * (Self::A41 * k1[i] + Self::A42 * k2[i] + Self::A43 * k3[i]);
}
system.rhs(t + 4.0 / 5.0 * h, &tmp, &mut k4)?;
for i in 0..n {
tmp[i] = y[i]
+ h * (Self::A51 * k1[i]
+ Self::A52 * k2[i]
+ Self::A53 * k3[i]
+ Self::A54 * k4[i]);
}
system.rhs(t + 8.0 / 9.0 * h, &tmp, &mut k5)?;
for i in 0..n {
tmp[i] = y[i]
+ h * (Self::A61 * k1[i]
+ Self::A62 * k2[i]
+ Self::A63 * k3[i]
+ Self::A64 * k4[i]
+ Self::A65 * k5[i]);
}
system.rhs(t + h, &tmp, &mut k6)?;
num_rhs += 5;
for i in 0..n {
y_new[i] = y[i]
+ h * (Self::B1 * k1[i]
+ Self::B3 * k3[i]
+ Self::B4 * k4[i]
+ Self::B5 * k5[i]
+ Self::B6 * k6[i]);
}
system.rhs(t + h, &y_new, &mut k7)?;
num_rhs += 1;
let mut err_norm = 0.0;
for i in 0..n {
let err_i = h
* (Self::E1 * k1[i]
+ Self::E3 * k3[i]
+ Self::E4 * k4[i]
+ Self::E5 * k5[i]
+ Self::E6 * k6[i]
+ Self::E7 * k7[i]);
let scale = config.atol + config.rtol * y_new[i].abs().max(y[i].abs());
err_norm += (err_i / scale).powi(2);
}
err_norm = (err_norm / n as f64).sqrt();
if err_norm <= 1.0 {
t += h;
y.copy_from_slice(&y_new);
num_steps += 1;
times.push(t);
states.push(y.clone());
k1.copy_from_slice(&k7);
let factor = if err_norm > 1e-15 {
(safety / err_norm.powf(0.2)).clamp(min_factor, max_factor)
} else {
max_factor
};
h *= factor;
} else {
num_rejected += 1;
let factor = (safety / err_norm.powf(0.2)).clamp(min_factor, 1.0);
h *= factor;
}
}
if num_steps + num_rejected >= config.max_steps && t < config.t_end - 1e-10 {
return Err(SolverError::ConvergenceFailure {
iterations: config.max_steps as u32,
residual: (config.t_end - t).abs(),
});
}
Ok(OdeSolution {
times,
states,
num_steps,
num_rejected,
num_rhs_evals: num_rhs,
})
}
}