use crate::integrals::{
canonical_shell_pairs, effective_coeffs, quartet_into_scratch, Engine, QuartetScratch,
};
use crate::shell::{Basis, Shell};
use crate::spherical::shell_transform;
fn unit_s(center: [f64; 3]) -> Shell {
Shell::new(0, center, vec![0.0], vec![1.0]).expect("unit s dummy shell is always valid")
}
impl Basis {
#[must_use]
pub fn eri_2c(&self) -> Vec<f64> {
self.eri_2c_with(Engine::Auto)
}
#[must_use]
pub fn eri_2c_with(&self, engine: Engine) -> Vec<f64> {
let naux = self.nao();
let offs = self.offsets();
let shells = self.shells();
let eff: Vec<Vec<f64>> = shells.iter().map(effective_coeffs).collect();
let c2s: Vec<Option<Vec<f64>>> = shells.iter().map(shell_transform).collect();
let dummies: Vec<Shell> = shells.iter().map(|s| unit_s(s.center())).collect();
let unit_eff = [1.0];
let mut out = vec![0.0; naux * naux];
let mut scratch = QuartetScratch::default();
for (p, sp) in shells.iter().enumerate() {
for (q, sq) in shells.iter().enumerate().take(p + 1) {
let len = quartet_into_scratch(
&mut scratch,
engine,
[sp, &dummies[p], sq, &dummies[q]],
[&eff[p], &unit_eff, &eff[q], &unit_eff],
[c2s[p].as_deref(), None, c2s[q].as_deref(), None],
);
let (np, nq) = (sp.n_func(), sq.n_func());
debug_assert_eq!(len, np * nq);
for a in 0..np {
for b in 0..nq {
let v = scratch.block[a * nq + b];
out[(offs[p] + a) * naux + offs[q] + b] = v;
out[(offs[q] + b) * naux + offs[p] + a] = v;
}
}
}
}
out
}
#[must_use]
pub fn eri_3c_block(&self, aux: &Basis, ish: usize, jsh: usize, psh: usize) -> Vec<f64> {
self.eri_3c_block_with(Engine::Auto, aux, ish, jsh, psh)
}
#[must_use]
pub fn eri_3c_block_with(
&self,
engine: Engine,
aux: &Basis,
ish: usize,
jsh: usize,
psh: usize,
) -> Vec<f64> {
let s = self.shells();
let (si, sj) = (&s[ish], &s[jsh]);
let sp = &aux.shells()[psh];
let dummy = unit_s(sp.center());
let (mi, mj, mp) = (
shell_transform(si),
shell_transform(sj),
shell_transform(sp),
);
let mut scratch = QuartetScratch::default();
let len = quartet_into_scratch(
&mut scratch,
engine,
[si, sj, sp, &dummy],
[
&effective_coeffs(si),
&effective_coeffs(sj),
&effective_coeffs(sp),
&[1.0],
],
[mi.as_deref(), mj.as_deref(), mp.as_deref(), None],
);
scratch.block[..len].to_vec()
}
#[must_use]
pub fn schwarz_aux_bounds(&self) -> Vec<f64> {
self.schwarz_aux_bounds_with(Engine::Auto)
}
#[must_use]
pub fn schwarz_aux_bounds_with(&self, engine: Engine) -> Vec<f64> {
let shells = self.shells();
let eff: Vec<Vec<f64>> = shells.iter().map(effective_coeffs).collect();
let c2s: Vec<Option<Vec<f64>>> = shells.iter().map(shell_transform).collect();
let unit_eff = [1.0];
let mut scratch = QuartetScratch::default();
let mut bounds = Vec::with_capacity(shells.len());
for (p, sp) in shells.iter().enumerate() {
let dummy = unit_s(sp.center());
let len = quartet_into_scratch(
&mut scratch,
engine,
[sp, &dummy, sp, &dummy],
[&eff[p], &unit_eff, &eff[p], &unit_eff],
[c2s[p].as_deref(), None, c2s[p].as_deref(), None],
);
let np = sp.n_func();
debug_assert_eq!(len, np * np);
let mut mx = 0.0_f64;
for mu in 0..np {
mx = mx.max(scratch.block[mu * np + mu].abs());
}
bounds.push(mx.sqrt());
}
bounds
}
#[must_use]
pub fn eri_3c_builder<'a>(&'a self, aux: &'a Basis) -> Eri3cBuilder<'a> {
Eri3cBuilder::new(self, aux)
}
}
#[derive(Debug)]
pub struct Eri3cBuilder<'b> {
main: &'b [Shell],
aux: &'b [Shell],
engine: Engine,
offs: Vec<usize>,
nfunc: Vec<usize>,
nao: usize,
aux_offs: Vec<usize>,
naux: usize,
eff: Vec<Vec<f64>>,
aux_eff: Vec<Vec<f64>>,
c2s: Vec<Option<Vec<f64>>>,
aux_c2s: Vec<Option<Vec<f64>>>,
dummies: Vec<Shell>,
pairs: Vec<(usize, usize)>,
}
impl<'b> Eri3cBuilder<'b> {
#[must_use]
pub fn new(main: &'b Basis, aux: &'b Basis) -> Self {
Self::with_engine(main, aux, Engine::Auto)
}
#[must_use]
pub fn with_engine(main: &'b Basis, aux: &'b Basis, engine: Engine) -> Self {
let shells = main.shells();
let aux_shells = aux.shells();
Eri3cBuilder {
main: shells,
aux: aux_shells,
engine,
offs: main.offsets(),
nfunc: shells.iter().map(Shell::n_func).collect(),
nao: main.nao(),
aux_offs: aux.offsets(),
naux: aux.nao(),
eff: shells.iter().map(effective_coeffs).collect(),
aux_eff: aux_shells.iter().map(effective_coeffs).collect(),
c2s: shells.iter().map(shell_transform).collect(),
aux_c2s: aux_shells.iter().map(shell_transform).collect(),
dummies: aux_shells.iter().map(|s| unit_s(s.center())).collect(),
pairs: canonical_shell_pairs(shells.len()),
}
}
#[must_use]
pub fn bra_pairs(&self) -> &[(usize, usize)] {
&self.pairs
}
#[must_use]
pub fn output_len(&self) -> usize {
self.nao * self.nao * self.naux
}
#[must_use]
pub fn partition<'o>(&self, out: &'o mut [f64]) -> Vec<Bra3cFill<'o>> {
let nao = self.nao;
assert_eq!(
out.len(),
nao * nao * self.naux,
"3c output buffer must be nao²·naux = {} elements",
nao * nao * self.naux
);
let mut slabs: Vec<Option<&'o mut [f64]>> =
out.chunks_exact_mut(self.naux).map(Some).collect();
debug_assert_eq!(slabs.len(), nao * nao);
let mut tasks = Vec::with_capacity(self.pairs.len());
for &(i, j) in &self.pairs {
let (ni, nj) = (self.nfunc[i], self.nfunc[j]);
let (oi, oj) = (self.offs[i], self.offs[j]);
let mut ij_band = Vec::with_capacity(ni * nj);
for a in 0..ni {
for b in 0..nj {
ij_band.push(claim_row(&mut slabs, (oi + a) * nao + (oj + b)));
}
}
let mut ji_band = Vec::new();
if i != j {
ji_band.reserve(nj * ni);
for b in 0..nj {
for a in 0..ni {
ji_band.push(claim_row(&mut slabs, (oj + b) * nao + (oi + a)));
}
}
}
tasks.push(Bra3cFill {
bra: (i, j),
ij_band,
ji_band,
});
}
debug_assert!(
slabs.iter().all(Option::is_none),
"partition left {} output rows unclaimed",
slabs.iter().filter(|s| s.is_some()).count()
);
tasks
}
pub fn fill(&self, task: &mut Bra3cFill<'_>) {
self.fill_filtered(task, |_| true);
}
pub fn fill_filtered(&self, task: &mut Bra3cFill<'_>, keep: impl Fn(usize) -> bool) {
let (i, j) = task.bra;
let (si, sj) = (&self.main[i], &self.main[j]);
let (ni, nj) = (self.nfunc[i], self.nfunc[j]);
let mut scratch = QuartetScratch::default();
for (p, sp) in self.aux.iter().enumerate() {
if !keep(p) {
continue;
}
let len = quartet_into_scratch(
&mut scratch,
self.engine,
[si, sj, sp, &self.dummies[p]],
[&self.eff[i], &self.eff[j], &self.aux_eff[p], &[1.0]],
[
self.c2s[i].as_deref(),
self.c2s[j].as_deref(),
self.aux_c2s[p].as_deref(),
None,
],
);
let np = sp.n_func();
debug_assert_eq!(len, ni * nj * np);
let op = self.aux_offs[p];
let block = &scratch.block[..len];
for a in 0..ni {
for b in 0..nj {
let row = &block[(a * nj + b) * np..(a * nj + b + 1) * np];
task.ij_band[a * nj + b][op..op + np].copy_from_slice(row);
if i != j {
task.ji_band[b * ni + a][op..op + np].copy_from_slice(row);
}
}
}
}
}
#[must_use]
pub fn build(&self) -> Vec<f64> {
let mut out = vec![0.0; self.output_len()];
let mut tasks = self.partition(&mut out);
for task in &mut tasks {
self.fill(task);
}
out
}
}
#[derive(Debug)]
pub struct Bra3cFill<'o> {
bra: (usize, usize),
ij_band: Vec<&'o mut [f64]>,
ji_band: Vec<&'o mut [f64]>,
}
impl Bra3cFill<'_> {
#[must_use]
pub fn bra(&self) -> (usize, usize) {
self.bra
}
}
fn claim_row<'o>(slabs: &mut [Option<&'o mut [f64]>], row: usize) -> &'o mut [f64] {
slabs[row]
.take()
.expect("output row claimed by two bra-pairs (disjointness violated)")
}