use super::aev_params::AevParams;
use super::neighbor::CellList;
use super::nn::FeedForwardNet;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub const ANI_TM_ELEMENTS: [u8; 25] = [
1, 6, 7, 8, 9, 14, 15, 16, 17, 22, 24, 25, 26, 27, 28, 29, 30, 35, 44, 46, 47, 53, 78, 79,
0,
];
pub const N_SPECIES_TM: usize = 24;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AniTmResult {
pub energy: f64,
pub forces: Vec<[f64; 3]>,
pub atomic_energies: Vec<f64>,
pub species: Vec<u8>,
pub has_transition_metals: bool,
}
pub fn species_index_tm(z: u8) -> Option<usize> {
ANI_TM_ELEMENTS[..N_SPECIES_TM].iter().position(|&e| e == z)
}
pub fn is_ani_tm_supported(z: u8) -> bool {
species_index_tm(z).is_some()
}
pub fn default_ani_tm_params() -> AevParams {
use std::f64::consts::PI;
let radial_eta = vec![16.0; 12];
let radial_rs: Vec<f64> = (0..12).map(|i| 0.8 + 0.5 * i as f64).collect();
let angular_eta = vec![8.0; 6];
let angular_rs: Vec<f64> = (0..6).map(|i| 0.8 + 0.8 * i as f64).collect();
let angular_zeta = vec![14.1, 6.3]; let angular_theta_s: Vec<f64> = (0..8).map(|i| PI * i as f64 / 8.0).collect();
AevParams {
radial_cutoff: 6.5, angular_cutoff: 4.5,
radial_eta,
radial_rs,
angular_eta,
angular_rs,
angular_zeta,
angular_theta_s,
}
}
pub fn compute_aevs_tm(
elements: &[u8],
positions: &[[f64; 3]],
params: &AevParams,
) -> Vec<Vec<f64>> {
let n_atoms = elements.len();
let mut has_unsupported = false;
for (i, &z) in elements.iter().enumerate() {
if !is_ani_tm_supported(z) {
eprintln!(
"WARNING: ANI-TM unsupported element Z={z} at atom {i}, will be skipped in AEV"
);
has_unsupported = true;
}
}
let _ = has_unsupported;
let cell_list = CellList::new(positions, params.radial_cutoff);
let neighbors = cell_list.find_neighbors(positions);
let n_rad_per_pair = params.radial_eta.len() * params.radial_rs.len();
let n_ang_per_triple = params.angular_eta.len()
* params.angular_rs.len()
* params.angular_zeta.len()
* params.angular_theta_s.len();
let n_rad_total = N_SPECIES_TM * n_rad_per_pair;
let n_ang_total = N_SPECIES_TM * (N_SPECIES_TM + 1) / 2 * n_ang_per_triple;
let aev_length = n_rad_total + n_ang_total;
let mut aevs = vec![vec![0.0; aev_length]; n_atoms];
for pair in &neighbors {
let i = pair.i;
let j = pair.j;
let r = pair.dist_sq.sqrt();
if r > params.radial_cutoff || r < 0.1 {
continue;
}
let sj = match species_index_tm(elements[j]) {
Some(s) => s,
None => continue,
};
let fc = cutoff_function(r, params.radial_cutoff);
for (ie, eta) in params.radial_eta.iter().enumerate() {
for (ir, rs) in params.radial_rs.iter().enumerate() {
let g = (-eta * (r - rs).powi(2)).exp() * fc;
let idx = sj * n_rad_per_pair + ie * params.radial_rs.len() + ir;
aevs[i][idx] += g;
}
}
}
for pair1 in &neighbors {
let i = pair1.i;
let j = pair1.j;
let rij = pair1.dist_sq.sqrt();
if rij > params.angular_cutoff || rij < 0.1 {
continue;
}
let sj = match species_index_tm(elements[j]) {
Some(s) => s,
None => continue,
};
for pair2 in &neighbors {
if pair2.i != i || pair2.j <= j {
continue;
}
let k = pair2.j;
let rik = pair2.dist_sq.sqrt();
if rik > params.angular_cutoff || rik < 0.1 {
continue;
}
let sk = match species_index_tm(elements[k]) {
Some(s) => s,
None => continue,
};
let cos_theta = compute_cos_angle(positions, i, j, k, rij, rik);
let theta = cos_theta.clamp(-1.0, 1.0).acos();
let fc_j = cutoff_function(rij, params.angular_cutoff);
let fc_k = cutoff_function(rik, params.angular_cutoff);
let (s_lo, s_hi) = if sj <= sk { (sj, sk) } else { (sk, sj) };
let pair_idx = s_lo * N_SPECIES_TM - s_lo * (s_lo + 1) / 2 + s_hi;
for (ie, eta) in params.angular_eta.iter().enumerate() {
for (ir, rs) in params.angular_rs.iter().enumerate() {
for (iz, zeta) in params.angular_zeta.iter().enumerate() {
for (it, theta_s) in params.angular_theta_s.iter().enumerate() {
let ravg = (rij + rik) / 2.0;
let g = (1.0 + (theta - theta_s).cos()).powf(*zeta)
* (-eta * (ravg - rs).powi(2)).exp()
* fc_j
* fc_k;
let sub_idx = ie
* params.angular_rs.len()
* params.angular_zeta.len()
* params.angular_theta_s.len()
+ ir * params.angular_zeta.len() * params.angular_theta_s.len()
+ iz * params.angular_theta_s.len()
+ it;
let idx = n_rad_total + pair_idx * n_ang_per_triple + sub_idx;
if idx < aev_length {
aevs[i][idx] += g;
}
}
}
}
}
}
}
aevs
}
fn cutoff_function(r: f64, rc: f64) -> f64 {
if r >= rc {
return 0.0;
}
0.5 * (1.0 + (std::f64::consts::PI * r / rc).cos())
}
fn compute_cos_angle(
positions: &[[f64; 3]],
i: usize,
j: usize,
k: usize,
rij: f64,
rik: f64,
) -> f64 {
let mut dot = 0.0;
for d in 0..3 {
let vij = positions[j][d] - positions[i][d];
let vik = positions[k][d] - positions[i][d];
dot += vij * vik;
}
dot / (rij * rik + 1e-30)
}
pub fn make_tm_test_models(aev_length: usize) -> HashMap<u8, FeedForwardNet> {
let mut models = HashMap::new();
for &z in ANI_TM_ELEMENTS[..N_SPECIES_TM].iter() {
models.insert(z, super::weights::make_test_model(aev_length));
}
models
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tm_species_index() {
assert_eq!(species_index_tm(1), Some(0)); assert_eq!(species_index_tm(6), Some(1)); assert_eq!(species_index_tm(26), Some(12)); assert_eq!(species_index_tm(79), Some(23)); assert_eq!(species_index_tm(2), None); }
#[test]
fn test_ani_tm_params() {
let params = default_ani_tm_params();
assert!(params.radial_cutoff > 5.0);
assert!(params.angular_cutoff > 3.5);
}
#[test]
fn test_ani_tm_water() {
let elements = vec![8u8, 1, 1];
let positions = vec![
[0.0, 0.0, 0.117],
[0.0, 0.757, -0.469],
[0.0, -0.757, -0.469],
];
let params = default_ani_tm_params();
let aevs = compute_aevs_tm(&elements, &positions, ¶ms);
assert_eq!(aevs.len(), 3);
assert!(!aevs[0].is_empty());
}
}