use crate::integrals::{
canonical_shell_pairs, effective_coeffs, quartet_into_scratch, Engine, QuartetScratch, PERMS8,
};
use crate::shell::{Basis, Shell};
use crate::spherical::shell_transform;
#[derive(Debug)]
pub struct EriBuilder<'b> {
shells: &'b [Shell],
engine: Engine,
offs: Vec<usize>,
nfunc: Vec<usize>,
nao: usize,
eff: Vec<Vec<f64>>,
c2s: Vec<Option<Vec<f64>>>,
pairs: Vec<(usize, usize)>,
}
impl<'b> EriBuilder<'b> {
#[must_use]
pub fn new(basis: &'b Basis) -> Self {
Self::with_engine(basis, Engine::Auto)
}
#[must_use]
pub fn with_engine(basis: &'b Basis, engine: Engine) -> Self {
let shells = basis.shells();
let offs = basis.offsets();
let nfunc: Vec<usize> = shells.iter().map(Shell::n_func).collect();
let nao = basis.nao();
let eff: Vec<Vec<f64>> = shells.iter().map(effective_coeffs).collect();
let c2s: Vec<Option<Vec<f64>>> = shells.iter().map(shell_transform).collect();
let pairs = canonical_shell_pairs(shells.len());
EriBuilder {
shells,
engine,
offs,
nfunc,
nao,
eff,
c2s,
pairs,
}
}
#[must_use]
pub fn bra_pairs(&self) -> &[(usize, usize)] {
&self.pairs
}
#[must_use]
pub fn output_len(&self) -> usize {
let n = self.nao;
n * n * n * n
}
#[must_use]
pub fn partition<'o>(&self, out: &'o mut [f64]) -> Vec<BraPairFill<'o>> {
let nao = self.nao;
let plane = nao * nao; assert_eq!(
out.len(),
plane * plane,
"ERI output buffer must be nao⁴ = {} elements",
plane * plane
);
let mut slabs: Vec<Option<&'o mut [f64]>> = out.chunks_exact_mut(plane).map(Some).collect();
debug_assert_eq!(slabs.len(), plane);
let mut tasks = Vec::with_capacity(self.pairs.len());
for &(i, j) in &self.pairs {
let (ni, nj) = (self.nfunc[i], self.nfunc[j]);
let (oi, oj) = (self.offs[i], self.offs[j]);
let mut ij_band = Vec::with_capacity(ni * nj);
for a in 0..ni {
for b in 0..nj {
let row = (oi + a) * nao + (oj + b);
ij_band.push(claim_row(&mut slabs, row));
}
}
let mut ji_band = Vec::new();
if i != j {
ji_band.reserve(nj * ni);
for b in 0..nj {
for a in 0..ni {
let row = (oj + b) * nao + (oi + a);
ji_band.push(claim_row(&mut slabs, row));
}
}
}
tasks.push(BraPairFill {
bra: (i, j),
ij_band,
ji_band,
});
}
debug_assert!(
slabs.iter().all(Option::is_none),
"partition left {} output rows unclaimed",
slabs.iter().filter(|s| s.is_some()).count()
);
tasks
}
pub fn fill(&self, task: &mut BraPairFill<'_>) {
let (i, j) = task.bra;
let mut sink = BandSink {
nao: self.nao,
off_i: self.offs[i],
off_j: self.offs[j],
n_i: self.nfunc[i],
n_j: self.nfunc[j],
ij: &mut task.ij_band,
ji: &mut task.ji_band,
};
self.run_bra_pair(i, j, &mut sink);
}
#[must_use]
pub fn build(&self) -> Vec<f64> {
let mut out = vec![0.0; self.output_len()];
let mut tasks = self.partition(&mut out);
for task in &mut tasks {
self.fill(task);
}
out
}
fn run_bra_pair<S: EriSink>(&self, i: usize, j: usize, sink: &mut S) {
let s = self.shells;
let (sa, sb) = (&s[i], &s[j]);
let mut scratch = QuartetScratch::default();
for &(k, l) in &self.pairs {
let (sc, sd) = (&s[k], &s[l]);
let len = quartet_into_scratch(
&mut scratch,
self.engine,
[sa, sb, sc, sd],
[&self.eff[i], &self.eff[j], &self.eff[k], &self.eff[l]],
[
self.c2s[i].as_deref(),
self.c2s[j].as_deref(),
self.c2s[k].as_deref(),
self.c2s[l].as_deref(),
],
);
scatter_4fold(
sink,
[i, j, k, l],
&self.offs,
[self.nfunc[i], self.nfunc[j], self.nfunc[k], self.nfunc[l]],
&scratch.block[..len],
);
}
}
}
#[derive(Debug)]
pub struct BraPairFill<'o> {
bra: (usize, usize),
ij_band: Vec<&'o mut [f64]>,
ji_band: Vec<&'o mut [f64]>,
}
impl BraPairFill<'_> {
#[must_use]
pub fn bra(&self) -> (usize, usize) {
self.bra
}
}
impl Basis {
#[must_use]
pub fn eri_builder(&self) -> EriBuilder<'_> {
EriBuilder::new(self)
}
}
fn claim_row<'o>(slabs: &mut [Option<&'o mut [f64]>], row: usize) -> &'o mut [f64] {
slabs[row]
.take()
.expect("output row claimed by two bra-pairs (disjointness violated)")
}
trait EriSink {
fn put(&mut self, mu: usize, nu: usize, la: usize, sg: usize, v: f64);
}
struct BandSink<'a, 'o> {
nao: usize,
off_i: usize,
off_j: usize,
n_i: usize,
n_j: usize,
ij: &'a mut [&'o mut [f64]],
ji: &'a mut [&'o mut [f64]],
}
impl EriSink for BandSink<'_, '_> {
#[inline]
fn put(&mut self, mu: usize, nu: usize, la: usize, sg: usize, v: f64) {
let col = la * self.nao + sg;
if mu >= self.off_i
&& mu - self.off_i < self.n_i
&& nu >= self.off_j
&& nu - self.off_j < self.n_j
{
let a = mu - self.off_i;
let b = nu - self.off_j;
self.ij[a * self.n_j + b][col] = v;
} else {
let b = mu - self.off_j;
let a = nu - self.off_i;
self.ji[b * self.n_i + a][col] = v;
}
}
}
fn scatter_4fold<S: EriSink>(
sink: &mut S,
sidx: [usize; 4],
offs: &[usize],
n: [usize; 4],
block: &[f64],
) {
let mut seen: [[usize; 4]; 4] = [[usize::MAX; 4]; 4];
let mut n_seen = 0;
for perm in &PERMS8[..4] {
let tup = [sidx[perm[0]], sidx[perm[1]], sidx[perm[2]], sidx[perm[3]]];
if seen[..n_seen].contains(&tup) {
continue;
}
seen[n_seen] = tup;
n_seen += 1;
let o = [offs[tup[0]], offs[tup[1]], offs[tup[2]], offs[tup[3]]];
let (m_ax, n_ax) = (perm[0], perm[1]); let (l_ax, s_ax) = (perm[2] - 2, perm[3] - 2);
let mut src = 0usize;
for a in 0..n[0] {
for b in 0..n[1] {
let ab = [a, b];
let mu = o[0] + ab[m_ax];
let nu = o[1] + ab[n_ax];
for c in 0..n[2] {
for d in 0..n[3] {
let cd = [c, d];
let la = o[2] + cd[l_ax];
let sg = o[3] + cd[s_ax];
sink.put(mu, nu, la, sg, block[src]);
src += 1;
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
struct CountSink<'a> {
nao: usize,
owner: &'a mut [i64],
current: i64,
}
impl EriSink for CountSink<'_> {
fn put(&mut self, mu: usize, nu: usize, la: usize, sg: usize, _v: f64) {
let idx = ((mu * self.nao + nu) * self.nao + la) * self.nao + sg;
assert_eq!(
self.owner[idx], -1,
"element {idx} written twice: now by bra-pair {}, previously by {}",
self.current, self.owner[idx]
);
self.owner[idx] = self.current;
}
}
fn mixed_basis() -> Basis {
Basis::new(vec![
Shell::new(0, [0.0, 0.0, 0.0], vec![1.4, 0.5], vec![0.6, 0.5]).unwrap(),
Shell::new(1, [0.6, -0.3, 0.2], vec![0.9], vec![1.0]).unwrap(),
Shell::new(1, [0.6, -0.3, 0.2], vec![0.4], vec![1.0]).unwrap(),
Shell::new_spherical(2, [-0.4, 0.7, -0.1], vec![1.1], vec![1.0]).unwrap(),
])
}
#[test]
fn write_coverage_exactly_once_and_disjoint() {
let basis = mixed_basis();
let builder = EriBuilder::new(&basis);
let n4 = builder.output_len();
let mut owner = vec![-1i64; n4];
for (p, &(i, j)) in builder.bra_pairs().iter().enumerate() {
let mut sink = CountSink {
nao: builder.nao,
owner: &mut owner,
current: p as i64,
};
builder.run_bra_pair(i, j, &mut sink);
}
let unwritten = owner.iter().filter(|&&o| o == -1).count();
assert_eq!(unwritten, 0, "{unwritten} output elements never written");
}
#[test]
fn partition_claims_every_row_once() {
let basis = mixed_basis();
let builder = EriBuilder::new(&basis);
let mut out = vec![0.0; builder.output_len()];
let tasks = builder.partition(&mut out);
let nao = builder.nao;
let total_rows: usize = tasks
.iter()
.map(|t| t.ij_band.len() + t.ji_band.len())
.sum();
assert_eq!(total_rows, nao * nao, "bra-pairs do not cover all rows");
assert_eq!(tasks.len(), builder.bra_pairs().len());
}
#[test]
fn serial_build_matches_basis_eri_tolerance() {
let basis = mixed_basis();
let reference = basis.eri();
let built = EriBuilder::new(&basis).build();
assert_eq!(reference.len(), built.len());
let peak = reference.iter().fold(0.0_f64, |m, &x| m.max(x.abs()));
let floor = 1e-3 * peak;
let mut worst_sig = 0.0_f64;
let mut worst_abs = 0.0_f64;
for (&r, &b) in reference.iter().zip(&built) {
let dv = (r - b).abs();
worst_abs = worst_abs.max(dv);
if r.abs() >= floor {
worst_sig = worst_sig.max(dv / r.abs());
}
}
assert!(
worst_sig < 1e-11,
"worst significant relative diff {worst_sig:e}"
);
assert!(
worst_abs < 1e-11 * peak.max(1.0) + 1e-12,
"worst absolute diff {worst_abs:e} (peak {peak:e})"
);
}
}