use integral::{Basis, BraPairFill, Engine, EriBuilder, Shell};
fn mixed_basis() -> Basis {
Basis::new(vec![
Shell::new(0, [0.0, 0.0, 0.0], vec![1.4, 0.5], vec![0.6, 0.5]).unwrap(),
Shell::new(1, [0.6, -0.3, 0.2], vec![0.9], vec![1.0]).unwrap(),
Shell::new(2, [-0.4, 0.7, -0.1], vec![1.1], vec![1.0]).unwrap(),
Shell::new(3, [0.2, 0.5, 0.8], vec![0.7], vec![1.0]).unwrap(),
])
}
fn spherical_basis() -> Basis {
Basis::new(vec![
Shell::new(0, [0.0, 0.0, 0.0], vec![1.2, 0.4], vec![0.5, 0.6]).unwrap(),
Shell::new_spherical(1, [0.5, 0.1, -0.2], vec![0.8], vec![1.0]).unwrap(),
Shell::new_spherical(2, [-0.3, 0.4, 0.6], vec![1.0], vec![1.0]).unwrap(),
Shell::new(0, [-0.3, 0.4, 0.6], vec![0.7], vec![1.0]).unwrap(),
])
}
fn ao_to_shell(basis: &Basis) -> Vec<usize> {
let mut map = Vec::new();
for (s, shell) in basis.shells().iter().enumerate() {
for _ in 0..shell.n_func() {
map.push(s);
}
}
map
}
fn canon(a: usize, b: usize) -> (usize, usize) {
if a >= b {
(a, b)
} else {
(b, a)
}
}
fn pair_ge(p: (usize, usize), q: (usize, usize)) -> bool {
p.0 > q.0 || (p.0 == q.0 && p.1 >= q.1)
}
fn assert_matches_reference(basis: &Basis, candidate: &[f64], reference: &[f64]) -> usize {
assert_eq!(candidate.len(), reference.len());
let nao = basis.nao();
assert_eq!(reference.len(), nao.pow(4));
let shell_of = ao_to_shell(basis);
let peak = reference.iter().fold(0.0_f64, |m, &x| m.max(x.abs()));
let floor = 1e-3 * peak;
let mut worst_sig = 0.0_f64;
let mut worst_abs = 0.0_f64;
let mut bit_mismatches = 0usize;
for mu in 0..nao {
for nu in 0..nao {
let bp = canon(shell_of[mu], shell_of[nu]);
for la in 0..nao {
for sg in 0..nao {
let idx = ((mu * nao + nu) * nao + la) * nao + sg;
let (r, c) = (reference[idx], candidate[idx]);
let dv = (r - c).abs();
worst_abs = worst_abs.max(dv);
if r.abs() >= floor {
worst_sig = worst_sig.max(dv / r.abs());
}
let kp = canon(shell_of[la], shell_of[sg]);
if pair_ge(bp, kp) {
assert_eq!(
r.to_bits(),
c.to_bits(),
"non-swapped element ({mu}{nu}|{la}{sg}) must be bit-identical: \
ref={r:e} cand={c:e}"
);
} else if r.to_bits() != c.to_bits() {
bit_mismatches += 1;
}
}
}
}
}
assert!(
worst_sig < 1e-11,
"worst significant-element relative diff {worst_sig:e} exceeds 1e-11"
);
assert!(
worst_abs < 1e-11 * peak.max(1.0) + 1e-12,
"worst absolute diff {worst_abs:e} exceeds floor (peak {peak:e})"
);
bit_mismatches
}
#[test]
fn build_matches_eri_auto() {
let basis = mixed_basis();
let reference = basis.eri();
let built = EriBuilder::new(&basis).build();
let mismatches = assert_matches_reference(&basis, &built, &reference);
assert!(
mismatches > 0,
"expected the bra↔ket-swapped subset to differ at the bit level"
);
}
#[test]
fn build_matches_eri_forced_engines() {
let basis = mixed_basis();
for engine in [Engine::OsHgp, Engine::Rys] {
let reference = basis.eri_with(engine);
let built = EriBuilder::with_engine(&basis, engine).build();
assert_matches_reference(&basis, &built, &reference);
}
}
#[test]
fn build_matches_eri_spherical() {
let basis = spherical_basis();
let reference = basis.eri();
let built = EriBuilder::new(&basis).build();
assert_matches_reference(&basis, &built, &reference);
}
#[test]
fn partition_fill_equals_build_any_order() {
let basis = mixed_basis();
let builder = EriBuilder::new(&basis);
let serial = builder.build();
let mut fwd = vec![0.0; builder.output_len()];
{
let mut tasks = builder.partition(&mut fwd);
for t in &mut tasks {
builder.fill(t);
}
}
assert_eq!(fwd, serial, "forward partition+fill differs from build()");
let mut rev = vec![0.0; builder.output_len()];
{
let mut tasks = builder.partition(&mut rev);
for t in tasks.iter_mut().rev() {
builder.fill(t);
}
}
assert_eq!(
rev, serial,
"reverse-order partition+fill differs from build()"
);
}
#[test]
fn partition_into_prefilled_buffer_overwrites_cleanly() {
let basis = mixed_basis();
let builder = EriBuilder::new(&basis);
let serial = builder.build();
let mut buf = vec![1234.5; builder.output_len()];
{
let mut tasks = builder.partition(&mut buf);
for t in &mut tasks {
builder.fill(t);
}
}
assert_eq!(buf, serial, "build did not overwrite every element");
}
#[test]
fn bra_pairs_are_canonical_and_align_with_tasks() {
let basis = mixed_basis();
let builder = EriBuilder::new(&basis);
let nsh = basis.shells().len();
let mut expected = Vec::new();
for i in 0..nsh {
for j in 0..=i {
expected.push((i, j));
}
}
assert_eq!(builder.bra_pairs(), expected.as_slice());
assert_eq!(builder.bra_pairs().len(), nsh * (nsh + 1) / 2);
let mut out = vec![0.0; builder.output_len()];
let tasks = builder.partition(&mut out);
assert_eq!(tasks.len(), builder.bra_pairs().len());
for (task, &pair) in tasks.iter().zip(builder.bra_pairs()) {
assert_eq!(task.bra(), pair);
assert!(pair.0 >= pair.1, "bra-pair {pair:?} is not canonical");
}
}
#[test]
fn types_are_thread_safe_for_external_drivers() {
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
assert_send::<BraPairFill<'_>>();
assert_sync::<EriBuilder<'_>>();
assert_send::<EriBuilder<'_>>();
}
#[test]
fn output_len_is_nao_pow4() {
let basis = mixed_basis();
let builder = EriBuilder::new(&basis);
assert_eq!(builder.output_len(), basis.nao().pow(4));
}
#[test]
#[should_panic(expected = "nao⁴")]
fn partition_rejects_wrong_buffer_size() {
let basis = mixed_basis();
let builder = EriBuilder::new(&basis);
let mut wrong = vec![0.0; builder.output_len() - 1];
let _ = builder.partition(&mut wrong);
}