1use std::collections::HashMap;
33use std::sync::atomic::{AtomicU64, Ordering};
34use std::sync::{Arc, Mutex};
35
36use oxicuda_backend::{
37 BackendResult, BackendTranspose, BinaryOp, ComputeBackend, ReduceOp, UnaryOp,
38};
39
40use crate::WebGpuBackend;
41use crate::error::{WebGpuError, WebGpuResult};
42use crate::memory::WebGpuBufferInfo;
43
44#[derive(Debug)]
51pub struct WasmGpuDevice {
52 #[allow(dead_code)]
54 pub(crate) instance: wgpu::Instance,
55 #[allow(dead_code)]
57 pub(crate) adapter: wgpu::Adapter,
58 pub(crate) device: wgpu::Device,
60 pub(crate) queue: wgpu::Queue,
62 pub adapter_name: String,
64}
65
66impl WasmGpuDevice {
67 pub async fn from_adapter(
73 instance: wgpu::Instance,
74 adapter: wgpu::Adapter,
75 ) -> WebGpuResult<Self> {
76 let adapter_name = adapter.get_info().name.clone();
77
78 let (device, queue) = adapter
79 .request_device(&wgpu::DeviceDescriptor {
80 label: Some("oxicuda-webgpu-wasm"),
81 required_features: wgpu::Features::empty(),
82 required_limits: wgpu::Limits::default(),
83 memory_hints: wgpu::MemoryHints::default(),
84 ..Default::default()
85 })
86 .await
87 .map_err(|e| WebGpuError::DeviceRequest(e.to_string()))?;
88
89 Ok(Self {
90 instance,
91 adapter,
92 device,
93 queue,
94 adapter_name,
95 })
96 }
97}
98
99pub async fn request_adapter() -> WebGpuResult<wgpu::Adapter> {
106 let instance = wgpu::Instance::new(wgpu::InstanceDescriptor::new_without_display_handle());
107
108 instance
109 .request_adapter(&wgpu::RequestAdapterOptions {
110 power_preference: wgpu::PowerPreference::HighPerformance,
111 compatible_surface: None,
112 force_fallback_adapter: false,
113 })
114 .await
115 .map_err(|e| WebGpuError::DeviceRequest(e.to_string()))
116}
117
118pub struct WasmMemoryManager {
126 device: Arc<WasmGpuDevice>,
127 buffers: Mutex<HashMap<u64, WebGpuBufferInfo>>,
128 next_handle: AtomicU64,
129}
130
131impl WasmMemoryManager {
132 pub fn new(device: Arc<WasmGpuDevice>) -> Self {
134 Self {
135 device,
136 buffers: Mutex::new(HashMap::new()),
137 next_handle: AtomicU64::new(1),
138 }
139 }
140
141 pub fn alloc(&self, bytes: usize) -> WebGpuResult<u64> {
143 let size = bytes as u64;
144 let buffer = self.device.device.create_buffer(&wgpu::BufferDescriptor {
145 label: Some("oxicuda-wasm-buffer"),
146 size,
147 usage: wgpu::BufferUsages::STORAGE
148 | wgpu::BufferUsages::COPY_SRC
149 | wgpu::BufferUsages::COPY_DST,
150 mapped_at_creation: false,
151 });
152
153 let handle = self.next_handle.fetch_add(1, Ordering::Relaxed);
154
155 self.buffers
156 .lock()
157 .map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))?
158 .insert(handle, WebGpuBufferInfo { buffer, size });
159
160 Ok(handle)
161 }
162
163 pub fn free(&self, handle: u64) -> WebGpuResult<()> {
165 self.buffers
166 .lock()
167 .map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))?
168 .remove(&handle);
169 Ok(())
170 }
171
172 pub fn copy_htod(&self, handle: u64, src: &[u8]) -> WebGpuResult<()> {
176 let buffers = self
177 .buffers
178 .lock()
179 .map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))?;
180
181 let buf_info = buffers
182 .get(&handle)
183 .ok_or_else(|| WebGpuError::InvalidArgument(format!("unknown handle {handle}")))?;
184
185 self.device.queue.write_buffer(&buf_info.buffer, 0, src);
186 Ok(())
187 }
188
189 pub fn copy_dtoh(&self, dst: &mut [u8], handle: u64) -> WebGpuResult<()> {
195 let staging = {
196 let buffers = self
197 .buffers
198 .lock()
199 .map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))?;
200
201 let buf_info = buffers
202 .get(&handle)
203 .ok_or_else(|| WebGpuError::InvalidArgument(format!("unknown handle {handle}")))?;
204
205 let staging = self.device.device.create_buffer(&wgpu::BufferDescriptor {
206 label: Some("oxicuda-wasm-staging"),
207 size: buf_info.size,
208 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
209 mapped_at_creation: false,
210 });
211
212 let mut encoder =
213 self.device
214 .device
215 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
216 label: Some("oxicuda-wasm-readback"),
217 });
218
219 encoder.copy_buffer_to_buffer(&buf_info.buffer, 0, &staging, 0, buf_info.size);
220 self.device.queue.submit(std::iter::once(encoder.finish()));
221
222 staging
223 };
224
225 let slice = staging.slice(..);
226 let (tx, rx) = std::sync::mpsc::channel();
227 slice.map_async(wgpu::MapMode::Read, move |result| {
228 let _ = tx.send(result);
229 });
230
231 let _ = self.device.device.poll(wgpu::PollType::wait_indefinitely());
232
233 rx.recv()
234 .map_err(|_| WebGpuError::BufferMapping("channel closed before map completed".into()))?
235 .map_err(|e| WebGpuError::BufferMapping(format!("{e:?}")))?;
236
237 let data = slice.get_mapped_range();
238 let copy_len = dst.len().min(data.len());
239 dst[..copy_len].copy_from_slice(&data[..copy_len]);
240 drop(data);
241 staging.unmap();
242
243 Ok(())
244 }
245}
246
247impl std::fmt::Debug for WasmMemoryManager {
248 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249 let count = self.buffers.lock().map(|b| b.len()).unwrap_or(0);
250 write!(f, "WasmMemoryManager(buffers={count})")
251 }
252}
253
254#[derive(Debug)]
270pub struct WasmBackend {
271 inner: WebGpuBackend,
272}
273
274impl WasmBackend {
275 pub fn new() -> Self {
277 Self {
278 inner: WebGpuBackend::new(),
279 }
280 }
281
282 pub async fn init_from_canvas(_canvas_id: &str) -> Result<Self, WebGpuError> {
292 let mut backend = Self::new();
297 backend
298 .inner
299 .init()
300 .map_err(|e| WebGpuError::DeviceRequest(e.to_string()))?;
301 Ok(backend)
302 }
303}
304
305impl Default for WasmBackend {
306 fn default() -> Self {
307 Self::new()
308 }
309}
310
311impl ComputeBackend for WasmBackend {
314 fn name(&self) -> &str {
315 "webgpu-wasm"
316 }
317
318 fn init(&mut self) -> BackendResult<()> {
319 self.inner.init()
320 }
321
322 fn is_initialized(&self) -> bool {
323 self.inner.is_initialized()
324 }
325
326 #[allow(clippy::too_many_arguments)]
327 fn gemm(
328 &self,
329 trans_a: BackendTranspose,
330 trans_b: BackendTranspose,
331 m: usize,
332 n: usize,
333 k: usize,
334 alpha: f64,
335 a_ptr: u64,
336 lda: usize,
337 b_ptr: u64,
338 ldb: usize,
339 beta: f64,
340 c_ptr: u64,
341 ldc: usize,
342 ) -> BackendResult<()> {
343 self.inner.gemm(
344 trans_a, trans_b, m, n, k, alpha, a_ptr, lda, b_ptr, ldb, beta, c_ptr, ldc,
345 )
346 }
347
348 #[allow(clippy::too_many_arguments)]
349 fn conv2d_forward(
350 &self,
351 input_ptr: u64,
352 input_shape: &[usize],
353 filter_ptr: u64,
354 filter_shape: &[usize],
355 output_ptr: u64,
356 output_shape: &[usize],
357 stride: &[usize],
358 padding: &[usize],
359 ) -> BackendResult<()> {
360 self.inner.conv2d_forward(
361 input_ptr,
362 input_shape,
363 filter_ptr,
364 filter_shape,
365 output_ptr,
366 output_shape,
367 stride,
368 padding,
369 )
370 }
371
372 #[allow(clippy::too_many_arguments)]
373 fn attention(
374 &self,
375 q_ptr: u64,
376 k_ptr: u64,
377 v_ptr: u64,
378 o_ptr: u64,
379 batch: usize,
380 heads: usize,
381 seq_q: usize,
382 seq_kv: usize,
383 head_dim: usize,
384 scale: f64,
385 causal: bool,
386 ) -> BackendResult<()> {
387 self.inner.attention(
388 q_ptr, k_ptr, v_ptr, o_ptr, batch, heads, seq_q, seq_kv, head_dim, scale, causal,
389 )
390 }
391
392 fn reduce(
393 &self,
394 op: ReduceOp,
395 input_ptr: u64,
396 output_ptr: u64,
397 shape: &[usize],
398 axis: usize,
399 ) -> BackendResult<()> {
400 self.inner.reduce(op, input_ptr, output_ptr, shape, axis)
401 }
402
403 fn unary(&self, op: UnaryOp, input_ptr: u64, output_ptr: u64, n: usize) -> BackendResult<()> {
404 self.inner.unary(op, input_ptr, output_ptr, n)
405 }
406
407 fn binary(
408 &self,
409 op: BinaryOp,
410 a_ptr: u64,
411 b_ptr: u64,
412 output_ptr: u64,
413 n: usize,
414 ) -> BackendResult<()> {
415 self.inner.binary(op, a_ptr, b_ptr, output_ptr, n)
416 }
417
418 fn synchronize(&self) -> BackendResult<()> {
419 self.inner.synchronize()
420 }
421
422 fn alloc(&self, bytes: usize) -> BackendResult<u64> {
423 self.inner.alloc(bytes)
424 }
425
426 fn free(&self, ptr: u64) -> BackendResult<()> {
427 self.inner.free(ptr)
428 }
429
430 fn copy_htod(&self, dst: u64, src: &[u8]) -> BackendResult<()> {
431 self.inner.copy_htod(dst, src)
432 }
433
434 fn copy_dtoh(&self, dst: &mut [u8], src: u64) -> BackendResult<()> {
435 self.inner.copy_dtoh(dst, src)
436 }
437}
438
439#[cfg(test)]
442mod tests {
443 use super::*;
444 use oxicuda_backend::BackendError;
445
446 #[test]
448 fn wasm_module_compiles() {
449 let backend = WasmBackend::new();
450 assert!(!backend.is_initialized());
451 assert_eq!(backend.name(), "webgpu-wasm");
452
453 let debug_str = format!("{backend:?}");
455 assert!(debug_str.contains("WasmBackend"));
456 }
457
458 #[test]
460 fn wasm_feature_flag_gating() {
461 let backend = WasmBackend::new();
463 let _: &dyn ComputeBackend = &backend;
464
465 let _default = WasmBackend::default();
467 }
468
469 #[test]
471 fn wasm_public_api_accessible() {
472 fn _assert_wasm_gpu_device_exists(_: &WasmGpuDevice) {}
474
475 fn _assert_wasm_memory_manager_exists(_: &WasmMemoryManager) {}
477
478 let _b = WasmBackend::new();
480 let _b2 = WasmBackend::default();
481
482 let _fn_ptr: fn() -> _ = || request_adapter();
484 }
485
486 #[test]
488 fn wasm_backend_not_initialized_guards() {
489 let b = WasmBackend::new();
490 assert_eq!(b.alloc(1024), Err(BackendError::NotInitialized));
491 assert_eq!(b.free(1), Err(BackendError::NotInitialized));
492 assert_eq!(b.copy_htod(1, b"hello"), Err(BackendError::NotInitialized));
493
494 let mut buf = [0u8; 4];
495 assert_eq!(b.copy_dtoh(&mut buf, 1), Err(BackendError::NotInitialized));
496 assert_eq!(b.synchronize(), Err(BackendError::NotInitialized));
497 }
498
499 #[test]
501 fn wasm_backend_init_graceful() {
502 let mut b = WasmBackend::new();
503 let _result = b.init();
504 }
505}