use crate::base::Potential2;
use crate::math::{Mask, Vector};
#[derive(Clone, Copy, Debug)]
pub struct Switch<P, T> {
inner: P,
rs_sq: T,
rc_sq: T,
inv_denom: T, three_rs_sq: T,
}
impl<P, T: Vector> Switch<P, T> {
#[inline]
pub fn new(inner: P, rs: f64, rc: f64) -> Self {
debug_assert!(rs < rc, "Switch distance must be less than cutoff");
let rs_sq = rs * rs;
let rc_sq = rc * rc;
let diff = rc_sq - rs_sq;
let denom = diff * diff * diff;
Self {
inner,
rs_sq: T::splat(rs_sq),
rc_sq: T::splat(rc_sq),
inv_denom: T::splat(1.0 / denom),
three_rs_sq: T::splat(3.0 * rs_sq),
}
}
#[inline(always)]
fn switch_value(&self, r_sq: T) -> T {
let two = T::splat(2.0);
let rc_minus_r = self.rc_sq - r_sq;
let term1 = rc_minus_r * rc_minus_r;
let term2 = self.rc_sq + two * r_sq - self.three_rs_sq;
term1 * term2 * self.inv_denom
}
#[inline(always)]
fn switch_derivative(&self, r_sq: T) -> T {
let six = T::splat(6.0);
let term = six * (self.rc_sq - r_sq) * (self.rs_sq - r_sq);
term * self.inv_denom
}
}
impl<P: Potential2<T>, T: Vector> Potential2<T> for Switch<P, T> {
#[inline(always)]
fn energy(&self, r_sq: T) -> T {
let e = self.inner.energy(r_sq);
let inside_rs = r_sq.lt(self.rs_sq);
let inside_rc = r_sq.lt(self.rc_sq);
let in_switch = inside_rc & !inside_rs;
let s = self.switch_value(r_sq);
inside_rs.select(e, in_switch.select(e * s, T::zero()))
}
#[inline(always)]
fn force_factor(&self, r_sq: T) -> T {
let two = T::splat(2.0);
let e = self.inner.energy(r_sq);
let f = self.inner.force_factor(r_sq);
let inside_rs = r_sq.lt(self.rs_sq);
let inside_rc = r_sq.lt(self.rc_sq);
let in_switch = inside_rc & !inside_rs;
let s_func = self.switch_value(r_sq);
let ds_dr2 = self.switch_derivative(r_sq);
let f_switched = f * s_func - two * e * ds_dr2;
inside_rs.select(f, in_switch.select(f_switched, T::zero()))
}
#[inline(always)]
fn energy_force(&self, r_sq: T) -> (T, T) {
let two = T::splat(2.0);
let (e, f) = self.inner.energy_force(r_sq);
let inside_rs = r_sq.lt(self.rs_sq);
let inside_rc = r_sq.lt(self.rc_sq);
let in_switch = inside_rc & !inside_rs;
let s_func = self.switch_value(r_sq);
let ds_dr2 = self.switch_derivative(r_sq);
let e_switched = e * s_func;
let f_switched = f * s_func - two * e * ds_dr2;
(
inside_rs.select(e, in_switch.select(e_switched, T::zero())),
inside_rs.select(f, in_switch.select(f_switched, T::zero())),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pair::Lj;
use approx::assert_relative_eq;
#[test]
fn test_switch_inside_rs() {
let lj: Lj<f64> = Lj::new(1.0, 3.4);
let lj_sw = Switch::new(lj, 8.0, 10.0);
let r_sq = 25.0;
let e_base = lj.energy(r_sq);
let e_sw = lj_sw.energy(r_sq);
assert_relative_eq!(e_base, e_sw, epsilon = 1e-10);
}
#[test]
fn test_switch_at_rs() {
let lj: Lj<f64> = Lj::new(1.0, 3.4);
let rs = 8.0;
let lj_sw = Switch::new(lj, rs, 10.0);
let e_base = lj.energy(rs * rs);
let e_sw = lj_sw.energy(rs * rs);
assert_relative_eq!(e_base, e_sw, epsilon = 1e-10);
}
#[test]
fn test_switch_at_rc() {
let lj: Lj<f64> = Lj::new(1.0, 3.4);
let rc = 10.0;
let lj_sw = Switch::new(lj, 8.0, rc);
let r = rc - 0.0001;
let e = lj_sw.energy(r * r);
assert!(e.abs() < 1e-5);
}
#[test]
fn test_switch_outside_rc() {
let lj: Lj<f64> = Lj::new(1.0, 3.4);
let lj_sw = Switch::new(lj, 8.0, 10.0);
let r_sq = 121.0;
let e = lj_sw.energy(r_sq);
let f = lj_sw.force_factor(r_sq);
assert_relative_eq!(e, 0.0, epsilon = 1e-10);
assert_relative_eq!(f, 0.0, epsilon = 1e-10);
}
#[test]
fn test_switch_numerical_derivative() {
let lj: Lj<f64> = Lj::new(1.0, 3.4);
let lj_sw = Switch::new(lj, 8.0, 10.0);
let r = 9.0;
let h = 1e-7;
let e_plus = lj_sw.energy((r + h) * (r + h));
let e_minus = lj_sw.energy((r - h) * (r - h));
let dv_dr_numerical = (e_plus - e_minus) / (2.0 * h);
let f = lj_sw.force_factor(r * r);
let dv_dr_analytical = -f * r;
assert_relative_eq!(dv_dr_analytical, dv_dr_numerical, epsilon = 1e-6);
}
}