use crate::error::{MemoryError, NumRs2Error, OperationContext, Result};
use crate::gpu::context::GpuContextRef;
use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
const LARGE_TRANSFER_THRESHOLD: u64 = 16 * 1024 * 1024;
const DEFAULT_MAX_POOL_SIZE: usize = 100;
const DEFAULT_BUFFER_EXPIRATION: Duration = Duration::from_secs(300);
#[repr(align(64))]
pub struct GpuMemoryPool {
context: GpuContextRef,
pools: Arc<Mutex<BufferPools>>,
config: PoolConfig,
}
#[derive(Debug, Clone)]
pub struct PoolConfig {
pub max_pool_size: usize,
pub buffer_expiration: Duration,
pub auto_gc: bool,
pub gc_retention_rate: f32,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
max_pool_size: DEFAULT_MAX_POOL_SIZE,
buffer_expiration: DEFAULT_BUFFER_EXPIRATION,
auto_gc: true,
gc_retention_rate: 0.8,
}
}
}
struct BufferPools {
pools: HashMap<(u64, u32), VecDeque<PooledBuffer>>,
total_buffers: usize,
total_bytes: u64,
}
struct PooledBuffer {
buffer: wgpu::Buffer,
size: u64,
last_used: Instant,
}
impl GpuMemoryPool {
pub fn new(context: GpuContextRef) -> Self {
Self::with_config(context, PoolConfig::default())
}
pub fn with_config(context: GpuContextRef, config: PoolConfig) -> Self {
Self {
context,
pools: Arc::new(Mutex::new(BufferPools {
pools: HashMap::new(),
total_buffers: 0,
total_bytes: 0,
})),
config,
}
}
pub fn allocate(&mut self, size: u64, usage: wgpu::BufferUsages) -> Result<ManagedBuffer> {
let usage_bits = usage.bits();
let key = (size, usage_bits);
let buffer = {
let mut pools = self.pools.lock().map_err(|e| {
NumRs2Error::from(MemoryError::gpu_memory_error(
&format!("Failed to lock buffer pool: {}", e),
None,
))
})?;
if let Some(pool) = pools.pools.get_mut(&key) {
if let Some(mut pooled) = pool.pop_front() {
pooled.last_used = Instant::now();
pools.total_buffers = pools.total_buffers.saturating_sub(1);
pools.total_bytes = pools.total_bytes.saturating_sub(size);
Some(pooled.buffer)
} else {
None
}
} else {
None
}
};
let buffer = if let Some(buf) = buffer {
buf
} else {
self.context.create_empty_buffer(size, usage)
};
Ok(ManagedBuffer {
buffer: Some(buffer),
size,
usage_bits,
pool: Arc::clone(&self.pools),
returned: false,
})
}
pub fn collect_garbage(&mut self, retention_rate: f32) -> Result<(usize, u64)> {
let retention_rate = retention_rate.clamp(0.0, 1.0);
let cutoff_time = Instant::now() - self.config.buffer_expiration;
let mut pools = self.pools.lock().map_err(|e| {
NumRs2Error::from(MemoryError::gpu_memory_error(
&format!("Failed to lock buffer pool during GC: {}", e),
None,
))
})?;
let mut freed_buffers = 0;
let mut freed_bytes = 0u64;
for pool in pools.pools.values_mut() {
let original_len = pool.len();
let keep_count = ((original_len as f32) * retention_rate).ceil() as usize;
let mut kept = 0;
pool.retain(|pooled| {
let should_keep = pooled.last_used > cutoff_time && kept < keep_count;
if should_keep {
kept += 1;
true
} else {
freed_buffers += 1;
freed_bytes += pooled.size;
false
}
});
}
pools.total_buffers = pools.total_buffers.saturating_sub(freed_buffers);
pools.total_bytes = pools.total_bytes.saturating_sub(freed_bytes);
pools.pools.retain(|_, pool| !pool.is_empty());
Ok((freed_buffers, freed_bytes))
}
pub fn statistics(&self) -> Result<PoolStatistics> {
let pools = self.pools.lock().map_err(|e| {
NumRs2Error::from(MemoryError::gpu_memory_error(
&format!("Failed to lock buffer pool: {}", e),
None,
))
})?;
Ok(PoolStatistics {
total_buffers: pools.total_buffers,
total_bytes: pools.total_bytes,
pool_count: pools.pools.len(),
})
}
pub fn clear(&mut self) -> Result<()> {
let mut pools = self.pools.lock().map_err(|e| {
NumRs2Error::from(MemoryError::gpu_memory_error(
&format!("Failed to lock buffer pool during clear: {}", e),
None,
))
})?;
pools.pools.clear();
pools.total_buffers = 0;
pools.total_bytes = 0;
Ok(())
}
}
#[derive(Debug, Clone, Copy)]
pub struct PoolStatistics {
pub total_buffers: usize,
pub total_bytes: u64,
pub pool_count: usize,
}
pub struct ManagedBuffer {
buffer: Option<wgpu::Buffer>,
size: u64,
usage_bits: u32,
pool: Arc<Mutex<BufferPools>>,
returned: bool,
}
impl ManagedBuffer {
pub fn buffer(&self) -> &wgpu::Buffer {
self.buffer
.as_ref()
.expect("Buffer has been returned to pool")
}
pub fn size(&self) -> u64 {
self.size
}
pub fn return_to_pool(&mut self) {
if !self.returned {
self.returned = true;
if let Some(buffer) = self.buffer.take() {
if let Ok(mut pools) = self.pool.lock() {
let key = (self.size, self.usage_bits);
let pool = pools.pools.entry(key).or_insert_with(VecDeque::new);
pool.push_back(PooledBuffer {
buffer,
size: self.size,
last_used: Instant::now(),
});
pools.total_buffers += 1;
pools.total_bytes += self.size;
}
}
}
}
}
impl Drop for ManagedBuffer {
fn drop(&mut self) {
self.return_to_pool();
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransferStrategy {
Immediate,
Batched,
Async,
}
#[repr(align(64))]
pub struct TransferOptimizer {
context: GpuContextRef,
strategy: TransferStrategy,
batch_queue: Vec<PendingTransfer>,
async_queue: Arc<Mutex<Vec<AsyncTransfer>>>,
}
struct PendingTransfer {
source: Vec<u8>,
destination: wgpu::Buffer,
offset: u64,
}
struct AsyncTransfer {
id: u64,
size: u64,
submitted: std::time::Instant,
completed: bool,
}
impl TransferOptimizer {
pub fn new(context: GpuContextRef, strategy: TransferStrategy) -> Self {
Self {
context,
strategy,
batch_queue: Vec::new(),
async_queue: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn is_large_transfer(size: u64) -> bool {
size >= LARGE_TRANSFER_THRESHOLD
}
pub fn queue_transfer<T: bytemuck::Pod>(
&mut self,
data: &[T],
buffer: &wgpu::Buffer,
) -> Result<()> {
self.queue_transfer_with_offset(data, buffer, 0)
}
pub fn queue_transfer_with_offset<T: bytemuck::Pod>(
&mut self,
data: &[T],
buffer: &wgpu::Buffer,
offset: u64,
) -> Result<()> {
let byte_data = bytemuck::cast_slice(data);
match self.strategy {
TransferStrategy::Immediate => {
self.context.queue().write_buffer(buffer, offset, byte_data);
Ok(())
}
TransferStrategy::Batched => {
self.batch_queue.push(PendingTransfer {
source: byte_data.to_vec(),
destination: buffer.clone(),
offset,
});
Ok(())
}
TransferStrategy::Async => {
let transfer_id = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_err(|e| {
NumRs2Error::from(MemoryError::gpu_memory_error(
&format!("Failed to generate transfer ID: {}", e),
None,
))
})?
.as_nanos() as u64;
self.context.queue().write_buffer(buffer, offset, byte_data);
let mut async_queue = self.async_queue.lock().map_err(|e| {
NumRs2Error::from(MemoryError::gpu_memory_error(
&format!("Failed to lock async queue: {}", e),
None,
))
})?;
async_queue.push(AsyncTransfer {
id: transfer_id,
size: byte_data.len() as u64,
submitted: Instant::now(),
completed: false,
});
Ok(())
}
}
}
pub fn flush(&mut self) -> Result<()> {
if !self.batch_queue.is_empty() {
for transfer in self.batch_queue.drain(..) {
self.context.queue().write_buffer(
&transfer.destination,
transfer.offset,
&transfer.source,
);
}
}
Ok(())
}
pub fn pending_transfers(&self) -> Result<usize> {
let async_queue = self.async_queue.lock().map_err(|e| {
NumRs2Error::from(MemoryError::gpu_memory_error(
&format!("Failed to lock async queue: {}", e),
None,
))
})?;
Ok(async_queue.iter().filter(|t| !t.completed).count())
}
pub fn wait_for_completion(&mut self) -> Result<()> {
self.context
.device()
.poll(wgpu::PollType::wait_indefinitely())
.map_err(|e| {
NumRs2Error::from(MemoryError::gpu_memory_error(
&format!("Failed to poll device: {:?}", e),
None,
))
})?;
let mut async_queue = self.async_queue.lock().map_err(|e| {
NumRs2Error::from(MemoryError::gpu_memory_error(
&format!("Failed to lock async queue: {}", e),
None,
))
})?;
for transfer in async_queue.iter_mut() {
transfer.completed = true;
}
Ok(())
}
pub fn clear_completed(&mut self) -> Result<usize> {
let mut async_queue = self.async_queue.lock().map_err(|e| {
NumRs2Error::from(MemoryError::gpu_memory_error(
&format!("Failed to lock async queue: {}", e),
None,
))
})?;
let before_count = async_queue.len();
async_queue.retain(|t| !t.completed);
Ok(before_count - async_queue.len())
}
pub fn strategy(&self) -> TransferStrategy {
self.strategy
}
pub fn set_strategy(&mut self, strategy: TransferStrategy) {
self.strategy = strategy;
}
}
pub struct DoubleBuffer {
context: GpuContextRef,
buffers: [wgpu::Buffer; 2],
current_index: usize,
size: u64,
usage: wgpu::BufferUsages,
}
impl DoubleBuffer {
pub fn new(context: GpuContextRef, size: u64, usage: wgpu::BufferUsages) -> Self {
let buffer_a = context.create_empty_buffer(size, usage);
let buffer_b = context.create_empty_buffer(size, usage);
Self {
context,
buffers: [buffer_a, buffer_b],
current_index: 0,
size,
usage,
}
}
pub fn current(&self) -> &wgpu::Buffer {
&self.buffers[self.current_index]
}
pub fn next(&self) -> &wgpu::Buffer {
&self.buffers[1 - self.current_index]
}
pub fn swap(&mut self) {
self.current_index = 1 - self.current_index;
}
pub fn size(&self) -> u64 {
self.size
}
pub fn write_current<T: bytemuck::Pod>(&self, data: &[T]) {
self.context
.queue()
.write_buffer(self.current(), 0, bytemuck::cast_slice(data));
}
pub fn write_next<T: bytemuck::Pod>(&self, data: &[T]) {
self.context
.queue()
.write_buffer(self.next(), 0, bytemuck::cast_slice(data));
}
}
pub struct BufferAliasManager {
context: GpuContextRef,
aliases: Arc<Mutex<HashMap<u64, Vec<BufferAlias>>>>,
}
struct BufferAlias {
buffer: wgpu::Buffer,
offset: u64,
size: u64,
usage: wgpu::BufferUsages,
ref_count: usize,
}
impl BufferAliasManager {
pub fn new(context: GpuContextRef) -> Self {
Self {
context,
aliases: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn get_or_create_buffer(
&mut self,
size: u64,
usage: wgpu::BufferUsages,
) -> Result<wgpu::Buffer> {
let key = size;
let mut aliases = self.aliases.lock().map_err(|e| {
NumRs2Error::from(MemoryError::gpu_memory_error(
&format!("Failed to lock alias manager: {}", e),
None,
))
})?;
if let Some(alias_list) = aliases.get_mut(&key) {
for alias in alias_list.iter_mut() {
if alias.usage == usage {
alias.ref_count += 1;
return Ok(alias.buffer.clone());
}
}
}
let buffer = self.context.create_empty_buffer(size, usage);
let alias = BufferAlias {
buffer: buffer.clone(),
offset: 0,
size,
usage,
ref_count: 1,
};
aliases.entry(key).or_insert_with(Vec::new).push(alias);
Ok(buffer)
}
pub fn release_buffer(&mut self, size: u64, usage: wgpu::BufferUsages) -> Result<()> {
let key = size;
let mut aliases = self.aliases.lock().map_err(|e| {
NumRs2Error::from(MemoryError::gpu_memory_error(
&format!("Failed to lock alias manager: {}", e),
None,
))
})?;
if let Some(alias_list) = aliases.get_mut(&key) {
for alias in alias_list.iter_mut() {
if alias.usage == usage && alias.ref_count > 0 {
alias.ref_count -= 1;
break;
}
}
alias_list.retain(|a| a.ref_count > 0);
if alias_list.is_empty() {
aliases.remove(&key);
}
}
Ok(())
}
pub fn statistics(&self) -> Result<AliasStatistics> {
let aliases = self.aliases.lock().map_err(|e| {
NumRs2Error::from(MemoryError::gpu_memory_error(
&format!("Failed to lock alias manager: {}", e),
None,
))
})?;
let total_aliases: usize = aliases.values().map(|v| v.len()).sum();
let total_refs: usize = aliases
.values()
.flat_map(|v| v.iter().map(|a| a.ref_count))
.sum();
Ok(AliasStatistics {
total_aliases,
total_references: total_refs,
buffer_sizes: aliases.len(),
})
}
pub fn clear(&mut self) -> Result<()> {
let mut aliases = self.aliases.lock().map_err(|e| {
NumRs2Error::from(MemoryError::gpu_memory_error(
&format!("Failed to lock alias manager: {}", e),
None,
))
})?;
aliases.clear();
Ok(())
}
}
#[derive(Debug, Clone, Copy)]
pub struct AliasStatistics {
pub total_aliases: usize,
pub total_references: usize,
pub buffer_sizes: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pool_config_default() {
let config = PoolConfig::default();
assert_eq!(config.max_pool_size, DEFAULT_MAX_POOL_SIZE);
assert!(config.auto_gc);
assert_eq!(config.gc_retention_rate, 0.8);
}
#[test]
fn test_is_large_transfer() {
assert!(!TransferOptimizer::is_large_transfer(1024));
assert!(!TransferOptimizer::is_large_transfer(
LARGE_TRANSFER_THRESHOLD - 1
));
assert!(TransferOptimizer::is_large_transfer(
LARGE_TRANSFER_THRESHOLD
));
assert!(TransferOptimizer::is_large_transfer(
LARGE_TRANSFER_THRESHOLD + 1
));
}
#[test]
fn test_transfer_strategy() {
assert_ne!(TransferStrategy::Immediate, TransferStrategy::Batched);
assert_ne!(TransferStrategy::Immediate, TransferStrategy::Async);
assert_ne!(TransferStrategy::Batched, TransferStrategy::Async);
}
}