Skip to main content

oxicuda_webgpu/
backend.rs

1//! [`WebGpuBackend`] — the main entry point for the oxicuda-webgpu crate.
2//!
3//! Implements the [`ComputeBackend`] trait from `oxicuda-backend` using
4//! `wgpu` for cross-platform GPU compute (Vulkan, Metal, DX12, WebGPU).
5
6use std::sync::Arc;
7
8use oxicuda_backend::{
9    BackendError, BackendResult, BackendTranspose, BinaryOp, ComputeBackend, ReduceOp, UnaryOp,
10};
11use wgpu;
12
13use crate::{device::WebGpuDevice, memory::WebGpuMemoryManager};
14
15// ─── Backend struct ──────────────────────────────────────────────────────────
16
17/// Cross-platform GPU compute backend backed by `wgpu`.
18///
19/// # Lifecycle
20///
21/// 1. `WebGpuBackend::new()` — create an uninitialised backend.
22/// 2. `init()` — select the best available adapter and create the device.
23/// 3. Use `alloc`, `copy_htod`, compute ops, `copy_dtoh`, `free`.
24/// 4. `synchronize()` — wait for all pending GPU work to finish.
25#[derive(Debug)]
26pub struct WebGpuBackend {
27    device: Option<Arc<WebGpuDevice>>,
28    memory: Option<Arc<WebGpuMemoryManager>>,
29    initialized: bool,
30}
31
32impl WebGpuBackend {
33    /// Create a new, uninitialised WebGPU backend.
34    pub fn new() -> Self {
35        Self {
36            device: None,
37            memory: None,
38            initialized: false,
39        }
40    }
41
42    /// Return an error if the backend is not yet initialised.
43    fn check_init(&self) -> BackendResult<()> {
44        if self.initialized {
45            Ok(())
46        } else {
47            Err(BackendError::NotInitialized)
48        }
49    }
50
51    /// Convenience accessor: get the memory manager or return `NotInitialized`.
52    fn memory(&self) -> BackendResult<&Arc<WebGpuMemoryManager>> {
53        self.memory.as_ref().ok_or(BackendError::NotInitialized)
54    }
55}
56
57impl Default for WebGpuBackend {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63// ─── ComputeBackend impl ─────────────────────────────────────────────────────
64
65impl ComputeBackend for WebGpuBackend {
66    fn name(&self) -> &str {
67        "webgpu"
68    }
69
70    fn init(&mut self) -> BackendResult<()> {
71        if self.initialized {
72            return Ok(());
73        }
74
75        match WebGpuDevice::new() {
76            Ok(dev) => {
77                let dev = Arc::new(dev);
78                tracing::info!("WebGPU backend initialised on: {}", dev.adapter_name);
79                let memory = WebGpuMemoryManager::new(Arc::clone(&dev));
80                self.device = Some(dev);
81                self.memory = Some(Arc::new(memory));
82                self.initialized = true;
83                Ok(())
84            }
85            Err(e) => Err(BackendError::from(e)),
86        }
87    }
88
89    fn is_initialized(&self) -> bool {
90        self.initialized
91    }
92
93    // ── Compute operations ────────────────────────────────────────────────────
94
95    fn gemm(
96        &self,
97        _trans_a: BackendTranspose,
98        _trans_b: BackendTranspose,
99        m: usize,
100        n: usize,
101        k: usize,
102        _alpha: f64,
103        _a_ptr: u64,
104        _lda: usize,
105        _b_ptr: u64,
106        _ldb: usize,
107        _beta: f64,
108        _c_ptr: u64,
109        _ldc: usize,
110    ) -> BackendResult<()> {
111        self.check_init()?;
112        // Zero-dimension matrices are trivially done.
113        if m == 0 || n == 0 || k == 0 {
114            return Ok(());
115        }
116        Err(BackendError::Unsupported(
117            "WebGPU GEMM shader dispatch not yet wired".into(),
118        ))
119    }
120
121    fn conv2d_forward(
122        &self,
123        _input_ptr: u64,
124        input_shape: &[usize],
125        _filter_ptr: u64,
126        filter_shape: &[usize],
127        _output_ptr: u64,
128        output_shape: &[usize],
129        stride: &[usize],
130        padding: &[usize],
131    ) -> BackendResult<()> {
132        self.check_init()?;
133
134        if input_shape.len() != 4 {
135            return Err(BackendError::InvalidArgument(
136                "input_shape must have 4 elements (NCHW)".into(),
137            ));
138        }
139        if filter_shape.len() != 4 {
140            return Err(BackendError::InvalidArgument(
141                "filter_shape must have 4 elements (KCFHFW)".into(),
142            ));
143        }
144        if output_shape.len() != 4 {
145            return Err(BackendError::InvalidArgument(
146                "output_shape must have 4 elements (NKOhOw)".into(),
147            ));
148        }
149        if stride.len() != 2 {
150            return Err(BackendError::InvalidArgument(
151                "stride must have 2 elements [sh, sw]".into(),
152            ));
153        }
154        if padding.len() != 2 {
155            return Err(BackendError::InvalidArgument(
156                "padding must have 2 elements [ph, pw]".into(),
157            ));
158        }
159
160        Err(BackendError::Unsupported(
161            "WebGPU conv2d not yet wired".into(),
162        ))
163    }
164
165    fn attention(
166        &self,
167        _q_ptr: u64,
168        _k_ptr: u64,
169        _v_ptr: u64,
170        _o_ptr: u64,
171        _batch: usize,
172        _heads: usize,
173        seq_q: usize,
174        seq_kv: usize,
175        head_dim: usize,
176        scale: f64,
177        _causal: bool,
178    ) -> BackendResult<()> {
179        self.check_init()?;
180
181        if seq_q == 0 || seq_kv == 0 || head_dim == 0 {
182            return Err(BackendError::InvalidArgument(
183                "seq_q, seq_kv, and head_dim must all be > 0".into(),
184            ));
185        }
186        if scale <= 0.0 || !scale.is_finite() {
187            return Err(BackendError::InvalidArgument(format!(
188                "scale must be a positive finite number, got {scale}"
189            )));
190        }
191
192        Err(BackendError::Unsupported(
193            "WebGPU attention not yet wired".into(),
194        ))
195    }
196
197    fn reduce(
198        &self,
199        _op: ReduceOp,
200        _input_ptr: u64,
201        _output_ptr: u64,
202        shape: &[usize],
203        axis: usize,
204    ) -> BackendResult<()> {
205        self.check_init()?;
206
207        if shape.is_empty() {
208            return Err(BackendError::InvalidArgument(
209                "shape must not be empty".into(),
210            ));
211        }
212        if axis >= shape.len() {
213            return Err(BackendError::InvalidArgument(format!(
214                "axis {axis} is out of bounds for shape of length {}",
215                shape.len()
216            )));
217        }
218
219        Err(BackendError::Unsupported(
220            "WebGPU reduce not yet wired".into(),
221        ))
222    }
223
224    fn unary(
225        &self,
226        _op: UnaryOp,
227        _input_ptr: u64,
228        _output_ptr: u64,
229        n: usize,
230    ) -> BackendResult<()> {
231        self.check_init()?;
232        if n == 0 {
233            return Ok(());
234        }
235        Err(BackendError::Unsupported(
236            "WebGPU unary not yet wired".into(),
237        ))
238    }
239
240    fn binary(
241        &self,
242        _op: BinaryOp,
243        _a_ptr: u64,
244        _b_ptr: u64,
245        _output_ptr: u64,
246        n: usize,
247    ) -> BackendResult<()> {
248        self.check_init()?;
249        if n == 0 {
250            return Ok(());
251        }
252        Err(BackendError::Unsupported(
253            "WebGPU binary not yet wired".into(),
254        ))
255    }
256
257    // ── Synchronisation ───────────────────────────────────────────────────────
258
259    fn synchronize(&self) -> BackendResult<()> {
260        self.check_init()?;
261        if let Some(dev) = &self.device {
262            let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
263        }
264        Ok(())
265    }
266
267    // ── Memory management ─────────────────────────────────────────────────────
268
269    fn alloc(&self, bytes: usize) -> BackendResult<u64> {
270        self.check_init()?;
271        if bytes == 0 {
272            return Err(BackendError::InvalidArgument(
273                "cannot allocate 0 bytes".into(),
274            ));
275        }
276        self.memory()?.alloc(bytes).map_err(BackendError::from)
277    }
278
279    fn free(&self, ptr: u64) -> BackendResult<()> {
280        self.check_init()?;
281        self.memory()?.free(ptr).map_err(BackendError::from)
282    }
283
284    fn copy_htod(&self, dst: u64, src: &[u8]) -> BackendResult<()> {
285        self.check_init()?;
286        if src.is_empty() {
287            return Ok(());
288        }
289        self.memory()?
290            .copy_to_device(dst, src)
291            .map_err(BackendError::from)
292    }
293
294    fn copy_dtoh(&self, dst: &mut [u8], src: u64) -> BackendResult<()> {
295        self.check_init()?;
296        if dst.is_empty() {
297            return Ok(());
298        }
299        self.memory()?
300            .copy_from_device(dst, src)
301            .map_err(BackendError::from)
302    }
303}
304
305// ─── Tests ───────────────────────────────────────────────────────────────────
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310    use oxicuda_backend::{BackendTranspose, BinaryOp, ReduceOp, UnaryOp};
311
312    // ── Construction ──────────────────────────────────────────────────────────
313
314    #[test]
315    fn webgpu_backend_new_uninitialized() {
316        let b = WebGpuBackend::new();
317        assert!(!b.is_initialized());
318    }
319
320    #[test]
321    fn webgpu_backend_name() {
322        let b = WebGpuBackend::new();
323        assert_eq!(b.name(), "webgpu");
324    }
325
326    #[test]
327    fn webgpu_backend_default() {
328        let b = WebGpuBackend::default();
329        assert!(!b.is_initialized());
330        assert_eq!(b.name(), "webgpu");
331    }
332
333    #[test]
334    fn backend_debug_impl() {
335        let b = WebGpuBackend::new();
336        let s = format!("{b:?}");
337        assert!(s.contains("WebGpuBackend"));
338    }
339
340    // ── Object-safety smoke test ──────────────────────────────────────────────
341
342    #[test]
343    fn backend_object_safe() {
344        let b: Box<dyn ComputeBackend> = Box::new(WebGpuBackend::new());
345        assert_eq!(b.name(), "webgpu");
346    }
347
348    // ── Not-initialized guards ────────────────────────────────────────────────
349
350    #[test]
351    fn backend_not_initialized_gemm() {
352        let b = WebGpuBackend::new();
353        let result = b.gemm(
354            BackendTranspose::NoTrans,
355            BackendTranspose::NoTrans,
356            4,
357            4,
358            4,
359            1.0,
360            0,
361            4,
362            0,
363            4,
364            0.0,
365            0,
366            4,
367        );
368        assert_eq!(result, Err(BackendError::NotInitialized));
369    }
370
371    #[test]
372    fn backend_not_initialized_alloc() {
373        let b = WebGpuBackend::new();
374        let result = b.alloc(1024);
375        assert_eq!(result, Err(BackendError::NotInitialized));
376    }
377
378    #[test]
379    fn backend_not_initialized_synchronize() {
380        let b = WebGpuBackend::new();
381        assert_eq!(b.synchronize(), Err(BackendError::NotInitialized));
382    }
383
384    #[test]
385    fn backend_not_initialized_free() {
386        let b = WebGpuBackend::new();
387        assert_eq!(b.free(1), Err(BackendError::NotInitialized));
388    }
389
390    #[test]
391    fn backend_not_initialized_copy_htod() {
392        let b = WebGpuBackend::new();
393        assert_eq!(b.copy_htod(1, b"hello"), Err(BackendError::NotInitialized));
394    }
395
396    #[test]
397    fn backend_not_initialized_copy_dtoh() {
398        let b = WebGpuBackend::new();
399        let mut buf = [0u8; 4];
400        assert_eq!(b.copy_dtoh(&mut buf, 1), Err(BackendError::NotInitialized));
401    }
402
403    // ── Zero-size / trivial-OK paths (no GPU needed) ─────────────────────────
404
405    /// These tests exercise the "no-op for zero size" branches.  We need the
406    /// backend to be initialised, but if no GPU is available we skip.
407    fn try_init() -> Option<WebGpuBackend> {
408        let mut b = WebGpuBackend::new();
409        match b.init() {
410            Ok(()) => Some(b),
411            Err(_) => None,
412        }
413    }
414
415    #[test]
416    fn gemm_zero_size_after_init() {
417        let Some(b) = try_init() else {
418            return;
419        };
420        let result = b.gemm(
421            BackendTranspose::NoTrans,
422            BackendTranspose::NoTrans,
423            0,
424            0,
425            0,
426            1.0,
427            0,
428            1,
429            0,
430            1,
431            0.0,
432            0,
433            1,
434        );
435        assert_eq!(result, Ok(()));
436    }
437
438    #[test]
439    fn unary_zero_elements_after_init() {
440        let Some(b) = try_init() else {
441            return;
442        };
443        assert_eq!(b.unary(UnaryOp::Relu, 0, 0, 0), Ok(()));
444    }
445
446    #[test]
447    fn binary_zero_elements_after_init() {
448        let Some(b) = try_init() else {
449            return;
450        };
451        assert_eq!(b.binary(BinaryOp::Add, 0, 0, 0, 0), Ok(()));
452    }
453
454    #[test]
455    fn copy_htod_empty_noop() {
456        let Some(b) = try_init() else {
457            return;
458        };
459        assert_eq!(b.copy_htod(0, &[]), Ok(()));
460    }
461
462    #[test]
463    fn copy_dtoh_empty_noop() {
464        let Some(b) = try_init() else {
465            return;
466        };
467        assert_eq!(b.copy_dtoh(&mut [], 0), Ok(()));
468    }
469
470    #[test]
471    fn alloc_zero_bytes_error() {
472        let Some(b) = try_init() else {
473            return;
474        };
475        assert_eq!(
476            b.alloc(0),
477            Err(BackendError::InvalidArgument(
478                "cannot allocate 0 bytes".into()
479            ))
480        );
481    }
482
483    #[test]
484    fn synchronize_after_init() {
485        let Some(b) = try_init() else {
486            return;
487        };
488        assert_eq!(b.synchronize(), Ok(()));
489    }
490
491    // ── Argument validation (post-init) ───────────────────────────────────────
492
493    #[test]
494    fn reduce_empty_shape_error() {
495        let Some(b) = try_init() else {
496            return;
497        };
498        assert_eq!(
499            b.reduce(ReduceOp::Sum, 0, 0, &[], 0),
500            Err(BackendError::InvalidArgument(
501                "shape must not be empty".into()
502            ))
503        );
504    }
505
506    #[test]
507    fn reduce_axis_out_of_bounds_error() {
508        let Some(b) = try_init() else {
509            return;
510        };
511        assert_eq!(
512            b.reduce(ReduceOp::Sum, 0, 0, &[4, 4], 5),
513            Err(BackendError::InvalidArgument(
514                "axis 5 is out of bounds for shape of length 2".into()
515            ))
516        );
517    }
518
519    #[test]
520    fn attention_zero_seq_error() {
521        let Some(b) = try_init() else {
522            return;
523        };
524        assert_eq!(
525            b.attention(0, 0, 0, 0, 1, 1, 0, 8, 64, 0.125, false),
526            Err(BackendError::InvalidArgument(
527                "seq_q, seq_kv, and head_dim must all be > 0".into()
528            ))
529        );
530    }
531
532    #[test]
533    fn attention_nonpositive_scale_error() {
534        let Some(b) = try_init() else {
535            return;
536        };
537        assert_eq!(
538            b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, 0.0, false),
539            Err(BackendError::InvalidArgument(
540                "scale must be a positive finite number, got 0".into()
541            ))
542        );
543        assert_eq!(
544            b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, -1.0, false),
545            Err(BackendError::InvalidArgument(
546                "scale must be a positive finite number, got -1".into()
547            ))
548        );
549        assert!(
550            b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, f64::INFINITY, false)
551                .is_err()
552        );
553    }
554
555    #[test]
556    fn conv2d_wrong_input_shape_error() {
557        let Some(b) = try_init() else {
558            return;
559        };
560        // 3-element input_shape — should fail.
561        assert_eq!(
562            b.conv2d_forward(
563                0,
564                &[1, 3, 32],
565                0,
566                &[16, 3, 3, 3],
567                0,
568                &[1, 16, 30, 30],
569                &[1, 1],
570                &[0, 0]
571            ),
572            Err(BackendError::InvalidArgument(
573                "input_shape must have 4 elements (NCHW)".into()
574            ))
575        );
576    }
577
578    #[test]
579    fn conv2d_wrong_filter_shape_error() {
580        let Some(b) = try_init() else {
581            return;
582        };
583        assert_eq!(
584            b.conv2d_forward(
585                0,
586                &[1, 3, 32, 32],
587                0,
588                &[16, 3, 3],
589                0,
590                &[1, 16, 30, 30],
591                &[1, 1],
592                &[0, 0]
593            ),
594            Err(BackendError::InvalidArgument(
595                "filter_shape must have 4 elements (KCFHFW)".into()
596            ))
597        );
598    }
599
600    #[test]
601    fn conv2d_wrong_stride_shape_error() {
602        let Some(b) = try_init() else {
603            return;
604        };
605        assert_eq!(
606            b.conv2d_forward(
607                0,
608                &[1, 3, 32, 32],
609                0,
610                &[16, 3, 3, 3],
611                0,
612                &[1, 16, 30, 30],
613                &[1], // <-- wrong
614                &[0, 0],
615            ),
616            Err(BackendError::InvalidArgument(
617                "stride must have 2 elements [sh, sw]".into()
618            ))
619        );
620    }
621
622    // ── Init is idempotent ────────────────────────────────────────────────────
623
624    #[test]
625    fn init_idempotent() {
626        let Some(mut b) = try_init() else {
627            return;
628        };
629        // Second call must succeed without error.
630        assert_eq!(b.init(), Ok(()));
631        assert!(b.is_initialized());
632    }
633
634    // ── Graceful failure ──────────────────────────────────────────────────────
635
636    #[test]
637    fn webgpu_init_graceful_failure() {
638        // We cannot force a failure, but we can at least verify that init()
639        // returns a Result and never panics.
640        let mut b = WebGpuBackend::new();
641        let _result = b.init(); // Ok or Err — both are acceptable.
642        // No panic => test passes.
643    }
644}