#[derive(Debug, Clone)]
pub struct EnvMatOutput {
pub env_mat: Vec<f32>,
pub diff: Vec<f32>,
pub sw: Vec<f32>,
pub last_dim: usize,
}
pub fn smooth_weight(r: f32, rmin: f32, rmax: f32) -> f32 {
assert!(rmin < rmax, "rmin must be < rmax");
let r = r.clamp(rmin, rmax);
let u = (r - rmin) / (rmax - rmin);
let u2 = u * u;
u2 * u * (-6.0 * u2 + 15.0 * u - 10.0) + 1.0
}
pub fn exp_sw(r: f32, rmin: f32, rmax: f32) -> f32 {
assert!(rmin < rmax, "rmin must be < rmax");
let r = r.clamp(0.0, rmax);
let c = 20.0;
let a = c / rmin;
let b = rmin;
(-((a * (r - b)).exp())).exp()
}
#[derive(Debug, Clone, Copy)]
pub struct EnvMatParams {
pub rcut: f32,
pub rcut_smth: f32,
pub protection: f32,
pub use_exp_switch: bool,
pub radial_only: bool,
}
pub fn make_env_mat(
coord: &[f32],
nlist: &[i32],
nf: usize,
nall: usize,
nloc: usize,
nnei: usize,
params: EnvMatParams,
) -> EnvMatOutput {
assert_eq!(coord.len(), nf * nall * 3);
assert_eq!(nlist.len(), nf * nloc * nnei);
let last_dim = if params.radial_only { 1 } else { 4 };
let mut env_mat = vec![0f32; nf * nloc * nnei * last_dim];
let mut diff = vec![0f32; nf * nloc * nnei * 3];
let mut sw = vec![0f32; nf * nloc * nnei];
for f in 0..nf {
let coord_base = f * nall * 3;
let nlist_base = f * nloc * nnei;
for i in 0..nloc {
let xi = coord[coord_base + i * 3];
let yi = coord[coord_base + i * 3 + 1];
let zi = coord[coord_base + i * 3 + 2];
for j in 0..nnei {
let raw_j = nlist[nlist_base + i * nnei + j];
let valid = raw_j >= 0;
let jj = raw_j.max(0) as usize;
let xj = coord[coord_base + jj * 3];
let yj = coord[coord_base + jj * 3 + 1];
let zj = coord[coord_base + jj * 3 + 2];
let dx = xj - xi;
let dy = yj - yi;
let dz = zj - zi;
let r2 = dx * dx + dy * dy + dz * dz;
let r = if valid && r2 > 0.0 { r2.sqrt() } else { 1.0 };
let weight_raw = if params.use_exp_switch {
exp_sw(r, params.rcut_smth, params.rcut)
} else {
smooth_weight(r, params.rcut_smth, params.rcut)
};
let weight = if valid { weight_raw } else { 0.0 };
let denom = r + params.protection;
let t0 = 1.0 / denom;
let inv_denom2 = 1.0 / (denom * denom);
let diff_base = (nlist_base + i * nnei + j) * 3;
if valid {
diff[diff_base] = dx;
diff[diff_base + 1] = dy;
diff[diff_base + 2] = dz;
}
sw[nlist_base + i * nnei + j] = weight;
let em_base = (nlist_base + i * nnei + j) * last_dim;
env_mat[em_base] = t0 * weight;
if !params.radial_only {
env_mat[em_base + 1] = dx * inv_denom2 * weight;
env_mat[em_base + 2] = dy * inv_denom2 * weight;
env_mat[em_base + 3] = dz * inv_denom2 * weight;
}
}
}
}
EnvMatOutput {
env_mat,
diff,
sw,
last_dim,
}
}
pub fn apply_stats(
env_mat: &mut [f32],
atype_loc: &[i32],
davg: &[f32],
dstd: &[f32],
nf: usize,
nloc: usize,
nnei: usize,
last_dim: usize,
ntypes: usize,
) {
assert_eq!(env_mat.len(), nf * nloc * nnei * last_dim);
assert_eq!(atype_loc.len(), nf * nloc);
assert_eq!(davg.len(), ntypes * nnei * last_dim);
assert_eq!(dstd.len(), ntypes * nnei * last_dim);
for f in 0..nf {
for i in 0..nloc {
let t = atype_loc[f * nloc + i].max(0) as usize;
assert!(t < ntypes, "atype out of range");
for j in 0..nnei {
let em_base = ((f * nloc + i) * nnei + j) * last_dim;
let stat_base = (t * nnei + j) * last_dim;
for k in 0..last_dim {
let v = env_mat[em_base + k];
env_mat[em_base + k] = (v - davg[stat_base + k]) / dstd[stat_base + k];
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn env_mat_shapes_match() {
let nf = 1;
let nall = 4;
let nloc = 2;
let nnei = 3;
let coord = vec![
0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, ];
let nlist = vec![1i32, 2, 3, 0, 2, -1];
let params = EnvMatParams {
rcut: 2.0,
rcut_smth: 0.5,
protection: 0.0,
use_exp_switch: false,
radial_only: false,
};
let out = make_env_mat(&coord, &nlist, nf, nall, nloc, nnei, params);
assert_eq!(out.env_mat.len(), nf * nloc * nnei * 4);
assert_eq!(out.diff.len(), nf * nloc * nnei * 3);
assert_eq!(out.sw.len(), nf * nloc * nnei);
let masked_idx = (0 * nloc + 1) * nnei + 2; assert_eq!(out.sw[masked_idx], 0.0);
for k in 0..4 {
assert_eq!(out.env_mat[masked_idx * 4 + k], 0.0);
}
}
#[test]
fn smooth_weight_clamps() {
let rmin = 1.0;
let rmax = 6.0;
assert_eq!(smooth_weight(0.5, rmin, rmax), 1.0); assert_eq!(smooth_weight(10.0, rmin, rmax), 0.0); }
}