use scirs2_core::ndarray::{Array1, Array2};
#[cfg(feature = "symbolic")]
use scirs2_symbolic::{
cas::{solve_ode, OdeKind, OdeSolution, SolveOdeError},
eml::{eval_real, EvalCtx, LoweredOp},
};
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub enum SolvePreference {
#[default]
SymbolicFirst,
ForceNumerical,
}
#[derive(Debug, Clone)]
pub struct OdeOpts {
pub max_steps: usize,
pub rtol: f64,
pub atol: f64,
pub preferred: SolvePreference,
}
impl Default for OdeOpts {
fn default() -> Self {
OdeOpts {
max_steps: 10_000,
rtol: 1e-6,
atol: 1e-8,
preferred: SolvePreference::SymbolicFirst,
}
}
}
#[cfg(feature = "symbolic")]
pub enum SymbolicOrNumericalResult {
Symbolic {
x_of_t: LoweredOp,
kind: OdeKind,
integration_constants: Vec<usize>,
},
Numerical {
trajectory: Array2<f64>,
time: Array1<f64>,
},
}
#[cfg(not(feature = "symbolic"))]
pub enum SymbolicOrNumericalResult {
Numerical {
trajectory: Array2<f64>,
time: Array1<f64>,
},
}
#[derive(Debug, Clone)]
pub enum SymbolicFirstError {
InvalidInterval,
NumericalFailed(String),
}
impl std::fmt::Display for SymbolicFirstError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidInterval => {
write!(f, "invalid ODE interval: t_end must be > t0")
}
Self::NumericalFailed(msg) => {
write!(f, "numerical solver failed: {msg}")
}
}
}
}
impl std::error::Error for SymbolicFirstError {}
#[cfg(feature = "symbolic")]
pub fn solve_ode_symbolic_or_numerical<F>(
rhs_symbolic: Option<&LoweredOp>,
rhs_numeric: F,
x_var: usize,
t_var: usize,
ic: (f64, f64),
t_end: f64,
opts: &OdeOpts,
) -> Result<SymbolicOrNumericalResult, SymbolicFirstError>
where
F: Fn(f64, f64) -> f64,
{
let (t0, _x0) = ic;
if t_end <= t0 {
return Err(SymbolicFirstError::InvalidInterval);
}
if opts.preferred == SolvePreference::SymbolicFirst {
if let Some(sym_rhs) = rhs_symbolic {
match solve_ode(sym_rhs, x_var, t_var, Some(ic)) {
Ok(OdeSolution {
x_of_t,
kind,
integration_constants,
}) => {
return Ok(SymbolicOrNumericalResult::Symbolic {
x_of_t,
kind,
integration_constants,
});
}
Err(_e) => {
}
}
}
}
let (trajectory, time) = rk4_fixed(&rhs_numeric, t0, ic.1, t_end, opts.max_steps)?;
Ok(SymbolicOrNumericalResult::Numerical { trajectory, time })
}
#[cfg(not(feature = "symbolic"))]
pub fn solve_ode_symbolic_or_numerical<F>(
rhs_numeric: F,
ic: (f64, f64),
t_end: f64,
opts: &OdeOpts,
) -> Result<SymbolicOrNumericalResult, SymbolicFirstError>
where
F: Fn(f64, f64) -> f64,
{
let (t0, x0) = ic;
if t_end <= t0 {
return Err(SymbolicFirstError::InvalidInterval);
}
let (trajectory, time) = rk4_fixed(&rhs_numeric, t0, x0, t_end, opts.max_steps)?;
Ok(SymbolicOrNumericalResult::Numerical { trajectory, time })
}
#[cfg(feature = "symbolic")]
pub fn rhs_from_symbolic_only(
sym_rhs: LoweredOp,
x_var: usize,
t_var: usize,
) -> impl Fn(f64, f64) -> f64 {
let capacity = x_var.max(t_var) + 1;
move |t: f64, x: f64| {
let mut bindings = vec![0.0_f64; capacity];
bindings[t_var] = t;
bindings[x_var] = x;
let ctx = EvalCtx::new(&bindings);
eval_real(&sym_rhs, &ctx).unwrap_or(f64::NAN)
}
}
fn rk4_fixed<F>(
f: &F,
t0: f64,
x0: f64,
t_end: f64,
n_steps: usize,
) -> Result<(Array2<f64>, Array1<f64>), SymbolicFirstError>
where
F: Fn(f64, f64) -> f64,
{
let h = (t_end - t0) / (n_steps as f64);
let mut t = t0;
let mut x = x0;
let capacity = n_steps + 1;
let mut flat: Vec<f64> = Vec::with_capacity(capacity * 2);
let mut time_vec: Vec<f64> = Vec::with_capacity(capacity);
flat.push(t);
flat.push(x);
time_vec.push(t);
for _ in 0..n_steps {
let k1 = f(t, x);
let k2 = f(t + h / 2.0, x + h * k1 / 2.0);
let k3 = f(t + h / 2.0, x + h * k2 / 2.0);
let k4 = f(t + h, x + h * k3);
x += h * (k1 + 2.0 * k2 + 2.0 * k3 + k4) / 6.0;
t += h;
flat.push(t);
flat.push(x);
time_vec.push(t);
}
let n = time_vec.len();
let trajectory = Array2::from_shape_vec((n, 2), flat)
.map_err(|e| SymbolicFirstError::NumericalFailed(e.to_string()))?;
let time = Array1::from_vec(time_vec);
Ok((trajectory, time))
}