use super::memory_budget::GpuMemoryBudget;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GpuTier {
Tier1,
Tier2,
Tier3,
Tier4,
}
impl GpuTier {
pub fn gpu_recommended(&self) -> bool {
matches!(self, GpuTier::Tier1 | GpuTier::Tier2)
}
}
impl std::fmt::Display for GpuTier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
GpuTier::Tier1 => write!(f, "Tier 1 (massive speedup)"),
GpuTier::Tier2 => write!(f, "Tier 2 (significant speedup)"),
GpuTier::Tier3 => write!(f, "Tier 3 (moderate speedup)"),
GpuTier::Tier4 => write!(f, "Tier 4 (CPU preferred)"),
}
}
}
#[derive(Debug, Clone)]
pub struct ShaderDescriptor {
pub name: &'static str,
pub tier: GpuTier,
pub workgroup_size: [u32; 3],
pub entry_point: &'static str,
pub storage_bindings: u32,
pub uniform_bindings: u32,
pub description: &'static str,
}
impl ShaderDescriptor {
pub fn total_bindings(&self) -> u32 {
self.storage_bindings + self.uniform_bindings
}
pub fn check_against_budget(&self, budget: &GpuMemoryBudget) -> Result<(), String> {
let total = self.storage_bindings + self.uniform_bindings;
if total > budget.limits.max_storage_buffers_per_stage + 4 {
return Err(format!(
"Shader '{}' needs {} bindings, budget allows {}",
self.name, total, budget.limits.max_storage_buffers_per_stage
));
}
let invocations = self.workgroup_size[0] as u64
* self.workgroup_size[1] as u64
* self.workgroup_size[2] as u64;
if invocations > budget.limits.max_invocations_per_workgroup as u64 {
return Err(format!(
"Shader '{}' workgroup has {} invocations, max {}",
self.name, invocations, budget.limits.max_invocations_per_workgroup
));
}
Ok(())
}
}
pub const SHADER_VECTOR_ADD: ShaderDescriptor = ShaderDescriptor {
name: "vector_add",
tier: GpuTier::Tier4,
workgroup_size: [64, 1, 1],
entry_point: "main",
storage_bindings: 3, uniform_bindings: 1, description: "Element-wise vector addition (GPU smoke test)",
};
pub const SHADER_ORBITAL_GRID: ShaderDescriptor = ShaderDescriptor {
name: "orbital_grid",
tier: GpuTier::Tier1,
workgroup_size: [8, 8, 4],
entry_point: "main",
storage_bindings: 4, uniform_bindings: 1, description: "MO wavefunction on 3D grid (GPU Tier 1: O(grid × N_basis))",
};
pub const SHADER_MARCHING_CUBES: ShaderDescriptor = ShaderDescriptor {
name: "marching_cubes",
tier: GpuTier::Tier2,
workgroup_size: [4, 4, 4],
entry_point: "main",
storage_bindings: 5, uniform_bindings: 1, description: "Isosurface extraction via marching cubes (GPU Tier 2: O(voxels))",
};
pub const SHADER_ESP_GRID: ShaderDescriptor = ShaderDescriptor {
name: "esp_grid",
tier: GpuTier::Tier1,
workgroup_size: [8, 8, 4],
entry_point: "main",
storage_bindings: 4, uniform_bindings: 1, description: "Electrostatic potential on 3D grid (GPU Tier 1: O(grid × N²))",
};
pub const SHADER_D4_DISPERSION: ShaderDescriptor = ShaderDescriptor {
name: "d4_dispersion",
tier: GpuTier::Tier3,
workgroup_size: [16, 16, 1],
entry_point: "main",
storage_bindings: 3, uniform_bindings: 1, description: "D4 pairwise dispersion (GPU Tier 3: O(N²))",
};
pub const SHADER_EEQ_COULOMB: ShaderDescriptor = ShaderDescriptor {
name: "eeq_coulomb",
tier: GpuTier::Tier3,
workgroup_size: [16, 16, 1],
entry_point: "main",
storage_bindings: 3, uniform_bindings: 1, description: "EEQ damped Coulomb matrix gamma_ij (GPU Tier 3: O(N²))",
};
pub const SHADER_DENSITY_GRID: ShaderDescriptor = ShaderDescriptor {
name: "density_grid",
tier: GpuTier::Tier1,
workgroup_size: [8, 8, 4],
entry_point: "main",
storage_bindings: 4, uniform_bindings: 1, description: "Electron density on 3D grid (GPU Tier 1: O(grid × N²))",
};
pub const SHADER_TWO_ELECTRON: ShaderDescriptor = ShaderDescriptor {
name: "two_electron_eri",
tier: GpuTier::Tier1,
workgroup_size: [64, 1, 1],
entry_point: "main",
storage_bindings: 4, uniform_bindings: 1, description: "Two-electron repulsion integrals (GPU Tier 1: O(N⁴))",
};
pub const SHADER_FOCK_BUILD: ShaderDescriptor = ShaderDescriptor {
name: "fock_build",
tier: GpuTier::Tier1,
workgroup_size: [16, 16, 1],
entry_point: "main",
storage_bindings: 4, uniform_bindings: 1, description: "Fock matrix construction G(P) (GPU Tier 1: O(N⁴))",
};
pub const SHADER_ONE_ELECTRON: ShaderDescriptor = ShaderDescriptor {
name: "one_electron",
tier: GpuTier::Tier2,
workgroup_size: [16, 16, 1],
entry_point: "main",
storage_bindings: 4, uniform_bindings: 1, description: "One-electron matrices S,T,V (GPU Tier 2: O(N²))",
};
pub const SHADER_GAMMA_MATRIX: ShaderDescriptor = ShaderDescriptor {
name: "gamma_matrix",
tier: GpuTier::Tier3,
workgroup_size: [16, 16, 1],
entry_point: "main",
storage_bindings: 3, uniform_bindings: 1, description: "SCC-DFTB gamma matrix (GPU Tier 3: O(N²) pairwise Coulomb)",
};
pub const SHADER_ALPB_BORN_RADII: ShaderDescriptor = ShaderDescriptor {
name: "alpb_born_radii",
tier: GpuTier::Tier3,
workgroup_size: [64, 1, 1],
entry_point: "main",
storage_bindings: 3, uniform_bindings: 1, description: "ALPB Born radii (GPU Tier 3: O(N²) descreening)",
};
pub const SHADER_CPM_COULOMB: ShaderDescriptor = ShaderDescriptor {
name: "cpm_coulomb",
tier: GpuTier::Tier3,
workgroup_size: [16, 16, 1],
entry_point: "main",
storage_bindings: 2, uniform_bindings: 1, description: "CPM Coulomb matrix J_ij (GPU Tier 3: O(N²) pairwise electrostatics)",
};
pub const ALL_SHADERS: &[&ShaderDescriptor] = &[
&SHADER_VECTOR_ADD,
&SHADER_ORBITAL_GRID,
&SHADER_MARCHING_CUBES,
&SHADER_ESP_GRID,
&SHADER_D4_DISPERSION,
&SHADER_EEQ_COULOMB,
&SHADER_DENSITY_GRID,
&SHADER_TWO_ELECTRON,
&SHADER_FOCK_BUILD,
&SHADER_ONE_ELECTRON,
&SHADER_GAMMA_MATRIX,
&SHADER_ALPB_BORN_RADII,
&SHADER_CPM_COULOMB,
];
pub fn find_shader(name: &str) -> Option<&'static ShaderDescriptor> {
ALL_SHADERS.iter().find(|s| s.name == name).copied()
}
pub fn shaders_by_tier(tier: GpuTier) -> Vec<&'static ShaderDescriptor> {
ALL_SHADERS
.iter()
.filter(|s| s.tier == tier)
.copied()
.collect()
}
pub fn shader_catalogue_report() -> String {
let mut report = String::from("GPU Shader Catalogue\n====================\n\n");
for tier in &[
GpuTier::Tier1,
GpuTier::Tier2,
GpuTier::Tier3,
GpuTier::Tier4,
] {
let shaders = shaders_by_tier(*tier);
if shaders.is_empty() {
continue;
}
report.push_str(&format!("{tier}\n"));
for s in &shaders {
report.push_str(&format!(
" {} — wg[{},{},{}], {} bindings — {}\n",
s.name,
s.workgroup_size[0],
s.workgroup_size[1],
s.workgroup_size[2],
s.total_bindings(),
s.description,
));
}
report.push('\n');
}
report
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tier_display() {
assert_eq!(GpuTier::Tier1.to_string(), "Tier 1 (massive speedup)");
assert!(GpuTier::Tier1.gpu_recommended());
assert!(!GpuTier::Tier4.gpu_recommended());
}
#[test]
fn test_shader_lookup() {
let s = find_shader("orbital_grid").unwrap();
assert_eq!(s.tier, GpuTier::Tier1);
assert_eq!(s.workgroup_size, [8, 8, 4]);
}
#[test]
fn test_shader_lookup_missing() {
assert!(find_shader("nonexistent").is_none());
}
#[test]
fn test_shaders_by_tier() {
let t1 = shaders_by_tier(GpuTier::Tier1);
assert!(t1.len() >= 3); assert!(t1.iter().all(|s| s.tier == GpuTier::Tier1));
}
#[test]
fn test_budget_check_passes() {
let budget = GpuMemoryBudget::webgpu_default();
assert!(SHADER_ORBITAL_GRID.check_against_budget(&budget).is_ok());
assert!(SHADER_MARCHING_CUBES.check_against_budget(&budget).is_ok());
}
#[test]
fn test_catalogue_report() {
let report = shader_catalogue_report();
assert!(report.contains("orbital_grid"));
assert!(report.contains("Tier 1"));
assert!(report.contains("Tier 3"));
}
#[test]
fn test_total_bindings() {
assert_eq!(SHADER_ORBITAL_GRID.total_bindings(), 5);
assert_eq!(SHADER_VECTOR_ADD.total_bindings(), 4);
}
#[test]
fn test_all_shaders_registered() {
assert_eq!(ALL_SHADERS.len(), 13);
}
}