use crate::error::{LinalgError, LinalgResult};
use std::collections::HashMap;
use std::time::Instant;
#[derive(Debug)]
pub struct GpuMemoryManager {
pub gpu_id: usize,
pub memory_pools: Vec<MemoryPool>,
pub allocation_strategy: MemoryAllocationStrategy,
pub garbage_collector: MemoryGarbageCollector,
}
#[derive(Debug)]
pub struct MemoryPool {
pub size: usize,
pub free_blocks: Vec<MemoryBlock>,
pub allocated_blocks: Vec<MemoryBlock>,
pub pool_type: MemoryPoolType,
}
#[derive(Debug, Clone)]
pub struct MemoryBlock {
pub start: usize,
pub size: usize,
pub in_use: bool,
pub allocated_at: Option<Instant>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum MemoryPoolType {
Global,
Shared,
Constant,
Texture,
Unified,
}
#[derive(Debug, Clone)]
pub enum MemoryAllocationStrategy {
FirstFit,
BestFit,
WorstFit,
Buddy,
Segregated,
Predictive,
}
#[derive(Debug)]
pub struct MemoryGarbageCollector {
pub strategy: GCStrategy,
pub threshold: f64,
pub auto_collect: bool,
pub stats: GCStats,
}
#[derive(Debug, Clone)]
pub enum GCStrategy {
MarkAndSweep,
Generational,
Incremental,
Concurrent,
}
#[derive(Debug, Clone)]
pub struct GCStats {
pub collections_performed: usize,
pub memory_reclaimed: usize,
pub total_gc_time_ms: f64,
pub avg_collection_time_ms: f64,
}
#[derive(Debug)]
pub struct BandwidthPredictor {
pub models: Vec<BandwidthPredictionModel>,
pub history: std::collections::VecDeque<BandwidthMeasurement>,
pub accuracy: f64,
}
#[derive(Debug, Clone)]
pub enum BandwidthPredictionModel {
LinearRegression,
NeuralNetwork,
TimeSeries,
Ensemble,
}
#[derive(Debug, Clone)]
pub struct BandwidthMeasurement {
pub timestamp: Instant,
pub bandwidth_gbps: f64,
pub access_pattern: MemoryAccessPattern,
pub data_size: usize,
}
#[derive(Debug, Clone)]
pub enum MemoryAccessPattern {
Sequential,
Random,
Strided(usize),
Coalesced,
Broadcast,
}
#[derive(Debug, Clone)]
pub enum TensorCorePrecision {
FP16,
BF16,
FP32,
FP64,
TF32,
INT8,
}
impl GpuMemoryManager {
pub fn new(gpu_id: usize) -> LinalgResult<Self> {
Ok(Self {
gpu_id,
memory_pools: Vec::new(),
allocation_strategy: MemoryAllocationStrategy::BestFit,
garbage_collector: MemoryGarbageCollector::new(),
})
}
pub fn add_memory_pool(&mut self, pool: MemoryPool) {
self.memory_pools.push(pool);
}
pub fn allocate(
&mut self,
size: usize,
pool_type: MemoryPoolType,
) -> LinalgResult<MemoryBlock> {
let pool_index = self
.memory_pools
.iter()
.position(|p| p.pool_type == pool_type)
.ok_or_else(|| {
LinalgError::ComputationError(format!("No pool found for type {:?}", pool_type))
})?;
let pool = &mut self.memory_pools[pool_index];
match self.allocation_strategy {
MemoryAllocationStrategy::FirstFit => Self::allocate_first_fit(pool, size),
MemoryAllocationStrategy::BestFit => Self::allocate_best_fit(pool, size),
MemoryAllocationStrategy::WorstFit => Self::allocate_worst_fit(pool, size),
MemoryAllocationStrategy::Buddy => Self::allocate_buddy(pool, size),
MemoryAllocationStrategy::Segregated => Self::allocate_segregated(pool, size),
MemoryAllocationStrategy::Predictive => Self::allocate_predictive(pool, size),
}
}
pub fn deallocate(
&mut self,
block: MemoryBlock,
pool_type: MemoryPoolType,
) -> LinalgResult<()> {
let pool_index = self
.memory_pools
.iter()
.position(|p| p.pool_type == pool_type)
.ok_or_else(|| {
LinalgError::ComputationError(format!("No pool found for type {:?}", pool_type))
})?;
let pool = &mut self.memory_pools[pool_index];
pool.allocated_blocks.retain(|b| b.start != block.start);
let mut free_block = block;
free_block.in_use = false;
free_block.allocated_at = None;
pool.free_blocks.push(free_block);
Self::coalesce_free_blocks(pool);
Ok(())
}
pub fn collect_garbage(&mut self) -> LinalgResult<usize> {
let start_time = Instant::now();
let mut total_reclaimed = 0;
for pool in &mut self.memory_pools {
total_reclaimed += Self::collect_pool_garbage(pool)?;
}
let gc_time = start_time.elapsed().as_millis() as f64;
self.garbage_collector.stats.collections_performed += 1;
self.garbage_collector.stats.memory_reclaimed += total_reclaimed;
self.garbage_collector.stats.total_gc_time_ms += gc_time;
self.garbage_collector.stats.avg_collection_time_ms =
self.garbage_collector.stats.total_gc_time_ms
/ self.garbage_collector.stats.collections_performed as f64;
Ok(total_reclaimed)
}
pub fn get_memory_stats(&self) -> MemoryStats {
let mut total_allocated = 0;
let mut total_free = 0;
let mut total_fragmented = 0;
for pool in &self.memory_pools {
total_allocated += pool.allocated_blocks.iter().map(|b| b.size).sum::<usize>();
total_free += pool.free_blocks.iter().map(|b| b.size).sum::<usize>();
total_fragmented += pool.free_blocks.len().saturating_sub(1);
}
MemoryStats {
total_allocated,
total_free,
fragmentation_count: total_fragmented,
pool_count: self.memory_pools.len(),
gc_stats: self.garbage_collector.stats.clone(),
}
}
fn allocate_first_fit(pool: &mut MemoryPool, size: usize) -> LinalgResult<MemoryBlock> {
for (i, block) in pool.free_blocks.iter().enumerate() {
if block.size >= size {
let mut allocated_block = block.clone();
allocated_block.size = size;
allocated_block.in_use = true;
allocated_block.allocated_at = Some(Instant::now());
if block.size > size {
let remaining_block = MemoryBlock {
start: block.start + size,
size: block.size - size,
in_use: false,
allocated_at: None,
};
pool.free_blocks[i] = remaining_block;
} else {
pool.free_blocks.remove(i);
}
pool.allocated_blocks.push(allocated_block.clone());
return Ok(allocated_block);
}
}
Err(LinalgError::ComputationError(
"No suitable block found".to_string(),
))
}
fn allocate_best_fit(pool: &mut MemoryPool, size: usize) -> LinalgResult<MemoryBlock> {
let mut best_fit_index = None;
let mut best_fit_size = usize::MAX;
for (i, block) in pool.free_blocks.iter().enumerate() {
if block.size >= size && block.size < best_fit_size {
best_fit_index = Some(i);
best_fit_size = block.size;
}
}
if let Some(index) = best_fit_index {
let block = &pool.free_blocks[index];
let mut allocated_block = block.clone();
allocated_block.size = size;
allocated_block.in_use = true;
allocated_block.allocated_at = Some(Instant::now());
if block.size > size {
let remaining_block = MemoryBlock {
start: block.start + size,
size: block.size - size,
in_use: false,
allocated_at: None,
};
pool.free_blocks[index] = remaining_block;
} else {
pool.free_blocks.remove(index);
}
pool.allocated_blocks.push(allocated_block.clone());
Ok(allocated_block)
} else {
Err(LinalgError::ComputationError(
"No suitable block found".to_string(),
))
}
}
fn allocate_worst_fit(pool: &mut MemoryPool, size: usize) -> LinalgResult<MemoryBlock> {
Self::allocate_first_fit(pool, size) }
fn allocate_buddy(pool: &mut MemoryPool, size: usize) -> LinalgResult<MemoryBlock> {
Self::allocate_first_fit(pool, size) }
fn allocate_segregated(pool: &mut MemoryPool, size: usize) -> LinalgResult<MemoryBlock> {
Self::allocate_first_fit(pool, size) }
fn allocate_predictive(pool: &mut MemoryPool, size: usize) -> LinalgResult<MemoryBlock> {
Self::allocate_best_fit(pool, size) }
fn coalesce_free_blocks(pool: &mut MemoryPool) {
pool.free_blocks.sort_by_key(|b| b.start);
let mut i = 0;
while i < pool.free_blocks.len().saturating_sub(1) {
let current_end = pool.free_blocks[i].start + pool.free_blocks[i].size;
let next_start = pool.free_blocks[i + 1].start;
if current_end == next_start {
pool.free_blocks[i].size += pool.free_blocks[i + 1].size;
pool.free_blocks.remove(i + 1);
} else {
i += 1;
}
}
}
fn collect_pool_garbage(pool: &mut MemoryPool) -> LinalgResult<usize> {
let before_count = pool.allocated_blocks.len();
pool.allocated_blocks.retain(|block| {
if let Some(allocated_at) = block.allocated_at {
allocated_at.elapsed().as_secs() < 300 } else {
true
}
});
let reclaimed_count = before_count - pool.allocated_blocks.len();
Ok(reclaimed_count * 1024) }
}
impl MemoryPool {
pub fn new(size: usize, pool_type: MemoryPoolType) -> Self {
let initial_block = MemoryBlock {
start: 0,
size,
in_use: false,
allocated_at: None,
};
Self {
size,
free_blocks: vec![initial_block],
allocated_blocks: Vec::new(),
pool_type,
}
}
pub fn utilization(&self) -> f64 {
let allocated_size: usize = self.allocated_blocks.iter().map(|b| b.size).sum();
if self.size == 0 {
0.0
} else {
(allocated_size as f64 / self.size as f64) * 100.0
}
}
}
impl MemoryGarbageCollector {
pub fn new() -> Self {
Self {
strategy: GCStrategy::MarkAndSweep,
threshold: 0.8, auto_collect: true,
stats: GCStats::new(),
}
}
}
impl GCStats {
pub fn new() -> Self {
Self {
collections_performed: 0,
memory_reclaimed: 0,
total_gc_time_ms: 0.0,
avg_collection_time_ms: 0.0,
}
}
}
#[derive(Debug, Clone)]
pub struct MemoryStats {
pub total_allocated: usize,
pub total_free: usize,
pub fragmentation_count: usize,
pub pool_count: usize,
pub gc_stats: GCStats,
}
impl BandwidthPredictor {
pub fn new() -> Self {
Self {
models: vec![BandwidthPredictionModel::LinearRegression],
history: std::collections::VecDeque::new(),
accuracy: 0.85,
}
}
pub fn add_measurement(&mut self, measurement: BandwidthMeasurement) {
self.history.push_back(measurement);
if self.history.len() > 1000 {
self.history.pop_front();
}
}
pub fn predict_bandwidth(&self, data_size: usize, access_pattern: MemoryAccessPattern) -> f64 {
let base_bandwidth = match access_pattern {
MemoryAccessPattern::Sequential => 800.0, MemoryAccessPattern::Coalesced => 750.0,
MemoryAccessPattern::Strided(_) => 400.0,
MemoryAccessPattern::Random => 200.0,
MemoryAccessPattern::Broadcast => 600.0,
};
#[cfg(target_pointer_width = "32")]
let threshold = 256 * 1024 * 1024; #[cfg(target_pointer_width = "64")]
let threshold = 1024 * 1024 * 1024;
let size_factor = if data_size > threshold { 0.9 } else { 1.0 };
base_bandwidth * size_factor
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_pool_creation() {
let pool = MemoryPool::new(1024 * 1024, MemoryPoolType::Global);
assert_eq!(pool.size, 1024 * 1024);
assert_eq!(pool.free_blocks.len(), 1);
assert_eq!(pool.allocated_blocks.len(), 0);
}
#[test]
fn test_memory_manager_creation() {
let manager = GpuMemoryManager::new(0).expect("Operation failed");
assert_eq!(manager.gpu_id, 0);
assert_eq!(manager.memory_pools.len(), 0);
}
#[test]
fn test_bandwidth_predictor() {
let predictor = BandwidthPredictor::new();
let bandwidth = predictor.predict_bandwidth(1024, MemoryAccessPattern::Sequential);
assert!(bandwidth > 0.0);
}
}