use super::{GpuCapabilities, GpuError, GpuResult};
use crate::core::{PlottingError, Result};
use crate::data::{PooledVec, SharedMemoryPool};
use bytemuck::{Pod, cast_slice, try_cast_slice};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use wgpu::util::DeviceExt;
pub struct GpuBuffer {
buffer: wgpu::Buffer,
size: u64,
usage: wgpu::BufferUsages,
label: Option<String>,
}
impl GpuBuffer {
pub fn new(
device: &wgpu::Device,
data: &[u8],
usage: wgpu::BufferUsages,
label: Option<&str>,
) -> Self {
let buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label,
contents: data,
usage,
});
Self {
buffer,
size: data.len() as u64,
usage,
label: label.map(|s| s.to_string()),
}
}
pub fn new_empty(
device: &wgpu::Device,
size: u64,
usage: wgpu::BufferUsages,
label: Option<&str>,
) -> Self {
let buffer = device.create_buffer(&wgpu::BufferDescriptor {
label,
size,
usage,
mapped_at_creation: false,
});
Self {
buffer,
size,
usage,
label: label.map(|s| s.to_string()),
}
}
pub fn buffer(&self) -> &wgpu::Buffer {
&self.buffer
}
pub fn size(&self) -> u64 {
self.size
}
pub fn usage(&self) -> wgpu::BufferUsages {
self.usage
}
}
#[allow(clippy::type_complexity)] pub struct GpuMemoryPool {
device: Arc<wgpu::Device>,
queue: Arc<wgpu::Queue>,
buffer_cache: Arc<Mutex<HashMap<(u64, wgpu::BufferUsages), Vec<GpuBuffer>>>>,
total_allocated: Arc<Mutex<u64>>,
memory_limit: u64,
alignment: u64,
stats: Arc<Mutex<GpuMemoryStats>>,
}
#[derive(Debug, Clone, Default)]
pub struct GpuMemoryStats {
pub total_allocated: u64,
pub buffers_created: usize,
pub buffers_reused: usize,
pub cache_hits: usize,
pub cache_misses: usize,
pub memory_limit: u64,
}
impl GpuMemoryPool {
pub fn new(
device: Arc<wgpu::Device>,
queue: Arc<wgpu::Queue>,
capabilities: &GpuCapabilities,
) -> Result<Self> {
let memory_limit = (capabilities.max_buffer_size as f64 * 0.8) as u64;
let alignment = device.limits().min_uniform_buffer_offset_alignment as u64;
Ok(Self {
device,
queue,
buffer_cache: Arc::new(Mutex::new(HashMap::new())),
total_allocated: Arc::new(Mutex::new(0)),
memory_limit,
alignment,
stats: Arc::new(Mutex::new(GpuMemoryStats {
memory_limit,
..Default::default()
})),
})
}
pub fn create_buffer<T: Pod>(
&self,
data: &[T],
usage: wgpu::BufferUsages,
) -> Result<GpuBuffer> {
let bytes = cast_slice(data);
self.create_buffer_bytes(bytes, usage, None)
}
pub fn create_buffer_empty<T: Pod>(
&self,
count: usize,
usage: wgpu::BufferUsages,
) -> Result<GpuBuffer> {
let size = (count * std::mem::size_of::<T>()) as u64;
let aligned_size = self.align_buffer_size(size);
self.create_buffer_empty_bytes(aligned_size, usage, None)
}
pub fn create_buffer_bytes(
&self,
data: &[u8],
usage: wgpu::BufferUsages,
label: Option<&str>,
) -> Result<GpuBuffer> {
let size = data.len() as u64;
let aligned_size = self.align_buffer_size(size);
{
let total_allocated = self.total_allocated.lock().unwrap();
if *total_allocated + aligned_size > self.memory_limit {
return Err(PlottingError::GpuMemoryError {
requested: aligned_size as usize,
available: Some((self.memory_limit - *total_allocated) as usize),
});
}
}
let cache_key = (aligned_size, usage);
if let Some(buffer) = self.try_reuse_buffer(&cache_key) {
self.queue.write_buffer(buffer.buffer(), 0, data);
{
let mut stats = self.stats.lock().unwrap();
stats.buffers_reused += 1;
stats.cache_hits += 1;
}
return Ok(buffer);
}
let buffer = if data.is_empty() {
GpuBuffer::new_empty(&self.device, aligned_size, usage, label)
} else {
let mut padded_data = data.to_vec();
padded_data.resize(aligned_size as usize, 0);
GpuBuffer::new(&self.device, &padded_data, usage, label)
};
{
let mut total_allocated = self.total_allocated.lock().unwrap();
*total_allocated += aligned_size;
}
{
let mut stats = self.stats.lock().unwrap();
stats.total_allocated += aligned_size;
stats.buffers_created += 1;
stats.cache_misses += 1;
}
Ok(buffer)
}
pub fn create_buffer_empty_bytes(
&self,
size: u64,
usage: wgpu::BufferUsages,
label: Option<&str>,
) -> Result<GpuBuffer> {
let aligned_size = self.align_buffer_size(size);
{
let total_allocated = self.total_allocated.lock().unwrap();
if *total_allocated + aligned_size > self.memory_limit {
return Err(PlottingError::GpuMemoryError {
requested: aligned_size as usize,
available: Some((self.memory_limit - *total_allocated) as usize),
});
}
}
let buffer = GpuBuffer::new_empty(&self.device, aligned_size, usage, label);
{
let mut total_allocated = self.total_allocated.lock().unwrap();
*total_allocated += aligned_size;
}
{
let mut stats = self.stats.lock().unwrap();
stats.total_allocated += aligned_size;
stats.buffers_created += 1;
stats.cache_misses += 1;
}
Ok(buffer)
}
fn try_reuse_buffer(&self, cache_key: &(u64, wgpu::BufferUsages)) -> Option<GpuBuffer> {
let mut cache = self.buffer_cache.lock().unwrap();
if let Some(buffers) = cache.get_mut(cache_key) {
buffers.pop()
} else {
None
}
}
pub fn return_buffer(&self, buffer: GpuBuffer) {
let cache_key = (buffer.size(), buffer.usage());
let mut cache = self.buffer_cache.lock().unwrap();
cache.entry(cache_key).or_default().push(buffer);
}
fn align_buffer_size(&self, size: u64) -> u64 {
size.div_ceil(self.alignment) * self.alignment
}
fn cast_slice_safe<T: Pod + Clone>(bytes: &[u8], element_count: usize) -> Vec<T> {
let element_size = std::mem::size_of::<T>();
if let Ok(aligned) = try_cast_slice::<u8, T>(bytes) {
return aligned.to_vec();
}
let mut result = Vec::with_capacity(element_count);
for i in 0..element_count {
let offset = i * element_size;
let element_bytes = &bytes[offset..offset + element_size];
let mut aligned_bytes = vec![0u8; element_size];
aligned_bytes.copy_from_slice(element_bytes);
let element: &T = bytemuck::from_bytes(&aligned_bytes);
result.push(*element);
}
result
}
pub fn read_buffer<T: Pod + Clone>(&self, buffer: &GpuBuffer) -> GpuResult<Vec<T>> {
if !buffer.usage().contains(wgpu::BufferUsages::COPY_SRC) {
return Err(GpuError::OperationFailed(
"Buffer was not created with COPY_SRC usage".to_string(),
));
}
let element_size = std::mem::size_of::<T>();
let element_count = (buffer.size() as usize) / element_size;
let staging_buffer = self
.create_buffer_empty_bytes(
buffer.size(),
wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
Some("GPU Readback Staging"),
)
.map_err(|e| GpuError::BufferCreationFailed(format!("{}", e)))?;
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("GPU Buffer Copy"),
});
encoder.copy_buffer_to_buffer(
buffer.buffer(),
0,
staging_buffer.buffer(),
0,
buffer.size(),
);
let submission = self.queue.submit(Some(encoder.finish()));
let buffer_slice = staging_buffer.buffer().slice(..);
let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
sender.send(result).ok();
});
let _ = self.device.poll(wgpu::PollType::Wait {
submission_index: Some(submission),
timeout: None,
});
pollster::block_on(receiver.receive())
.ok_or_else(|| GpuError::OperationFailed("Buffer mapping failed".to_string()))?
.map_err(|e| GpuError::OperationFailed(format!("Buffer mapping error: {:?}", e)))?;
let mapped_data = buffer_slice.get_mapped_range();
let byte_slice = &mapped_data[..element_count * element_size];
let result_data = Self::cast_slice_safe::<T>(byte_slice, element_count);
drop(mapped_data);
staging_buffer.buffer().unmap();
Ok(result_data)
}
pub fn create_buffer_from_pooled<T: Pod>(
&self,
pooled_data: &PooledVec<T>,
usage: wgpu::BufferUsages,
) -> Result<GpuBuffer> {
let slice: &[T] = pooled_data.as_slice();
self.create_buffer(slice, usage)
}
pub fn read_buffer_to_pooled<T: Pod + Clone + Default>(
&self,
buffer: &GpuBuffer,
cpu_pool: SharedMemoryPool<T>,
) -> GpuResult<PooledVec<T>> {
let gpu_data = self.read_buffer::<T>(buffer)?;
let mut pooled_result = PooledVec::with_capacity(gpu_data.len(), cpu_pool);
for item in gpu_data {
pooled_result.push(item);
}
Ok(pooled_result)
}
pub fn get_stats(&self) -> GpuMemoryStats {
self.stats.lock().unwrap().clone()
}
pub fn clear_cache(&self) {
let mut cache = self.buffer_cache.lock().unwrap();
let mut total_allocated = self.total_allocated.lock().unwrap();
cache.clear();
*total_allocated = 0;
let mut stats = self.stats.lock().unwrap();
*stats = GpuMemoryStats {
memory_limit: stats.memory_limit,
..Default::default()
};
}
pub fn memory_usage_fraction(&self) -> f32 {
let total_allocated = *self.total_allocated.lock().unwrap();
total_allocated as f32 / self.memory_limit as f32
}
pub fn is_memory_pressure(&self) -> bool {
self.memory_usage_fraction() > 0.85 }
}
pub struct PooledGpuBuffer {
buffer: Option<GpuBuffer>,
pool: Arc<GpuMemoryPool>,
}
impl PooledGpuBuffer {
pub fn new(buffer: GpuBuffer, pool: Arc<GpuMemoryPool>) -> Self {
Self {
buffer: Some(buffer),
pool,
}
}
pub fn buffer(&self) -> &GpuBuffer {
self.buffer.as_ref().unwrap()
}
}
impl Drop for PooledGpuBuffer {
fn drop(&mut self) {
if let Some(buffer) = self.buffer.take() {
self.pool.return_buffer(buffer);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_buffer_alignment() {
let alignment = 256u64; let pool_alignment = |size: u64| -> u64 { size.div_ceil(alignment) * alignment };
assert_eq!(pool_alignment(100), 256);
assert_eq!(pool_alignment(256), 256);
assert_eq!(pool_alignment(300), 512);
}
#[test]
fn test_memory_stats() {
let stats = GpuMemoryStats {
total_allocated: 1024,
buffers_created: 5,
cache_hits: 3,
..Default::default()
};
assert_eq!(stats.total_allocated, 1024);
assert_eq!(stats.buffers_created, 5);
}
#[test]
fn test_cast_slice_safe_aligned() {
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let bytes: &[u8] = bytemuck::cast_slice(&data);
let result = GpuMemoryPool::cast_slice_safe::<f32>(bytes, 4);
assert_eq!(result, vec![1.0f32, 2.0, 3.0, 4.0]);
}
#[test]
fn test_cast_slice_safe_unaligned() {
let original: Vec<f32> = vec![1.0, 2.0, 3.0];
let original_bytes: &[u8] = bytemuck::cast_slice(&original);
let mut bytes = vec![0u8]; bytes.extend_from_slice(original_bytes);
let unaligned = &bytes[1..];
let result = GpuMemoryPool::cast_slice_safe::<f32>(unaligned, 3);
assert_eq!(result, vec![1.0f32, 2.0, 3.0]);
}
#[test]
fn test_cast_slice_safe_u32() {
let data: Vec<u32> = vec![100, 200, 300, 400, 500];
let bytes: &[u8] = bytemuck::cast_slice(&data);
let result = GpuMemoryPool::cast_slice_safe::<u32>(bytes, 5);
assert_eq!(result, vec![100u32, 200, 300, 400, 500]);
}
#[test]
fn test_cast_slice_safe_unaligned_u64() {
let original: Vec<u64> = vec![0x1234567890ABCDEF, 0xFEDCBA0987654321];
let original_bytes: &[u8] = bytemuck::cast_slice(&original);
let mut bytes = vec![0u8; 3];
bytes.extend_from_slice(original_bytes);
let unaligned = &bytes[3..];
let result = GpuMemoryPool::cast_slice_safe::<u64>(unaligned, 2);
assert_eq!(result, vec![0x1234567890ABCDEFu64, 0xFEDCBA0987654321]);
}
#[test]
fn test_cast_slice_safe_empty() {
let bytes: &[u8] = &[];
let result = GpuMemoryPool::cast_slice_safe::<f32>(bytes, 0);
assert!(result.is_empty());
}
#[test]
fn test_cast_slice_safe_single_element() {
let data: Vec<f32> = vec![42.5];
let bytes: &[u8] = bytemuck::cast_slice(&data);
let result = GpuMemoryPool::cast_slice_safe::<f32>(bytes, 1);
assert_eq!(result, vec![42.5f32]);
}
}