use crate::error::SolverError;
use crate::problem::OdeSystem;
use crate::solver::{Solver, SolverOptions, SolverResult, SolverStats};
use crate::t_eval::{validate_grid, TEvalEmitter};
use numra_core::Scalar;
#[derive(Clone, Debug, Default)]
pub struct Vern6;
impl Vern6 {
pub fn new() -> Self {
Self
}
}
#[allow(dead_code)]
mod vern6_tableau {
pub const C: [f64; 10] = [
0.0,
0.06,
0.09593333333333333,
0.1439,
0.4973,
0.9725,
0.9995,
1.0,
1.0,
0.5,
];
pub const A21: f64 = 0.06;
pub const A31: f64 = 0.019239962962962962;
pub const A32: f64 = 0.07669337037037037;
pub const A41: f64 = 0.035975;
pub const A42: f64 = 0.0;
pub const A43: f64 = 0.107925;
pub const A51: f64 = 1.3186834152331484;
pub const A52: f64 = 0.0;
pub const A53: f64 = -5.042058063628562;
pub const A54: f64 = 4.220674648395414;
pub const A61: f64 = -41.872591664327516;
pub const A62: f64 = 0.0;
pub const A63: f64 = 159.4325621631375;
pub const A64: f64 = -122.11921356501003;
pub const A65: f64 = 5.531743066200054;
pub const A71: f64 = -54.430156935316504;
pub const A72: f64 = 0.0;
pub const A73: f64 = 207.06725136501848;
pub const A74: f64 = -158.61081378459;
pub const A75: f64 = 6.991816585950242;
pub const A76: f64 = -0.018597231062203234;
pub const A81: f64 = -54.66374178728198;
pub const A82: f64 = 0.0;
pub const A83: f64 = 207.95280625538936;
pub const A84: f64 = -159.2889574744995;
pub const A85: f64 = 7.018743740796944;
pub const A86: f64 = -0.018338785905045722;
pub const A87: f64 = -0.0005119484997882099;
pub const A91: f64 = 0.03438957868357036;
pub const A92: f64 = 0.0;
pub const A93: f64 = 0.0;
pub const A94: f64 = 0.2582624555633503;
pub const A95: f64 = 0.4209371189673537;
pub const A96: f64 = 4.40539646966931;
pub const A97: f64 = -176.48311902429865;
pub const A98: f64 = 172.36413340141507;
pub const A101: f64 = 0.016524159013572806;
pub const A102: f64 = 0.0;
pub const A103: f64 = 0.0;
pub const A104: f64 = 0.3053128187514179;
pub const A105: f64 = 0.2071200938201979;
pub const A106: f64 = -1.293879140655123;
pub const A107: f64 = 57.11988411588149;
pub const A108: f64 = -55.87979207510932;
pub const A109: f64 = 0.024830028297766014;
pub const B: [f64; 10] = [
0.03438957868357036,
0.0,
0.0,
0.2582624555633503,
0.4209371189673537,
4.40539646966931,
-176.48311902429865,
172.36413340141507,
0.0,
0.0,
];
pub const B_HAT: [f64; 10] = [
0.0490996764838249,
0.0,
0.0,
0.22511122295165242,
0.4694682253029562,
0.8065792249988868,
0.0,
-0.607119489177796,
0.056861139440475696,
0.0,
];
pub const E: [f64; 10] = [
B[0] - B_HAT[0], 0.0,
0.0,
B[3] - B_HAT[3], B[4] - B_HAT[4], B[5] - B_HAT[5], B[6] - B_HAT[6], B[7] - B_HAT[7], B[8] - B_HAT[8], 0.0,
];
}
impl<S: Scalar> Solver<S> for Vern6 {
fn solve<Sys: OdeSystem<S>>(
problem: &Sys,
t0: S,
tf: S,
y0: &[S],
options: &SolverOptions<S>,
) -> Result<SolverResult<S>, SolverError> {
use vern6_tableau::*;
let dim = problem.dim();
if dim == 0 {
return Err(SolverError::DimensionMismatch {
expected: 1,
actual: 0,
});
}
let mut t = t0;
let mut y = y0.to_vec();
let direction = if tf >= t0 { S::ONE } else { -S::ONE };
let mut k = vec![S::ZERO; 10 * dim];
let mut y_stage = vec![S::ZERO; dim];
let mut y_new = vec![S::ZERO; dim];
let mut err = vec![S::ZERO; dim];
let mut stats = SolverStats::default();
if let Some(grid) = options.t_eval.as_deref() {
validate_grid(grid, t0, tf)?;
}
let mut grid_emitter = options
.t_eval
.as_deref()
.map(|g| TEvalEmitter::new(g, direction));
let (mut t_out, mut y_out) = if grid_emitter.is_some() {
(Vec::new(), Vec::new())
} else {
(vec![t0], y0.to_vec())
};
problem.rhs(t, &y, &mut k[0..dim]);
stats.n_eval += 1;
let mut f_norm = S::ZERO;
for i in 0..dim {
f_norm = f_norm + k[i] * k[i];
}
f_norm = f_norm.sqrt();
let mut h = if f_norm > S::from_f64(1e-10) {
S::from_f64(0.01) * (S::ONE / f_norm)
} else {
S::from_f64(0.01)
};
h = h.min(options.h_max).max(options.h_min) * direction;
let max_steps = options.max_steps;
let mut step_count = 0;
while (tf - t) * direction > S::from_f64(1e-14) * (tf - t0).abs() {
if step_count >= max_steps {
return Err(SolverError::MaxIterationsExceeded { t: t.to_f64() });
}
if (t + h - tf) * direction > S::ZERO {
h = tf - t;
}
let c2 = S::from_f64(C[1]);
let a21 = S::from_f64(A21);
for i in 0..dim {
y_stage[i] = y[i] + h * a21 * k[i];
}
problem.rhs(t + c2 * h, &y_stage, &mut k[dim..2 * dim]);
let c3 = S::from_f64(C[2]);
let a31 = S::from_f64(A31);
let a32 = S::from_f64(A32);
for i in 0..dim {
y_stage[i] = y[i] + h * (a31 * k[i] + a32 * k[dim + i]);
}
problem.rhs(t + c3 * h, &y_stage, &mut k[2 * dim..3 * dim]);
let c4 = S::from_f64(C[3]);
let a41 = S::from_f64(A41);
let a43 = S::from_f64(A43);
for i in 0..dim {
y_stage[i] = y[i] + h * (a41 * k[i] + a43 * k[2 * dim + i]);
}
problem.rhs(t + c4 * h, &y_stage, &mut k[3 * dim..4 * dim]);
let c5 = S::from_f64(C[4]);
let a51 = S::from_f64(A51);
let a53 = S::from_f64(A53);
let a54 = S::from_f64(A54);
for i in 0..dim {
y_stage[i] = y[i] + h * (a51 * k[i] + a53 * k[2 * dim + i] + a54 * k[3 * dim + i]);
}
problem.rhs(t + c5 * h, &y_stage, &mut k[4 * dim..5 * dim]);
let c6 = S::from_f64(C[5]);
let a61 = S::from_f64(A61);
let a63 = S::from_f64(A63);
let a64 = S::from_f64(A64);
let a65 = S::from_f64(A65);
for i in 0..dim {
y_stage[i] = y[i]
+ h * (a61 * k[i]
+ a63 * k[2 * dim + i]
+ a64 * k[3 * dim + i]
+ a65 * k[4 * dim + i]);
}
problem.rhs(t + c6 * h, &y_stage, &mut k[5 * dim..6 * dim]);
let c7 = S::from_f64(C[6]);
let a71 = S::from_f64(A71);
let a73 = S::from_f64(A73);
let a74 = S::from_f64(A74);
let a75 = S::from_f64(A75);
let a76 = S::from_f64(A76);
for i in 0..dim {
y_stage[i] = y[i]
+ h * (a71 * k[i]
+ a73 * k[2 * dim + i]
+ a74 * k[3 * dim + i]
+ a75 * k[4 * dim + i]
+ a76 * k[5 * dim + i]);
}
problem.rhs(t + c7 * h, &y_stage, &mut k[6 * dim..7 * dim]);
let c8 = S::from_f64(C[7]);
let a81 = S::from_f64(A81);
let a83 = S::from_f64(A83);
let a84 = S::from_f64(A84);
let a85 = S::from_f64(A85);
let a86 = S::from_f64(A86);
let a87 = S::from_f64(A87);
for i in 0..dim {
y_stage[i] = y[i]
+ h * (a81 * k[i]
+ a83 * k[2 * dim + i]
+ a84 * k[3 * dim + i]
+ a85 * k[4 * dim + i]
+ a86 * k[5 * dim + i]
+ a87 * k[6 * dim + i]);
}
problem.rhs(t + c8 * h, &y_stage, &mut k[7 * dim..8 * dim]);
let a91 = S::from_f64(A91);
let a94 = S::from_f64(A94);
let a95 = S::from_f64(A95);
let a96 = S::from_f64(A96);
let a97 = S::from_f64(A97);
let a98 = S::from_f64(A98);
for i in 0..dim {
y_new[i] = y[i]
+ h * (a91 * k[i]
+ a94 * k[3 * dim + i]
+ a95 * k[4 * dim + i]
+ a96 * k[5 * dim + i]
+ a97 * k[6 * dim + i]
+ a98 * k[7 * dim + i]);
}
problem.rhs(t + h, &y_new, &mut k[8 * dim..9 * dim]);
stats.n_eval += 8;
let e = &E;
for i in 0..dim {
err[i] = h
* (S::from_f64(e[0]) * k[i]
+ S::from_f64(e[3]) * k[3 * dim + i]
+ S::from_f64(e[4]) * k[4 * dim + i]
+ S::from_f64(e[5]) * k[5 * dim + i]
+ S::from_f64(e[6]) * k[6 * dim + i]
+ S::from_f64(e[7]) * k[7 * dim + i]
+ S::from_f64(e[8]) * k[8 * dim + i]);
}
let mut err_norm = S::ZERO;
for i in 0..dim {
let sc = options.atol + options.rtol * y[i].abs().max(y_new[i].abs());
let ratio = err[i] / sc;
err_norm = err_norm + ratio * ratio;
}
err_norm = (err_norm / S::from_f64(dim as f64)).sqrt();
if err_norm <= S::ONE {
stats.n_accept += 1;
let t_new = t + h;
if let Some(ref mut emitter) = grid_emitter {
let (dy_start, dy_end_block) = k.split_at(dim);
let dy_end = &dy_end_block[(8 - 1) * dim..8 * dim];
emitter.emit_step(
t, &y, dy_start, t_new, &y_new, dy_end, &mut t_out, &mut y_out,
);
} else {
t_out.push(t_new);
y_out.extend_from_slice(&y_new);
}
t = t_new;
y.copy_from_slice(&y_new);
k.copy_within(8 * dim..9 * dim, 0);
} else {
stats.n_reject += 1;
}
let safety = S::from_f64(0.9);
let min_factor = S::from_f64(0.2);
let max_factor = S::from_f64(10.0);
let factor = if err_norm > S::from_f64(1e-10) {
safety * (S::ONE / err_norm).powf(S::from_f64(1.0 / 7.0))
} else {
max_factor
};
let factor = factor.min(max_factor).max(min_factor);
h = h * factor;
let h_abs = h.abs();
let h_abs = h_abs.min(options.h_max).max(options.h_min);
h = h_abs * direction;
step_count += 1;
}
if stats.n_accept > 0 {
} else {
problem.rhs(t, &y, &mut k[0..dim]);
stats.n_eval += 1;
}
Ok(SolverResult::new(t_out, y_out, dim, stats))
}
}
#[derive(Clone, Debug, Default)]
pub struct Vern7;
impl Vern7 {
pub fn new() -> Self {
Self
}
}
mod vern7_tableau {
pub const C: [f64; 10] = [
0.0,
0.005,
0.10888888888888888888888888888888888889,
0.16333333333333333333333333333333333333,
0.4555,
0.60950944899783813170870044214860249496,
0.884,
0.925,
1.0,
1.0,
];
pub const A: [[f64; 9]; 9] = [
[0.005, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[
-1.076790123456790123456790123456790123457,
1.185679012345679012345679012345679012346,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
],
[
0.04083333333333333333333333333333333333333,
0.0,
0.1225,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
],
[
0.6389139236255726780508121615993336109954,
0.0,
-2.455672638223656809662640566430653894211,
2.272258714598084131611828404831320283215,
0.0,
0.0,
0.0,
0.0,
0.0,
],
[
-2.661577375018757131119259297861818119279,
0.0,
10.80451388645613769565396655365532838482,
-8.353914657396199411968048547819291691541,
0.8204875949566569791420417341743839209619,
0.0,
0.0,
0.0,
0.0,
],
[
6.067741434696770992718360183877276714679,
0.0,
-24.71127363591108579734203485290746001803,
20.42751793078889394045773111748346612697,
-1.906157978816647150624096784352757010879,
1.006172249242068014790040335899474187268,
0.0,
0.0,
0.0,
],
[
12.05467007625320299509109452892778311648,
0.0,
-49.75478495046898932807257615331444758322,
41.14288863860467663259698416710157354209,
-4.461760149974004185641911603484815375051,
2.042334822239174959821717077708608543738,
-0.09834843665406107379530801693870224403537,
0.0,
0.0,
],
[
10.13814652288180787641845141981689030769,
0.0,
-42.64113603171750214622846006736635730625,
35.76384003992257007135021178023160054034,
-4.348022840392907653340370296908245943710,
2.009862268377035895441943593011827554771,
0.3487490460338272405953822853053145879140,
-0.2714390051048312842371587140910297407572,
0.0,
],
[
-45.03007203429867712435322405073769635151,
0.0,
187.3272437654588840752418206154201997384,
-154.0288236935018690596728621034510402582,
18.56465306347536233859492332958439136765,
-7.141809679295078854925420496823551192821,
1.308808578161378625114762706007696696508,
0.0,
0.0,
],
];
pub const B: [f64; 10] = [
0.047155618486272221704317651088381756796,
0.0,
0.0,
0.25750564298434151895964361010376875810,
0.26216653977412620477138630957645277111,
0.15216092656738557403231331991651175355,
0.49399691700324842469071758932278768443,
-0.29430311714032504415572447440927034291,
0.081317472324951099997345994401367618925,
0.0,
];
pub const E: [f64; 10] = [
0.047155618486272221704317651088381756796 - 0.044608606606341176287318175974791977814, 0.0,
0.0,
0.25750564298434151895964361010376875810 - 0.26716403785713726805091022609438378997, 0.26216653977412620477138630957645277111 - 0.22010183001772930199797157766507530963, 0.15216092656738557403231331991651175355 - 0.21884317031431568309831208335128938246, 0.49399691700324842469071758932278768443 - 0.22898717054112028833781738897635523654, -0.29430311714032504415572447440927034291 - 0.0, 0.081317472324951099997345994401367618925 - 0.0, 0.0 - 0.020295184663356282227670547938104303586, ];
}
impl<S: Scalar> Solver<S> for Vern7 {
fn solve<Sys: OdeSystem<S>>(
problem: &Sys,
t0: S,
tf: S,
y0: &[S],
options: &SolverOptions<S>,
) -> Result<SolverResult<S>, SolverError> {
let a_slices: Vec<&[f64]> = vern7_tableau::A.iter().map(|row| row.as_slice()).collect();
solve_erk(
problem,
t0,
tf,
y0,
options,
&vern7_tableau::C,
&a_slices,
&vern7_tableau::B,
&vern7_tableau::E,
7, 10, )
}
}
#[derive(Clone, Debug, Default)]
pub struct Vern8;
impl Vern8 {
pub fn new() -> Self {
Self
}
}
mod vern8_tableau {
pub const C: [f64; 13] = [
0.0,
0.05,
0.1065625,
0.15984375,
0.39,
0.465,
0.155,
0.943,
0.9018020417358569582597079406783721499560,
0.909,
0.94,
1.0,
1.0,
];
pub const A2: [f64; 1] = [0.05];
pub const A3: [f64; 2] = [-0.0069931640625, 0.1135556640625];
pub const A4: [f64; 3] = [0.0399609375, 0.0, 0.1198828125];
pub const A5: [f64; 4] = [
0.3613975628004575124052940721184028345129,
0.0,
-1.341524066700492771819987788202715834917,
1.370126503900035259414693716084313000404,
];
pub const A6: [f64; 5] = [
0.04904720279720279720279720279720279720280,
0.0,
0.0,
0.2350972042214404739862988335493427143122,
0.1808555929813567288109039636534544884850,
];
pub const A7: [f64; 6] = [
0.06169289044289044289044289044289044289044,
0.0,
0.0,
0.1123656831464027662262557035130015442303,
-0.03885046071451366767049048108111244567456,
0.01979188712522045855379188712522045855379,
];
pub const A8: [f64; 7] = [
-1.767630240222326875735597119572145586714,
0.0,
0.0,
-62.5,
-6.061889377376669100821361459659331999758,
5.650823198222763138561298030600840174201,
65.62169641937623283799566054863063741227,
];
pub const A9: [f64; 8] = [
-1.180945066554970799825116282628297957882,
0.0,
0.0,
-41.50473441114320841606641502701994225874,
-4.434438319103725011225169229846100211776,
4.260408188586133024812193710744693240761,
43.75364022446171584987676829438379303004,
0.007871425489912310687446475044226307550860,
];
pub const A10: [f64; 9] = [
-1.281405999441488405459510291182054246266,
0.0,
0.0,
-45.04713996013986630220754257136007322267,
-4.731362069449576477311464265491282810943,
4.514967016593807841185851584597240996214,
47.44909557172985134869022392235929015114,
0.01059228297111661135687393955516542875228,
-0.005746842263844616254432318478286296232021,
];
pub const A11: [f64; 10] = [
-1.724470134262485191756709817484481861731,
0.0,
0.0,
-60.92349008483054016518434619253765246063,
-5.951518376222392455202832767061854868290,
5.556523730698456235979791650843592496839,
63.98301198033305336837536378635995939281,
0.01464202825041496159275921391759452676003,
0.06460408772358203603621865144977650714892,
-0.07930323169008878984024452548693373291447,
];
pub const A12: [f64; 11] = [
-3.301622667747079016353994789790983625569,
0.0,
0.0,
-118.0112723597525085666923303957898868510,
-10.14142238845611248642783916034510897595,
9.139311332232057923544012273556827000619,
123.3759428284042683684847180986501894364,
4.623244378874580474839807625067630924792,
-3.383277738068201923652550971536811240814,
4.527592100324618189451265339351129035325,
-5.828495485811622963193088019162985703755,
];
pub const A13: [f64; 12] = [
-3.039515033766309030040102851821200251056,
0.0,
0.0,
-109.2608680894176254686444192322164623352,
-9.290642497400293449717665542656897549158,
8.430504981764911142134299253836167803454,
114.2010010378331313557424041095523427476,
-0.9637271342145479358162375658987901652762,
-5.034884088802189791198680336183332323118,
5.958130824002923177540402165388172072794,
0.0,
0.0,
];
pub const B: [f64; 13] = [
0.04427989419007951074716746668098518862111,
0.0,
0.0,
0.0,
0.0,
0.3541049391724448744815552028733568354121,
0.2479692154956437828667629415370663023884,
-15.69420203883808405099207034271191213468,
25.08406496555856261343930031237186278518,
-31.73836778626027646833156112007297739997,
22.93828327398878395231483560344797018313,
-0.2361324633071542145259900641263517600737,
0.0,
];
#[allow(dead_code)]
pub const B_HAT: [f64; 13] = [
0.04431261522908979212486436510209029764893,
0.0,
0.0,
0.0,
0.0,
0.3546095642343226447863179350895055038855,
0.2478480431366653069619986721504458660016,
4.448134732475784492725128317159648871312,
19.84688636611873369930932399297687935291,
-23.58162337746561841969517960870394965085,
0.0,
0.0,
-0.3601679437289775162124536737746202409110,
];
pub const E: [f64; 13] = [
0.04427989419007951074716746668098518862111 - 0.04431261522908979212486436510209029764893,
0.0,
0.0,
0.0,
0.0,
0.3541049391724448744815552028733568354121 - 0.3546095642343226447863179350895055038855,
0.2479692154956437828667629415370663023884 - 0.2478480431366653069619986721504458660016,
-15.69420203883808405099207034271191213468 - 4.448134732475784492725128317159648871312,
25.08406496555856261343930031237186278518 - 19.84688636611873369930932399297687935291,
-31.73836778626027646833156112007297739997 - (-23.58162337746561841969517960870394965085),
22.93828327398878395231483560344797018313 - 0.0,
-0.2361324633071542145259900641263517600737 - 0.0,
0.0 - (-0.3601679437289775162124536737746202409110),
];
}
impl<S: Scalar> Solver<S> for Vern8 {
fn solve<Sys: OdeSystem<S>>(
problem: &Sys,
t0: S,
tf: S,
y0: &[S],
options: &SolverOptions<S>,
) -> Result<SolverResult<S>, SolverError> {
solve_vern8(problem, t0, tf, y0, options)
}
}
fn solve_erk<S, Sys>(
problem: &Sys,
t0: S,
tf: S,
y0: &[S],
options: &SolverOptions<S>,
c: &[f64],
a: &[&[f64]],
b: &[f64],
e: &[f64],
order: usize,
stages: usize,
) -> Result<SolverResult<S>, SolverError>
where
S: Scalar,
Sys: OdeSystem<S>,
{
let dim = problem.dim();
if y0.len() != dim {
return Err(SolverError::DimensionMismatch {
expected: dim,
actual: y0.len(),
});
}
let mut t = t0;
let mut y = y0.to_vec();
let mut k = vec![S::ZERO; stages * dim];
let mut y_stage = vec![S::ZERO; dim];
let mut y_new = vec![S::ZERO; dim];
let mut err = vec![S::ZERO; dim];
let mut dy_old = vec![S::ZERO; dim];
let mut dy_new = vec![S::ZERO; dim];
let mut stats = SolverStats::default();
problem.rhs(t, &y, &mut k[0..dim]);
stats.n_eval += 1;
let mut h = initial_step_size(&y, &k[0..dim], options, dim);
let h_min = options.h_min;
let h_max = options.h_max.min((tf - t0).abs());
let direction = if tf > t0 { S::ONE } else { -S::ONE };
if let Some(grid) = options.t_eval.as_deref() {
validate_grid(grid, t0, tf)?;
}
let mut grid_emitter = options
.t_eval
.as_deref()
.map(|g| TEvalEmitter::new(g, direction));
let (mut t_out, mut y_out) = if grid_emitter.is_some() {
(Vec::new(), Vec::new())
} else {
(vec![t0], y0.to_vec())
};
let mut step_count = 0_usize;
while (tf - t) * direction > S::from_f64(1e-10) * (tf - t0).abs() {
if step_count >= options.max_steps {
return Err(SolverError::MaxIterationsExceeded { t: t.to_f64() });
}
if (t + h - tf) * direction > S::ZERO {
h = tf - t;
}
h = h.abs().max(h_min) * direction;
if h.abs() > h_max {
h = h_max * direction;
}
for s in 1..stages {
for i in 0..dim {
let mut sum = S::ZERO;
for j in 0..s {
if s - 1 < a.len() && j < a[s - 1].len() {
sum = sum + S::from_f64(a[s - 1][j]) * k[j * dim + i];
}
}
y_stage[i] = y[i] + h * sum;
}
problem.rhs(
t + S::from_f64(c[s]) * h,
&y_stage,
&mut k[s * dim..(s + 1) * dim],
);
}
stats.n_eval += stages - 1;
for i in 0..dim {
let mut sum_b = S::ZERO;
let mut sum_e = S::ZERO;
for s in 0..stages {
sum_b = sum_b + S::from_f64(b[s]) * k[s * dim + i];
sum_e = sum_e + S::from_f64(e[s]) * k[s * dim + i];
}
y_new[i] = y[i] + h * sum_b;
err[i] = h * sum_e;
}
let err_norm = error_norm(&err, &y, &y_new, options, dim);
let safety = S::from_f64(0.9);
let fac_max = S::from_f64(4.0);
let fac_min = S::from_f64(0.2);
let order_f = S::from_usize(order + 1);
if err_norm <= S::ONE {
stats.n_accept += 1;
let t_new = t + h;
dy_old.copy_from_slice(&k[0..dim]);
problem.rhs(t_new, &y_new, &mut dy_new);
stats.n_eval += 1;
if let Some(ref mut emitter) = grid_emitter {
emitter.emit_step(
t, &y, &dy_old, t_new, &y_new, &dy_new, &mut t_out, &mut y_out,
);
} else {
t_out.push(t_new);
y_out.extend_from_slice(&y_new);
}
t = t_new;
y.copy_from_slice(&y_new);
k[0..dim].copy_from_slice(&dy_new);
let err_safe = err_norm.max(S::from_f64(1e-10));
let fac = safety * err_safe.powf(-S::ONE / order_f);
let fac = fac.min(fac_max).max(fac_min);
h = h * fac;
} else {
stats.n_reject += 1;
let err_safe = err_norm.max(S::from_f64(1e-10));
let fac = safety * err_safe.powf(-S::ONE / (order_f - S::ONE));
let fac = fac.max(fac_min);
h = h * fac;
}
if h.abs() < h_min {
return Err(SolverError::StepSizeTooSmall {
t: t.to_f64(),
h: h.to_f64(),
h_min: h_min.to_f64(),
});
}
step_count += 1;
}
Ok(SolverResult::new(t_out, y_out, dim, stats))
}
fn solve_vern8<S, Sys>(
problem: &Sys,
t0: S,
tf: S,
y0: &[S],
options: &SolverOptions<S>,
) -> Result<SolverResult<S>, SolverError>
where
S: Scalar,
Sys: OdeSystem<S>,
{
let dim = problem.dim();
if y0.len() != dim {
return Err(SolverError::DimensionMismatch {
expected: dim,
actual: y0.len(),
});
}
let mut t = t0;
let mut y = y0.to_vec();
let mut k = vec![S::ZERO; 13 * dim];
let mut y_stage = vec![S::ZERO; dim];
let mut y_new = vec![S::ZERO; dim];
let mut err = vec![S::ZERO; dim];
let mut dy_old = vec![S::ZERO; dim];
let mut dy_new = vec![S::ZERO; dim];
let mut stats = SolverStats::default();
problem.rhs(t, &y, &mut k[0..dim]);
stats.n_eval += 1;
let mut h = initial_step_size(&y, &k[0..dim], options, dim);
let h_min = options.h_min;
let h_max = options.h_max.min((tf - t0).abs());
let direction = if tf > t0 { S::ONE } else { -S::ONE };
if let Some(grid) = options.t_eval.as_deref() {
validate_grid(grid, t0, tf)?;
}
let mut grid_emitter = options
.t_eval
.as_deref()
.map(|g| TEvalEmitter::new(g, direction));
let (mut t_out, mut y_out) = if grid_emitter.is_some() {
(Vec::new(), Vec::new())
} else {
(vec![t0], y0.to_vec())
};
let mut step_count = 0_usize;
while (tf - t) * direction > S::from_f64(1e-10) * (tf - t0).abs() {
if step_count >= options.max_steps {
return Err(SolverError::MaxIterationsExceeded { t: t.to_f64() });
}
if (t + h - tf) * direction > S::ZERO {
h = tf - t;
}
h = h.abs().max(h_min) * direction;
if h.abs() > h_max {
h = h_max * direction;
}
for i in 0..dim {
y_stage[i] = y[i] + h * S::from_f64(vern8_tableau::A2[0]) * k[i];
}
problem.rhs(
t + S::from_f64(vern8_tableau::C[1]) * h,
&y_stage,
&mut k[dim..2 * dim],
);
for i in 0..dim {
y_stage[i] = y[i]
+ h * (S::from_f64(vern8_tableau::A3[0]) * k[i]
+ S::from_f64(vern8_tableau::A3[1]) * k[dim + i]);
}
problem.rhs(
t + S::from_f64(vern8_tableau::C[2]) * h,
&y_stage,
&mut k[2 * dim..3 * dim],
);
for i in 0..dim {
y_stage[i] = y[i]
+ h * (S::from_f64(vern8_tableau::A4[0]) * k[i]
+ S::from_f64(vern8_tableau::A4[2]) * k[2 * dim + i]);
}
problem.rhs(
t + S::from_f64(vern8_tableau::C[3]) * h,
&y_stage,
&mut k[3 * dim..4 * dim],
);
for i in 0..dim {
y_stage[i] = y[i]
+ h * (S::from_f64(vern8_tableau::A5[0]) * k[i]
+ S::from_f64(vern8_tableau::A5[2]) * k[2 * dim + i]
+ S::from_f64(vern8_tableau::A5[3]) * k[3 * dim + i]);
}
problem.rhs(
t + S::from_f64(vern8_tableau::C[4]) * h,
&y_stage,
&mut k[4 * dim..5 * dim],
);
for i in 0..dim {
y_stage[i] = y[i]
+ h * (S::from_f64(vern8_tableau::A6[0]) * k[i]
+ S::from_f64(vern8_tableau::A6[3]) * k[3 * dim + i]
+ S::from_f64(vern8_tableau::A6[4]) * k[4 * dim + i]);
}
problem.rhs(
t + S::from_f64(vern8_tableau::C[5]) * h,
&y_stage,
&mut k[5 * dim..6 * dim],
);
for i in 0..dim {
y_stage[i] = y[i]
+ h * (S::from_f64(vern8_tableau::A7[0]) * k[i]
+ S::from_f64(vern8_tableau::A7[3]) * k[3 * dim + i]
+ S::from_f64(vern8_tableau::A7[4]) * k[4 * dim + i]
+ S::from_f64(vern8_tableau::A7[5]) * k[5 * dim + i]);
}
problem.rhs(
t + S::from_f64(vern8_tableau::C[6]) * h,
&y_stage,
&mut k[6 * dim..7 * dim],
);
for i in 0..dim {
y_stage[i] = y[i]
+ h * (S::from_f64(vern8_tableau::A8[0]) * k[i]
+ S::from_f64(vern8_tableau::A8[3]) * k[3 * dim + i]
+ S::from_f64(vern8_tableau::A8[4]) * k[4 * dim + i]
+ S::from_f64(vern8_tableau::A8[5]) * k[5 * dim + i]
+ S::from_f64(vern8_tableau::A8[6]) * k[6 * dim + i]);
}
problem.rhs(
t + S::from_f64(vern8_tableau::C[7]) * h,
&y_stage,
&mut k[7 * dim..8 * dim],
);
for i in 0..dim {
y_stage[i] = y[i]
+ h * (S::from_f64(vern8_tableau::A9[0]) * k[i]
+ S::from_f64(vern8_tableau::A9[3]) * k[3 * dim + i]
+ S::from_f64(vern8_tableau::A9[4]) * k[4 * dim + i]
+ S::from_f64(vern8_tableau::A9[5]) * k[5 * dim + i]
+ S::from_f64(vern8_tableau::A9[6]) * k[6 * dim + i]
+ S::from_f64(vern8_tableau::A9[7]) * k[7 * dim + i]);
}
problem.rhs(
t + S::from_f64(vern8_tableau::C[8]) * h,
&y_stage,
&mut k[8 * dim..9 * dim],
);
for i in 0..dim {
y_stage[i] = y[i]
+ h * (S::from_f64(vern8_tableau::A10[0]) * k[i]
+ S::from_f64(vern8_tableau::A10[3]) * k[3 * dim + i]
+ S::from_f64(vern8_tableau::A10[4]) * k[4 * dim + i]
+ S::from_f64(vern8_tableau::A10[5]) * k[5 * dim + i]
+ S::from_f64(vern8_tableau::A10[6]) * k[6 * dim + i]
+ S::from_f64(vern8_tableau::A10[7]) * k[7 * dim + i]
+ S::from_f64(vern8_tableau::A10[8]) * k[8 * dim + i]);
}
problem.rhs(
t + S::from_f64(vern8_tableau::C[9]) * h,
&y_stage,
&mut k[9 * dim..10 * dim],
);
for i in 0..dim {
y_stage[i] = y[i]
+ h * (S::from_f64(vern8_tableau::A11[0]) * k[i]
+ S::from_f64(vern8_tableau::A11[3]) * k[3 * dim + i]
+ S::from_f64(vern8_tableau::A11[4]) * k[4 * dim + i]
+ S::from_f64(vern8_tableau::A11[5]) * k[5 * dim + i]
+ S::from_f64(vern8_tableau::A11[6]) * k[6 * dim + i]
+ S::from_f64(vern8_tableau::A11[7]) * k[7 * dim + i]
+ S::from_f64(vern8_tableau::A11[8]) * k[8 * dim + i]
+ S::from_f64(vern8_tableau::A11[9]) * k[9 * dim + i]);
}
problem.rhs(
t + S::from_f64(vern8_tableau::C[10]) * h,
&y_stage,
&mut k[10 * dim..11 * dim],
);
for i in 0..dim {
y_stage[i] = y[i]
+ h * (S::from_f64(vern8_tableau::A12[0]) * k[i]
+ S::from_f64(vern8_tableau::A12[3]) * k[3 * dim + i]
+ S::from_f64(vern8_tableau::A12[4]) * k[4 * dim + i]
+ S::from_f64(vern8_tableau::A12[5]) * k[5 * dim + i]
+ S::from_f64(vern8_tableau::A12[6]) * k[6 * dim + i]
+ S::from_f64(vern8_tableau::A12[7]) * k[7 * dim + i]
+ S::from_f64(vern8_tableau::A12[8]) * k[8 * dim + i]
+ S::from_f64(vern8_tableau::A12[9]) * k[9 * dim + i]
+ S::from_f64(vern8_tableau::A12[10]) * k[10 * dim + i]);
}
problem.rhs(
t + S::from_f64(vern8_tableau::C[11]) * h,
&y_stage,
&mut k[11 * dim..12 * dim],
);
for i in 0..dim {
y_stage[i] = y[i]
+ h * (S::from_f64(vern8_tableau::A13[0]) * k[i]
+ S::from_f64(vern8_tableau::A13[3]) * k[3 * dim + i]
+ S::from_f64(vern8_tableau::A13[4]) * k[4 * dim + i]
+ S::from_f64(vern8_tableau::A13[5]) * k[5 * dim + i]
+ S::from_f64(vern8_tableau::A13[6]) * k[6 * dim + i]
+ S::from_f64(vern8_tableau::A13[7]) * k[7 * dim + i]
+ S::from_f64(vern8_tableau::A13[8]) * k[8 * dim + i]
+ S::from_f64(vern8_tableau::A13[9]) * k[9 * dim + i]);
}
problem.rhs(
t + S::from_f64(vern8_tableau::C[12]) * h,
&y_stage,
&mut k[12 * dim..13 * dim],
);
stats.n_eval += 12;
for i in 0..dim {
let mut sum_b = S::ZERO;
let mut sum_e = S::ZERO;
for s in 0..13 {
sum_b = sum_b + S::from_f64(vern8_tableau::B[s]) * k[s * dim + i];
sum_e = sum_e + S::from_f64(vern8_tableau::E[s]) * k[s * dim + i];
}
y_new[i] = y[i] + h * sum_b;
err[i] = h * sum_e;
}
let err_norm = error_norm(&err, &y, &y_new, options, dim);
let safety = S::from_f64(0.9);
let fac_max = S::from_f64(3.0);
let fac_min = S::from_f64(0.2);
if err_norm <= S::ONE {
stats.n_accept += 1;
let t_new = t + h;
dy_old.copy_from_slice(&k[0..dim]);
problem.rhs(t_new, &y_new, &mut dy_new);
stats.n_eval += 1;
if let Some(ref mut emitter) = grid_emitter {
emitter.emit_step(
t, &y, &dy_old, t_new, &y_new, &dy_new, &mut t_out, &mut y_out,
);
} else {
t_out.push(t_new);
y_out.extend_from_slice(&y_new);
}
t = t_new;
y.copy_from_slice(&y_new);
k[0..dim].copy_from_slice(&dy_new);
let err_safe = err_norm.max(S::from_f64(1e-10));
let fac = safety * err_safe.powf(S::from_f64(-1.0 / 9.0));
let fac = fac.min(fac_max).max(fac_min);
h = h * fac;
} else {
stats.n_reject += 1;
let err_safe = err_norm.max(S::from_f64(1e-10));
let fac = safety * err_safe.powf(S::from_f64(-1.0 / 8.0));
let fac = fac.max(fac_min);
h = h * fac;
}
if h.abs() < h_min {
return Err(SolverError::StepSizeTooSmall {
t: t.to_f64(),
h: h.to_f64(),
h_min: h_min.to_f64(),
});
}
step_count += 1;
}
Ok(SolverResult::new(t_out, y_out, dim, stats))
}
fn initial_step_size<S: Scalar>(y0: &[S], f0: &[S], options: &SolverOptions<S>, dim: usize) -> S {
if let Some(h0) = options.h0 {
return h0;
}
let mut y_norm = S::ZERO;
let mut f_norm = S::ZERO;
for i in 0..dim {
let sc = options.atol + options.rtol * y0[i].abs();
y_norm = y_norm + (y0[i] / sc) * (y0[i] / sc);
f_norm = f_norm + (f0[i] / sc) * (f0[i] / sc);
}
y_norm = (y_norm / S::from_usize(dim)).sqrt();
f_norm = (f_norm / S::from_usize(dim)).sqrt();
if y_norm < S::from_f64(1e-5) || f_norm < S::from_f64(1e-5) {
S::from_f64(1e-6)
} else {
(S::from_f64(0.01) * y_norm / f_norm).min(options.h_max)
}
}
fn error_norm<S: Scalar>(
err: &[S],
y: &[S],
y_new: &[S],
options: &SolverOptions<S>,
dim: usize,
) -> S {
let mut err_norm = S::ZERO;
for i in 0..dim {
let sc = options.atol + options.rtol * y[i].abs().max(y_new[i].abs());
let sc = sc.max(S::from_f64(1e-15));
let scaled_err = err[i] / sc;
err_norm = err_norm + scaled_err * scaled_err;
}
(err_norm / S::from_usize(dim)).sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::problem::OdeProblem;
#[test]
fn test_vern6_exponential() {
let problem = OdeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -y[0];
},
0.0,
5.0,
vec![1.0],
);
let options = SolverOptions::default().rtol(1e-6).atol(1e-8);
let result = Vern6::solve(&problem, 0.0, 5.0, &[1.0], &options).unwrap();
assert!(result.success);
let y_final = result.y_final().unwrap();
let expected = (-5.0_f64).exp();
assert!(
(y_final[0] - expected).abs() < 1e-5,
"Vern6 exponential: got {}, expected {}",
y_final[0],
expected
);
}
#[test]
fn test_vern7_harmonic() {
let problem = OdeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = y[1];
dydt[1] = -y[0];
},
0.0,
10.0,
vec![1.0, 0.0],
);
let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
let result = Vern7::solve(&problem, 0.0, 10.0, &[1.0, 0.0], &options).unwrap();
assert!(result.success);
let y_final = result.y_final().unwrap();
assert!(
(y_final[0] - 10.0_f64.cos()).abs() < 0.01,
"Vern7 harmonic: got {}, expected {}",
y_final[0],
10.0_f64.cos()
);
}
#[test]
fn test_vern8_high_accuracy() {
let problem = OdeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -y[0];
},
0.0,
5.0,
vec![1.0],
);
let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
let result = Vern8::solve(&problem, 0.0, 5.0, &[1.0], &options).unwrap();
assert!(result.success);
let y_final = result.y_final().unwrap();
let expected = (-5.0_f64).exp();
assert!(
(y_final[0] - expected).abs() < expected * 0.05,
"Vern8 exponential: got {}, expected {}",
y_final[0],
expected
);
}
#[test]
fn test_vern_methods_agree() {
let problem = OdeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -y[0];
},
0.0,
2.0,
vec![1.0],
);
let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
let r6 = Vern6::solve(&problem, 0.0, 2.0, &[1.0], &options).unwrap();
let r7 = Vern7::solve(&problem, 0.0, 2.0, &[1.0], &options).unwrap();
let r8 = Vern8::solve(&problem, 0.0, 2.0, &[1.0], &options).unwrap();
let y6 = r6.y_final().unwrap()[0];
let y7 = r7.y_final().unwrap()[0];
let y8 = r8.y_final().unwrap()[0];
let expected = (-2.0_f64).exp();
assert!(
(y6 - expected).abs() < expected * 0.01,
"Vern6 disagrees: {}",
y6
);
assert!(
(y7 - expected).abs() < expected * 0.01,
"Vern7 disagrees: {}",
y7
);
assert!(
(y8 - expected).abs() < expected * 0.05,
"Vern8 disagrees: {}",
y8
);
}
}