use super::d4_data::*;
const D4_S6: f64 = 1.0;
const D4_S8: f64 = 2.7;
const D4_A1: f64 = 0.52;
const D4_A2: f64 = 5.0;
const D4_S9: f64 = 5.0;
const D4_WF: f64 = 6.0;
const D4_GA: f64 = 3.0;
const D4_GC: f64 = 2.0;
const D4_CN_CUTOFF: f64 = 25.0;
const D4_DISP2_CUTOFF: f64 = 50.0;
pub struct D4Model {
pub nat: usize,
pub elements: Vec<u8>,
pub cn: Vec<f64>,
dispmat_flat: Vec<f64>,
mref: usize,
#[allow(dead_code)]
scaled_alpha: Vec<Vec<Vec<f64>>>,
c6_ref_flat: Vec<f64>,
elem_types: Vec<u8>,
atom_to_type: Vec<usize>,
}
pub struct D4Weights {
pub gwvec: Vec<Vec<f64>>,
pub dgwdq: Vec<Vec<f64>>,
}
impl D4Model {
pub fn new(elements: &[u8], positions: &[[f64; 3]]) -> Self {
let nat = elements.len();
let mref = MAX_REF;
let cn = compute_d4_cn(elements, positions);
let mut elem_types: Vec<u8> = Vec::new();
let mut atom_to_type = vec![0usize; nat];
for (iat, &z) in elements.iter().enumerate() {
if let Some(pos) = elem_types.iter().position(|&e| e == z) {
atom_to_type[iat] = pos;
} else {
atom_to_type[iat] = elem_types.len();
elem_types.push(z);
}
}
let scaled_alpha = compute_scaled_alpha(elements);
let n_types = elem_types.len();
let mut c6_ref_flat = vec![0.0f64; n_types * n_types * mref * mref];
for (it, &zi) in elem_types.iter().enumerate() {
let nref_i = get_nref(zi);
for (jt, &zj) in elem_types.iter().enumerate() {
let nref_j = get_nref(zj);
for iref in 0..nref_i {
let alpha_i = &scaled_alpha[it][iref];
if alpha_i.iter().all(|&v| v == 0.0) {
continue;
}
for jref in 0..nref_j {
let alpha_j = &scaled_alpha[jt][jref];
if alpha_j.iter().all(|&v| v == 0.0) {
continue;
}
let mut c6 = 0.0;
for k in 0..NFREQ {
c6 += CP_WEIGHTS[k] * alpha_i[k] * alpha_j[k];
}
c6 *= 3.0 / std::f64::consts::PI;
let idx = (it * n_types + jt) * mref * mref + iref * mref + jref;
c6_ref_flat[idx] = c6;
}
}
}
}
let mut dispmat_flat = vec![0.0f64; mref * nat * mref * nat];
let cutoff2 = D4_DISP2_CUTOFF * D4_DISP2_CUTOFF;
for iat in 0..nat {
let iz = elements[iat];
let it = atom_to_type[iat];
let nref_i = get_nref(iz);
for jat in 0..=iat {
let jz = elements[jat];
let jt = atom_to_type[jat];
let nref_j = get_nref(jz);
let dx = positions[iat][0] - positions[jat][0];
let dy = positions[iat][1] - positions[jat][1];
let dz = positions[iat][2] - positions[jat][2];
let r2 = dx * dx + dy * dy + dz * dz;
if r2 > cutoff2 || r2 < 1e-15 {
continue;
}
let rrij = 3.0 * R4R2[iz as usize - 1] * R4R2[jz as usize - 1];
let r0ij = D4_A1 * rrij.sqrt() + D4_A2;
let t6 = 1.0 / (r2.powi(3) + r0ij.powi(6));
let t8 = 1.0 / (r2.powi(4) + r0ij.powi(8));
let de = -(D4_S6 * t6 + D4_S8 * rrij * t8);
for iref in 0..nref_i {
for jref in 0..nref_j {
let c6_idx = (it * n_types + jt) * mref * mref + iref * mref + jref;
let c6 = c6_ref_flat[c6_idx];
let val = de * c6;
let idx_ij = ((iref * nat + iat) * mref + jref) * nat + jat;
let idx_ji = ((jref * nat + jat) * mref + iref) * nat + iat;
dispmat_flat[idx_ij] = val;
dispmat_flat[idx_ji] = val;
}
}
}
}
D4Model {
nat,
elements: elements.to_vec(),
cn,
dispmat_flat,
mref,
scaled_alpha,
c6_ref_flat,
elem_types,
atom_to_type,
}
}
pub fn weight_references(&self, charges: &[f64]) -> D4Weights {
let nat = self.nat;
let mut gwvec = vec![vec![0.0f64; MAX_REF]; nat];
let mut dgwdq = vec![vec![0.0f64; MAX_REF]; nat];
for iat in 0..nat {
let z = self.elements[iat];
let zi = z as usize;
if zi == 0 || zi > MAX_ELEM {
continue;
}
let nref = get_nref(z);
if nref == 0 {
continue;
}
let cn_val = self.cn[iat];
let q_val = charges[iat];
let zeff_i = EFFECTIVE_NUCLEAR_CHARGE[zi - 1];
let gi = CHEMICAL_HARDNESS[zi - 1] * D4_GC;
let mut ngw = vec![0usize; nref];
{
let max_cn_int: usize = 19;
let mut cnc = vec![0usize; max_cn_int + 1];
cnc[0] = 1; for iref in 0..nref {
let rcn = get_refcn(z, iref);
let icn = (rcn.round() as usize).min(max_cn_int);
cnc[icn] += 1;
}
for iref in 0..nref {
let rcn = get_refcn(z, iref);
let icn = (rcn.round() as usize).min(max_cn_int);
let c = cnc[icn];
ngw[iref] = c * (c + 1) / 2;
}
}
let mut covcn = vec![0.0f64; nref];
let mut refq = vec![0.0f64; nref];
for iref in 0..nref {
covcn[iref] = get_refcovcn(z, iref);
refq[iref] = get_refq(z, iref);
}
let mut norm = 0.0f64;
for iref in 0..nref {
for igw in 1..=ngw[iref] {
let wf = igw as f64 * D4_WF;
norm += weight_cn(wf, cn_val, covcn[iref]);
}
}
let norm_inv = if norm.abs() > 1e-150 { 1.0 / norm } else { 0.0 };
for iref in 0..nref {
let mut expw = 0.0f64;
for igw in 1..=ngw[iref] {
let wf = igw as f64 * D4_WF;
expw += weight_cn(wf, cn_val, covcn[iref]);
}
let mut gwk = expw * norm_inv;
if !gwk.is_finite() || norm_inv == 0.0 {
let max_covcn = covcn[..nref]
.iter()
.cloned()
.fold(f64::NEG_INFINITY, f64::max);
gwk = if (max_covcn - covcn[iref]).abs() < 1e-12 {
1.0
} else {
0.0
};
}
let z_val = zeta(D4_GA, gi, refq[iref] + zeff_i, q_val + zeff_i);
let dz_val = dzeta(D4_GA, gi, refq[iref] + zeff_i, q_val + zeff_i);
gwvec[iat][iref] = gwk * z_val;
dgwdq[iat][iref] = gwk * dz_val;
}
}
D4Weights { gwvec, dgwdq }
}
pub fn get_potential(&self, weights: &D4Weights) -> Vec<f64> {
let nat = self.nat;
let mref = self.mref;
let mut vat = vec![0.0f64; nat];
for iat in 0..nat {
let nref_i = get_nref(self.elements[iat]);
let mut vvec = vec![0.0f64; nref_i];
for iref in 0..nref_i {
for jat in 0..nat {
let nref_j = get_nref(self.elements[jat]);
for jref in 0..nref_j {
let idx = ((iref * nat + iat) * mref + jref) * nat + jat;
vvec[iref] += self.dispmat_flat[idx] * weights.gwvec[jat][jref];
}
}
}
for iref in 0..nref_i {
vat[iat] += vvec[iref] * weights.dgwdq[iat][iref];
}
}
vat
}
pub fn get_energy(&self, weights: &D4Weights) -> f64 {
let nat = self.nat;
let mref = self.mref;
let mut energy = 0.0f64;
for iat in 0..nat {
let nref_i = get_nref(self.elements[iat]);
let mut vvec = vec![0.0f64; nref_i];
for iref in 0..nref_i {
for jat in 0..nat {
let nref_j = get_nref(self.elements[jat]);
for jref in 0..nref_j {
let idx = ((iref * nat + iat) * mref + jref) * nat + jat;
vvec[iref] += self.dispmat_flat[idx] * weights.gwvec[jat][jref];
}
}
}
for iref in 0..nref_i {
energy += 0.5 * vvec[iref] * weights.gwvec[iat][iref];
}
}
energy
}
pub fn get_atm_energy(&self, positions: &[[f64; 3]]) -> f64 {
let nat = self.nat;
if nat < 3 || D4_S9.abs() < 1e-15 {
return 0.0;
}
let zero_charges = vec![0.0f64; nat];
let w0 = self.weight_references(&zero_charges);
let c6 = self.get_c6_matrix(&w0);
let cutoff2 = D4_CN_CUTOFF * D4_CN_CUTOFF;
let alp3 = 16.0 / 3.0;
let mut energy = 0.0f64;
for iat in 0..nat {
let iz = self.elements[iat] as usize;
for jat in 0..iat {
let jz = self.elements[jat] as usize;
let c6ij = c6[jat * nat + iat];
let r0ij = D4_A1 * (3.0 * R4R2[iz - 1] * R4R2[jz - 1]).sqrt() + D4_A2;
let vij = [
positions[jat][0] - positions[iat][0],
positions[jat][1] - positions[iat][1],
positions[jat][2] - positions[iat][2],
];
let r2ij = vij[0] * vij[0] + vij[1] * vij[1] + vij[2] * vij[2];
if r2ij > cutoff2 || r2ij < 1e-15 {
continue;
}
for kat in 0..jat {
let kz = self.elements[kat] as usize;
let c6ik = c6[kat * nat + iat];
let c6jk = c6[kat * nat + jat];
let c9 = -D4_S9 * (c6ij * c6ik * c6jk).abs().sqrt();
let r0ik = D4_A1 * (3.0 * R4R2[kz - 1] * R4R2[iz - 1]).sqrt() + D4_A2;
let r0jk = D4_A1 * (3.0 * R4R2[kz - 1] * R4R2[jz - 1]).sqrt() + D4_A2;
let r0 = r0ij * r0ik * r0jk;
let triple = triple_scale(iat, jat, kat);
let vik = [
positions[kat][0] - positions[iat][0],
positions[kat][1] - positions[iat][1],
positions[kat][2] - positions[iat][2],
];
let r2ik = vik[0] * vik[0] + vik[1] * vik[1] + vik[2] * vik[2];
if r2ik > cutoff2 || r2ik < 1e-15 {
continue;
}
let vjk = [vik[0] - vij[0], vik[1] - vij[1], vik[2] - vij[2]];
let r2jk = vjk[0] * vjk[0] + vjk[1] * vjk[1] + vjk[2] * vjk[2];
if r2jk > cutoff2 || r2jk < 1e-15 {
continue;
}
let r2 = r2ij * r2ik * r2jk;
let r1 = r2.sqrt();
let r3 = r2 * r1;
let r5 = r3 * r2;
let fdmp = 1.0 / (1.0 + 6.0 * (r0 / r1).powf(alp3));
let ang =
0.375 * (r2ij + r2jk - r2ik) * (r2ij - r2jk + r2ik) * (-r2ij + r2jk + r2ik)
/ r5
+ 1.0 / r3;
let rr = ang * fdmp;
let de = rr * c9 * triple / 6.0;
energy -= 6.0 * de;
}
}
}
energy
}
fn get_c6_matrix(&self, weights: &D4Weights) -> Vec<f64> {
let nat = self.nat;
let n_types = self.elem_types.len();
let mref = self.mref;
let mut c6 = vec![0.0f64; nat * nat];
for iat in 0..nat {
let it = self.atom_to_type[iat];
let nref_i = get_nref(self.elements[iat]);
for jat in 0..nat {
let jt = self.atom_to_type[jat];
let nref_j = get_nref(self.elements[jat]);
let mut val = 0.0;
for iref in 0..nref_i {
for jref in 0..nref_j {
let c6_idx = (it * n_types + jt) * mref * mref + iref * mref + jref;
val += weights.gwvec[iat][iref]
* weights.gwvec[jat][jref]
* self.c6_ref_flat[c6_idx];
}
}
c6[jat * nat + iat] = val;
}
}
c6
}
}
fn get_nref(z: u8) -> usize {
let zi = z as usize;
if zi == 0 || zi > MAX_ELEM {
return 0;
}
REFN[zi - 1]
}
fn get_refcn(z: u8, iref: usize) -> f64 {
let zi = z as usize;
if zi == 0 || zi > MAX_ELEM {
return 0.0;
}
REFCN[(zi - 1) * MAX_REF + iref]
}
fn get_refcovcn(z: u8, iref: usize) -> f64 {
let zi = z as usize;
if zi == 0 || zi > MAX_ELEM {
return 0.0;
}
REFCOVCN[(zi - 1) * MAX_REF + iref]
}
fn get_refq(z: u8, iref: usize) -> f64 {
let zi = z as usize;
if zi == 0 || zi > MAX_ELEM {
return 0.0;
}
REFQ_GFN2[(zi - 1) * MAX_REF + iref]
}
fn weight_cn(wf: f64, cn: f64, cnref: f64) -> f64 {
let d = cn - cnref;
(-wf * d * d).exp()
}
fn zeta(a: f64, c: f64, qref: f64, qmod: f64) -> f64 {
if qmod < 0.0 {
return a.exp();
}
(a * (1.0 - (c * (1.0 - qref / qmod)).exp())).exp()
}
fn dzeta(a: f64, c: f64, qref: f64, qmod: f64) -> f64 {
if qmod < 0.0 {
return 0.0;
}
let z = zeta(a, c, qref, qmod);
-a * c * (c * (1.0 - qref / qmod)).exp() * z * qref / (qmod * qmod)
}
fn triple_scale(ii: usize, jj: usize, kk: usize) -> f64 {
if ii == jj {
if ii == kk {
1.0 / 6.0
} else {
0.5
}
} else if ii != kk && jj != kk {
1.0
} else {
0.5
}
}
fn compute_d4_cn(elements: &[u8], positions: &[[f64; 3]]) -> Vec<f64> {
let nat = elements.len();
let mut cn = vec![0.0f64; nat];
let cutoff2 = D4_CN_CUTOFF * D4_CN_CUTOFF;
let ka = 7.5f64;
for iat in 0..nat {
let zi = elements[iat] as usize;
if zi == 0 || zi > MAX_ELEM {
continue;
}
let rcov_i = COVRAD_D3[zi - 1];
for jat in 0..nat {
if iat == jat {
continue;
}
let zj = elements[jat] as usize;
if zj == 0 || zj > MAX_ELEM {
continue;
}
let rcov_j = COVRAD_D3[zj - 1];
let dx = positions[iat][0] - positions[jat][0];
let dy = positions[iat][1] - positions[jat][1];
let dz = positions[iat][2] - positions[jat][2];
let r2 = dx * dx + dy * dy + dz * dz;
if r2 > cutoff2 || r2 < 1e-15 {
continue;
}
let r = r2.sqrt();
let rcov_sum = rcov_i + rcov_j;
let cn_val = 0.5 * erf(-ka * (r / rcov_sum - 1.0));
cn[iat] += cn_val + 0.5;
}
}
cn
}
fn erf(x: f64) -> f64 {
let sign = if x >= 0.0 { 1.0 } else { -1.0 };
let x = x.abs();
let t = 1.0 / (1.0 + 0.3275911 * x);
let t2 = t * t;
let t3 = t2 * t;
let t4 = t3 * t;
let t5 = t4 * t;
let poly =
0.254829592 * t - 0.284496736 * t2 + 1.421413741 * t3 - 1.453152027 * t4 + 1.061405429 * t5;
sign * (1.0 - poly * (-x * x).exp())
}
fn compute_scaled_alpha(elements: &[u8]) -> Vec<Vec<Vec<f64>>> {
let mut elem_types: Vec<u8> = Vec::new();
for &z in elements {
if !elem_types.contains(&z) {
elem_types.push(z);
}
}
let mut result = Vec::with_capacity(elem_types.len());
for &z in &elem_types {
let zi = z as usize;
let nref = get_nref(z);
let mut alphas_for_elem = vec![vec![0.0f64; NFREQ]; MAX_REF];
for iref in 0..nref {
let base_idx = (zi - 1) * MAX_REF + iref;
let is_sys = REFSYS[base_idx]; let hc = HCOUNT[base_idx];
let asc = ASCALE[base_idx];
let rh = REFH[base_idx];
if is_sys == 0 {
let alpha_base = (zi - 1) * MAX_REF * NFREQ + iref * NFREQ;
for k in 0..NFREQ {
alphas_for_elem[iref][k] = (asc * ALPHAIW[alpha_base + k]).max(0.0);
}
continue;
}
let ss = if is_sys <= MAX_SEC {
SSCALE[is_sys - 1]
} else {
0.0
};
let iz_sec = EFFECTIVE_NUCLEAR_CHARGE[is_sys - 1];
let eta_sec = CHEMICAL_HARDNESS[is_sys - 1] * D4_GC;
let z_scale = zeta(D4_GA, eta_sec, iz_sec, rh + iz_sec);
let sec_base = (is_sys - 1) * NFREQ;
let alpha_base = (zi - 1) * MAX_REF * NFREQ + iref * NFREQ;
for k in 0..NFREQ {
let sec_val = if is_sys <= MAX_SEC && sec_base + k < SECAIW.len() {
ss * SECAIW[sec_base + k] * z_scale
} else {
0.0
};
alphas_for_elem[iref][k] =
(asc * (ALPHAIW[alpha_base + k] - hc * sec_val)).max(0.0);
}
}
result.push(alphas_for_elem);
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_water_d4_cn() {
let elements = [8u8, 1, 1];
let positions = [
[0.0, 0.0, 0.221228620],
[0.0, 1.430453160, -0.885762480],
[0.0, -1.430453160, -0.885762480],
];
let cn = compute_d4_cn(&elements, &positions);
assert!(
cn[0] > 1.0 && cn[0] < 2.5,
"O CN={:.6}, expected 1.0–2.5",
cn[0]
);
assert!(
cn[1] > 0.5 && cn[1] < 1.5,
"H CN={:.6}, expected 0.5–1.5",
cn[1]
);
assert!(
cn[2] > 0.5 && cn[2] < 1.5,
"H CN={:.6}, expected 0.5–1.5",
cn[2]
);
}
#[test]
fn test_water_d4_potential_at_zero_charges() {
let elements = [8u8, 1, 1];
let positions = [
[0.0, 0.0, 0.221228620],
[0.0, 1.430453160, -0.885762480],
[0.0, -1.430453160, -0.885762480],
];
let model = D4Model::new(&elements, &positions);
let charges = [0.0, 0.0, 0.0];
let w = model.weight_references(&charges);
let vat = model.get_potential(&w);
let e_sc = model.get_energy(&w);
eprintln!("D4 vat (q=0): {:?}", vat);
eprintln!("D4 SC energy (q=0): {:.10e}", e_sc);
assert!(vat[0].abs() > 1e-6, "O vat should be non-zero");
assert!((vat[1] - vat[2]).abs() < 1e-12, "H vat should be symmetric");
assert!(e_sc < 0.0, "SC energy should be negative");
assert!(
(e_sc - (-2.506e-4)).abs() < 5e-5,
"SC energy should match Python: got {:.6e}",
e_sc
);
}
}