use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SymmetryFunctionParams {
pub radial_cutoff: f64,
pub angular_cutoff: f64,
pub radial_shifts: Vec<f64>,
pub radial_etas: Vec<f64>,
pub angular_shifts: Vec<f64>,
pub angular_etas: Vec<f64>,
pub angular_zetas: Vec<f64>,
}
impl Default for SymmetryFunctionParams {
fn default() -> Self {
let radial_shifts: Vec<f64> = (0..16).map(|i| 0.9 + 0.25 * i as f64).collect();
let angular_shifts: Vec<f64> = (0..4).map(|i| 0.9 + 0.5 * i as f64).collect();
Self {
radial_cutoff: 5.2,
angular_cutoff: 3.5,
radial_shifts,
radial_etas: vec![16.0; 16],
angular_shifts,
angular_etas: vec![8.0; 4],
angular_zetas: vec![1.0, 2.0, 4.0, 8.0],
}
}
}
#[derive(Debug, Clone)]
pub struct Aev {
pub radial: Vec<f64>,
pub angular: Vec<f64>,
}
impl Aev {
pub fn len(&self) -> usize {
self.radial.len() + self.angular.len()
}
pub fn is_empty(&self) -> bool {
self.radial.is_empty() && self.angular.is_empty()
}
pub fn to_vec(&self) -> Vec<f64> {
let mut v = self.radial.clone();
v.extend_from_slice(&self.angular);
v
}
}
#[inline]
fn cutoff_fn(r: f64, rc: f64) -> f64 {
if r >= rc {
0.0
} else {
0.5 * (1.0 + (std::f64::consts::PI * r / rc).cos())
}
}
pub fn compute_aevs(
elements: &[u8],
positions: &[[f64; 3]],
params: &SymmetryFunctionParams,
) -> Vec<Aev> {
let n = elements.len();
let n_radial = params.radial_shifts.len() * params.radial_etas.len();
let n_angular =
params.angular_shifts.len() * params.angular_etas.len() * params.angular_zetas.len();
#[cfg(feature = "parallel")]
{
use rayon::prelude::*;
(0..n)
.into_par_iter()
.map(|i| compute_single_aev(i, n, positions, params, n_radial, n_angular))
.collect()
}
#[cfg(not(feature = "parallel"))]
{
(0..n)
.map(|i| compute_single_aev(i, n, positions, params, n_radial, n_angular))
.collect()
}
}
fn compute_single_aev(
i: usize,
n: usize,
positions: &[[f64; 3]],
params: &SymmetryFunctionParams,
n_radial: usize,
n_angular: usize,
) -> Aev {
let mut radial = vec![0.0; n_radial];
let mut angular = vec![0.0; n_angular];
for j in 0..n {
if i == j {
continue;
}
let r_ij = distance(&positions[i], &positions[j]);
if r_ij >= params.radial_cutoff {
continue;
}
let fc = cutoff_fn(r_ij, params.radial_cutoff);
for (s_idx, (&rs, &eta)) in params
.radial_shifts
.iter()
.zip(params.radial_etas.iter())
.enumerate()
{
radial[s_idx] += (-eta * (r_ij - rs).powi(2)).exp() * fc;
}
}
for j in 0..n {
if i == j {
continue;
}
let r_ij = distance(&positions[i], &positions[j]);
if r_ij >= params.angular_cutoff {
continue;
}
for k in (j + 1)..n {
if i == k {
continue;
}
let r_ik = distance(&positions[i], &positions[k]);
if r_ik >= params.angular_cutoff {
continue;
}
let theta = angle(&positions[i], &positions[j], &positions[k]);
let fc_ij = cutoff_fn(r_ij, params.angular_cutoff);
let fc_ik = cutoff_fn(r_ik, params.angular_cutoff);
let mut idx = 0;
for &rs in ¶ms.angular_shifts {
for &eta in ¶ms.angular_etas {
for &zeta in ¶ms.angular_zetas {
let r_mid = 0.5 * (r_ij + r_ik);
let g = 2.0f64.powf(1.0 - zeta)
* (1.0 + theta.cos()).powf(zeta)
* (-eta * (r_mid - rs).powi(2)).exp()
* fc_ij
* fc_ik;
angular[idx] += g;
idx += 1;
}
}
}
}
}
Aev { radial, angular }
}
fn distance(a: &[f64; 3], b: &[f64; 3]) -> f64 {
let dx = a[0] - b[0];
let dy = a[1] - b[1];
let dz = a[2] - b[2];
(dx * dx + dy * dy + dz * dz).sqrt()
}
fn angle(center: &[f64; 3], a: &[f64; 3], b: &[f64; 3]) -> f64 {
let va = [a[0] - center[0], a[1] - center[1], a[2] - center[2]];
let vb = [b[0] - center[0], b[1] - center[1], b[2] - center[2]];
let dot = va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2];
let ma = (va[0] * va[0] + va[1] * va[1] + va[2] * va[2]).sqrt();
let mb = (vb[0] * vb[0] + vb[1] * vb[1] + vb[2] * vb[2]).sqrt();
(dot / (ma * mb + 1e-30)).clamp(-1.0, 1.0).acos()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn aev_default_params() {
let p = SymmetryFunctionParams::default();
assert!(!p.radial_etas.is_empty());
assert!(!p.radial_shifts.is_empty());
assert!(p.radial_cutoff > 0.0);
}
#[test]
fn aev_single_atom_is_zero() {
let params = SymmetryFunctionParams::default();
let aevs = compute_aevs(&[1], &[[0.0, 0.0, 0.0]], ¶ms);
assert_eq!(aevs.len(), 1);
for v in &aevs[0].radial {
assert!(v.abs() < 1e-15, "lone atom radial AEV should be zero");
}
}
#[test]
fn aev_h2_has_nonzero_radial() {
let params = SymmetryFunctionParams::default();
let aevs = compute_aevs(&[1, 1], &[[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]], ¶ms);
assert_eq!(aevs.len(), 2);
let radial_sum: f64 = aevs[0].radial.iter().sum();
assert!(radial_sum > 0.0, "H2 should have nonzero radial AEV");
}
#[test]
fn aev_symmetric_molecule_equal_descriptors() {
let params = SymmetryFunctionParams::default();
let aevs = compute_aevs(&[1, 1], &[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], ¶ms);
let v0 = aevs[0].to_vec();
let v1 = aevs[1].to_vec();
assert_eq!(v0.len(), v1.len());
for i in 0..v0.len() {
assert!(
(v0[i] - v1[i]).abs() < 1e-12,
"symmetric atoms should have equal AEVs at index {i}"
);
}
}
#[test]
fn aev_length_consistent() {
let params = SymmetryFunctionParams::default();
let aevs = compute_aevs(&[1, 6], &[[0.0, 0.0, 0.0], [1.5, 0.0, 0.0]], ¶ms);
assert!(!aevs[0].is_empty());
assert_eq!(aevs[0].len(), aevs[0].radial.len() + aevs[0].angular.len());
}
}