use crate::DType;
mod bdf;
mod bvp;
mod dae;
mod dae_helpers;
mod dae_ic;
mod dae_jacobian;
pub mod dense_output;
#[cfg(feature = "sparse")]
pub(crate) mod direct_solver;
#[cfg(feature = "sparse")]
pub(crate) mod direct_solver_config;
mod dop853;
pub mod events;
mod jacobian;
mod lsoda;
mod radau;
mod rk23;
mod rk45;
#[cfg(feature = "sparse")]
mod sparse_utils;
#[cfg(feature = "sparse")]
pub(crate) mod sparsity_detection;
#[cfg(feature = "sparse")]
pub(crate) mod symbolic_analysis;
#[cfg(feature = "sparse")]
pub use direct_solver::DirectSparseSolver;
#[cfg(feature = "sparse")]
pub use direct_solver_config::{DirectSolverConfig, SparseSolverStrategy};
#[cfg(feature = "sparse")]
pub use sparse_utils::SparseJacobianCache;
#[cfg(feature = "sparse")]
pub use sparsity_detection::{detect_jacobian_sparsity, sparsity_ratio};
mod step_control;
pub(crate) mod stiff_client;
mod symplectic;
pub use bdf::bdf_impl;
pub use bvp::bvp_impl;
pub use dae::dae_impl;
pub use dae_jacobian::{compute_dae_jacobian, eval_dae_primal};
pub use dense_output::{DenseOutputStep, dense_eval};
pub use dop853::dop853_impl;
pub use events::{EventCheckResult, check_events, evaluate_events, handle_terminal_event};
pub use jacobian::{
compute_iteration_matrix, compute_jacobian_autograd, compute_norm, compute_norm_scalar,
eval_primal,
};
pub use lsoda::lsoda_impl;
pub use radau::radau_impl;
pub use rk23::rk23_impl;
pub use rk45::{rk45_impl, rk45_with_events_impl};
pub use step_control::*;
pub use symplectic::{leapfrog_impl, verlet_impl};
use numr::error::Result;
use numr::ops::TensorOps;
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
use crate::integrate::error::{IntegrateError, IntegrateResult};
use crate::integrate::{ODEMethod, ODEOptions};
pub struct ODEResultParams<'a, R: Runtime<DType = DType>> {
pub t_values: &'a [f64],
pub y_values: &'a [Tensor<R>],
pub success: bool,
pub message: Option<String>,
pub nfev: usize,
pub naccept: usize,
pub nreject: usize,
}
pub fn build_ode_result<R, C>(
client: &C,
params: ODEResultParams<R>,
method: ODEMethod,
) -> IntegrateResult<ODEResultTensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + RuntimeClient<R>,
{
let n_steps = params.t_values.len();
let t_tensor = Tensor::<R>::from_slice(params.t_values, &[n_steps], client.device());
let y_refs: Vec<&Tensor<R>> = params.y_values.iter().collect();
let y_tensor = client
.stack(&y_refs, 0)
.map_err(|e| IntegrateError::InvalidInput {
context: format!("Failed to stack y tensors: {}", e),
})?;
Ok(ODEResultTensor {
t: t_tensor,
y: y_tensor,
success: params.success,
message: params.message,
nfev: params.nfev,
naccept: params.naccept,
nreject: params.nreject,
method,
})
}
#[derive(Debug, Clone)]
pub struct ODEResultTensor<R: Runtime<DType = DType>> {
pub t: Tensor<R>,
pub y: Tensor<R>,
pub success: bool,
pub message: Option<String>,
pub nfev: usize,
pub naccept: usize,
pub nreject: usize,
pub method: ODEMethod,
}
impl<R: Runtime<DType = DType>> ODEResultTensor<R> {
pub fn y_final(&self) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
{
let shape = self.y.shape();
if shape.len() != 2 || shape[0] == 0 {
return Err(numr::error::Error::InvalidArgument {
arg: "y",
reason: "Expected 2D tensor with at least one row".to_string(),
});
}
Ok(self.y.clone())
}
pub fn y_final_vec(&self) -> Vec<f64> {
let shape = self.y.shape();
if shape.len() != 2 || shape[0] == 0 {
return vec![];
}
let n_steps = shape[0];
let n_vars = shape[1];
let all_data: Vec<f64> = self.y.to_vec();
let last_row_start = (n_steps - 1) * n_vars;
all_data[last_row_start..].to_vec()
}
}
pub fn solve_ivp_impl<R, C, F>(
client: &C,
f: F,
t_span: [f64; 2],
y0: &Tensor<R>,
options: &ODEOptions,
) -> IntegrateResult<ODEResultTensor<R>>
where
R: Runtime<DType = DType>,
C: numr::ops::TensorOps<R> + numr::ops::ScalarOps<R> + numr::runtime::RuntimeClient<R>,
F: Fn(&Tensor<R>, &Tensor<R>) -> Result<Tensor<R>>,
{
let [t_start, t_end] = t_span;
if t_start >= t_end {
return Err(IntegrateError::InvalidInterval {
a: t_start,
b: t_end,
context: "solve_ivp".to_string(),
});
}
if y0.shape().is_empty() || y0.shape()[0] == 0 {
return Err(IntegrateError::InvalidInput {
context: "solve_ivp: initial condition cannot be empty".to_string(),
});
}
match options.method {
ODEMethod::RK23 => rk23_impl(client, f, t_span, y0, options),
ODEMethod::RK45 => rk45_impl(client, f, t_span, y0, options),
ODEMethod::DOP853 => dop853_impl(client, f, t_span, y0, options),
ODEMethod::BDF | ODEMethod::Radau | ODEMethod::LSODA => Err(IntegrateError::InvalidInput {
context: format!(
"Method {:?} requires using the dedicated solver function (e.g., solve_ivp_bdf)",
options.method
),
}),
ODEMethod::Verlet | ODEMethod::Leapfrog => Err(IntegrateError::InvalidInput {
context: format!(
"Symplectic method {:?} requires using verlet() or leapfrog() with q0, p0",
options.method
),
}),
}
}