#[cfg(feature = "cuda")]
use cudarc::driver::CudaContext;
use ewald::{PmeRecip, force_coulomb_short_range, get_grid_n};
use lin_alg::f32::Vec3;
use crate::non_bonded::{CHARGE_UNIT_SCALER, EWALD_ALPHA, LONG_RANGE_CUTOFF};
pub const K_ELEC: f32 = CHARGE_UNIT_SCALER * CHARGE_UNIT_SCALER;
fn wrap1(x: f32, l: f32) -> f32 {
x.rem_euclid(l)
}
fn wrap_pos(p: Vec3, l: (f32, f32, f32)) -> Vec3 {
Vec3::new(wrap1(p.x, l.0), wrap1(p.y, l.1), wrap1(p.z, l.2))
}
#[cfg(feature = "cuda")]
fn make_pme(l: (f32, f32, f32), alpha: f32, mesh_spacing: f32) -> PmeRecip {
let stream = {
let ctx = CudaContext::new(0).unwrap();
ctx.default_stream()
};
let dims = get_grid_n(l, mesh_spacing);
PmeRecip::new(Some(&stream), dims, l, alpha)
}
#[cfg(not(feature = "cuda"))]
fn make_pme(l: (f32, f32, f32), alpha: f32, mesh_spacing: f32) -> PmeRecip {
let dims = get_grid_n(l, mesh_spacing);
PmeRecip::new(dims, l, alpha)
}
pub fn spme_pair_forces_energy(
r1: Vec3,
r2: Vec3,
q1_e: f32,
q2_e: f32,
box_len: f32,
alpha: f32,
) -> ((Vec3, Vec3, Vec3), (f32, f32, f32)) {
#[cfg(feature = "cuda")]
let stream = {
let ctx = CudaContext::new(0).unwrap();
ctx.default_stream()
};
let q1 = q1_e * CHARGE_UNIT_SCALER;
let q2 = q2_e * CHARGE_UNIT_SCALER;
let diff = {
let mut v = r1 - r2;
v.x -= box_len * (v.x / box_len).round();
v.y -= box_len * (v.y / box_len).round();
v.z -= box_len * (v.z / box_len).round();
v
};
let dist = diff.magnitude();
let inv_dist = 1.0 / dist;
let dir = diff * inv_dist;
let (f_sr_1, e_sr) =
force_coulomb_short_range(dir, dist, inv_dist, q1, q2, LONG_RANGE_CUTOFF, alpha);
let l = (box_len, box_len, box_len);
let mut pme = make_pme(l, alpha, 1.0);
let pos = vec![wrap_pos(r1, l), wrap_pos(r2, l)];
let q_arr = vec![q1, q2];
#[cfg(feature = "cuda")]
let (f_recip, e_lr) = pme.forces_gpu(&stream, &pos, &q_arr);
#[cfg(not(feature = "cuda"))]
let (f_recip, e_lr) = pme.forces(&pos, &q_arr);
let e_total = e_sr + e_lr;
((f_sr_1, f_recip[0], f_recip[1]), (e_sr, e_lr, e_total))
}
pub fn vacuum_coulomb_energy(q1_e: f32, q2_e: f32, dist: f32) -> f32 {
K_ELEC * q1_e * q2_e / dist
}
const REL_TOL: f32 = 0.01;
fn assert_rel_close(got: f32, expected: f32, tol: f32, label: &str) {
if expected.abs() < 1e-6 {
assert!(got.abs() < 1e-4, "{label}: expected ≈ 0, got {got:.6e}");
return;
}
let rel = ((got - expected) / expected).abs();
assert!(
rel < tol,
"{label}: got {got:.6}, expected {expected:.6}, rel_err = {rel:.4} (tol = {tol})"
);
}
#[test]
fn test_spme_energy_opposite_charges() {
let box_len = 50.0;
let alpha = EWALD_ALPHA;
for (dist, tag) in [(3.0, "3Å"), (5.0, "5Å"), (8.0, "8Å")] {
let center = box_len / 2.0;
let r1 = Vec3::new(center - dist / 2.0, center, center);
let r2 = Vec3::new(center + dist / 2.0, center, center);
let (_, (e_sr, e_lr, e_total)) = spme_pair_forces_energy(r1, r2, 1.0, -1.0, box_len, alpha);
let e_vac = vacuum_coulomb_energy(1.0, -1.0, dist);
println!(
"+1/−1 at {tag}: e_sr={e_sr:.4} e_lr={e_lr:.4} \
e_total={e_total:.4} e_vac={e_vac:.4} kcal/mol"
);
assert_rel_close(e_total, e_vac, REL_TOL, &format!("energy +1/-1 at {tag}"));
}
}
#[test]
fn test_spme_energy_fractional_charges() {
let box_len = 50.0;
let alpha = EWALD_ALPHA;
let dist = 5.0;
let center = box_len / 2.0;
let r1 = Vec3::new(center - dist / 2.0, center, center);
let r2 = Vec3::new(center + dist / 2.0, center, center);
for (q1e, q2e, tag) in [(0.5f32, -0.5f32, "q=±0.5"), (0.25, -0.25, "q=±0.25")] {
let (_, (_, _, e_total)) = spme_pair_forces_energy(r1, r2, q1e, q2e, box_len, alpha);
let e_vac = vacuum_coulomb_energy(q1e, q2e, dist);
println!("{tag}: e_total={e_total:.4} e_vac={e_vac:.4} kcal/mol");
assert_rel_close(e_total, e_vac, REL_TOL, tag);
}
}
#[test]
fn test_spme_energy_box_convergence() {
let dist = 5.0;
let alpha = EWALD_ALPHA;
let e_vac = vacuum_coulomb_energy(1.0, -1.0, dist);
println!("Box-size convergence test (e_vac = {e_vac:.4} kcal/mol):");
for box_len in [20.0, 30.0, 50.0] {
let c = box_len / 2.0;
let r1 = Vec3::new(c - dist / 2.0, c, c);
let r2 = Vec3::new(c + dist / 2.0, c, c);
let (_, (_, _, e_total)) = spme_pair_forces_energy(r1, r2, 1.0, -1.0, box_len, alpha);
let rel = ((e_total - e_vac) / e_vac).abs();
println!(" L={box_len:.0} Å: e_total={e_total:.4} rel_err={rel:.4}");
}
let c = 50.0 / 2.0;
let r1 = Vec3::new(c - dist / 2.0, c, c);
let r2 = Vec3::new(c + dist / 2.0, c, c);
let (_, (_, _, e_total)) = spme_pair_forces_energy(r1, r2, 1.0, -1.0, 50.0, alpha);
assert_rel_close(e_total, e_vac, REL_TOL, "energy at L=50 Å");
}
#[test]
fn test_spme_force_magnitude() {
let box_len = 50.;
for (dist, tag) in [(3., "3Å"), (5.0, "5Å"), (8.0, "8Å")] {
let center = box_len / 2.0;
let r1 = Vec3::new(center - dist / 2.0, center, center);
let r2 = Vec3::new(center + dist / 2.0, center, center);
let ((f_sr_1, f_lr_1, _), _) =
spme_pair_forces_energy(r1, r2, 1.0, -1.0, box_len, EWALD_ALPHA);
let fx_total = f_sr_1.x + f_lr_1.x;
let f_vac_x = K_ELEC / (dist * dist);
println!(
"+1/−1 at {tag}: f_sr_x={:.4} f_lr_x={:.4} \
f_total_x={fx_total:.4} f_vac_x={f_vac_x:.4} kcal/(mol·Å)",
f_sr_1.x, f_lr_1.x
);
let ftol = if dist >= 7.0 { 2.0 * REL_TOL } else { REL_TOL };
assert_rel_close(fx_total, f_vac_x, ftol, &format!("Fx on q1 at {tag}"));
let fy = f_sr_1.y + f_lr_1.y;
let fz = f_sr_1.z + f_lr_1.z;
assert!(
fy.abs() < 0.01 * f_vac_x,
"{tag}: Fy should be ~0, got {fy:.4e}"
);
assert!(
fz.abs() < 0.01 * f_vac_x,
"{tag}: Fz should be ~0, got {fz:.4e}"
);
}
}
#[test]
fn test_spme_force_newton3() {
let box_len = 50.;
let alpha = EWALD_ALPHA;
for (dist, tag) in [(3., "3Å"), (5.0, "5Å"), (8.0, "8Å")] {
let center = box_len / 2.0;
let r1 = Vec3::new(center - dist / 2.0, center, center);
let r2 = Vec3::new(center + dist / 2.0, center, center);
let q1 = 1. * CHARGE_UNIT_SCALER;
let q2 = -1.0 * CHARGE_UNIT_SCALER;
let diff = r1 - r2;
let inv_d = 1.0 / dist;
let dir = diff * inv_d;
let (f_sr_1, _) =
force_coulomb_short_range(dir, dist, inv_d, q1, q2, LONG_RANGE_CUTOFF, alpha);
let (f_sr_2, _) =
force_coulomb_short_range(-dir, dist, inv_d, q2, q1, LONG_RANGE_CUTOFF, alpha);
let l = (box_len, box_len, box_len);
let mut pme = make_pme(l, alpha, 1.0);
let pos = vec![wrap_pos(r1, l), wrap_pos(r2, l)];
let q_arr = vec![q1, q2];
let (f_recip, _) = pme.forces(&pos, &q_arr);
let f1 = f_sr_1 + f_recip[0];
let f2 = f_sr_2 + f_recip[1];
let sum_x = f1.x + f2.x;
let sum_y = f1.y + f2.y;
let sum_z = f1.z + f2.z;
let sum_mag = (sum_x * sum_x + sum_y * sum_y + sum_z * sum_z).sqrt();
let f1_mag = f1.magnitude();
println!("Newton3 at {tag}: |f1|={f1_mag:.4} |f1+f2|={sum_mag:.4e}");
assert!(
sum_mag < 0.02 * f1_mag,
"{tag}: Newton 3rd law violated: |f1+f2| = {sum_mag:.4e}, |f1| = {f1_mag:.4}"
);
}
}
#[test]
fn test_short_range_cutoff() {
let q = CHARGE_UNIT_SCALER;
let dir = Vec3::new(1.0, 0.0, 0.0);
for dist in [
LONG_RANGE_CUTOFF,
LONG_RANGE_CUTOFF + 0.1,
LONG_RANGE_CUTOFF + 1.0,
] {
let (f, e) =
force_coulomb_short_range(dir, dist, 1.0 / dist, q, q, LONG_RANGE_CUTOFF, EWALD_ALPHA);
assert_eq!(
f.magnitude_squared(),
0.0,
"force should be 0 at dist={dist:.2}: {f:?}"
);
assert_eq!(e, 0.0, "energy should be 0 at dist={dist:.2}: {e}");
}
}
#[test]
fn test_spme_energy_non_cubic_box() {
let alpha = EWALD_ALPHA;
let dist = 5.0;
let lx = 60.0;
let ly = 30.0;
let lz = 30.0;
let r1 = Vec3::new(lx / 2.0 - dist / 2.0, ly / 2.0, lz / 2.0);
let r2 = Vec3::new(lx / 2.0 + dist / 2.0, ly / 2.0, lz / 2.0);
let q1 = 1.0 * CHARGE_UNIT_SCALER;
let q2 = -1.0 * CHARGE_UNIT_SCALER;
let l = (lx, ly, lz);
let mut pme = make_pme(l, alpha, 1.0);
let pos = vec![wrap_pos(r1, l), wrap_pos(r2, l)];
let q_arr = vec![q1, q2];
let (_, e_lr) = pme.forces(&pos, &q_arr);
let diff = r1 - r2;
let inv_d = 1.0 / dist;
let dir = diff * inv_d;
let (_, e_sr) = force_coulomb_short_range(dir, dist, inv_d, q1, q2, LONG_RANGE_CUTOFF, alpha);
let e_total = e_sr + e_lr;
let e_vac = vacuum_coulomb_energy(1.0, -1.0, dist);
println!("Non-cubic box (60×30×30): e_total={e_total:.4} e_vac={e_vac:.4} kcal/mol");
assert_rel_close(e_total, e_vac, REL_TOL, "non-cubic box energy");
}
#[test]
fn test_spme_force_like_charges() {
let box_len = 50.0;
let alpha = EWALD_ALPHA;
for (dist, tag) in [(3.0, "3Å"), (5.0, "5Å"), (8.0, "8Å")] {
let center = box_len / 2.0;
let r1 = Vec3::new(center - dist / 2.0, center, center);
let r2 = Vec3::new(center + dist / 2.0, center, center);
let ((f_sr_1, f_lr_1, _), _) = spme_pair_forces_energy(r1, r2, 1.0, 1.0, box_len, alpha);
let fx_total = f_sr_1.x + f_lr_1.x;
let f_vac_x = -K_ELEC / (dist * dist);
println!(
"+1/+1 at {tag}: f_sr_x={:.4} f_lr_x={:.4} \
f_total_x={fx_total:.4} f_vac_x={f_vac_x:.4} kcal/(mol·Å)",
f_sr_1.x, f_lr_1.x
);
let ftol = if dist >= 7.0 { 2.0 * REL_TOL } else { REL_TOL };
assert_rel_close(
fx_total,
f_vac_x,
ftol,
&format!("Fx on q1 (like) at {tag}"),
);
let fy = f_sr_1.y + f_lr_1.y;
let fz = f_sr_1.z + f_lr_1.z;
assert!(
fy.abs() < 0.01 * f_vac_x.abs(),
"{tag}: Fy should be ~0, got {fy:.4e}"
);
assert!(
fz.abs() < 0.01 * f_vac_x.abs(),
"{tag}: Fz should be ~0, got {fz:.4e}"
);
}
}
#[test]
fn test_spme_force_matches_energy_gradient() {
let box_len = 50.0;
let alpha = EWALD_ALPHA;
let delta = 0.01;
for (dist, tag) in [(3.0, "3Å"), (5.0, "5Å"), (8.0, "8Å")] {
let center = box_len / 2.0;
let r1 = Vec3::new(center - dist / 2.0, center, center);
let r2 = Vec3::new(center + dist / 2.0, center, center);
let r1_plus = Vec3::new(r1.x + delta, r1.y, r1.z);
let r1_minus = Vec3::new(r1.x - delta, r1.y, r1.z);
let (_, (_, _, e_plus)) = spme_pair_forces_energy(r1_plus, r2, 1.0, -1.0, box_len, alpha);
let (_, (_, _, e_minus)) = spme_pair_forces_energy(r1_minus, r2, 1.0, -1.0, box_len, alpha);
let fx_numerical = -(e_plus - e_minus) / (2.0 * delta);
let ((f_sr, f_lr, _), _) = spme_pair_forces_energy(r1, r2, 1.0, -1.0, box_len, alpha);
let fx_computed = f_sr.x + f_lr.x;
println!(
"Force–gradient check at {tag}: fx_computed={fx_computed:.4} \
fx_numerical={fx_numerical:.4} kcal/(mol·Å)"
);
assert_rel_close(
fx_computed,
fx_numerical,
0.02,
&format!("force = -dE/dx at {tag}"),
);
}
}