use std::sync::Mutex;
use std::collections::BTreeMap;
#[derive(Debug, Clone)]
pub struct MemoryAllocation {
pub device_ptr: u64,
pub size: usize,
pub is_pinned: bool,
pub allocated_at: std::time::SystemTime,
}
#[derive(Debug)]
pub struct GpuMemoryPool {
pub device_id: usize,
pub total_memory: u64,
pub max_fragmentation: f64,
allocations: Mutex<BTreeMap<u64, MemoryAllocation>>,
free_blocks: Mutex<BTreeMap<usize, Vec<u64>>>,
used_memory: Mutex<u64>,
}
impl GpuMemoryPool {
pub fn new(device_id: usize, total_memory: u64) -> Self {
GpuMemoryPool {
device_id,
total_memory,
max_fragmentation: 0.5,
allocations: Mutex::new(BTreeMap::new()),
free_blocks: Mutex::new(BTreeMap::new()),
used_memory: Mutex::new(0),
}
}
pub fn allocate(&self, size: usize) -> Result<MemoryAllocation, String> {
let mut allocations = self.allocations.lock().unwrap();
let mut free_blocks = self.free_blocks.lock().unwrap();
let mut used = self.used_memory.lock().unwrap();
let total_free = self.total_memory - *used;
if size as u64 > total_free {
return Err(format!(
"Insufficient GPU memory: requested {} bytes, available {} bytes",
size, total_free
));
}
let device_ptr = if let Some(found_size) = free_blocks.iter().find(|(s, _)| **s >= size).map(|(s, _)| *s) {
let ptrs = free_blocks.get_mut(&found_size).unwrap();
let ptr = ptrs.pop().unwrap();
if ptrs.is_empty() {
free_blocks.remove(&found_size);
}
ptr
} else {
allocations.last_key_value()
.map(|(ptr, alloc)| ptr + alloc.size as u64)
.unwrap_or(0) as u64
};
let allocation = MemoryAllocation {
device_ptr,
size,
is_pinned: false,
allocated_at: std::time::SystemTime::now(),
};
allocations.insert(device_ptr, allocation.clone());
*used += size as u64;
Ok(allocation)
}
pub fn deallocate(&self, ptr: u64) -> Result<(), String> {
let mut allocations = self.allocations.lock().unwrap();
let mut free_blocks = self.free_blocks.lock().unwrap();
let mut used = self.used_memory.lock().unwrap();
if let Some(alloc) = allocations.remove(&ptr) {
*used -= alloc.size as u64;
free_blocks.entry(alloc.size)
.or_insert_with(Vec::new)
.push(ptr);
Ok(())
} else {
Err(format!("Pointer {} not found in allocations", ptr))
}
}
pub fn defragment(&self) -> Result<(), String> {
let _allocations = self.allocations.lock().unwrap();
let free_blocks_lock = self.free_blocks.lock().unwrap();
let mut merged = Vec::new();
for (size, ptrs) in free_blocks_lock.iter() {
for ptr in ptrs {
merged.push((*ptr, *size));
}
}
drop(free_blocks_lock);
merged.sort_by_key(|(ptr, _)| *ptr);
let mut final_blocks: BTreeMap<usize, Vec<u64>> = BTreeMap::new();
let mut current_ptr = 0u64;
let mut current_size = 0usize;
for (ptr, size) in merged {
if ptr == current_ptr + current_size as u64 {
current_size += size;
} else {
if current_size > 0 {
final_blocks.entry(current_size)
.or_insert_with(Vec::new)
.push(current_ptr);
}
current_ptr = ptr;
current_size = size;
}
}
if current_size > 0 {
final_blocks.entry(current_size)
.or_insert_with(Vec::new)
.push(current_ptr);
}
let mut free_blocks = self.free_blocks.lock().unwrap();
*free_blocks = final_blocks;
Ok(())
}
pub fn utilization(&self) -> f64 {
let used = *self.used_memory.lock().unwrap();
(used as f64 / self.total_memory as f64) * 100.0
}
pub fn fragmentation(&self) -> f64 {
let allocations = self.allocations.lock().unwrap();
let free_blocks = self.free_blocks.lock().unwrap();
let n_allocs = allocations.len();
if n_allocs == 0 {
return 0.0;
}
let n_free_blocks: usize = free_blocks.values().map(|v| v.len()).sum();
(n_free_blocks as f64 / (n_allocs + n_free_blocks) as f64) * 100.0
}
}
#[derive(Debug, Clone)]
pub struct HostToDeviceTransfer {
pub host_ptr: *const u8,
pub device_ptr: u64,
pub size: usize,
pub async_transfer: bool,
}
impl HostToDeviceTransfer {
pub fn new(host_ptr: *const u8, device_ptr: u64, size: usize, async_transfer: bool) -> Self {
HostToDeviceTransfer {
host_ptr,
device_ptr,
size,
async_transfer,
}
}
pub fn execute(&self) -> Result<(), String> {
if self.host_ptr.is_null() {
return Err("Invalid host pointer".to_string());
}
if self.device_ptr == 0 {
return Err("Invalid device pointer".to_string());
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct DeviceToHostTransfer {
pub device_ptr: u64,
pub host_ptr: *mut u8,
pub size: usize,
pub async_transfer: bool,
}
impl DeviceToHostTransfer {
pub fn new(device_ptr: u64, host_ptr: *mut u8, size: usize, async_transfer: bool) -> Self {
DeviceToHostTransfer {
device_ptr,
host_ptr,
size,
async_transfer,
}
}
pub fn execute(&self) -> Result<(), String> {
if self.device_ptr == 0 {
return Err("Invalid device pointer".to_string());
}
if self.host_ptr.is_null() {
return Err("Invalid host pointer".to_string());
}
Ok(())
}
}
#[derive(Debug)]
pub struct MultiGpuMemory {
pools: Vec<GpuMemoryPool>,
}
impl MultiGpuMemory {
pub fn new(devices: &[(usize, u64)]) -> Self {
let pools = devices
.iter()
.map(|(device_id, total_mem)| GpuMemoryPool::new(*device_id, *total_mem))
.collect();
MultiGpuMemory { pools }
}
pub fn allocate_balanced(&self, size: usize) -> Result<(usize, MemoryAllocation), String> {
let mut best_pool_idx = 0;
let mut best_fragmentation = f64::INFINITY;
for (idx, pool) in self.pools.iter().enumerate() {
let frag = pool.fragmentation();
if frag < best_fragmentation {
best_fragmentation = frag;
best_pool_idx = idx;
}
}
let alloc = self.pools[best_pool_idx].allocate(size)?;
Ok((best_pool_idx, alloc))
}
pub fn pool(&self, device_id: usize) -> Option<&GpuMemoryPool> {
self.pools.iter().find(|p| p.device_id == device_id)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_pool_allocation() {
let pool = GpuMemoryPool::new(0, 1024 * 1024 * 1024); let alloc = pool.allocate(1024).unwrap();
assert_eq!(alloc.size, 1024);
assert_eq!(pool.utilization(), 1024.0 / (1024 * 1024 * 1024) as f64 * 100.0);
}
#[test]
fn test_memory_deallocation() {
let pool = GpuMemoryPool::new(0, 1024 * 1024 * 1024);
let alloc = pool.allocate(1024).unwrap();
let ptr = alloc.device_ptr;
pool.deallocate(ptr).unwrap();
assert!(pool.allocate(1024).is_ok());
}
#[test]
fn test_fragmentation() {
let pool = GpuMemoryPool::new(0, 10000);
let a1 = pool.allocate(1000).unwrap();
let a2 = pool.allocate(1000).unwrap();
let a3 = pool.allocate(1000).unwrap();
pool.deallocate(a1.device_ptr).unwrap();
pool.deallocate(a3.device_ptr).unwrap();
let frag = pool.fragmentation();
assert!(frag > 0.0);
}
#[test]
fn test_multilevel_gpu_memory() {
let devices = vec![(0, 1024 * 1024 * 1024), (1, 2 * 1024 * 1024 * 1024)];
let mgpu = MultiGpuMemory::new(&devices);
let (dev_id, alloc) = mgpu.allocate_balanced(4096).unwrap();
assert!(dev_id < 2);
assert_eq!(alloc.size, 4096);
}
#[test]
fn test_host_device_transfer() {
let host_data = vec![1u8, 2, 3, 4];
let transfer = HostToDeviceTransfer::new(
host_data.as_ptr(),
0x10000000,
4,
false,
);
assert!(transfer.execute().is_ok());
}
}