Skip to main content

oxicuda_webgpu/
wasm.rs

1//! WASM target support for browser-based GPU compute via WebGPU.
2//!
3//! This module is conditionally compiled on `wasm32` targets (or when the `wasm`
4//! feature is enabled for native testing) and provides browser-friendly wrappers
5//! around the `wgpu` WebGPU backend.
6//!
7//! # Architecture
8//!
9//! ```text
10//! +-------------------------------------------+
11//! |         JavaScript / Browser              |
12//! +-------------------+-----------------------+
13//!                     |
14//! +-------------------v-----------------------+
15//! |   WasmGpuDevice / WasmBackend (wasm32)    |
16//! +-------------------+-----------------------+
17//!                     |  delegates to
18//! +-------------------v-----------------------+
19//! |   WebGpuBackend (wgpu web-sys backend)    |
20//! +-------------------------------------------+
21//! ```
22//!
23//! # Usage
24//!
25//! The [`WasmBackend`] wraps the existing [`WebGpuBackend`]
26//! and adds browser-specific initialisation methods such as
27//! [`init_from_canvas`](WasmBackend::init_from_canvas).
28//!
29//! The [`WasmMemoryManager`] provides async-friendly buffer staging suited to the
30//! browser event loop.
31
32use std::collections::HashMap;
33use std::sync::atomic::{AtomicU64, Ordering};
34use std::sync::{Arc, Mutex};
35
36use oxicuda_backend::{
37    BackendResult, BackendTranspose, BinaryOp, ComputeBackend, ReduceOp, UnaryOp,
38};
39
40use crate::WebGpuBackend;
41use crate::error::{WebGpuError, WebGpuResult};
42use crate::memory::WebGpuBufferInfo;
43
44// ---- WasmGpuDevice --------------------------------------------------------
45
46/// A WebGPU device obtained from the browser's `navigator.gpu` API.
47///
48/// Wraps the `wgpu` adapter and device objects and provides async construction
49/// methods appropriate for the browser environment.
50#[derive(Debug)]
51pub struct WasmGpuDevice {
52    /// The wgpu instance.
53    #[allow(dead_code)]
54    pub(crate) instance: wgpu::Instance,
55    /// The selected GPU adapter.
56    #[allow(dead_code)]
57    pub(crate) adapter: wgpu::Adapter,
58    /// The logical device.
59    pub(crate) device: wgpu::Device,
60    /// The queue for submitting command buffers.
61    pub(crate) queue: wgpu::Queue,
62    /// Human-readable adapter name.
63    pub adapter_name: String,
64}
65
66impl WasmGpuDevice {
67    /// Create a new [`WasmGpuDevice`] from an already-obtained adapter.
68    ///
69    /// This is the async path used by browser callers. On native targets this
70    /// may not be exercised directly, but it is the intended entry point for
71    /// WASM builds.
72    pub async fn from_adapter(
73        instance: wgpu::Instance,
74        adapter: wgpu::Adapter,
75    ) -> WebGpuResult<Self> {
76        let adapter_name = adapter.get_info().name.clone();
77
78        let (device, queue) = adapter
79            .request_device(&wgpu::DeviceDescriptor {
80                label: Some("oxicuda-webgpu-wasm"),
81                required_features: wgpu::Features::empty(),
82                required_limits: wgpu::Limits::default(),
83                memory_hints: wgpu::MemoryHints::default(),
84                ..Default::default()
85            })
86            .await
87            .map_err(|e| WebGpuError::DeviceRequest(e.to_string()))?;
88
89        Ok(Self {
90            instance,
91            adapter,
92            device,
93            queue,
94            adapter_name,
95        })
96    }
97}
98
99// ---- request_adapter -------------------------------------------------------
100
101/// Request a WebGPU adapter from the browser.
102///
103/// On `wasm32` this goes through the browser's `navigator.gpu` API via the
104/// `wgpu` web-sys backend.
105pub async fn request_adapter() -> WebGpuResult<wgpu::Adapter> {
106    let instance = wgpu::Instance::new(wgpu::InstanceDescriptor::new_without_display_handle());
107
108    instance
109        .request_adapter(&wgpu::RequestAdapterOptions {
110            power_preference: wgpu::PowerPreference::HighPerformance,
111            compatible_surface: None,
112            force_fallback_adapter: false,
113        })
114        .await
115        .map_err(|e| WebGpuError::DeviceRequest(e.to_string()))
116}
117
118// ---- WasmMemoryManager -----------------------------------------------------
119
120/// Browser-side buffer manager that uses async `map_async` staging.
121///
122/// This mirrors [`WebGpuMemoryManager`](crate::memory::WebGpuMemoryManager) but
123/// is designed to work within the single-threaded browser event loop where
124/// blocking calls are not allowed.
125pub struct WasmMemoryManager {
126    device: Arc<WasmGpuDevice>,
127    buffers: Mutex<HashMap<u64, WebGpuBufferInfo>>,
128    next_handle: AtomicU64,
129}
130
131impl WasmMemoryManager {
132    /// Create a new WASM memory manager backed by `device`.
133    pub fn new(device: Arc<WasmGpuDevice>) -> Self {
134        Self {
135            device,
136            buffers: Mutex::new(HashMap::new()),
137            next_handle: AtomicU64::new(1),
138        }
139    }
140
141    /// Allocate a device buffer of `bytes` bytes.
142    pub fn alloc(&self, bytes: usize) -> WebGpuResult<u64> {
143        let size = bytes as u64;
144        let buffer = self.device.device.create_buffer(&wgpu::BufferDescriptor {
145            label: Some("oxicuda-wasm-buffer"),
146            size,
147            usage: wgpu::BufferUsages::STORAGE
148                | wgpu::BufferUsages::COPY_SRC
149                | wgpu::BufferUsages::COPY_DST,
150            mapped_at_creation: false,
151        });
152
153        let handle = self.next_handle.fetch_add(1, Ordering::Relaxed);
154
155        self.buffers
156            .lock()
157            .map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))?
158            .insert(handle, WebGpuBufferInfo { buffer, size });
159
160        Ok(handle)
161    }
162
163    /// Free the buffer identified by `handle`.
164    pub fn free(&self, handle: u64) -> WebGpuResult<()> {
165        self.buffers
166            .lock()
167            .map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))?
168            .remove(&handle);
169        Ok(())
170    }
171
172    /// Upload host bytes to the device buffer (host-to-device copy).
173    ///
174    /// Uses `Queue::write_buffer` which is available in both native and WASM.
175    pub fn copy_htod(&self, handle: u64, src: &[u8]) -> WebGpuResult<()> {
176        let buffers = self
177            .buffers
178            .lock()
179            .map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))?;
180
181        let buf_info = buffers
182            .get(&handle)
183            .ok_or_else(|| WebGpuError::InvalidArgument(format!("unknown handle {handle}")))?;
184
185        self.device.queue.write_buffer(&buf_info.buffer, 0, src);
186        Ok(())
187    }
188
189    /// Download device buffer to host bytes (device-to-host copy).
190    ///
191    /// Uses a staging buffer with `map_async` and blocks via `pollster::block_on`.
192    /// On native WASM, callers should prefer the async variant or schedule this
193    /// on a web worker to avoid blocking the main thread.
194    pub fn copy_dtoh(&self, dst: &mut [u8], handle: u64) -> WebGpuResult<()> {
195        let staging = {
196            let buffers = self
197                .buffers
198                .lock()
199                .map_err(|_| WebGpuError::BufferMapping("mutex poisoned".into()))?;
200
201            let buf_info = buffers
202                .get(&handle)
203                .ok_or_else(|| WebGpuError::InvalidArgument(format!("unknown handle {handle}")))?;
204
205            let staging = self.device.device.create_buffer(&wgpu::BufferDescriptor {
206                label: Some("oxicuda-wasm-staging"),
207                size: buf_info.size,
208                usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
209                mapped_at_creation: false,
210            });
211
212            let mut encoder =
213                self.device
214                    .device
215                    .create_command_encoder(&wgpu::CommandEncoderDescriptor {
216                        label: Some("oxicuda-wasm-readback"),
217                    });
218
219            encoder.copy_buffer_to_buffer(&buf_info.buffer, 0, &staging, 0, buf_info.size);
220            self.device.queue.submit(std::iter::once(encoder.finish()));
221
222            staging
223        };
224
225        let slice = staging.slice(..);
226        let (tx, rx) = std::sync::mpsc::channel();
227        slice.map_async(wgpu::MapMode::Read, move |result| {
228            let _ = tx.send(result);
229        });
230
231        let _ = self.device.device.poll(wgpu::PollType::wait_indefinitely());
232
233        rx.recv()
234            .map_err(|_| WebGpuError::BufferMapping("channel closed before map completed".into()))?
235            .map_err(|e| WebGpuError::BufferMapping(format!("{e:?}")))?;
236
237        let data = slice.get_mapped_range();
238        let copy_len = dst.len().min(data.len());
239        dst[..copy_len].copy_from_slice(&data[..copy_len]);
240        drop(data);
241        staging.unmap();
242
243        Ok(())
244    }
245}
246
247impl std::fmt::Debug for WasmMemoryManager {
248    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249        let count = self.buffers.lock().map(|b| b.len()).unwrap_or(0);
250        write!(f, "WasmMemoryManager(buffers={count})")
251    }
252}
253
254// ---- WasmBackend -----------------------------------------------------------
255
256/// WebGPU compute backend for WASM (browser) targets.
257///
258/// Wraps [`WebGpuBackend`] and adds browser-specific initialisation paths.
259/// Implements [`ComputeBackend`] by delegating all compute operations to the
260/// inner [`WebGpuBackend`], which already supports WASM via `wgpu`'s web-sys
261/// backend.
262///
263/// # Notes
264///
265/// Synchronous [`ComputeBackend`] trait methods use `pollster::block_on` to
266/// bridge async wgpu calls. In production browser deployments, prefer using
267/// the async initialisation helpers directly and scheduling GPU work on web
268/// workers where blocking is acceptable.
269#[derive(Debug)]
270pub struct WasmBackend {
271    inner: WebGpuBackend,
272}
273
274impl WasmBackend {
275    /// Create a new, uninitialised WASM backend.
276    pub fn new() -> Self {
277        Self {
278            inner: WebGpuBackend::new(),
279        }
280    }
281
282    /// Initialise the backend from an HTML canvas element by ID.
283    ///
284    /// This is the recommended browser entry point. The canvas is not used for
285    /// rendering but is required by some WebGPU implementations to obtain a
286    /// valid adapter.
287    ///
288    /// # Errors
289    ///
290    /// Returns an error if no WebGPU adapter is available or device creation fails.
291    pub async fn init_from_canvas(_canvas_id: &str) -> Result<Self, WebGpuError> {
292        // In the browser, wgpu's web-sys backend goes through navigator.gpu
293        // which does not actually require a canvas for compute-only usage.
294        // We accept the canvas_id for forward compatibility (e.g. surface-based
295        // adapters) but currently initialise via the standard path.
296        let mut backend = Self::new();
297        backend
298            .inner
299            .init()
300            .map_err(|e| WebGpuError::DeviceRequest(e.to_string()))?;
301        Ok(backend)
302    }
303}
304
305impl Default for WasmBackend {
306    fn default() -> Self {
307        Self::new()
308    }
309}
310
311// ---- ComputeBackend for WasmBackend ----------------------------------------
312
313impl ComputeBackend for WasmBackend {
314    fn name(&self) -> &str {
315        "webgpu-wasm"
316    }
317
318    fn init(&mut self) -> BackendResult<()> {
319        self.inner.init()
320    }
321
322    fn is_initialized(&self) -> bool {
323        self.inner.is_initialized()
324    }
325
326    #[allow(clippy::too_many_arguments)]
327    fn gemm(
328        &self,
329        trans_a: BackendTranspose,
330        trans_b: BackendTranspose,
331        m: usize,
332        n: usize,
333        k: usize,
334        alpha: f64,
335        a_ptr: u64,
336        lda: usize,
337        b_ptr: u64,
338        ldb: usize,
339        beta: f64,
340        c_ptr: u64,
341        ldc: usize,
342    ) -> BackendResult<()> {
343        self.inner.gemm(
344            trans_a, trans_b, m, n, k, alpha, a_ptr, lda, b_ptr, ldb, beta, c_ptr, ldc,
345        )
346    }
347
348    #[allow(clippy::too_many_arguments)]
349    fn conv2d_forward(
350        &self,
351        input_ptr: u64,
352        input_shape: &[usize],
353        filter_ptr: u64,
354        filter_shape: &[usize],
355        output_ptr: u64,
356        output_shape: &[usize],
357        stride: &[usize],
358        padding: &[usize],
359    ) -> BackendResult<()> {
360        self.inner.conv2d_forward(
361            input_ptr,
362            input_shape,
363            filter_ptr,
364            filter_shape,
365            output_ptr,
366            output_shape,
367            stride,
368            padding,
369        )
370    }
371
372    #[allow(clippy::too_many_arguments)]
373    fn attention(
374        &self,
375        q_ptr: u64,
376        k_ptr: u64,
377        v_ptr: u64,
378        o_ptr: u64,
379        batch: usize,
380        heads: usize,
381        seq_q: usize,
382        seq_kv: usize,
383        head_dim: usize,
384        scale: f64,
385        causal: bool,
386    ) -> BackendResult<()> {
387        self.inner.attention(
388            q_ptr, k_ptr, v_ptr, o_ptr, batch, heads, seq_q, seq_kv, head_dim, scale, causal,
389        )
390    }
391
392    fn reduce(
393        &self,
394        op: ReduceOp,
395        input_ptr: u64,
396        output_ptr: u64,
397        shape: &[usize],
398        axis: usize,
399    ) -> BackendResult<()> {
400        self.inner.reduce(op, input_ptr, output_ptr, shape, axis)
401    }
402
403    fn unary(&self, op: UnaryOp, input_ptr: u64, output_ptr: u64, n: usize) -> BackendResult<()> {
404        self.inner.unary(op, input_ptr, output_ptr, n)
405    }
406
407    fn binary(
408        &self,
409        op: BinaryOp,
410        a_ptr: u64,
411        b_ptr: u64,
412        output_ptr: u64,
413        n: usize,
414    ) -> BackendResult<()> {
415        self.inner.binary(op, a_ptr, b_ptr, output_ptr, n)
416    }
417
418    fn synchronize(&self) -> BackendResult<()> {
419        self.inner.synchronize()
420    }
421
422    fn alloc(&self, bytes: usize) -> BackendResult<u64> {
423        self.inner.alloc(bytes)
424    }
425
426    fn free(&self, ptr: u64) -> BackendResult<()> {
427        self.inner.free(ptr)
428    }
429
430    fn copy_htod(&self, dst: u64, src: &[u8]) -> BackendResult<()> {
431        self.inner.copy_htod(dst, src)
432    }
433
434    fn copy_dtoh(&self, dst: &mut [u8], src: u64) -> BackendResult<()> {
435        self.inner.copy_dtoh(dst, src)
436    }
437}
438
439// ---- Tests -----------------------------------------------------------------
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444    use oxicuda_backend::BackendError;
445
446    /// Basic compilation test: the wasm module types exist and are constructible.
447    #[test]
448    fn wasm_module_compiles() {
449        let backend = WasmBackend::new();
450        assert!(!backend.is_initialized());
451        assert_eq!(backend.name(), "webgpu-wasm");
452
453        // Debug impl works.
454        let debug_str = format!("{backend:?}");
455        assert!(debug_str.contains("WasmBackend"));
456    }
457
458    /// Verify conditional compilation: wasm types implement expected traits.
459    #[test]
460    fn wasm_feature_flag_gating() {
461        // WasmBackend implements ComputeBackend.
462        let backend = WasmBackend::new();
463        let _: &dyn ComputeBackend = &backend;
464
465        // WasmBackend implements Default.
466        let _default = WasmBackend::default();
467    }
468
469    /// All public types and functions are accessible when `wasm` feature is enabled.
470    #[test]
471    fn wasm_public_api_accessible() {
472        // WasmGpuDevice is a public type.
473        fn _assert_wasm_gpu_device_exists(_: &WasmGpuDevice) {}
474
475        // WasmMemoryManager is a public type.
476        fn _assert_wasm_memory_manager_exists(_: &WasmMemoryManager) {}
477
478        // WasmBackend is a public type with new() and default().
479        let _b = WasmBackend::new();
480        let _b2 = WasmBackend::default();
481
482        // request_adapter is a public async fn (we can reference it).
483        let _fn_ptr: fn() -> _ = || request_adapter();
484    }
485
486    /// Not-initialised guards return proper errors.
487    #[test]
488    fn wasm_backend_not_initialized_guards() {
489        let b = WasmBackend::new();
490        assert_eq!(b.alloc(1024), Err(BackendError::NotInitialized));
491        assert_eq!(b.free(1), Err(BackendError::NotInitialized));
492        assert_eq!(b.copy_htod(1, b"hello"), Err(BackendError::NotInitialized));
493
494        let mut buf = [0u8; 4];
495        assert_eq!(b.copy_dtoh(&mut buf, 1), Err(BackendError::NotInitialized));
496        assert_eq!(b.synchronize(), Err(BackendError::NotInitialized));
497    }
498
499    /// Init may fail gracefully (no GPU) but must not panic.
500    #[test]
501    fn wasm_backend_init_graceful() {
502        let mut b = WasmBackend::new();
503        let _result = b.init();
504    }
505}