use crate::fuga::ConvToMat;
use crate::traits::math::{InnerProduct, Norm, Normed, Vector};
use crate::util::non_macro::eye;
use anyhow::{bail, Result};
pub trait ODEProblem {
fn rhs(&self, t: f64, y: &[f64], dy: &mut [f64]) -> Result<()>;
}
pub trait ODEIntegrator {
fn step<P: ODEProblem>(&self, problem: &P, t: f64, y: &mut [f64], dt: f64) -> Result<f64>;
}
#[derive(Debug, Clone)]
pub enum ODEError {
ConstraintViolation(f64, Vec<f64>, Vec<f64>), ReachedMaxStepIter,
}
impl std::fmt::Display for ODEError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ODEError::ConstraintViolation(t, y, dy) => write!(
f,
"Constraint violation at t = {}, y = {:?}, dy = {:?}",
t, y, dy
),
ODEError::ReachedMaxStepIter => write!(f, "Reached maximum number of steps per step"),
}
}
}
pub trait ODESolver {
fn solve<P: ODEProblem>(
&self,
problem: &P,
t_span: (f64, f64),
dt: f64,
initial_conditions: &[f64],
) -> Result<(Vec<f64>, Vec<Vec<f64>>)>;
}
pub struct BasicODESolver<I: ODEIntegrator> {
integrator: I,
}
impl<I: ODEIntegrator> BasicODESolver<I> {
pub fn new(integrator: I) -> Self {
Self { integrator }
}
}
impl<I: ODEIntegrator> ODESolver for BasicODESolver<I> {
fn solve<P: ODEProblem>(
&self,
problem: &P,
t_span: (f64, f64),
dt: f64,
initial_conditions: &[f64],
) -> Result<(Vec<f64>, Vec<Vec<f64>>)> {
let mut t = t_span.0;
let mut dt = dt;
let mut y = initial_conditions.to_vec();
let mut t_vec = vec![t];
let mut y_vec = vec![y.clone()];
while t < t_span.1 {
let dt_step = self.integrator.step(problem, t, &mut y, dt)?;
t += dt;
t_vec.push(t);
y_vec.push(y.clone());
dt = dt_step;
}
Ok((t_vec, y_vec))
}
}
pub trait ButcherTableau {
const C: &'static [f64];
const A: &'static [&'static [f64]];
const BU: &'static [f64];
const BE: &'static [f64];
fn tol(&self) -> f64 {
unimplemented!()
}
fn safety_factor(&self) -> f64 {
unimplemented!()
}
fn max_step_size(&self) -> f64 {
unimplemented!()
}
fn min_step_size(&self) -> f64 {
unimplemented!()
}
fn max_step_iter(&self) -> usize {
unimplemented!()
}
fn order(&self) -> usize {
4
}
}
impl<BU: ButcherTableau> ODEIntegrator for BU {
fn step<P: ODEProblem>(&self, problem: &P, t: f64, y: &mut [f64], dt: f64) -> Result<f64> {
let n = y.len();
let mut iter_count = 0usize;
let mut dt = dt;
let n_k = Self::C.len();
loop {
let mut k_vec = vec![vec![0.0; n]; n_k];
let mut y_temp = y.to_vec();
for stage in 0..n_k {
for i in 0..n {
let mut s = 0.0;
for (j, kj) in k_vec.iter().enumerate().take(stage) {
s += Self::A[stage][j] * kj[i];
}
y_temp[i] = y[i] + dt * s;
}
problem.rhs(t + dt * Self::C[stage], &y_temp, &mut k_vec[stage])?;
}
if !Self::BE.is_empty() {
let mut error = 0f64;
#[allow(clippy::needless_range_loop)]
for i in 0..n {
let mut s = 0.0;
for (j, kj) in k_vec.iter().enumerate().take(n_k) {
s += (Self::BU[j] - Self::BE[j]) * kj[i];
}
error = error.max(dt * s.abs())
}
let factor = (self.tol() / error).powf(1.0 / (self.order() as f64 + 1.0));
let new_dt = self.safety_factor() * dt * factor;
let new_dt = new_dt.clamp(self.min_step_size(), self.max_step_size());
if error < self.tol() {
for i in 0..n {
let mut s = 0.0;
for (j, kj) in k_vec.iter().enumerate().take(n_k) {
s += Self::BU[j] * kj[i];
}
y[i] += dt * s;
}
return Ok(new_dt);
} else {
iter_count += 1;
if iter_count >= self.max_step_iter() {
bail!(ODEError::ReachedMaxStepIter);
}
dt = new_dt;
}
} else {
for i in 0..n {
let mut s = 0.0;
for (j, kj) in k_vec.iter().enumerate().take(n_k) {
s += Self::BU[j] * kj[i];
}
y[i] += dt * s;
}
return Ok(dt);
}
}
}
}
#[derive(Debug, Clone, Copy, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "rkyv",
derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
)]
pub struct RALS3;
impl ButcherTableau for RALS3 {
const C: &'static [f64] = &[0.0, 0.5, 0.75];
const A: &'static [&'static [f64]] = &[&[], &[0.5], &[0.0, 0.75]];
const BU: &'static [f64] = &[2.0 / 9.0, 1.0 / 3.0, 4.0 / 9.0];
const BE: &'static [f64] = &[];
}
#[derive(Debug, Clone, Copy, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "rkyv",
derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
)]
pub struct RK4;
impl ButcherTableau for RK4 {
const C: &'static [f64] = &[0.0, 0.5, 0.5, 1.0];
const A: &'static [&'static [f64]] = &[&[], &[0.5], &[0.0, 0.5], &[0.0, 0.0, 1.0]];
const BU: &'static [f64] = &[1.0 / 6.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 6.0];
const BE: &'static [f64] = &[];
}
#[derive(Debug, Clone, Copy, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "rkyv",
derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
)]
pub struct RALS4;
impl ButcherTableau for RALS4 {
const C: &'static [f64] = &[0.0, 0.4, 0.45573725, 1.0];
const A: &'static [&'static [f64]] = &[
&[],
&[0.4],
&[0.29697761, 0.158575964],
&[0.21810040, -3.050965616, 3.83286476],
];
const BU: &'static [f64] = &[0.17476028, -0.55148066, 1.20553560, 0.17118478];
const BE: &'static [f64] = &[];
}
#[derive(Debug, Clone, Copy, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "rkyv",
derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
)]
pub struct RK5;
impl ButcherTableau for RK5 {
const C: &'static [f64] = &[0.0, 0.2, 0.3, 0.8, 8.0 / 9.0, 1.0, 1.0];
const A: &'static [&'static [f64]] = &[
&[],
&[0.2],
&[0.075, 0.225],
&[44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0],
&[
19372.0 / 6561.0,
-25360.0 / 2187.0,
64448.0 / 6561.0,
-212.0 / 729.0,
],
&[
9017.0 / 3168.0,
-355.0 / 33.0,
46732.0 / 5247.0,
49.0 / 176.0,
-5103.0 / 18656.0,
],
&[
35.0 / 384.0,
0.0,
500.0 / 1113.0,
125.0 / 192.0,
-2187.0 / 6784.0,
11.0 / 84.0,
],
];
const BU: &'static [f64] = &[
5179.0 / 57600.0,
0.0,
7571.0 / 16695.0,
393.0 / 640.0,
-92097.0 / 339200.0,
187.0 / 2100.0,
1.0 / 40.0,
];
const BE: &'static [f64] = &[];
}
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "rkyv",
derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
)]
pub struct BS23 {
pub tol: f64,
pub safety_factor: f64,
pub min_step_size: f64,
pub max_step_size: f64,
pub max_step_iter: usize,
}
impl Default for BS23 {
fn default() -> Self {
Self {
tol: 1e-3,
safety_factor: 0.9,
min_step_size: 1e-6,
max_step_size: 1e-1,
max_step_iter: 100,
}
}
}
impl BS23 {
pub fn new(
tol: f64,
safety_factor: f64,
min_step_size: f64,
max_step_size: f64,
max_step_iter: usize,
) -> Self {
Self {
tol,
safety_factor,
min_step_size,
max_step_size,
max_step_iter,
}
}
}
impl ButcherTableau for BS23 {
const C: &'static [f64] = &[0.0, 0.5, 0.75, 1.0];
const A: &'static [&'static [f64]] = &[
&[],
&[0.5],
&[0.0, 0.75],
&[2.0 / 9.0, 1.0 / 3.0, 4.0 / 9.0],
];
const BU: &'static [f64] = &[2.0 / 9.0, 1.0 / 3.0, 4.0 / 9.0, 0.0];
const BE: &'static [f64] = &[7.0 / 24.0, 0.25, 1.0 / 3.0, 0.125];
fn tol(&self) -> f64 {
self.tol
}
fn safety_factor(&self) -> f64 {
self.safety_factor
}
fn min_step_size(&self) -> f64 {
self.min_step_size
}
fn max_step_size(&self) -> f64 {
self.max_step_size
}
fn max_step_iter(&self) -> usize {
self.max_step_iter
}
fn order(&self) -> usize {
2
}
}
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "rkyv",
derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
)]
pub struct RKF45 {
pub tol: f64,
pub safety_factor: f64,
pub min_step_size: f64,
pub max_step_size: f64,
pub max_step_iter: usize,
}
impl Default for RKF45 {
fn default() -> Self {
Self {
tol: 1e-6,
safety_factor: 0.9,
min_step_size: 1e-6,
max_step_size: 1e-1,
max_step_iter: 100,
}
}
}
impl RKF45 {
pub fn new(
tol: f64,
safety_factor: f64,
min_step_size: f64,
max_step_size: f64,
max_step_iter: usize,
) -> Self {
Self {
tol,
safety_factor,
min_step_size,
max_step_size,
max_step_iter,
}
}
}
impl ButcherTableau for RKF45 {
const C: &'static [f64] = &[0.0, 1.0 / 4.0, 3.0 / 8.0, 12.0 / 13.0, 1.0, 1.0 / 2.0];
const A: &'static [&'static [f64]] = &[
&[],
&[0.25],
&[3.0 / 32.0, 9.0 / 32.0],
&[1932.0 / 2197.0, -7200.0 / 2197.0, 7296.0 / 2197.0],
&[439.0 / 216.0, -8.0, 3680.0 / 513.0, -845.0 / 4104.0],
&[
-8.0 / 27.0,
2.0,
-3544.0 / 2565.0,
1859.0 / 4104.0,
-11.0 / 40.0,
],
];
const BU: &'static [f64] = &[
16.0 / 135.0,
0.0,
6656.0 / 12825.0,
28561.0 / 56430.0,
-9.0 / 50.0,
2.0 / 55.0,
];
const BE: &'static [f64] = &[
25.0 / 216.0,
0.0,
1408.0 / 2565.0,
2197.0 / 4104.0,
-1.0 / 5.0,
0.0,
];
fn tol(&self) -> f64 {
self.tol
}
fn safety_factor(&self) -> f64 {
self.safety_factor
}
fn min_step_size(&self) -> f64 {
self.min_step_size
}
fn max_step_size(&self) -> f64 {
self.max_step_size
}
fn max_step_iter(&self) -> usize {
self.max_step_iter
}
}
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "rkyv",
derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
)]
pub struct DP45 {
pub tol: f64,
pub safety_factor: f64,
pub min_step_size: f64,
pub max_step_size: f64,
pub max_step_iter: usize,
}
impl Default for DP45 {
fn default() -> Self {
Self {
tol: 1e-6,
safety_factor: 0.9,
min_step_size: 1e-6,
max_step_size: 1e-1,
max_step_iter: 100,
}
}
}
impl DP45 {
pub fn new(
tol: f64,
safety_factor: f64,
min_step_size: f64,
max_step_size: f64,
max_step_iter: usize,
) -> Self {
Self {
tol,
safety_factor,
min_step_size,
max_step_size,
max_step_iter,
}
}
}
impl ButcherTableau for DP45 {
const C: &'static [f64] = &[0.0, 0.2, 0.3, 0.8, 8.0 / 9.0, 1.0, 1.0];
const A: &'static [&'static [f64]] = &[
&[],
&[0.2],
&[0.075, 0.225],
&[44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0],
&[
19372.0 / 6561.0,
-25360.0 / 2187.0,
64448.0 / 6561.0,
-212.0 / 729.0,
],
&[
9017.0 / 3168.0,
-355.0 / 33.0,
46732.0 / 5247.0,
49.0 / 176.0,
-5103.0 / 18656.0,
],
&[
35.0 / 384.0,
0.0,
500.0 / 1113.0,
125.0 / 192.0,
-2187.0 / 6784.0,
11.0 / 84.0,
],
];
const BU: &'static [f64] = &[
35.0 / 384.0,
0.0,
500.0 / 1113.0,
125.0 / 192.0,
-2187.0 / 6784.0,
11.0 / 84.0,
0.0,
];
const BE: &'static [f64] = &[
5179.0 / 57600.0,
0.0,
7571.0 / 16695.0,
393.0 / 640.0,
-92097.0 / 339200.0,
187.0 / 2100.0,
1.0 / 40.0,
];
fn tol(&self) -> f64 {
self.tol
}
fn safety_factor(&self) -> f64 {
self.safety_factor
}
fn min_step_size(&self) -> f64 {
self.min_step_size
}
fn max_step_size(&self) -> f64 {
self.max_step_size
}
fn max_step_iter(&self) -> usize {
self.max_step_iter
}
}
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "rkyv",
derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
)]
pub struct TSIT45 {
pub tol: f64,
pub safety_factor: f64,
pub min_step_size: f64,
pub max_step_size: f64,
pub max_step_iter: usize,
}
impl Default for TSIT45 {
fn default() -> Self {
Self {
tol: 1e-6,
safety_factor: 0.9,
min_step_size: 1e-6,
max_step_size: 1e-1,
max_step_iter: 100,
}
}
}
impl TSIT45 {
pub fn new(
tol: f64,
safety_factor: f64,
min_step_size: f64,
max_step_size: f64,
max_step_iter: usize,
) -> Self {
Self {
tol,
safety_factor,
min_step_size,
max_step_size,
max_step_iter,
}
}
}
impl ButcherTableau for TSIT45 {
const C: &'static [f64] = &[0.0, 0.161, 0.327, 0.9, 0.9800255409045097, 1.0, 1.0];
const A: &'static [&'static [f64]] = &[
&[],
&[Self::C[1]],
&[Self::C[2] - 0.335480655492357, 0.335480655492357],
&[
Self::C[3] - (-6.359448489975075 + 4.362295432869581),
-6.359448489975075,
4.362295432869581,
],
&[
Self::C[4] - (-11.74888356406283 + 7.495539342889836 - 0.09249506636175525),
-11.74888356406283,
7.495539342889836,
-0.09249506636175525,
],
&[
Self::C[5]
- (-12.92096931784711 + 8.159367898576159
- 0.0715849732814010
- 0.02826905039406838),
-12.92096931784711,
8.159367898576159,
-0.0715849732814010,
-0.02826905039406838,
],
&[
Self::BU[0],
Self::BU[1],
Self::BU[2],
Self::BU[3],
Self::BU[4],
Self::BU[5],
],
];
const BU: &'static [f64] = &[
0.09646076681806523,
0.01,
0.4798896504144996,
1.379008574103742,
-3.290069515436081,
2.324710524099774,
0.0,
];
const BE: &'static [f64] = &[
0.001780011052226,
0.000816434459657,
-0.007880878010262,
0.144711007173263,
-0.582357165452555,
0.458082105929187,
1.0 / 66.0,
];
fn tol(&self) -> f64 {
self.tol
}
fn safety_factor(&self) -> f64 {
self.safety_factor
}
fn min_step_size(&self) -> f64 {
self.min_step_size
}
fn max_step_size(&self) -> f64 {
self.max_step_size
}
fn max_step_iter(&self) -> usize {
self.max_step_iter
}
}
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "rkyv",
derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
)]
pub struct RKF78 {
pub tol: f64,
pub safety_factor: f64,
pub min_step_size: f64,
pub max_step_size: f64,
pub max_step_iter: usize,
}
impl Default for RKF78 {
fn default() -> Self {
Self {
tol: 1e-7, safety_factor: 0.9,
min_step_size: 1e-10, max_step_size: 1e-1,
max_step_iter: 100,
}
}
}
impl RKF78 {
pub fn new(
tol: f64,
safety_factor: f64,
min_step_size: f64,
max_step_size: f64,
max_step_iter: usize,
) -> Self {
Self {
tol,
safety_factor,
min_step_size,
max_step_size,
max_step_iter,
}
}
}
impl ButcherTableau for RKF78 {
const C: &'static [f64] = &[
0.0,
2.0 / 27.0,
1.0 / 9.0,
1.0 / 6.0,
5.0 / 12.0,
1.0 / 2.0,
5.0 / 6.0,
1.0 / 6.0,
2.0 / 3.0,
1.0 / 3.0,
1.0,
0.0, 1.0, ];
const A: &'static [&'static [f64]] = &[
&[],
&[2.0 / 27.0],
&[1.0 / 36.0, 3.0 / 36.0],
&[1.0 / 24.0, 0.0, 3.0 / 24.0],
&[20.0 / 48.0, 0.0, -75.0 / 48.0, 75.0 / 48.0],
&[1.0 / 20.0, 0.0, 0.0, 5.0 / 20.0, 4.0 / 20.0],
&[
-25.0 / 108.0,
0.0,
0.0,
125.0 / 108.0,
-260.0 / 108.0,
250.0 / 108.0,
],
&[
31.0 / 300.0,
0.0,
0.0,
0.0,
61.0 / 225.0,
-2.0 / 9.0,
13.0 / 900.0,
],
&[
2.0,
0.0,
0.0,
-53.0 / 6.0,
704.0 / 45.0,
-107.0 / 9.0,
67.0 / 90.0,
3.0,
],
&[
-91.0 / 108.0,
0.0,
0.0,
23.0 / 108.0,
-976.0 / 135.0,
311.0 / 54.0,
-19.0 / 60.0,
17.0 / 6.0,
-1.0 / 12.0,
],
&[
2383.0 / 4100.0,
0.0,
0.0,
-341.0 / 164.0,
4496.0 / 1025.0,
-301.0 / 82.0,
2133.0 / 4100.0,
45.0 / 82.0,
45.0 / 164.0,
18.0 / 41.0,
],
&[
3.0 / 205.0,
0.0,
0.0,
0.0,
0.0,
-6.0 / 41.0,
-3.0 / 205.0,
-3.0 / 41.0,
3.0 / 41.0,
6.0 / 41.0,
0.0,
],
&[
-1777.0 / 4100.0,
0.0,
0.0,
-341.0 / 164.0,
4496.0 / 1025.0,
-289.0 / 82.0,
2193.0 / 4100.0,
51.0 / 82.0,
33.0 / 164.0,
12.0 / 41.0,
0.0,
1.0,
],
];
const BU: &'static [f64] = &[
0.0,
0.0,
0.0,
0.0,
0.0,
34.0 / 105.0,
9.0 / 35.0,
9.0 / 35.0,
9.0 / 280.0,
9.0 / 280.0,
0.0,
41.0 / 840.0,
41.0 / 840.0,
];
const BE: &'static [f64] = &[
41.0 / 840.0,
0.0,
0.0,
0.0,
0.0,
34.0 / 105.0,
9.0 / 35.0,
9.0 / 35.0,
9.0 / 280.0,
9.0 / 280.0,
41.0 / 840.0,
0.0,
0.0,
];
fn tol(&self) -> f64 {
self.tol
}
fn safety_factor(&self) -> f64 {
self.safety_factor
}
fn min_step_size(&self) -> f64 {
self.min_step_size
}
fn max_step_size(&self) -> f64 {
self.max_step_size
}
fn max_step_iter(&self) -> usize {
self.max_step_iter
}
fn order(&self) -> usize {
7
}
}
const SQRT3: f64 = 1.7320508075688772;
const C1: f64 = 0.5 - SQRT3 / 6.0;
const C2: f64 = 0.5 + SQRT3 / 6.0;
const A11: f64 = 0.25;
const A12: f64 = 0.25 - SQRT3 / 6.0;
const A21: f64 = 0.25 + SQRT3 / 6.0;
const A22: f64 = 0.25;
const B1: f64 = 0.5;
const B2: f64 = 0.5;
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "rkyv",
derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
)]
pub enum ImplicitSolver {
FixedPoint,
Broyden,
}
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "rkyv",
derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
)]
pub struct GL4 {
pub solver: ImplicitSolver,
pub tol: f64,
pub max_step_iter: usize,
}
impl Default for GL4 {
fn default() -> Self {
GL4 {
solver: ImplicitSolver::FixedPoint,
tol: 1e-8,
max_step_iter: 100,
}
}
}
impl GL4 {
pub fn new(solver: ImplicitSolver, tol: f64, max_step_iter: usize) -> Self {
GL4 {
solver,
tol,
max_step_iter,
}
}
}
impl ODEIntegrator for GL4 {
#[allow(non_snake_case)]
#[inline]
fn step<P: ODEProblem>(&self, problem: &P, t: f64, y: &mut [f64], dt: f64) -> Result<f64> {
let n = y.len();
let mut k1 = vec![0.0; n];
let mut k2 = vec![0.0; n];
problem.rhs(t, y, &mut k1)?;
k2.copy_from_slice(&k1);
match self.solver {
ImplicitSolver::FixedPoint => {
let mut y1 = vec![0.0; n];
let mut y2 = vec![0.0; n];
for _ in 0..self.max_step_iter {
let k1_old = k1.clone();
let k2_old = k2.clone();
for i in 0..n {
y1[i] = y[i] + dt * (A11 * k1[i] + A12 * k2[i]);
y2[i] = y[i] + dt * (A11 * k1[i] + A12 * k2[i]);
}
problem.rhs(t + C1 * dt, &y1, &mut k1)?;
problem.rhs(t + C2 * dt, &y2, &mut k2)?;
let mut max_diff = 0f64;
for i in 0..n {
max_diff = max_diff.max((k1[i] - k1_old[i]).abs());
max_diff = max_diff.max((k2[i] - k2_old[i]).abs());
}
if max_diff < self.tol {
break;
}
}
}
ImplicitSolver::Broyden => {
let m = 2 * n;
let mut U = vec![0.0; m];
U[..n].copy_from_slice(&k1);
U[n..].copy_from_slice(&k2);
let mut F_vec = vec![0.0; m];
compute_F(problem, t, y, dt, &U, &mut F_vec)?;
let mut J_inv = eye(m);
for _ in 0..self.max_step_iter {
let delta = (&J_inv * &F_vec).mul_scalar(-1.0);
U.iter_mut().zip(delta.iter()).for_each(|(u, d)| *u += *d);
let mut F_new = vec![0.0; m];
compute_F(problem, t, y, dt, &U, &mut F_new)?;
if F_new.norm(Norm::LInf) < self.tol {
break;
}
let delta_F = F_new.sub_vec(&F_vec);
let J_inv_delta_F = &J_inv * &delta_F;
let denom = delta.dot(&J_inv_delta_F);
if denom.abs() < 1e-12 {
break;
}
let delta_minus_J_inv_delta_F = delta.sub_vec(&J_inv_delta_F).to_col();
let delta_T_J_inv = &delta.to_row() * &J_inv;
let update = (delta_minus_J_inv_delta_F * delta_T_J_inv) / denom;
J_inv = J_inv + update;
F_vec = F_new;
}
k1.copy_from_slice(&U[..n]);
k2.copy_from_slice(&U[n..]);
}
}
for i in 0..n {
y[i] += dt * (B1 * k1[i] + B2 * k2[i]);
}
Ok(dt)
}
}
#[allow(non_snake_case)]
fn compute_F<P: ODEProblem>(
problem: &P,
t: f64,
y: &[f64],
dt: f64,
U: &[f64], F: &mut [f64],
) -> Result<()> {
let n = y.len();
let (k1_slice, k2_slice) = U.split_at(n);
let mut y1 = vec![0.0; n];
let mut y2 = vec![0.0; n];
for i in 0..n {
y1[i] = y[i] + dt * (A11 * k1_slice[i] + A12 * k2_slice[i]);
y2[i] = y[i] + dt * (A21 * k1_slice[i] + A22 * k2_slice[i]);
}
let (f1, f2) = F.split_at_mut(n);
problem.rhs(t + C1 * dt, &y1, f1)?;
problem.rhs(t + C2 * dt, &y2, f2)?;
for i in 0..n {
f1[i] = k1_slice[i] - f1[i];
f2[i] = k2_slice[i] - f2[i];
}
Ok(())
}