use integral_core::os_eri::{self, ShellRef};
use integral_core::{os, rys};
use crate::shell::{Basis, Shell};
use crate::spherical::{shell_transform, transform_block};
pub(crate) fn to_func_1e(block: Vec<f64>, sa: &Shell, sb: &Shell) -> Vec<f64> {
let mats = [shell_transform(sa), shell_transform(sb)];
transform_block(
block,
&[sa.n_cart(), sb.n_cart()],
&[mats[0].as_deref(), mats[1].as_deref()],
)
}
pub(crate) fn to_func_eri(
block: Vec<f64>,
sa: &Shell,
sb: &Shell,
sc: &Shell,
sd: &Shell,
) -> Vec<f64> {
let mats = [
shell_transform(sa),
shell_transform(sb),
shell_transform(sc),
shell_transform(sd),
];
to_func_eri_cached(
block,
[sa, sb, sc, sd],
[
mats[0].as_deref(),
mats[1].as_deref(),
mats[2].as_deref(),
mats[3].as_deref(),
],
)
}
pub(crate) fn to_func_eri_cached(
block: Vec<f64>,
s: [&Shell; 4],
mats: [Option<&[f64]>; 4],
) -> Vec<f64> {
transform_block(
block,
&[s[0].n_cart(), s[1].n_cart(), s[2].n_cart(), s[3].n_cart()],
&mats,
)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Engine {
#[default]
Auto,
OsHgp,
Rys,
}
#[must_use]
pub fn select_engine(l_total: usize, contraction_degree: usize) -> Engine {
let threshold = match l_total {
0..=5 => 1, 6..=16 => 16, _ => return Engine::Rys, };
if contraction_degree >= threshold {
Engine::OsHgp
} else {
Engine::Rys
}
}
pub(crate) fn place_block(
mat: &mut [f64],
n: usize,
row_off: usize,
col_off: usize,
block: &[f64],
nb: usize,
) {
let na = block.len() / nb;
for i in 0..na {
for j in 0..nb {
mat[(row_off + i) * n + col_off + j] = block[i * nb + j];
}
}
}
fn contract_pair<F>(sa: &Shell, sb: &Shell, mut prim_op: F) -> Vec<f64>
where
F: FnMut(os::Prim, os::Prim, f64, &mut [f64]),
{
let mut block = vec![0.0; sa.n_cart() * sb.n_cart()];
for pi in 0..sa.n_prim() {
for pj in 0..sb.n_prim() {
let scale = sa.primitive_coeff(pi) * sb.primitive_coeff(pj);
prim_op(sa.prim(pi), sb.prim(pj), scale, &mut block);
}
}
block
}
impl Basis {
#[must_use]
pub fn overlap(&self) -> Vec<f64> {
let n = self.nao();
let offs = self.offsets();
let mut mat = vec![0.0; n * n];
for (si, sa) in self.shells().iter().enumerate() {
for (sj, sb) in self.shells().iter().enumerate() {
let block = to_func_1e(contract_pair(sa, sb, os::overlap_into), sa, sb);
place_block(&mut mat, n, offs[si], offs[sj], &block, sb.n_func());
}
}
mat
}
#[must_use]
pub fn kinetic(&self) -> Vec<f64> {
let n = self.nao();
let offs = self.offsets();
let mut mat = vec![0.0; n * n];
for (si, sa) in self.shells().iter().enumerate() {
for (sj, sb) in self.shells().iter().enumerate() {
let block = to_func_1e(contract_pair(sa, sb, os::kinetic_into), sa, sb);
place_block(&mut mat, n, offs[si], offs[sj], &block, sb.n_func());
}
}
mat
}
#[must_use]
pub fn nuclear(&self, charges: &[([f64; 3], f64)]) -> Vec<f64> {
let n = self.nao();
let offs = self.offsets();
let mut mat = vec![0.0; n * n];
for (si, sa) in self.shells().iter().enumerate() {
for (sj, sb) in self.shells().iter().enumerate() {
let block = to_func_1e(
contract_pair(sa, sb, |a, b, scale, out| {
os::nuclear_into(a, b, charges, scale, out);
}),
sa,
sb,
);
place_block(&mut mat, n, offs[si], offs[sj], &block, sb.n_func());
}
}
mat
}
#[must_use]
pub fn dipole(&self, o: [f64; 3]) -> [Vec<f64>; 3] {
let n = self.nao();
let offs = self.offsets();
let mut dx = vec![0.0; n * n];
let mut dy = vec![0.0; n * n];
let mut dz = vec![0.0; n * n];
for (si, sa) in self.shells().iter().enumerate() {
for (sj, sb) in self.shells().iter().enumerate() {
let (na, nb) = (sa.n_cart(), sb.n_cart());
let (mut bx, mut by, mut bz) =
(vec![0.0; na * nb], vec![0.0; na * nb], vec![0.0; na * nb]);
for pi in 0..sa.n_prim() {
for pj in 0..sb.n_prim() {
let scale = sa.primitive_coeff(pi) * sb.primitive_coeff(pj);
os::dipole_into(
sa.prim(pi),
sb.prim(pj),
o,
scale,
&mut bx,
&mut by,
&mut bz,
);
}
}
let bx = to_func_1e(bx, sa, sb);
let by = to_func_1e(by, sa, sb);
let bz = to_func_1e(bz, sa, sb);
let nbf = sb.n_func();
place_block(&mut dx, n, offs[si], offs[sj], &bx, nbf);
place_block(&mut dy, n, offs[si], offs[sj], &by, nbf);
place_block(&mut dz, n, offs[si], offs[sj], &bz, nbf);
}
}
[dx, dy, dz]
}
#[must_use]
pub fn eri_block(&self, i: usize, j: usize, k: usize, l: usize) -> Vec<f64> {
self.eri_block_with(Engine::Auto, i, j, k, l)
}
#[must_use]
pub fn eri_block_with(
&self,
engine: Engine,
i: usize,
j: usize,
k: usize,
l: usize,
) -> Vec<f64> {
let s = self.shells();
let (sa, sb, sc, sd) = (&s[i], &s[j], &s[k], &s[l]);
let block = contract_quartet(engine, sa, sb, sc, sd);
to_func_eri(block, sa, sb, sc, sd)
}
#[must_use]
pub fn eri(&self) -> Vec<f64> {
self.eri_with(Engine::Auto)
}
#[must_use]
pub fn eri_with(&self, engine: Engine) -> Vec<f64> {
let nao = self.nao();
let offs = self.offsets();
let shells = self.shells();
let eff: Vec<Vec<f64>> = shells.iter().map(effective_coeffs).collect();
let c2s: Vec<Option<Vec<f64>>> = shells.iter().map(shell_transform).collect();
let mut out = vec![0.0; nao * nao * nao * nao];
for (si, sa) in shells.iter().enumerate() {
for (sj, sb) in shells.iter().enumerate().take(si + 1) {
for (sk, sc) in shells.iter().enumerate().take(si + 1) {
let l_top = if sk == si { sj } else { sk };
for (sl, sd) in shells.iter().enumerate().take(l_top + 1) {
let block = to_func_eri_cached(
contract_quartet_cached(
engine, sa, &eff[si], sb, &eff[sj], sc, &eff[sk], sd, &eff[sl],
),
[sa, sb, sc, sd],
[
c2s[si].as_deref(),
c2s[sj].as_deref(),
c2s[sk].as_deref(),
c2s[sl].as_deref(),
],
);
scatter_eri_block_s8(
&mut out,
nao,
[si, sj, sk, sl],
&offs,
[sa.n_func(), sb.n_func(), sc.n_func(), sd.n_func()],
&block,
);
}
}
}
}
out
}
#[must_use]
pub fn schwarz_bounds(&self) -> Vec<f64> {
self.schwarz_bounds_with(Engine::Auto)
}
#[must_use]
pub fn schwarz_bounds_with(&self, engine: Engine) -> Vec<f64> {
let shells = self.shells();
let nsh = shells.len();
let mut q = vec![0.0; nsh * nsh];
for i in 0..nsh {
for j in 0..nsh {
let (ni, nj) = (shells[i].n_func(), shells[j].n_func());
let block = self.eri_block_with(engine, i, j, i, j);
let mut mx = 0.0_f64;
for mu in 0..ni {
for nu in 0..nj {
let idx = ((mu * nj + nu) * ni + mu) * nj + nu;
mx = mx.max(block[idx].abs());
}
}
q[i * nsh + j] = mx.sqrt();
}
}
q
}
#[must_use]
pub fn eri_screened(&self, tau: f64) -> (Vec<f64>, ScreeningStats) {
self.eri_screened_with(Engine::Auto, tau)
}
#[must_use]
pub fn eri_screened_with(&self, engine: Engine, tau: f64) -> (Vec<f64>, ScreeningStats) {
let nao = self.nao();
let offs = self.offsets();
let shells = self.shells();
let nsh = shells.len();
let q = self.schwarz_bounds_with(engine);
let eff: Vec<Vec<f64>> = shells.iter().map(effective_coeffs).collect();
let c2s: Vec<Option<Vec<f64>>> = shells.iter().map(shell_transform).collect();
let mut out = vec![0.0; nao * nao * nao * nao];
let mut total = 0_usize;
let mut skipped = 0_usize;
for si in 0..nsh {
for sj in 0..=si {
let qij = q[si * nsh + sj];
for sk in 0..=si {
let l_top = if sk == si { sj } else { sk };
for sl in 0..=l_top {
total += 1;
if qij * q[sk * nsh + sl] < tau {
skipped += 1;
continue;
}
let block = to_func_eri_cached(
contract_quartet_cached(
engine,
&shells[si],
&eff[si],
&shells[sj],
&eff[sj],
&shells[sk],
&eff[sk],
&shells[sl],
&eff[sl],
),
[&shells[si], &shells[sj], &shells[sk], &shells[sl]],
[
c2s[si].as_deref(),
c2s[sj].as_deref(),
c2s[sk].as_deref(),
c2s[sl].as_deref(),
],
);
scatter_eri_block_s8(
&mut out,
nao,
[si, sj, sk, sl],
&offs,
[
shells[si].n_func(),
shells[sj].n_func(),
shells[sk].n_func(),
shells[sl].n_func(),
],
&block,
);
}
}
}
}
(
out,
ScreeningStats {
shell_quartets_total: total,
shell_quartets_skipped: skipped,
tau,
},
)
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ScreeningStats {
pub shell_quartets_total: usize,
pub shell_quartets_skipped: usize,
pub tau: f64,
}
impl ScreeningStats {
#[must_use]
pub fn skipped_fraction(&self) -> f64 {
if self.shell_quartets_total == 0 {
0.0
} else {
self.shell_quartets_skipped as f64 / self.shell_quartets_total as f64
}
}
}
fn contract_quartet(engine: Engine, sa: &Shell, sb: &Shell, sc: &Shell, sd: &Shell) -> Vec<f64> {
contract_quartet_cached(
engine,
sa,
&effective_coeffs(sa),
sb,
&effective_coeffs(sb),
sc,
&effective_coeffs(sc),
sd,
&effective_coeffs(sd),
)
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn contract_quartet_cached(
engine: Engine,
sa: &Shell,
ea: &[f64],
sb: &Shell,
eb: &[f64],
sc: &Shell,
ec: &[f64],
sd: &Shell,
ed: &[f64],
) -> Vec<f64> {
let resolved = match engine {
Engine::Auto => select_engine(
sa.l() + sb.l() + sc.l() + sd.l(),
sa.n_prim() * sb.n_prim() * sc.n_prim() * sd.n_prim(),
),
forced => forced,
};
match resolved {
Engine::OsHgp => contract_quartet_oshgp(sa, ea, sb, eb, sc, ec, sd, ed),
_ => contract_quartet_rys(sa, ea, sb, eb, sc, ec, sd, ed),
}
}
pub(crate) fn effective_coeffs(s: &Shell) -> Vec<f64> {
(0..s.n_prim()).map(|i| s.primitive_coeff(i)).collect()
}
#[allow(clippy::too_many_arguments)]
fn contract_quartet_rys(
sa: &Shell,
ea: &[f64],
sb: &Shell,
eb: &[f64],
sc: &Shell,
ec: &[f64],
sd: &Shell,
ed: &[f64],
) -> Vec<f64> {
let mut block = vec![0.0; sa.n_cart() * sb.n_cart() * sc.n_cart() * sd.n_cart()];
for (pa, &ca) in ea.iter().enumerate() {
for (pb, &cb) in eb.iter().enumerate() {
for (pc, &cc) in ec.iter().enumerate() {
for (pd, &cd) in ed.iter().enumerate() {
let scale = ca * cb * cc * cd;
rys::coulomb_into(
sa.prim(pa),
sb.prim(pb),
sc.prim(pc),
sd.prim(pd),
scale,
&mut block,
);
}
}
}
}
block
}
#[allow(clippy::too_many_arguments)]
fn contract_quartet_oshgp(
sa: &Shell,
ea: &[f64],
sb: &Shell,
eb: &[f64],
sc: &Shell,
ec: &[f64],
sd: &Shell,
ed: &[f64],
) -> Vec<f64> {
let mut block = vec![0.0; sa.n_cart() * sb.n_cart() * sc.n_cart() * sd.n_cart()];
os_eri::coulomb_shell_into(
ShellRef {
center: sa.center(),
l: sa.l(),
exps: sa.exponents(),
coeffs: ea,
},
ShellRef {
center: sb.center(),
l: sb.l(),
exps: sb.exponents(),
coeffs: eb,
},
ShellRef {
center: sc.center(),
l: sc.l(),
exps: sc.exponents(),
coeffs: ec,
},
ShellRef {
center: sd.center(),
l: sd.l(),
exps: sd.exponents(),
coeffs: ed,
},
&mut block,
);
block
}
pub(crate) const PERMS8: [[usize; 4]; 8] = [
[0, 1, 2, 3], [1, 0, 2, 3], [0, 1, 3, 2], [1, 0, 3, 2], [2, 3, 0, 1], [2, 3, 1, 0], [3, 2, 0, 1], [3, 2, 1, 0], ];
pub(crate) fn canonical_shell_pairs(nsh: usize) -> Vec<(usize, usize)> {
let mut pairs = Vec::with_capacity(nsh * (nsh + 1) / 2);
for i in 0..nsh {
for j in 0..=i {
pairs.push((i, j));
}
}
pairs
}
fn scatter_eri_block_s8(
out: &mut [f64],
nao: usize,
sidx: [usize; 4],
offs: &[usize],
n: [usize; 4],
block: &[f64],
) {
let mut seen: [[usize; 4]; 8] = [[usize::MAX; 4]; 8];
let mut n_seen = 0;
for perm in &PERMS8 {
let tup = [sidx[perm[0]], sidx[perm[1]], sidx[perm[2]], sidx[perm[3]]];
if seen[..n_seen].contains(&tup) {
continue;
}
seen[n_seen] = tup;
n_seen += 1;
let mut stride = [0usize; 4];
let mut base = 0usize;
for (q, &src_axis) in perm.iter().enumerate() {
let s = nao.pow(3 - q as u32);
stride[src_axis] = s;
base += offs[sidx[src_axis]] * s;
}
let mut src = 0usize;
for a in 0..n[0] {
let oa = base + a * stride[0];
for b in 0..n[1] {
let ob = oa + b * stride[1];
for c in 0..n[2] {
let oc = ob + c * stride[2];
for d in 0..n[3] {
out[oc + d * stride[3]] = block[src];
src += 1;
}
}
}
}
}
}