use crate::common::IntegrateFloat;
use crate::error::IntegrateResult;
use crate::ode::types::{ODEMethod, ODEOptions, ODEResult};
use scirs2_core::ndarray::{Array1, ArrayView1};
#[allow(dead_code)]
pub fn rk45_method<F, Func>(
f: Func,
t_span: [F; 2],
y0: Array1<F>,
opts: ODEOptions<F>,
) -> IntegrateResult<ODEResult<F>>
where
F: IntegrateFloat,
Func: Fn(F, ArrayView1<F>) -> Array1<F>,
{
let [t_start, t_end] = t_span;
let n_dim = y0.len();
let h0 = opts.h0.unwrap_or_else(|| {
let _span = t_end - t_start;
_span / F::from_usize(100).expect("Operation failed")
});
let min_step = opts.min_step.unwrap_or_else(|| {
let _span = t_end - t_start;
_span * F::from_f64(1e-8).expect("Operation failed") });
let max_step = opts.max_step.unwrap_or_else(|| {
t_end - t_start });
let mut t = t_start;
let mut y = y0.clone();
let mut h = h0;
let mut t_values = vec![t_start];
let mut y_values = vec![y0.clone()];
let mut func_evals = 0;
let mut step_count = 0;
let mut accepted_steps = 0;
let mut rejected_steps = 0;
let c2 = F::from_f64(1.0 / 5.0).expect("Operation failed");
let c3 = F::from_f64(3.0 / 10.0).expect("Operation failed");
let c4 = F::from_f64(4.0 / 5.0).expect("Operation failed");
let c5 = F::from_f64(8.0 / 9.0).expect("Operation failed");
let c6 = F::one();
while t < t_end && step_count < opts.max_steps {
if t + h > t_end {
h = t_end - t;
}
h = h.min(max_step).max(min_step);
let k1 = f(t, y.view());
let mut y_stage = y.clone();
for i in 0..n_dim {
y_stage[i] = y[i] + h * F::from_f64(1.0 / 5.0).expect("Operation failed") * k1[i];
}
let k2 = f(t + c2 * h, y_stage.view());
let mut y_stage = y.clone();
for i in 0..n_dim {
y_stage[i] = y[i]
+ h * (F::from_f64(3.0 / 40.0).expect("Operation failed") * k1[i]
+ F::from_f64(9.0 / 40.0).expect("Operation failed") * k2[i]);
}
let k3 = f(t + c3 * h, y_stage.view());
let mut y_stage = y.clone();
for i in 0..n_dim {
y_stage[i] = y[i]
+ h * (F::from_f64(44.0 / 45.0).expect("Operation failed") * k1[i]
+ F::from_f64(-56.0 / 15.0).expect("Operation failed") * k2[i]
+ F::from_f64(32.0 / 9.0).expect("Operation failed") * k3[i]);
}
let k4 = f(t + c4 * h, y_stage.view());
let mut y_stage = y.clone();
for i in 0..n_dim {
y_stage[i] = y[i]
+ h * (F::from_f64(19372.0 / 6561.0).expect("Operation failed") * k1[i]
+ F::from_f64(-25360.0 / 2187.0).expect("Operation failed") * k2[i]
+ F::from_f64(64448.0 / 6561.0).expect("Operation failed") * k3[i]
+ F::from_f64(-212.0 / 729.0).expect("Operation failed") * k4[i]);
}
let k5 = f(t + c5 * h, y_stage.view());
let mut y_stage = y.clone();
for i in 0..n_dim {
y_stage[i] = y[i]
+ h * (F::from_f64(9017.0 / 3168.0).expect("Operation failed") * k1[i]
+ F::from_f64(-355.0 / 33.0).expect("Operation failed") * k2[i]
+ F::from_f64(46732.0 / 5247.0).expect("Operation failed") * k3[i]
+ F::from_f64(49.0 / 176.0).expect("Operation failed") * k4[i]
+ F::from_f64(-5103.0 / 18656.0).expect("Operation failed") * k5[i]);
}
let k6 = f(t + c6 * h, y_stage.view());
let mut y_stage = y.clone();
for i in 0..n_dim {
y_stage[i] = y[i]
+ h * (F::from_f64(35.0 / 384.0).expect("Operation failed") * k1[i]
+ F::zero() * k2[i]
+ F::from_f64(500.0 / 1113.0).expect("Operation failed") * k3[i]
+ F::from_f64(125.0 / 192.0).expect("Operation failed") * k4[i]
+ F::from_f64(-2187.0 / 6784.0).expect("Operation failed") * k5[i]
+ F::from_f64(11.0 / 84.0).expect("Operation failed") * k6[i]);
}
let k7 = f(t + h, y_stage.view());
func_evals += 7;
let mut y5 = y.clone();
for i in 0..n_dim {
y5[i] = y[i]
+ h * (F::from_f64(35.0 / 384.0).expect("Operation failed") * k1[i]
+ F::zero() * k2[i]
+ F::from_f64(500.0 / 1113.0).expect("Operation failed") * k3[i]
+ F::from_f64(125.0 / 192.0).expect("Operation failed") * k4[i]
+ F::from_f64(-2187.0 / 6784.0).expect("Operation failed") * k5[i]
+ F::from_f64(11.0 / 84.0).expect("Operation failed") * k6[i]
+ F::zero() * k7[i]);
}
let mut y4 = y.clone();
for i in 0..n_dim {
y4[i] = y[i]
+ h * (F::from_f64(5179.0 / 57600.0).expect("Operation failed") * k1[i]
+ F::zero() * k2[i]
+ F::from_f64(7571.0 / 16695.0).expect("Operation failed") * k3[i]
+ F::from_f64(393.0 / 640.0).expect("Operation failed") * k4[i]
+ F::from_f64(-92097.0 / 339200.0).expect("Operation failed") * k5[i]
+ F::from_f64(187.0 / 2100.0).expect("Operation failed") * k6[i]
+ F::from_f64(1.0 / 40.0).expect("Operation failed") * k7[i]);
}
let mut err_norm = F::zero();
for i in 0..n_dim {
let sc = opts.atol + opts.rtol * y5[i].abs();
let err = (y5[i] - y4[i]).abs() / sc;
err_norm = err_norm.max(err);
}
let order = F::from_f64(5.0).expect("Operation failed"); let exponent = F::one() / (order + F::one());
let safety = F::from_f64(0.9).expect("Operation failed");
let factor = safety * (F::one() / err_norm).powf(exponent);
let factor_min = F::from_f64(0.2).expect("Operation failed");
let factor_max = F::from_f64(5.0).expect("Operation failed");
let factor = factor.min(factor_max).max(factor_min);
if err_norm <= F::one() {
t += h;
y = y5;
t_values.push(t);
y_values.push(y.clone());
if err_norm <= F::from_f64(0.1).expect("Operation failed") {
h *= factor.max(F::from_f64(2.0).expect("Operation failed"));
} else {
h *= factor;
}
step_count += 1;
accepted_steps += 1;
} else {
h *= factor.min(F::one());
rejected_steps += 1;
if h < min_step {
return Err(crate::error::IntegrateError::StepSizeTooSmall(format!(
"Step size {h} too small at t {t}"
)));
}
}
}
let success = t >= t_end;
let message = if !success {
Some(format!(
"Maximum number of steps ({}) reached",
opts.max_steps
))
} else {
None
};
Ok(ODEResult {
t: t_values,
y: y_values,
success,
message,
n_eval: func_evals,
n_steps: step_count,
n_accepted: accepted_steps,
n_rejected: rejected_steps,
n_lu: 0, n_jac: 0, method: ODEMethod::RK45,
})
}
#[allow(dead_code)]
pub fn rk23_method<F, Func>(
f: Func,
t_span: [F; 2],
y0: Array1<F>,
opts: ODEOptions<F>,
) -> IntegrateResult<ODEResult<F>>
where
F: IntegrateFloat,
Func: Fn(F, ArrayView1<F>) -> Array1<F>,
{
let [t_start, t_end] = t_span;
let h0 = opts.h0.unwrap_or_else(|| {
let _span = t_end - t_start;
_span / F::from_usize(100).expect("Operation failed")
});
let min_step = opts.min_step.unwrap_or_else(|| {
let _span = t_end - t_start;
_span * F::from_f64(1e-8).expect("Operation failed") });
let max_step = opts.max_step.unwrap_or_else(|| {
t_end - t_start });
let mut t = t_start;
let mut y = y0.clone();
let mut h = h0;
let mut t_values = vec![t_start];
let mut y_values = vec![y0.clone()];
let mut func_evals = 0;
let mut step_count = 0;
let mut accepted_steps = 0;
let rejected_steps = 0;
while t < t_end && step_count < opts.max_steps {
if t + h > t_end {
h = t_end - t;
}
h = h.min(max_step).max(min_step);
let k1 = f(t, y.view());
func_evals += 1;
let y_next = y.clone() + k1.clone() * h;
t += h;
y = y_next;
t_values.push(t);
y_values.push(y.clone());
step_count += 1;
accepted_steps += 1;
}
let success = t >= t_end;
let message = if !success {
Some(format!(
"Maximum number of steps ({}) reached",
opts.max_steps
))
} else {
None
};
Ok(ODEResult {
t: t_values,
y: y_values,
success,
message,
n_eval: func_evals,
n_steps: step_count,
n_accepted: accepted_steps,
n_rejected: rejected_steps,
n_lu: 0, n_jac: 0, method: ODEMethod::RK23,
})
}
#[allow(dead_code)]
pub fn dop853_method<F, Func>(
f: Func,
t_span: [F; 2],
y0: Array1<F>,
opts: ODEOptions<F>,
) -> IntegrateResult<ODEResult<F>>
where
F: IntegrateFloat,
Func: Fn(F, ArrayView1<F>) -> Array1<F>,
{
let [t_start, t_end] = t_span;
let h0 = opts.h0.unwrap_or_else(|| {
let _span = t_end - t_start;
_span / F::from_usize(100).expect("Operation failed")
});
let min_step = opts.min_step.unwrap_or_else(|| {
let _span = t_end - t_start;
_span * F::from_f64(1e-8).expect("Operation failed") });
let max_step = opts.max_step.unwrap_or_else(|| {
t_end - t_start });
let mut t = t_start;
let mut y = y0.clone();
let mut h = h0;
let mut t_values = vec![t_start];
let mut y_values = vec![y0.clone()];
let mut func_evals = 0;
let mut step_count = 0;
let mut accepted_steps = 0;
let rejected_steps = 0;
while t < t_end && step_count < opts.max_steps {
if t + h > t_end {
h = t_end - t;
}
h = h.min(max_step).max(min_step);
let k1 = f(t, y.view());
func_evals += 1;
let y_next = y.clone() + k1.clone() * h;
t += h;
y = y_next;
t_values.push(t);
y_values.push(y.clone());
step_count += 1;
accepted_steps += 1;
}
let success = t >= t_end;
let message = if !success {
Some(format!(
"Maximum number of steps ({}) reached",
opts.max_steps
))
} else {
None
};
Ok(ODEResult {
t: t_values,
y: y_values,
success,
message,
n_eval: func_evals,
n_steps: step_count,
n_accepted: accepted_steps,
n_rejected: rejected_steps,
n_lu: 0, n_jac: 0, method: ODEMethod::DOP853,
})
}