use std::marker::PhantomData;
use std::mem::size_of;
use super::executor::GpuError;
#[derive(Debug)]
pub struct GpuBuffer<T> {
device_index: usize,
size: usize, _marker: PhantomData<T>,
}
impl<T> GpuBuffer<T> {
pub(crate) fn new(device_index: usize, size: usize) -> Self {
Self {
device_index,
size,
_marker: PhantomData,
}
}
pub fn size(&self) -> usize {
self.size
}
pub fn device_index(&self) -> usize {
self.device_index
}
pub fn element_count(&self) -> usize {
let elem_size = size_of::<T>();
self.size.checked_div(elem_size).unwrap_or(0)
}
}
#[derive(Debug)]
pub struct GpuMemoryPool {
pub device_index: usize,
total_allocated_bytes: usize,
peak_allocated_bytes: usize,
allocation_count: usize,
}
impl Default for GpuMemoryPool {
fn default() -> Self {
Self::new(0)
}
}
impl GpuMemoryPool {
pub fn new(device_index: usize) -> Self {
Self {
device_index,
total_allocated_bytes: 0,
peak_allocated_bytes: 0,
allocation_count: 0,
}
}
pub fn allocate<T>(&mut self, _count: usize) -> Result<GpuBuffer<T>, GpuError> {
Err(GpuError::NotAvailable {
reason: "CUDA not available in stub mode".to_string(),
})
}
pub fn deallocate<T>(&mut self, _buffer: GpuBuffer<T>) -> Result<(), GpuError> {
Ok(())
}
pub fn allocated_bytes(&self) -> usize {
self.total_allocated_bytes
}
pub fn peak_bytes(&self) -> usize {
self.peak_allocated_bytes
}
pub fn allocation_count(&self) -> usize {
self.allocation_count
}
pub fn reset_peak(&mut self) {
self.peak_allocated_bytes = self.total_allocated_bytes;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gpu_memory_pool_new() {
let pool = GpuMemoryPool::new(2);
assert_eq!(pool.device_index, 2);
assert_eq!(pool.allocated_bytes(), 0);
assert_eq!(pool.peak_bytes(), 0);
assert_eq!(pool.allocation_count(), 0);
}
#[test]
fn test_allocate_returns_not_available() {
let mut pool = GpuMemoryPool::new(0);
let result: Result<GpuBuffer<f32>, GpuError> = pool.allocate(1024);
assert!(result.is_err());
match result {
Err(GpuError::NotAvailable { reason }) => {
assert!(!reason.is_empty());
}
other => panic!("Expected NotAvailable, got {:?}", other),
}
}
#[test]
fn test_deallocate_no_panic() {
let mut pool = GpuMemoryPool::new(0);
let buf: GpuBuffer<f64> = GpuBuffer::new(0, 256);
let result = pool.deallocate(buf);
assert!(result.is_ok());
}
#[test]
fn test_peak_tracking_initial() {
let pool = GpuMemoryPool::new(0);
assert_eq!(pool.allocated_bytes(), 0);
assert_eq!(pool.peak_bytes(), 0);
}
#[test]
fn test_allocation_count() {
let pool = GpuMemoryPool::new(0);
assert_eq!(pool.allocation_count(), 0);
}
#[test]
fn test_reset_peak() {
let mut pool = GpuMemoryPool::new(0);
pool.peak_allocated_bytes = 4096;
pool.total_allocated_bytes = 2048;
pool.reset_peak();
assert_eq!(pool.peak_bytes(), 2048);
}
}