use crate::error::{ClusteringError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Instant;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MemoryStrategy {
Unified,
Explicit,
Pooled { pool_size_mb: usize },
ZeroCopy,
Adaptive,
}
impl Default for MemoryStrategy {
fn default() -> Self {
MemoryStrategy::Adaptive
}
}
#[derive(Debug)]
pub struct GpuMemoryManager {
pools: HashMap<usize, Vec<GpuMemoryBlock>>,
total_allocated: usize,
peak_usage: usize,
alignment: usize,
max_pool_size: usize,
stats: MemoryStats,
}
#[derive(Debug, Clone)]
pub struct GpuMemoryBlock {
pub device_ptr: usize,
pub size: usize,
pub in_use: bool,
pub allocated_at: Instant,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct MemoryStats {
pub total_allocated: usize,
pub peak_usage: usize,
pub allocation_count: usize,
pub deallocation_count: usize,
pub pool_hits: usize,
pub pool_misses: usize,
}
impl GpuMemoryManager {
pub fn new(alignment: usize, max_pool_size: usize) -> Self {
Self {
pools: HashMap::new(),
total_allocated: 0,
peak_usage: 0,
alignment,
max_pool_size,
stats: MemoryStats::default(),
}
}
pub fn allocate(&mut self, size: usize) -> Result<GpuMemoryBlock> {
let aligned_size = (size + self.alignment - 1) & !(self.alignment - 1);
let size_class = self.get_size_class(aligned_size);
self.stats.allocation_count += 1;
if let Some(pool) = self.pools.get_mut(&size_class) {
for block in pool.iter_mut() {
if !block.in_use && block.size >= aligned_size {
block.in_use = true;
self.stats.pool_hits += 1;
return Ok(GpuMemoryBlock {
device_ptr: block.device_ptr,
size: block.size,
in_use: true,
allocated_at: Instant::now(),
});
}
}
}
self.stats.pool_misses += 1;
let device_ptr = self.allocate_device_memory(aligned_size)?;
self.total_allocated += aligned_size;
self.peak_usage = self.peak_usage.max(self.total_allocated);
self.stats.total_allocated = self.total_allocated;
self.stats.peak_usage = self.peak_usage;
Ok(GpuMemoryBlock {
device_ptr,
size: aligned_size,
in_use: true,
allocated_at: Instant::now(),
})
}
pub fn deallocate(&mut self, mut block: GpuMemoryBlock) -> Result<()> {
block.in_use = false;
self.stats.deallocation_count += 1;
let size_class = self.get_size_class(block.size);
let pool = self.pools.entry(size_class).or_insert_with(Vec::new);
if pool.len() < self.max_pool_size {
pool.push(block);
} else {
self.free_device_memory(block.device_ptr, block.size)?;
self.total_allocated -= block.size;
self.stats.total_allocated = self.total_allocated;
}
Ok(())
}
pub fn clear_pools(&mut self) -> Result<()> {
for pool in self.pools.values() {
for block in pool {
if !block.in_use {
self.free_device_memory(block.device_ptr, block.size)?;
self.total_allocated -= block.size;
}
}
}
self.pools.clear();
self.stats.total_allocated = self.total_allocated;
Ok(())
}
pub fn get_stats(&self) -> &MemoryStats {
&self.stats
}
pub fn pool_efficiency(&self) -> f64 {
if self.stats.allocation_count == 0 {
0.0
} else {
self.stats.pool_hits as f64 / self.stats.allocation_count as f64
}
}
pub fn current_usage(&self) -> usize {
self.total_allocated
}
pub fn peak_usage(&self) -> usize {
self.peak_usage
}
fn get_size_class(&self, size: usize) -> usize {
if size == 0 {
return 1;
}
let mut class = 1;
while class < size {
class <<= 1;
}
class
}
fn allocate_device_memory(&self, size: usize) -> Result<usize> {
if size == 0 {
return Err(ClusteringError::InvalidInput(
"Cannot allocate zero bytes".to_string(),
));
}
if size > 16 * 1024 * 1024 * 1024 {
return Err(ClusteringError::InvalidInput(
"Allocation too large".to_string(),
));
}
Ok(0x1000_0000 + size) }
fn free_device_memory(&self, _device_ptr: usize, _size: usize) -> Result<()> {
Ok(())
}
}
impl MemoryStats {
pub fn allocation_efficiency(&self) -> f64 {
if self.allocation_count == 0 {
1.0
} else {
self.deallocation_count as f64 / self.allocation_count as f64
}
}
pub fn average_allocation_size(&self) -> f64 {
if self.allocation_count == 0 {
0.0
} else {
self.total_allocated as f64 / self.allocation_count as f64
}
}
pub fn has_potential_leaks(&self) -> bool {
self.allocation_count > self.deallocation_count }
}
#[derive(Debug, Clone)]
pub enum MemoryTransfer {
HostToDevice {
host_ptr: *const u8,
device_ptr: usize,
size: usize,
},
DeviceToHost {
device_ptr: usize,
host_ptr: *mut u8,
size: usize,
},
DeviceToDevice {
src_device_ptr: usize,
dst_device_ptr: usize,
size: usize,
},
}
impl MemoryTransfer {
pub fn size(&self) -> usize {
match self {
MemoryTransfer::HostToDevice { size, .. } => *size,
MemoryTransfer::DeviceToHost { size, .. } => *size,
MemoryTransfer::DeviceToDevice { size, .. } => *size,
}
}
pub fn execute(&self) -> Result<()> {
match self {
MemoryTransfer::HostToDevice { .. } => {
Ok(())
}
MemoryTransfer::DeviceToHost { .. } => {
Ok(())
}
MemoryTransfer::DeviceToDevice { .. } => {
Ok(())
}
}
}
}
#[derive(Debug, Clone, Default)]
pub struct BandwidthMonitor {
pub total_transferred: usize,
pub transfer_count: usize,
pub total_time_us: u64,
}
impl BandwidthMonitor {
pub fn record_transfer(&mut self, size: usize, duration_us: u64) {
self.total_transferred += size;
self.transfer_count += 1;
self.total_time_us += duration_us;
}
pub fn average_bandwidth_gbps(&self) -> f64 {
if self.total_time_us == 0 {
0.0
} else {
let total_gb = self.total_transferred as f64 / (1024.0 * 1024.0 * 1024.0);
let total_seconds = self.total_time_us as f64 / 1_000_000.0;
total_gb / total_seconds
}
}
pub fn average_transfer_size(&self) -> f64 {
if self.transfer_count == 0 {
0.0
} else {
self.total_transferred as f64 / self.transfer_count as f64
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_manager_creation() {
let manager = GpuMemoryManager::new(256, 10);
assert_eq!(manager.alignment, 256);
assert_eq!(manager.max_pool_size, 10);
assert_eq!(manager.current_usage(), 0);
}
#[test]
fn test_memory_allocation() {
let mut manager = GpuMemoryManager::new(256, 10);
let block = manager.allocate(1024).expect("Operation failed");
assert!(block.size >= 1024);
assert!(block.in_use);
assert_eq!(manager.get_stats().allocation_count, 1);
assert_eq!(manager.get_stats().pool_misses, 1);
}
#[test]
fn test_memory_pooling() {
let mut manager = GpuMemoryManager::new(256, 10);
let block = manager.allocate(1024).expect("Operation failed");
manager.deallocate(block).expect("Operation failed");
let _block2 = manager.allocate(1024).expect("Operation failed");
assert!(manager.get_stats().pool_hits > 0);
}
#[test]
fn test_memory_stats() {
let stats = MemoryStats {
allocation_count: 10,
deallocation_count: 8,
total_allocated: 1024,
pool_hits: 5,
..Default::default()
};
assert_eq!(stats.allocation_efficiency(), 0.8);
assert_eq!(stats.average_allocation_size(), 102.4);
assert!(stats.has_potential_leaks());
}
#[test]
fn test_bandwidth_monitor() {
let mut monitor = BandwidthMonitor::default();
monitor.record_transfer(1024 * 1024 * 1024, 1_000_000);
assert_eq!(monitor.average_bandwidth_gbps(), 1.0);
assert_eq!(monitor.average_transfer_size(), 1024.0 * 1024.0 * 1024.0);
}
#[test]
fn test_memory_transfer() {
let transfer = MemoryTransfer::HostToDevice {
host_ptr: std::ptr::null(),
device_ptr: 0x1000,
size: 1024,
};
assert_eq!(transfer.size(), 1024);
assert!(transfer.execute().is_ok());
}
}