use serde::{Deserialize, Serialize};
pub const ANI_ELEMENTS: [u8; 7] = [1, 6, 7, 8, 9, 16, 17];
pub const N_SPECIES: usize = ANI_ELEMENTS.len();
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AevParams {
pub radial_cutoff: f64,
pub angular_cutoff: f64,
pub radial_eta: Vec<f64>,
pub radial_rs: Vec<f64>,
pub angular_eta: Vec<f64>,
pub angular_rs: Vec<f64>,
pub angular_zeta: Vec<f64>,
pub angular_theta_s: Vec<f64>,
}
impl AevParams {
pub fn radial_length(&self) -> usize {
self.radial_eta.len() * self.radial_rs.len()
}
pub fn angular_length(&self) -> usize {
self.angular_eta.len()
* self.angular_rs.len()
* self.angular_zeta.len()
* self.angular_theta_s.len()
}
pub fn total_aev_length(&self) -> usize {
let n_rad = N_SPECIES * self.radial_length();
let n_ang = N_SPECIES * (N_SPECIES + 1) / 2 * self.angular_length();
n_rad + n_ang
}
}
pub fn species_index(z: u8) -> Option<usize> {
ANI_ELEMENTS.iter().position(|&e| e == z)
}
pub fn default_ani2x_params() -> AevParams {
use std::f64::consts::PI;
let radial_eta = vec![19.7; 8];
let radial_rs: Vec<f64> = (0..8).map(|i| 0.8 + 0.5625 * i as f64).collect();
let angular_eta = vec![12.5; 4];
let angular_rs: Vec<f64> = (0..4).map(|i| 0.8 + 0.95 * i as f64).collect();
let angular_zeta = vec![14.1; 1];
let angular_theta_s: Vec<f64> = (0..8).map(|i| PI * i as f64 / 8.0).collect();
AevParams {
radial_cutoff: 5.2,
angular_cutoff: 3.5,
radial_eta,
radial_rs,
angular_eta,
angular_rs,
angular_zeta,
angular_theta_s,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_species_index() {
assert_eq!(species_index(1), Some(0)); assert_eq!(species_index(6), Some(1)); assert_eq!(species_index(8), Some(3)); assert_eq!(species_index(26), None); }
#[test]
fn test_aev_dimensions() {
let params = default_ani2x_params();
assert!(params.radial_length() > 0);
assert!(params.angular_length() > 0);
assert!(params.total_aev_length() > 0);
}
}