use std::f64::consts::PI;
use num_dual::DualNum;
use crate::families::lda::{Lda, LdaEnergy, LdaVars};
use crate::families::XcEval;
use crate::func::{Family, FunctionalId, FunctionalInfo, Kind};
use crate::reduced::consts::FPP_VWN;
use crate::reduced::vars::{f_zeta, one_minus_z_pow4};
pub(crate) const A_VWN: [f64; 3] = [0.0310907, 0.01554535, -1.0 / (6.0 * PI * PI)];
pub(crate) const B_VWN: [f64; 3] = [3.72744, 7.06042, 1.13107];
pub(crate) const C_VWN: [f64; 3] = [12.9352, 18.0578, 13.0045];
pub(crate) const X0_VWN: [f64; 3] = [-0.10498, -0.32500, -0.0047584];
pub(crate) fn vwn_f_aux<N: DualNum<f64> + Copy>(rs: N, a: f64, b: f64, c: f64, x0: f64) -> N {
let y = rs.sqrt(); let q = (4.0 * c - b * b).sqrt(); let f1 = 2.0 * b / q; let f2 = b * x0 / (x0 * x0 + b * x0 + c); let f3 = 2.0 * (2.0 * x0 + b) / q; let xx = rs + N::from(b) * y + N::from(c); let inv = xx.recip();
let l1 = (-(N::from(b) * y + N::from(c)) * inv).ln_1p();
let ymx0 = y - N::from(x0);
let l2 = ((ymx0 * ymx0 - xx) * inv).ln_1p();
let at = (N::from(q) / (y + y + N::from(b))).atan(); N::from(a) * (l1 + N::from(f1 - f2 * f3) * at - N::from(f2) * l2)
}
fn vwn5_ec<N: DualNum<f64> + Copy>(rs: N, z: N, zeta_threshold: f64) -> N {
let eps0 = vwn_f_aux(rs, A_VWN[0], B_VWN[0], C_VWN[0], X0_VWN[0]); let eps1 = vwn_f_aux(rs, A_VWN[1], B_VWN[1], C_VWN[1], X0_VWN[1]); let alpha = vwn_f_aux(rs, A_VWN[2], B_VWN[2], C_VWN[2], X0_VWN[2]); let fz = f_zeta(z, zeta_threshold);
let z4 = z.powi(4);
eps0 + alpha * fz * one_minus_z_pow4(z) / N::from(FPP_VWN) + (eps1 - eps0) * fz * z4
}
pub(crate) struct LdaCVwn {
info: FunctionalInfo,
zeta_threshold: f64,
}
impl LdaCVwn {
fn new() -> Self {
Self {
info: FunctionalInfo {
id: Some(FunctionalId::LdaCVwn),
name: "lda_c_vwn",
family: Family::Lda,
kind: Kind::Correlation,
needs_sigma: false,
needs_lapl: false,
needs_tau: false,
dens_threshold: 1e-15,
hybrid: None,
},
zeta_threshold: f64::EPSILON, }
}
pub(crate) fn boxed() -> Box<dyn XcEval> {
Box::new(Lda(Self::new()))
}
}
impl LdaEnergy for LdaCVwn {
fn info(&self) -> &FunctionalInfo {
&self.info
}
fn f<N: DualNum<f64> + Copy>(&self, v: LdaVars<N>) -> N {
vwn5_ec(v.rs, v.z, self.zeta_threshold)
}
}
#[cfg(test)]
mod tests {
use crate::{Functional, FunctionalId, Spin, XcInput};
#[test]
fn unpol_vrho_matches_finite_difference() {
let f = Functional::new(FunctionalId::LdaCVwn, Spin::Unpolarized).unwrap();
let edens = |x: f64| x * f.eval(1, &XcInput::lda(&[x])).unwrap().exc[0];
for &n in &[0.02, 0.2, 2.0, 50.0] {
let h = 1e-6 * n;
let fd = (edens(n + h) - edens(n - h)) / (2.0 * h);
let v = f.eval(1, &XcInput::lda(&[n])).unwrap().vrho[0];
assert!(
(v - fd).abs() <= 1e-6 * v.abs().max(1.0),
"n={n}: {v} vs fd {fd}"
);
}
}
#[test]
fn pol_vrho_matches_finite_difference() {
let f = Functional::new(FunctionalId::LdaCVwn, Spin::Polarized).unwrap();
let (na, nb) = (0.6, 0.25);
let e = |a: f64, b: f64| (a + b) * f.eval(1, &XcInput::lda(&[a, b])).unwrap().exc[0];
let out = f.eval(1, &XcInput::lda(&[na, nb])).unwrap();
let ha = 1e-6 * na;
let hb = 1e-6 * nb;
let fda = (e(na + ha, nb) - e(na - ha, nb)) / (2.0 * ha);
let fdb = (e(na, nb + hb) - e(na, nb - hb)) / (2.0 * hb);
assert!((out.vrho[0] - fda).abs() <= 1e-6 * out.vrho[0].abs().max(1.0));
assert!((out.vrho[1] - fdb).abs() <= 1e-6 * out.vrho[1].abs().max(1.0));
}
#[test]
fn unpol_pol_symmetry_at_zero_polarization() {
let up = Functional::new(FunctionalId::LdaCVwn, Spin::Unpolarized).unwrap();
let po = Functional::new(FunctionalId::LdaCVwn, Spin::Polarized).unwrap();
let n = 0.9;
let ou = up.eval(1, &XcInput::lda(&[n])).unwrap();
let op = po.eval(1, &XcInput::lda(&[n / 2.0, n / 2.0])).unwrap();
assert!((ou.exc[0] - op.exc[0]).abs() <= 1e-13 * ou.exc[0].abs());
assert!((ou.vrho[0] - op.vrho[0]).abs() <= 1e-12 * ou.vrho[0].abs());
assert!((ou.vrho[0] - op.vrho[1]).abs() <= 1e-12 * ou.vrho[0].abs());
}
#[test]
fn edge_energy_and_derivatives_finite() {
let f = Functional::new(FunctionalId::LdaCVwn, Spin::Polarized).unwrap();
let rho = [
1.0, 0.0, 0.0, 1.0, 1e-3, 0.0, 1e-12, 1e-13, 100.0, 50.0, ];
let out = f.eval(5, &XcInput::lda(&rho)).unwrap();
for v in out.exc.iter().chain(&out.vrho) {
assert!(v.is_finite(), "non-finite output: {v}");
}
}
}