use crate::error::IntegrateResult;
use crate::ode::types::{ODEMethod, ODEOptions, ODEResult};
use crate::IntegrateFloat;
use scirs2_core::ndarray::{Array1, ArrayView1};
#[allow(dead_code)]
pub fn euler_method<F, Func>(
f: Func,
t_span: [F; 2],
y0: Array1<F>,
h: F,
opts: ODEOptions<F>,
) -> IntegrateResult<ODEResult<F>>
where
F: IntegrateFloat,
Func: Fn(F, ArrayView1<F>) -> Array1<F>,
{
let [t_start, t_end] = t_span;
let step_size = h;
let mut t = t_start;
let mut y = y0.clone();
let mut t_values = vec![t_start];
let mut y_values = vec![y0.clone()];
let mut func_evals = 0;
let mut step_count = 0;
while t < t_end && step_count < opts.max_steps {
let next_t = if t + step_size > t_end {
t_end
} else {
t + step_size
};
let h_actual = next_t - t;
let dy = f(t, y.view());
func_evals += 1;
let y_next = y.clone() + dy * h_actual;
t = next_t;
y = y_next;
t_values.push(t);
y_values.push(y.clone());
step_count += 1;
}
let success = t >= t_end;
let message = if !success {
Some(format!(
"Maximum number of steps ({}) reached",
opts.max_steps
))
} else {
None
};
Ok(ODEResult {
t: t_values,
y: y_values,
success,
message,
n_eval: func_evals,
n_steps: step_count,
n_accepted: step_count, n_rejected: 0, n_lu: 0, n_jac: 0, method: ODEMethod::Euler,
})
}
#[allow(dead_code)]
pub fn rk4_method<F, Func>(
f: Func,
t_span: [F; 2],
y0: Array1<F>,
h: F,
opts: ODEOptions<F>,
) -> IntegrateResult<ODEResult<F>>
where
F: IntegrateFloat,
Func: Fn(F, ArrayView1<F>) -> Array1<F>,
{
let [t_start, t_end] = t_span;
let step_size = h;
let mut t = t_start;
let mut y = y0.clone();
let mut t_values = vec![t_start];
let mut y_values = vec![y0.clone()];
let mut func_evals = 0;
let mut step_count = 0;
let two = F::from_f64(2.0).expect("Operation failed");
let six = F::from_f64(6.0).expect("Operation failed");
while t < t_end && step_count < opts.max_steps {
let next_t = if t + step_size > t_end {
t_end
} else {
t + step_size
};
let h_actual = next_t - t;
let half_step = h_actual / two;
let k1 = f(t, y.view());
let k2 = f(t + half_step, (y.clone() + k1.clone() * half_step).view());
let k3 = f(t + half_step, (y.clone() + k2.clone() * half_step).view());
let k4 = f(t + h_actual, (y.clone() + k3.clone() * h_actual).view());
func_evals += 4;
let slope = (k1 + k2.clone() * two + k3.clone() * two + k4) / six;
let y_next = y.clone() + slope * h_actual;
t = next_t;
y = y_next;
t_values.push(t);
y_values.push(y.clone());
step_count += 1;
}
let success = t >= t_end;
let message = if !success {
Some(format!(
"Maximum number of steps ({}) reached",
opts.max_steps
))
} else {
None
};
Ok(ODEResult {
t: t_values,
y: y_values,
success,
message,
n_eval: func_evals,
n_steps: step_count,
n_accepted: step_count, n_rejected: 0, n_lu: 0, n_jac: 0, method: ODEMethod::RK4,
})
}
pub fn ssprk3_step<S, F>(state: &S, t: f64, dt: f64, rhs: &F) -> S
where
S: Clone + std::ops::Add<Output = S> + std::ops::Mul<f64, Output = S>,
F: Fn(&S, f64) -> S,
{
let l0 = rhs(state, t);
let u1 = state.clone() + l0 * dt;
let l1 = rhs(&u1, t + dt);
let u2 = state.clone() * (3.0 / 4.0) + (u1.clone() + l1 * dt) * (1.0 / 4.0);
let l2 = rhs(&u2, t + 0.5 * dt);
state.clone() * (1.0 / 3.0) + (u2 + l2 * dt) * (2.0 / 3.0)
}
pub fn ssprk4_step<S, F>(state: &S, t: f64, dt: f64, rhs: &F) -> S
where
S: Clone + std::ops::Add<Output = S> + std::ops::Mul<f64, Output = S>,
F: Fn(&S, f64) -> S,
{
const C1: f64 = 0.391_752_226_571_89;
const C2: f64 = 0.586_079_152_584_48;
const C3: f64 = 0.474_542_363_121_968;
const C4: f64 = 0.935_010_630_967_653;
let l0 = rhs(state, t);
let u1 = state.clone() + l0 * (C1 * dt);
let l1 = rhs(&u1, t + C1 * dt);
let u2 = state.clone() * 0.444_370_493_651_235
+ (u1.clone() + l1 * (C1 * dt)) * 0.555_629_506_348_765;
let l2 = rhs(&u2, t + C2 * dt);
let u3 = state.clone() * 0.620_101_851_488_403
+ (u2 + l2 * (0.251_891_774_271_694 * dt)) * 0.379_898_148_511_597;
let l3 = rhs(&u3, t + C3 * dt);
let u4 = state.clone() * 0.178_079_954_393_132
+ (u3 + l3 * (0.544_974_750_228_521 * dt)) * 0.821_920_045_606_868;
let l4 = rhs(&u4, t + C4 * dt);
state.clone() * 0.517_231_671_970_585
+ u1 * 0.096_059_710_526_147
+ (u4 + l4 * (0.226_007_483_236_906 * dt)) * 0.386_708_617_503_268
}