use crate::common::IntegrateFloat;
use crate::error::IntegrateResult;
use crate::symplectic::{HamiltonianFn, SymplecticIntegrator};
use scirs2_core::ndarray::{Array1, Array2};
use std::f64::consts::PI;
use std::marker::PhantomData;
#[derive(Debug, Clone)]
pub struct GaussLegendre4<F: IntegrateFloat> {
_marker: PhantomData<F>,
}
impl<F: IntegrateFloat> GaussLegendre4<F> {
pub fn new() -> Self {
GaussLegendre4 {
_marker: PhantomData,
}
}
}
impl<F: IntegrateFloat> Default for GaussLegendre4<F> {
fn default() -> Self {
Self::new()
}
}
impl<F: IntegrateFloat> SymplecticIntegrator<F> for GaussLegendre4<F> {
fn step(
&self,
system: &dyn HamiltonianFn<F>,
t: F,
q: &Array1<F>,
p: &Array1<F>,
dt: F,
) -> IntegrateResult<(Array1<F>, Array1<F>)> {
let two = F::one() + F::one();
let half = F::one() / two;
let quarter = half / two;
let sqrt_3 = F::from_f64(3.0_f64.sqrt()).expect("Operation failed");
let sixth = F::one() / (F::from_f64(6.0).expect("Operation failed"));
let c = [half - sqrt_3 * sixth, half + sqrt_3 * sixth];
let a11 = quarter;
let a12 = quarter - sqrt_3 * sixth;
let a21 = quarter + sqrt_3 * sixth;
let a22 = quarter;
let n = q.len();
let zero = Array1::<F>::zeros(n);
let mut k_q = [zero.clone(), zero.clone()];
let mut k_p = [zero.clone(), zero.clone()];
let t1 = t + c[0] * dt;
let t2 = t + c[1] * dt;
for _ in 0..10 {
let q1 = q + &(&k_q[0] * (a11 * dt) + &k_q[1] * (a12 * dt));
let p1 = p + &(&k_p[0] * (a11 * dt) + &k_p[1] * (a12 * dt));
let q2 = q + &(&k_q[0] * (a21 * dt) + &k_q[1] * (a22 * dt));
let p2 = p + &(&k_p[0] * (a21 * dt) + &k_p[1] * (a22 * dt));
let dq1 = system.dq_dt(t1, &q1, &p1)?;
let dp1 = system.dp_dt(t1, &q1, &p1)?;
let dq2 = system.dq_dt(t2, &q2, &p2)?;
let dp2 = system.dp_dt(t2, &q2, &p2)?;
let err1 = (&dq1 - &k_q[0])
.iter()
.map(|&x| x.abs())
.fold(F::zero(), |a, b| a.max(b));
let err2 = (&dq2 - &k_q[1])
.iter()
.map(|&x| x.abs())
.fold(F::zero(), |a, b| a.max(b));
let err3 = (&dp1 - &k_p[0])
.iter()
.map(|&x| x.abs())
.fold(F::zero(), |a, b| a.max(b));
let err4 = (&dp2 - &k_p[1])
.iter()
.map(|&x| x.abs())
.fold(F::zero(), |a, b| a.max(b));
let max_err = err1.max(err2).max(err3).max(err4);
if max_err < F::from_f64(1e-12).expect("Operation failed") {
break;
}
k_q[0] = dq1;
k_p[0] = dp1;
k_q[1] = dq2;
k_p[1] = dp2;
}
let q_new = q + &(&k_q[0] * (half * dt) + &k_q[1] * (half * dt));
let p_new = p + &(&k_p[0] * (half * dt) + &k_p[1] * (half * dt));
Ok((q_new, p_new))
}
}
#[derive(Debug, Clone)]
pub struct GaussLegendre6<F: IntegrateFloat> {
_marker: PhantomData<F>,
}
impl<F: IntegrateFloat> GaussLegendre6<F> {
pub fn new() -> Self {
GaussLegendre6 {
_marker: PhantomData,
}
}
}
impl<F: IntegrateFloat> Default for GaussLegendre6<F> {
fn default() -> Self {
Self::new()
}
}
impl<F: IntegrateFloat> SymplecticIntegrator<F> for GaussLegendre6<F> {
fn step(
&self,
system: &dyn HamiltonianFn<F>,
t: F,
q: &Array1<F>,
p: &Array1<F>,
dt: F,
) -> IntegrateResult<(Array1<F>, Array1<F>)> {
let two = F::one() + F::one();
let _half = F::one() / two;
let c1 = F::from_f64(0.5 - 0.1 * 15.0_f64.sqrt()).expect("Operation failed");
let c2 = F::from_f64(0.5).expect("Operation failed");
let c3 = F::from_f64(0.5 + 0.1 * 15.0_f64.sqrt()).expect("Operation failed");
let c = [c1, c2, c3];
let a = Array2::<F>::from_shape_vec(
(3, 3),
vec![
F::from_f64(5.0 / 36.0).expect("Operation failed"),
F::from_f64(2.0 / 9.0 - 1.0 / 15.0 * 15.0_f64.sqrt()).expect("Operation failed"),
F::from_f64(5.0 / 36.0 - 1.0 / 30.0 * 15.0_f64.sqrt()).expect("Operation failed"),
F::from_f64(5.0 / 36.0 + 1.0 / 24.0 * 15.0_f64.sqrt()).expect("Operation failed"),
F::from_f64(2.0 / 9.0).expect("Operation failed"),
F::from_f64(5.0 / 36.0 - 1.0 / 24.0 * 15.0_f64.sqrt()).expect("Operation failed"),
F::from_f64(5.0 / 36.0 + 1.0 / 30.0 * 15.0_f64.sqrt()).expect("Operation failed"),
F::from_f64(2.0 / 9.0 + 1.0 / 15.0 * 15.0_f64.sqrt()).expect("Operation failed"),
F::from_f64(5.0 / 36.0).expect("Operation failed"),
],
)
.expect("Failed to create integrator");
let b1 = F::from_f64(5.0 / 18.0).expect("Operation failed");
let b2 = F::from_f64(4.0 / 9.0).expect("Operation failed");
let b3 = F::from_f64(5.0 / 18.0).expect("Operation failed");
let b = [b1, b2, b3];
let n = q.len();
let zero = Array1::<F>::zeros(n);
let mut k_q = [zero.clone(), zero.clone(), zero.clone()];
let mut k_p = [zero.clone(), zero.clone(), zero.clone()];
let t1 = t + c[0] * dt;
let t2 = t + c[1] * dt;
let t3 = t + c[2] * dt;
for _ in 0..15 {
let q1 = q + &(&k_q[0] * (a[[0, 0]] * dt)
+ &k_q[1] * (a[[0, 1]] * dt)
+ &k_q[2] * (a[[0, 2]] * dt));
let p1 = p + &(&k_p[0] * (a[[0, 0]] * dt)
+ &k_p[1] * (a[[0, 1]] * dt)
+ &k_p[2] * (a[[0, 2]] * dt));
let q2 = q + &(&k_q[0] * (a[[1, 0]] * dt)
+ &k_q[1] * (a[[1, 1]] * dt)
+ &k_q[2] * (a[[1, 2]] * dt));
let p2 = p + &(&k_p[0] * (a[[1, 0]] * dt)
+ &k_p[1] * (a[[1, 1]] * dt)
+ &k_p[2] * (a[[1, 2]] * dt));
let q3 = q + &(&k_q[0] * (a[[2, 0]] * dt)
+ &k_q[1] * (a[[2, 1]] * dt)
+ &k_q[2] * (a[[2, 2]] * dt));
let p3 = p + &(&k_p[0] * (a[[2, 0]] * dt)
+ &k_p[1] * (a[[2, 1]] * dt)
+ &k_p[2] * (a[[2, 2]] * dt));
let dq1 = system.dq_dt(t1, &q1, &p1)?;
let dp1 = system.dp_dt(t1, &q1, &p1)?;
let dq2 = system.dq_dt(t2, &q2, &p2)?;
let dp2 = system.dp_dt(t2, &q2, &p2)?;
let dq3 = system.dq_dt(t3, &q3, &p3)?;
let dp3 = system.dp_dt(t3, &q3, &p3)?;
let err_max = [
&dq1 - &k_q[0],
&dq2 - &k_q[1],
&dq3 - &k_q[2],
&dp1 - &k_p[0],
&dp2 - &k_p[1],
&dp3 - &k_p[2],
]
.iter()
.flat_map(|arr| arr.iter().map(|&x| x.abs()))
.fold(F::zero(), |a, b| a.max(b));
if err_max < F::from_f64(1e-12).expect("Operation failed") {
break;
}
k_q[0] = dq1;
k_p[0] = dp1;
k_q[1] = dq2;
k_p[1] = dp2;
k_q[2] = dq3;
k_p[2] = dp3;
}
let q_new = q + &(&k_q[0] * (b[0] * dt) + &k_q[1] * (b[1] * dt) + &k_q[2] * (b[2] * dt));
let p_new = p + &(&k_p[0] * (b[0] * dt) + &k_p[1] * (b[1] * dt) + &k_p[2] * (b[2] * dt));
Ok((q_new, p_new))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::symplectic::leapfrog::StormerVerlet;
use crate::symplectic::potential::SeparableHamiltonian;
use scirs2_core::ndarray::array;
#[test]
fn test_accuracy_comparison() {
let system = SeparableHamiltonian::new(
|_t, p| -> f64 { 0.5 * p.dot(p) },
|_t, q| -> f64 { 0.5 * q.dot(q) },
);
let q0 = array![1.0];
let p0 = array![0.0];
let t0 = 0.0;
let tf = 2.0 * PI;
let dt = 0.1;
let verlet = StormerVerlet::new();
let gl4 = GaussLegendre4::new();
let gl6 = GaussLegendre6::new();
let verlet_result = verlet
.integrate(&system, t0, tf, dt, q0.clone(), p0.clone())
.expect("Test: integration failed");
let gl4_result = gl4
.integrate(&system, t0, tf, dt, q0.clone(), p0.clone())
.expect("Test: integration failed");
let gl6_result = gl6
.integrate(&system, t0, tf, dt, q0.clone(), p0.clone())
.expect("Test: integration failed");
let verlet_error = ((verlet_result.q.last().expect("Operation failed")[0] - 1.0).powi(2)
+ verlet_result.p.last().expect("Operation failed")[0].powi(2))
.sqrt();
let gl4_error = ((gl4_result.q.last().expect("Operation failed")[0] - 1.0).powi(2)
+ gl4_result.p.last().expect("Operation failed")[0].powi(2))
.sqrt();
let gl6_error = ((gl6_result.q.last().expect("Operation failed")[0] - 1.0).powi(2)
+ gl6_result.p.last().expect("Operation failed")[0].powi(2))
.sqrt();
assert!(
gl4_error < verlet_error,
"GL4 error ({gl4_error}) should be smaller than Verlet error ({verlet_error})"
);
assert!(
gl6_error < gl4_error,
"GL6 error ({gl6_error}) should be smaller than GL4 error ({gl4_error})"
);
}
#[test]
fn test_energy_preservation() {
let pendulum = SeparableHamiltonian::new(
|_t, p| -> f64 { 0.5 * p.dot(p) },
|_t, q| -> f64 { -q[0].cos() },
);
let q0 = array![0.0];
let p0 = array![1.5];
let t0 = 0.0;
let tf = 50.0;
let dt = 0.1;
let verlet = StormerVerlet::new();
let gl4 = GaussLegendre4::new();
let verlet_result = verlet
.integrate(&pendulum, t0, tf, dt, q0.clone(), p0.clone())
.expect("Test: integration failed");
let gl4_result = gl4
.integrate(&pendulum, t0, tf, dt, q0.clone(), p0.clone())
.expect("Test: integration failed");
if let (Some(verlet_error), Some(gl4_error)) = (
verlet_result.energy_relative_error,
gl4_result.energy_relative_error,
) {
assert!(
gl4_error < verlet_error,
"GL4 energy error ({gl4_error}) should be smaller than Verlet error ({verlet_error})"
);
}
}
}