#[derive(Debug, Clone)]
pub struct GpuMemoryLimits {
pub max_storage_buffer_binding_size: u64,
pub max_buffer_size: u64,
pub max_uniform_buffer_binding_size: u64,
pub max_storage_buffers_per_stage: u32,
pub max_workgroup_storage_size: u32,
pub max_invocations_per_workgroup: u32,
pub max_workgroup_size_x: u32,
pub max_workgroup_size_y: u32,
pub max_workgroup_size_z: u32,
pub max_workgroups_per_dimension: u32,
}
impl Default for GpuMemoryLimits {
fn default() -> Self {
Self {
max_storage_buffer_binding_size: 134_217_728, max_buffer_size: 268_435_456, max_uniform_buffer_binding_size: 65_536, max_storage_buffers_per_stage: 8,
max_workgroup_storage_size: 16_384, max_invocations_per_workgroup: 256,
max_workgroup_size_x: 256,
max_workgroup_size_y: 256,
max_workgroup_size_z: 64,
max_workgroups_per_dimension: 65535,
}
}
}
#[derive(Debug, Clone)]
pub struct GpuMemoryBudget {
pub limits: GpuMemoryLimits,
pub allocated_bytes: u64,
pub storage_bindings_used: u32,
}
impl GpuMemoryBudget {
pub fn new(limits: GpuMemoryLimits) -> Self {
Self {
limits,
allocated_bytes: 0,
storage_bindings_used: 0,
}
}
pub fn webgpu_default() -> Self {
Self::new(GpuMemoryLimits::default())
}
pub fn check_buffer(&self, size_bytes: u64) -> Result<(), MemoryError> {
if size_bytes > self.limits.max_buffer_size {
return Err(MemoryError::BufferTooLarge {
requested: size_bytes,
max: self.limits.max_buffer_size,
});
}
if size_bytes > self.limits.max_storage_buffer_binding_size {
return Err(MemoryError::BindingTooLarge {
requested: size_bytes,
max: self.limits.max_storage_buffer_binding_size,
});
}
Ok(())
}
pub fn check_storage_binding(&self) -> Result<(), MemoryError> {
if self.storage_bindings_used >= self.limits.max_storage_buffers_per_stage {
return Err(MemoryError::TooManyBindings {
current: self.storage_bindings_used,
max: self.limits.max_storage_buffers_per_stage,
});
}
Ok(())
}
pub fn check_workgroup(&self, size: [u32; 3], count: [u32; 3]) -> Result<(), MemoryError> {
if size[0] > self.limits.max_workgroup_size_x
|| size[1] > self.limits.max_workgroup_size_y
|| size[2] > self.limits.max_workgroup_size_z
{
return Err(MemoryError::WorkgroupSizeExceeded {
requested: size,
max: [
self.limits.max_workgroup_size_x,
self.limits.max_workgroup_size_y,
self.limits.max_workgroup_size_z,
],
});
}
let total = size[0] as u64 * size[1] as u64 * size[2] as u64;
if total > self.limits.max_invocations_per_workgroup as u64 {
return Err(MemoryError::TooManyInvocations {
requested: total as u32,
max: self.limits.max_invocations_per_workgroup,
});
}
for (i, &c) in count.iter().enumerate() {
if c > self.limits.max_workgroups_per_dimension {
return Err(MemoryError::TooManyWorkgroups {
dimension: i as u32,
requested: c,
max: self.limits.max_workgroups_per_dimension,
});
}
}
Ok(())
}
pub fn estimate_orbital_grid_memory(
&self,
n_basis: usize,
grid_points: usize,
max_primitives_per_basis: usize,
) -> Result<OrbitalGridMemoryEstimate, MemoryError> {
let basis_bytes = (n_basis * 32) as u64;
let mo_bytes = (n_basis * 4) as u64;
let prim_bytes = (n_basis * max_primitives_per_basis * 8) as u64;
let params_bytes = 32u64;
let output_bytes = (grid_points * 4) as u64;
let total = basis_bytes + mo_bytes + prim_bytes + params_bytes + output_bytes;
self.check_buffer(basis_bytes)?;
self.check_buffer(mo_bytes)?;
self.check_buffer(prim_bytes)?;
self.check_buffer(output_bytes)?;
if 5 > self.limits.max_storage_buffers_per_stage {
return Err(MemoryError::TooManyBindings {
current: 5,
max: self.limits.max_storage_buffers_per_stage,
});
}
Ok(OrbitalGridMemoryEstimate {
basis_bytes,
mo_coefficients_bytes: mo_bytes,
primitives_bytes: prim_bytes,
params_bytes,
output_bytes,
total_bytes: total,
fits_in_webgpu: total <= self.limits.max_buffer_size,
})
}
pub fn estimate_d4_dispersion_memory(
&self,
n_atoms: usize,
) -> Result<PairwiseMemoryEstimate, MemoryError> {
let pos_bytes = (n_atoms * 16) as u64;
let params_bytes = (n_atoms * 32) as u64;
let config_bytes = 32u64;
let pairwise_bytes = (n_atoms * n_atoms * 4) as u64;
let output_bytes = (n_atoms * 4) as u64;
let total = pos_bytes + params_bytes + config_bytes + pairwise_bytes + output_bytes;
self.check_buffer(pairwise_bytes)?;
Ok(PairwiseMemoryEstimate {
positions_bytes: pos_bytes,
params_bytes,
pairwise_bytes,
total_bytes: total,
fits_in_webgpu: total <= self.limits.max_buffer_size,
max_atoms_for_limit: ((self.limits.max_storage_buffer_binding_size / 4) as f64).sqrt()
as usize,
})
}
pub fn optimal_grid_dispatch(&self, dims: [u32; 3]) -> ([u32; 3], [u32; 3]) {
let wg_size = [
8u32.min(self.limits.max_workgroup_size_x),
8u32.min(self.limits.max_workgroup_size_y),
4u32.min(self.limits.max_workgroup_size_z),
];
let wg_count = [
dims[0].div_ceil(wg_size[0]),
dims[1].div_ceil(wg_size[1]),
dims[2].div_ceil(wg_size[2]),
];
(wg_size, wg_count)
}
pub fn optimal_1d_dispatch(&self, n: u32) -> (u32, u32) {
let wg_size = 64u32.min(self.limits.max_workgroup_size_x);
let wg_count = n.div_ceil(wg_size);
(wg_size, wg_count)
}
}
#[derive(Debug, Clone)]
pub struct OrbitalGridMemoryEstimate {
pub basis_bytes: u64,
pub mo_coefficients_bytes: u64,
pub primitives_bytes: u64,
pub params_bytes: u64,
pub output_bytes: u64,
pub total_bytes: u64,
pub fits_in_webgpu: bool,
}
#[derive(Debug, Clone)]
pub struct PairwiseMemoryEstimate {
pub positions_bytes: u64,
pub params_bytes: u64,
pub pairwise_bytes: u64,
pub total_bytes: u64,
pub fits_in_webgpu: bool,
pub max_atoms_for_limit: usize,
}
#[derive(Debug, Clone)]
pub enum MemoryError {
BufferTooLarge {
requested: u64,
max: u64,
},
BindingTooLarge {
requested: u64,
max: u64,
},
TooManyBindings {
current: u32,
max: u32,
},
WorkgroupSizeExceeded {
requested: [u32; 3],
max: [u32; 3],
},
TooManyInvocations {
requested: u32,
max: u32,
},
TooManyWorkgroups {
dimension: u32,
requested: u32,
max: u32,
},
}
impl std::fmt::Display for MemoryError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MemoryError::BufferTooLarge { requested, max } => {
write!(f, "Buffer {requested} bytes exceeds max {max} bytes")
}
MemoryError::BindingTooLarge { requested, max } => {
write!(f, "Binding {requested} bytes exceeds max {max} bytes")
}
MemoryError::TooManyBindings { current, max } => {
write!(f, "Need {current} bindings, max {max}")
}
MemoryError::WorkgroupSizeExceeded { requested, max } => {
write!(
f,
"Workgroup [{},{},{}] exceeds max [{},{},{}]",
requested[0], requested[1], requested[2], max[0], max[1], max[2]
)
}
MemoryError::TooManyInvocations { requested, max } => {
write!(f, "{requested} invocations exceeds max {max}")
}
MemoryError::TooManyWorkgroups {
dimension,
requested,
max,
} => {
write!(
f,
"Dimension {dimension}: {requested} workgroups exceeds max {max}"
)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_webgpu_defaults() {
let limits = GpuMemoryLimits::default();
assert_eq!(limits.max_storage_buffer_binding_size, 128 * 1024 * 1024);
assert_eq!(limits.max_buffer_size, 256 * 1024 * 1024);
assert_eq!(limits.max_uniform_buffer_binding_size, 64 * 1024);
}
#[test]
fn test_buffer_check_within_limits() {
let budget = GpuMemoryBudget::webgpu_default();
assert!(budget.check_buffer(1024 * 1024).is_ok()); }
#[test]
fn test_buffer_check_exceeds_limits() {
let budget = GpuMemoryBudget::webgpu_default();
assert!(budget.check_buffer(300_000_000).is_err()); }
#[test]
fn test_orbital_grid_small_molecule() {
let budget = GpuMemoryBudget::webgpu_default();
let est = budget.estimate_orbital_grid_memory(7, 125_000, 3).unwrap();
assert!(est.fits_in_webgpu);
assert!(est.total_bytes < 1_000_000); }
#[test]
fn test_workgroup_check_valid() {
let budget = GpuMemoryBudget::webgpu_default();
assert!(budget.check_workgroup([8, 8, 4], [100, 100, 50]).is_ok());
}
#[test]
fn test_workgroup_check_too_large() {
let budget = GpuMemoryBudget::webgpu_default();
assert!(budget.check_workgroup([512, 1, 1], [1, 1, 1]).is_err());
}
#[test]
fn test_d4_memory_small_system() {
let budget = GpuMemoryBudget::webgpu_default();
let est = budget.estimate_d4_dispersion_memory(100).unwrap();
assert!(est.fits_in_webgpu);
}
#[test]
fn test_d4_max_atoms_calculable() {
let budget = GpuMemoryBudget::webgpu_default();
let est = budget.estimate_d4_dispersion_memory(10).unwrap();
assert!(est.max_atoms_for_limit > 5000);
}
#[test]
fn test_optimal_grid_dispatch() {
let budget = GpuMemoryBudget::webgpu_default();
let (wg_size, wg_count) = budget.optimal_grid_dispatch([64, 64, 64]);
assert_eq!(wg_size, [8, 8, 4]);
assert_eq!(wg_count, [8, 8, 16]);
}
#[test]
fn test_optimal_1d_dispatch() {
let budget = GpuMemoryBudget::webgpu_default();
let (wg_size, wg_count) = budget.optimal_1d_dispatch(1000);
assert_eq!(wg_size, 64);
assert_eq!(wg_count, 16); }
}