use crate::base::Potential4;
use crate::math::Vector;
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct Cos<T> {
k: T,
n: i32,
cos_d: T,
sin_d: T,
neg_n: T,
}
impl<T: Vector> Cos<T> {
#[inline]
pub fn new(k: f64, n: i32, d: f64) -> Self {
Self {
k: T::splat(k),
n,
cos_d: T::splat(d.cos()),
sin_d: T::splat(d.sin()),
neg_n: T::splat(-(n as f64)),
}
}
#[inline]
pub fn planar(k: f64) -> Self {
Self::new(k, 2, core::f64::consts::PI)
}
#[inline]
pub fn trigonal(k: f64) -> Self {
Self::new(k, 3, 0.0)
}
}
impl<T: Vector> Potential4<T> for Cos<T> {
#[inline(always)]
fn energy(&self, cos_xi: T, sin_xi: T) -> T {
let one = T::splat(1.0);
let (cos_n, sin_n) = chebyshev_cos_sin(self.n, cos_xi, sin_xi);
let cos_term = cos_n * self.cos_d + sin_n * self.sin_d;
self.k * (one + cos_term)
}
#[inline(always)]
fn derivative(&self, cos_xi: T, sin_xi: T) -> T {
let (cos_n, sin_n) = chebyshev_cos_sin(self.n, cos_xi, sin_xi);
let sin_term = sin_n * self.cos_d - cos_n * self.sin_d;
self.k * self.neg_n * sin_term
}
#[inline(always)]
fn energy_derivative(&self, cos_xi: T, sin_xi: T) -> (T, T) {
let one = T::splat(1.0);
let (cos_n, sin_n) = chebyshev_cos_sin(self.n, cos_xi, sin_xi);
let cos_term = cos_n * self.cos_d + sin_n * self.sin_d;
let sin_term = sin_n * self.cos_d - cos_n * self.sin_d;
let energy = self.k * (one + cos_term);
let derivative = self.k * self.neg_n * sin_term;
(energy, derivative)
}
}
#[inline(always)]
fn chebyshev_cos_sin<T: Vector>(n: i32, cos_x: T, sin_x: T) -> (T, T) {
let zero = T::zero();
let one = T::splat(1.0);
let two = T::splat(2.0);
match n {
0 => (one, zero),
1 => (cos_x, sin_x),
2 => {
let cos2 = two * cos_x * cos_x - one;
let sin2 = two * sin_x * cos_x;
(cos2, sin2)
}
3 => {
let four = T::splat(4.0);
let three = T::splat(3.0);
let cos2 = cos_x * cos_x;
let cos3 = four * cos2 * cos_x - three * cos_x;
let sin3 = sin_x * (four * cos2 - one);
(cos3, sin3)
}
_ => {
let mut cos_prev = one;
let mut sin_prev = zero;
let mut cos_curr = cos_x;
let mut sin_curr = sin_x;
for _ in 1..n {
let cos_next = two * cos_x * cos_curr - cos_prev;
let sin_next = two * cos_x * sin_curr - sin_prev;
cos_prev = cos_curr;
sin_prev = sin_curr;
cos_curr = cos_next;
sin_curr = sin_next;
}
(cos_curr, sin_curr)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use core::f64::consts::PI;
#[test]
fn test_cos_planar_at_zero() {
let cos: Cos<f64> = Cos::planar(10.0);
let e = cos.energy(1.0, 0.0);
assert_relative_eq!(e, 0.0, epsilon = 1e-10);
}
#[test]
fn test_cos_planar_at_90() {
let k = 10.0;
let cos: Cos<f64> = Cos::planar(k);
let xi = PI / 2.0;
let e = cos.energy(xi.cos(), xi.sin());
let expected = 2.0 * k;
assert_relative_eq!(e, expected, epsilon = 1e-10);
}
#[test]
fn test_cos_n1() {
let k = 5.0;
let d = 0.0;
let cos: Cos<f64> = Cos::new(k, 1, d);
let xi = 1.0;
let e = cos.energy(xi.cos(), xi.sin());
let expected = k * (1.0 + xi.cos());
assert_relative_eq!(e, expected, epsilon = 1e-10);
}
#[test]
fn test_cos_n3() {
let k = 8.0;
let cos: Cos<f64> = Cos::trigonal(k);
let e = cos.energy(1.0, 0.0);
assert_relative_eq!(e, 2.0 * k, epsilon = 1e-10);
let xi = PI / 3.0;
let e60 = cos.energy(xi.cos(), xi.sin());
assert_relative_eq!(e60, 0.0, epsilon = 1e-10);
}
#[test]
fn test_cos_numerical_derivative() {
let cos: Cos<f64> = Cos::new(15.0, 2, 0.5);
let xi = 0.7;
let h = 1e-7;
let e_plus = cos.energy((xi + h).cos(), (xi + h).sin());
let e_minus = cos.energy((xi - h).cos(), (xi - h).sin());
let deriv_numerical = (e_plus - e_minus) / (2.0 * h);
let deriv_analytical = cos.derivative(xi.cos(), xi.sin());
assert_relative_eq!(deriv_analytical, deriv_numerical, epsilon = 1e-6);
}
#[test]
fn test_chebyshev_n4() {
let xi = 0.8;
let (cos4, sin4) = chebyshev_cos_sin(4, xi.cos(), xi.sin());
let expected_cos = (4.0 * xi).cos();
let expected_sin = (4.0 * xi).sin();
assert_relative_eq!(cos4, expected_cos, epsilon = 1e-10);
assert_relative_eq!(sin4, expected_sin, epsilon = 1e-10);
}
#[test]
fn test_cos_energy_derivative_consistency() {
let cos: Cos<f64> = Cos::new(20.0, 3, PI / 4.0);
let xi = 1.2;
let e1 = cos.energy(xi.cos(), xi.sin());
let d1 = cos.derivative(xi.cos(), xi.sin());
let (e2, d2) = cos.energy_derivative(xi.cos(), xi.sin());
assert_relative_eq!(e1, e2, epsilon = 1e-10);
assert_relative_eq!(d1, d2, epsilon = 1e-10);
}
}