use nalgebra::DMatrix;
use super::{assemble_periodic_operators, KMesh};
use crate::gpu::context::GpuContext;
const GPU_PERIODIC_THRESHOLD: usize = 64;
pub fn assemble_periodic_operators_gpu(
ctx: &GpuContext,
h_real_space: &[(&DMatrix<f64>, [i32; 3])],
s_real_space: &[(&DMatrix<f64>, [i32; 3])],
kmesh: &KMesh,
) -> (Vec<DMatrix<f64>>, Vec<DMatrix<f64>>) {
if !ctx.capabilities.gpu_available || kmesh.points.len() < GPU_PERIODIC_THRESHOLD {
return assemble_periodic_operators(h_real_space, s_real_space, kmesh);
}
assemble_periodic_operators(h_real_space, s_real_space, kmesh)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::alpha::periodic_linear::{KMeshCentering, KMeshConfig};
#[test]
fn gpu_assembly_falls_back_to_cpu() {
let ctx = GpuContext::cpu_fallback();
let h = DMatrix::<f64>::identity(4, 4);
let h_pairs: Vec<(&DMatrix<f64>, [i32; 3])> = vec![(&h, [0, 0, 0])];
let s_pairs: Vec<(&DMatrix<f64>, [i32; 3])> = vec![(&h, [0, 0, 0])];
let mesh = crate::alpha::periodic_linear::monkhorst_pack_mesh(&KMeshConfig {
grid: [3, 3, 3],
centering: KMeshCentering::MonkhorstPack,
})
.unwrap();
let (hk_gpu, sk_gpu) = assemble_periodic_operators_gpu(&ctx, &h_pairs, &s_pairs, &mesh);
let (hk_cpu, sk_cpu) = assemble_periodic_operators(&h_pairs, &s_pairs, &mesh);
assert_eq!(hk_gpu.len(), hk_cpu.len());
assert_eq!(sk_gpu.len(), sk_cpu.len());
for (g, c) in hk_gpu.iter().zip(&hk_cpu) {
assert!((g - c).norm() < 1e-12, "GPU/CPU H(k) parity violated");
}
}
}