use crate::common::IntegrateFloat;
use crate::error::{IntegrateError, IntegrateResult};
use crate::ode::types::{ODEMethod, ODEOptions, ODEResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct ExtrapolationOptions<F: IntegrateFloat> {
pub max_order: usize,
pub min_order: usize,
pub base_method: ExtrapolationBaseMethod,
pub extrapolation_tol: F,
pub safety_factor: F,
pub max_increase_factor: F,
pub max_decrease_factor: F,
}
impl<F: IntegrateFloat> Default for ExtrapolationOptions<F> {
fn default() -> Self {
Self {
max_order: 10,
min_order: 3,
base_method: ExtrapolationBaseMethod::ModifiedMidpoint,
extrapolation_tol: F::from_f64(1e-12).expect("Operation failed"),
safety_factor: F::from_f64(0.9).expect("Operation failed"),
max_increase_factor: F::from_f64(1.5).expect("Operation failed"),
max_decrease_factor: F::from_f64(0.5).expect("Operation failed"),
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum ExtrapolationBaseMethod {
ModifiedMidpoint,
Euler,
RungeKutta4,
}
#[derive(Debug, Clone)]
pub struct ExtrapolationResult<F: IntegrateFloat> {
pub y: Array1<F>,
pub error_estimate: F,
pub table: Array2<F>,
pub n_substeps: usize,
pub final_order: usize,
pub converged: bool,
}
#[allow(dead_code)]
pub fn gragg_bulirsch_stoer_method<F, Func>(
f: Func,
t_span: [F; 2],
y0: Array1<F>,
opts: ODEOptions<F>,
ext_opts: Option<ExtrapolationOptions<F>>,
) -> IntegrateResult<ODEResult<F>>
where
F: IntegrateFloat,
Func: Fn(F, ArrayView1<F>) -> Array1<F>,
{
let [t_start, t_end] = t_span;
let ext_options = ext_opts.unwrap_or_default();
let mut h = opts.h0.unwrap_or_else(|| {
let _span = t_end - t_start;
_span / F::from_usize(100).expect("Operation failed")
});
let min_step = opts.min_step.unwrap_or_else(|| {
let _span = t_end - t_start;
_span / F::from_usize(1_000_000).expect("Operation failed")
});
let max_step = opts.max_step.unwrap_or_else(|| {
let _span = t_end - t_start;
_span / F::from_usize(10).expect("Operation failed")
});
let mut t_values = vec![t_start];
let mut y_values = vec![y0.clone()];
let mut t = t_start;
let mut y = y0;
let mut steps = 0;
let mut func_evals = 0;
let mut rejected_steps = 0;
while t < t_end {
if t + h > t_end {
h = t_end - t;
}
let result = extrapolation_step(&f, t, &y, h, &ext_options)?;
func_evals += result.n_substeps * (result.n_substeps + 1);
let error_estimate = result.error_estimate;
let tolerance =
opts.atol + opts.rtol * y.iter().map(|&x| x.abs()).fold(F::zero(), |a, b| a.max(b));
if error_estimate <= tolerance {
t += h;
y = result.y;
steps += 1;
t_values.push(t);
y_values.push(y.clone());
if result.converged && result.final_order >= ext_options.min_order {
h *= ext_options.max_increase_factor.min(
(tolerance / error_estimate.max(F::from_f64(1e-14).expect("Operation failed")))
.powf(
F::one()
/ F::from_usize(result.final_order + 1).expect("Operation failed"),
)
* ext_options.safety_factor,
);
}
} else {
rejected_steps += 1;
h *= ext_options.max_decrease_factor.max(
(tolerance / error_estimate).powf(
F::one() / F::from_usize(result.final_order + 1).expect("Operation failed"),
) * ext_options.safety_factor,
);
}
if h < min_step {
return Err(IntegrateError::StepSizeTooSmall(
"Step size became too small in extrapolation method".to_string(),
));
}
h = h.min(max_step);
if steps > 100000 {
return Err(IntegrateError::ComputationError(
"Maximum number of steps exceeded in extrapolation method".to_string(),
));
}
}
Ok(ODEResult {
t: t_values,
y: y_values,
success: true,
message: Some("Integration completed successfully".to_string()),
n_eval: func_evals,
n_steps: steps,
n_accepted: steps,
n_rejected: rejected_steps,
n_lu: 0,
n_jac: 0,
method: ODEMethod::RK45, })
}
#[allow(dead_code)]
fn extrapolation_step<F, Func>(
f: &Func,
t: F,
y: &Array1<F>,
h: F,
options: &ExtrapolationOptions<F>,
) -> IntegrateResult<ExtrapolationResult<F>>
where
F: IntegrateFloat,
Func: Fn(F, ArrayView1<F>) -> Array1<F>,
{
let _n_dim = y.len();
let max_order = options.max_order;
let step_sequence: Vec<usize> = (1..=max_order).map(|i| 2 * i).collect();
let mut table = Array2::zeros((max_order, max_order));
let mut y_table = Vec::new();
let mut converged = false;
let mut final_order = 0;
let mut error_estimate = F::infinity();
for (i, &n_steps) in step_sequence.iter().enumerate() {
if i >= max_order {
break;
}
let h_sub = h / F::from_usize(n_steps).expect("Operation failed");
let y_approx = match options.base_method {
ExtrapolationBaseMethod::ModifiedMidpoint => {
modified_midpoint_sequence(f, t, y, h_sub, n_steps)?
}
ExtrapolationBaseMethod::Euler => euler_sequence(f, t, y, h_sub, n_steps)?,
ExtrapolationBaseMethod::RungeKutta4 => rk4_sequence(f, t, y, h_sub, n_steps)?,
};
y_table.push(y_approx.clone());
let norm = y_approx
.iter()
.map(|&x| x * x)
.fold(F::zero(), |a, b| a + b)
.sqrt();
table[[i, 0]] = norm;
for j in 1..=i {
if j >= max_order {
break;
}
let ratio = F::from_usize(step_sequence[i]).expect("Operation failed")
/ F::from_usize(step_sequence[i - 1]).expect("Operation failed");
let denominator =
ratio.powf(F::from_usize(2 * j).expect("Operation failed")) - F::one();
if denominator.abs() > F::from_f64(1e-14).expect("Operation failed") {
table[[i, j]] =
table[[i, j - 1]] + (table[[i, j - 1]] - table[[i - 1, j - 1]]) / denominator;
} else {
table[[i, j]] = table[[i, j - 1]];
}
}
if i >= options.min_order - 1 {
let current_order = i;
if current_order > 0 {
let current_est = table[[current_order, current_order]];
let prev_est = table[[current_order - 1, current_order - 1]];
error_estimate = (current_est - prev_est).abs();
if error_estimate <= options.extrapolation_tol * current_est.abs() {
converged = true;
final_order = current_order + 1;
break;
}
}
}
final_order = i + 1;
}
let final_y = if final_order > 0 && !y_table.is_empty() {
y_table[final_order - 1].clone()
} else {
y.clone()
};
Ok(ExtrapolationResult {
y: final_y,
error_estimate,
table,
n_substeps: step_sequence
.get(final_order.saturating_sub(1))
.copied()
.unwrap_or(2),
final_order,
converged,
})
}
#[allow(dead_code)]
fn modified_midpoint_sequence<F, Func>(
f: &Func,
t0: F,
y0: &Array1<F>,
h_sub: F,
n_steps: usize,
) -> IntegrateResult<Array1<F>>
where
F: IntegrateFloat,
Func: Fn(F, ArrayView1<F>) -> Array1<F>,
{
if n_steps == 0 {
return Ok(y0.clone());
}
let mut y = y0.clone();
let mut y_prev = y0.clone();
let mut t = t0;
if n_steps >= 1 {
let dy = f(t, y.view());
let y_next = &y + &dy * h_sub;
y_prev = y.clone();
y = y_next;
t += h_sub;
}
for _ in 1..n_steps {
let dy = f(t, y.view());
let y_next = &y_prev + &dy * (F::from_f64(2.0).expect("Operation failed") * h_sub);
y_prev = y.clone();
y = y_next;
t += h_sub;
}
if n_steps > 1 {
let dy = f(t, y.view());
y = (&y + &y_prev + &dy * h_sub) * F::from_f64(0.5).expect("Operation failed");
}
Ok(y)
}
#[allow(dead_code)]
fn euler_sequence<F, Func>(
f: &Func,
t0: F,
y0: &Array1<F>,
h_sub: F,
n_steps: usize,
) -> IntegrateResult<Array1<F>>
where
F: IntegrateFloat,
Func: Fn(F, ArrayView1<F>) -> Array1<F>,
{
let mut y = y0.clone();
let mut t = t0;
for _ in 0..n_steps {
let dy = f(t, y.view());
y = &y + &dy * h_sub;
t += h_sub;
}
Ok(y)
}
#[allow(dead_code)]
fn rk4_sequence<F, Func>(
f: &Func,
t0: F,
y0: &Array1<F>,
h_sub: F,
n_steps: usize,
) -> IntegrateResult<Array1<F>>
where
F: IntegrateFloat,
Func: Fn(F, ArrayView1<F>) -> Array1<F>,
{
let mut y = y0.clone();
let mut t = t0;
let h_half = h_sub * F::from_f64(0.5).expect("Operation failed");
let h_sixth = h_sub / F::from_f64(6.0).expect("Operation failed");
for _ in 0..n_steps {
let k1 = f(t, y.view());
let k2 = f(t + h_half, (&y + &k1 * h_half).view());
let k3 = f(t + h_half, (&y + &k2 * h_half).view());
let k4 = f(t + h_sub, (&y + &k3 * h_sub).view());
y = &y
+ (&k1
+ &k2 * F::from_f64(2.0).expect("Operation failed")
+ &k3 * F::from_f64(2.0).expect("Operation failed")
+ &k4)
* h_sixth;
t += h_sub;
}
Ok(y)
}
#[allow(dead_code)]
pub fn richardson_extrapolation_step<F, Func, Method>(
method: Method,
f: &Func,
t: F,
y: &Array1<F>,
h: F,
) -> IntegrateResult<(Array1<F>, F)>
where
F: IntegrateFloat,
Func: Fn(F, ArrayView1<F>) -> Array1<F> + ?Sized,
Method: Fn(&Func, F, &Array1<F>, F) -> IntegrateResult<Array1<F>>,
{
let y1 = method(f, t, y, h)?;
let h_half = h * F::from_f64(0.5).expect("Operation failed");
let y_mid = method(f, t, y, h_half)?;
let y2 = method(f, t + h_half, &y_mid, h_half)?;
let y_extrapolated = (&y2 * F::from_f64(4.0).expect("Operation failed") - &y1)
/ F::from_f64(3.0).expect("Operation failed");
let error_estimate = (&y2 - &y1)
.iter()
.map(|&x| x.abs())
.fold(F::zero(), |a, b| a.max(b))
/ F::from_f64(3.0).expect("Operation failed");
Ok((y_extrapolated, error_estimate))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_modified_midpoint_sequence() {
let f = |_t: f64, y: ArrayView1<f64>| -y.to_owned();
let y0 = Array1::from_vec(vec![1.0]);
let h = 0.1;
let n_steps = 10;
let result = modified_midpoint_sequence(&f, 0.0, &y0, h / n_steps as f64, n_steps)
.expect("Operation failed");
let exact = (-h).exp();
assert_relative_eq!(result[0], exact, epsilon = 1e-3);
}
#[test]
fn test_richardson_extrapolation() {
let y0 = Array1::from_vec(vec![1.0]);
let h = 0.1;
let f = |_t: f64, y: ArrayView1<f64>| -y.to_owned();
let result =
gragg_bulirsch_stoer_method(f, [0.0, h], y0.clone(), ODEOptions::default(), None)
.expect("Operation failed");
let exact = (-h).exp();
let final_value = result.y.last().expect("Operation failed")[0];
assert!(result.success);
assert_relative_eq!(final_value, exact, epsilon = 1e-6);
}
#[test]
fn test_extrapolation_options_default() {
let opts: ExtrapolationOptions<f64> = Default::default();
assert_eq!(opts.max_order, 10);
assert_eq!(opts.min_order, 3);
assert_eq!(opts.safety_factor, 0.9);
}
}