use std::cell::RefCell;
use integral_math::am::{cart_components, cart_index, n_cart};
use integral_math::boys::{boys_array, boys_array2, boys_array4};
use crate::os::{Vec3, MAX_L};
#[derive(Debug, Clone, Copy)]
pub struct ShellRef<'a> {
pub center: Vec3,
pub l: usize,
pub exps: &'a [f64],
pub coeffs: &'a [f64],
}
#[inline]
fn n_addr(lmax: usize) -> usize {
(lmax + 1) * (lmax + 2) * (lmax + 3) / 6
}
#[inline]
fn tri_below(d: usize) -> usize {
d * (d + 1) * (d + 2) / 6
}
#[inline]
fn addr(t: [usize; 3]) -> usize {
let n = t[0] + t[1] + t[2];
let base = n * (n + 1) * (n + 2) / 6;
let within = (n - t[0]) * (n - t[0] + 1) / 2 + t[2];
base + within
}
#[derive(Debug, Default, Clone)]
pub struct EriScratch {
levels: Vec<f64>,
c_ef: Vec<f64>,
bra: Vec<f64>,
bra_prev: Vec<f64>,
bra_cur: Vec<f64>,
ket_prev: Vec<f64>,
ket_cur: Vec<f64>,
bra_pairs: Vec<PrimPair>,
ket_pairs: Vec<PrimPair>,
tri_all: Vec<Vec<[usize; 3]>>,
eoff: Vec<usize>,
slab: Vec<usize>,
}
#[derive(Debug, Clone, Copy, Default)]
struct PrimPair {
zeta: f64,
center: Vec3,
kappa: f64,
c1: f64,
c2: f64,
inv_2zeta: f64,
r1: Vec3,
}
const PAIR_NEGLIGIBLE: f64 = 1e-32;
fn build_pairs(out: &mut Vec<PrimPair>, s1: ShellRef<'_>, s2: ShellRef<'_>) {
out.clear();
out.reserve(s1.exps.len() * s2.exps.len());
let d2 = dist2(s1.center, s2.center);
for (&e1, &c1) in s1.exps.iter().zip(s1.coeffs.iter()) {
for (&e2, &c2) in s2.exps.iter().zip(s2.coeffs.iter()) {
let zeta = e1 + e2;
let kappa = (-(e1 * e2 / zeta) * d2).exp();
if kappa * (c1 * c2).abs() < PAIR_NEGLIGIBLE {
continue;
}
let center = combine(e1, s1.center, e2, s2.center, zeta);
out.push(PrimPair {
zeta,
center,
kappa,
c1,
c2,
inv_2zeta: 0.5 / zeta,
r1: sub(center, s1.center),
});
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ShellPairData {
pairs: Vec<PrimPair>,
}
impl ShellPairData {
#[must_use]
pub fn len(&self) -> usize {
self.pairs.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.pairs.is_empty()
}
}
#[must_use]
pub fn shell_pair_data(s1: ShellRef<'_>, s2: ShellRef<'_>) -> ShellPairData {
let mut pairs = Vec::new();
build_pairs(&mut pairs, s1, s2);
ShellPairData { pairs }
}
impl EriScratch {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn resident_f64(&self) -> usize {
self.levels.len()
+ self.c_ef.len()
+ self.bra.len()
+ self.bra_prev.len()
+ self.bra_cur.len()
+ self.ket_prev.len()
+ self.ket_cur.len()
}
#[must_use]
pub fn largest_buffer_f64(&self) -> usize {
[
self.levels.len(),
self.c_ef.len(),
self.bra.len(),
self.bra_prev.len(),
self.bra_cur.len(),
self.ket_prev.len(),
self.ket_cur.len(),
]
.into_iter()
.max()
.unwrap_or(0)
}
}
thread_local! {
static ERI_SCRATCH: RefCell<EriScratch> = RefCell::new(EriScratch::new());
}
#[inline]
fn ensure_len(v: &mut Vec<f64>, n: usize) {
if v.len() < n {
v.resize(n, 0.0);
}
}
#[inline]
fn ensure_usize(v: &mut Vec<usize>, n: usize) {
if v.len() < n {
v.resize(n, 0);
}
}
pub fn coulomb_shell_into(
a: ShellRef<'_>,
b: ShellRef<'_>,
c: ShellRef<'_>,
d: ShellRef<'_>,
out: &mut [f64],
) {
ERI_SCRATCH.with(|s| coulomb_shell_into_scratch(&mut s.borrow_mut(), a, b, c, d, out));
}
pub fn coulomb_shell_pairs_into(
a: ShellRef<'_>,
b: ShellRef<'_>,
c: ShellRef<'_>,
d: ShellRef<'_>,
bra_pairs: &ShellPairData,
ket_pairs: &ShellPairData,
out: &mut [f64],
) {
ERI_SCRATCH.with(|s| {
coulomb_shell_pairs_into_scratch(
&mut s.borrow_mut(),
a,
b,
c,
d,
bra_pairs,
ket_pairs,
out,
);
});
}
pub fn coulomb_shell_into_scratch(
scratch: &mut EriScratch,
a: ShellRef<'_>,
b: ShellRef<'_>,
c: ShellRef<'_>,
d: ShellRef<'_>,
out: &mut [f64],
) {
let mut bra_pairs = std::mem::take(&mut scratch.bra_pairs);
let mut ket_pairs = std::mem::take(&mut scratch.ket_pairs);
build_pairs(&mut bra_pairs, a, b);
build_pairs(&mut ket_pairs, c, d);
coulomb_shell_core(scratch, a, b, c, d, &bra_pairs, &ket_pairs, out);
scratch.bra_pairs = bra_pairs;
scratch.ket_pairs = ket_pairs;
}
#[allow(clippy::too_many_arguments)]
pub fn coulomb_shell_pairs_into_scratch(
scratch: &mut EriScratch,
a: ShellRef<'_>,
b: ShellRef<'_>,
c: ShellRef<'_>,
d: ShellRef<'_>,
bra_pairs: &ShellPairData,
ket_pairs: &ShellPairData,
out: &mut [f64],
) {
coulomb_shell_core(scratch, a, b, c, d, &bra_pairs.pairs, &ket_pairs.pairs, out);
}
#[allow(clippy::too_many_arguments)]
fn coulomb_shell_core(
scratch: &mut EriScratch,
a: ShellRef<'_>,
b: ShellRef<'_>,
c: ShellRef<'_>,
d: ShellRef<'_>,
bra_pairs: &[PrimPair],
ket_pairs: &[PrimPair],
out: &mut [f64],
) {
let (la, lb, lc, ld) = (a.l, b.l, c.l, d.l);
debug_assert!(
la <= MAX_L && lb <= MAX_L && lc <= MAX_L && ld <= MAX_L,
"angular momentum exceeds MAX_L"
);
let (na, nb, nc, nd) = (n_cart(la), n_cart(lb), n_cart(lc), n_cart(ld));
debug_assert!(out.len() >= na * nb * nc * nd, "ERI output block too short");
let ne = la + lb; let nf = lc + ld; let l_total = ne + nf;
let n_e = n_addr(ne);
let n_f = n_addr(nf);
let EriScratch {
levels,
c_ef,
bra,
bra_prev,
bra_cur,
ket_prev,
ket_cur,
tri_all,
eoff,
slab,
..
} = scratch;
if l_total == 0 {
let two_pi_2_5 =
2.0 * std::f64::consts::PI * std::f64::consts::PI * std::f64::consts::PI.sqrt();
let mut acc = 0.0;
let mut f0 = [0.0f64; 1];
let mut f1 = [0.0f64; 1];
for bra in bra_pairs.iter() {
let bc = bra.c1 * bra.c2;
let p = bra.zeta;
let mut kets = ket_pairs.chunks_exact(2);
for pair in kets.by_ref() {
let (k0, k1) = (&pair[0], &pair[1]);
let pq0 = p + k0.zeta;
let pq1 = p + k1.zeta;
let t0 = (p * k0.zeta / pq0) * dist2(bra.center, k0.center);
let t1 = (p * k1.zeta / pq1) * dist2(bra.center, k1.center);
let pref0 = two_pi_2_5 / (p * k0.zeta * pq0.sqrt()) * bra.kappa * k0.kappa;
let pref1 = two_pi_2_5 / (p * k1.zeta * pq1.sqrt()) * bra.kappa * k1.kappa;
boys_array2(0, t0, t1, &mut f0, &mut f1);
acc += (bc * k0.c1) * k0.c2 * (pref0 * f0[0]);
acc += (bc * k1.c1) * k1.c2 * (pref1 * f1[0]);
}
for ket in kets.remainder() {
let q = ket.zeta;
let pq = p + q;
let t = (p * q / pq) * dist2(bra.center, ket.center);
let pref = two_pi_2_5 / (p * q * pq.sqrt()) * bra.kappa * ket.kappa;
boys_array(0, t, &mut f0);
acc += (bc * ket.c1) * ket.c2 * (pref * f0[0]);
}
}
out[0] += acc;
return;
}
ensure_len(c_ef, n_e * n_f);
c_ef[..n_e * n_f].fill(0.0);
if tri_all.len() <= 2 * MAX_L {
*tri_all = (0..=2 * MAX_L).map(cart_components).collect();
}
match (ne, nf) {
(0, 1) => contract_class::<0, 1>(bra_pairs, ket_pairs, c_ef),
(1, 0) => contract_class::<1, 0>(bra_pairs, ket_pairs, c_ef),
(1, 1) => contract_class::<1, 1>(bra_pairs, ket_pairs, c_ef),
(0, 2) => contract_class::<0, 2>(bra_pairs, ket_pairs, c_ef),
(2, 0) => contract_class::<2, 0>(bra_pairs, ket_pairs, c_ef),
(1, 2) => contract_class::<1, 2>(bra_pairs, ket_pairs, c_ef),
(2, 1) => contract_class::<2, 1>(bra_pairs, ket_pairs, c_ef),
(2, 2) => contract_class::<2, 2>(bra_pairs, ket_pairs, c_ef),
(0, 3) => contract_class::<0, 3>(bra_pairs, ket_pairs, c_ef),
(3, 0) => contract_class::<3, 0>(bra_pairs, ket_pairs, c_ef),
(1, 3) => contract_class::<1, 3>(bra_pairs, ket_pairs, c_ef),
(3, 1) => contract_class::<3, 1>(bra_pairs, ket_pairs, c_ef),
(2, 3) => contract_class::<2, 3>(bra_pairs, ket_pairs, c_ef),
(3, 2) => contract_class::<3, 2>(bra_pairs, ket_pairs, c_ef),
(3, 3) => contract_class::<3, 3>(bra_pairs, ket_pairs, c_ef),
_ => {
ensure_usize(slab, nf + 1);
ensure_usize(eoff, (nf + 1) * n_e);
#[allow(clippy::needless_range_loop)]
for k in 0..=nf {
let base = k * n_e;
let mut run = 0usize;
let mut ae = 0usize;
for de in 0..=ne {
let mlen = l_total - de - k + 1; for _ in 0..n_cart(de) {
eoff[base + ae] = run;
ae += 1;
run += mlen;
}
}
slab[k] = run;
}
let maxlevel = (0..=nf).map(|k| n_cart(k) * slab[k]).max().unwrap_or(1);
ensure_len(levels, 3 * maxlevel);
#[cfg(debug_assertions)]
levels[..3 * maxlevel].fill(f64::NAN);
let (eoff_s, slab_s, tri_s) = (&eoff[..], &slab[..], &tri_all[..]);
for bra in bra_pairs.iter() {
let bra_coef = bra.c1 * bra.c2; for ket in ket_pairs.iter() {
let scale = (bra_coef * ket.c1) * ket.c2;
vrr_primitive(
bra.zeta,
ket.zeta,
bra.center,
ket.center,
bra.kappa,
ket.kappa,
bra.r1,
ket.r1,
bra.inv_2zeta,
ket.inv_2zeta,
ne,
nf,
l_total,
n_e,
n_f,
maxlevel,
eoff_s,
slab_s,
tri_s,
levels,
scale,
c_ef,
);
}
}
}
}
let ab = sub(a.center, b.center); let cd = sub(c.center, d.center); hrr_and_scatter(
la,
lb,
lc,
ld,
n_f,
c_ef,
ab,
cd,
out,
bra,
bra_prev,
bra_cur,
ket_prev,
ket_cur,
&tri_all[..],
);
}
const TRI1: [[usize; 3]; 3] = [[1, 0, 0], [0, 1, 0], [0, 0, 1]];
const TRI2: [[usize; 3]; 6] = [
[2, 0, 0],
[1, 1, 0],
[1, 0, 1],
[0, 2, 0],
[0, 1, 1],
[0, 0, 2],
];
const TRI3: [[usize; 3]; 10] = [
[3, 0, 0],
[2, 1, 0],
[2, 0, 1],
[1, 2, 0],
[1, 1, 1],
[1, 0, 2],
[0, 3, 0],
[0, 2, 1],
[0, 1, 2],
[0, 0, 3],
];
const E_COMP: [([usize; 3], usize); 20] = [
([0, 0, 0], 0),
([1, 0, 0], 1),
([0, 1, 0], 1),
([0, 0, 1], 1),
([2, 0, 0], 2),
([1, 1, 0], 2),
([1, 0, 1], 2),
([0, 2, 0], 2),
([0, 1, 1], 2),
([0, 0, 2], 2),
([3, 0, 0], 3),
([2, 1, 0], 3),
([2, 0, 1], 3),
([1, 2, 0], 3),
([1, 1, 1], 3),
([1, 0, 2], 3),
([0, 3, 0], 3),
([0, 2, 1], 3),
([0, 1, 2], 3),
([0, 0, 3], 3),
];
#[derive(Clone, Copy, Default)]
struct QuartetHeader {
wp: [f64; 3], wq: [f64; 3], pref: f64,
inv_2pq: f64,
q_over_pq: f64,
p_over_pq: f64,
scale: f64, t: f64, }
const STRIP: usize = 4;
fn contract_class<const NE: usize, const NF: usize>(
bra_pairs: &[PrimPair],
ket_pairs: &[PrimPair],
c_ef: &mut [f64],
) {
let mut ea = [0.0f64; 140];
let mut eb1 = [0.0f64; 360];
let mut eb2 = [0.0f64; 600];
let mut eb3 = [0.0f64; 800];
let lt = NE + NF;
let mut hdr = [QuartetHeader::default(); STRIP];
let mut fm = [[0.0f64; 7]; STRIP];
for bra in bra_pairs {
let bra_coef = bra.c1 * bra.c2;
let (p, pc) = (bra.zeta, bra.center);
let mut start = 0;
while start < ket_pairs.len() {
let strip = &ket_pairs[start..(start + STRIP).min(ket_pairs.len())];
let header = |h: &mut QuartetHeader, ket: &PrimPair| {
let (q, qc) = (ket.zeta, ket.center);
let pq = p + q;
let rho = p * q / pq;
let pq_vec = sub(pc, qc); h.t = rho * norm2(pq_vec);
let w = [
(p * pc[0] + q * qc[0]) / pq,
(p * pc[1] + q * qc[1]) / pq,
(p * pc[2] + q * qc[2]) / pq,
];
h.wp = sub(w, pc);
h.wq = sub(w, qc);
use std::f64::consts::PI;
h.pref = 2.0 * PI * PI * PI.sqrt() / (p * q * pq.sqrt()) * bra.kappa * ket.kappa;
h.inv_2pq = 0.5 / pq;
h.q_over_pq = q / pq;
h.p_over_pq = p / pq;
h.scale = (bra_coef * ket.c1) * ket.c2;
};
let len = strip.len();
let mut k = 0;
while k + 2 <= len {
header(&mut hdr[k], &strip[k]);
header(&mut hdr[k + 1], &strip[k + 1]);
let (f0, f1) = fm.split_at_mut(k + 1);
boys_array2(
lt,
hdr[k].t,
hdr[k + 1].t,
&mut f0[k][..=lt],
&mut f1[0][..=lt],
);
k += 2;
}
if k < len {
header(&mut hdr[k], &strip[k]);
boys_array(lt, hdr[k].t, &mut fm[k][..=lt]);
}
for (k, ket) in strip.iter().enumerate() {
vrr_small::<NE, NF>(
bra, ket, &hdr[k], &fm[k], &mut ea, &mut eb1, &mut eb2, &mut eb3, c_ef,
);
}
start += STRIP;
}
}
}
#[inline(always)]
#[allow(clippy::too_many_arguments)]
fn vrr_small<const NE: usize, const NF: usize>(
bra: &PrimPair,
ket: &PrimPair,
hdr: &QuartetHeader,
fm: &[f64; 7],
ea: &mut [f64; 140],
eb1: &mut [f64; 360],
eb2: &mut [f64; 600],
eb3: &mut [f64; 800],
c_ef: &mut [f64],
) {
let lt = NE + NF;
let n_e = n_addr(NE);
let n_f = n_addr(NF);
let sa = lt + 1;
let s1 = lt;
let s2 = lt.saturating_sub(1);
let s3 = lt.saturating_sub(2);
let pa = bra.r1; let qcen = ket.r1; let (inv_2p, inv_2q) = (bra.inv_2zeta, ket.inv_2zeta);
let QuartetHeader {
wp,
wq,
pref,
inv_2pq,
q_over_pq,
p_over_pq,
scale,
t: _,
} = *hdr;
for m in 0..=lt {
ea[m] = pref * fm[m];
}
if NE >= 1 {
for (iw, te) in TRI1.iter().enumerate() {
let i = lower_axis(*te);
for m in 0..=(lt - 1) {
ea[(1 + iw) * sa + m] = pa[i] * ea[m] + wp[i] * ea[m + 1];
}
}
}
if NE >= 2 {
for (iw, te) in TRI2.iter().enumerate() {
let i = lower_axis(*te);
let s1a = addr(dec(*te, i));
let has2 = te[i] >= 2;
let coef2 = ((te[i] - 1) as f64) * inv_2p;
let s2a = if has2 { addr(dec(dec(*te, i), i)) } else { 0 };
for m in 0..=(lt - 2) {
let mut v = pa[i] * ea[s1a * sa + m] + wp[i] * ea[s1a * sa + m + 1];
if has2 {
v += coef2 * (ea[s2a * sa + m] - q_over_pq * ea[s2a * sa + m + 1]);
}
ea[(4 + iw) * sa + m] = v;
}
}
}
if NE >= 3 {
for (iw, te) in TRI3.iter().enumerate() {
let i = lower_axis(*te);
let s1a = addr(dec(*te, i));
let has2 = te[i] >= 2;
let coef2 = ((te[i] - 1) as f64) * inv_2p;
let s2a = if has2 { addr(dec(dec(*te, i), i)) } else { 0 };
for m in 0..=(lt - 3) {
let mut v = pa[i] * ea[s1a * sa + m] + wp[i] * ea[s1a * sa + m + 1];
if has2 {
v += coef2 * (ea[s2a * sa + m] - q_over_pq * ea[s2a * sa + m + 1]);
}
ea[(10 + iw) * sa + m] = v;
}
}
}
for ae in 0..n_e {
c_ef[ae * n_f] += scale * ea[ae * sa];
}
if NF >= 1 {
for (fw, tf) in TRI1.iter().enumerate() {
let j = lower_axis(*tf);
for ae in 0..n_e {
let (te, _) = E_COMP[ae];
let has_cross = te[j] >= 1;
let cross_coef = te[j] as f64 * inv_2pq;
let cs = if has_cross { addr(dec(te, j)) } else { 0 };
let mtop = NF - 1;
for m in 0..=mtop {
let mut v = qcen[j] * ea[ae * sa + m] + wq[j] * ea[ae * sa + m + 1];
if has_cross {
v += cross_coef * ea[cs * sa + m + 1];
}
eb1[(ae * 3 + fw) * s1 + m] = v;
}
}
}
for ae in 0..n_e {
for fw in 0..3 {
c_ef[ae * n_f + 1 + fw] += scale * eb1[(ae * 3 + fw) * s1];
}
}
}
if NF >= 2 {
for (fw, tf) in TRI2.iter().enumerate() {
let j = lower_axis(*tf);
let f1 = dec(*tf, j);
let lf1 = cart_index(f1); let has2 = tf[j] >= 2; let coef2 = ((tf[j] - 1) as f64) * inv_2q;
for ae in 0..n_e {
let (te, _) = E_COMP[ae];
let has_cross = te[j] >= 1;
let cross_coef = te[j] as f64 * inv_2pq;
let cs = if has_cross { addr(dec(te, j)) } else { 0 };
let mtop = NF - 2;
for m in 0..=mtop {
let mut v = qcen[j] * eb1[(ae * 3 + lf1) * s1 + m]
+ wq[j] * eb1[(ae * 3 + lf1) * s1 + m + 1];
if has2 {
v += coef2 * (ea[ae * sa + m] - p_over_pq * ea[ae * sa + m + 1]);
}
if has_cross {
v += cross_coef * eb1[(cs * 3 + lf1) * s1 + m + 1];
}
eb2[(ae * 6 + fw) * s2 + m] = v;
}
}
}
for ae in 0..n_e {
for fw in 0..6 {
c_ef[ae * n_f + 4 + fw] += scale * eb2[(ae * 6 + fw) * s2];
}
}
}
if NF >= 3 {
for (fw, tf) in TRI3.iter().enumerate() {
let j = lower_axis(*tf);
let f1 = dec(*tf, j);
let lf1 = cart_index(f1); let has2 = tf[j] >= 2; let coef2 = ((tf[j] - 1) as f64) * inv_2q;
let lf2 = if has2 { cart_index(dec(f1, j)) } else { 0 };
for ae in 0..n_e {
let (te, _) = E_COMP[ae];
let has_cross = te[j] >= 1;
let cross_coef = te[j] as f64 * inv_2pq;
let cs = if has_cross { addr(dec(te, j)) } else { 0 };
let mtop = 0; for m in 0..=mtop {
let mut v = qcen[j] * eb2[(ae * 6 + lf1) * s2 + m]
+ wq[j] * eb2[(ae * 6 + lf1) * s2 + m + 1];
if has2 {
v += coef2
* (eb1[(ae * 3 + lf2) * s1 + m]
- p_over_pq * eb1[(ae * 3 + lf2) * s1 + m + 1]);
}
if has_cross {
v += cross_coef * eb2[(cs * 6 + lf1) * s2 + m + 1];
}
eb3[(ae * 10 + fw) * s3 + m] = v;
}
}
}
for ae in 0..n_e {
for fw in 0..10 {
c_ef[ae * n_f + 10 + fw] += scale * eb3[(ae * 10 + fw) * s3];
}
}
}
}
type V4 = [f64; 4];
#[inline(always)]
fn v4_mul(a: V4, b: V4) -> V4 {
[a[0] * b[0], a[1] * b[1], a[2] * b[2], a[3] * b[3]]
}
#[inline(always)]
fn v4_add(a: V4, b: V4) -> V4 {
[a[0] + b[0], a[1] + b[1], a[2] + b[2], a[3] + b[3]]
}
#[inline(always)]
fn v4_sub(a: V4, b: V4) -> V4 {
[a[0] - b[0], a[1] - b[1], a[2] - b[2], a[3] - b[3]]
}
#[inline(always)]
fn v4_scale(s: f64, a: V4) -> V4 {
[s * a[0], s * a[1], s * a[2], s * a[3]]
}
#[derive(Clone, Copy, Default)]
struct QuartetHeader4 {
wp: [V4; 3],
wq: [V4; 3],
pref: V4,
inv_2pq: V4,
q_over_pq: V4,
p_over_pq: V4,
scale: V4,
t: V4,
}
#[allow(clippy::needless_range_loop)]
fn contract_class4<const NE: usize, const NF: usize>(
bra_pairs: [&[PrimPair]; 4],
ket_pairs: [&[PrimPair]; 4],
c_ef: &mut [&mut [f64]; 4],
) {
let lt = NE + NF;
let n_bra = bra_pairs[0].len();
let n_ket = ket_pairs[0].len();
debug_assert!(bra_pairs.iter().all(|b| b.len() == n_bra));
debug_assert!(ket_pairs.iter().all(|k| k.len() == n_ket));
let mut ea = [[0.0f64; 4]; 140];
let mut eb1 = [[0.0f64; 4]; 360];
let mut eb2 = [[0.0f64; 4]; 600];
let mut eb3 = [[0.0f64; 4]; 800];
let mut hdr = QuartetHeader4::default();
let mut fm = [[0.0f64; 4]; 7];
for ib in 0..n_bra {
let bras = [
&bra_pairs[0][ib],
&bra_pairs[1][ib],
&bra_pairs[2][ib],
&bra_pairs[3][ib],
];
let mut pa = [[0.0f64; 4]; 3];
let mut inv_2p = [0.0f64; 4];
let mut bra_coef = [0.0f64; 4];
for (lane, bra) in bras.iter().enumerate() {
for (dst, &src) in pa.iter_mut().zip(bra.r1.iter()) {
dst[lane] = src;
}
inv_2p[lane] = bra.inv_2zeta;
bra_coef[lane] = bra.c1 * bra.c2;
}
for ik in 0..n_ket {
let kets = [
&ket_pairs[0][ik],
&ket_pairs[1][ik],
&ket_pairs[2][ik],
&ket_pairs[3][ik],
];
let mut qcen = [[0.0f64; 4]; 3];
let mut inv_2q = [0.0f64; 4];
for (lane, (bra, ket)) in bras.iter().zip(kets.iter()).enumerate() {
let (p, pc) = (bra.zeta, bra.center);
let (q, qc) = (ket.zeta, ket.center);
let pq = p + q;
let rho = p * q / pq;
let pq_vec = sub(pc, qc); hdr.t[lane] = rho * norm2(pq_vec);
let w = [
(p * pc[0] + q * qc[0]) / pq,
(p * pc[1] + q * qc[1]) / pq,
(p * pc[2] + q * qc[2]) / pq,
];
let wp = sub(w, pc);
let wq = sub(w, qc);
for ax in 0..3 {
hdr.wp[ax][lane] = wp[ax];
hdr.wq[ax][lane] = wq[ax];
qcen[ax][lane] = ket.r1[ax];
}
use std::f64::consts::PI;
hdr.pref[lane] =
2.0 * PI * PI * PI.sqrt() / (p * q * pq.sqrt()) * bra.kappa * ket.kappa;
hdr.inv_2pq[lane] = 0.5 / pq;
hdr.q_over_pq[lane] = q / pq;
hdr.p_over_pq[lane] = p / pq;
hdr.scale[lane] = (bra_coef[lane] * ket.c1) * ket.c2;
inv_2q[lane] = ket.inv_2zeta;
}
boys_array4(lt, hdr.t, &mut fm[..=lt]);
vrr_small4::<NE, NF>(
&pa, &qcen, inv_2p, inv_2q, &hdr, &fm, &mut ea, &mut eb1, &mut eb2, &mut eb3, c_ef,
);
}
}
}
#[inline(always)]
#[allow(clippy::too_many_arguments)]
fn vrr_small4<const NE: usize, const NF: usize>(
pa: &[V4; 3],
qcen: &[V4; 3],
inv_2p: V4,
inv_2q: V4,
hdr: &QuartetHeader4,
fm: &[V4; 7],
ea: &mut [V4; 140],
eb1: &mut [V4; 360],
eb2: &mut [V4; 600],
eb3: &mut [V4; 800],
c_ef: &mut [&mut [f64]; 4],
) {
const { assert!(NE <= 3 && NF <= 3, "batch kernel covers NE, NF <= 3") };
let lt = NE + NF;
let n_e = n_addr(NE);
let n_f = n_addr(NF);
let sa = lt + 1;
let s1 = lt;
let s2 = lt.saturating_sub(1);
let s3 = lt.saturating_sub(2);
let QuartetHeader4 {
wp,
wq,
pref,
inv_2pq,
q_over_pq,
p_over_pq,
scale,
t: _,
} = *hdr;
for m in 0..=lt {
ea[m] = v4_mul(pref, fm[m]);
}
if NE >= 1 {
for (iw, te) in TRI1.iter().enumerate() {
let i = lower_axis(*te);
for m in 0..=(lt - 1) {
ea[(1 + iw) * sa + m] = v4_add(v4_mul(pa[i], ea[m]), v4_mul(wp[i], ea[m + 1]));
}
}
}
if NE >= 2 {
for (iw, te) in TRI2.iter().enumerate() {
let i = lower_axis(*te);
let s1a = addr(dec(*te, i));
let has2 = te[i] >= 2;
let coef2 = v4_scale((te[i] - 1) as f64, inv_2p);
let s2a = if has2 { addr(dec(dec(*te, i), i)) } else { 0 };
for m in 0..=(lt - 2) {
let mut v = v4_add(
v4_mul(pa[i], ea[s1a * sa + m]),
v4_mul(wp[i], ea[s1a * sa + m + 1]),
);
if has2 {
v = v4_add(
v,
v4_mul(
coef2,
v4_sub(ea[s2a * sa + m], v4_mul(q_over_pq, ea[s2a * sa + m + 1])),
),
);
}
ea[(4 + iw) * sa + m] = v;
}
}
}
if NE >= 3 {
for (iw, te) in TRI3.iter().enumerate() {
let i = lower_axis(*te);
let s1a = addr(dec(*te, i));
let has2 = te[i] >= 2;
let coef2 = v4_scale((te[i] - 1) as f64, inv_2p);
let s2a = if has2 { addr(dec(dec(*te, i), i)) } else { 0 };
for m in 0..=(lt - 3) {
let mut v = v4_add(
v4_mul(pa[i], ea[s1a * sa + m]),
v4_mul(wp[i], ea[s1a * sa + m + 1]),
);
if has2 {
v = v4_add(
v,
v4_mul(
coef2,
v4_sub(ea[s2a * sa + m], v4_mul(q_over_pq, ea[s2a * sa + m + 1])),
),
);
}
ea[(10 + iw) * sa + m] = v;
}
}
}
for ae in 0..n_e {
let v = ea[ae * sa];
for lane in 0..4 {
c_ef[lane][ae * n_f] += scale[lane] * v[lane];
}
}
if NF >= 1 {
for (fw, tf) in TRI1.iter().enumerate() {
let j = lower_axis(*tf);
for ae in 0..n_e {
let (te, _) = E_COMP[ae];
let has_cross = te[j] >= 1;
let cross_coef = v4_scale(te[j] as f64, inv_2pq);
let cs = if has_cross { addr(dec(te, j)) } else { 0 };
let mtop = NF - 1;
for m in 0..=mtop {
let mut v = v4_add(
v4_mul(qcen[j], ea[ae * sa + m]),
v4_mul(wq[j], ea[ae * sa + m + 1]),
);
if has_cross {
v = v4_add(v, v4_mul(cross_coef, ea[cs * sa + m + 1]));
}
eb1[(ae * 3 + fw) * s1 + m] = v;
}
}
}
for ae in 0..n_e {
for fw in 0..3 {
let v = eb1[(ae * 3 + fw) * s1];
for lane in 0..4 {
c_ef[lane][ae * n_f + 1 + fw] += scale[lane] * v[lane];
}
}
}
}
if NF >= 2 {
for (fw, tf) in TRI2.iter().enumerate() {
let j = lower_axis(*tf);
let f1 = dec(*tf, j);
let lf1 = cart_index(f1);
let has2 = tf[j] >= 2; let coef2 = v4_scale((tf[j] - 1) as f64, inv_2q);
for ae in 0..n_e {
let (te, _) = E_COMP[ae];
let has_cross = te[j] >= 1;
let cross_coef = v4_scale(te[j] as f64, inv_2pq);
let cs = if has_cross { addr(dec(te, j)) } else { 0 };
let mtop = NF - 2;
for m in 0..=mtop {
let mut v = v4_add(
v4_mul(qcen[j], eb1[(ae * 3 + lf1) * s1 + m]),
v4_mul(wq[j], eb1[(ae * 3 + lf1) * s1 + m + 1]),
);
if has2 {
v = v4_add(
v,
v4_mul(
coef2,
v4_sub(ea[ae * sa + m], v4_mul(p_over_pq, ea[ae * sa + m + 1])),
),
);
}
if has_cross {
v = v4_add(v, v4_mul(cross_coef, eb1[(cs * 3 + lf1) * s1 + m + 1]));
}
eb2[(ae * 6 + fw) * s2 + m] = v;
}
}
}
for ae in 0..n_e {
for fw in 0..6 {
let v = eb2[(ae * 6 + fw) * s2];
for lane in 0..4 {
c_ef[lane][ae * n_f + 4 + fw] += scale[lane] * v[lane];
}
}
}
}
if NF >= 3 {
for (fw, tf) in TRI3.iter().enumerate() {
let j = lower_axis(*tf);
let f1 = dec(*tf, j);
let lf1 = cart_index(f1); let has2 = tf[j] >= 2; let coef2 = v4_scale((tf[j] - 1) as f64, inv_2q);
let lf2 = if has2 { cart_index(dec(f1, j)) } else { 0 };
for ae in 0..n_e {
let (te, _) = E_COMP[ae];
let has_cross = te[j] >= 1;
let cross_coef = v4_scale(te[j] as f64, inv_2pq);
let cs = if has_cross { addr(dec(te, j)) } else { 0 };
let mtop = 0; for m in 0..=mtop {
let mut v = v4_add(
v4_mul(qcen[j], eb2[(ae * 6 + lf1) * s2 + m]),
v4_mul(wq[j], eb2[(ae * 6 + lf1) * s2 + m + 1]),
);
if has2 {
v = v4_add(
v,
v4_mul(
coef2,
v4_sub(
eb1[(ae * 3 + lf2) * s1 + m],
v4_mul(p_over_pq, eb1[(ae * 3 + lf2) * s1 + m + 1]),
),
),
);
}
if has_cross {
v = v4_add(v, v4_mul(cross_coef, eb2[(cs * 6 + lf1) * s2 + m + 1]));
}
eb3[(ae * 10 + fw) * s3 + m] = v;
}
}
}
for ae in 0..n_e {
for fw in 0..10 {
let v = eb3[(ae * 10 + fw) * s3];
for lane in 0..4 {
c_ef[lane][ae * n_f + 10 + fw] += scale[lane] * v[lane];
}
}
}
}
}
#[allow(clippy::needless_range_loop)] fn contract_ssss4(
bra_pairs: [&[PrimPair]; 4],
ket_pairs: [&[PrimPair]; 4],
outs: &mut [&mut [f64]; 4],
) {
let n_bra = bra_pairs[0].len();
let n_ket = ket_pairs[0].len();
debug_assert!(bra_pairs.iter().all(|b| b.len() == n_bra));
debug_assert!(ket_pairs.iter().all(|k| k.len() == n_ket));
let two_pi_2_5 =
2.0 * std::f64::consts::PI * std::f64::consts::PI * std::f64::consts::PI.sqrt();
let mut acc = [0.0f64; 4];
let mut fm = [[0.0f64; 4]; 1];
for ib in 0..n_bra {
let bras = [
&bra_pairs[0][ib],
&bra_pairs[1][ib],
&bra_pairs[2][ib],
&bra_pairs[3][ib],
];
let mut bc = [0.0f64; 4];
for (lane, bra) in bras.iter().enumerate() {
bc[lane] = bra.c1 * bra.c2;
}
for ik in 0..n_ket {
let kets = [
&ket_pairs[0][ik],
&ket_pairs[1][ik],
&ket_pairs[2][ik],
&ket_pairs[3][ik],
];
let mut t = [0.0f64; 4];
let mut pref = [0.0f64; 4];
for (lane, (bra, ket)) in bras.iter().zip(kets.iter()).enumerate() {
let p = bra.zeta;
let q = ket.zeta;
let pq = p + q;
t[lane] = (p * q / pq) * dist2(bra.center, ket.center);
pref[lane] = two_pi_2_5 / (p * q * pq.sqrt()) * bra.kappa * ket.kappa;
}
boys_array4(0, t, &mut fm);
for (lane, ket) in kets.iter().enumerate() {
acc[lane] += (bc[lane] * ket.c1) * ket.c2 * (pref[lane] * fm[0][lane]);
}
}
}
for (lane, out) in outs.iter_mut().enumerate() {
out[0] += acc[lane];
}
}
#[must_use]
pub fn surviving_pair_count(s1: ShellRef<'_>, s2: ShellRef<'_>) -> usize {
let d2 = dist2(s1.center, s2.center);
let mut n = 0;
for (&e1, &c1) in s1.exps.iter().zip(s1.coeffs.iter()) {
for (&e2, &c2) in s2.exps.iter().zip(s2.coeffs.iter()) {
let zeta = e1 + e2;
let kappa = (-(e1 * e2 / zeta) * d2).exp();
if kappa * (c1 * c2).abs() >= PAIR_NEGLIGIBLE {
n += 1;
}
}
}
n
}
#[derive(Debug, Default)]
pub struct EriBatch4Scratch {
bra_pairs: [Vec<PrimPair>; 4],
ket_pairs: [Vec<PrimPair>; 4],
c_ef: [Vec<f64>; 4],
scalar: EriScratch,
}
pub fn coulomb_shell_batch4_into_scratch(
scratch: &mut EriBatch4Scratch,
quartets: &[[ShellRef<'_>; 4]; 4],
outs: &mut [&mut [f64]; 4],
) {
let mut bra_pairs = std::mem::take(&mut scratch.bra_pairs);
let mut ket_pairs = std::mem::take(&mut scratch.ket_pairs);
for (lane, &[a, b, c, d]) in quartets.iter().enumerate() {
build_pairs(&mut bra_pairs[lane], a, b);
build_pairs(&mut ket_pairs[lane], c, d);
}
let bra = [
&bra_pairs[0][..],
&bra_pairs[1][..],
&bra_pairs[2][..],
&bra_pairs[3][..],
];
let ket = [
&ket_pairs[0][..],
&ket_pairs[1][..],
&ket_pairs[2][..],
&ket_pairs[3][..],
];
coulomb_shell_batch4_core(scratch, quartets, bra, ket, outs);
scratch.bra_pairs = bra_pairs;
scratch.ket_pairs = ket_pairs;
}
pub fn coulomb_shell_batch4_pairs_into_scratch(
scratch: &mut EriBatch4Scratch,
quartets: &[[ShellRef<'_>; 4]; 4],
bra_pairs: [&ShellPairData; 4],
ket_pairs: [&ShellPairData; 4],
outs: &mut [&mut [f64]; 4],
) {
coulomb_shell_batch4_core(
scratch,
quartets,
bra_pairs.map(|p| &p.pairs[..]),
ket_pairs.map(|p| &p.pairs[..]),
outs,
);
}
fn coulomb_shell_batch4_core(
scratch: &mut EriBatch4Scratch,
quartets: &[[ShellRef<'_>; 4]; 4],
bra: [&[PrimPair]; 4],
ket: [&[PrimPair]; 4],
outs: &mut [&mut [f64]; 4],
) {
let [a0, b0, c0, d0] = quartets[0];
let ne = a0.l + b0.l;
let nf = c0.l + d0.l;
let lanes_match = quartets
.iter()
.all(|[a, b, c, d]| a.l + b.l == ne && c.l + d.l == nf);
let n_bra = bra[0].len();
let n_ket = ket[0].len();
let lockstep = lanes_match
&& ne <= 3
&& nf <= 3
&& bra.iter().all(|p| p.len() == n_bra)
&& ket.iter().all(|p| p.len() == n_ket);
if !lockstep {
for (lane, &[a, b, c, d]) in quartets.iter().enumerate() {
coulomb_shell_core(
&mut scratch.scalar,
a,
b,
c,
d,
bra[lane],
ket[lane],
outs[lane],
);
}
return;
}
if ne + nf == 0 {
contract_ssss4(bra, ket, outs);
return;
}
let n_e = n_addr(ne);
let n_f = n_addr(nf);
for lane in 0..4 {
ensure_len(&mut scratch.c_ef[lane], n_e * n_f);
scratch.c_ef[lane][..n_e * n_f].fill(0.0);
}
{
let [c0, c1, c2, c3] = &mut scratch.c_ef;
let mut c_ef: [&mut [f64]; 4] = [
&mut c0[..n_e * n_f],
&mut c1[..n_e * n_f],
&mut c2[..n_e * n_f],
&mut c3[..n_e * n_f],
];
match (ne, nf) {
(0, 1) => contract_class4::<0, 1>(bra, ket, &mut c_ef),
(1, 0) => contract_class4::<1, 0>(bra, ket, &mut c_ef),
(1, 1) => contract_class4::<1, 1>(bra, ket, &mut c_ef),
(0, 2) => contract_class4::<0, 2>(bra, ket, &mut c_ef),
(2, 0) => contract_class4::<2, 0>(bra, ket, &mut c_ef),
(1, 2) => contract_class4::<1, 2>(bra, ket, &mut c_ef),
(2, 1) => contract_class4::<2, 1>(bra, ket, &mut c_ef),
(2, 2) => contract_class4::<2, 2>(bra, ket, &mut c_ef),
(0, 3) => contract_class4::<0, 3>(bra, ket, &mut c_ef),
(3, 0) => contract_class4::<3, 0>(bra, ket, &mut c_ef),
(1, 3) => contract_class4::<1, 3>(bra, ket, &mut c_ef),
(3, 1) => contract_class4::<3, 1>(bra, ket, &mut c_ef),
(2, 3) => contract_class4::<2, 3>(bra, ket, &mut c_ef),
(3, 2) => contract_class4::<3, 2>(bra, ket, &mut c_ef),
(3, 3) => contract_class4::<3, 3>(bra, ket, &mut c_ef),
_ => unreachable!("guarded above"),
}
}
let EriScratch {
bra,
bra_prev,
bra_cur,
ket_prev,
ket_cur,
tri_all,
..
} = &mut scratch.scalar;
if tri_all.len() <= 2 * MAX_L {
*tri_all = (0..=2 * MAX_L).map(cart_components).collect();
}
for (lane, &[a, b, c, d]) in quartets.iter().enumerate() {
let ab = sub(a.center, b.center);
let cd = sub(c.center, d.center);
hrr_and_scatter(
a.l,
b.l,
c.l,
d.l,
n_f,
&scratch.c_ef[lane],
ab,
cd,
outs[lane],
bra,
bra_prev,
bra_cur,
ket_prev,
ket_cur,
&tri_all[..],
);
}
}
#[allow(clippy::too_many_arguments)]
fn vrr_primitive(
p: f64,
q: f64,
pc: Vec3,
qc: Vec3,
kab: f64,
kcd: f64,
pa: Vec3, qcen: Vec3, inv_2p: f64, inv_2q: f64, ne: usize,
nf: usize,
l_total: usize,
n_e: usize,
n_f: usize,
maxlevel: usize,
eoff: &[usize],
slab: &[usize],
tri: &[Vec<[usize; 3]>],
levels: &mut [f64],
scale: f64,
c_ef: &mut [f64],
) {
let pq = p + q;
let rho = p * q / pq;
let pq_vec = sub(pc, qc); let t = rho * norm2(pq_vec);
let w = [
(p * pc[0] + q * qc[0]) / pq,
(p * pc[1] + q * qc[1]) / pq,
(p * pc[2] + q * qc[2]) / pq,
];
let wp = sub(w, pc); let wq = sub(w, qc);
let off = |k: usize, lf: usize, ae: usize, m: usize| -> usize {
(k % 3) * maxlevel + lf * slab[k] + eoff[k * n_e + ae] + m
};
use std::f64::consts::PI;
let pref = 2.0 * PI * PI * PI.sqrt() / (p * q * pq.sqrt()) * kab * kcd;
let mut fm = [0.0f64; 4 * MAX_L + 1];
boys_array(l_total, t, &mut fm[..=l_total]);
for m in 0..=l_total {
levels[off(0, 0, 0, m)] = pref * fm[m];
}
let inv_2pq = 0.5 / pq;
let q_over_pq = q / pq;
let p_over_pq = p / pq;
let eoff0 = &eoff[..n_e];
for (na, te_list) in tri.iter().enumerate().take(ne + 1).skip(1) {
for &te in te_list {
let i = lower_axis(te);
let s1a = addr(dec(te, i));
let mmax = l_total - na;
let coef2 = ((te[i] - 1) as f64) * inv_2p; let has2 = te[i] >= 2;
let s1_0 = eoff0[s1a];
let s2_0 = if has2 {
eoff0[addr(dec(dec(te, i), i))]
} else {
0
};
let dst_0 = eoff0[addr(te)];
for m in 0..=mmax {
let mut v = pa[i] * levels[s1_0 + m] + wp[i] * levels[s1_0 + m + 1];
if has2 {
v += coef2 * (levels[s2_0 + m] - q_over_pq * levels[s2_0 + m + 1]);
}
levels[dst_0 + m] = v;
}
}
}
for ae in 0..n_e {
c_ef[ae * n_f] += scale * levels[off(0, 0, ae, 0)];
}
for k in 1..=nf {
let slot_k = (k % 3) * maxlevel;
let slot_k1 = ((k - 1) % 3) * maxlevel;
let (slab_k, slab_k1) = (slab[k], slab[k - 1]);
let eoff_k = &eoff[k * n_e..];
let eoff_k1 = &eoff[(k - 1) * n_e..];
let (slot_k2, slab_k2, eoff_k2): (usize, usize, &[usize]) = if k >= 2 {
(
((k - 2) % 3) * maxlevel,
slab[k - 2],
&eoff[(k - 2) * n_e..],
)
} else {
(0, 0, &[])
};
for &tf in &tri[k] {
let j = lower_axis(tf);
let f1 = dec(tf, j);
let lf1 = cart_index(f1); let coef2 = ((tf[j] - 1) as f64) * inv_2q;
let has2 = tf[j] >= 2;
let lf2 = if has2 { cart_index(dec(f1, j)) } else { 0 }; let lf = cart_index(tf); let sk = slot_k + lf * slab_k;
let sk1 = slot_k1 + lf1 * slab_k1;
let sk2 = slot_k2 + lf2 * slab_k2;
for (nadeg, te_list) in tri.iter().enumerate().take(ne + 1) {
for &te in te_list {
let ea = addr(te);
let has_cross = te[j] >= 1;
let (cross_coef, cross_0) = if has_cross {
(te[j] as f64 * inv_2pq, sk1 + eoff_k1[addr(dec(te, j))])
} else {
(0.0, 0)
};
let mmax = l_total - nadeg - k;
let dst_0 = sk + eoff_k[ea];
let src1_0 = sk1 + eoff_k1[ea];
let src2_0 = if has2 { sk2 + eoff_k2[ea] } else { 0 };
for m in 0..=mmax {
let mut v = qcen[j] * levels[src1_0 + m] + wq[j] * levels[src1_0 + m + 1];
if has2 {
v += coef2 * (levels[src2_0 + m] - p_over_pq * levels[src2_0 + m + 1]);
}
if has_cross {
v += cross_coef * levels[cross_0 + m + 1];
}
levels[dst_0 + m] = v;
}
}
}
}
for &tf in &tri[k] {
let lf = cart_index(tf);
let af = addr(tf);
for ae in 0..n_e {
c_ef[ae * n_f + af] += scale * levels[off(k, lf, ae, 0)];
}
}
}
}
#[allow(clippy::too_many_arguments)]
fn hrr_and_scatter<'a>(
la: usize,
lb: usize,
lc: usize,
ld: usize,
n_f: usize,
c_ef: &[f64],
ab: Vec3,
cd: Vec3,
out: &mut [f64],
bra: &mut Vec<f64>,
mut prev: &'a mut Vec<f64>,
mut cur: &'a mut Vec<f64>,
mut kprev: &'a mut Vec<f64>,
mut kcur: &'a mut Vec<f64>,
tri: &[Vec<[usize; 3]>],
) {
let (na, nb, nc, nd) = (n_cart(la), n_cart(lb), n_cart(lc), n_cart(ld));
let ne = la + lb; let nf = lc + ld; let n_e = n_addr(ne);
let f_base = tri_below(lc);
let nf_range = n_f - f_base;
let bra_len = na * nb * nf_range;
ensure_len(bra, bra_len);
#[cfg(debug_assertions)]
bra[..bra_len].fill(f64::NAN);
let layer_len = n_cart(lb) * n_e;
ensure_len(prev, layer_len);
ensure_len(cur, layer_len);
for f_global in f_base..n_f {
let jf = f_global - f_base;
for &a in tri[la..=ne].iter().flatten() {
let ae = addr(a);
prev[ae] = c_ef[ae * n_f + f_global];
}
for kb in 1..=lb {
for (ibw, &b) in tri[kb].iter().enumerate() {
let i = lower_axis(b);
let b1w = cart_index(dec(b, i));
for &a in tri[la..=(ne - kb)].iter().flatten() {
let ae = addr(a);
let a1e = addr(inc(a, i));
cur[ibw * n_e + ae] = prev[b1w * n_e + a1e] + ab[i] * prev[b1w * n_e + ae];
}
}
std::mem::swap(&mut prev, &mut cur);
}
for (ib, &b) in tri[lb].iter().enumerate() {
let ibw = cart_index(b); for (ia, &a) in tri[la].iter().enumerate() {
bra[(ia * nb + ib) * nf_range + jf] = prev[ibw * n_e + addr(a)];
}
}
}
let klayer_len = n_cart(ld) * n_f;
ensure_len(kprev, klayer_len);
ensure_len(kcur, klayer_len);
for ia in 0..na {
for ib in 0..nb {
let brarow = (ia * nb + ib) * nf_range;
for &c in tri[lc..=nf].iter().flatten() {
let ce = addr(c);
kprev[ce] = bra[brarow + (ce - f_base)];
}
for kd in 1..=ld {
for (idw, &d) in tri[kd].iter().enumerate() {
let j = lower_axis(d);
let d1w = cart_index(dec(d, j));
for &c in tri[lc..=(nf - kd)].iter().flatten() {
let ce = addr(c);
let c1e = addr(inc(c, j));
kcur[idw * n_f + ce] =
kprev[d1w * n_f + c1e] + cd[j] * kprev[d1w * n_f + ce];
}
}
std::mem::swap(&mut kprev, &mut kcur);
}
for (ic, &c) in tri[lc].iter().enumerate() {
let ce = addr(c);
for id in 0..nd {
out[((ia * nb + ib) * nc + ic) * nd + id] += kprev[id * n_f + ce];
}
}
}
}
}
#[inline]
fn lower_axis(t: [usize; 3]) -> usize {
if t[0] > 0 {
0
} else if t[1] > 0 {
1
} else {
2
}
}
#[inline]
fn dec(mut t: [usize; 3], i: usize) -> [usize; 3] {
t[i] -= 1;
t
}
#[inline]
fn inc(mut t: [usize; 3], i: usize) -> [usize; 3] {
t[i] += 1;
t
}
#[inline]
fn combine(a: f64, ca: Vec3, b: f64, cb: Vec3, p: f64) -> Vec3 {
[
(a * ca[0] + b * cb[0]) / p,
(a * ca[1] + b * cb[1]) / p,
(a * ca[2] + b * cb[2]) / p,
]
}
#[inline]
fn sub(u: Vec3, v: Vec3) -> Vec3 {
[u[0] - v[0], u[1] - v[1], u[2] - v[2]]
}
#[inline]
fn dist2(u: Vec3, v: Vec3) -> f64 {
norm2(sub(u, v))
}
#[inline]
fn norm2(u: Vec3) -> f64 {
u[0] * u[0] + u[1] * u[1] + u[2] * u[2]
}
#[cfg(test)]
mod tests {
use super::*;
fn s(center: Vec3, l: usize, exp: f64) -> (Vec3, usize, [f64; 1], [f64; 1]) {
(center, l, [exp], [1.0])
}
#[test]
fn ssss_matches_closed_form() {
let (ac, al, ae, acf) = s([0.0, 0.0, 0.0], 0, 0.8);
let (bc, bl, be, bcf) = s([0.0, 0.0, 0.7], 0, 1.3);
let (cc, cl, ce, ccf) = s([0.4, 0.0, 0.0], 0, 0.5);
let (dc, dl, de, dcf) = s([0.0, 0.6, 0.2], 0, 1.1);
let mut out = [0.0; 1];
coulomb_shell_into(
ShellRef {
center: ac,
l: al,
exps: &ae,
coeffs: &acf,
},
ShellRef {
center: bc,
l: bl,
exps: &be,
coeffs: &bcf,
},
ShellRef {
center: cc,
l: cl,
exps: &ce,
coeffs: &ccf,
},
ShellRef {
center: dc,
l: dl,
exps: &de,
coeffs: &dcf,
},
&mut out,
);
let p = 0.8 + 1.3;
let q = 0.5 + 1.1;
let pcen = combine(0.8, ac, 1.3, bc, p);
let qcen = combine(0.5, cc, 1.1, dc, q);
let kab = (-(0.8 * 1.3 / p) * dist2(ac, bc)).exp();
let kcd = (-(0.5 * 1.1 / q) * dist2(cc, dc)).exp();
let rho = p * q / (p + q);
let t = rho * dist2(pcen, qcen);
let mut fm = [0.0; 1];
boys_array(0, t, &mut fm);
use std::f64::consts::PI;
let expect = 2.0 * PI * PI * PI.sqrt() / (p * q * (p + q).sqrt()) * kab * kcd * fm[0];
assert!(
(out[0] - expect).abs() < 1e-14 * expect.abs(),
"ssss {} vs {}",
out[0],
expect
);
}
use crate::os::Prim;
use crate::rys::coulomb_into;
#[allow(clippy::too_many_arguments)]
fn os_block(
la: usize,
ea: f64,
ca: Vec3,
lb: usize,
eb: f64,
cb: Vec3,
lc: usize,
recc: f64,
ccc: Vec3,
ld: usize,
ed: f64,
cdd: Vec3,
) -> Vec<f64> {
let mut out = vec![0.0; n_cart(la) * n_cart(lb) * n_cart(lc) * n_cart(ld)];
let (ea1, eb1, ec1, ed1) = ([ea], [eb], [recc], [ed]);
let one = [1.0];
coulomb_shell_into(
ShellRef {
center: ca,
l: la,
exps: &ea1,
coeffs: &one,
},
ShellRef {
center: cb,
l: lb,
exps: &eb1,
coeffs: &one,
},
ShellRef {
center: ccc,
l: lc,
exps: &ec1,
coeffs: &one,
},
ShellRef {
center: cdd,
l: ld,
exps: &ed1,
coeffs: &one,
},
&mut out,
);
out
}
#[test]
fn matches_rys_engine_primitive_sweep() {
let ca = [0.0, 0.0, 0.0];
let cb = [0.5, -0.3, 0.2];
let cc = [-0.4, 0.6, -0.1];
let cd = [0.2, 0.4, 0.8];
let (ea, eb, ec, ed) = (0.9, 1.3, 0.7, 1.1);
let quartets = [
(0usize, 0usize, 0usize, 0usize),
(1, 0, 0, 0),
(0, 0, 1, 0),
(1, 1, 0, 0),
(1, 0, 1, 0),
(1, 1, 1, 1),
(2, 0, 0, 0),
(2, 1, 0, 0),
(2, 1, 2, 1),
(0, 1, 2, 3), (2, 2, 3, 3), (3, 0, 0, 1),
(4, 4, 4, 1), (6, 6, 1, 0), (3, 3, 3, 3), ];
for (la, lb, lc, ld) in quartets {
let os = os_block(la, ea, ca, lb, eb, cb, lc, ec, cc, ld, ed, cd);
let mut rys = vec![0.0; os.len()];
coulomb_into(
Prim::new(ea, ca, la),
Prim::new(eb, cb, lb),
Prim::new(ec, cc, lc),
Prim::new(ed, cd, ld),
1.0,
&mut rys,
);
assert_cross_engine_close(&os, &rys, &format!("({la}{lb}|{lc}{ld})"));
}
}
fn assert_cross_engine_close(os: &[f64], rys: &[f64], tag: &str) {
const ATOL: f64 = 1e-11;
const RTOL: f64 = 1e-10;
for (o, r) in os.iter().zip(rys.iter()) {
assert!(
(o - r).abs() <= ATOL + RTOL * r.abs(),
"{tag} OS vs Rys mismatch: {o} vs {r} (Δ={:e})",
(o - r).abs()
);
}
}
#[test]
fn contracted_quartet_matches_rys_sum() {
let ca = [0.0, 0.0, 0.0];
let cb = [0.6, -0.2, 0.1];
let cc = [-0.3, 0.5, -0.2];
let cd = [0.2, 0.3, 0.7];
let (la, lb, lc, ld) = (1usize, 1usize, 2usize, 0usize);
let ax = [1.4, 0.45];
let acf = [0.6, 0.5];
let bx = [0.9, 0.3];
let bcf = [0.55, 0.5];
let cx = [1.1, 0.4];
let ccf = [0.7, 0.4];
let dx = [0.8];
let dcf = [1.0];
let mut os = vec![0.0; n_cart(la) * n_cart(lb) * n_cart(lc) * n_cart(ld)];
coulomb_shell_into(
ShellRef {
center: ca,
l: la,
exps: &ax,
coeffs: &acf,
},
ShellRef {
center: cb,
l: lb,
exps: &bx,
coeffs: &bcf,
},
ShellRef {
center: cc,
l: lc,
exps: &cx,
coeffs: &ccf,
},
ShellRef {
center: cd,
l: ld,
exps: &dx,
coeffs: &dcf,
},
&mut os,
);
let mut rys = vec![0.0; os.len()];
for (&ea, &wa) in ax.iter().zip(acf.iter()) {
for (&eb, &wb) in bx.iter().zip(bcf.iter()) {
for (&ec, &wc) in cx.iter().zip(ccf.iter()) {
for (&ed, &wd) in dx.iter().zip(dcf.iter()) {
coulomb_into(
Prim::new(ea, ca, la),
Prim::new(eb, cb, lb),
Prim::new(ec, cc, lc),
Prim::new(ed, cd, ld),
wa * wb * wc * wd,
&mut rys,
);
}
}
}
}
for (o, r) in os.iter().zip(rys.iter()) {
assert!(
(o - r).abs() <= 1e-10 * r.abs().max(1e-12),
"contracted OS vs Rys mismatch: {o} vs {r}"
);
}
}
#[test]
fn addr_is_bijective() {
for lmax in 0..=6 {
let mut seen = vec![false; n_addr(lmax)];
for n in 0..=lmax {
for t in cart_components(n) {
let a = addr(t);
assert!(a < n_addr(lmax), "addr {a} out of range for lmax {lmax}");
assert!(!seen[a], "addr collision at {t:?}");
seen[a] = true;
}
}
assert!(
seen.iter().all(|&x| x),
"addr not surjective at lmax {lmax}"
);
}
}
}