#![cfg(feature = "api-v7_1")]
use super::example_densities;
use libxc::prelude::{libxc_enum_items::*, *};
use rayon::prelude::*;
use std::collections::{HashMap, HashSet};
use std::sync::atomic::{AtomicUsize, Ordering};
type RefKey = (String, String, String);
lazy_static::lazy_static! {
static ref REF: Vec<(RefKey, HashMap<String, Vec<f64>>)> = {
let mut m = Vec::new();
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/regression/reference.toml");
let content = std::fs::read_to_string(path).expect("Failed to read reference.toml");
type RefData = HashMap<String, HashMap<String, HashMap<String, HashMap<String, Vec<f64>>>>>;
let data: RefData = toml::from_str(&content).expect("Failed to parse reference.toml");
for (category, xc_map) in data {
for (xc_name, species_map) in xc_map {
for (species, values) in species_map {
m.push(((category.clone(), xc_name.clone(), species.clone()), values));
}
}
}
m
};
static ref SKIPPED_CASES: HashSet<String> = {
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/regression/skipped_tests");
let content = std::fs::read_to_string(path).expect("Failed to read skipped_tests");
content.lines().filter(|l| !l.is_empty()).map(|l| l.to_string()).collect()
};
}
fn allclose(a: &[f64], b: &[f64], rtol: f64, atol: f64) -> bool {
if a.len() != b.len() {
return false;
}
a.iter().zip(b.iter()).all(|(&x, &y)| (x - y).abs() <= atol + rtol * y.abs())
}
fn get_error(a: &[f64], b: &[f64]) -> f64 {
let max_abs_diff = a.iter().zip(b.iter()).map(|(&x, &y)| (x - y).abs()).fold(0.0, f64::max);
let max_abs = a.iter().chain(b.iter()).map(|&v| v.abs()).fold(0.0, f64::max);
max_abs_diff / (1.0 + max_abs)
}
fn test_regression_entry(
category: &str,
xc_name: &str,
species: &str,
reference: &HashMap<String, Vec<f64>>,
) -> Result<(), String> {
let xc_id = format!("{}.{}.{}", category, xc_name, species);
if SKIPPED_CASES.contains(xc_id.as_str()) {
return Err("SKIP".to_string());
}
let spin = if species == "BrOH" || species.contains("restr") { Unpolarized } else { Polarized };
let input = example_densities::test_data(species.to_string(), spin);
let input_ref = input.iter().map(|(k, v)| (k.clone(), v.as_slice())).collect();
let xc_identifier = category.to_owned() + "_" + xc_name;
let xc = match LibXCFunctional::from_identifier_f(&xc_identifier, spin) {
Ok(f) => f,
Err(_) => return Err(format!("SKIP (unknown functional {})", xc_identifier)),
};
let (out_buffer, out_layout) = match xc.compute_xc(&input_ref, 1) {
Ok(r) => r,
Err(e) => {
return Err(format!("FAIL (failed to compute XC for {}), errmsg: {e}", xc_identifier))
},
};
for (key, ref_values) in reference {
let ref_out = match out_layout.get(key) {
Some(layout) => &out_buffer[layout],
None => return Err(format!("key {} not found in output for {}", key, xc_identifier)),
};
let (rtol, atol) = if key == "zk" { (5e-8, 1e-10) } else { (5e-5, 1e-7) };
let error_metric = get_error(ref_out, ref_values); if !allclose(ref_out, ref_values, rtol, atol) && error_metric > rtol {
return Err(format!(
"mismatch for {}.{} key={}, len={} vs {}, error_metric={:.2e}",
xc_identifier,
species,
key,
ref_out.len(),
ref_values.len(),
error_metric
));
}
}
Ok(())
}
#[test]
#[ignore = r#"
probably problems in libxc:
- no exc capibilities: mgga_x_bj06, gga_x_lbm, gga_x_lb, mgga_x_rpp09, mgga_x_tb09, lda_xc_tih
- numerical problem: gga_x.pbepow.Li_restr.vsigma, error_metric=7.55e-4
"#]
fn test_regression() {
let pass_count = AtomicUsize::new(0);
let skip_count = AtomicUsize::new(0);
let fail_count = AtomicUsize::new(0);
REF.par_iter().for_each(|((category, xc_name, species), reference)| {
let xc_id = format!("{}.{}.{}", category, xc_name, species);
match test_regression_entry(category, xc_name, species, reference) {
Ok(()) => {
pass_count.fetch_add(1, Ordering::Relaxed);
},
Err(msg) if msg.starts_with("SKIP") => {
skip_count.fetch_add(1, Ordering::Relaxed);
eprintln!("SKIP: {} - {}", xc_id, msg);
},
Err(msg) => {
fail_count.fetch_add(1, Ordering::Relaxed);
eprintln!("FAIL: {} - {}", xc_id, msg);
},
}
});
let passes = pass_count.load(Ordering::Relaxed);
let skips = skip_count.load(Ordering::Relaxed);
let fails = fail_count.load(Ordering::Relaxed);
eprintln!("\n=== Results: {} passed, {} skipped, {} failed ===", passes, skips, fails);
assert_eq!(fails, 0, "{} regression test(s) failed", fails);
}