use crate::error::{ScirsError, ScirsResult};
use scirs2_core::gpu::{GpuBuffer, GpuContext};
pub type OptimGpuArray<T> = GpuBuffer<T>;
pub type OptimGpuBuffer<T> = GpuBuffer<T>;
use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone)]
pub struct GpuMemoryInfo {
pub total: usize,
pub free: usize,
pub used: usize,
}
pub struct GpuMemoryPool {
context: Arc<GpuContext>,
pools: Arc<Mutex<HashMap<usize, VecDeque<GpuMemoryBlock>>>>,
allocated_blocks: Arc<Mutex<Vec<AllocatedBlock>>>,
memory_limit: Option<usize>,
current_usage: Arc<Mutex<usize>>,
allocation_stats: Arc<Mutex<AllocationStats>>,
}
impl GpuMemoryPool {
pub fn new(context: Arc<GpuContext>, memory_limit: Option<usize>) -> ScirsResult<Self> {
Ok(Self {
context,
pools: Arc::new(Mutex::new(HashMap::new())),
allocated_blocks: Arc::new(Mutex::new(Vec::new())),
memory_limit,
current_usage: Arc::new(Mutex::new(0)),
allocation_stats: Arc::new(Mutex::new(AllocationStats::new())),
})
}
pub fn new_stub() -> Self {
use scirs2_core::gpu::GpuBackend;
let context = GpuContext::new(GpuBackend::Cpu).expect("CPU backend should always work");
Self {
context: Arc::new(context),
pools: Arc::new(Mutex::new(HashMap::new())),
allocated_blocks: Arc::new(Mutex::new(Vec::new())),
memory_limit: None,
current_usage: Arc::new(Mutex::new(0)),
allocation_stats: Arc::new(Mutex::new(AllocationStats::new())),
}
}
pub fn allocate_workspace(&mut self, size: usize) -> ScirsResult<GpuWorkspace> {
let block = self.allocate_block(size)?;
Ok(GpuWorkspace::new(block, Arc::clone(&self.pools)))
}
fn allocate_block(&mut self, size: usize) -> ScirsResult<GpuMemoryBlock> {
let mut stats = self.allocation_stats.lock().expect("Operation failed");
stats.total_allocations += 1;
if let Some(limit) = self.memory_limit {
let current = *self.current_usage.lock().expect("Operation failed");
if current + size > limit {
drop(stats);
self.garbage_collect()?;
stats = self.allocation_stats.lock().expect("Operation failed");
let current = *self.current_usage.lock().expect("Operation failed");
if current + size > limit {
return Err(ScirsError::MemoryError(
scirs2_core::error::ErrorContext::new(format!(
"Would exceed memory limit: {} + {} > {}",
current, size, limit
))
.with_location(scirs2_core::error::ErrorLocation::new(file!(), line!())),
));
}
}
}
let mut pools = self.pools.lock().expect("Operation failed");
if let Some(pool) = pools.get_mut(&size) {
if let Some(block) = pool.pop_front() {
stats.pool_hits += 1;
return Ok(block);
}
}
stats.new_allocations += 1;
let gpu_buffer = self.context.create_buffer::<u8>(size);
let ptr = std::ptr::null_mut();
let block = GpuMemoryBlock {
size,
ptr,
gpu_buffer: Some(gpu_buffer),
};
*self.current_usage.lock().expect("Operation failed") += size;
Ok(block)
}
fn return_block(&self, block: GpuMemoryBlock) {
let mut pools = self.pools.lock().expect("Operation failed");
pools.entry(block.size).or_default().push_back(block);
}
fn garbage_collect(&mut self) -> ScirsResult<()> {
let mut pools = self.pools.lock().expect("Operation failed");
let mut freed_memory = 0;
for (size, pool) in pools.iter_mut() {
let count = pool.len();
freed_memory += size * count;
pool.clear();
}
*self.current_usage.lock().expect("Operation failed") = self
.current_usage
.lock()
.expect("Operation failed")
.saturating_sub(freed_memory);
let mut stats = self.allocation_stats.lock().expect("Operation failed");
stats.garbage_collections += 1;
stats.total_freed_memory += freed_memory;
Ok(())
}
pub fn memory_stats(&self) -> MemoryStats {
let current_usage = *self.current_usage.lock().expect("Operation failed");
let allocation_stats = self
.allocation_stats
.lock()
.expect("Operation failed")
.clone();
let pool_sizes: HashMap<usize, usize> = self
.pools
.lock()
.expect("Operation failed")
.iter()
.map(|(&size, pool)| (size, pool.len()))
.collect();
MemoryStats {
current_usage,
memory_limit: self.memory_limit,
allocation_stats,
pool_sizes,
}
}
}
pub struct GpuMemoryBlock {
size: usize,
ptr: *mut u8,
gpu_buffer: Option<OptimGpuBuffer<u8>>,
}
impl std::fmt::Debug for GpuMemoryBlock {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GpuMemoryBlock")
.field("size", &self.size)
.field("ptr", &self.ptr)
.field("gpu_buffer", &self.gpu_buffer.is_some())
.finish()
}
}
unsafe impl Send for GpuMemoryBlock {}
unsafe impl Sync for GpuMemoryBlock {}
impl GpuMemoryBlock {
pub fn size(&self) -> usize {
self.size
}
pub fn ptr(&self) -> *mut u8 {
self.ptr
}
pub fn as_typed<T: scirs2_core::GpuDataType>(&self) -> ScirsResult<&OptimGpuBuffer<T>> {
if let Some(ref buffer) = self.gpu_buffer {
Err(ScirsError::ComputationError(
scirs2_core::error::ErrorContext::new("Type casting not supported".to_string()),
))
} else {
Err(ScirsError::InvalidInput(
scirs2_core::error::ErrorContext::new("Memory block not available".to_string()),
))
}
}
}
pub struct GpuWorkspace {
blocks: Vec<GpuMemoryBlock>,
pool: Arc<Mutex<HashMap<usize, VecDeque<GpuMemoryBlock>>>>,
}
impl GpuWorkspace {
fn new(
initial_block: GpuMemoryBlock,
pool: Arc<Mutex<HashMap<usize, VecDeque<GpuMemoryBlock>>>>,
) -> Self {
Self {
blocks: vec![initial_block],
pool,
}
}
pub fn get_block(&mut self, size: usize) -> ScirsResult<&GpuMemoryBlock> {
for block in &self.blocks {
if block.size >= size {
return Ok(block);
}
}
Err(ScirsError::MemoryError(
scirs2_core::error::ErrorContext::new("No suitable block available".to_string()),
))
}
pub fn get_buffer<T: scirs2_core::GpuDataType>(
&mut self,
size: usize,
) -> ScirsResult<&OptimGpuBuffer<T>> {
let size_bytes = size * std::mem::size_of::<T>();
let block = self.get_block(size_bytes)?;
block.as_typed::<T>()
}
pub fn create_array<T>(&mut self, dimensions: &[usize]) -> ScirsResult<OptimGpuArray<T>>
where
T: Clone + Default + 'static + scirs2_core::GpuDataType,
{
let total_elements: usize = dimensions.iter().product();
let buffer = self.get_buffer::<T>(total_elements)?;
Err(ScirsError::ComputationError(
scirs2_core::error::ErrorContext::new("Array creation not supported".to_string()),
))
}
pub fn total_size(&self) -> usize {
self.blocks.iter().map(|b| b.size).sum()
}
}
impl Drop for GpuWorkspace {
fn drop(&mut self) {
let mut pool = self.pool.lock().expect("Operation failed");
for block in self.blocks.drain(..) {
pool.entry(block.size).or_default().push_back(block);
}
}
}
#[derive(Debug)]
struct AllocatedBlock {
size: usize,
allocated_at: std::time::Instant,
}
#[derive(Debug, Clone)]
pub struct AllocationStats {
pub total_allocations: u64,
pub pool_hits: u64,
pub new_allocations: u64,
pub garbage_collections: u64,
pub total_freed_memory: usize,
}
impl AllocationStats {
fn new() -> Self {
Self {
total_allocations: 0,
pool_hits: 0,
new_allocations: 0,
garbage_collections: 0,
total_freed_memory: 0,
}
}
pub fn hit_rate(&self) -> f64 {
if self.total_allocations == 0 {
0.0
} else {
self.pool_hits as f64 / self.total_allocations as f64
}
}
}
#[derive(Debug, Clone)]
pub struct MemoryStats {
pub current_usage: usize,
pub memory_limit: Option<usize>,
pub allocation_stats: AllocationStats,
pub pool_sizes: HashMap<usize, usize>,
}
impl MemoryStats {
pub fn utilization(&self) -> Option<f64> {
self.memory_limit.map(|limit| {
if limit == 0 {
0.0
} else {
self.current_usage as f64 / limit as f64
}
})
}
pub fn generate_report(&self) -> String {
let mut report = String::from("GPU Memory Usage Report\n");
report.push_str("=======================\n\n");
report.push_str(&format!(
"Current Usage: {} bytes ({:.2} MB)\n",
self.current_usage,
self.current_usage as f64 / 1024.0 / 1024.0
));
if let Some(limit) = self.memory_limit {
report.push_str(&format!(
"Memory Limit: {} bytes ({:.2} MB)\n",
limit,
limit as f64 / 1024.0 / 1024.0
));
if let Some(util) = self.utilization() {
report.push_str(&format!("Utilization: {:.1}%\n", util * 100.0));
}
}
report.push('\n');
report.push_str("Allocation Statistics:\n");
report.push_str(&format!(
" Total Allocations: {}\n",
self.allocation_stats.total_allocations
));
report.push_str(&format!(
" Pool Hits: {} ({:.1}%)\n",
self.allocation_stats.pool_hits,
self.allocation_stats.hit_rate() * 100.0
));
report.push_str(&format!(
" New Allocations: {}\n",
self.allocation_stats.new_allocations
));
report.push_str(&format!(
" Garbage Collections: {}\n",
self.allocation_stats.garbage_collections
));
report.push_str(&format!(
" Total Freed: {} bytes\n",
self.allocation_stats.total_freed_memory
));
if !self.pool_sizes.is_empty() {
report.push('\n');
report.push_str("Memory Pools:\n");
let mut pools: Vec<_> = self.pool_sizes.iter().collect();
pools.sort_by_key(|&(size_, _)| size_);
for (&size, &count) in pools {
report.push_str(&format!(" {} bytes: {} blocks\n", size, count));
}
}
report
}
}
pub mod optimization {
use super::*;
#[derive(Debug, Clone)]
pub struct MemoryOptimizationConfig {
pub target_utilization: f64,
pub max_pool_size: usize,
pub gc_threshold: f64,
pub use_prefetching: bool,
}
impl Default for MemoryOptimizationConfig {
fn default() -> Self {
Self {
target_utilization: 0.8,
max_pool_size: 100,
gc_threshold: 0.9,
use_prefetching: true,
}
}
}
pub struct MemoryOptimizer {
config: MemoryOptimizationConfig,
pool: Arc<GpuMemoryPool>,
optimization_stats: OptimizationStats,
}
impl MemoryOptimizer {
pub fn new(config: MemoryOptimizationConfig, pool: Arc<GpuMemoryPool>) -> Self {
Self {
config,
pool,
optimization_stats: OptimizationStats::new(),
}
}
pub fn optimize(&mut self) -> ScirsResult<()> {
let stats = self.pool.memory_stats();
if let Some(utilization) = stats.utilization() {
if utilization > self.config.gc_threshold {
self.perform_garbage_collection()?;
self.optimization_stats.gc_triggered += 1;
}
}
self.optimize_pool_sizes(&stats)?;
Ok(())
}
fn perform_garbage_collection(&mut self) -> ScirsResult<()> {
self.optimization_stats.gc_operations += 1;
Ok(())
}
fn optimize_pool_sizes(&mut self, stats: &MemoryStats) -> ScirsResult<()> {
for (&_size, &count) in &stats.pool_sizes {
if count > self.config.max_pool_size {
self.optimization_stats.pool_optimizations += 1;
}
}
Ok(())
}
pub fn stats(&self) -> &OptimizationStats {
&self.optimization_stats
}
}
#[derive(Debug, Clone)]
pub struct OptimizationStats {
pub gc_triggered: u64,
pub gc_operations: u64,
pub pool_optimizations: u64,
}
impl OptimizationStats {
fn new() -> Self {
Self {
gc_triggered: 0,
gc_operations: 0,
pool_optimizations: 0,
}
}
}
}
pub mod utils {
use super::*;
pub fn calculate_allocation_strategy(
problem_size: usize,
batch_size: usize,
available_memory: usize,
) -> AllocationStrategy {
let estimated_usage = estimate_memory_usage(problem_size, batch_size);
if estimated_usage > available_memory {
AllocationStrategy::Chunked {
chunk_size: available_memory / 2,
overlap: true,
}
} else if estimated_usage > available_memory / 2 {
AllocationStrategy::Conservative {
pool_size_limit: available_memory / 4,
}
} else {
AllocationStrategy::Aggressive {
prefetch_size: estimated_usage * 2,
}
}
}
pub fn estimate_memory_usage(_problem_size: usize, batch_size: usize) -> usize {
let input_size = batch_size * _problem_size * 8; let output_size = batch_size * 8; let temp_size = input_size;
input_size + output_size + temp_size
}
#[derive(Debug, Clone)]
pub enum AllocationStrategy {
Chunked { chunk_size: usize, overlap: bool },
Conservative { pool_size_limit: usize },
Aggressive { prefetch_size: usize },
}
pub fn check_memory_availability(
required_memory: usize,
memory_info: &GpuMemoryInfo,
) -> ScirsResult<bool> {
let available = memory_info.free;
let safety_margin = 0.1; let usable = (available as f64 * (1.0 - safety_margin)) as usize;
Ok(required_memory <= usable)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_allocation_stats() {
let mut stats = AllocationStats::new();
stats.total_allocations = 100;
stats.pool_hits = 70;
assert_eq!(stats.hit_rate(), 0.7);
}
#[test]
fn test_memory_stats_utilization() {
let stats = MemoryStats {
current_usage: 800,
memory_limit: Some(1000),
allocation_stats: AllocationStats::new(),
pool_sizes: HashMap::new(),
};
assert_eq!(stats.utilization(), Some(0.8));
}
#[test]
fn test_memory_usage_estimation() {
let usage = utils::estimate_memory_usage(10, 100);
assert!(usage > 0);
let larger_usage = utils::estimate_memory_usage(20, 200);
assert!(larger_usage > usage);
}
#[test]
fn test_allocation_strategy() {
let strategy = utils::calculate_allocation_strategy(
1000, 1000, 500_000, );
match strategy {
utils::AllocationStrategy::Chunked { .. } => {
}
_ => panic!("Expected chunked strategy for large problem"),
}
}
}