use crate::common::IntegrateFloat;
use crate::error::{IntegrateError, IntegrateResult};
use crate::ode::methods::{
bdf_method, dop853_method, enhanced_bdf_method, enhanced_lsoda_method, euler_method,
lsoda_method, radau_method, radau_method_with_mass, rk23_method, rk45_method, rk4_method,
};
use crate::ode::types::{MassMatrix, MassMatrixType, ODEMethod, ODEOptions, ODEResult};
use crate::ode::utils::dense_output::DenseSolution;
use crate::ode::utils::events::{
EventAction, EventHandler, ODEOptionsWithEvents, ODEResultWithEvents,
};
use crate::ode::utils::interpolation::ContinuousOutputMethod;
use crate::ode::utils::mass_matrix;
use scirs2_core::ndarray::{Array1, ArrayView1};
#[allow(dead_code)]
pub fn solve_ivp<F, Func>(
f: Func,
t_span: [F; 2],
y0: Array1<F>,
options: Option<ODEOptions<F>>,
) -> IntegrateResult<ODEResult<F>>
where
F: IntegrateFloat + std::iter::Sum + std::default::Default,
Func: Fn(F, ArrayView1<F>) -> Array1<F> + Clone,
{
let opts = options.unwrap_or_default();
if let Some(mass) = &opts.mass_matrix {
return solve_ivp_with_mass_internal(f, t_span, y0, mass.clone(), opts);
}
let [t_start, t_end] = t_span;
let h0 = opts.h0.unwrap_or_else(|| {
let _span = t_end - t_start;
_span * F::from_f64(0.01).expect("Operation failed") });
match opts.method {
ODEMethod::Euler => euler_method(f, t_span, y0, h0, opts),
ODEMethod::RK4 => rk4_method(f, t_span, y0, h0, opts),
ODEMethod::RK45 => rk45_method(f, t_span, y0, opts),
ODEMethod::RK23 => rk23_method(f, t_span, y0, opts),
ODEMethod::Bdf => bdf_method(f, t_span, y0, opts),
ODEMethod::DOP853 => dop853_method(f, t_span, y0, opts),
ODEMethod::Radau => radau_method(f, t_span, y0, opts),
ODEMethod::LSODA => {
let lsoda_opts = ODEOptions {
h0: opts.h0.or_else(|| {
let _span = t_span[1] - t_span[0];
Some(_span * F::from_f64(0.05).expect("Operation failed")) }),
min_step: opts.min_step.or_else(|| {
let _span = t_span[1] - t_span[0];
Some(_span * F::from_f64(0.0001).expect("Operation failed"))
}),
..opts
};
lsoda_method(f, t_span, y0, lsoda_opts)
}
ODEMethod::EnhancedLSODA => {
let enhanced_opts = ODEOptions {
h0: opts.h0.or_else(|| {
let _span = t_span[1] - t_span[0];
Some(_span * F::from_f64(0.05).expect("Operation failed")) }),
min_step: opts.min_step.or_else(|| {
let _span = t_span[1] - t_span[0];
Some(_span * F::from_f64(0.0001).expect("Operation failed"))
}),
max_steps: if opts.max_steps == 500 {
1000
} else {
opts.max_steps
},
..opts
};
enhanced_lsoda_method(f, t_span, y0, enhanced_opts)
}
ODEMethod::EnhancedBDF => {
let enhanced_bdf_opts = ODEOptions {
h0: opts.h0.or_else(|| {
let _span = t_span[1] - t_span[0];
Some(_span * F::from_f64(0.01).expect("Operation failed"))
}),
max_steps: if opts.max_steps == 500 {
1000
} else {
opts.max_steps
},
max_order: opts.max_order.or(Some(3)),
..opts
};
enhanced_bdf_method(f, t_span, y0, enhanced_bdf_opts)
}
}
}
#[allow(dead_code)]
fn solve_ivp_with_mass_internal<F, Func>(
f: Func,
t_span: [F; 2],
y0: Array1<F>,
mass_matrix: MassMatrix<F>,
opts: ODEOptions<F>,
) -> IntegrateResult<ODEResult<F>>
where
F: IntegrateFloat + std::iter::Sum + std::default::Default,
Func: Fn(F, ArrayView1<F>) -> Array1<F> + Clone,
{
mass_matrix::check_mass_compatibility(&mass_matrix, t_span[0], y0.view())?;
match mass_matrix.matrix_type {
MassMatrixType::Identity => {
let mut new_opts = opts.clone();
new_opts.mass_matrix = None;
solve_ivp(f, t_span, y0, Some(new_opts))
}
MassMatrixType::Constant | MassMatrixType::TimeDependent => {
match opts.method {
ODEMethod::Radau => {
crate::ode::methods::radau_method_with_mass(f, t_span, y0, mass_matrix, opts)
}
ODEMethod::Bdf | ODEMethod::EnhancedBDF => {
solve_bdf_with_mass_matrix(f, t_span, y0, mass_matrix, opts)
}
_ => {
let f_clone = f.clone();
let transformed_f =
mass_matrix::transform_to_standard_form(f_clone, &mass_matrix);
let wrapper_f = move |t: F, y: ArrayView1<F>| -> Array1<F> {
transformed_f(t, y).unwrap_or_else(|_| {
Array1::zeros(y.len())
})
};
let mut new_opts = opts.clone();
new_opts.mass_matrix = None;
let [t_start, t_end] = t_span;
let h0 = new_opts.h0.unwrap_or_else(|| {
let _span = t_end - t_start;
_span * F::from_f64(0.01).expect("Operation failed") });
match new_opts.method {
ODEMethod::Euler => euler_method(wrapper_f, t_span, y0, h0, new_opts),
ODEMethod::RK4 => rk4_method(wrapper_f, t_span, y0, h0, new_opts),
ODEMethod::RK45 => rk45_method(wrapper_f, t_span, y0, new_opts),
ODEMethod::RK23 => rk23_method(wrapper_f, t_span, y0, new_opts),
ODEMethod::DOP853 => dop853_method(wrapper_f, t_span, y0, new_opts),
ODEMethod::Radau => radau_method_with_mass(
wrapper_f,
t_span,
y0,
mass_matrix.clone(),
new_opts,
),
ODEMethod::Bdf => bdf_method(wrapper_f, t_span, y0, new_opts),
ODEMethod::EnhancedBDF => {
enhanced_bdf_method(wrapper_f, t_span, y0, new_opts)
}
ODEMethod::LSODA => lsoda_method(wrapper_f, t_span, y0, new_opts),
ODEMethod::EnhancedLSODA => {
enhanced_lsoda_method(wrapper_f, t_span, y0, new_opts)
}
}
}
}
}
MassMatrixType::StateDependent => {
match opts.method {
ODEMethod::Radau => {
crate::ode::methods::radau_method_with_mass(f, t_span, y0, mass_matrix, opts)
}
ODEMethod::Bdf | ODEMethod::EnhancedBDF => {
solve_bdf_with_state_dependent_mass_matrix(f, t_span, y0, mass_matrix, opts)
}
_ => {
let wrapper_f = move |t: F, y: ArrayView1<F>| -> Array1<F> {
let rhs = f(t, y);
match mass_matrix::solve_mass_system(&mass_matrix, t, y, rhs.view()) {
Ok(result) => result,
Err(_) => {
Array1::zeros(y.len())
}
}
};
let mut new_opts = opts.clone();
new_opts.mass_matrix = None;
let h0 = new_opts.h0.unwrap_or_else(|| {
let _span = t_span[1] - t_span[0];
_span * F::from_f64(0.01).expect("Operation failed") });
match new_opts.method {
ODEMethod::Euler => euler_method(wrapper_f, t_span, y0, h0, new_opts),
ODEMethod::RK4 => rk4_method(wrapper_f, t_span, y0, h0, new_opts),
ODEMethod::RK45 => rk45_method(wrapper_f, t_span, y0, new_opts),
ODEMethod::RK23 => rk23_method(wrapper_f, t_span, y0, new_opts),
ODEMethod::DOP853 => dop853_method(wrapper_f, t_span, y0, new_opts),
ODEMethod::Radau => radau_method(wrapper_f, t_span, y0, new_opts),
ODEMethod::LSODA => lsoda_method(wrapper_f, t_span, y0, new_opts),
ODEMethod::EnhancedLSODA => {
enhanced_lsoda_method(wrapper_f, t_span, y0, new_opts)
}
ODEMethod::EnhancedBDF => {
enhanced_bdf_method(wrapper_f, t_span, y0, new_opts)
}
ODEMethod::Bdf => {
Err(IntegrateError::NotImplementedError(
"BDF method should not reach this case".to_string(),
))
}
}
}
}
}
}
}
#[allow(dead_code)]
pub fn solve_ivp_with_events<F, Func, EventFunc>(
f: Func,
t_span: [F; 2],
y0: Array1<F>,
event_funcs: Vec<EventFunc>,
options: ODEOptionsWithEvents<F>,
) -> IntegrateResult<ODEResultWithEvents<F>>
where
F: IntegrateFloat + std::default::Default,
Func: Fn(F, ArrayView1<F>) -> Array1<F> + Clone + 'static,
EventFunc: Fn(F, ArrayView1<F>) -> F,
{
let mut base_options = options.base_options.clone();
base_options.dense_output = true;
if let Some(mass_matrix) = &base_options.mass_matrix {
return solve_ivp_with_events_and_mass(
f,
t_span,
y0,
event_funcs,
options,
mass_matrix.clone(),
);
}
let base_result = solve_ivp(f.clone(), t_span, y0.clone(), Some(base_options))?;
let dense_output = if !base_result.t.is_empty() {
let dense = DenseSolution::new(
base_result.t.clone(),
base_result.y.clone(),
None, Some(ContinuousOutputMethod::CubicHermite),
Some(Box::new(f.clone())), );
Some(dense)
} else {
None
};
let mut event_handler = EventHandler::new(options.event_specs.clone());
event_handler.initialize(base_result.t[0], &base_result.y[0], &event_funcs)?;
let mut event_termination = false;
for i in 1..base_result.t.len() {
let t = base_result.t[i];
let y = &base_result.y[i];
let action = event_handler.check_events(t, y, dense_output.as_ref(), &event_funcs)?;
if action == EventAction::Stop {
event_termination = true;
break;
}
}
let final_result = if event_termination {
let last_event = event_handler.record.events.last().ok_or_else(|| {
IntegrateError::ValueError("No event found for termination".to_string())
})?;
let mut event_index = base_result.t.len();
let mut exact_match = false;
for (i, &t) in base_result.t.iter().enumerate() {
if (t - last_event.time).abs() < F::from_f64(1e-10).expect("Operation failed") {
event_index = i + 1;
exact_match = true;
break;
} else if t > last_event.time {
event_index = i;
break;
}
}
let mut truncated_t = base_result.t[..event_index].to_vec();
let mut truncated_y = base_result.y[..event_index].to_vec();
if !exact_match && event_index > 0 {
let last_t = truncated_t.last().copied().unwrap_or(F::zero());
if (last_t - last_event.time).abs() > F::from_f64(1e-10).expect("Operation failed") {
truncated_t.push(last_event.time);
truncated_y.push(last_event.state.clone());
}
}
ODEResult {
t: truncated_t,
y: truncated_y,
message: base_result.message,
success: base_result.success,
n_steps: base_result.n_steps,
n_eval: base_result.n_eval,
n_accepted: base_result.n_accepted,
n_rejected: base_result.n_rejected,
n_lu: base_result.n_lu,
n_jac: base_result.n_jac,
method: base_result.method,
}
} else {
base_result
};
let result_with_events = ODEResultWithEvents::new(
final_result,
event_handler.record,
dense_output,
event_termination,
);
Ok(result_with_events)
}
#[allow(dead_code)]
fn solve_ivp_with_events_and_mass<F, Func, EventFunc>(
f: Func,
t_span: [F; 2],
y0: Array1<F>,
event_funcs: Vec<EventFunc>,
options: ODEOptionsWithEvents<F>,
mass_matrix: MassMatrix<F>,
) -> IntegrateResult<ODEResultWithEvents<F>>
where
F: IntegrateFloat + std::iter::Sum + std::default::Default,
Func: Fn(F, ArrayView1<F>) -> Array1<F> + Clone + 'static,
EventFunc: Fn(F, ArrayView1<F>) -> F,
{
mass_matrix::check_mass_compatibility(&mass_matrix, t_span[0], y0.view())?;
match mass_matrix.matrix_type {
MassMatrixType::Identity => {
let mut modified_options = options;
modified_options.base_options.mass_matrix = None;
solve_ivp_with_events(f, t_span, y0, event_funcs, modified_options)
}
MassMatrixType::Constant
| MassMatrixType::TimeDependent
| MassMatrixType::StateDependent => {
match options.base_options.method {
ODEMethod::Radau => solve_ivp_with_events_radau_mass(
f,
t_span,
y0,
event_funcs,
options,
mass_matrix,
),
_ => {
match mass_matrix.matrix_type {
MassMatrixType::Constant | MassMatrixType::TimeDependent => {
let f_clone = f.clone();
let mass_clone = mass_matrix.clone();
let transformed_f = move |t: F, y: ArrayView1<F>| -> Array1<F> {
let rhs = f_clone(t, y);
match mass_matrix::solve_mass_system(&mass_clone, t, y, rhs.view()) {
Ok(result) => result,
Err(_) => Array1::zeros(y.len()), }
};
let mut modified_options = options;
modified_options.base_options.mass_matrix = None;
let base_result = solve_ivp(transformed_f, t_span, y0, Some(modified_options.base_options))?;
let empty_events = crate::ode::utils::events::EventRecord::new();
Ok(ODEResultWithEvents {
base_result,
events: empty_events,
dense_output: None,
event_termination: false,
})
}
MassMatrixType::StateDependent => {
Err(IntegrateError::NotImplementedError(
"Event detection with state-dependent mass matrices is only supported with the Radau method".to_string()
))
}
MassMatrixType::Identity => unreachable!(),
}
}
}
}
}
}
#[allow(dead_code)]
fn solve_ivp_with_events_radau_mass<F, Func, EventFunc>(
f: Func,
t_span: [F; 2],
y0: Array1<F>,
event_funcs: Vec<EventFunc>,
options: ODEOptionsWithEvents<F>,
mass_matrix: MassMatrix<F>,
) -> IntegrateResult<ODEResultWithEvents<F>>
where
F: IntegrateFloat + std::iter::Sum + std::default::Default,
Func: Fn(F, ArrayView1<F>) -> Array1<F> + Clone + 'static,
EventFunc: Fn(F, ArrayView1<F>) -> F,
{
use crate::ode::methods::radau_mass::radau_method_with_mass;
use crate::ode::utils::events::{EventAction, EventHandler};
let mut event_handler = EventHandler::new(options.event_specs.clone());
let mut all_t = vec![t_span[0]];
let mut all_y = vec![y0.clone()];
let _all_dy: Vec<Array1<F>> = Vec::new();
let mut current_t = t_span[0];
let mut current_y = y0.clone();
let t_end = t_span[1];
let mut total_n_eval = 0;
let mut total_n_jac = 0;
let mut total_n_lu = 0;
let mut total_n_steps = 0;
let mut total_n_accepted = 0;
let mut total_n_rejected = 0;
event_handler.initialize(current_t, ¤t_y, &event_funcs)?;
let max_step_size = (t_end - current_t) / F::from_f64(100.0).expect("Operation failed");
while current_t < t_end {
let next_t = (current_t + max_step_size).min(t_end);
let step_result = radau_method_with_mass(
f.clone(),
[current_t, next_t],
current_y.clone(),
mass_matrix.clone(),
options.base_options.clone(),
)?;
total_n_eval += step_result.n_eval;
total_n_jac += step_result.n_jac;
total_n_lu += step_result.n_lu;
total_n_steps += step_result.n_steps;
total_n_accepted += step_result.n_accepted;
total_n_rejected += step_result.n_rejected;
for i in 1..step_result.t.len() {
let step_t = step_result.t[i];
let step_y = &step_result.y[i];
let action = event_handler.check_events(step_t, step_y, None, &event_funcs)?;
all_t.push(step_t);
all_y.push(step_y.clone());
if action == EventAction::Stop {
let base_result = ODEResult {
t: all_t,
y: all_y,
success: true,
message: Some("Integration stopped due to event".to_string()),
n_eval: total_n_eval,
n_steps: total_n_steps,
n_accepted: total_n_accepted,
n_rejected: total_n_rejected,
n_lu: total_n_lu,
n_jac: total_n_jac,
method: ODEMethod::Radau,
};
return Ok(ODEResultWithEvents::new(
base_result,
event_handler.record,
None, true, ));
}
}
current_t = step_result.t.last().copied().unwrap_or(current_t);
current_y = step_result.y.last().cloned().unwrap_or(current_y);
}
let base_result = ODEResult {
t: all_t,
y: all_y,
success: current_t >= t_end,
message: Some("Integration completed successfully".to_string()),
n_eval: total_n_eval,
n_steps: total_n_steps,
n_accepted: total_n_accepted,
n_rejected: total_n_rejected,
n_lu: total_n_lu,
n_jac: total_n_jac,
method: ODEMethod::Radau,
};
Ok(ODEResultWithEvents::new(
base_result,
event_handler.record,
None, false, ))
}
#[allow(dead_code)]
fn solve_bdf_with_mass_matrix<F, FFunc>(
f: FFunc,
t_span: [F; 2],
y0: Array1<F>,
mass_matrix: MassMatrix<F>,
mut opts: ODEOptions<F>,
) -> IntegrateResult<ODEResult<F>>
where
F: IntegrateFloat + std::default::Default,
FFunc: Fn(F, ArrayView1<F>) -> Array1<F> + Clone,
{
use crate::ode::methods::enhanced_bdf_method;
let mass_f = move |t: F, y: ArrayView1<F>| -> Array1<F> {
let rhs = f(t, y);
match mass_matrix.matrix_type {
MassMatrixType::Identity => {
rhs
}
MassMatrixType::Constant => {
rhs }
MassMatrixType::TimeDependent => {
rhs
}
MassMatrixType::StateDependent => {
rhs
}
}
};
opts.mass_matrix = Some(mass_matrix);
enhanced_bdf_method(mass_f, t_span, y0, opts)
}
#[allow(dead_code)]
fn solve_bdf_with_state_dependent_mass_matrix<F, FFunc>(
f: FFunc,
t_span: [F; 2],
y0: Array1<F>,
mass_matrix: MassMatrix<F>,
mut opts: ODEOptions<F>,
) -> IntegrateResult<ODEResult<F>>
where
F: IntegrateFloat + std::default::Default,
FFunc: Fn(F, ArrayView1<F>) -> Array1<F> + Clone,
{
let mass_f = move |t: F, y: ArrayView1<F>| -> Array1<F> {
f(t, y)
};
opts.mass_matrix = Some(mass_matrix);
opts.rtol = opts.rtol.min(F::from_f64(1e-8).unwrap_or(opts.rtol));
enhanced_bdf_method(mass_f, t_span, y0, opts)
}