use integral_math::am::{cart_components, n_cart};
use integral_math::rys::{rys_roots_weights, MAX_RYS_ROOTS};
use crate::os::{Prim, Vec3, MAX_L};
const L1: usize = MAX_L + 1;
const LT1: usize = 2 * MAX_L + 1;
type Axis4 = [[[[f64; L1]; L1]; L1]; L1];
pub fn coulomb_into(a: Prim, b: Prim, c: Prim, d: Prim, scale: f64, out: &mut [f64]) {
let (la, lb, lc, ld) = (a.l, b.l, c.l, d.l);
debug_assert!(
la <= MAX_L && lb <= MAX_L && lc <= MAX_L && ld <= MAX_L,
"angular momentum exceeds MAX_L"
);
let (na, nb, nc, nd) = (n_cart(la), n_cart(lb), n_cart(lc), n_cart(ld));
debug_assert!(out.len() >= na * nb * nc * nd, "ERI output block too short");
let p = a.exp + b.exp;
let q = c.exp + d.exp;
let p_ctr = combine(a.exp, a.center, b.exp, b.center, p);
let q_ctr = combine(c.exp, c.center, d.exp, d.center, q);
let kab = (-(a.exp * b.exp / p) * dist2(a.center, b.center)).exp();
let kcd = (-(c.exp * d.exp / q) * dist2(c.center, d.center)).exp();
let pq = p + q;
let rho = p * q / pq;
let pq_vec = sub(p_ctr, q_ctr);
let t = rho * norm2(pq_vec);
use std::f64::consts::PI;
let pref = 2.0 * PI * PI * PI.sqrt() / (p * q * pq.sqrt()) * kab * kcd;
let l_total = la + lb + lc + ld;
let nroots = l_total / 2 + 1;
let mut roots = [0.0f64; MAX_RYS_ROOTS];
let mut wts = [0.0f64; MAX_RYS_ROOTS];
rys_roots_weights(nroots, t, &mut roots, &mut wts);
let pa = sub(p_ctr, a.center); let qc = sub(q_ctr, c.center); let ab = sub(a.center, b.center); let cd = sub(c.center, d.center);
let comps_a = cart_components(la);
let comps_b = cart_components(lb);
let comps_c = cart_components(lc);
let comps_d = cart_components(ld);
let mut ix: Axis4 = zeroed();
let mut iy: Axis4 = zeroed();
let mut iz: Axis4 = zeroed();
for r in 0..nroots {
let x = roots[r];
let b00 = x / (2.0 * pq);
let b10 = (1.0 - x * q / pq) / (2.0 * p);
let b01 = (1.0 - x * p / pq) / (2.0 * q);
let cq = x * q / pq; let cp = x * p / pq;
for (axis, out4) in [&mut ix, &mut iy, &mut iz].into_iter().enumerate() {
let c00 = pa[axis] - cq * pq_vec[axis];
let cp0 = qc[axis] + cp * pq_vec[axis];
build_axis(
la, lb, lc, ld, c00, cp0, b10, b01, b00, ab[axis], cd[axis], out4,
);
}
let pw = scale * pref * wts[r];
for (ia, ca) in comps_a.iter().enumerate() {
for (ib, cb) in comps_b.iter().enumerate() {
for (ic, cc) in comps_c.iter().enumerate() {
for (id, cd_) in comps_d.iter().enumerate() {
let v = ix[ca[0]][cb[0]][cc[0]][cd_[0]]
* iy[ca[1]][cb[1]][cc[1]][cd_[1]]
* iz[ca[2]][cb[2]][cc[2]][cd_[2]];
out[((ia * nb + ib) * nc + ic) * nd + id] += pw * v;
}
}
}
}
}
}
#[allow(clippy::too_many_arguments, clippy::needless_range_loop)]
fn build_axis(
la: usize,
lb: usize,
lc: usize,
ld: usize,
c00: f64,
cp0: f64,
b10: f64,
b01: f64,
b00: f64,
ab: f64,
cd: f64,
out4: &mut Axis4,
) {
let nbra = la + lb; let nket = lc + ld;
let mut g = [[0.0f64; LT1]; LT1];
g[0][0] = 1.0;
if nbra >= 1 {
g[1][0] = c00;
}
for n in 1..nbra {
g[n + 1][0] = c00 * g[n][0] + (n as f64) * b10 * g[n - 1][0];
}
for m in 0..nket {
for n in 0..=nbra {
let mut term = cp0 * g[n][m];
if m >= 1 {
term += (m as f64) * b01 * g[n][m - 1];
}
if n >= 1 {
term += (n as f64) * b00 * g[n - 1][m];
}
g[n][m + 1] = term;
}
}
let mut h = [[[0.0f64; LT1]; L1]; LT1];
for i in 0..=nbra {
for m in 0..=nket {
h[i][0][m] = g[i][m];
}
}
for j in 1..=lb {
for i in 0..=(nbra - j) {
for m in 0..=nket {
h[i][j][m] = h[i + 1][j - 1][m] + ab * h[i][j - 1][m];
}
}
}
for i in 0..=la {
for j in 0..=lb {
let mut t = [[0.0f64; L1]; LT1];
for k in 0..=nket {
t[k][0] = h[i][j][k];
}
for l in 1..=ld {
for k in 0..=(nket - l) {
t[k][l] = t[k + 1][l - 1] + cd * t[k][l - 1];
}
}
for k in 0..=lc {
for l in 0..=ld {
out4[i][j][k][l] = t[k][l];
}
}
}
}
}
#[inline]
fn zeroed() -> Axis4 {
[[[[0.0; L1]; L1]; L1]; L1]
}
#[inline]
fn combine(a: f64, ca: Vec3, b: f64, cb: Vec3, p: f64) -> Vec3 {
[
(a * ca[0] + b * cb[0]) / p,
(a * ca[1] + b * cb[1]) / p,
(a * ca[2] + b * cb[2]) / p,
]
}
#[inline]
fn sub(u: Vec3, v: Vec3) -> Vec3 {
[u[0] - v[0], u[1] - v[1], u[2] - v[2]]
}
#[inline]
fn dist2(u: Vec3, v: Vec3) -> f64 {
norm2(sub(u, v))
}
#[inline]
fn norm2(u: Vec3) -> f64 {
u[0] * u[0] + u[1] * u[1] + u[2] * u[2]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ssss_matches_closed_form() {
let a = Prim::new(0.8, [0.0, 0.0, 0.0], 0);
let b = Prim::new(1.3, [0.0, 0.0, 0.7], 0);
let c = Prim::new(0.5, [0.4, 0.0, 0.0], 0);
let d = Prim::new(1.1, [0.0, 0.6, 0.2], 0);
let mut out = [0.0; 1];
coulomb_into(a, b, c, d, 1.0, &mut out);
let p = a.exp + b.exp;
let q = c.exp + d.exp;
let pc = combine(a.exp, a.center, b.exp, b.center, p);
let qc = combine(c.exp, c.center, d.exp, d.center, q);
let kab = (-(a.exp * b.exp / p) * dist2(a.center, b.center)).exp();
let kcd = (-(c.exp * d.exp / q) * dist2(c.center, d.center)).exp();
let rho = p * q / (p + q);
let t = rho * dist2(pc, qc);
let mut fm = [0.0; 1];
integral_math::boys::boys_array(0, t, &mut fm);
use std::f64::consts::PI;
let expect = 2.0 * PI * PI * PI.sqrt() / (p * q * (p + q).sqrt()) * kab * kcd * fm[0];
assert!(
(out[0] - expect).abs() < 1e-14 * expect.abs(),
"ssss {} vs {}",
out[0],
expect
);
}
#[test]
fn scale_and_accumulate() {
let a = Prim::new(0.8, [0.0, 0.0, 0.0], 0);
let b = Prim::new(1.3, [0.0, 0.0, 0.7], 0);
let c = Prim::new(0.5, [0.4, 0.0, 0.0], 0);
let d = Prim::new(1.1, [0.0, 0.6, 0.2], 0);
let mut one = [0.0; 1];
coulomb_into(a, b, c, d, 1.0, &mut one);
let mut acc = [0.0; 1];
coulomb_into(a, b, c, d, 2.0, &mut acc);
coulomb_into(a, b, c, d, 3.0, &mut acc);
assert!((acc[0] - 5.0 * one[0]).abs() < 1e-12 * one[0].abs());
}
#[test]
fn bra_ket_exchange_primitive() {
let a = Prim::new(0.9, [0.1, 0.0, 0.0], 1); let b = Prim::new(0.7, [0.0, 0.2, 0.0], 0); let c = Prim::new(1.2, [0.0, 0.0, 0.3], 0); let d = Prim::new(0.6, [0.2, 0.1, 0.0], 0); let mut abcd = [0.0; 3]; coulomb_into(a, b, c, d, 1.0, &mut abcd);
let mut cdab = [0.0; 3];
coulomb_into(c, d, a, b, 1.0, &mut cdab);
for i in 0..3 {
assert!(
(abcd[i] - cdab[i]).abs() < 1e-13 * abcd[i].abs().max(1e-300),
"bra-ket exchange mismatch at {i}: {} vs {}",
abcd[i],
cdab[i]
);
}
}
}