use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use wgpu::{Buffer, BufferDescriptor, BufferUsages, Device, Queue};
use super::WgpuRuntime;
use super::device::{WgpuDevice, WgpuError, query_adapter_info_blocking};
use super::shaders::PipelineCache;
use crate::runtime::{Allocator, RuntimeClient};
#[derive(Clone)]
pub struct WgpuClient {
pub(crate) device_id: WgpuDevice,
pub(crate) wgpu_device: Arc<Device>,
pub(crate) queue: Arc<Queue>,
pub(crate) allocator: WgpuAllocator,
pub(crate) raw_handle: WgpuRawHandle,
pub(crate) pipeline_cache: Arc<PipelineCache>,
}
impl std::fmt::Debug for WgpuClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WgpuClient")
.field("device", &self.device_id)
.finish_non_exhaustive()
}
}
impl WgpuClient {
pub fn new(device: WgpuDevice) -> Result<Self, WgpuError> {
super::cache::get_or_create_client(&device)
}
pub(super) fn new_uncached(device: WgpuDevice) -> Result<Self, WgpuError> {
let (adapter, info) = query_adapter_info_blocking(device.index)?;
let (wgpu_device, queue) = pollster::block_on(async {
let features = adapter.features();
let required_features = if features.contains(wgpu::Features::SUBGROUP) {
wgpu::Features::SUBGROUP
} else {
wgpu::Features::empty()
};
adapter
.request_device(&wgpu::DeviceDescriptor {
label: Some("numr WebGPU Device"),
required_features,
required_limits: wgpu::Limits::default(),
memory_hints: wgpu::MemoryHints::Performance,
trace: wgpu::Trace::Off,
experimental_features: wgpu::ExperimentalFeatures::default(),
})
.await
})
.map_err(|e| WgpuError::DeviceError(format!("{:?}", e)))?;
let wgpu_device = Arc::new(wgpu_device);
let queue = Arc::new(queue);
let allocator = WgpuAllocator {
device: wgpu_device.clone(),
queue: queue.clone(),
};
let raw_handle = WgpuRawHandle {
device: wgpu_device.clone(),
queue: queue.clone(),
};
let pipeline_cache = Arc::new(PipelineCache::new(wgpu_device.clone(), queue.clone()));
let device_with_info = WgpuDevice::with_info(device.index, info);
Ok(Self {
device_id: device_with_info,
wgpu_device,
queue,
allocator,
raw_handle,
pipeline_cache,
})
}
#[inline]
pub fn wgpu_device(&self) -> &Device {
&self.wgpu_device
}
#[inline]
pub fn wgpu_device_arc(&self) -> &Arc<Device> {
&self.wgpu_device
}
#[inline]
pub fn wgpu_queue(&self) -> &Queue {
&self.queue
}
#[inline]
pub fn pipeline_cache(&self) -> &PipelineCache {
&self.pipeline_cache
}
pub fn create_storage_buffer(&self, label: &str, size: u64) -> Buffer {
self.wgpu_device.create_buffer(&BufferDescriptor {
label: Some(label),
size,
usage: BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC,
mapped_at_creation: false,
})
}
pub fn create_staging_buffer(&self, label: &str, size: u64) -> Buffer {
self.wgpu_device.create_buffer(&BufferDescriptor {
label: Some(label),
size,
usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
mapped_at_creation: false,
})
}
pub fn create_uniform_buffer(&self, label: &str, size: u64) -> Buffer {
self.wgpu_device.create_buffer(&BufferDescriptor {
label: Some(label),
size,
usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST,
mapped_at_creation: false,
})
}
pub fn write_buffer<T: bytemuck::Pod>(&self, buffer: &Buffer, data: &[T]) {
self.queue
.write_buffer(buffer, 0, bytemuck::cast_slice(data));
}
pub fn submit_and_wait(&self, encoder: wgpu::CommandEncoder) {
let submission = self.queue.submit(std::iter::once(encoder.finish()));
let _ = self.wgpu_device.poll(wgpu::PollType::Wait {
submission_index: Some(submission),
timeout: Some(Duration::from_secs(60)),
});
}
pub fn read_buffer<T: bytemuck::Pod>(
&self,
staging: &Buffer,
output: &mut [T],
) -> crate::error::Result<()> {
let slice = staging.slice(..);
let (sender, receiver) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |result| {
let _ = sender.send(result);
});
self.wgpu_device
.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: Some(Duration::from_secs(60)),
})
.map_err(|e| {
crate::error::Error::Backend(format!("GPU poll failed during buffer read: {e}"))
})?;
let map_result = receiver.recv().map_err(|_| {
crate::error::Error::Backend(
"map_async callback was not invoked during buffer read".into(),
)
})?;
map_result.map_err(|e| {
crate::error::Error::Backend(format!("map_async failed during buffer read: {e}"))
})?;
{
let data = slice.get_mapped_range();
let src: &[T] = bytemuck::cast_slice(&data);
output.copy_from_slice(&src[..output.len()]);
}
staging.unmap();
Ok(())
}
}
impl RuntimeClient<WgpuRuntime> for WgpuClient {
fn device(&self) -> &WgpuDevice {
&self.device_id
}
fn synchronize(&self) {
let _ = self.wgpu_device.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: Some(Duration::from_secs(60)),
});
}
fn allocator(&self) -> &WgpuAllocator {
&self.allocator
}
}
#[derive(Clone)]
pub struct WgpuAllocator {
device: Arc<Device>,
#[allow(dead_code)] queue: Arc<Queue>,
}
static BUFFER_REGISTRY: std::sync::OnceLock<parking_lot::Mutex<HashMap<u64, Arc<Buffer>>>> =
std::sync::OnceLock::new();
static BUFFER_ID_COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(1);
fn get_buffer_registry() -> &'static parking_lot::Mutex<HashMap<u64, Arc<Buffer>>> {
BUFFER_REGISTRY.get_or_init(|| parking_lot::Mutex::new(HashMap::new()))
}
pub fn get_buffer(id: u64) -> Option<Arc<Buffer>> {
if id == 0 {
return None;
}
let registry = get_buffer_registry();
let guard = registry.lock();
guard.get(&id).cloned()
}
impl Allocator for WgpuAllocator {
fn allocate(&self, size_bytes: usize) -> crate::error::Result<u64> {
if size_bytes == 0 {
return Ok(0);
}
let aligned_size = size_bytes.div_ceil(4) * 4;
let buffer = self.device.create_buffer(&BufferDescriptor {
label: Some("numr tensor buffer"),
size: aligned_size as u64,
usage: BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let id = BUFFER_ID_COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let registry = get_buffer_registry();
let mut guard = registry.lock();
guard.insert(id, Arc::new(buffer));
Ok(id)
}
fn deallocate(&self, ptr: u64, _size_bytes: usize) {
if ptr == 0 {
return;
}
let registry = get_buffer_registry();
let mut guard = registry.lock();
guard.remove(&ptr);
}
fn is_frozen(&self) -> bool {
false }
fn freeze(&self) -> bool {
true }
fn unfreeze(&self) {
}
}
#[derive(Clone)]
pub struct WgpuRawHandle {
pub device: Arc<Device>,
pub queue: Arc<Queue>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::{Allocator, Device, RuntimeClient};
#[test]
fn test_wgpu_client_creation() {
let device = WgpuDevice::new(0);
match WgpuClient::new(device) {
Ok(client) => {
println!("Client created for: {}", client.device().name());
assert_eq!(client.device().id(), 0);
}
Err(e) => {
println!("No GPU available, skipping test: {}", e);
}
}
}
#[test]
fn test_wgpu_allocator() {
let device = WgpuDevice::new(0);
match WgpuClient::new(device) {
Ok(client) => {
let allocator = client.allocator();
let id = allocator.allocate(1024).expect("allocation should succeed");
assert_ne!(id, 0);
let buffer = get_buffer(id);
assert!(buffer.is_some());
allocator.deallocate(id, 1024);
let buffer = get_buffer(id);
assert!(buffer.is_none());
}
Err(e) => {
println!("No GPU available, skipping test: {}", e);
}
}
}
#[test]
fn test_wgpu_buffer_roundtrip() {
let device = WgpuDevice::new(0);
match WgpuClient::new(device) {
Ok(client) => {
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let size = data.len() * std::mem::size_of::<f32>();
let storage = client.create_storage_buffer("test", size as u64);
client.write_buffer(&storage, &data);
let staging = client.create_staging_buffer("staging", size as u64);
let mut encoder =
client
.wgpu_device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("copy"),
});
encoder.copy_buffer_to_buffer(&storage, 0, &staging, 0, size as u64);
client.submit_and_wait(encoder);
let mut result = vec![0.0f32; data.len()];
client
.read_buffer(&staging, &mut result)
.expect("readback should succeed");
assert_eq!(data, result);
println!("Buffer roundtrip successful: {:?}", result);
}
Err(e) => {
println!("No GPU available, skipping test: {}", e);
}
}
}
}