use super::gaussian_integrals::{distance_squared, gaussian_product_center};
use std::f64::consts::PI;
fn erf_approx(x: f64) -> f64 {
let sign = if x < 0.0 { -1.0 } else { 1.0 };
let x = x.abs();
let t = 1.0 / (1.0 + 0.3275911 * x);
let poly = t
* (0.254829592
+ t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
sign * (1.0 - poly * (-x * x).exp())
}
#[derive(Debug, Clone)]
pub struct ShellPairData {
pub p: f64,
pub mu: f64,
pub center_p: [f64; 3],
pub k_ab: f64,
pub center_a: [f64; 3],
pub center_b: [f64; 3],
pub alpha: f64,
pub beta: f64,
}
impl ShellPairData {
pub fn new(alpha: f64, center_a: [f64; 3], beta: f64, center_b: [f64; 3]) -> Self {
let p = alpha + beta;
let mu = alpha * beta / p;
let ab2 = distance_squared(¢er_a, ¢er_b);
let k_ab = (-mu * ab2).exp();
let center_p = [
gaussian_product_center(alpha, center_a[0], beta, center_b[0]),
gaussian_product_center(alpha, center_a[1], beta, center_b[1]),
gaussian_product_center(alpha, center_a[2], beta, center_b[2]),
];
Self {
p,
mu,
center_p,
k_ab,
center_a,
center_b,
alpha,
beta,
}
}
}
pub fn boys_function(n: usize, x: f64) -> f64 {
if x < 1e-12 {
return 1.0 / (2 * n + 1) as f64;
}
if x > 30.0 {
let pi_over_x = (PI / x).sqrt();
let mut result = pi_over_x / 2.0;
for i in 1..=n {
result *= (2 * i - 1) as f64 / (2.0 * x);
}
return result;
}
let f0 = (PI / x).sqrt() * erf_approx(x.sqrt()) / 2.0;
if n == 0 {
return f0;
}
let n_max = n + 15;
let f_high = 0.0;
for _k in (0..=n_max).rev() {
}
let exp_neg_x = (-x).exp();
let mut f_curr = f0;
for i in 0..n {
f_curr = ((2 * i + 1) as f64 * f_curr - exp_neg_x) / (2.0 * x);
}
let _ = f_high;
f_curr
}
pub fn eri_ssss(sp_ab: &ShellPairData, sp_cd: &ShellPairData) -> f64 {
let p = sp_ab.p;
let q = sp_cd.p;
let alpha = p * q / (p + q);
let pq2 = distance_squared(&sp_ab.center_p, &sp_cd.center_p);
let t = alpha * pq2;
let prefactor = 2.0 * PI.powi(2) * PI.sqrt() / (p * q * (p + q).sqrt());
prefactor * sp_ab.k_ab * sp_cd.k_ab * boys_function(0, t)
}
pub fn schwarz_bound(alpha: f64, center_a: [f64; 3], beta: f64, center_b: [f64; 3]) -> f64 {
let sp = ShellPairData::new(alpha, center_a, beta, center_b);
let eri = eri_ssss(&sp, &sp);
eri.abs().sqrt()
}
pub fn compute_eris_screened(
shell_pairs: &[ShellPairData],
schwarz_q: &[f64],
threshold: f64,
) -> Vec<(usize, usize, usize, usize, f64)> {
let n_pairs = shell_pairs.len();
let mut eris = Vec::new();
for ij in 0..n_pairs {
if schwarz_q[ij] < threshold {
continue;
}
for kl in ij..n_pairs {
if schwarz_q[ij] * schwarz_q[kl] < threshold {
continue;
}
let val = eri_ssss(&shell_pairs[ij], &shell_pairs[kl]);
if val.abs() > threshold {
eris.push((ij, kl, 0, 0, val));
}
}
}
eris
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn boys_f0_is_correct() {
let f0 = boys_function(0, 0.0);
assert!((f0 - 1.0).abs() < 1e-10);
let f0_1 = boys_function(0, 1.0);
let expected = (PI / 1.0).sqrt() * erf_approx(1.0) / 2.0;
assert!((f0_1 - expected).abs() < 1e-8);
}
#[test]
fn eri_ssss_hydrogen_molecule() {
let sp1 = ShellPairData::new(1.0, [0.0, 0.0, 0.0], 1.0, [0.0, 0.0, 0.0]);
let sp2 = ShellPairData::new(1.0, [1.4, 0.0, 0.0], 1.0, [1.4, 0.0, 0.0]);
let eri = eri_ssss(&sp1, &sp2);
assert!(eri > 0.0);
}
}