use crate::base::Potential4;
use crate::math::Vector;
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct Rb<T> {
c0: T,
c1: T,
c2: T,
c3: T,
c4: T,
c5: T,
}
impl<T: Vector> Rb<T> {
#[inline]
pub fn new(c0: f64, c1: f64, c2: f64, c3: f64, c4: f64, c5: f64) -> Self {
Self {
c0: T::splat(c0),
c1: T::splat(c1),
c2: T::splat(c2),
c3: T::splat(c3),
c4: T::splat(c4),
c5: T::splat(c5),
}
}
#[inline]
pub fn from_array(c: [f64; 6]) -> Self {
Self::new(c[0], c[1], c[2], c[3], c[4], c[5])
}
#[inline]
pub fn four_term(c0: f64, c1: f64, c2: f64, c3: f64) -> Self {
Self::new(c0, c1, c2, c3, 0.0, 0.0)
}
}
impl<T: Vector> Potential4<T> for Rb<T> {
#[inline(always)]
fn energy(&self, cos_phi: T, _sin_phi: T) -> T {
let result = self.c5;
let result = result * cos_phi + self.c4;
let result = result * cos_phi + self.c3;
let result = result * cos_phi + self.c2;
let result = result * cos_phi + self.c1;
result * cos_phi + self.c0
}
#[inline(always)]
fn derivative(&self, cos_phi: T, sin_phi: T) -> T {
let five = T::splat(5.0);
let four = T::splat(4.0);
let three = T::splat(3.0);
let two = T::splat(2.0);
let dv_dcos = five * self.c5;
let dv_dcos = dv_dcos * cos_phi + four * self.c4;
let dv_dcos = dv_dcos * cos_phi + three * self.c3;
let dv_dcos = dv_dcos * cos_phi + two * self.c2;
let dv_dcos = dv_dcos * cos_phi + self.c1;
T::zero() - dv_dcos * sin_phi
}
#[inline(always)]
fn energy_derivative(&self, cos_phi: T, sin_phi: T) -> (T, T) {
let five = T::splat(5.0);
let four = T::splat(4.0);
let three = T::splat(3.0);
let two = T::splat(2.0);
let cos2 = cos_phi * cos_phi;
let cos3 = cos2 * cos_phi;
let cos4 = cos2 * cos2;
let cos5 = cos4 * cos_phi;
let energy = self.c0
+ self.c1 * cos_phi
+ self.c2 * cos2
+ self.c3 * cos3
+ self.c4 * cos4
+ self.c5 * cos5;
let dv_dcos = self.c1
+ two * self.c2 * cos_phi
+ three * self.c3 * cos2
+ four * self.c4 * cos3
+ five * self.c5 * cos4;
let derivative = T::zero() - dv_dcos * sin_phi;
(energy, derivative)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use core::f64::consts::PI;
#[test]
fn test_rb_at_zero() {
let rb: Rb<f64> = Rb::new(1.0, 2.0, 3.0, 4.0, 5.0, 6.0);
let e0 = rb.energy(1.0, 0.0);
assert_relative_eq!(e0, 21.0, epsilon = 1e-10);
}
#[test]
fn test_rb_at_180() {
let rb: Rb<f64> = Rb::new(1.0, 2.0, 3.0, 4.0, 5.0, 6.0);
let e180 = rb.energy(-1.0, 0.0);
assert_relative_eq!(e180, -3.0, epsilon = 1e-10);
}
#[test]
fn test_rb_quadratic() {
let rb: Rb<f64> = Rb::new(1.0, 0.0, 2.0, 0.0, 0.0, 0.0);
let phi = PI / 4.0;
let cos_phi = phi.cos();
let energy = rb.energy(cos_phi, phi.sin());
let expected = 1.0 + 2.0 * cos_phi * cos_phi;
assert_relative_eq!(energy, expected, epsilon = 1e-10);
}
#[test]
fn test_rb_numerical_derivative() {
let rb: Rb<f64> = Rb::new(9.28, 12.16, -13.12, -3.06, 26.24, -31.5);
let phi = 1.1;
let h = 1e-7;
let e_plus = rb.energy((phi + h).cos(), (phi + h).sin());
let e_minus = rb.energy((phi - h).cos(), (phi - h).sin());
let deriv_numerical = (e_plus - e_minus) / (2.0 * h);
let deriv_analytical = rb.derivative(phi.cos(), phi.sin());
assert_relative_eq!(deriv_analytical, deriv_numerical, epsilon = 1e-6);
}
#[test]
fn test_rb_from_array() {
let rb1: Rb<f64> = Rb::new(1.0, 2.0, 3.0, 4.0, 5.0, 6.0);
let rb2: Rb<f64> = Rb::from_array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let phi = 0.7;
assert_relative_eq!(
rb1.energy(phi.cos(), phi.sin()),
rb2.energy(phi.cos(), phi.sin()),
epsilon = 1e-10
);
}
}