1use std::sync::Arc;
7
8use oxicuda_backend::{
9 BackendError, BackendResult, BackendTranspose, BinaryOp, ComputeBackend, ReduceOp, UnaryOp,
10};
11use wgpu;
12
13use crate::{device::WebGpuDevice, memory::WebGpuMemoryManager};
14
15#[derive(Debug)]
26pub struct WebGpuBackend {
27 device: Option<Arc<WebGpuDevice>>,
28 memory: Option<Arc<WebGpuMemoryManager>>,
29 initialized: bool,
30}
31
32impl WebGpuBackend {
33 pub fn new() -> Self {
35 Self {
36 device: None,
37 memory: None,
38 initialized: false,
39 }
40 }
41
42 fn check_init(&self) -> BackendResult<()> {
44 if self.initialized {
45 Ok(())
46 } else {
47 Err(BackendError::NotInitialized)
48 }
49 }
50
51 fn memory(&self) -> BackendResult<&Arc<WebGpuMemoryManager>> {
53 self.memory.as_ref().ok_or(BackendError::NotInitialized)
54 }
55}
56
57impl Default for WebGpuBackend {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63impl ComputeBackend for WebGpuBackend {
66 fn name(&self) -> &str {
67 "webgpu"
68 }
69
70 fn init(&mut self) -> BackendResult<()> {
71 if self.initialized {
72 return Ok(());
73 }
74
75 match WebGpuDevice::new() {
76 Ok(dev) => {
77 let dev = Arc::new(dev);
78 tracing::info!("WebGPU backend initialised on: {}", dev.adapter_name);
79 let memory = WebGpuMemoryManager::new(Arc::clone(&dev));
80 self.device = Some(dev);
81 self.memory = Some(Arc::new(memory));
82 self.initialized = true;
83 Ok(())
84 }
85 Err(e) => Err(BackendError::from(e)),
86 }
87 }
88
89 fn is_initialized(&self) -> bool {
90 self.initialized
91 }
92
93 fn gemm(
96 &self,
97 _trans_a: BackendTranspose,
98 _trans_b: BackendTranspose,
99 m: usize,
100 n: usize,
101 k: usize,
102 _alpha: f64,
103 _a_ptr: u64,
104 _lda: usize,
105 _b_ptr: u64,
106 _ldb: usize,
107 _beta: f64,
108 _c_ptr: u64,
109 _ldc: usize,
110 ) -> BackendResult<()> {
111 self.check_init()?;
112 if m == 0 || n == 0 || k == 0 {
114 return Ok(());
115 }
116 Err(BackendError::Unsupported(
117 "WebGPU GEMM shader dispatch not yet wired".into(),
118 ))
119 }
120
121 fn conv2d_forward(
122 &self,
123 _input_ptr: u64,
124 input_shape: &[usize],
125 _filter_ptr: u64,
126 filter_shape: &[usize],
127 _output_ptr: u64,
128 output_shape: &[usize],
129 stride: &[usize],
130 padding: &[usize],
131 ) -> BackendResult<()> {
132 self.check_init()?;
133
134 if input_shape.len() != 4 {
135 return Err(BackendError::InvalidArgument(
136 "input_shape must have 4 elements (NCHW)".into(),
137 ));
138 }
139 if filter_shape.len() != 4 {
140 return Err(BackendError::InvalidArgument(
141 "filter_shape must have 4 elements (KCFHFW)".into(),
142 ));
143 }
144 if output_shape.len() != 4 {
145 return Err(BackendError::InvalidArgument(
146 "output_shape must have 4 elements (NKOhOw)".into(),
147 ));
148 }
149 if stride.len() != 2 {
150 return Err(BackendError::InvalidArgument(
151 "stride must have 2 elements [sh, sw]".into(),
152 ));
153 }
154 if padding.len() != 2 {
155 return Err(BackendError::InvalidArgument(
156 "padding must have 2 elements [ph, pw]".into(),
157 ));
158 }
159
160 Err(BackendError::Unsupported(
161 "WebGPU conv2d not yet wired".into(),
162 ))
163 }
164
165 fn attention(
166 &self,
167 _q_ptr: u64,
168 _k_ptr: u64,
169 _v_ptr: u64,
170 _o_ptr: u64,
171 _batch: usize,
172 _heads: usize,
173 seq_q: usize,
174 seq_kv: usize,
175 head_dim: usize,
176 scale: f64,
177 _causal: bool,
178 ) -> BackendResult<()> {
179 self.check_init()?;
180
181 if seq_q == 0 || seq_kv == 0 || head_dim == 0 {
182 return Err(BackendError::InvalidArgument(
183 "seq_q, seq_kv, and head_dim must all be > 0".into(),
184 ));
185 }
186 if scale <= 0.0 || !scale.is_finite() {
187 return Err(BackendError::InvalidArgument(format!(
188 "scale must be a positive finite number, got {scale}"
189 )));
190 }
191
192 Err(BackendError::Unsupported(
193 "WebGPU attention not yet wired".into(),
194 ))
195 }
196
197 fn reduce(
198 &self,
199 _op: ReduceOp,
200 _input_ptr: u64,
201 _output_ptr: u64,
202 shape: &[usize],
203 axis: usize,
204 ) -> BackendResult<()> {
205 self.check_init()?;
206
207 if shape.is_empty() {
208 return Err(BackendError::InvalidArgument(
209 "shape must not be empty".into(),
210 ));
211 }
212 if axis >= shape.len() {
213 return Err(BackendError::InvalidArgument(format!(
214 "axis {axis} is out of bounds for shape of length {}",
215 shape.len()
216 )));
217 }
218
219 Err(BackendError::Unsupported(
220 "WebGPU reduce not yet wired".into(),
221 ))
222 }
223
224 fn unary(
225 &self,
226 _op: UnaryOp,
227 _input_ptr: u64,
228 _output_ptr: u64,
229 n: usize,
230 ) -> BackendResult<()> {
231 self.check_init()?;
232 if n == 0 {
233 return Ok(());
234 }
235 Err(BackendError::Unsupported(
236 "WebGPU unary not yet wired".into(),
237 ))
238 }
239
240 fn binary(
241 &self,
242 _op: BinaryOp,
243 _a_ptr: u64,
244 _b_ptr: u64,
245 _output_ptr: u64,
246 n: usize,
247 ) -> BackendResult<()> {
248 self.check_init()?;
249 if n == 0 {
250 return Ok(());
251 }
252 Err(BackendError::Unsupported(
253 "WebGPU binary not yet wired".into(),
254 ))
255 }
256
257 fn synchronize(&self) -> BackendResult<()> {
260 self.check_init()?;
261 if let Some(dev) = &self.device {
262 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
263 }
264 Ok(())
265 }
266
267 fn alloc(&self, bytes: usize) -> BackendResult<u64> {
270 self.check_init()?;
271 if bytes == 0 {
272 return Err(BackendError::InvalidArgument(
273 "cannot allocate 0 bytes".into(),
274 ));
275 }
276 self.memory()?.alloc(bytes).map_err(BackendError::from)
277 }
278
279 fn free(&self, ptr: u64) -> BackendResult<()> {
280 self.check_init()?;
281 self.memory()?.free(ptr).map_err(BackendError::from)
282 }
283
284 fn copy_htod(&self, dst: u64, src: &[u8]) -> BackendResult<()> {
285 self.check_init()?;
286 if src.is_empty() {
287 return Ok(());
288 }
289 self.memory()?
290 .copy_to_device(dst, src)
291 .map_err(BackendError::from)
292 }
293
294 fn copy_dtoh(&self, dst: &mut [u8], src: u64) -> BackendResult<()> {
295 self.check_init()?;
296 if dst.is_empty() {
297 return Ok(());
298 }
299 self.memory()?
300 .copy_from_device(dst, src)
301 .map_err(BackendError::from)
302 }
303}
304
305#[cfg(test)]
308mod tests {
309 use super::*;
310 use oxicuda_backend::{BackendTranspose, BinaryOp, ReduceOp, UnaryOp};
311
312 #[test]
315 fn webgpu_backend_new_uninitialized() {
316 let b = WebGpuBackend::new();
317 assert!(!b.is_initialized());
318 }
319
320 #[test]
321 fn webgpu_backend_name() {
322 let b = WebGpuBackend::new();
323 assert_eq!(b.name(), "webgpu");
324 }
325
326 #[test]
327 fn webgpu_backend_default() {
328 let b = WebGpuBackend::default();
329 assert!(!b.is_initialized());
330 assert_eq!(b.name(), "webgpu");
331 }
332
333 #[test]
334 fn backend_debug_impl() {
335 let b = WebGpuBackend::new();
336 let s = format!("{b:?}");
337 assert!(s.contains("WebGpuBackend"));
338 }
339
340 #[test]
343 fn backend_object_safe() {
344 let b: Box<dyn ComputeBackend> = Box::new(WebGpuBackend::new());
345 assert_eq!(b.name(), "webgpu");
346 }
347
348 #[test]
351 fn backend_not_initialized_gemm() {
352 let b = WebGpuBackend::new();
353 let result = b.gemm(
354 BackendTranspose::NoTrans,
355 BackendTranspose::NoTrans,
356 4,
357 4,
358 4,
359 1.0,
360 0,
361 4,
362 0,
363 4,
364 0.0,
365 0,
366 4,
367 );
368 assert_eq!(result, Err(BackendError::NotInitialized));
369 }
370
371 #[test]
372 fn backend_not_initialized_alloc() {
373 let b = WebGpuBackend::new();
374 let result = b.alloc(1024);
375 assert_eq!(result, Err(BackendError::NotInitialized));
376 }
377
378 #[test]
379 fn backend_not_initialized_synchronize() {
380 let b = WebGpuBackend::new();
381 assert_eq!(b.synchronize(), Err(BackendError::NotInitialized));
382 }
383
384 #[test]
385 fn backend_not_initialized_free() {
386 let b = WebGpuBackend::new();
387 assert_eq!(b.free(1), Err(BackendError::NotInitialized));
388 }
389
390 #[test]
391 fn backend_not_initialized_copy_htod() {
392 let b = WebGpuBackend::new();
393 assert_eq!(b.copy_htod(1, b"hello"), Err(BackendError::NotInitialized));
394 }
395
396 #[test]
397 fn backend_not_initialized_copy_dtoh() {
398 let b = WebGpuBackend::new();
399 let mut buf = [0u8; 4];
400 assert_eq!(b.copy_dtoh(&mut buf, 1), Err(BackendError::NotInitialized));
401 }
402
403 fn try_init() -> Option<WebGpuBackend> {
408 let mut b = WebGpuBackend::new();
409 match b.init() {
410 Ok(()) => Some(b),
411 Err(_) => None,
412 }
413 }
414
415 #[test]
416 fn gemm_zero_size_after_init() {
417 let Some(b) = try_init() else {
418 return;
419 };
420 let result = b.gemm(
421 BackendTranspose::NoTrans,
422 BackendTranspose::NoTrans,
423 0,
424 0,
425 0,
426 1.0,
427 0,
428 1,
429 0,
430 1,
431 0.0,
432 0,
433 1,
434 );
435 assert_eq!(result, Ok(()));
436 }
437
438 #[test]
439 fn unary_zero_elements_after_init() {
440 let Some(b) = try_init() else {
441 return;
442 };
443 assert_eq!(b.unary(UnaryOp::Relu, 0, 0, 0), Ok(()));
444 }
445
446 #[test]
447 fn binary_zero_elements_after_init() {
448 let Some(b) = try_init() else {
449 return;
450 };
451 assert_eq!(b.binary(BinaryOp::Add, 0, 0, 0, 0), Ok(()));
452 }
453
454 #[test]
455 fn copy_htod_empty_noop() {
456 let Some(b) = try_init() else {
457 return;
458 };
459 assert_eq!(b.copy_htod(0, &[]), Ok(()));
460 }
461
462 #[test]
463 fn copy_dtoh_empty_noop() {
464 let Some(b) = try_init() else {
465 return;
466 };
467 assert_eq!(b.copy_dtoh(&mut [], 0), Ok(()));
468 }
469
470 #[test]
471 fn alloc_zero_bytes_error() {
472 let Some(b) = try_init() else {
473 return;
474 };
475 assert_eq!(
476 b.alloc(0),
477 Err(BackendError::InvalidArgument(
478 "cannot allocate 0 bytes".into()
479 ))
480 );
481 }
482
483 #[test]
484 fn synchronize_after_init() {
485 let Some(b) = try_init() else {
486 return;
487 };
488 assert_eq!(b.synchronize(), Ok(()));
489 }
490
491 #[test]
494 fn reduce_empty_shape_error() {
495 let Some(b) = try_init() else {
496 return;
497 };
498 assert_eq!(
499 b.reduce(ReduceOp::Sum, 0, 0, &[], 0),
500 Err(BackendError::InvalidArgument(
501 "shape must not be empty".into()
502 ))
503 );
504 }
505
506 #[test]
507 fn reduce_axis_out_of_bounds_error() {
508 let Some(b) = try_init() else {
509 return;
510 };
511 assert_eq!(
512 b.reduce(ReduceOp::Sum, 0, 0, &[4, 4], 5),
513 Err(BackendError::InvalidArgument(
514 "axis 5 is out of bounds for shape of length 2".into()
515 ))
516 );
517 }
518
519 #[test]
520 fn attention_zero_seq_error() {
521 let Some(b) = try_init() else {
522 return;
523 };
524 assert_eq!(
525 b.attention(0, 0, 0, 0, 1, 1, 0, 8, 64, 0.125, false),
526 Err(BackendError::InvalidArgument(
527 "seq_q, seq_kv, and head_dim must all be > 0".into()
528 ))
529 );
530 }
531
532 #[test]
533 fn attention_nonpositive_scale_error() {
534 let Some(b) = try_init() else {
535 return;
536 };
537 assert_eq!(
538 b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, 0.0, false),
539 Err(BackendError::InvalidArgument(
540 "scale must be a positive finite number, got 0".into()
541 ))
542 );
543 assert_eq!(
544 b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, -1.0, false),
545 Err(BackendError::InvalidArgument(
546 "scale must be a positive finite number, got -1".into()
547 ))
548 );
549 assert!(
550 b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, f64::INFINITY, false)
551 .is_err()
552 );
553 }
554
555 #[test]
556 fn conv2d_wrong_input_shape_error() {
557 let Some(b) = try_init() else {
558 return;
559 };
560 assert_eq!(
562 b.conv2d_forward(
563 0,
564 &[1, 3, 32],
565 0,
566 &[16, 3, 3, 3],
567 0,
568 &[1, 16, 30, 30],
569 &[1, 1],
570 &[0, 0]
571 ),
572 Err(BackendError::InvalidArgument(
573 "input_shape must have 4 elements (NCHW)".into()
574 ))
575 );
576 }
577
578 #[test]
579 fn conv2d_wrong_filter_shape_error() {
580 let Some(b) = try_init() else {
581 return;
582 };
583 assert_eq!(
584 b.conv2d_forward(
585 0,
586 &[1, 3, 32, 32],
587 0,
588 &[16, 3, 3],
589 0,
590 &[1, 16, 30, 30],
591 &[1, 1],
592 &[0, 0]
593 ),
594 Err(BackendError::InvalidArgument(
595 "filter_shape must have 4 elements (KCFHFW)".into()
596 ))
597 );
598 }
599
600 #[test]
601 fn conv2d_wrong_stride_shape_error() {
602 let Some(b) = try_init() else {
603 return;
604 };
605 assert_eq!(
606 b.conv2d_forward(
607 0,
608 &[1, 3, 32, 32],
609 0,
610 &[16, 3, 3, 3],
611 0,
612 &[1, 16, 30, 30],
613 &[1], &[0, 0],
615 ),
616 Err(BackendError::InvalidArgument(
617 "stride must have 2 elements [sh, sw]".into()
618 ))
619 );
620 }
621
622 #[test]
625 fn init_idempotent() {
626 let Some(mut b) = try_init() else {
627 return;
628 };
629 assert_eq!(b.init(), Ok(()));
631 assert!(b.is_initialized());
632 }
633
634 #[test]
637 fn webgpu_init_graceful_failure() {
638 let mut b = WebGpuBackend::new();
641 let _result = b.init(); }
644}