use super::{MemoryPlan, PlanError};
use crate::optimizer::AdapterCaps;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct DeviceMemoryBudget {
pub backend: &'static str,
pub per_buffer_bytes: u64,
pub peak_static_bytes: u64,
}
impl DeviceMemoryBudget {
#[must_use]
pub fn from_adapter(caps: &AdapterCaps) -> Self {
let per_buffer_bytes = caps.max_storage_buffer_binding_size.max(1);
Self {
backend: caps.backend,
per_buffer_bytes,
peak_static_bytes: per_buffer_bytes.saturating_mul(16),
}
}
pub fn validate(&self, plan: &MemoryPlan) -> Result<MemoryBudgetReport, PlanError> {
let mut largest_buffer_name = "";
let mut largest_buffer_bytes = 0u64;
for buffer in &plan.buffers {
let Some(size_bytes) = buffer.static_size_bytes else {
continue;
};
if size_bytes > largest_buffer_bytes {
largest_buffer_name = &buffer.name;
largest_buffer_bytes = size_bytes;
}
if size_bytes > self.per_buffer_bytes {
return Err(PlanError::BufferBudgetExceeded {
backend: self.backend,
name: buffer.name.clone(),
size_bytes,
budget_bytes: self.per_buffer_bytes,
});
}
}
if plan.static_bytes > self.peak_static_bytes {
return Err(PlanError::PeakBudgetExceeded {
backend: self.backend,
planned_bytes: plan.static_bytes,
budget_bytes: self.peak_static_bytes,
});
}
Ok(MemoryBudgetReport {
backend: self.backend,
static_bytes: plan.static_bytes,
peak_budget_bytes: self.peak_static_bytes,
largest_buffer_name: largest_buffer_name.to_owned(),
largest_buffer_bytes,
per_buffer_budget_bytes: self.per_buffer_bytes,
dynamic_buffers: plan.dynamic_buffers,
})
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct MemoryBudgetReport {
pub backend: &'static str,
pub static_bytes: u64,
pub peak_budget_bytes: u64,
pub largest_buffer_name: String,
pub largest_buffer_bytes: u64,
pub per_buffer_budget_bytes: u64,
pub dynamic_buffers: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::execution_plan::{BufferPlan, MemoryPlan};
use crate::ir::{BufferAccess, DataType, MemoryKind};
fn plan(static_sizes: &[u64]) -> MemoryPlan {
MemoryPlan {
buffers: static_sizes
.iter()
.enumerate()
.map(|(index, &size)| BufferPlan {
name: format!("b{index}"),
binding: index as u32,
access: BufferAccess::ReadWrite,
kind: MemoryKind::Global,
element: DataType::U32,
count: (size / 4) as u32,
static_size_bytes: Some(size),
output_range: None,
})
.collect(),
static_bytes: static_sizes.iter().copied().sum(),
dynamic_buffers: 0,
visible_readback_bytes: 0,
avoided_readback_bytes: 0,
}
}
#[test]
fn validates_under_budget_and_reports_largest_buffer() {
let budget = DeviceMemoryBudget {
backend: "test-gpu",
per_buffer_bytes: 1024,
peak_static_bytes: 4096,
};
let report = budget
.validate(&plan(&[128, 512, 256]))
.expect("Fix: plan is below both per-buffer and peak budgets");
assert_eq!(report.backend, "test-gpu");
assert_eq!(report.static_bytes, 896);
assert_eq!(report.largest_buffer_name, "b1");
assert_eq!(report.largest_buffer_bytes, 512);
}
#[test]
fn rejects_single_buffer_over_budget_before_peak_check() {
let budget = DeviceMemoryBudget {
backend: "test-gpu",
per_buffer_bytes: 1024,
peak_static_bytes: 4096,
};
let error = budget
.validate(&plan(&[128, 2048]))
.expect_err("single oversize buffer must fail before dispatch");
assert!(matches!(
error,
PlanError::BufferBudgetExceeded {
backend: "test-gpu",
name,
size_bytes: 2048,
budget_bytes: 1024,
} if name == "b1"
));
}
#[test]
fn rejects_peak_static_bytes_over_budget() {
let budget = DeviceMemoryBudget {
backend: "test-gpu",
per_buffer_bytes: 1024,
peak_static_bytes: 1500,
};
let error = budget
.validate(&plan(&[800, 800]))
.expect_err("aggregate peak memory must be bounded");
assert!(matches!(
error,
PlanError::PeakBudgetExceeded {
backend: "test-gpu",
planned_bytes: 1600,
budget_bytes: 1500,
}
));
}
}