use crate::error::{IntegrateError, IntegrateResult};
use crate::ode::utils::interpolation::{
cubic_hermite_interpolation, linear_interpolation, ContinuousOutputMethod,
};
use crate::IntegrateFloat;
use scirs2_core::ndarray::{Array1, ArrayView1};
use std::fmt::Debug;
type DerivativeFunction<F> = Box<dyn Fn(F, ArrayView1<F>) -> Array1<F>>;
pub struct DenseSolution<F: IntegrateFloat> {
pub t: Vec<F>,
pub y: Vec<Array1<F>>,
pub dydt: Option<Vec<Array1<F>>>,
pub method: ContinuousOutputMethod,
pub f: Option<DerivativeFunction<F>>,
}
impl<F: IntegrateFloat> Debug for DenseSolution<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DenseSolution")
.field("t", &self.t)
.field("y", &self.y)
.field("dydt", &self.dydt)
.field("method", &self.method)
.field("f", &"<closure>")
.finish()
}
}
impl<F: IntegrateFloat> DenseSolution<F> {
pub fn new(
t: Vec<F>,
y: Vec<Array1<F>>,
dydt: Option<Vec<Array1<F>>>,
method: Option<ContinuousOutputMethod>,
f: Option<DerivativeFunction<F>>,
) -> Self {
let interp_method = method.unwrap_or_default();
if interp_method == ContinuousOutputMethod::CubicHermite && dydt.is_none() && f.is_none() {
return DenseSolution {
t,
y,
dydt: None,
method: ContinuousOutputMethod::Linear,
f: None,
};
}
DenseSolution {
t,
y,
dydt,
method: interp_method,
f,
}
}
pub fn evaluate(&self, t: F) -> IntegrateResult<Array1<F>> {
let t_min = self
.t
.first()
.ok_or_else(|| IntegrateError::ComputationError("Empty solution".to_string()))?;
let t_max = self
.t
.last()
.ok_or_else(|| IntegrateError::ComputationError("Empty solution".to_string()))?;
if t < *t_min || t > *t_max {
return Err(IntegrateError::ValueError(format!(
"Evaluation time {t} is outside the solution range [{t_min}, {t_max}]"
)));
}
for (i, &ti) in self.t.iter().enumerate() {
if (t - ti).abs() < F::from_f64(1e-14).expect("Operation failed") {
return Ok(self.y[i].clone());
}
}
match self.method {
ContinuousOutputMethod::Linear => {
Ok(linear_interpolation(&self.t, &self.y, t))
}
ContinuousOutputMethod::CubicHermite => {
if let Some(ref dydt) = self.dydt {
Ok(cubic_hermite_interpolation(&self.t, &self.y, dydt, t))
} else if let Some(ref f) = self.f {
let mut dydt = Vec::with_capacity(self.t.len());
for i in 0..self.t.len() {
dydt.push(f(self.t[i], self.y[i].view()));
}
Ok(cubic_hermite_interpolation(&self.t, &self.y, &dydt, t))
} else {
Ok(linear_interpolation(&self.t, &self.y, t))
}
}
ContinuousOutputMethod::MethodSpecific => {
if let Some(ref dydt) = self.dydt {
Ok(cubic_hermite_interpolation(&self.t, &self.y, dydt, t))
} else if let Some(ref f) = self.f {
let mut dydt = Vec::with_capacity(self.t.len());
for i in 0..self.t.len() {
dydt.push(f(self.t[i], self.y[i].view()));
}
Ok(cubic_hermite_interpolation(&self.t, &self.y, &dydt, t))
} else {
Ok(linear_interpolation(&self.t, &self.y, t))
}
}
}
}
pub fn dense_output(&self, npoints: usize) -> IntegrateResult<(Vec<F>, Vec<Array1<F>>)> {
if self.t.is_empty() {
return Err(IntegrateError::ComputationError(
"Empty solution".to_string(),
));
}
let t_min = *self.t.first().expect("Operation failed");
let t_max = *self.t.last().expect("Operation failed");
let dt = (t_max - t_min) / F::from_usize(npoints - 1).expect("Operation failed");
let mut times = Vec::with_capacity(npoints);
let mut values = Vec::with_capacity(npoints);
for i in 0..npoints {
let t = t_min + dt * F::from_usize(i).expect("Operation failed");
times.push(t);
values.push(self.evaluate(t)?);
}
Ok((times, values))
}
pub fn extract_component(
&self,
component: usize,
n_points: usize,
) -> IntegrateResult<(Vec<F>, Vec<F>)> {
if self.y.is_empty() {
return Err(IntegrateError::ComputationError(
"Empty solution".to_string(),
));
}
let dim = self.y[0].len();
if component >= dim {
return Err(IntegrateError::ValueError(format!(
"Component index {} is out of bounds (0-{})",
component,
dim - 1
)));
}
let (times, values) = self.dense_output(n_points)?;
let component_values = values.iter().map(|v| v[component]).collect();
Ok((times, component_values))
}
}
#[derive(Debug, Clone)]
pub struct DOP853Interpolant<F: IntegrateFloat> {
pub t0: F,
pub h: F,
pub y0: Array1<F>,
pub k: Vec<Array1<F>>,
}
impl<F: IntegrateFloat> DOP853Interpolant<F> {
pub fn new(t0: F, h: F, y0: Array1<F>, k: Vec<Array1<F>>) -> Self {
DOP853Interpolant { t0, h, y0, k }
}
pub fn evaluate(&self, t: F) -> IntegrateResult<Array1<F>> {
let theta = (t - self.t0) / self.h;
if theta < F::zero() || theta > F::one() {
return Err(IntegrateError::ValueError(
"Evaluation point is outside of the step".to_string(),
));
}
let b1 = theta;
let b2 = theta * theta / F::from_f64(2.0).expect("Operation failed");
let b3 = theta * theta * theta / F::from_f64(6.0).expect("Operation failed");
let b4 = theta * theta * theta * theta / F::from_f64(24.0).expect("Operation failed");
let b5 =
theta * theta * theta * theta * theta / F::from_f64(120.0).expect("Operation failed");
let b6 = theta * theta * theta * theta * theta * theta
/ F::from_f64(720.0).expect("Operation failed");
let b7 = theta * theta * theta * theta * theta * theta * theta
/ F::from_f64(5040.0).expect("Operation failed");
let mut result = self.y0.clone();
result += &(self.k[0].clone() * self.h * b1);
result += &(self.k[1].clone() * self.h * b2);
result += &(self.k[2].clone() * self.h * b3);
result += &(self.k[3].clone() * self.h * b4);
result += &(self.k[4].clone() * self.h * b5);
result += &(self.k[5].clone() * self.h * b6);
if self.k.len() > 6 {
result += &(self.k[6].clone() * self.h * b7);
}
Ok(result)
}
}
#[derive(Debug, Clone)]
pub struct RadauInterpolant<F: IntegrateFloat> {
pub t0: F,
pub h: F,
pub y0: Array1<F>,
pub y1: Array1<F>,
pub k: Vec<Array1<F>>,
}
impl<F: IntegrateFloat> RadauInterpolant<F> {
pub fn new(t0: F, h: F, y0: Array1<F>, y1: Array1<F>, k: Vec<Array1<F>>) -> Self {
RadauInterpolant { t0, h, y0, y1, k }
}
pub fn evaluate(&self, t: F) -> IntegrateResult<Array1<F>> {
let theta = (t - self.t0) / self.h;
if theta < F::zero() || theta > F::one() {
return Err(IntegrateError::ValueError(
"Evaluation point is outside of the step".to_string(),
));
}
let h00 = F::from_f64(2.0).expect("Operation failed") * theta.powi(3)
- F::from_f64(3.0).expect("Operation failed") * theta.powi(2)
+ F::one();
let h10 =
theta.powi(3) - F::from_f64(2.0).expect("Operation failed") * theta.powi(2) + theta;
let h01 = F::from_f64(-2.0).expect("Operation failed") * theta.powi(3)
+ F::from_f64(3.0).expect("Operation failed") * theta.powi(2);
let h11 = theta.powi(3) - theta.powi(2);
let dy0 = &self.k[0];
let dy1 = if self.k.len() > 1 {
&self.k[self.k.len() - 1]
} else {
&self.k[0]
};
let mut result = Array1::zeros(self.y0.dim());
for i in 0..self.y0.len() {
result[i] =
h00 * self.y0[i] + h10 * self.h * dy0[i] + h01 * self.y1[i] + h11 * self.h * dy1[i];
}
Ok(result)
}
}
#[allow(dead_code)]
pub fn create_dense_solution<F, Func>(
t: Vec<F>,
y: Vec<Array1<F>>,
f: Func,
method: Option<ContinuousOutputMethod>,
) -> IntegrateResult<DenseSolution<F>>
where
F: IntegrateFloat,
Func: 'static + Fn(F, ArrayView1<F>) -> Array1<F>,
{
if t.is_empty() || y.is_empty() {
return Err(IntegrateError::ComputationError(
"Empty solution cannot be converted to dense output".to_string(),
));
}
if t.len() != y.len() {
return Err(IntegrateError::DimensionMismatch(
"Time and solution vectors must have the same length".to_string(),
));
}
let interp_method = method.unwrap_or_default();
let dydt = if interp_method == ContinuousOutputMethod::CubicHermite {
let mut derivatives = Vec::with_capacity(t.len());
for i in 0..t.len() {
derivatives.push(f(t[i], y[i].view()));
}
Some(derivatives)
} else {
None
};
Ok(DenseSolution::new(t, y, dydt, method, Some(Box::new(f))))
}