use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use oxicuda_backend::{
BackendResult, BackendTranspose, BinaryOp, ComputeBackend, ReduceOp, UnaryOp,
};
use crate::WebGpuBackend;
use crate::error::{WebGpuError, WebGpuResult};
use crate::memory::WebGpuBufferInfo;
#[derive(Debug)]
pub struct WasmGpuDevice {
#[allow(dead_code)]
pub(crate) instance: wgpu::Instance,
#[allow(dead_code)]
pub(crate) adapter: wgpu::Adapter,
pub(crate) device: wgpu::Device,
pub(crate) queue: wgpu::Queue,
pub adapter_name: String,
}
impl WasmGpuDevice {
pub async fn from_adapter(
instance: wgpu::Instance,
adapter: wgpu::Adapter,
) -> WebGpuResult<Self> {
let adapter_name = adapter.get_info().name.clone();
let (device, queue) = adapter
.request_device(&wgpu::DeviceDescriptor {
label: Some("oxicuda-webgpu-wasm"),
required_features: wgpu::Features::empty(),
required_limits: wgpu::Limits::default(),
memory_hints: wgpu::MemoryHints::default(),
..Default::default()
})
.await
.map_err(|e| WebGpuError::DeviceRequest(e.to_string()))?;
Ok(Self {
instance,
adapter,
device,
queue,
adapter_name,
})
}
}
pub async fn request_adapter() -> WebGpuResult<wgpu::Adapter> {
let instance = wgpu::Instance::new(wgpu::InstanceDescriptor::new_without_display_handle());
instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
compatible_surface: None,
force_fallback_adapter: false,
})
.await
.map_err(|e| WebGpuError::DeviceRequest(e.to_string()))
}
pub struct WasmMemoryManager {
device: Arc<WasmGpuDevice>,
buffers: Mutex<HashMap<u64, WebGpuBufferInfo>>,
next_handle: AtomicU64,
}
impl WasmMemoryManager {
pub fn new(device: Arc<WasmGpuDevice>) -> Self {
Self {
device,
buffers: Mutex::new(HashMap::new()),
next_handle: AtomicU64::new(1),
}
}
pub fn alloc(&self, bytes: usize) -> WebGpuResult<u64> {
let size = bytes as u64;
let buffer = self.device.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("oxicuda-wasm-buffer"),
size,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let handle = self.next_handle.fetch_add(1, Ordering::Relaxed);
self.buffers
.lock()
.map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))?
.insert(handle, WebGpuBufferInfo { buffer, size });
Ok(handle)
}
pub fn free(&self, handle: u64) -> WebGpuResult<()> {
self.buffers
.lock()
.map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))?
.remove(&handle);
Ok(())
}
pub fn copy_htod(&self, handle: u64, src: &[u8]) -> WebGpuResult<()> {
let buffers = self
.buffers
.lock()
.map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))?;
let buf_info = buffers
.get(&handle)
.ok_or_else(|| WebGpuError::InvalidArgument(format!("unknown handle {handle}")))?;
self.device.queue.write_buffer(&buf_info.buffer, 0, src);
Ok(())
}
pub fn copy_dtoh(&self, dst: &mut [u8], handle: u64) -> WebGpuResult<()> {
let staging = {
let buffers = self
.buffers
.lock()
.map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))?;
let buf_info = buffers
.get(&handle)
.ok_or_else(|| WebGpuError::InvalidArgument(format!("unknown handle {handle}")))?;
let staging = self.device.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("oxicuda-wasm-staging"),
size: buf_info.size,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let mut encoder =
self.device
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("oxicuda-wasm-readback"),
});
encoder.copy_buffer_to_buffer(&buf_info.buffer, 0, &staging, 0, buf_info.size);
self.device.queue.submit(std::iter::once(encoder.finish()));
staging
};
let slice = staging.slice(..);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |result| {
let _ = tx.send(result);
});
let _ = self.device.device.poll(wgpu::PollType::wait_indefinitely());
rx.recv()
.map_err(|_| WebGpuError::BufferMapping("channel closed before map completed".into()))?
.map_err(|e| WebGpuError::BufferMapping(format!("{e:?}")))?;
let data = slice.get_mapped_range();
let copy_len = dst.len().min(data.len());
dst[..copy_len].copy_from_slice(&data[..copy_len]);
drop(data);
staging.unmap();
Ok(())
}
}
impl std::fmt::Debug for WasmMemoryManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let count = self.buffers.lock().map(|b| b.len()).unwrap_or(0);
write!(f, "WasmMemoryManager(buffers={count})")
}
}
#[derive(Debug)]
pub struct WasmBackend {
inner: WebGpuBackend,
}
impl WasmBackend {
pub fn new() -> Self {
Self {
inner: WebGpuBackend::new(),
}
}
pub async fn init_from_canvas(_canvas_id: &str) -> Result<Self, WebGpuError> {
let mut backend = Self::new();
backend
.inner
.init()
.map_err(|e| WebGpuError::DeviceRequest(e.to_string()))?;
Ok(backend)
}
}
impl Default for WasmBackend {
fn default() -> Self {
Self::new()
}
}
impl ComputeBackend for WasmBackend {
fn name(&self) -> &str {
"webgpu-wasm"
}
fn init(&mut self) -> BackendResult<()> {
self.inner.init()
}
fn is_initialized(&self) -> bool {
self.inner.is_initialized()
}
#[allow(clippy::too_many_arguments)]
fn gemm(
&self,
trans_a: BackendTranspose,
trans_b: BackendTranspose,
m: usize,
n: usize,
k: usize,
alpha: f64,
a_ptr: u64,
lda: usize,
b_ptr: u64,
ldb: usize,
beta: f64,
c_ptr: u64,
ldc: usize,
) -> BackendResult<()> {
self.inner.gemm(
trans_a, trans_b, m, n, k, alpha, a_ptr, lda, b_ptr, ldb, beta, c_ptr, ldc,
)
}
#[allow(clippy::too_many_arguments)]
fn conv2d_forward(
&self,
input_ptr: u64,
input_shape: &[usize],
filter_ptr: u64,
filter_shape: &[usize],
output_ptr: u64,
output_shape: &[usize],
stride: &[usize],
padding: &[usize],
) -> BackendResult<()> {
self.inner.conv2d_forward(
input_ptr,
input_shape,
filter_ptr,
filter_shape,
output_ptr,
output_shape,
stride,
padding,
)
}
#[allow(clippy::too_many_arguments)]
fn attention(
&self,
q_ptr: u64,
k_ptr: u64,
v_ptr: u64,
o_ptr: u64,
batch: usize,
heads: usize,
seq_q: usize,
seq_kv: usize,
head_dim: usize,
scale: f64,
causal: bool,
) -> BackendResult<()> {
self.inner.attention(
q_ptr, k_ptr, v_ptr, o_ptr, batch, heads, seq_q, seq_kv, head_dim, scale, causal,
)
}
fn reduce(
&self,
op: ReduceOp,
input_ptr: u64,
output_ptr: u64,
shape: &[usize],
axis: usize,
) -> BackendResult<()> {
self.inner.reduce(op, input_ptr, output_ptr, shape, axis)
}
fn unary(&self, op: UnaryOp, input_ptr: u64, output_ptr: u64, n: usize) -> BackendResult<()> {
self.inner.unary(op, input_ptr, output_ptr, n)
}
fn binary(
&self,
op: BinaryOp,
a_ptr: u64,
b_ptr: u64,
output_ptr: u64,
n: usize,
) -> BackendResult<()> {
self.inner.binary(op, a_ptr, b_ptr, output_ptr, n)
}
fn synchronize(&self) -> BackendResult<()> {
self.inner.synchronize()
}
fn alloc(&self, bytes: usize) -> BackendResult<u64> {
self.inner.alloc(bytes)
}
fn free(&self, ptr: u64) -> BackendResult<()> {
self.inner.free(ptr)
}
fn copy_htod(&self, dst: u64, src: &[u8]) -> BackendResult<()> {
self.inner.copy_htod(dst, src)
}
fn copy_dtoh(&self, dst: &mut [u8], src: u64) -> BackendResult<()> {
self.inner.copy_dtoh(dst, src)
}
}
#[cfg(test)]
mod tests {
use super::*;
use oxicuda_backend::BackendError;
#[test]
fn wasm_module_compiles() {
let backend = WasmBackend::new();
assert!(!backend.is_initialized());
assert_eq!(backend.name(), "webgpu-wasm");
let debug_str = format!("{backend:?}");
assert!(debug_str.contains("WasmBackend"));
}
#[test]
fn wasm_feature_flag_gating() {
let backend = WasmBackend::new();
let _: &dyn ComputeBackend = &backend;
let _default = WasmBackend::default();
}
#[test]
fn wasm_public_api_accessible() {
fn _assert_wasm_gpu_device_exists(_: &WasmGpuDevice) {}
fn _assert_wasm_memory_manager_exists(_: &WasmMemoryManager) {}
let _b = WasmBackend::new();
let _b2 = WasmBackend::default();
let _fn_ptr: fn() -> _ = || request_adapter();
}
#[test]
fn wasm_backend_not_initialized_guards() {
let b = WasmBackend::new();
assert_eq!(b.alloc(1024), Err(BackendError::NotInitialized));
assert_eq!(b.free(1), Err(BackendError::NotInitialized));
assert_eq!(b.copy_htod(1, b"hello"), Err(BackendError::NotInitialized));
let mut buf = [0u8; 4];
assert_eq!(b.copy_dtoh(&mut buf, 1), Err(BackendError::NotInitialized));
assert_eq!(b.synchronize(), Err(BackendError::NotInitialized));
}
#[test]
fn wasm_backend_init_graceful() {
let mut b = WasmBackend::new();
let _result = b.init();
}
}