use super::FixedPoint;
use super::FixedVector;
use super::linalg::compute_tier_dot_raw;
use crate::fixed_point::universal::fasc::stack_evaluator::BinaryStorage;
use crate::fixed_point::core_types::errors::OverflowDetected;
pub trait OdeSystem {
fn eval(&self, t: FixedPoint, x: &FixedVector) -> FixedVector;
}
pub struct OdeFn<F: Fn(FixedPoint, &FixedVector) -> FixedVector> {
pub f: F,
}
impl<F: Fn(FixedPoint, &FixedVector) -> FixedVector> OdeSystem for OdeFn<F> {
fn eval(&self, t: FixedPoint, x: &FixedVector) -> FixedVector {
(self.f)(t, x)
}
}
pub fn ode_fn<F: Fn(FixedPoint, &FixedVector) -> FixedVector>(f: F) -> OdeFn<F> {
OdeFn { f }
}
#[derive(Clone, Debug)]
pub struct OdePoint {
pub t: FixedPoint,
pub x: FixedVector,
}
pub fn rk4_step<S: OdeSystem>(
sys: &S,
t: FixedPoint,
x: &FixedVector,
h: FixedPoint,
) -> FixedVector {
let half_h = h_half(h);
let h_sixth = half_h / FixedPoint::from_int(3);
let two = FixedPoint::from_int(2);
let k1 = sys.eval(t, x);
let k2 = sys.eval(t + half_h, &(x + &(&k1 * half_h)));
let k3 = sys.eval(t + half_h, &(x + &(&k2 * half_h)));
let k4 = sys.eval(t + h, &(x + &(&k3 * h)));
let n = x.len();
let weights: Vec<BinaryStorage> = vec![
FixedPoint::one().raw(), two.raw(), two.raw(), FixedPoint::one().raw(),
];
let mut result = FixedVector::new(n);
for i in 0..n {
let k_vals: Vec<BinaryStorage> = vec![k1[i].raw(), k2[i].raw(), k3[i].raw(), k4[i].raw()];
let weighted = FixedPoint::from_raw(compute_tier_dot_raw(&weights, &k_vals));
result[i] = x[i] + h_sixth * weighted;
}
result
}
pub fn rk4_integrate<S: OdeSystem>(
sys: &S,
x0: &FixedVector,
t0: FixedPoint,
t_end: FixedPoint,
h: FixedPoint,
) -> Vec<OdePoint> {
let mut trajectory = Vec::new();
let mut t = t0;
let mut x = x0.clone();
trajectory.push(OdePoint { t, x: x.clone() });
while t < t_end {
let remaining = t_end - t;
let step = if remaining < h { remaining } else { h };
if step.is_zero() { break; }
x = rk4_step(sys, t, &x, step);
t = t + step;
trajectory.push(OdePoint { t, x: x.clone() });
}
trajectory
}
pub struct Rk45Config {
pub tol: FixedPoint,
pub h_init: FixedPoint,
pub h_min: FixedPoint,
pub h_max: FixedPoint,
pub max_steps: usize,
}
impl Rk45Config {
pub fn new(tol: FixedPoint, h_init: FixedPoint) -> Self {
let h_min = FixedPoint::from_raw(quantum_raw());
Self {
tol,
h_init,
h_min,
h_max: h_init * FixedPoint::from_int(16),
max_steps: 100_000,
}
}
}
pub fn rk45_integrate<S: OdeSystem>(
sys: &S,
x0: &FixedVector,
t0: FixedPoint,
t_end: FixedPoint,
config: &Rk45Config,
) -> Result<(Vec<OdePoint>, usize), OverflowDetected> {
let mut trajectory = Vec::new();
let mut t = t0;
let mut x = x0.clone();
let mut h = config.h_init;
let mut rejected = 0usize;
trajectory.push(OdePoint { t, x: x.clone() });
let tol_loose = config.tol / FixedPoint::from_int(32);
for _ in 0..config.max_steps {
if t >= t_end { break; }
let remaining = t_end - t;
let step = if remaining < h { remaining } else { h };
if step.is_zero() { break; }
let (x4, x5) = dp45_pair(sys, t, &x, step);
let err = inf_norm_diff(&x5, &x4);
if err > config.tol {
h = h_half(h);
if h < config.h_min { h = config.h_min; }
rejected += 1;
continue;
}
x = x5;
t = t + step;
trajectory.push(OdePoint { t, x: x.clone() });
if err < tol_loose {
h = h + h; if h > config.h_max { h = config.h_max; }
}
}
Ok((trajectory, rejected))
}
fn dp45_pair<S: OdeSystem>(
sys: &S,
t: FixedPoint,
x: &FixedVector,
h: FixedPoint,
) -> (FixedVector, FixedVector) {
let c2 = FixedPoint::one() / FixedPoint::from_int(5);
let c3 = FixedPoint::from_int(3) / FixedPoint::from_int(10);
let c4 = FixedPoint::from_int(4) / FixedPoint::from_int(5);
let c5 = FixedPoint::from_int(8) / FixedPoint::from_int(9);
let k1 = sys.eval(t, x);
let x2 = x + &(&k1 * (h * c2));
let k2 = sys.eval(t + c2 * h, &x2);
let a31 = FixedPoint::from_int(3) / FixedPoint::from_int(40);
let a32 = FixedPoint::from_int(9) / FixedPoint::from_int(40);
let x3 = x + &(&(&k1 * (h * a31)) + &(&k2 * (h * a32)));
let k3 = sys.eval(t + c3 * h, &x3);
let a41 = FixedPoint::from_int(44) / FixedPoint::from_int(45);
let a42 = FixedPoint::from_int(-56) / FixedPoint::from_int(15);
let a43 = FixedPoint::from_int(32) / FixedPoint::from_int(9);
let x4_tmp = x + &(&(&k1 * (h * a41)) + &(&(&k2 * (h * a42)) + &(&k3 * (h * a43))));
let k4 = sys.eval(t + c4 * h, &x4_tmp);
let a51 = FixedPoint::from_int(19372) / FixedPoint::from_int(6561);
let a52 = FixedPoint::from_int(-25360) / FixedPoint::from_int(2187);
let a53 = FixedPoint::from_int(64448) / FixedPoint::from_int(6561);
let a54 = FixedPoint::from_int(-212) / FixedPoint::from_int(729);
let x5_tmp = x + &(&(&k1 * (h * a51))
+ &(&(&k2 * (h * a52)) + &(&(&k3 * (h * a53)) + &(&k4 * (h * a54)))));
let k5 = sys.eval(t + c5 * h, &x5_tmp);
let a61 = FixedPoint::from_int(9017) / FixedPoint::from_int(3168);
let a62 = FixedPoint::from_int(-355) / FixedPoint::from_int(33);
let a63 = FixedPoint::from_int(46732) / FixedPoint::from_int(5247);
let a64 = FixedPoint::from_int(49) / FixedPoint::from_int(176);
let a65 = FixedPoint::from_int(-5103) / FixedPoint::from_int(18656);
let x6_tmp = x + &(&(&k1 * (h * a61))
+ &(&(&k2 * (h * a62))
+ &(&(&k3 * (h * a63))
+ &(&(&k4 * (h * a64)) + &(&k5 * (h * a65))))));
let k6 = sys.eval(t + h, &x6_tmp);
let b1 = FixedPoint::from_int(35) / FixedPoint::from_int(384);
let b3 = FixedPoint::from_int(500) / FixedPoint::from_int(1113);
let b4 = FixedPoint::from_int(125) / FixedPoint::from_int(192);
let b5 = FixedPoint::from_int(-2187) / FixedPoint::from_int(6784);
let b6 = FixedPoint::from_int(11) / FixedPoint::from_int(84);
let bs1 = FixedPoint::from_int(5179) / FixedPoint::from_int(57600);
let bs3 = FixedPoint::from_int(7571) / FixedPoint::from_int(16695);
let bs4 = FixedPoint::from_int(393) / FixedPoint::from_int(640);
let bs5 = FixedPoint::from_int(-92097) / FixedPoint::from_int(339200);
let bs6 = FixedPoint::from_int(187) / FixedPoint::from_int(2100);
let _bs7 = FixedPoint::one() / FixedPoint::from_int(40);
let n = x.len();
let mut x5_out = FixedVector::new(n);
let mut x4_out = FixedVector::new(n);
let w5: Vec<BinaryStorage> = vec![b1.raw(), b3.raw(), b4.raw(), b5.raw(), b6.raw()];
let w4: Vec<BinaryStorage> = vec![bs1.raw(), bs3.raw(), bs4.raw(), bs5.raw(), bs6.raw()];
for i in 0..n {
let k_vals: Vec<BinaryStorage> = vec![k1[i].raw(), k3[i].raw(), k4[i].raw(), k5[i].raw(), k6[i].raw()];
let increment5 = h * FixedPoint::from_raw(compute_tier_dot_raw(&w5, &k_vals));
x5_out[i] = x[i] + increment5;
let increment4 = h * FixedPoint::from_raw(compute_tier_dot_raw(&w4, &k_vals));
x4_out[i] = x[i] + increment4;
}
(x4_out, x5_out)
}
pub trait HamiltonianSystem {
fn force(&self, q: &FixedVector, p: &FixedVector) -> FixedVector;
fn velocity(&self, q: &FixedVector, p: &FixedVector) -> FixedVector;
fn energy(&self, q: &FixedVector, p: &FixedVector) -> FixedPoint;
}
#[derive(Clone, Debug)]
pub struct HamiltonianPoint {
pub t: FixedPoint,
pub q: FixedVector,
pub p: FixedVector,
pub energy: FixedPoint,
}
pub fn verlet_step<H: HamiltonianSystem>(
sys: &H,
q: &FixedVector,
p: &FixedVector,
h: FixedPoint,
) -> (FixedVector, FixedVector) {
let half_h = h_half(h);
let f0 = sys.force(q, p);
let p_half = p + &(&f0 * half_h);
let v_half = sys.velocity(q, &p_half);
let q_new = q + &(&v_half * h);
let f1 = sys.force(&q_new, &p_half);
let p_new = &p_half + &(&f1 * half_h);
(q_new, p_new)
}
pub fn verlet_integrate<H: HamiltonianSystem>(
sys: &H,
q0: &FixedVector,
p0: &FixedVector,
t0: FixedPoint,
t_end: FixedPoint,
h: FixedPoint,
) -> Vec<HamiltonianPoint> {
let mut trajectory = Vec::new();
let mut t = t0;
let mut q = q0.clone();
let mut p = p0.clone();
trajectory.push(HamiltonianPoint {
t,
q: q.clone(),
p: p.clone(),
energy: sys.energy(&q, &p),
});
while t < t_end {
let remaining = t_end - t;
let step = if remaining < h { remaining } else { h };
if step.is_zero() { break; }
let (q_new, p_new) = verlet_step(sys, &q, &p, step);
q = q_new;
p = p_new;
t = t + step;
trajectory.push(HamiltonianPoint {
t,
q: q.clone(),
p: p.clone(),
energy: sys.energy(&q, &p),
});
}
trajectory
}
pub fn monitor_invariant<F: Fn(&FixedVector) -> FixedPoint>(
invariant: F,
trajectory: &[OdePoint],
) -> (FixedPoint, Vec<FixedPoint>) {
if trajectory.is_empty() {
return (FixedPoint::ZERO, Vec::new());
}
let c0 = invariant(&trajectory[0].x);
let mut max_drift = FixedPoint::ZERO;
let mut drifts = Vec::with_capacity(trajectory.len());
for point in trajectory {
let ci = invariant(&point.x);
let drift = (ci - c0).abs();
if drift > max_drift { max_drift = drift; }
drifts.push(drift);
}
(max_drift, drifts)
}
#[inline]
fn h_half(h: FixedPoint) -> FixedPoint {
#[cfg(any(table_format = "q32_32", table_format = "q16_16"))]
{ FixedPoint::from_raw(h.raw() >> 1) }
#[cfg(table_format = "q64_64")]
{ FixedPoint::from_raw(h.raw() >> 1u32) }
#[cfg(table_format = "q128_128")]
{ FixedPoint::from_raw(h.raw() >> 1u32) }
#[cfg(table_format = "q256_256")]
{ FixedPoint::from_raw(h.raw() >> 1usize) }
}
fn inf_norm_diff(a: &FixedVector, b: &FixedVector) -> FixedPoint {
assert_eq!(a.len(), b.len());
let mut max_val = FixedPoint::ZERO;
for i in 0..a.len() {
let d = (a[i] - b[i]).abs();
if d > max_val { max_val = d; }
}
max_val
}
#[cfg(table_format = "q32_32")]
fn quantum_raw() -> BinaryStorage { 1i64 }
#[cfg(table_format = "q16_16")]
fn quantum_raw() -> BinaryStorage { 1i32 }
#[cfg(table_format = "q64_64")]
fn quantum_raw() -> BinaryStorage { 1i128 }
#[cfg(table_format = "q128_128")]
fn quantum_raw() -> BinaryStorage {
use crate::fixed_point::I256;
I256::from_i128(1)
}
#[cfg(table_format = "q256_256")]
fn quantum_raw() -> BinaryStorage {
use crate::fixed_point::I512;
I512::from_i128(1)
}