use std::collections::BTreeMap;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use xcx::{Functional, FunctionalId, Spin, XcInput};
const SEED: u64 = 0xF02D_C0FF_EE5E_ED42;
const RANDOM_UNPOL: usize = 2000;
const RANDOM_POL: usize = 3000;
const FXC_SIGMA_CAP: f64 = 1e12;
const FXC_REDUCED_GRAD_MAX: f64 = 50.0;
struct Census {
evals: u64,
components: u64,
by_region: BTreeMap<&'static str, u64>,
fxc_evals: u64,
fxc_components: u64,
fxc_skipped: u64,
fxc_by_region: BTreeMap<&'static str, u64>,
}
impl Census {
fn new() -> Self {
Self {
evals: 0,
components: 0,
by_region: BTreeMap::new(),
fxc_evals: 0,
fxc_components: 0,
fxc_skipped: 0,
fxc_by_region: BTreeMap::new(),
}
}
#[allow(clippy::too_many_arguments)]
fn check(
&mut self,
f: &Functional,
id: FunctionalId,
spin: Spin,
rho: &[f64],
sigma: Option<&[f64]>,
tau: Option<&[f64]>,
regions: &[&'static str],
) {
let input = match (sigma, tau) {
(Some(s), Some(t)) => XcInput::gga(rho, s).with_tau(t),
(Some(s), None) => XcInput::gga(rho, s),
(None, _) => XcInput::lda(rho),
};
let r = f.eval(1, &input).unwrap_or_else(|e| {
panic!(
"{} (id {}) spin {spin:?}: eval errored on finite input {e:?}: rho=[{}] sigma={} tau={}",
f.info().name,
id.as_u32(),
fmt_slice(rho),
sigma.map_or("None".to_string(), |s| format!("[{}]", fmt_slice(s))),
tau.map_or("None".to_string(), |t| format!("[{}]", fmt_slice(t))),
)
});
let comps: [(&str, &[f64]); 5] = [
("exc", &r.exc),
("vrho", &r.vrho),
("vsigma", &r.vsigma),
("vtau", &r.vtau),
("vlapl", &r.vlapl),
];
for (name, slice) in comps {
for (k, &v) in slice.iter().enumerate() {
self.components += 1;
assert!(
v.is_finite(),
"NON-FINITE {name}[{k}] = {v} :: {} (id {}) spin {spin:?} region {regions:?}\n \
rho = [{}]\n sigma = {}",
f.info().name,
id.as_u32(),
fmt_slice(rho),
sigma.map_or("None".to_string(), |s| format!("[{}]", fmt_slice(s))),
);
}
}
if fxc_in_domain(rho, sigma, f.info().dens_threshold) {
let r = f.eval_fxc(1, &input).unwrap_or_else(|e| {
panic!(
"{} (id {}) spin {spin:?}: eval_fxc errored on finite input {e:?}: \
rho=[{}] sigma={} tau={}",
f.info().name,
id.as_u32(),
fmt_slice(rho),
sigma.map_or("None".to_string(), |s| format!("[{}]", fmt_slice(s))),
tau.map_or("None".to_string(), |t| format!("[{}]", fmt_slice(t))),
)
});
let fxc: [(&str, &[f64]); 6] = [
("v2rho2", &r.v2rho2),
("v2rhosigma", &r.v2rhosigma),
("v2sigma2", &r.v2sigma2),
("v2rhotau", &r.v2rhotau),
("v2sigmatau", &r.v2sigmatau),
("v2tau2", &r.v2tau2),
];
for (name, slice) in fxc {
for (k, &v) in slice.iter().enumerate() {
self.fxc_components += 1;
assert!(
v.is_finite(),
"NON-FINITE fxc {name}[{k}] = {v} :: {} (id {}) spin {spin:?} region {regions:?}\n \
rho = [{}]\n sigma = {}",
f.info().name,
id.as_u32(),
fmt_slice(rho),
sigma.map_or("None".to_string(), |s| format!("[{}]", fmt_slice(s))),
);
}
}
self.fxc_evals += 1;
for ® in regions {
*self.fxc_by_region.entry(reg).or_insert(0) += 1;
}
} else {
self.fxc_skipped += 1;
}
self.evals += 1;
for ® in regions {
*self.by_region.entry(reg).or_insert(0) += 1;
}
}
}
fn fxc_in_domain(rho: &[f64], sigma: Option<&[f64]>, thr: f64) -> bool {
let total: f64 = rho.iter().sum();
if total < thr {
return true; }
let sig = match sigma {
Some(s) => s,
None => return true, };
if sig.iter().any(|&x| x.abs() > FXC_SIGMA_CAP) {
return false;
}
let st = thr.powf(4.0 / 3.0);
let sfloor = st * st; let red = |s: f64, nch: f64| s.max(sfloor).sqrt() / nch.max(thr).powf(4.0 / 3.0);
if rho.len() == 1 {
let n = rho[0];
red(sig[0] / 4.0, n / 2.0) <= FXC_REDUCED_GRAD_MAX && red(sig[0], n) <= FXC_REDUCED_GRAD_MAX
} else {
let (na, nb) = (rho[0], rho[1]);
let (saa, sab, sbb) = (sig[0], sig[1], sig[2]);
let saa_f = saa.max(sfloor);
let sbb_f = sbb.max(sfloor);
let s_ave = 0.5 * (saa_f + sbb_f);
let sab_c = sab.clamp(-s_ave, s_ave);
let sigma_tot = (saa_f + 2.0 * sab_c + sbb_f).max(0.0);
let nt = na + nb;
red(saa, na) <= FXC_REDUCED_GRAD_MAX
&& red(sbb, nb) <= FXC_REDUCED_GRAD_MAX
&& sigma_tot.sqrt() / nt.max(thr).powf(4.0 / 3.0) <= FXC_REDUCED_GRAD_MAX
}
}
fn fmt_slice(s: &[f64]) -> String {
s.iter()
.map(|x| format!("{x:.17e}"))
.collect::<Vec<_>>()
.join(", ")
}
fn unpol_densities() -> Vec<(f64, &'static str)> {
let mut v = Vec::new();
for &n in &[
1e-14, 1e-13, 1e-12, 1e-10, 1e-8, 1e-6, 1e-4, 1e-3, 1e-2, 0.1, 0.3, 1.0, 3.0, 10.0, 31.6,
100.0, 316.0, 1000.0,
] {
v.push((n, "extreme_rs"));
}
for &thr in &[1e-15, 1e-14, 1e-12] {
v.push((0.5 * thr, "dens_threshold_straddle"));
v.push((thr * (1.0 - 1e-6), "dens_threshold_straddle"));
v.push((thr, "dens_threshold_straddle"));
v.push((thr * (1.0 + 1e-6), "dens_threshold_straddle"));
v.push((2.0 * thr, "dens_threshold_straddle"));
}
for &n in &[0.0, 1e-300, 1e-18] {
v.push((n, "sub_threshold"));
}
v
}
fn unpol_sigmas(n: f64) -> Vec<(f64, &'static str)> {
let mut v = vec![(0.0, "sigma_zero"), (1e-30, "sigma_tiny")];
for &s in &[1e-6, 1e-8, 1e-10] {
v.push((s, "sigma_small_band")); }
let nn = n.max(1e-12);
for &s in &[0.01, 0.1, 1.0, 10.0, 100.0, 1000.0] {
v.push((s * s * nn.powf(8.0 / 3.0), "reduced_grad_scaled"));
}
v.push((1e6, "sigma_huge_abs"));
v.push((1e12, "sigma_huge_abs"));
v
}
fn pol_density_pairs() -> Vec<(f64, f64, &'static str)> {
let mut v = Vec::new();
for &n in &[1e-12, 1e-6, 1e-3, 1.0, 1e3] {
v.push((n, 0.0, "full_pol_exact")); v.push((0.0, n, "full_pol_exact")); }
for &r in &[1e-4, 1e-6, 1e-8, 1e-10, 1e-12] {
v.push((1.0, r, "near_full_pol"));
v.push((r, 1.0, "near_full_pol"));
v.push((1e3, 1e3 * r, "near_full_pol")); }
for &n in &[1e-12, 1e-6, 1e-2, 1.0, 100.0, 1000.0] {
v.push((n / 2.0, n / 2.0, "equal_spin")); }
for &(a, b) in &[(0.7, 0.3), (0.3, 0.7), (0.9, 0.1), (0.1, 0.9), (0.6, 0.4)] {
v.push((a, b, "asymmetric"));
}
for &thr in &[1e-15, 1e-14, 1e-12] {
v.push((0.6 * thr, 0.6 * thr, "dens_threshold_straddle")); v.push((0.4 * thr, 0.4 * thr, "dens_threshold_straddle")); v.push((thr, thr, "dens_threshold_straddle"));
v.push((0.5, 0.1 * thr, "one_channel_subthreshold"));
v.push((0.1 * thr, 0.5, "one_channel_subthreshold"));
}
v.push((5e-15, 5e-15, "extreme_rs"));
v.push((500.0, 500.0, "extreme_rs"));
v.push((1e-300, 1e-300, "sub_threshold")); v
}
fn base_sigma_pairs(na: f64, nb: f64) -> Vec<(f64, f64)> {
let scaled = |n: f64, s: f64| {
let nn = n.max(1e-10);
s * s * nn.powf(8.0 / 3.0)
};
let mut v = vec![(0.0, 0.0)];
for &s in &[0.1, 1.0, 100.0] {
v.push((scaled(na, s), scaled(nb, s)));
}
v.push((1e-6, 1e-6));
v.push((1e-8, 1e-8));
v.push((1e-3, 0.0));
v.push((0.0, 1e-3));
v.push((1.0, 1e-12));
v.push((1e-12, 1.0));
v.push((1e6, 1e-6));
v.push((1e-6, 1e6));
v
}
fn pol_sigma_triples(na: f64, nb: f64) -> Vec<(f64, f64, f64, &'static str)> {
let mut v = Vec::new();
for (saa, sbb) in base_sigma_pairs(na, nb) {
let s_ave = 0.5 * (saa + sbb);
let gm = (saa * sbb).sqrt();
v.push((saa, 0.0, sbb, "sab_zero"));
v.push((saa, s_ave, sbb, "sab_clamp_edge_hi"));
v.push((saa, -s_ave, sbb, "sab_clamp_edge_lo")); for &eps in &[1e-12, 1e-3] {
v.push((saa, s_ave * (1.0 + eps), sbb, "sab_clamp_outside_hi"));
v.push((saa, s_ave * (1.0 - eps), sbb, "sab_clamp_inside_hi"));
v.push((saa, -s_ave * (1.0 + eps), sbb, "sab_clamp_outside_lo")); v.push((saa, -s_ave * (1.0 - eps), sbb, "sab_clamp_inside_lo"));
}
v.push((saa, 10.0 * s_ave, sbb, "sab_clamp_outside_hi"));
v.push((saa, -10.0 * s_ave, sbb, "sab_clamp_outside_lo"));
v.push((saa, gm, sbb, "gm_boundary"));
v.push((saa, -gm, sbb, "gm_boundary"));
v.push((saa, gm * (1.0 + 1e-3), sbb, "gm_outside"));
v.push((saa, gm * (1.0 - 1e-3), sbb, "gm_inside"));
v.push((saa, -gm * (1.0 + 1e-3), sbb, "gm_outside"));
v.push((saa, -gm * (1.0 - 1e-3), sbb, "gm_inside"));
}
v.push((0.2, -0.2, 0.2, "sigma_tot_zero_exact"));
v.push((0.0, 0.0, 0.0, "sigma_tot_zero_exact"));
v
}
const K_FACTOR_C: f64 = 4.557_799_872_345_596;
fn unpol_taus(n: f64, sigma: f64) -> Vec<(f64, &'static str)> {
let nn = n.max(1e-300);
let tw = sigma / (8.0 * nn); let tunif = K_FACTOR_C * nn.powf(5.0 / 3.0);
vec![
(0.0, "tau_zero"),
(tw, "tau_vw_edge"), (tw + tunif, "alpha_one"), (tw + 5.0 * tunif, "alpha_large"),
(1e8, "tau_huge"),
]
}
fn pol_taus(na: f64, nb: f64) -> Vec<(f64, f64, &'static str)> {
let ka = K_FACTOR_C * na.max(1e-300).powf(5.0 / 3.0);
let kb = K_FACTOR_C * nb.max(1e-300).powf(5.0 / 3.0);
vec![
(0.0, 0.0, "tau_zero"),
(ka, kb, "alpha_one"),
(5.0 * ka, 5.0 * kb, "alpha_large"),
(ka, 0.0, "tau_minority_floor"),
(1e8, 1e8, "tau_huge"),
]
}
fn loguniform(rng: &mut StdRng, lo: f64, hi: f64) -> f64 {
let (a, b) = (lo.ln(), hi.ln());
(a + (b - a) * rng.gen::<f64>()).exp()
}
fn pick_sigma(rng: &mut StdRng) -> f64 {
if rng.gen::<f64>() < 0.15 {
0.0
} else {
loguniform(rng, 1e-20, 1e8)
}
}
fn random_unpol(rng: &mut StdRng, count: usize) -> Vec<(f64, f64)> {
(0..count)
.map(|_| {
let n = loguniform(rng, 1e-16, 1e4);
let s = if rng.gen::<f64>() < 0.15 {
0.0
} else {
loguniform(rng, 1e-20, 1e8)
};
(n, s)
})
.collect()
}
fn random_pol(rng: &mut StdRng, count: usize) -> Vec<(f64, f64, f64, f64, f64)> {
(0..count)
.map(|_| {
let na = loguniform(rng, 1e-16, 1e4);
let nb = loguniform(rng, 1e-16, 1e4);
let saa = pick_sigma(rng);
let sbb = pick_sigma(rng);
let s_ave = 0.5 * (saa + sbb);
let scale = if rng.gen::<f64>() < 0.1 { 50.0 } else { 3.0 };
let sab = (2.0 * rng.gen::<f64>() - 1.0) * scale * s_ave;
(na, nb, saa, sab, sbb)
})
.collect()
}
#[test]
fn fuzz_all_functionals_finite() {
let mut census = Census::new();
let unpol_dens = unpol_densities();
let pol_pairs = pol_density_pairs();
let mut rng = StdRng::seed_from_u64(SEED);
let unpol_rand = random_unpol(&mut rng, RANDOM_UNPOL);
let pol_rand = random_pol(&mut rng, RANDOM_POL);
for &id in FunctionalId::ALL {
for &spin in &[Spin::Unpolarized, Spin::Polarized] {
let f = Functional::new(id, spin).expect("v0.1 functional must build");
let needs_sigma = f.info().needs_sigma;
let needs_tau = f.info().needs_tau;
match spin {
Spin::Unpolarized => {
for &(n, dreg) in &unpol_dens {
if needs_sigma {
for (s, sreg) in unpol_sigmas(n) {
if needs_tau {
for (t, treg) in unpol_taus(n, s) {
census.check(
&f,
id,
spin,
&[n],
Some(&[s]),
Some(&[t]),
&[dreg, sreg, treg],
);
}
} else {
census.check(
&f,
id,
spin,
&[n],
Some(&[s]),
None,
&[dreg, sreg],
);
}
}
} else {
census.check(&f, id, spin, &[n], None, None, &[dreg]);
}
}
for &(n, s) in &unpol_rand {
if needs_sigma {
if needs_tau {
for (t, _) in unpol_taus(n, s).into_iter().take(3) {
census.check(
&f,
id,
spin,
&[n],
Some(&[s]),
Some(&[t]),
&["random"],
);
}
} else {
census.check(&f, id, spin, &[n], Some(&[s]), None, &["random"]);
}
} else {
census.check(&f, id, spin, &[n], None, None, &["random"]);
}
}
}
Spin::Polarized => {
for &(na, nb, dreg) in &pol_pairs {
if needs_sigma {
for (saa, sab, sbb, sreg) in pol_sigma_triples(na, nb) {
if needs_tau {
for (ta, tb, treg) in pol_taus(na, nb) {
census.check(
&f,
id,
spin,
&[na, nb],
Some(&[saa, sab, sbb]),
Some(&[ta, tb]),
&[dreg, sreg, treg],
);
}
} else {
census.check(
&f,
id,
spin,
&[na, nb],
Some(&[saa, sab, sbb]),
None,
&[dreg, sreg],
);
}
}
} else {
census.check(&f, id, spin, &[na, nb], None, None, &[dreg]);
}
}
for &(na, nb, saa, sab, sbb) in &pol_rand {
if needs_sigma {
if needs_tau {
for (ta, tb, _) in pol_taus(na, nb).into_iter().take(3) {
census.check(
&f,
id,
spin,
&[na, nb],
Some(&[saa, sab, sbb]),
Some(&[ta, tb]),
&["random"],
);
}
} else {
census.check(
&f,
id,
spin,
&[na, nb],
Some(&[saa, sab, sbb]),
None,
&["random"],
);
}
} else {
census.check(&f, id, spin, &[na, nb], None, None, &["random"]);
}
}
}
_ => unreachable!("unknown Spin variant"),
}
}
}
println!("\n=== xcx fuzz coverage (seed = 0x{SEED:016X}) ===");
println!("functionals: {} × spins: 2", FunctionalId::ALL.len());
println!(
"total evaluations: {} finiteness checks: {}",
census.evals, census.components
);
println!("evaluations crediting each region (summed over all functionals/spins):");
for (region, n) in &census.by_region {
let fxc = census.fxc_by_region.get(region).copied().unwrap_or(0);
println!(" {region:<28} vxc {n:>9} fxc {fxc:>9}");
}
println!(
"fxc: {} evaluations checked, {} finiteness checks, {} skipped (out of physical domain)",
census.fxc_evals, census.fxc_components, census.fxc_skipped
);
println!("=== all outputs finite ===\n");
}
#[test]
fn batch_matches_single_point() {
let pol: &[(f64, f64, f64, f64, f64)] = &[
(0.6, 0.3, 0.1, 0.05, 0.08), (1.0, 0.0, 0.2, 0.0, 0.0), (0.5, 0.5, 0.2, -0.2, 0.2), (1.0, 1e-10, 0.1, 0.0, 1e-8), (1e-12, 1e-12, 0.0, 0.0, 0.0), (2.0, 1.0, 1e6, 0.0, 1e6), ];
let taus: &[(f64, f64)] = &[
(0.4, 0.25),
(0.5, 1e-20),
(0.3, 0.3),
(0.6, 1e-12),
(1e-20, 1e-20),
(5.0, 3.0),
];
for &id in FunctionalId::ALL {
let f = Functional::new(id, Spin::Polarized).unwrap();
let needs_sigma = f.info().needs_sigma;
let needs_tau = f.info().needs_tau;
let np = pol.len();
let mut rho = Vec::with_capacity(2 * np);
let mut sig = Vec::with_capacity(3 * np);
let mut tau = Vec::with_capacity(2 * np);
for (&(na, nb, saa, sab, sbb), &(ta, tb)) in pol.iter().zip(taus) {
rho.push(na);
rho.push(nb);
sig.push(saa);
sig.push(sab);
sig.push(sbb);
tau.push(ta);
tau.push(tb);
}
let input = match (needs_sigma, needs_tau) {
(true, true) => XcInput::gga(&rho, &sig).with_tau(&tau),
(true, false) => XcInput::gga(&rho, &sig),
_ => XcInput::lda(&rho),
};
let batch = f.eval(np, &input).unwrap();
for v in batch
.exc
.iter()
.chain(&batch.vrho)
.chain(&batch.vsigma)
.chain(&batch.vtau)
.chain(&batch.vlapl)
{
assert!(
v.is_finite(),
"{}: non-finite batch output {v}",
f.info().name
);
}
for (i, (&(na, nb, saa, sab, sbb), &(ta, tb))) in pol.iter().zip(taus).enumerate() {
let rho1 = [na, nb];
let sig1 = [saa, sab, sbb];
let tau1 = [ta, tb];
let one = match (needs_sigma, needs_tau) {
(true, true) => f
.eval(1, &XcInput::gga(&rho1, &sig1).with_tau(&tau1))
.unwrap(),
(true, false) => f.eval(1, &XcInput::gga(&rho1, &sig1)).unwrap(),
_ => f.eval(1, &XcInput::lda(&rho1)).unwrap(),
};
assert_eq!(
batch.exc[i].to_bits(),
one.exc[0].to_bits(),
"{} exc[{i}]",
f.info().name
);
for c in 0..2 {
assert_eq!(
batch.vrho[2 * i + c].to_bits(),
one.vrho[c].to_bits(),
"{} vrho[{i}][{c}]",
f.info().name
);
}
if needs_sigma {
for c in 0..3 {
assert_eq!(
batch.vsigma[3 * i + c].to_bits(),
one.vsigma[c].to_bits(),
"{} vsigma[{i}][{c}]",
f.info().name
);
}
}
if needs_tau {
for c in 0..2 {
assert_eq!(
batch.vtau[2 * i + c].to_bits(),
one.vtau[c].to_bits(),
"{} vtau[{i}][{c}]",
f.info().name
);
}
}
}
}
}