use crate::math::Vector;
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct Dreid<T, const N: u32 = 4> {
d0: T,
r0_sq: T,
neg_60_d0: T,
}
impl<T: Vector, const N: u32> Dreid<T, N> {
#[inline]
pub fn new(d0: f64, r0: f64) -> Self {
Self {
d0: T::splat(d0),
r0_sq: T::splat(r0 * r0),
neg_60_d0: T::splat(-60.0 * d0),
}
}
#[inline(always)]
pub fn energy(&self, r_sq: T, cos_theta: T) -> T {
let five = T::splat(5.0);
let six = T::splat(6.0);
let ratio2 = self.r0_sq / r_sq;
let ratio4 = ratio2 * ratio2;
let ratio8 = ratio4 * ratio4;
let ratio10 = ratio8 * ratio2;
let ratio12 = ratio10 * ratio2;
let cos_n = cos_power::<T, N>(cos_theta);
self.d0 * (five * ratio12 - six * ratio10) * cos_n
}
#[inline(always)]
pub fn derivative(&self, r_sq: T, cos_theta: T) -> (T, T) {
let five = T::splat(5.0);
let six = T::splat(6.0);
let ratio2 = self.r0_sq / r_sq;
let ratio4 = ratio2 * ratio2;
let ratio8 = ratio4 * ratio4;
let ratio10 = ratio8 * ratio2;
let ratio12 = ratio10 * ratio2;
let cos_n = cos_power::<T, N>(cos_theta);
let cos_nm1 = cos_power_m1::<T, N>(cos_theta);
let s = self.neg_60_d0 * (ratio10 - ratio12) * cos_n / r_sq;
let n_t = T::splat(N as f64);
let radial_part = five * ratio12 - six * ratio10;
let dv_dcos = self.d0 * radial_part * n_t * cos_nm1;
(s, dv_dcos)
}
#[inline(always)]
pub fn energy_derivative(&self, r_sq: T, cos_theta: T) -> (T, T, T) {
let five = T::splat(5.0);
let six = T::splat(6.0);
let ratio2 = self.r0_sq / r_sq;
let ratio4 = ratio2 * ratio2;
let ratio8 = ratio4 * ratio4;
let ratio10 = ratio8 * ratio2;
let ratio12 = ratio10 * ratio2;
let cos_n = cos_power::<T, N>(cos_theta);
let cos_nm1 = cos_power_m1::<T, N>(cos_theta);
let radial_part = five * ratio12 - six * ratio10;
let energy = self.d0 * radial_part * cos_n;
let s = self.neg_60_d0 * (ratio10 - ratio12) * cos_n / r_sq;
let n_t = T::splat(N as f64);
let dv_dcos = self.d0 * radial_part * n_t * cos_nm1;
(energy, s, dv_dcos)
}
}
#[inline(always)]
fn cos_power<T: Vector, const N: u32>(cos_theta: T) -> T {
match N {
0 => T::splat(1.0),
1 => cos_theta,
2 => cos_theta * cos_theta,
3 => cos_theta * cos_theta * cos_theta,
4 => {
let c2 = cos_theta * cos_theta;
c2 * c2
}
5 => {
let c2 = cos_theta * cos_theta;
c2 * c2 * cos_theta
}
6 => {
let c2 = cos_theta * cos_theta;
c2 * c2 * c2
}
_ => {
let mut result = T::splat(1.0);
let mut base = cos_theta;
let mut exp = N;
while exp > 0 {
if exp & 1 == 1 {
result = result * base;
}
base = base * base;
exp >>= 1;
}
result
}
}
}
#[inline(always)]
fn cos_power_m1<T: Vector, const N: u32>(cos_theta: T) -> T {
match N {
0 | 1 => T::splat(1.0), 2 => cos_theta,
3 => cos_theta * cos_theta,
4 => cos_theta * cos_theta * cos_theta,
5 => {
let c2 = cos_theta * cos_theta;
c2 * c2
}
6 => {
let c2 = cos_theta * cos_theta;
c2 * c2 * cos_theta
}
7 => {
let c2 = cos_theta * cos_theta;
c2 * c2 * c2
}
_ => {
let mut result = T::splat(1.0);
let mut base = cos_theta;
let mut exp = N - 1;
while exp > 0 {
if exp & 1 == 1 {
result = result * base;
}
base = base * base;
exp >>= 1;
}
result
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_dreid_at_minimum() {
let d0 = 8.0;
let r0 = 2.75;
let dreid: Dreid<f64> = Dreid::new(d0, r0);
let e = dreid.energy(r0 * r0, 1.0);
assert_relative_eq!(e, -d0, epsilon = 1e-10);
}
#[test]
fn test_dreid_angular_dependence() {
let dreid: Dreid<f64, 4> = Dreid::new(8.0, 2.75);
let r_sq = 2.75 * 2.75;
let e0 = dreid.energy(r_sq, 1.0);
let e90 = dreid.energy(r_sq, 0.0);
assert_relative_eq!(e90, 0.0, epsilon = 1e-10);
let e60 = dreid.energy(r_sq, 0.5);
let expected_ratio = 0.5_f64.powi(4);
assert_relative_eq!(e60 / e0, expected_ratio, epsilon = 1e-10);
}
#[test]
fn test_dreid_cos2_vs_cos4() {
let d0 = 8.0;
let r0 = 2.75;
let cos2: Dreid<f64, 2> = Dreid::new(d0, r0);
let cos4: Dreid<f64, 4> = Dreid::new(d0, r0);
let r_sq = r0 * r0;
let cos_theta = 0.7;
let e2 = cos2.energy(r_sq, cos_theta);
let e4 = cos4.energy(r_sq, cos_theta);
let e2_linear = cos2.energy(r_sq, 1.0);
let e4_linear = cos4.energy(r_sq, 1.0);
assert_relative_eq!(e2_linear, e4_linear, epsilon = 1e-10);
assert!(e4.abs() < e2.abs());
assert_relative_eq!(e2 / e2_linear, cos_theta.powi(2), epsilon = 1e-10);
assert_relative_eq!(e4 / e4_linear, cos_theta.powi(4), epsilon = 1e-10);
}
#[test]
fn test_dreid_numerical_radial_derivative() {
let dreid: Dreid<f64, 4> = Dreid::new(8.0, 2.75);
let r = 2.9;
let cos_theta = 0.9;
let h = 1e-7;
let e_plus = dreid.energy((r + h) * (r + h), cos_theta);
let e_minus = dreid.energy((r - h) * (r - h), cos_theta);
let dv_dr_numerical = (e_plus - e_minus) / (2.0 * h);
let (s, _) = dreid.derivative(r * r, cos_theta);
let dv_dr_analytical = -s * r;
assert_relative_eq!(dv_dr_analytical, dv_dr_numerical, epsilon = 1e-6);
}
#[test]
fn test_dreid_numerical_angular_derivative() {
let dreid: Dreid<f64, 4> = Dreid::new(8.0, 2.75);
let r_sq = 3.0 * 3.0;
let cos_theta = 0.85;
let h = 1e-7;
let e_plus = dreid.energy(r_sq, cos_theta + h);
let e_minus = dreid.energy(r_sq, cos_theta - h);
let dv_dcos_numerical = (e_plus - e_minus) / (2.0 * h);
let (_, dv_dcos_analytical) = dreid.derivative(r_sq, cos_theta);
assert_relative_eq!(dv_dcos_analytical, dv_dcos_numerical, epsilon = 1e-6);
}
#[test]
fn test_dreid_cos2_numerical_derivatives() {
let dreid: Dreid<f64, 2> = Dreid::new(6.0, 2.5);
let r = 2.7;
let r_sq = r * r;
let cos_theta = 0.8;
let h = 1e-7;
let e_plus = dreid.energy((r + h) * (r + h), cos_theta);
let e_minus = dreid.energy((r - h) * (r - h), cos_theta);
let dv_dr_numerical = (e_plus - e_minus) / (2.0 * h);
let (s, _) = dreid.derivative(r_sq, cos_theta);
let dv_dr_analytical = -s * r;
assert_relative_eq!(dv_dr_analytical, dv_dr_numerical, epsilon = 1e-6);
let e_plus = dreid.energy(r_sq, cos_theta + h);
let e_minus = dreid.energy(r_sq, cos_theta - h);
let dv_dcos_numerical = (e_plus - e_minus) / (2.0 * h);
let (_, dv_dcos_analytical) = dreid.derivative(r_sq, cos_theta);
assert_relative_eq!(dv_dcos_analytical, dv_dcos_numerical, epsilon = 1e-6);
}
#[test]
fn test_dreid_energy_derivative_consistency() {
let dreid: Dreid<f64, 4> = Dreid::new(6.0, 2.5);
let r_sq = 2.8 * 2.8;
let cos_theta = 0.75;
let e1 = dreid.energy(r_sq, cos_theta);
let (s1, dc1) = dreid.derivative(r_sq, cos_theta);
let (e2, s2, dc2) = dreid.energy_derivative(r_sq, cos_theta);
assert_relative_eq!(e1, e2, epsilon = 1e-10);
assert_relative_eq!(s1, s2, epsilon = 1e-10);
assert_relative_eq!(dc1, dc2, epsilon = 1e-10);
}
#[test]
fn test_dreid_various_powers() {
let d0 = 7.0;
let r0 = 2.6;
let r_sq = 2.8 * 2.8;
let cos_theta = 0.9;
let dreid2: Dreid<f64, 2> = Dreid::new(d0, r0);
let dreid4: Dreid<f64, 4> = Dreid::new(d0, r0);
let dreid6: Dreid<f64, 6> = Dreid::new(d0, r0);
let e2 = dreid2.energy(r_sq, cos_theta);
let e4 = dreid4.energy(r_sq, cos_theta);
let e6 = dreid6.energy(r_sq, cos_theta);
let e2_lin = dreid2.energy(r_sq, 1.0);
let e4_lin = dreid4.energy(r_sq, 1.0);
let e6_lin = dreid6.energy(r_sq, 1.0);
assert_relative_eq!(e2_lin, e4_lin, epsilon = 1e-10);
assert_relative_eq!(e4_lin, e6_lin, epsilon = 1e-10);
assert_relative_eq!(e2 / e2_lin, cos_theta.powi(2), epsilon = 1e-10);
assert_relative_eq!(e4 / e4_lin, cos_theta.powi(4), epsilon = 1e-10);
assert_relative_eq!(e6 / e6_lin, cos_theta.powi(6), epsilon = 1e-10);
}
}