use std::sync::Arc;
use std::time::{Duration, Instant};
use parking_lot::Mutex;
use ringkernel_core::control::ControlBlock;
use ringkernel_core::error::{Result, RingKernelError};
use ringkernel_core::memory::GpuBuffer;
use ringkernel_core::message::MessageHeader;
use crate::adapter::WgpuAdapter;
#[derive(Debug, Clone, Default)]
pub struct StagingPoolStats {
pub total_acquires: u64,
pub cache_hits: u64,
pub allocations: u64,
pub returns: u64,
pub trimmed: u64,
}
impl StagingPoolStats {
pub fn hit_rate(&self) -> f64 {
if self.total_acquires == 0 {
0.0
} else {
self.cache_hits as f64 / self.total_acquires as f64
}
}
}
struct StagingEntry {
buffer: wgpu::Buffer,
size: usize,
last_used: Instant,
}
pub struct StagingBufferPool {
device: Arc<wgpu::Device>,
buffers: Mutex<Vec<StagingEntry>>,
max_cached: usize,
stats: Mutex<StagingPoolStats>,
}
impl StagingBufferPool {
pub fn new(device: Arc<wgpu::Device>, max_cached: usize) -> Self {
Self {
device,
buffers: Mutex::new(Vec::with_capacity(max_cached)),
max_cached,
stats: Mutex::new(StagingPoolStats::default()),
}
}
pub fn acquire(&self, min_size: usize) -> StagingBufferGuard<'_> {
let mut stats = self.stats.lock();
stats.total_acquires += 1;
let mut buffers = self.buffers.lock();
let mut best_idx = None;
let mut best_size = usize::MAX;
for (idx, entry) in buffers.iter().enumerate() {
if entry.size >= min_size && entry.size < best_size {
best_size = entry.size;
best_idx = Some(idx);
if entry.size == min_size {
break;
}
}
}
if let Some(idx) = best_idx {
stats.cache_hits += 1;
let entry = buffers.remove(idx);
drop(buffers);
drop(stats);
return StagingBufferGuard {
buffer: Some(entry.buffer),
size: entry.size,
pool: self,
};
}
stats.allocations += 1;
drop(buffers);
drop(stats);
let actual_size = min_size.next_power_of_two().max(4096);
let buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Staging Buffer (Pooled)"),
size: actual_size as u64,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
StagingBufferGuard {
buffer: Some(buffer),
size: actual_size,
pool: self,
}
}
fn return_buffer(&self, buffer: wgpu::Buffer, size: usize) {
let mut stats = self.stats.lock();
let mut buffers = self.buffers.lock();
if buffers.len() < self.max_cached {
stats.returns += 1;
buffers.push(StagingEntry {
buffer,
size,
last_used: Instant::now(),
});
}
}
pub fn trim(&self, max_age: Duration) {
let mut stats = self.stats.lock();
let mut buffers = self.buffers.lock();
let now = Instant::now();
let before_len = buffers.len();
buffers.retain(|entry| now.duration_since(entry.last_used) < max_age);
let trimmed = before_len - buffers.len();
stats.trimmed += trimmed as u64;
}
pub fn clear(&self) {
let mut buffers = self.buffers.lock();
buffers.clear();
}
pub fn cached_count(&self) -> usize {
self.buffers.lock().len()
}
pub fn stats(&self) -> StagingPoolStats {
self.stats.lock().clone()
}
}
pub struct StagingBufferGuard<'a> {
buffer: Option<wgpu::Buffer>,
size: usize,
pool: &'a StagingBufferPool,
}
impl<'a> StagingBufferGuard<'a> {
pub fn buffer(&self) -> &wgpu::Buffer {
self.buffer.as_ref().expect("buffer should exist")
}
pub fn size(&self) -> usize {
self.size
}
}
impl<'a> Drop for StagingBufferGuard<'a> {
fn drop(&mut self) {
if let Some(buffer) = self.buffer.take() {
self.pool.return_buffer(buffer, self.size);
}
}
}
pub type SharedStagingPool = Arc<StagingBufferPool>;
pub fn create_staging_pool(device: Arc<wgpu::Device>, max_cached: usize) -> SharedStagingPool {
Arc::new(StagingBufferPool::new(device, max_cached))
}
pub struct WgpuBuffer {
buffer: wgpu::Buffer,
size: usize,
device: Arc<wgpu::Device>,
queue: Arc<wgpu::Queue>,
staging_pool: Option<SharedStagingPool>,
}
impl WgpuBuffer {
pub fn new(adapter: &WgpuAdapter, size: usize, label: Option<&str>) -> Self {
let buffer = adapter.device().create_buffer(&wgpu::BufferDescriptor {
label,
size: size as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
Self {
buffer,
size,
device: Arc::clone(adapter.device()),
queue: Arc::clone(adapter.queue()),
staging_pool: None,
}
}
pub fn new_with_pool(
adapter: &WgpuAdapter,
size: usize,
label: Option<&str>,
staging_pool: SharedStagingPool,
) -> Self {
let buffer = adapter.device().create_buffer(&wgpu::BufferDescriptor {
label,
size: size as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
Self {
buffer,
size,
device: Arc::clone(adapter.device()),
queue: Arc::clone(adapter.queue()),
staging_pool: Some(staging_pool),
}
}
pub fn new_init(adapter: &WgpuAdapter, data: &[u8], label: Option<&str>) -> Self {
let buffer = adapter.device().create_buffer(&wgpu::BufferDescriptor {
label,
size: data.len() as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: true,
});
buffer
.slice(..)
.get_mapped_range_mut()
.copy_from_slice(data);
buffer.unmap();
Self {
buffer,
size: data.len(),
device: Arc::clone(adapter.device()),
queue: Arc::clone(adapter.queue()),
staging_pool: None,
}
}
pub fn new_init_with_pool(
adapter: &WgpuAdapter,
data: &[u8],
label: Option<&str>,
staging_pool: SharedStagingPool,
) -> Self {
let buffer = adapter.device().create_buffer(&wgpu::BufferDescriptor {
label,
size: data.len() as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: true,
});
buffer
.slice(..)
.get_mapped_range_mut()
.copy_from_slice(data);
buffer.unmap();
Self {
buffer,
size: data.len(),
device: Arc::clone(adapter.device()),
queue: Arc::clone(adapter.queue()),
staging_pool: Some(staging_pool),
}
}
pub fn set_staging_pool(&mut self, pool: SharedStagingPool) {
self.staging_pool = Some(pool);
}
pub fn staging_pool(&self) -> Option<&SharedStagingPool> {
self.staging_pool.as_ref()
}
#[allow(dead_code)]
pub fn inner(&self) -> &wgpu::Buffer {
&self.buffer
}
pub fn as_entire_binding(&self) -> wgpu::BindingResource<'_> {
self.buffer.as_entire_binding()
}
}
impl GpuBuffer for WgpuBuffer {
fn size(&self) -> usize {
self.size
}
fn device_ptr(&self) -> usize {
0
}
fn copy_from_host(&self, data: &[u8]) -> Result<()> {
if data.len() > self.size {
return Err(RingKernelError::AllocationFailed {
size: data.len(),
reason: format!("buffer too small: {} > {}", data.len(), self.size),
});
}
self.queue.write_buffer(&self.buffer, 0, data);
Ok(())
}
fn copy_to_host(&self, data: &mut [u8]) -> Result<()> {
if data.len() > self.size {
return Err(RingKernelError::AllocationFailed {
size: data.len(),
reason: format!("buffer too small: {} > {}", data.len(), self.size),
});
}
if let Some(pool) = &self.staging_pool {
self.copy_to_host_pooled(data, pool)
} else {
self.copy_to_host_unpooled(data)
}
}
}
impl WgpuBuffer {
fn copy_to_host_pooled(&self, data: &mut [u8], pool: &StagingBufferPool) -> Result<()> {
let staging_guard = pool.acquire(self.size);
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("Copy Encoder (Pooled)"),
});
encoder.copy_buffer_to_buffer(&self.buffer, 0, staging_guard.buffer(), 0, self.size as u64);
self.queue.submit(Some(encoder.finish()));
let slice = staging_guard.buffer().slice(..);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |result| {
tx.send(result).unwrap();
});
self.device.poll(wgpu::Maintain::Wait);
rx.recv()
.map_err(|e| RingKernelError::TransferFailed(format!("Channel error: {}", e)))?
.map_err(|e| RingKernelError::TransferFailed(format!("Map error: {}", e)))?;
let mapped = slice.get_mapped_range();
let copy_len = self.size.min(data.len());
data[..copy_len].copy_from_slice(&mapped[..copy_len]);
drop(mapped);
staging_guard.buffer().unmap();
Ok(())
}
fn copy_to_host_unpooled(&self, data: &mut [u8]) -> Result<()> {
let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Staging Buffer"),
size: self.size as u64,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("Copy Encoder"),
});
encoder.copy_buffer_to_buffer(&self.buffer, 0, &staging, 0, self.size as u64);
self.queue.submit(Some(encoder.finish()));
let slice = staging.slice(..);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |result| {
tx.send(result).unwrap();
});
self.device.poll(wgpu::Maintain::Wait);
rx.recv()
.map_err(|e| RingKernelError::TransferFailed(format!("Channel error: {}", e)))?
.map_err(|e| RingKernelError::TransferFailed(format!("Map error: {}", e)))?;
let mapped = slice.get_mapped_range();
let copy_len = self.size.min(data.len());
data[..copy_len].copy_from_slice(&mapped[..copy_len]);
drop(mapped);
staging.unmap();
Ok(())
}
}
pub struct WgpuControlBlock {
buffer: WgpuBuffer,
}
impl WgpuControlBlock {
pub fn new(adapter: &WgpuAdapter) -> Self {
let size = std::mem::size_of::<ControlBlock>();
let data = vec![0u8; size];
let buffer = WgpuBuffer::new_init(adapter, &data, Some("Control Block"));
Self { buffer }
}
pub fn as_binding(&self) -> wgpu::BindingResource<'_> {
self.buffer.as_entire_binding()
}
#[allow(dead_code)]
pub fn buffer(&self) -> &WgpuBuffer {
&self.buffer
}
pub fn read(&self) -> Result<ControlBlock> {
let mut data = vec![0u8; std::mem::size_of::<ControlBlock>()];
self.buffer.copy_to_host(&mut data)?;
Ok(unsafe { std::ptr::read(data.as_ptr() as *const ControlBlock) })
}
pub fn write(&self, cb: &ControlBlock) -> Result<()> {
let data = unsafe {
std::slice::from_raw_parts(
cb as *const ControlBlock as *const u8,
std::mem::size_of::<ControlBlock>(),
)
};
self.buffer.copy_from_host(data)
}
}
pub struct WgpuMessageQueue {
headers: WgpuBuffer,
payloads: WgpuBuffer,
capacity: usize,
max_payload_size: usize,
head: std::sync::atomic::AtomicU32,
tail: std::sync::atomic::AtomicU32,
}
impl WgpuMessageQueue {
pub fn new(adapter: &WgpuAdapter, capacity: usize, max_payload_size: usize) -> Self {
let header_size = std::mem::size_of::<MessageHeader>() * capacity;
let payload_size = max_payload_size * capacity;
Self {
headers: WgpuBuffer::new(adapter, header_size, Some("Message Headers")),
payloads: WgpuBuffer::new(adapter, payload_size, Some("Message Payloads")),
capacity,
max_payload_size,
head: std::sync::atomic::AtomicU32::new(0),
tail: std::sync::atomic::AtomicU32::new(0),
}
}
pub fn headers_binding(&self) -> wgpu::BindingResource<'_> {
self.headers.as_entire_binding()
}
#[allow(dead_code)]
pub fn payloads_binding(&self) -> wgpu::BindingResource<'_> {
self.payloads.as_entire_binding()
}
#[allow(dead_code)]
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn len(&self) -> usize {
let head = self.head.load(std::sync::atomic::Ordering::Acquire);
let tail = self.tail.load(std::sync::atomic::Ordering::Acquire);
if tail >= head {
(tail - head) as usize
} else {
self.capacity - (head - tail) as usize
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn is_full(&self) -> bool {
self.len() >= self.capacity - 1
}
pub fn enqueue(&self, envelope: &ringkernel_core::message::MessageEnvelope) -> Result<()> {
if self.is_full() {
return Err(RingKernelError::QueueFull {
capacity: self.capacity,
});
}
let tail = self.tail.load(std::sync::atomic::Ordering::Acquire) as usize;
let header_offset = tail * std::mem::size_of::<MessageHeader>();
let header_bytes = envelope.header.as_bytes();
let mut header_data = vec![0u8; std::mem::size_of::<MessageHeader>()];
header_data[..header_bytes.len()].copy_from_slice(header_bytes);
let mut all_headers = vec![0u8; self.headers.size()];
self.headers.copy_to_host(&mut all_headers)?;
all_headers[header_offset..header_offset + header_data.len()].copy_from_slice(&header_data);
self.headers.copy_from_host(&all_headers)?;
if !envelope.payload.is_empty() {
let payload_offset = tail * self.max_payload_size;
let payload_len = envelope.payload.len().min(self.max_payload_size);
let mut all_payloads = vec![0u8; self.payloads.size()];
self.payloads.copy_to_host(&mut all_payloads)?;
all_payloads[payload_offset..payload_offset + payload_len]
.copy_from_slice(&envelope.payload[..payload_len]);
self.payloads.copy_from_host(&all_payloads)?;
}
let new_tail = ((tail + 1) % self.capacity) as u32;
self.tail
.store(new_tail, std::sync::atomic::Ordering::Release);
Ok(())
}
pub fn dequeue(&self) -> Result<ringkernel_core::message::MessageEnvelope> {
if self.is_empty() {
return Err(RingKernelError::QueueEmpty);
}
let head = self.head.load(std::sync::atomic::Ordering::Acquire) as usize;
let header_offset = head * std::mem::size_of::<MessageHeader>();
let mut all_headers = vec![0u8; self.headers.size()];
self.headers.copy_to_host(&mut all_headers)?;
let header_bytes =
&all_headers[header_offset..header_offset + std::mem::size_of::<MessageHeader>()];
let header = MessageHeader::read_from(header_bytes).ok_or_else(|| {
RingKernelError::DeserializationError("Failed to read header".to_string())
})?;
let payload_offset = head * self.max_payload_size;
let payload_size = (header.payload_size as usize).min(self.max_payload_size);
let payload = if payload_size > 0 {
let mut all_payloads = vec![0u8; self.payloads.size()];
self.payloads.copy_to_host(&mut all_payloads)?;
all_payloads[payload_offset..payload_offset + payload_size].to_vec()
} else {
Vec::new()
};
let new_head = ((head + 1) % self.capacity) as u32;
self.head
.store(new_head, std::sync::atomic::Ordering::Release);
Ok(ringkernel_core::message::MessageEnvelope { header, payload })
}
pub fn try_dequeue(&self) -> Option<ringkernel_core::message::MessageEnvelope> {
if self.is_empty() {
None
} else {
self.dequeue().ok()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
#[ignore] async fn test_wgpu_buffer() {
let adapter = WgpuAdapter::new().await.unwrap();
let buffer = WgpuBuffer::new(&adapter, 1024, Some("Test Buffer"));
let data = vec![42u8; 1024];
buffer.copy_from_host(&data).unwrap();
let mut read_data = vec![0u8; 1024];
buffer.copy_to_host(&mut read_data).unwrap();
assert_eq!(data, read_data);
}
}