use crate::base::Potential2;
use crate::math::{Mask, Vector};
#[derive(Clone, Copy, Debug)]
pub struct Shift<P, T> {
inner: P,
rc_sq: T,
e_rc: T, }
impl<P: Potential2<T>, T: Vector> Shift<P, T> {
#[inline]
pub fn new(inner: P, rc: f64) -> Self {
let rc_sq = T::splat(rc * rc);
let e_rc = inner.energy(rc_sq);
Self { inner, rc_sq, e_rc }
}
#[inline]
pub fn inner(&self) -> &P {
&self.inner
}
#[inline]
pub fn shift(&self) -> T {
self.e_rc
}
}
impl<P: Potential2<T>, T: Vector> Potential2<T> for Shift<P, T> {
#[inline(always)]
fn energy(&self, r_sq: T) -> T {
let e = self.inner.energy(r_sq) - self.e_rc;
let mask = r_sq.lt(self.rc_sq);
mask.select(e, T::zero())
}
#[inline(always)]
fn force_factor(&self, r_sq: T) -> T {
let s = self.inner.force_factor(r_sq);
let mask = r_sq.lt(self.rc_sq);
mask.select(s, T::zero())
}
#[inline(always)]
fn energy_force(&self, r_sq: T) -> (T, T) {
let (e, s) = self.inner.energy_force(r_sq);
let mask = r_sq.lt(self.rc_sq);
(
mask.select(e - self.e_rc, T::zero()),
mask.select(s, T::zero()),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pair::Lj;
use approx::assert_relative_eq;
#[test]
fn test_shift_at_cutoff() {
let lj: Lj<f64> = Lj::new(1.0, 3.4);
let rc = 10.0;
let lj_shift = Shift::new(lj, rc);
let r = rc - 0.0001;
let e = lj_shift.energy(r * r);
assert!(e.abs() < 1e-5, "Energy near cutoff = {}, expected ~0", e);
}
#[test]
fn test_shift_zero_outside() {
let lj: Lj<f64> = Lj::new(1.0, 3.4);
let lj_shift = Shift::new(lj, 10.0);
let r = 11.0;
let e = lj_shift.energy(r * r);
assert_relative_eq!(e, 0.0, epsilon = 1e-10);
}
#[test]
fn test_shift_vs_unshifted() {
let lj: Lj<f64> = Lj::new(1.0, 3.4);
let rc = 10.0;
let lj_shift = Shift::new(lj, rc);
let r = 5.0;
let e_base = lj.energy(r * r);
let e_rc = lj.energy(rc * rc);
let e_shift = lj_shift.energy(r * r);
assert_relative_eq!(e_shift, e_base - e_rc, epsilon = 1e-10);
}
#[test]
fn test_shift_force_unchanged() {
let lj: Lj<f64> = Lj::new(1.0, 3.4);
let lj_shift = Shift::new(lj, 10.0);
let r = 5.0;
let s_base = lj.force_factor(r * r);
let s_shift = lj_shift.force_factor(r * r);
assert_relative_eq!(s_base, s_shift, epsilon = 1e-10);
}
}