use super::GpuDevice;
use anyhow::Result;
pub const DEFAULT_MORSEL_SIZE: usize = 128 * 1024 * 1024;
#[derive(Debug, Clone)]
pub struct GpuMemoryLimits {
pub total_vram: u64,
pub usable_vram: u64,
pub morsel_size: usize,
pub max_morsels: usize,
}
impl GpuMemoryLimits {
#[allow(
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::cast_precision_loss
)]
pub fn detect(device: &GpuDevice) -> Result<Self> {
let limits = device.device().limits();
let total_vram = limits.max_buffer_size;
let usable_vram = (total_vram as f64 * 0.7) as u64;
let morsel_size = DEFAULT_MORSEL_SIZE;
let max_morsels = (usable_vram as usize / morsel_size).max(1);
Ok(Self {
total_vram,
usable_vram,
morsel_size,
max_morsels,
})
}
#[must_use]
#[allow(clippy::cast_possible_truncation)]
pub fn fits_in_vram(&self, graph_size_bytes: usize) -> bool {
graph_size_bytes <= self.usable_vram as usize
}
#[must_use]
pub fn morsels_needed(&self, size_bytes: usize) -> usize {
size_bytes.div_ceil(self.morsel_size)
}
#[must_use]
pub fn recommended_tile_size(&self, node_size_bytes: usize) -> usize {
if node_size_bytes == 0 {
return 1000; }
(self.morsel_size / node_size_bytes).max(100)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_memory_limits_detection() {
if !GpuDevice::is_gpu_available().await {
eprintln!("⚠️ Skipping test_memory_limits_detection: GPU not available");
return;
}
let device = GpuDevice::new().await.unwrap();
let limits = GpuMemoryLimits::detect(&device).unwrap();
println!("Total VRAM: {} bytes", limits.total_vram);
println!("Usable VRAM: {} bytes", limits.usable_vram);
println!("Max morsels: {}", limits.max_morsels);
assert!(limits.total_vram > 0);
assert!(limits.usable_vram > 0);
assert!(limits.usable_vram <= limits.total_vram);
assert!(limits.max_morsels > 0);
}
#[test]
fn test_fits_in_vram() {
let limits = GpuMemoryLimits {
total_vram: 8 * 1024 * 1024 * 1024, usable_vram: 5 * 1024 * 1024 * 1024, morsel_size: DEFAULT_MORSEL_SIZE,
max_morsels: 40,
};
assert!(limits.fits_in_vram(100 * 1024 * 1024)); assert!(limits.fits_in_vram(4 * 1024 * 1024 * 1024)); assert!(!limits.fits_in_vram(6 * 1024 * 1024 * 1024)); }
#[test]
fn test_morsels_needed() {
let limits = GpuMemoryLimits {
total_vram: 8 * 1024 * 1024 * 1024,
usable_vram: 5 * 1024 * 1024 * 1024,
morsel_size: DEFAULT_MORSEL_SIZE,
max_morsels: 40,
};
assert_eq!(limits.morsels_needed(128 * 1024 * 1024), 1);
assert_eq!(limits.morsels_needed(256 * 1024 * 1024), 2);
assert_eq!(limits.morsels_needed(200 * 1024 * 1024), 2);
}
}