use crate::error::IntegrateResult;
use crate::ode::types::{ODEMethod, ODEOptions, ODEResult};
use crate::IntegrateFloat;
use scirs2_core::ndarray::{Array1, ArrayView1};
#[allow(dead_code)]
pub fn euler_method<F, Func>(
f: Func,
t_span: [F; 2],
y0: Array1<F>,
h: F,
opts: ODEOptions<F>,
) -> IntegrateResult<ODEResult<F>>
where
F: IntegrateFloat,
Func: Fn(F, ArrayView1<F>) -> Array1<F>,
{
let [t_start, t_end] = t_span;
let step_size = h;
let mut t = t_start;
let mut y = y0.clone();
let mut t_values = vec![t_start];
let mut y_values = vec![y0.clone()];
let mut func_evals = 0;
let mut step_count = 0;
while t < t_end && step_count < opts.max_steps {
let next_t = if t + step_size > t_end {
t_end
} else {
t + step_size
};
let h_actual = next_t - t;
let dy = f(t, y.view());
func_evals += 1;
let y_next = y.clone() + dy * h_actual;
t = next_t;
y = y_next;
t_values.push(t);
y_values.push(y.clone());
step_count += 1;
}
let success = t >= t_end;
let message = if !success {
Some(format!(
"Maximum number of steps ({}) reached",
opts.max_steps
))
} else {
None
};
Ok(ODEResult {
t: t_values,
y: y_values,
success,
message,
n_eval: func_evals,
n_steps: step_count,
n_accepted: step_count, n_rejected: 0, n_lu: 0, n_jac: 0, method: ODEMethod::Euler,
})
}
#[allow(dead_code)]
pub fn rk4_method<F, Func>(
f: Func,
t_span: [F; 2],
y0: Array1<F>,
h: F,
opts: ODEOptions<F>,
) -> IntegrateResult<ODEResult<F>>
where
F: IntegrateFloat,
Func: Fn(F, ArrayView1<F>) -> Array1<F>,
{
let [t_start, t_end] = t_span;
let step_size = h;
let mut t = t_start;
let mut y = y0.clone();
let mut t_values = vec![t_start];
let mut y_values = vec![y0.clone()];
let mut func_evals = 0;
let mut step_count = 0;
let two = F::from_f64(2.0).expect("Operation failed");
let six = F::from_f64(6.0).expect("Operation failed");
while t < t_end && step_count < opts.max_steps {
let next_t = if t + step_size > t_end {
t_end
} else {
t + step_size
};
let h_actual = next_t - t;
let half_step = h_actual / two;
let k1 = f(t, y.view());
let k2 = f(t + half_step, (y.clone() + k1.clone() * half_step).view());
let k3 = f(t + half_step, (y.clone() + k2.clone() * half_step).view());
let k4 = f(t + h_actual, (y.clone() + k3.clone() * h_actual).view());
func_evals += 4;
let slope = (k1 + k2.clone() * two + k3.clone() * two + k4) / six;
let y_next = y.clone() + slope * h_actual;
t = next_t;
y = y_next;
t_values.push(t);
y_values.push(y.clone());
step_count += 1;
}
let success = t >= t_end;
let message = if !success {
Some(format!(
"Maximum number of steps ({}) reached",
opts.max_steps
))
} else {
None
};
Ok(ODEResult {
t: t_values,
y: y_values,
success,
message,
n_eval: func_evals,
n_steps: step_count,
n_accepted: step_count, n_rejected: 0, n_lu: 0, n_jac: 0, method: ODEMethod::RK4,
})
}