use crate::base::Potential4;
use crate::math::Vector;
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct Opls<T> {
c1: T,
c2: T,
c3: T,
c4: T,
offset: T,
}
impl<T: Vector> Opls<T> {
#[inline]
pub fn new(c1: f64, c2: f64, c3: f64, c4: f64) -> Self {
Self {
c1: T::splat(c1),
c2: T::splat(c2),
c3: T::splat(c3),
c4: T::splat(c4),
offset: T::splat(c1 + c2 + c3 + c4),
}
}
#[inline]
pub fn simple(c1: f64, c3: f64) -> Self {
Self::new(c1, 0.0, c3, 0.0)
}
#[inline]
pub fn threefold(c3: f64) -> Self {
Self::new(0.0, 0.0, c3, 0.0)
}
}
impl<T: Vector> Potential4<T> for Opls<T> {
#[inline(always)]
fn energy(&self, cos_phi: T, _sin_phi: T) -> T {
let cos2 = cos_phi * cos_phi;
let cos3 = cos2 * cos_phi;
let cos4 = cos2 * cos2;
let two = T::splat(2.0);
let three = T::splat(3.0);
let four = T::splat(4.0);
let eight = T::splat(8.0);
let cos_2phi = two * cos2 - T::one();
let cos_3phi = four * cos3 - three * cos_phi;
let cos_4phi = eight * cos4 - eight * cos2 + T::one();
self.offset + self.c1 * cos_phi - self.c2 * cos_2phi + self.c3 * cos_3phi
- self.c4 * cos_4phi
}
#[inline(always)]
fn derivative(&self, cos_phi: T, sin_phi: T) -> T {
let cos2 = cos_phi * cos_phi;
let two = T::splat(2.0);
let three = T::splat(3.0);
let four = T::splat(4.0);
let sin_2phi = two * sin_phi * cos_phi;
let sin_3phi = sin_phi * (four * cos2 - T::one());
let sin_4phi = four * sin_phi * cos_phi * (two * cos2 - T::one());
T::zero() - self.c1 * sin_phi + two * self.c2 * sin_2phi - three * self.c3 * sin_3phi
+ four * self.c4 * sin_4phi
}
#[inline(always)]
fn energy_derivative(&self, cos_phi: T, sin_phi: T) -> (T, T) {
let cos2 = cos_phi * cos_phi;
let cos3 = cos2 * cos_phi;
let cos4 = cos2 * cos2;
let two = T::splat(2.0);
let three = T::splat(3.0);
let four = T::splat(4.0);
let eight = T::splat(8.0);
let cos_2phi = two * cos2 - T::one();
let cos_3phi = four * cos3 - three * cos_phi;
let cos_4phi = eight * cos4 - eight * cos2 + T::one();
let energy = self.offset + self.c1 * cos_phi - self.c2 * cos_2phi + self.c3 * cos_3phi
- self.c4 * cos_4phi;
let sin_2phi = two * sin_phi * cos_phi;
let sin_3phi = sin_phi * (four * cos2 - T::one());
let sin_4phi = four * sin_phi * cos_phi * (two * cos2 - T::one());
let derivative = T::zero() - self.c1 * sin_phi + two * self.c2 * sin_2phi
- three * self.c3 * sin_3phi
+ four * self.c4 * sin_4phi;
(energy, derivative)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use core::f64::consts::PI;
#[test]
fn test_opls_at_zero() {
let opls: Opls<f64> = Opls::new(1.0, 2.0, 3.0, 4.0);
let e0 = opls.energy(1.0, 0.0);
assert_relative_eq!(e0, 8.0, epsilon = 1e-10);
}
#[test]
fn test_opls_threefold() {
let c3 = 2.0;
let opls: Opls<f64> = Opls::threefold(c3);
let e0 = opls.energy(1.0, 0.0);
assert_relative_eq!(e0, 4.0, epsilon = 1e-10);
let phi = PI / 3.0;
let e60 = opls.energy(phi.cos(), phi.sin());
assert_relative_eq!(e60, 0.0, epsilon = 1e-10);
}
#[test]
fn test_opls_numerical_derivative() {
let opls: Opls<f64> = Opls::new(1.0, 0.5, 1.5, 0.2);
let phi = 0.8;
let h = 1e-7;
let e_plus = opls.energy((phi + h).cos(), (phi + h).sin());
let e_minus = opls.energy((phi - h).cos(), (phi - h).sin());
let deriv_numerical = (e_plus - e_minus) / (2.0 * h);
let deriv_analytical = opls.derivative(phi.cos(), phi.sin());
assert_relative_eq!(deriv_analytical, deriv_numerical, epsilon = 1e-6);
}
#[test]
fn test_opls_energy_derivative_consistency() {
let opls: Opls<f64> = Opls::new(1.0, 0.5, 1.5, 0.2);
let phi = 1.2;
let (e1, d1) = opls.energy_derivative(phi.cos(), phi.sin());
let e2 = opls.energy(phi.cos(), phi.sin());
let d2 = opls.derivative(phi.cos(), phi.sin());
assert_relative_eq!(e1, e2, epsilon = 1e-10);
assert_relative_eq!(d1, d2, epsilon = 1e-10);
}
}