use std::cell::RefCell;
use integral_math::am::{cart_components, cart_index, n_cart};
use integral_math::boys::boys_array;
use crate::os::{Vec3, MAX_L};
#[derive(Debug, Clone, Copy)]
pub struct ShellRef<'a> {
pub center: Vec3,
pub l: usize,
pub exps: &'a [f64],
pub coeffs: &'a [f64],
}
#[inline]
fn n_addr(lmax: usize) -> usize {
(lmax + 1) * (lmax + 2) * (lmax + 3) / 6
}
#[inline]
fn tri_below(d: usize) -> usize {
d * (d + 1) * (d + 2) / 6
}
#[inline]
fn addr(t: [usize; 3]) -> usize {
let n = t[0] + t[1] + t[2];
let base = n * (n + 1) * (n + 2) / 6;
let within = (n - t[0]) * (n - t[0] + 1) / 2 + t[2];
base + within
}
#[derive(Debug, Default, Clone)]
pub struct EriScratch {
levels: Vec<f64>,
c_ef: Vec<f64>,
bra: Vec<f64>,
bra_prev: Vec<f64>,
bra_cur: Vec<f64>,
ket_prev: Vec<f64>,
ket_cur: Vec<f64>,
}
impl EriScratch {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn resident_f64(&self) -> usize {
self.levels.len()
+ self.c_ef.len()
+ self.bra.len()
+ self.bra_prev.len()
+ self.bra_cur.len()
+ self.ket_prev.len()
+ self.ket_cur.len()
}
#[must_use]
pub fn largest_buffer_f64(&self) -> usize {
[
self.levels.len(),
self.c_ef.len(),
self.bra.len(),
self.bra_prev.len(),
self.bra_cur.len(),
self.ket_prev.len(),
self.ket_cur.len(),
]
.into_iter()
.max()
.unwrap_or(0)
}
}
thread_local! {
static ERI_SCRATCH: RefCell<EriScratch> = RefCell::new(EriScratch::new());
}
#[inline]
fn ensure_len(v: &mut Vec<f64>, n: usize) {
if v.len() < n {
v.resize(n, 0.0);
}
}
pub fn coulomb_shell_into(
a: ShellRef<'_>,
b: ShellRef<'_>,
c: ShellRef<'_>,
d: ShellRef<'_>,
out: &mut [f64],
) {
ERI_SCRATCH.with(|s| coulomb_shell_into_scratch(&mut s.borrow_mut(), a, b, c, d, out));
}
pub fn coulomb_shell_into_scratch(
scratch: &mut EriScratch,
a: ShellRef<'_>,
b: ShellRef<'_>,
c: ShellRef<'_>,
d: ShellRef<'_>,
out: &mut [f64],
) {
let (la, lb, lc, ld) = (a.l, b.l, c.l, d.l);
debug_assert!(
la <= MAX_L && lb <= MAX_L && lc <= MAX_L && ld <= MAX_L,
"angular momentum exceeds MAX_L"
);
let (na, nb, nc, nd) = (n_cart(la), n_cart(lb), n_cart(lc), n_cart(ld));
debug_assert!(out.len() >= na * nb * nc * nd, "ERI output block too short");
let ne = la + lb; let nf = lc + ld; let l_total = ne + nf;
let n_e = n_addr(ne);
let n_f = n_addr(nf);
let EriScratch {
levels,
c_ef,
bra,
bra_prev,
bra_cur,
ket_prev,
ket_cur,
} = scratch;
ensure_len(c_ef, n_e * n_f);
c_ef[..n_e * n_f].fill(0.0);
let tri_e: Vec<Vec<[usize; 3]>> = (0..=ne).map(cart_components).collect();
let tri_f: Vec<Vec<[usize; 3]>> = (0..=nf).map(cart_components).collect();
let mut eoff: Vec<Vec<usize>> = Vec::with_capacity(nf + 1);
let mut slab: Vec<usize> = Vec::with_capacity(nf + 1);
for k in 0..=nf {
let mut off = Vec::with_capacity(n_e);
let mut run = 0usize;
for de in 0..=ne {
let mlen = l_total - de - k + 1; for _ in 0..n_cart(de) {
off.push(run);
run += mlen;
}
}
eoff.push(off);
slab.push(run);
}
let maxlevel = (0..=nf).map(|k| n_cart(k) * slab[k]).max().unwrap_or(1);
ensure_len(levels, 3 * maxlevel);
#[cfg(debug_assertions)]
levels[..3 * maxlevel].fill(f64::NAN);
for (&ea, &ca) in a.exps.iter().zip(a.coeffs.iter()) {
for (&eb, &cb) in b.exps.iter().zip(b.coeffs.iter()) {
let p = ea + eb;
let pc = combine(ea, a.center, eb, b.center, p);
let kab = (-(ea * eb / p) * dist2(a.center, b.center)).exp();
for (&ec, &cc) in c.exps.iter().zip(c.coeffs.iter()) {
for (&ed, &cd) in d.exps.iter().zip(d.coeffs.iter()) {
let q = ec + ed;
let qc = combine(ec, c.center, ed, d.center, q);
let kcd = (-(ec * ed / q) * dist2(c.center, d.center)).exp();
let scale = ca * cb * cc * cd;
vrr_primitive(
p, q, pc, qc, kab, kcd, a.center, c.center, ne, nf, l_total, n_e, n_f,
maxlevel, &eoff, &slab, &tri_e, &tri_f, levels, scale, c_ef,
);
}
}
}
}
let ab = sub(a.center, b.center); let cd = sub(c.center, d.center); hrr_and_scatter(
la, lb, lc, ld, n_f, c_ef, ab, cd, out, bra, bra_prev, bra_cur, ket_prev, ket_cur,
);
}
#[allow(clippy::too_many_arguments)]
fn vrr_primitive(
p: f64,
q: f64,
pc: Vec3,
qc: Vec3,
kab: f64,
kcd: f64,
a_center: Vec3,
c_center: Vec3,
ne: usize,
nf: usize,
l_total: usize,
n_e: usize,
n_f: usize,
maxlevel: usize,
eoff: &[Vec<usize>],
slab: &[usize],
tri_e: &[Vec<[usize; 3]>],
tri_f: &[Vec<[usize; 3]>],
levels: &mut [f64],
scale: f64,
c_ef: &mut [f64],
) {
let pq = p + q;
let rho = p * q / pq;
let pq_vec = sub(pc, qc); let t = rho * norm2(pq_vec);
let w = [
(p * pc[0] + q * qc[0]) / pq,
(p * pc[1] + q * qc[1]) / pq,
(p * pc[2] + q * qc[2]) / pq,
];
let pa = sub(pc, a_center); let wp = sub(w, pc); let qcen = sub(qc, c_center); let wq = sub(w, qc);
let off = |k: usize, lf: usize, ae: usize, m: usize| -> usize {
(k % 3) * maxlevel + lf * slab[k] + eoff[k][ae] + m
};
use std::f64::consts::PI;
let pref = 2.0 * PI * PI * PI.sqrt() / (p * q * pq.sqrt()) * kab * kcd;
let mut fm = [0.0f64; 4 * MAX_L + 1];
boys_array(l_total, t, &mut fm[..=l_total]);
for m in 0..=l_total {
levels[off(0, 0, 0, m)] = pref * fm[m];
}
let inv_2p = 0.5 / p;
let inv_2q = 0.5 / q;
let inv_2pq = 0.5 / pq;
let q_over_pq = q / pq;
let p_over_pq = p / pq;
let eoff0 = &eoff[0];
for (na, te_list) in tri_e.iter().enumerate().take(ne + 1).skip(1) {
for &te in te_list {
let i = lower_axis(te);
let s1a = addr(dec(te, i));
let mmax = l_total - na;
let coef2 = ((te[i] - 1) as f64) * inv_2p; let has2 = te[i] >= 2;
let s1_0 = eoff0[s1a];
let s2_0 = if has2 {
eoff0[addr(dec(dec(te, i), i))]
} else {
0
};
let dst_0 = eoff0[addr(te)];
for m in 0..=mmax {
let mut v = pa[i] * levels[s1_0 + m] + wp[i] * levels[s1_0 + m + 1];
if has2 {
v += coef2 * (levels[s2_0 + m] - q_over_pq * levels[s2_0 + m + 1]);
}
levels[dst_0 + m] = v;
}
}
}
for ae in 0..n_e {
c_ef[ae * n_f] += scale * levels[off(0, 0, ae, 0)];
}
for k in 1..=nf {
let slot_k = (k % 3) * maxlevel;
let slot_k1 = ((k - 1) % 3) * maxlevel;
let (slab_k, slab_k1) = (slab[k], slab[k - 1]);
let eoff_k = &eoff[k];
let eoff_k1 = &eoff[k - 1];
let (slot_k2, slab_k2, eoff_k2): (usize, usize, &[usize]) = if k >= 2 {
(((k - 2) % 3) * maxlevel, slab[k - 2], &eoff[k - 2])
} else {
(0, 0, &[])
};
for &tf in &tri_f[k] {
let j = lower_axis(tf);
let f1 = dec(tf, j);
let lf1 = cart_index(f1); let coef2 = ((tf[j] - 1) as f64) * inv_2q;
let has2 = tf[j] >= 2;
let lf2 = if has2 { cart_index(dec(f1, j)) } else { 0 }; let lf = cart_index(tf); let sk = slot_k + lf * slab_k;
let sk1 = slot_k1 + lf1 * slab_k1;
let sk2 = slot_k2 + lf2 * slab_k2;
for (nadeg, te_list) in tri_e.iter().enumerate().take(ne + 1) {
for &te in te_list {
let ea = addr(te);
let has_cross = te[j] >= 1;
let (cross_coef, cross_0) = if has_cross {
(te[j] as f64 * inv_2pq, sk1 + eoff_k1[addr(dec(te, j))])
} else {
(0.0, 0)
};
let mmax = l_total - nadeg - k;
let dst_0 = sk + eoff_k[ea];
let src1_0 = sk1 + eoff_k1[ea];
let src2_0 = if has2 { sk2 + eoff_k2[ea] } else { 0 };
for m in 0..=mmax {
let mut v = qcen[j] * levels[src1_0 + m] + wq[j] * levels[src1_0 + m + 1];
if has2 {
v += coef2 * (levels[src2_0 + m] - p_over_pq * levels[src2_0 + m + 1]);
}
if has_cross {
v += cross_coef * levels[cross_0 + m + 1];
}
levels[dst_0 + m] = v;
}
}
}
}
for &tf in &tri_f[k] {
let lf = cart_index(tf);
let af = addr(tf);
for ae in 0..n_e {
c_ef[ae * n_f + af] += scale * levels[off(k, lf, ae, 0)];
}
}
}
}
#[allow(clippy::too_many_arguments)]
fn hrr_and_scatter<'a>(
la: usize,
lb: usize,
lc: usize,
ld: usize,
n_f: usize,
c_ef: &[f64],
ab: Vec3,
cd: Vec3,
out: &mut [f64],
bra: &mut Vec<f64>,
mut prev: &'a mut Vec<f64>,
mut cur: &'a mut Vec<f64>,
mut kprev: &'a mut Vec<f64>,
mut kcur: &'a mut Vec<f64>,
) {
let (na, nb, nc, nd) = (n_cart(la), n_cart(lb), n_cart(lc), n_cart(ld));
let ne = la + lb; let nf = lc + ld; let n_e = n_addr(ne);
let tri: Vec<Vec<[usize; 3]>> = (0..=ne.max(nf)).map(cart_components).collect();
let f_base = tri_below(lc);
let nf_range = n_f - f_base;
let bra_len = na * nb * nf_range;
ensure_len(bra, bra_len);
#[cfg(debug_assertions)]
bra[..bra_len].fill(f64::NAN);
let layer_len = n_cart(lb) * n_e;
ensure_len(prev, layer_len);
ensure_len(cur, layer_len);
for f_global in f_base..n_f {
let jf = f_global - f_base;
for &a in tri[la..=ne].iter().flatten() {
let ae = addr(a);
prev[ae] = c_ef[ae * n_f + f_global];
}
for kb in 1..=lb {
for (ibw, &b) in tri[kb].iter().enumerate() {
let i = lower_axis(b);
let b1w = cart_index(dec(b, i));
for &a in tri[la..=(ne - kb)].iter().flatten() {
let ae = addr(a);
let a1e = addr(inc(a, i));
cur[ibw * n_e + ae] = prev[b1w * n_e + a1e] + ab[i] * prev[b1w * n_e + ae];
}
}
std::mem::swap(&mut prev, &mut cur);
}
for (ib, &b) in tri[lb].iter().enumerate() {
let ibw = cart_index(b); for (ia, &a) in tri[la].iter().enumerate() {
bra[(ia * nb + ib) * nf_range + jf] = prev[ibw * n_e + addr(a)];
}
}
}
let klayer_len = n_cart(ld) * n_f;
ensure_len(kprev, klayer_len);
ensure_len(kcur, klayer_len);
for ia in 0..na {
for ib in 0..nb {
let brarow = (ia * nb + ib) * nf_range;
for &c in tri[lc..=nf].iter().flatten() {
let ce = addr(c);
kprev[ce] = bra[brarow + (ce - f_base)];
}
for kd in 1..=ld {
for (idw, &d) in tri[kd].iter().enumerate() {
let j = lower_axis(d);
let d1w = cart_index(dec(d, j));
for &c in tri[lc..=(nf - kd)].iter().flatten() {
let ce = addr(c);
let c1e = addr(inc(c, j));
kcur[idw * n_f + ce] =
kprev[d1w * n_f + c1e] + cd[j] * kprev[d1w * n_f + ce];
}
}
std::mem::swap(&mut kprev, &mut kcur);
}
for (ic, &c) in tri[lc].iter().enumerate() {
let ce = addr(c);
for id in 0..nd {
out[((ia * nb + ib) * nc + ic) * nd + id] += kprev[id * n_f + ce];
}
}
}
}
}
#[inline]
fn lower_axis(t: [usize; 3]) -> usize {
if t[0] > 0 {
0
} else if t[1] > 0 {
1
} else {
2
}
}
#[inline]
fn dec(mut t: [usize; 3], i: usize) -> [usize; 3] {
t[i] -= 1;
t
}
#[inline]
fn inc(mut t: [usize; 3], i: usize) -> [usize; 3] {
t[i] += 1;
t
}
#[inline]
fn combine(a: f64, ca: Vec3, b: f64, cb: Vec3, p: f64) -> Vec3 {
[
(a * ca[0] + b * cb[0]) / p,
(a * ca[1] + b * cb[1]) / p,
(a * ca[2] + b * cb[2]) / p,
]
}
#[inline]
fn sub(u: Vec3, v: Vec3) -> Vec3 {
[u[0] - v[0], u[1] - v[1], u[2] - v[2]]
}
#[inline]
fn dist2(u: Vec3, v: Vec3) -> f64 {
norm2(sub(u, v))
}
#[inline]
fn norm2(u: Vec3) -> f64 {
u[0] * u[0] + u[1] * u[1] + u[2] * u[2]
}
#[cfg(test)]
mod tests {
use super::*;
fn s(center: Vec3, l: usize, exp: f64) -> (Vec3, usize, [f64; 1], [f64; 1]) {
(center, l, [exp], [1.0])
}
#[test]
fn ssss_matches_closed_form() {
let (ac, al, ae, acf) = s([0.0, 0.0, 0.0], 0, 0.8);
let (bc, bl, be, bcf) = s([0.0, 0.0, 0.7], 0, 1.3);
let (cc, cl, ce, ccf) = s([0.4, 0.0, 0.0], 0, 0.5);
let (dc, dl, de, dcf) = s([0.0, 0.6, 0.2], 0, 1.1);
let mut out = [0.0; 1];
coulomb_shell_into(
ShellRef {
center: ac,
l: al,
exps: &ae,
coeffs: &acf,
},
ShellRef {
center: bc,
l: bl,
exps: &be,
coeffs: &bcf,
},
ShellRef {
center: cc,
l: cl,
exps: &ce,
coeffs: &ccf,
},
ShellRef {
center: dc,
l: dl,
exps: &de,
coeffs: &dcf,
},
&mut out,
);
let p = 0.8 + 1.3;
let q = 0.5 + 1.1;
let pcen = combine(0.8, ac, 1.3, bc, p);
let qcen = combine(0.5, cc, 1.1, dc, q);
let kab = (-(0.8 * 1.3 / p) * dist2(ac, bc)).exp();
let kcd = (-(0.5 * 1.1 / q) * dist2(cc, dc)).exp();
let rho = p * q / (p + q);
let t = rho * dist2(pcen, qcen);
let mut fm = [0.0; 1];
boys_array(0, t, &mut fm);
use std::f64::consts::PI;
let expect = 2.0 * PI * PI * PI.sqrt() / (p * q * (p + q).sqrt()) * kab * kcd * fm[0];
assert!(
(out[0] - expect).abs() < 1e-14 * expect.abs(),
"ssss {} vs {}",
out[0],
expect
);
}
use crate::os::Prim;
use crate::rys::coulomb_into;
#[allow(clippy::too_many_arguments)]
fn os_block(
la: usize,
ea: f64,
ca: Vec3,
lb: usize,
eb: f64,
cb: Vec3,
lc: usize,
recc: f64,
ccc: Vec3,
ld: usize,
ed: f64,
cdd: Vec3,
) -> Vec<f64> {
let mut out = vec![0.0; n_cart(la) * n_cart(lb) * n_cart(lc) * n_cart(ld)];
let (ea1, eb1, ec1, ed1) = ([ea], [eb], [recc], [ed]);
let one = [1.0];
coulomb_shell_into(
ShellRef {
center: ca,
l: la,
exps: &ea1,
coeffs: &one,
},
ShellRef {
center: cb,
l: lb,
exps: &eb1,
coeffs: &one,
},
ShellRef {
center: ccc,
l: lc,
exps: &ec1,
coeffs: &one,
},
ShellRef {
center: cdd,
l: ld,
exps: &ed1,
coeffs: &one,
},
&mut out,
);
out
}
#[test]
fn matches_rys_engine_primitive_sweep() {
let ca = [0.0, 0.0, 0.0];
let cb = [0.5, -0.3, 0.2];
let cc = [-0.4, 0.6, -0.1];
let cd = [0.2, 0.4, 0.8];
let (ea, eb, ec, ed) = (0.9, 1.3, 0.7, 1.1);
let quartets = [
(0usize, 0usize, 0usize, 0usize),
(1, 0, 0, 0),
(0, 0, 1, 0),
(1, 1, 0, 0),
(1, 0, 1, 0),
(1, 1, 1, 1),
(2, 0, 0, 0),
(2, 1, 0, 0),
(2, 1, 2, 1),
(0, 1, 2, 3), (2, 2, 3, 3), (3, 0, 0, 1),
(4, 4, 4, 1), (6, 6, 1, 0), (3, 3, 3, 3), ];
for (la, lb, lc, ld) in quartets {
let os = os_block(la, ea, ca, lb, eb, cb, lc, ec, cc, ld, ed, cd);
let mut rys = vec![0.0; os.len()];
coulomb_into(
Prim::new(ea, ca, la),
Prim::new(eb, cb, lb),
Prim::new(ec, cc, lc),
Prim::new(ed, cd, ld),
1.0,
&mut rys,
);
assert_cross_engine_close(&os, &rys, &format!("({la}{lb}|{lc}{ld})"));
}
}
fn assert_cross_engine_close(os: &[f64], rys: &[f64], tag: &str) {
const ATOL: f64 = 1e-11;
const RTOL: f64 = 1e-10;
for (o, r) in os.iter().zip(rys.iter()) {
assert!(
(o - r).abs() <= ATOL + RTOL * r.abs(),
"{tag} OS vs Rys mismatch: {o} vs {r} (Δ={:e})",
(o - r).abs()
);
}
}
#[test]
fn contracted_quartet_matches_rys_sum() {
let ca = [0.0, 0.0, 0.0];
let cb = [0.6, -0.2, 0.1];
let cc = [-0.3, 0.5, -0.2];
let cd = [0.2, 0.3, 0.7];
let (la, lb, lc, ld) = (1usize, 1usize, 2usize, 0usize);
let ax = [1.4, 0.45];
let acf = [0.6, 0.5];
let bx = [0.9, 0.3];
let bcf = [0.55, 0.5];
let cx = [1.1, 0.4];
let ccf = [0.7, 0.4];
let dx = [0.8];
let dcf = [1.0];
let mut os = vec![0.0; n_cart(la) * n_cart(lb) * n_cart(lc) * n_cart(ld)];
coulomb_shell_into(
ShellRef {
center: ca,
l: la,
exps: &ax,
coeffs: &acf,
},
ShellRef {
center: cb,
l: lb,
exps: &bx,
coeffs: &bcf,
},
ShellRef {
center: cc,
l: lc,
exps: &cx,
coeffs: &ccf,
},
ShellRef {
center: cd,
l: ld,
exps: &dx,
coeffs: &dcf,
},
&mut os,
);
let mut rys = vec![0.0; os.len()];
for (&ea, &wa) in ax.iter().zip(acf.iter()) {
for (&eb, &wb) in bx.iter().zip(bcf.iter()) {
for (&ec, &wc) in cx.iter().zip(ccf.iter()) {
for (&ed, &wd) in dx.iter().zip(dcf.iter()) {
coulomb_into(
Prim::new(ea, ca, la),
Prim::new(eb, cb, lb),
Prim::new(ec, cc, lc),
Prim::new(ed, cd, ld),
wa * wb * wc * wd,
&mut rys,
);
}
}
}
}
for (o, r) in os.iter().zip(rys.iter()) {
assert!(
(o - r).abs() <= 1e-10 * r.abs().max(1e-12),
"contracted OS vs Rys mismatch: {o} vs {r}"
);
}
}
#[test]
fn addr_is_bijective() {
for lmax in 0..=6 {
let mut seen = vec![false; n_addr(lmax)];
for n in 0..=lmax {
for t in cart_components(n) {
let a = addr(t);
assert!(a < n_addr(lmax), "addr {a} out of range for lmax {lmax}");
assert!(!seen[a], "addr collision at {t:?}");
seen[a] = true;
}
}
assert!(
seen.iter().all(|&x| x),
"addr not surjective at lmax {lmax}"
);
}
}
}