use std::arch::x86_64::{_CMP_LT_OQ, _mm256_blendv_ps, _mm256_cmp_ps, _mm256_set1_ps};
use lin_alg::f32::{Vec3, Vec3x8, Vec3x16, f32x8, f32x16};
use statrs::function::erf::erfc;
use crate::INV_SQRT_PI;
pub fn force_coulomb_short_range(
dir: Vec3,
dist: f32,
inv_dist: f32,
q_0: f32,
q_1: f32,
cutoff_dist: f32,
α: f32,
) -> (Vec3, f32) {
if dist >= cutoff_dist {
return (Vec3::new_zero(), 0.);
}
let α_r = α * dist;
let erfc_term = erfc(α_r as f64) as f32;
let charge_term = q_0 * q_1;
let energy = charge_term * inv_dist * erfc_term;
let exp_term = (-α_r * α_r).exp();
let force_mag = charge_term
* (erfc_term * inv_dist * inv_dist + 2.0 * α * exp_term * INV_SQRT_PI * inv_dist);
(dir * force_mag, energy)
}
#[cfg(target_arch = "x86_64")]
pub fn force_coulomb_short_range_x8(
dir: Vec3x8,
dist: f32x8,
inv_dist: f32x8,
q_0: f32x8,
q_1: f32x8,
cutoff_dist: f32x8,
α: f32x8,
) -> (Vec3x8, f32x8) {
let α_r = α * dist;
let erfc_term = α_r.erfc();
let charge_term = q_0 * q_1;
let energy = charge_term * inv_dist * erfc_term;
let exp_term = (-α_r * α_r).exp();
let force_mag = charge_term
* (erfc_term * inv_dist * inv_dist
+ f32x8::splat(2.) * α * exp_term * f32x8::splat(INV_SQRT_PI) * inv_dist);
let force = dir * force_mag;
unsafe {
let keep = _mm256_cmp_ps::<{ _CMP_LT_OQ }>(dist.0, cutoff_dist.0);
let zero = _mm256_set1_ps(0.0);
let fx = _mm256_blendv_ps(zero, (force.x).0, keep);
let fy = _mm256_blendv_ps(zero, (force.y).0, keep);
let fz = _mm256_blendv_ps(zero, (force.z).0, keep);
let en = _mm256_blendv_ps(zero, energy.0, keep);
(
Vec3x8 {
x: f32x8(fx),
y: f32x8(fy),
z: f32x8(fz),
},
f32x8(en),
)
}
}
#[cfg(target_arch = "x86_64")]
pub fn force_coulomb_short_range_x16(
dir: Vec3x16,
dist: f32x16,
inv_dist: f32x16,
q_0: f32x16,
q_1: f32x16,
cutoff_dist: f32x16,
α: f32x16,
) -> (Vec3x16, f32x16) {
let α_r = α * dist;
let erfc_term = α_r.erfc();
let charge_term = q_0 * q_1;
let energy = charge_term * inv_dist * erfc_term;
let exp_term = (-α_r * α_r).exp();
let force_mag = charge_term
* (erfc_term * inv_dist * inv_dist
+ f32x16::splat(2.) * α * exp_term * f32x16::splat(INV_SQRT_PI) * inv_dist);
let force = dir * force_mag;
unsafe {
use core::arch::x86_64::*;
let keep: __mmask16 = _mm512_cmp_ps_mask::<{ _CMP_LT_OQ }>(dist.0, cutoff_dist.0);
let fx = _mm512_maskz_mov_ps(keep, (force.x).0);
let fy = _mm512_maskz_mov_ps(keep, (force.y).0);
let fz = _mm512_maskz_mov_ps(keep, (force.z).0);
let en = _mm512_maskz_mov_ps(keep, energy.0);
(
Vec3x16 {
x: f32x16(fx),
y: f32x16(fy),
z: f32x16(fz),
},
f32x16(en),
)
}
}