use crate::common::IntegrateFloat;
use crate::error::IntegrateResult;
use crate::symplectic::{HamiltonianFn, SymplecticIntegrator};
use scirs2_core::ndarray::Array1;
use std::f64::consts::PI;
use std::marker::PhantomData;
#[derive(Debug, Clone)]
pub struct StormerVerlet<F: IntegrateFloat> {
_marker: PhantomData<F>,
}
impl<F: IntegrateFloat> StormerVerlet<F> {
pub fn new() -> Self {
StormerVerlet {
_marker: PhantomData,
}
}
}
impl<F: IntegrateFloat> Default for StormerVerlet<F> {
fn default() -> Self {
Self::new()
}
}
impl<F: IntegrateFloat> SymplecticIntegrator<F> for StormerVerlet<F> {
fn step(
&self,
system: &dyn HamiltonianFn<F>,
t: F,
q: &Array1<F>,
p: &Array1<F>,
dt: F,
) -> IntegrateResult<(Array1<F>, Array1<F>)> {
let dp1 = system.dp_dt(t, q, p)?;
let p_half = p + &(&dp1 * (dt / (F::one() + F::one())));
let t_half = t + dt / (F::one() + F::one());
let dq = system.dq_dt(t_half, q, &p_half)?;
let q_new = q + &(&dq * dt);
let t_new = t + dt;
let dp2 = system.dp_dt(t_new, &q_new, &p_half)?;
let p_new = p_half + &(&dp2 * (dt / (F::one() + F::one())));
Ok((q_new, p_new))
}
}
#[allow(dead_code)]
pub fn velocity_verlet<F: IntegrateFloat>(
system: &dyn HamiltonianFn<F>,
t: F,
q: &Array1<F>,
p: &Array1<F>,
dt: F,
) -> IntegrateResult<(Array1<F>, Array1<F>)> {
let dp_old = system.dp_dt(t, q, p)?;
let p_half = p + &(&dp_old * (dt / F::from(2.0).expect("Failed to convert constant to float")));
let dq = system.dq_dt(
t + dt / F::from(2.0).expect("Failed to convert constant to float"),
q,
&p_half,
)?;
let q_new = q + &(&dq * dt);
let dp_new = system.dp_dt(t + dt, &q_new, &p_half)?;
let p_new =
&p_half + &(&dp_new * (dt / F::from(2.0).expect("Failed to convert constant to float")));
Ok((q_new, p_new))
}
#[allow(dead_code)]
pub fn position_verlet<F: IntegrateFloat>(
system: &dyn HamiltonianFn<F>,
t: F,
q: &Array1<F>,
p: &Array1<F>,
dt: F,
) -> IntegrateResult<(Array1<F>, Array1<F>)> {
StormerVerlet::new().step(system, t, q, p, dt)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::symplectic::potential::SeparableHamiltonian;
use scirs2_core::ndarray::array;
#[test]
fn test_verlet_harmonic() {
let system = SeparableHamiltonian::new(
|_t, p| -> f64 { 0.5 * p.dot(p) },
|_t, q| -> f64 { 0.5 * q.dot(q) },
);
let q0 = array![1.0];
let p0 = array![0.0];
let t0 = 0.0;
let dt = 0.1;
let period = 2.0 * PI;
let steps = (period / dt).round() as usize;
let integrator = StormerVerlet::new();
let mut q = q0.clone();
let mut p = p0.clone();
let mut t = t0;
for _ in 0..steps {
let (q_new, p_new) = integrator
.step(&system, t, &q, &p, dt)
.expect("Operation failed");
q = q_new;
p = p_new;
t += dt;
}
assert!((q[0] - q0[0]).abs() < 0.1);
assert!((p[0] - p0[0]).abs() < 0.1);
}
#[test]
fn test_compare_velocity_verlet() {
let system = SeparableHamiltonian::new(
|_t, p| -> f64 { 0.5 * p.dot(p) },
|_t, q| -> f64 { 0.5 * q.dot(q) },
);
let q0 = array![1.0];
let p0 = array![0.0];
let t0 = 0.0;
let dt = 0.01;
let initial_energy = 0.5 * p0.dot(&p0) + 0.5 * q0.dot(&q0);
let period = 2.0 * PI;
let steps = (period / dt).round() as usize;
let mut q1 = q0.clone();
let mut p1 = p0.clone();
for _ in 0..steps {
let (q_new, p_new) = StormerVerlet::new()
.step(&system, t0, &q1, &p1, dt)
.expect("Operation failed");
q1 = q_new;
p1 = p_new;
}
let energy1 = 0.5 * p1.dot(&p1) + 0.5 * q1.dot(&q1);
let mut q2 = q0.clone();
let mut p2 = p0.clone();
for _ in 0..steps {
let (q_new, p_new) =
velocity_verlet(&system, t0, &q2, &p2, dt).expect("Operation failed");
q2 = q_new;
p2 = p_new;
}
let energy2 = 0.5 * p2.dot(&p2) + 0.5 * q2.dot(&q2);
assert!((energy1 - initial_energy).abs() < 1e-6);
assert!((energy2 - initial_energy).abs() < 1e-6);
assert!((q1[0] - q0[0]).abs() < 0.01);
assert!((q2[0] - q0[0]).abs() < 0.01);
}
#[test]
fn test_energy_conservation() {
let kepler = SeparableHamiltonian::new(
|_t, p| -> f64 { 0.5 * p.dot(p) },
|_t, q| -> f64 {
let r = (q[0] * q[0] + q[1] * q[1]).sqrt();
if r < 1e-10 {
0.0
} else {
-1.0 / r
}
},
);
let q0 = array![1.0, 0.0]; let p0 = array![0.0, 1.0];
let t0 = 0.0;
let tf = 10.0; let dt = 0.01;
let integrator = StormerVerlet::new();
let result = integrator
.integrate(&kepler, t0, tf, dt, q0.clone(), p0.clone())
.expect("Operation failed");
if let Some(error) = result.energy_relative_error {
assert!(error < 1e-3, "Energy error too large: {error}");
}
for i in 0..result.q.len() {
let q = &result.q[i];
let r = (q[0] * q[0] + q[1] * q[1]).sqrt();
assert!((r - 1.0).abs() < 0.01, "Orbit not circular, r = {r}");
}
}
}