1use std::sync::Arc;
7
8use oxicuda_backend::{
9 BackendError, BackendResult, BackendTranspose, BinaryOp, ComputeBackend, ReduceOp, UnaryOp,
10};
11use wgpu;
12
13use crate::{device::WebGpuDevice, memory::WebGpuMemoryManager, shader};
14
15fn map_unary_op(op: UnaryOp) -> &'static str {
18 match op {
19 UnaryOp::Relu => "relu",
20 UnaryOp::Sigmoid => "sigmoid",
21 UnaryOp::Tanh => "tanh",
22 UnaryOp::Exp => "exp",
23 UnaryOp::Log => "log",
24 UnaryOp::Sqrt => "sqrt",
25 UnaryOp::Abs => "abs",
26 UnaryOp::Neg => "neg",
27 }
28}
29
30fn map_binary_op(op: BinaryOp) -> &'static str {
31 match op {
32 BinaryOp::Add => "add",
33 BinaryOp::Sub => "sub",
34 BinaryOp::Mul => "mul",
35 BinaryOp::Div => "div",
36 BinaryOp::Max => "max",
37 BinaryOp::Min => "min",
38 }
39}
40
41fn map_reduce_op(op: ReduceOp) -> &'static str {
42 match op {
43 ReduceOp::Sum => "sum",
44 ReduceOp::Max => "max",
45 ReduceOp::Min => "min",
46 ReduceOp::Mean => "mean",
47 }
48}
49
50#[derive(Debug)]
61pub struct WebGpuBackend {
62 device: Option<Arc<WebGpuDevice>>,
63 memory: Option<Arc<WebGpuMemoryManager>>,
64 initialized: bool,
65}
66
67impl WebGpuBackend {
68 pub fn new() -> Self {
70 Self {
71 device: None,
72 memory: None,
73 initialized: false,
74 }
75 }
76
77 fn check_init(&self) -> BackendResult<()> {
79 if self.initialized {
80 Ok(())
81 } else {
82 Err(BackendError::NotInitialized)
83 }
84 }
85
86 fn memory(&self) -> BackendResult<&Arc<WebGpuMemoryManager>> {
88 self.memory.as_ref().ok_or(BackendError::NotInitialized)
89 }
90
91 fn device(&self) -> BackendResult<&Arc<WebGpuDevice>> {
93 self.device.as_ref().ok_or(BackendError::NotInitialized)
94 }
95}
96
97impl WebGpuBackend {
98 #[allow(clippy::too_many_arguments)]
106 pub fn gemm_f16(
107 &self,
108 m: usize,
109 n: usize,
110 k: usize,
111 alpha: f64,
112 a_ptr: u64,
113 b_ptr: u64,
114 beta: f64,
115 c_ptr: u64,
116 ) -> BackendResult<()> {
117 self.check_init()?;
118 if m == 0 || n == 0 || k == 0 {
119 return Ok(());
120 }
121
122 let dev = self.device()?;
123 let mem = self.memory()?;
124
125 let tile_size: u32 = 8;
126 let wgsl = shader::gemm_wgsl_f16(tile_size);
127
128 let shader_mod = dev
129 .device
130 .create_shader_module(wgpu::ShaderModuleDescriptor {
131 label: Some("oxicuda-gemm-f16"),
132 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
133 });
134
135 let pipeline = dev
136 .device
137 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
138 label: Some("oxicuda-gemm-f16"),
139 layout: None,
140 module: &shader_mod,
141 entry_point: Some("main"),
142 compilation_options: Default::default(),
143 cache: None,
144 });
145
146 let bgl = pipeline.get_bind_group_layout(0);
147
148 let mut params_bytes = [0u8; 20];
150 params_bytes[0..4].copy_from_slice(&(m as u32).to_le_bytes());
151 params_bytes[4..8].copy_from_slice(&(n as u32).to_le_bytes());
152 params_bytes[8..12].copy_from_slice(&(k as u32).to_le_bytes());
153 params_bytes[12..16].copy_from_slice(&(alpha as f32).to_le_bytes());
154 params_bytes[16..20].copy_from_slice(&(beta as f32).to_le_bytes());
155
156 let uniform_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
157 label: Some("oxicuda-gemm-f16-params"),
158 size: 20,
159 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
160 mapped_at_creation: false,
161 });
162 dev.queue.write_buffer(&uniform_buf, 0, ¶ms_bytes);
163
164 let bind_group = {
165 let buffers = mem
166 .lock_buffers()
167 .map_err(|e| BackendError::DeviceError(e.to_string()))?;
168 let a_info = buffers
169 .get(&a_ptr)
170 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {a_ptr}")))?;
171 let b_info = buffers
172 .get(&b_ptr)
173 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {b_ptr}")))?;
174 let c_info = buffers
175 .get(&c_ptr)
176 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {c_ptr}")))?;
177
178 dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
179 label: Some("oxicuda-gemm-f16"),
180 layout: &bgl,
181 entries: &[
182 wgpu::BindGroupEntry {
183 binding: 0,
184 resource: a_info.buffer.as_entire_binding(),
185 },
186 wgpu::BindGroupEntry {
187 binding: 1,
188 resource: b_info.buffer.as_entire_binding(),
189 },
190 wgpu::BindGroupEntry {
191 binding: 2,
192 resource: c_info.buffer.as_entire_binding(),
193 },
194 wgpu::BindGroupEntry {
195 binding: 3,
196 resource: uniform_buf.as_entire_binding(),
197 },
198 ],
199 })
200 };
201
202 let mut encoder = dev
203 .device
204 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
205 label: Some("oxicuda-gemm-f16"),
206 });
207
208 {
209 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
210 label: Some("oxicuda-gemm-f16"),
211 timestamp_writes: None,
212 });
213 pass.set_pipeline(&pipeline);
214 pass.set_bind_group(0, &bind_group, &[]);
215 let wg_x = (n as u32).div_ceil(tile_size);
216 let wg_y = (m as u32).div_ceil(tile_size);
217 pass.dispatch_workgroups(wg_x, wg_y, 1);
218 }
219
220 dev.queue.submit(std::iter::once(encoder.finish()));
221 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
222
223 Ok(())
224 }
225}
226
227impl Default for WebGpuBackend {
228 fn default() -> Self {
229 Self::new()
230 }
231}
232
233impl ComputeBackend for WebGpuBackend {
236 fn name(&self) -> &str {
237 "webgpu"
238 }
239
240 fn init(&mut self) -> BackendResult<()> {
241 if self.initialized {
242 return Ok(());
243 }
244
245 match WebGpuDevice::new() {
246 Ok(dev) => {
247 let dev = Arc::new(dev);
248 tracing::info!("WebGPU backend initialised on: {}", dev.adapter_name);
249 let memory = WebGpuMemoryManager::new(Arc::clone(&dev));
250 self.device = Some(dev);
251 self.memory = Some(Arc::new(memory));
252 self.initialized = true;
253 Ok(())
254 }
255 Err(e) => Err(BackendError::from(e)),
256 }
257 }
258
259 fn is_initialized(&self) -> bool {
260 self.initialized
261 }
262
263 fn gemm(
266 &self,
267 trans_a: BackendTranspose,
268 trans_b: BackendTranspose,
269 m: usize,
270 n: usize,
271 k: usize,
272 alpha: f64,
273 a_ptr: u64,
274 _lda: usize,
275 b_ptr: u64,
276 _ldb: usize,
277 beta: f64,
278 c_ptr: u64,
279 _ldc: usize,
280 ) -> BackendResult<()> {
281 self.check_init()?;
282 if m == 0 || n == 0 || k == 0 {
284 return Ok(());
285 }
286
287 if trans_a != BackendTranspose::NoTrans || trans_b != BackendTranspose::NoTrans {
289 return Err(BackendError::Unsupported(
290 "WebGPU GEMM does not yet support transposed inputs".into(),
291 ));
292 }
293
294 let dev = self.device()?;
295 let mem = self.memory()?;
296
297 let tile_size: u32 = 8;
298 let wgsl = shader::gemm_wgsl(tile_size);
299
300 let shader_mod = dev
301 .device
302 .create_shader_module(wgpu::ShaderModuleDescriptor {
303 label: Some("oxicuda-gemm"),
304 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
305 });
306
307 let pipeline = dev
308 .device
309 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
310 label: Some("oxicuda-gemm"),
311 layout: None,
312 module: &shader_mod,
313 entry_point: Some("main"),
314 compilation_options: Default::default(),
315 cache: None,
316 });
317
318 let bgl = pipeline.get_bind_group_layout(0);
319
320 let mut params_bytes = [0u8; 20];
322 params_bytes[0..4].copy_from_slice(&(m as u32).to_le_bytes());
323 params_bytes[4..8].copy_from_slice(&(n as u32).to_le_bytes());
324 params_bytes[8..12].copy_from_slice(&(k as u32).to_le_bytes());
325 params_bytes[12..16].copy_from_slice(&(alpha as f32).to_le_bytes());
326 params_bytes[16..20].copy_from_slice(&(beta as f32).to_le_bytes());
327
328 let uniform_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
329 label: Some("oxicuda-gemm-params"),
330 size: 20,
331 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
332 mapped_at_creation: false,
333 });
334 dev.queue.write_buffer(&uniform_buf, 0, ¶ms_bytes);
335
336 let bind_group = {
338 let buffers = mem
339 .lock_buffers()
340 .map_err(|e| BackendError::DeviceError(e.to_string()))?;
341 let a_info = buffers
342 .get(&a_ptr)
343 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {a_ptr}")))?;
344 let b_info = buffers
345 .get(&b_ptr)
346 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {b_ptr}")))?;
347 let c_info = buffers
348 .get(&c_ptr)
349 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {c_ptr}")))?;
350
351 dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
352 label: Some("oxicuda-gemm"),
353 layout: &bgl,
354 entries: &[
355 wgpu::BindGroupEntry {
356 binding: 0,
357 resource: a_info.buffer.as_entire_binding(),
358 },
359 wgpu::BindGroupEntry {
360 binding: 1,
361 resource: b_info.buffer.as_entire_binding(),
362 },
363 wgpu::BindGroupEntry {
364 binding: 2,
365 resource: c_info.buffer.as_entire_binding(),
366 },
367 wgpu::BindGroupEntry {
368 binding: 3,
369 resource: uniform_buf.as_entire_binding(),
370 },
371 ],
372 })
373 };
374
375 let mut encoder = dev
376 .device
377 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
378 label: Some("oxicuda-gemm"),
379 });
380
381 {
382 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
383 label: Some("oxicuda-gemm"),
384 timestamp_writes: None,
385 });
386 pass.set_pipeline(&pipeline);
387 pass.set_bind_group(0, &bind_group, &[]);
388 let wg_x = (n as u32).div_ceil(tile_size);
389 let wg_y = (m as u32).div_ceil(tile_size);
390 pass.dispatch_workgroups(wg_x, wg_y, 1);
391 }
392
393 dev.queue.submit(std::iter::once(encoder.finish()));
394 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
395
396 Ok(())
397 }
398
399 #[allow(clippy::too_many_arguments)]
400 fn batched_gemm(
401 &self,
402 trans_a: BackendTranspose,
403 trans_b: BackendTranspose,
404 m: usize,
405 n: usize,
406 k: usize,
407 alpha: f64,
408 a_ptr: u64,
409 _lda: usize,
410 stride_a: usize,
411 b_ptr: u64,
412 _ldb: usize,
413 stride_b: usize,
414 beta: f64,
415 c_ptr: u64,
416 _ldc: usize,
417 stride_c: usize,
418 batch_count: usize,
419 ) -> BackendResult<()> {
420 self.check_init()?;
421
422 if batch_count == 0 || m == 0 || n == 0 || k == 0 {
423 return Ok(());
424 }
425
426 if trans_a != BackendTranspose::NoTrans || trans_b != BackendTranspose::NoTrans {
427 return Err(BackendError::Unsupported(
428 "WebGPU batched GEMM does not yet support transposed inputs".into(),
429 ));
430 }
431
432 let dev = self.device()?;
433 let mem = self.memory()?;
434
435 let tile_size: u32 = 8;
436 let wgsl = shader::batched_gemm_wgsl(tile_size);
437
438 let shader_mod = dev
439 .device
440 .create_shader_module(wgpu::ShaderModuleDescriptor {
441 label: Some("oxicuda-batched-gemm"),
442 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
443 });
444
445 let pipeline = dev
446 .device
447 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
448 label: Some("oxicuda-batched-gemm"),
449 layout: None,
450 module: &shader_mod,
451 entry_point: Some("main"),
452 compilation_options: Default::default(),
453 cache: None,
454 });
455
456 let bgl = pipeline.get_bind_group_layout(0);
457
458 let mut params_bytes = [0u8; 48];
464 params_bytes[0..4].copy_from_slice(&(m as u32).to_le_bytes());
465 params_bytes[4..8].copy_from_slice(&(n as u32).to_le_bytes());
466 params_bytes[8..12].copy_from_slice(&(k as u32).to_le_bytes());
467 params_bytes[12..16].copy_from_slice(&(alpha as f32).to_le_bytes());
468 params_bytes[16..20].copy_from_slice(&(beta as f32).to_le_bytes());
469 params_bytes[20..24].copy_from_slice(&(batch_count as u32).to_le_bytes());
470 params_bytes[24..28].copy_from_slice(&(stride_a as u32).to_le_bytes());
471 params_bytes[28..32].copy_from_slice(&(stride_b as u32).to_le_bytes());
472 params_bytes[32..36].copy_from_slice(&(stride_c as u32).to_le_bytes());
473 let uniform_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
476 label: Some("oxicuda-batched-gemm-params"),
477 size: 48,
478 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
479 mapped_at_creation: false,
480 });
481 dev.queue.write_buffer(&uniform_buf, 0, ¶ms_bytes);
482
483 let bind_group = {
484 let buffers = mem
485 .lock_buffers()
486 .map_err(|e| BackendError::DeviceError(e.to_string()))?;
487 let a_info = buffers
488 .get(&a_ptr)
489 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {a_ptr}")))?;
490 let b_info = buffers
491 .get(&b_ptr)
492 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {b_ptr}")))?;
493 let c_info = buffers
494 .get(&c_ptr)
495 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {c_ptr}")))?;
496
497 dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
498 label: Some("oxicuda-batched-gemm"),
499 layout: &bgl,
500 entries: &[
501 wgpu::BindGroupEntry {
502 binding: 0,
503 resource: a_info.buffer.as_entire_binding(),
504 },
505 wgpu::BindGroupEntry {
506 binding: 1,
507 resource: b_info.buffer.as_entire_binding(),
508 },
509 wgpu::BindGroupEntry {
510 binding: 2,
511 resource: c_info.buffer.as_entire_binding(),
512 },
513 wgpu::BindGroupEntry {
514 binding: 3,
515 resource: uniform_buf.as_entire_binding(),
516 },
517 ],
518 })
519 };
520
521 let mut encoder = dev
522 .device
523 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
524 label: Some("oxicuda-batched-gemm"),
525 });
526
527 {
528 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
529 label: Some("oxicuda-batched-gemm"),
530 timestamp_writes: None,
531 });
532 pass.set_pipeline(&pipeline);
533 pass.set_bind_group(0, &bind_group, &[]);
534 let wg_x = (n as u32).div_ceil(tile_size);
535 let wg_y = (m as u32).div_ceil(tile_size);
536 pass.dispatch_workgroups(wg_x, wg_y, batch_count as u32);
537 }
538
539 dev.queue.submit(std::iter::once(encoder.finish()));
540 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
541
542 Ok(())
543 }
544
545 fn conv2d_forward(
546 &self,
547 input_ptr: u64,
548 input_shape: &[usize],
549 filter_ptr: u64,
550 filter_shape: &[usize],
551 output_ptr: u64,
552 output_shape: &[usize],
553 stride: &[usize],
554 padding: &[usize],
555 ) -> BackendResult<()> {
556 self.check_init()?;
557
558 if input_shape.len() != 4 {
559 return Err(BackendError::InvalidArgument(
560 "input_shape must have 4 elements (NCHW)".into(),
561 ));
562 }
563 if filter_shape.len() != 4 {
564 return Err(BackendError::InvalidArgument(
565 "filter_shape must have 4 elements (KCFHFW)".into(),
566 ));
567 }
568 if output_shape.len() != 4 {
569 return Err(BackendError::InvalidArgument(
570 "output_shape must have 4 elements (NKOhOw)".into(),
571 ));
572 }
573 if stride.len() != 2 {
574 return Err(BackendError::InvalidArgument(
575 "stride must have 2 elements [sh, sw]".into(),
576 ));
577 }
578 if padding.len() != 2 {
579 return Err(BackendError::InvalidArgument(
580 "padding must have 2 elements [ph, pw]".into(),
581 ));
582 }
583
584 let mem = self.memory()?;
585
586 let batch = input_shape[0];
587 let c_in = input_shape[1];
588 let h_in = input_shape[2];
589 let w_in = input_shape[3];
590 let k_out = filter_shape[0];
591 let fh = filter_shape[2];
592 let fw = filter_shape[3];
593 let oh = output_shape[2];
594 let ow = output_shape[3];
595 let sh = stride[0];
596 let sw = stride[1];
597 let ph = padding[0];
598 let pw = padding[1];
599
600 let in_elems: usize = input_shape.iter().product();
601 let f_elems: usize = filter_shape.iter().product();
602 let o_elems: usize = output_shape.iter().product();
603
604 let mut in_bytes = vec![0u8; in_elems * 4];
606 let mut f_bytes = vec![0u8; f_elems * 4];
607 mem.copy_from_device(&mut in_bytes, input_ptr)
608 .map_err(BackendError::from)?;
609 mem.copy_from_device(&mut f_bytes, filter_ptr)
610 .map_err(BackendError::from)?;
611
612 let in_f32 = bytes_to_f32_vec(&in_bytes);
613 let f_f32 = bytes_to_f32_vec(&f_bytes);
614 let mut out_f32 = vec![0.0f32; o_elems];
615
616 for b in 0..batch {
617 for kf in 0..k_out {
618 for oy in 0..oh {
619 for ox in 0..ow {
620 let mut acc = 0.0f32;
621 for ci in 0..c_in {
622 for fy in 0..fh {
623 for fx in 0..fw {
624 let iy = (oy * sh + fy) as isize - ph as isize;
625 let ix = (ox * sw + fx) as isize - pw as isize;
626 if iy >= 0
627 && (iy as usize) < h_in
628 && ix >= 0
629 && (ix as usize) < w_in
630 {
631 let in_idx = ((b * c_in + ci) * h_in + iy as usize) * w_in
632 + ix as usize;
633 let f_idx = ((kf * c_in + ci) * fh + fy) * fw + fx;
634 acc += in_f32[in_idx] * f_f32[f_idx];
635 }
636 }
637 }
638 }
639 out_f32[((b * k_out + kf) * oh + oy) * ow + ox] = acc;
640 }
641 }
642 }
643 }
644
645 let out_bytes = f32_slice_to_bytes(&out_f32);
646 mem.copy_to_device(output_ptr, &out_bytes)
647 .map_err(BackendError::from)?;
648
649 Ok(())
650 }
651
652 fn attention(
653 &self,
654 q_ptr: u64,
655 k_ptr: u64,
656 v_ptr: u64,
657 o_ptr: u64,
658 batch: usize,
659 heads: usize,
660 seq_q: usize,
661 seq_kv: usize,
662 head_dim: usize,
663 scale: f64,
664 causal: bool,
665 ) -> BackendResult<()> {
666 self.check_init()?;
667
668 if seq_q == 0 || seq_kv == 0 || head_dim == 0 {
669 return Err(BackendError::InvalidArgument(
670 "seq_q, seq_kv, and head_dim must all be > 0".into(),
671 ));
672 }
673 if scale <= 0.0 || !scale.is_finite() {
674 return Err(BackendError::InvalidArgument(format!(
675 "scale must be a positive finite number, got {scale}"
676 )));
677 }
678
679 let mem = self.memory()?;
680
681 let batch_heads = batch * heads;
682 let q_elems = batch_heads * seq_q * head_dim;
683 let kv_elems = batch_heads * seq_kv * head_dim;
684 let o_elems = q_elems;
685
686 let mut q_bytes = vec![0u8; q_elems * 4];
688 let mut k_bytes = vec![0u8; kv_elems * 4];
689 let mut v_bytes = vec![0u8; kv_elems * 4];
690
691 mem.copy_from_device(&mut q_bytes, q_ptr)
692 .map_err(BackendError::from)?;
693 mem.copy_from_device(&mut k_bytes, k_ptr)
694 .map_err(BackendError::from)?;
695 mem.copy_from_device(&mut v_bytes, v_ptr)
696 .map_err(BackendError::from)?;
697
698 let q_f32 = bytes_to_f32_vec(&q_bytes);
699 let k_f32 = bytes_to_f32_vec(&k_bytes);
700 let v_f32 = bytes_to_f32_vec(&v_bytes);
701 let mut o_f32 = vec![0.0f32; o_elems];
702
703 let scale_f32 = scale as f32;
704
705 for bh in 0..batch_heads {
706 let q_off = bh * seq_q * head_dim;
707 let k_off = bh * seq_kv * head_dim;
708 let v_off = k_off;
709
710 for sq in 0..seq_q {
711 let kv_limit = if causal { (sq + 1).min(seq_kv) } else { seq_kv };
712
713 let mut max_score = f32::NEG_INFINITY;
715 for sk in 0..kv_limit {
716 let mut dot = 0.0f32;
717 for dd in 0..head_dim {
718 dot +=
719 q_f32[q_off + sq * head_dim + dd] * k_f32[k_off + sk * head_dim + dd];
720 }
721 let s = dot * scale_f32;
722 if s > max_score {
723 max_score = s;
724 }
725 }
726
727 let mut sum_exp = 0.0f32;
729 let mut acc = vec![0.0f32; head_dim];
730 for sk in 0..kv_limit {
731 let mut dot = 0.0f32;
732 for dd in 0..head_dim {
733 dot +=
734 q_f32[q_off + sq * head_dim + dd] * k_f32[k_off + sk * head_dim + dd];
735 }
736 let w = (dot * scale_f32 - max_score).exp();
737 sum_exp += w;
738 for dd in 0..head_dim {
739 acc[dd] += w * v_f32[v_off + sk * head_dim + dd];
740 }
741 }
742
743 let o_base = q_off + sq * head_dim;
745 if sum_exp > 0.0 {
746 for dd in 0..head_dim {
747 o_f32[o_base + dd] = acc[dd] / sum_exp;
748 }
749 }
750 }
751 }
752
753 let o_bytes = f32_slice_to_bytes(&o_f32);
754 mem.copy_to_device(o_ptr, &o_bytes)
755 .map_err(BackendError::from)?;
756
757 Ok(())
758 }
759
760 fn reduce(
761 &self,
762 op: ReduceOp,
763 input_ptr: u64,
764 output_ptr: u64,
765 shape: &[usize],
766 axis: usize,
767 ) -> BackendResult<()> {
768 self.check_init()?;
769
770 if shape.is_empty() {
771 return Err(BackendError::InvalidArgument(
772 "shape must not be empty".into(),
773 ));
774 }
775 if axis >= shape.len() {
776 return Err(BackendError::InvalidArgument(format!(
777 "axis {axis} is out of bounds for shape of length {}",
778 shape.len()
779 )));
780 }
781
782 if shape.len() != 1 {
786 return Err(BackendError::Unsupported(
787 "WebGPU reduce currently supports only 1-D shapes".into(),
788 ));
789 }
790
791 let n_elements = shape[0];
792 if n_elements == 0 {
793 return Ok(());
794 }
795
796 let dev = self.device()?;
797 let mem = self.memory()?;
798 let op_str = map_reduce_op(op);
799
800 let wg_count = (n_elements as u32).div_ceil(256);
802
803 let pass1_wgsl = shader::reduction_wgsl(op_str);
804 let pass1_shader = dev
805 .device
806 .create_shader_module(wgpu::ShaderModuleDescriptor {
807 label: Some("oxicuda-reduce-pass1"),
808 source: wgpu::ShaderSource::Wgsl(pass1_wgsl.into()),
809 });
810 let pass1_pipeline = dev
811 .device
812 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
813 label: Some("oxicuda-reduce-pass1"),
814 layout: None,
815 module: &pass1_shader,
816 entry_point: Some("main"),
817 compilation_options: Default::default(),
818 cache: None,
819 });
820
821 let partial_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
823 label: Some("oxicuda-reduce-partial"),
824 size: (wg_count as u64) * 4, usage: wgpu::BufferUsages::STORAGE
826 | wgpu::BufferUsages::COPY_SRC
827 | wgpu::BufferUsages::COPY_DST,
828 mapped_at_creation: false,
829 });
830
831 let mut p1_params = [0u8; 4];
833 p1_params[0..4].copy_from_slice(&(n_elements as u32).to_le_bytes());
834 let p1_uniform = dev.device.create_buffer(&wgpu::BufferDescriptor {
835 label: Some("oxicuda-reduce-p1-params"),
836 size: 4,
837 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
838 mapped_at_creation: false,
839 });
840 dev.queue.write_buffer(&p1_uniform, 0, &p1_params);
841
842 let bgl1 = pass1_pipeline.get_bind_group_layout(0);
843
844 let bg1 = {
845 let buffers = mem
846 .lock_buffers()
847 .map_err(|e| BackendError::DeviceError(e.to_string()))?;
848 let in_info = buffers.get(&input_ptr).ok_or_else(|| {
849 BackendError::InvalidArgument(format!("unknown handle {input_ptr}"))
850 })?;
851
852 dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
853 label: Some("oxicuda-reduce-pass1"),
854 layout: &bgl1,
855 entries: &[
856 wgpu::BindGroupEntry {
857 binding: 0,
858 resource: in_info.buffer.as_entire_binding(),
859 },
860 wgpu::BindGroupEntry {
861 binding: 1,
862 resource: partial_buf.as_entire_binding(),
863 },
864 wgpu::BindGroupEntry {
865 binding: 2,
866 resource: p1_uniform.as_entire_binding(),
867 },
868 ],
869 })
870 };
871
872 let pass2_wgsl = shader::reduction_final_wgsl(op_str);
874 let pass2_shader = dev
875 .device
876 .create_shader_module(wgpu::ShaderModuleDescriptor {
877 label: Some("oxicuda-reduce-pass2"),
878 source: wgpu::ShaderSource::Wgsl(pass2_wgsl.into()),
879 });
880 let pass2_pipeline = dev
881 .device
882 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
883 label: Some("oxicuda-reduce-pass2"),
884 layout: None,
885 module: &pass2_shader,
886 entry_point: Some("main"),
887 compilation_options: Default::default(),
888 cache: None,
889 });
890
891 let mut p2_params = [0u8; 4];
893 p2_params[0..4].copy_from_slice(&wg_count.to_le_bytes());
894 let p2_uniform = dev.device.create_buffer(&wgpu::BufferDescriptor {
895 label: Some("oxicuda-reduce-p2-params"),
896 size: 4,
897 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
898 mapped_at_creation: false,
899 });
900 dev.queue.write_buffer(&p2_uniform, 0, &p2_params);
901
902 let bgl2 = pass2_pipeline.get_bind_group_layout(0);
903
904 let bg2 = {
905 let buffers = mem
906 .lock_buffers()
907 .map_err(|e| BackendError::DeviceError(e.to_string()))?;
908 let out_info = buffers.get(&output_ptr).ok_or_else(|| {
909 BackendError::InvalidArgument(format!("unknown handle {output_ptr}"))
910 })?;
911
912 dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
913 label: Some("oxicuda-reduce-pass2"),
914 layout: &bgl2,
915 entries: &[
916 wgpu::BindGroupEntry {
917 binding: 0,
918 resource: partial_buf.as_entire_binding(),
919 },
920 wgpu::BindGroupEntry {
921 binding: 1,
922 resource: out_info.buffer.as_entire_binding(),
923 },
924 wgpu::BindGroupEntry {
925 binding: 2,
926 resource: p2_uniform.as_entire_binding(),
927 },
928 ],
929 })
930 };
931
932 let mut encoder = dev
934 .device
935 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
936 label: Some("oxicuda-reduce"),
937 });
938
939 {
940 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
941 label: Some("oxicuda-reduce-pass1"),
942 timestamp_writes: None,
943 });
944 pass.set_pipeline(&pass1_pipeline);
945 pass.set_bind_group(0, &bg1, &[]);
946 pass.dispatch_workgroups(wg_count, 1, 1);
947 }
948 {
949 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
950 label: Some("oxicuda-reduce-pass2"),
951 timestamp_writes: None,
952 });
953 pass.set_pipeline(&pass2_pipeline);
954 pass.set_bind_group(0, &bg2, &[]);
955 pass.dispatch_workgroups(1, 1, 1);
956 }
957
958 dev.queue.submit(std::iter::once(encoder.finish()));
959 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
960
961 if op == ReduceOp::Mean && n_elements > 1 {
963 let mut buf = [0u8; 4];
964 mem.copy_from_device(&mut buf, output_ptr)
965 .map_err(BackendError::from)?;
966 let val = f32::from_le_bytes(buf);
967 let mean = val / (n_elements as f32);
968 mem.copy_to_device(output_ptr, &mean.to_le_bytes())
969 .map_err(BackendError::from)?;
970 }
971
972 Ok(())
973 }
974
975 fn unary(&self, op: UnaryOp, input_ptr: u64, output_ptr: u64, n: usize) -> BackendResult<()> {
976 self.check_init()?;
977 if n == 0 {
978 return Ok(());
979 }
980
981 let dev = self.device()?;
982 let mem = self.memory()?;
983
984 let op_str = map_unary_op(op);
985 let wgsl = shader::elementwise_wgsl(op_str);
986
987 let shader_mod = dev
988 .device
989 .create_shader_module(wgpu::ShaderModuleDescriptor {
990 label: Some("oxicuda-unary"),
991 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
992 });
993
994 let pipeline = dev
995 .device
996 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
997 label: Some("oxicuda-unary"),
998 layout: None,
999 module: &shader_mod,
1000 entry_point: Some("main"),
1001 compilation_options: Default::default(),
1002 cache: None,
1003 });
1004
1005 let bgl = pipeline.get_bind_group_layout(0);
1006
1007 let bind_group = {
1008 let buffers = mem
1009 .lock_buffers()
1010 .map_err(|e| BackendError::DeviceError(e.to_string()))?;
1011 let in_info = buffers.get(&input_ptr).ok_or_else(|| {
1012 BackendError::InvalidArgument(format!("unknown handle {input_ptr}"))
1013 })?;
1014 let out_info = buffers.get(&output_ptr).ok_or_else(|| {
1015 BackendError::InvalidArgument(format!("unknown handle {output_ptr}"))
1016 })?;
1017
1018 dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
1019 label: Some("oxicuda-unary"),
1020 layout: &bgl,
1021 entries: &[
1022 wgpu::BindGroupEntry {
1023 binding: 0,
1024 resource: in_info.buffer.as_entire_binding(),
1025 },
1026 wgpu::BindGroupEntry {
1027 binding: 1,
1028 resource: out_info.buffer.as_entire_binding(),
1029 },
1030 ],
1031 })
1032 };
1033
1034 let mut encoder = dev
1035 .device
1036 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1037 label: Some("oxicuda-unary"),
1038 });
1039
1040 {
1041 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1042 label: Some("oxicuda-unary"),
1043 timestamp_writes: None,
1044 });
1045 pass.set_pipeline(&pipeline);
1046 pass.set_bind_group(0, &bind_group, &[]);
1047 let workgroups = (n as u32).div_ceil(256);
1048 pass.dispatch_workgroups(workgroups, 1, 1);
1049 }
1050
1051 dev.queue.submit(std::iter::once(encoder.finish()));
1052 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
1053
1054 Ok(())
1055 }
1056
1057 fn binary(
1058 &self,
1059 op: BinaryOp,
1060 a_ptr: u64,
1061 b_ptr: u64,
1062 output_ptr: u64,
1063 n: usize,
1064 ) -> BackendResult<()> {
1065 self.check_init()?;
1066 if n == 0 {
1067 return Ok(());
1068 }
1069
1070 let dev = self.device()?;
1071 let mem = self.memory()?;
1072
1073 let op_str = map_binary_op(op);
1074 let wgsl = shader::binary_wgsl(op_str);
1075
1076 let shader_mod = dev
1077 .device
1078 .create_shader_module(wgpu::ShaderModuleDescriptor {
1079 label: Some("oxicuda-binary"),
1080 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1081 });
1082
1083 let pipeline = dev
1084 .device
1085 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1086 label: Some("oxicuda-binary"),
1087 layout: None,
1088 module: &shader_mod,
1089 entry_point: Some("main"),
1090 compilation_options: Default::default(),
1091 cache: None,
1092 });
1093
1094 let bgl = pipeline.get_bind_group_layout(0);
1095
1096 let bind_group = {
1097 let buffers = mem
1098 .lock_buffers()
1099 .map_err(|e| BackendError::DeviceError(e.to_string()))?;
1100 let a_info = buffers
1101 .get(&a_ptr)
1102 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {a_ptr}")))?;
1103 let b_info = buffers
1104 .get(&b_ptr)
1105 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {b_ptr}")))?;
1106 let out_info = buffers.get(&output_ptr).ok_or_else(|| {
1107 BackendError::InvalidArgument(format!("unknown handle {output_ptr}"))
1108 })?;
1109
1110 dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
1111 label: Some("oxicuda-binary"),
1112 layout: &bgl,
1113 entries: &[
1114 wgpu::BindGroupEntry {
1115 binding: 0,
1116 resource: a_info.buffer.as_entire_binding(),
1117 },
1118 wgpu::BindGroupEntry {
1119 binding: 1,
1120 resource: b_info.buffer.as_entire_binding(),
1121 },
1122 wgpu::BindGroupEntry {
1123 binding: 2,
1124 resource: out_info.buffer.as_entire_binding(),
1125 },
1126 ],
1127 })
1128 };
1129
1130 let mut encoder = dev
1131 .device
1132 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1133 label: Some("oxicuda-binary"),
1134 });
1135
1136 {
1137 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1138 label: Some("oxicuda-binary"),
1139 timestamp_writes: None,
1140 });
1141 pass.set_pipeline(&pipeline);
1142 pass.set_bind_group(0, &bind_group, &[]);
1143 let workgroups = (n as u32).div_ceil(256);
1144 pass.dispatch_workgroups(workgroups, 1, 1);
1145 }
1146
1147 dev.queue.submit(std::iter::once(encoder.finish()));
1148 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
1149
1150 Ok(())
1151 }
1152
1153 fn synchronize(&self) -> BackendResult<()> {
1156 self.check_init()?;
1157 if let Some(dev) = &self.device {
1158 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
1159 }
1160 Ok(())
1161 }
1162
1163 fn alloc(&self, bytes: usize) -> BackendResult<u64> {
1166 self.check_init()?;
1167 if bytes == 0 {
1168 return Err(BackendError::InvalidArgument(
1169 "cannot allocate 0 bytes".into(),
1170 ));
1171 }
1172 self.memory()?.alloc(bytes).map_err(BackendError::from)
1173 }
1174
1175 fn free(&self, ptr: u64) -> BackendResult<()> {
1176 self.check_init()?;
1177 self.memory()?.free(ptr).map_err(BackendError::from)
1178 }
1179
1180 fn copy_htod(&self, dst: u64, src: &[u8]) -> BackendResult<()> {
1181 self.check_init()?;
1182 if src.is_empty() {
1183 return Ok(());
1184 }
1185 self.memory()?
1186 .copy_to_device(dst, src)
1187 .map_err(BackendError::from)
1188 }
1189
1190 fn copy_dtoh(&self, dst: &mut [u8], src: u64) -> BackendResult<()> {
1191 self.check_init()?;
1192 if dst.is_empty() {
1193 return Ok(());
1194 }
1195 self.memory()?
1196 .copy_from_device(dst, src)
1197 .map_err(BackendError::from)
1198 }
1199}
1200
1201fn bytes_to_f32_vec(bytes: &[u8]) -> Vec<f32> {
1205 bytes
1206 .chunks_exact(4)
1207 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
1208 .collect()
1209}
1210
1211fn f32_slice_to_bytes(data: &[f32]) -> Vec<u8> {
1213 data.iter().flat_map(|v| v.to_le_bytes()).collect()
1214}
1215
1216#[cfg(test)]
1219mod tests {
1220 use super::*;
1221 use oxicuda_backend::{BackendTranspose, BinaryOp, ReduceOp, UnaryOp};
1222
1223 #[test]
1226 fn webgpu_backend_new_uninitialized() {
1227 let b = WebGpuBackend::new();
1228 assert!(!b.is_initialized());
1229 }
1230
1231 #[test]
1232 fn webgpu_backend_name() {
1233 let b = WebGpuBackend::new();
1234 assert_eq!(b.name(), "webgpu");
1235 }
1236
1237 #[test]
1238 fn webgpu_backend_default() {
1239 let b = WebGpuBackend::default();
1240 assert!(!b.is_initialized());
1241 assert_eq!(b.name(), "webgpu");
1242 }
1243
1244 #[test]
1245 fn backend_debug_impl() {
1246 let b = WebGpuBackend::new();
1247 let s = format!("{b:?}");
1248 assert!(s.contains("WebGpuBackend"));
1249 }
1250
1251 #[test]
1254 fn backend_object_safe() {
1255 let b: Box<dyn ComputeBackend> = Box::new(WebGpuBackend::new());
1256 assert_eq!(b.name(), "webgpu");
1257 }
1258
1259 #[test]
1262 fn backend_not_initialized_gemm() {
1263 let b = WebGpuBackend::new();
1264 let result = b.gemm(
1265 BackendTranspose::NoTrans,
1266 BackendTranspose::NoTrans,
1267 4,
1268 4,
1269 4,
1270 1.0,
1271 0,
1272 4,
1273 0,
1274 4,
1275 0.0,
1276 0,
1277 4,
1278 );
1279 assert_eq!(result, Err(BackendError::NotInitialized));
1280 }
1281
1282 #[test]
1283 fn backend_not_initialized_alloc() {
1284 let b = WebGpuBackend::new();
1285 let result = b.alloc(1024);
1286 assert_eq!(result, Err(BackendError::NotInitialized));
1287 }
1288
1289 #[test]
1290 fn backend_not_initialized_synchronize() {
1291 let b = WebGpuBackend::new();
1292 assert_eq!(b.synchronize(), Err(BackendError::NotInitialized));
1293 }
1294
1295 #[test]
1296 fn backend_not_initialized_free() {
1297 let b = WebGpuBackend::new();
1298 assert_eq!(b.free(1), Err(BackendError::NotInitialized));
1299 }
1300
1301 #[test]
1302 fn backend_not_initialized_copy_htod() {
1303 let b = WebGpuBackend::new();
1304 assert_eq!(b.copy_htod(1, b"hello"), Err(BackendError::NotInitialized));
1305 }
1306
1307 #[test]
1308 fn backend_not_initialized_copy_dtoh() {
1309 let b = WebGpuBackend::new();
1310 let mut buf = [0u8; 4];
1311 assert_eq!(b.copy_dtoh(&mut buf, 1), Err(BackendError::NotInitialized));
1312 }
1313
1314 fn try_init() -> Option<WebGpuBackend> {
1319 let mut b = WebGpuBackend::new();
1320 match b.init() {
1321 Ok(()) => Some(b),
1322 Err(_) => None,
1323 }
1324 }
1325
1326 #[test]
1327 fn gemm_zero_size_after_init() {
1328 let Some(b) = try_init() else {
1329 return;
1330 };
1331 let result = b.gemm(
1332 BackendTranspose::NoTrans,
1333 BackendTranspose::NoTrans,
1334 0,
1335 0,
1336 0,
1337 1.0,
1338 0,
1339 1,
1340 0,
1341 1,
1342 0.0,
1343 0,
1344 1,
1345 );
1346 assert_eq!(result, Ok(()));
1347 }
1348
1349 #[test]
1350 fn unary_zero_elements_after_init() {
1351 let Some(b) = try_init() else {
1352 return;
1353 };
1354 assert_eq!(b.unary(UnaryOp::Relu, 0, 0, 0), Ok(()));
1355 }
1356
1357 #[test]
1358 fn binary_zero_elements_after_init() {
1359 let Some(b) = try_init() else {
1360 return;
1361 };
1362 assert_eq!(b.binary(BinaryOp::Add, 0, 0, 0, 0), Ok(()));
1363 }
1364
1365 #[test]
1366 fn copy_htod_empty_noop() {
1367 let Some(b) = try_init() else {
1368 return;
1369 };
1370 assert_eq!(b.copy_htod(0, &[]), Ok(()));
1371 }
1372
1373 #[test]
1374 fn copy_dtoh_empty_noop() {
1375 let Some(b) = try_init() else {
1376 return;
1377 };
1378 assert_eq!(b.copy_dtoh(&mut [], 0), Ok(()));
1379 }
1380
1381 #[test]
1382 fn alloc_zero_bytes_error() {
1383 let Some(b) = try_init() else {
1384 return;
1385 };
1386 assert_eq!(
1387 b.alloc(0),
1388 Err(BackendError::InvalidArgument(
1389 "cannot allocate 0 bytes".into()
1390 ))
1391 );
1392 }
1393
1394 #[test]
1395 fn synchronize_after_init() {
1396 let Some(b) = try_init() else {
1397 return;
1398 };
1399 assert_eq!(b.synchronize(), Ok(()));
1400 }
1401
1402 #[test]
1405 fn reduce_empty_shape_error() {
1406 let Some(b) = try_init() else {
1407 return;
1408 };
1409 assert_eq!(
1410 b.reduce(ReduceOp::Sum, 0, 0, &[], 0),
1411 Err(BackendError::InvalidArgument(
1412 "shape must not be empty".into()
1413 ))
1414 );
1415 }
1416
1417 #[test]
1418 fn reduce_axis_out_of_bounds_error() {
1419 let Some(b) = try_init() else {
1420 return;
1421 };
1422 assert_eq!(
1423 b.reduce(ReduceOp::Sum, 0, 0, &[4, 4], 5),
1424 Err(BackendError::InvalidArgument(
1425 "axis 5 is out of bounds for shape of length 2".into()
1426 ))
1427 );
1428 }
1429
1430 #[test]
1431 fn attention_zero_seq_error() {
1432 let Some(b) = try_init() else {
1433 return;
1434 };
1435 assert_eq!(
1436 b.attention(0, 0, 0, 0, 1, 1, 0, 8, 64, 0.125, false),
1437 Err(BackendError::InvalidArgument(
1438 "seq_q, seq_kv, and head_dim must all be > 0".into()
1439 ))
1440 );
1441 }
1442
1443 #[test]
1444 fn attention_nonpositive_scale_error() {
1445 let Some(b) = try_init() else {
1446 return;
1447 };
1448 assert_eq!(
1449 b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, 0.0, false),
1450 Err(BackendError::InvalidArgument(
1451 "scale must be a positive finite number, got 0".into()
1452 ))
1453 );
1454 assert_eq!(
1455 b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, -1.0, false),
1456 Err(BackendError::InvalidArgument(
1457 "scale must be a positive finite number, got -1".into()
1458 ))
1459 );
1460 assert!(
1461 b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, f64::INFINITY, false)
1462 .is_err()
1463 );
1464 }
1465
1466 #[test]
1467 fn conv2d_wrong_input_shape_error() {
1468 let Some(b) = try_init() else {
1469 return;
1470 };
1471 assert_eq!(
1473 b.conv2d_forward(
1474 0,
1475 &[1, 3, 32],
1476 0,
1477 &[16, 3, 3, 3],
1478 0,
1479 &[1, 16, 30, 30],
1480 &[1, 1],
1481 &[0, 0]
1482 ),
1483 Err(BackendError::InvalidArgument(
1484 "input_shape must have 4 elements (NCHW)".into()
1485 ))
1486 );
1487 }
1488
1489 #[test]
1490 fn conv2d_wrong_filter_shape_error() {
1491 let Some(b) = try_init() else {
1492 return;
1493 };
1494 assert_eq!(
1495 b.conv2d_forward(
1496 0,
1497 &[1, 3, 32, 32],
1498 0,
1499 &[16, 3, 3],
1500 0,
1501 &[1, 16, 30, 30],
1502 &[1, 1],
1503 &[0, 0]
1504 ),
1505 Err(BackendError::InvalidArgument(
1506 "filter_shape must have 4 elements (KCFHFW)".into()
1507 ))
1508 );
1509 }
1510
1511 #[test]
1512 fn conv2d_wrong_stride_shape_error() {
1513 let Some(b) = try_init() else {
1514 return;
1515 };
1516 assert_eq!(
1517 b.conv2d_forward(
1518 0,
1519 &[1, 3, 32, 32],
1520 0,
1521 &[16, 3, 3, 3],
1522 0,
1523 &[1, 16, 30, 30],
1524 &[1], &[0, 0],
1526 ),
1527 Err(BackendError::InvalidArgument(
1528 "stride must have 2 elements [sh, sw]".into()
1529 ))
1530 );
1531 }
1532
1533 #[test]
1536 fn init_idempotent() {
1537 let Some(mut b) = try_init() else {
1538 return;
1539 };
1540 assert_eq!(b.init(), Ok(()));
1542 assert!(b.is_initialized());
1543 }
1544
1545 #[test]
1548 fn webgpu_init_graceful_failure() {
1549 let mut b = WebGpuBackend::new();
1552 let _result = b.init(); }
1555
1556 fn upload_f32(b: &WebGpuBackend, data: &[f32]) -> u64 {
1563 let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
1564 let h = b.alloc(bytes.len()).expect("alloc");
1565 b.copy_htod(h, &bytes).expect("copy_htod");
1566 h
1567 }
1568
1569 fn download_f32(b: &WebGpuBackend, h: u64, n: usize) -> Vec<f32> {
1571 let mut bytes = vec![0u8; n * 4];
1572 b.copy_dtoh(&mut bytes, h).expect("copy_dtoh");
1573 bytes
1574 .chunks_exact(4)
1575 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
1576 .collect()
1577 }
1578
1579 #[test]
1580 fn unary_neg_small() {
1581 let Some(b) = try_init() else { return };
1582 let input = [1.0f32, -2.0, 3.0, 0.0];
1583 let in_h = upload_f32(&b, &input);
1584 let out_h = b.alloc(input.len() * 4).expect("alloc output");
1585
1586 b.unary(UnaryOp::Neg, in_h, out_h, input.len())
1587 .expect("unary neg");
1588
1589 let result = download_f32(&b, out_h, input.len());
1590 let expected = [-1.0f32, 2.0, -3.0, 0.0];
1591 for (r, e) in result.iter().zip(expected.iter()) {
1592 assert!((r - e).abs() < 1e-6, "got {r}, expected {e}");
1593 }
1594
1595 b.free(in_h).expect("free");
1596 b.free(out_h).expect("free");
1597 }
1598
1599 #[test]
1600 fn unary_abs_small() {
1601 let Some(b) = try_init() else { return };
1602 let input = [-3.0f32, 4.0, -5.0, 0.0];
1603 let in_h = upload_f32(&b, &input);
1604 let out_h = b.alloc(input.len() * 4).expect("alloc output");
1605
1606 b.unary(UnaryOp::Abs, in_h, out_h, input.len())
1607 .expect("unary abs");
1608
1609 let result = download_f32(&b, out_h, input.len());
1610 let expected = [3.0f32, 4.0, 5.0, 0.0];
1611 for (r, e) in result.iter().zip(expected.iter()) {
1612 assert!((r - e).abs() < 1e-6, "got {r}, expected {e}");
1613 }
1614
1615 b.free(in_h).expect("free");
1616 b.free(out_h).expect("free");
1617 }
1618
1619 #[test]
1620 fn binary_add_small() {
1621 let Some(b) = try_init() else { return };
1622 let a = [1.0f32, 2.0, 3.0, 4.0];
1623 let bv = [10.0f32, 20.0, 30.0, 40.0];
1624 let a_h = upload_f32(&b, &a);
1625 let b_h = upload_f32(&b, &bv);
1626 let out_h = b.alloc(a.len() * 4).expect("alloc output");
1627
1628 b.binary(BinaryOp::Add, a_h, b_h, out_h, a.len())
1629 .expect("binary add");
1630
1631 let result = download_f32(&b, out_h, a.len());
1632 let expected = [11.0f32, 22.0, 33.0, 44.0];
1633 for (r, e) in result.iter().zip(expected.iter()) {
1634 assert!((r - e).abs() < 1e-6, "got {r}, expected {e}");
1635 }
1636
1637 b.free(a_h).expect("free");
1638 b.free(b_h).expect("free");
1639 b.free(out_h).expect("free");
1640 }
1641
1642 #[test]
1643 fn binary_mul_small() {
1644 let Some(b) = try_init() else { return };
1645 let a = [2.0f32, 3.0, 4.0, 5.0];
1646 let bv = [10.0f32, 10.0, 10.0, 10.0];
1647 let a_h = upload_f32(&b, &a);
1648 let b_h = upload_f32(&b, &bv);
1649 let out_h = b.alloc(a.len() * 4).expect("alloc output");
1650
1651 b.binary(BinaryOp::Mul, a_h, b_h, out_h, a.len())
1652 .expect("binary mul");
1653
1654 let result = download_f32(&b, out_h, a.len());
1655 let expected = [20.0f32, 30.0, 40.0, 50.0];
1656 for (r, e) in result.iter().zip(expected.iter()) {
1657 assert!((r - e).abs() < 1e-6, "got {r}, expected {e}");
1658 }
1659
1660 b.free(a_h).expect("free");
1661 b.free(b_h).expect("free");
1662 b.free(out_h).expect("free");
1663 }
1664
1665 #[test]
1666 fn reduce_sum_small() {
1667 let Some(b) = try_init() else { return };
1668 let input = [1.0f32, 2.0, 3.0, 4.0];
1669 let in_h = upload_f32(&b, &input);
1670 let out_h = b.alloc(4).expect("alloc output"); b.reduce(ReduceOp::Sum, in_h, out_h, &[4], 0)
1673 .expect("reduce sum");
1674
1675 let result = download_f32(&b, out_h, 1);
1676 assert!(
1677 (result[0] - 10.0).abs() < 1e-5,
1678 "expected 10.0, got {}",
1679 result[0]
1680 );
1681
1682 b.free(in_h).expect("free");
1683 b.free(out_h).expect("free");
1684 }
1685
1686 #[test]
1687 fn reduce_max_small() {
1688 let Some(b) = try_init() else { return };
1689 let input = [1.0f32, 5.0, 3.0, 2.0];
1690 let in_h = upload_f32(&b, &input);
1691 let out_h = b.alloc(4).expect("alloc output");
1692
1693 b.reduce(ReduceOp::Max, in_h, out_h, &[4], 0)
1694 .expect("reduce max");
1695
1696 let result = download_f32(&b, out_h, 1);
1697 assert!(
1698 (result[0] - 5.0).abs() < 1e-5,
1699 "expected 5.0, got {}",
1700 result[0]
1701 );
1702
1703 b.free(in_h).expect("free");
1704 b.free(out_h).expect("free");
1705 }
1706
1707 #[test]
1708 fn reduce_mean_small() {
1709 let Some(b) = try_init() else { return };
1710 let input = [2.0f32, 4.0, 6.0, 8.0];
1711 let in_h = upload_f32(&b, &input);
1712 let out_h = b.alloc(4).expect("alloc output");
1713
1714 b.reduce(ReduceOp::Mean, in_h, out_h, &[4], 0)
1715 .expect("reduce mean");
1716
1717 let result = download_f32(&b, out_h, 1);
1718 assert!(
1719 (result[0] - 5.0).abs() < 1e-5,
1720 "expected 5.0, got {}",
1721 result[0]
1722 );
1723
1724 b.free(in_h).expect("free");
1725 b.free(out_h).expect("free");
1726 }
1727
1728 #[test]
1729 fn gemm_identity_2x2() {
1730 let Some(b) = try_init() else { return };
1731 let a = [1.0f32, 2.0, 3.0, 4.0];
1734 let eye = [1.0f32, 0.0, 0.0, 1.0];
1735 let c_init = [0.0f32; 4];
1736
1737 let a_h = upload_f32(&b, &a);
1738 let b_h = upload_f32(&b, &eye);
1739 let c_h = upload_f32(&b, &c_init);
1740
1741 b.gemm(
1742 BackendTranspose::NoTrans,
1743 BackendTranspose::NoTrans,
1744 2,
1745 2,
1746 2,
1747 1.0,
1748 a_h,
1749 2,
1750 b_h,
1751 2,
1752 0.0,
1753 c_h,
1754 2,
1755 )
1756 .expect("gemm");
1757
1758 let result = download_f32(&b, c_h, 4);
1759 for (r, e) in result.iter().zip(a.iter()) {
1760 assert!((r - e).abs() < 1e-5, "got {r}, expected {e}");
1761 }
1762
1763 b.free(a_h).expect("free");
1764 b.free(b_h).expect("free");
1765 b.free(c_h).expect("free");
1766 }
1767
1768 #[test]
1769 fn gemm_2x3_times_3x2() {
1770 let Some(b) = try_init() else { return };
1771 let a = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
1773 let bm = [7.0f32, 8.0, 9.0, 10.0, 11.0, 12.0];
1774 let c_init = [0.0f32; 4];
1775
1776 let a_h = upload_f32(&b, &a);
1777 let b_h = upload_f32(&b, &bm);
1778 let c_h = upload_f32(&b, &c_init);
1779
1780 b.gemm(
1781 BackendTranspose::NoTrans,
1782 BackendTranspose::NoTrans,
1783 2,
1784 2,
1785 3,
1786 1.0,
1787 a_h,
1788 3,
1789 b_h,
1790 2,
1791 0.0,
1792 c_h,
1793 2,
1794 )
1795 .expect("gemm");
1796
1797 let result = download_f32(&b, c_h, 4);
1799 let expected = [58.0f32, 64.0, 139.0, 154.0];
1800 for (r, e) in result.iter().zip(expected.iter()) {
1801 assert!((r - e).abs() < 1e-4, "got {r}, expected {e}");
1802 }
1803
1804 b.free(a_h).expect("free");
1805 b.free(b_h).expect("free");
1806 b.free(c_h).expect("free");
1807 }
1808
1809 #[test]
1810 fn gemm_alpha_beta() {
1811 let Some(b) = try_init() else { return };
1812 let a = [1.0f32, 0.0, 0.0, 1.0];
1816 let bm = [1.0f32, 0.0, 0.0, 1.0];
1817 let c_init = [1.0f32, 1.0, 1.0, 1.0];
1818
1819 let a_h = upload_f32(&b, &a);
1820 let b_h = upload_f32(&b, &bm);
1821 let c_h = upload_f32(&b, &c_init);
1822
1823 b.gemm(
1824 BackendTranspose::NoTrans,
1825 BackendTranspose::NoTrans,
1826 2,
1827 2,
1828 2,
1829 2.0,
1830 a_h,
1831 2,
1832 b_h,
1833 2,
1834 3.0,
1835 c_h,
1836 2,
1837 )
1838 .expect("gemm alpha+beta");
1839
1840 let result = download_f32(&b, c_h, 4);
1841 let expected = [5.0f32, 3.0, 3.0, 5.0];
1842 for (r, e) in result.iter().zip(expected.iter()) {
1843 assert!((r - e).abs() < 1e-4, "got {r}, expected {e}");
1844 }
1845
1846 b.free(a_h).expect("free");
1847 b.free(b_h).expect("free");
1848 b.free(c_h).expect("free");
1849 }
1850
1851 #[test]
1854 fn conv2d_identity_1x1() {
1855 let Some(b) = try_init() else { return };
1858 let input: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1859 let filter = [2.0f32];
1860 let expected: Vec<f32> = input.iter().map(|x| x * 2.0).collect();
1861
1862 let in_h = upload_f32(&b, &input);
1863 let f_h = upload_f32(&b, &filter);
1864 let out_h = b.alloc(9 * 4).expect("alloc output");
1865
1866 b.conv2d_forward(
1867 in_h,
1868 &[1, 1, 3, 3],
1869 f_h,
1870 &[1, 1, 1, 1],
1871 out_h,
1872 &[1, 1, 3, 3],
1873 &[1, 1],
1874 &[0, 0],
1875 )
1876 .expect("conv2d");
1877
1878 let result = download_f32(&b, out_h, 9);
1879 for (r, e) in result.iter().zip(expected.iter()) {
1880 assert!((r - e).abs() < 1e-5, "got {r}, expected {e}");
1881 }
1882
1883 b.free(in_h).expect("free");
1884 b.free(f_h).expect("free");
1885 b.free(out_h).expect("free");
1886 }
1887
1888 #[test]
1889 fn conv2d_3x3_no_padding() {
1890 let Some(b) = try_init() else { return };
1893 let input: Vec<f32> = (0..16).map(|x| x as f32).collect();
1894 let filter = [1.0f32; 9];
1895
1896 let in_h = upload_f32(&b, &input);
1897 let f_h = upload_f32(&b, &filter);
1898 let out_h = b.alloc(4 * 4).expect("alloc output");
1899
1900 b.conv2d_forward(
1901 in_h,
1902 &[1, 1, 4, 4],
1903 f_h,
1904 &[1, 1, 3, 3],
1905 out_h,
1906 &[1, 1, 2, 2],
1907 &[1, 1],
1908 &[0, 0],
1909 )
1910 .expect("conv2d");
1911
1912 let result = download_f32(&b, out_h, 4);
1913 assert!((result[0] - 45.0).abs() < 1e-4, "got {}", result[0]);
1915 assert!((result[1] - 54.0).abs() < 1e-4, "got {}", result[1]);
1917
1918 b.free(in_h).expect("free");
1919 b.free(f_h).expect("free");
1920 b.free(out_h).expect("free");
1921 }
1922
1923 #[test]
1924 fn conv2d_with_padding() {
1925 let Some(b) = try_init() else { return };
1929 let input = [1.0f32, 2.0, 3.0, 4.0];
1930 let filter = [1.0f32; 9];
1931
1932 let in_h = upload_f32(&b, &input);
1933 let f_h = upload_f32(&b, &filter);
1934 let out_h = b.alloc(4 * 4).expect("alloc output");
1935
1936 b.conv2d_forward(
1937 in_h,
1938 &[1, 1, 2, 2],
1939 f_h,
1940 &[1, 1, 3, 3],
1941 out_h,
1942 &[1, 1, 2, 2],
1943 &[1, 1],
1944 &[1, 1],
1945 )
1946 .expect("conv2d");
1947
1948 let result = download_f32(&b, out_h, 4);
1949 assert!((result[0] - 10.0).abs() < 1e-4, "got {}", result[0]);
1952
1953 b.free(in_h).expect("free");
1954 b.free(f_h).expect("free");
1955 b.free(out_h).expect("free");
1956 }
1957
1958 #[test]
1961 fn attention_uniform_weights() {
1962 let Some(b) = try_init() else { return };
1966
1967 let q = [1.0f32, 0.0];
1968 let k = [1.0f32, 0.0, 1.0, 0.0];
1969 let v = [1.0f32, 2.0, 3.0, 4.0];
1970
1971 let q_h = upload_f32(&b, &q);
1972 let k_h = upload_f32(&b, &k);
1973 let v_h = upload_f32(&b, &v);
1974 let o_h = b.alloc(2 * 4).expect("alloc output");
1975
1976 b.attention(q_h, k_h, v_h, o_h, 1, 1, 1, 2, 2, 1.0, false)
1977 .expect("attention");
1978
1979 let result = download_f32(&b, o_h, 2);
1980 assert!(
1982 (result[0] - 2.0).abs() < 1e-4,
1983 "got {}, expected 2.0",
1984 result[0]
1985 );
1986 assert!(
1987 (result[1] - 3.0).abs() < 1e-4,
1988 "got {}, expected 3.0",
1989 result[1]
1990 );
1991
1992 b.free(q_h).expect("free");
1993 b.free(k_h).expect("free");
1994 b.free(v_h).expect("free");
1995 b.free(o_h).expect("free");
1996 }
1997
1998 #[test]
1999 fn attention_causal_single_token() {
2000 let Some(b) = try_init() else { return };
2005
2006 let q = [1.0f32, 1.0];
2007 let k = [1.0f32, 1.0];
2008 let v = [10.0f32, 20.0];
2009
2010 let q_h = upload_f32(&b, &q);
2011 let k_h = upload_f32(&b, &k);
2012 let v_h = upload_f32(&b, &v);
2013 let o_h = b.alloc(2 * 4).expect("alloc output");
2014
2015 b.attention(q_h, k_h, v_h, o_h, 1, 1, 2, 2, 1, 1.0, true)
2016 .expect("attention causal");
2017
2018 let result = download_f32(&b, o_h, 2);
2019 assert!(
2020 (result[0] - 10.0).abs() < 1e-4,
2021 "got {}, expected 10.0",
2022 result[0]
2023 );
2024 assert!(
2025 (result[1] - 15.0).abs() < 1e-4,
2026 "got {}, expected 15.0",
2027 result[1]
2028 );
2029
2030 b.free(q_h).expect("free");
2031 b.free(k_h).expect("free");
2032 b.free(v_h).expect("free");
2033 b.free(o_h).expect("free");
2034 }
2035
2036 #[test]
2039 fn batched_gemm_not_initialized() {
2040 let b = WebGpuBackend::new();
2041 let result = b.batched_gemm(
2042 BackendTranspose::NoTrans,
2043 BackendTranspose::NoTrans,
2044 4,
2045 4,
2046 4,
2047 1.0,
2048 0,
2049 4,
2050 16,
2051 0,
2052 4,
2053 16,
2054 0.0,
2055 0,
2056 4,
2057 16,
2058 2,
2059 );
2060 assert_eq!(result, Err(BackendError::NotInitialized));
2061 }
2062
2063 #[test]
2064 fn batched_gemm_zero_batch_noop() {
2065 let Some(b) = try_init() else { return };
2066 let result = b.batched_gemm(
2067 BackendTranspose::NoTrans,
2068 BackendTranspose::NoTrans,
2069 4,
2070 4,
2071 4,
2072 1.0,
2073 0,
2074 4,
2075 16,
2076 0,
2077 4,
2078 16,
2079 0.0,
2080 0,
2081 4,
2082 16,
2083 0, );
2085 assert_eq!(result, Ok(()));
2086 }
2087
2088 #[test]
2089 fn batched_gemm_zero_dims_noop() {
2090 let Some(b) = try_init() else { return };
2091 let result = b.batched_gemm(
2093 BackendTranspose::NoTrans,
2094 BackendTranspose::NoTrans,
2095 0,
2096 4,
2097 4,
2098 1.0,
2099 0,
2100 4,
2101 16,
2102 0,
2103 4,
2104 16,
2105 0.0,
2106 0,
2107 4,
2108 16,
2109 2,
2110 );
2111 assert_eq!(result, Ok(()));
2112 let result = b.batched_gemm(
2114 BackendTranspose::NoTrans,
2115 BackendTranspose::NoTrans,
2116 4,
2117 0,
2118 4,
2119 1.0,
2120 0,
2121 4,
2122 16,
2123 0,
2124 4,
2125 16,
2126 0.0,
2127 0,
2128 4,
2129 16,
2130 2,
2131 );
2132 assert_eq!(result, Ok(()));
2133 let result = b.batched_gemm(
2135 BackendTranspose::NoTrans,
2136 BackendTranspose::NoTrans,
2137 4,
2138 4,
2139 0,
2140 1.0,
2141 0,
2142 4,
2143 16,
2144 0,
2145 4,
2146 16,
2147 0.0,
2148 0,
2149 4,
2150 16,
2151 2,
2152 );
2153 assert_eq!(result, Ok(()));
2154 }
2155
2156 #[test]
2157 fn batched_gemm_identity_2x2() {
2158 let Some(b) = try_init() else { return };
2159 let a = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
2163 let eye = [1.0f32, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0];
2164 let c_init = [0.0f32; 8];
2165
2166 let a_h = upload_f32(&b, &a);
2167 let b_h = upload_f32(&b, &eye);
2168 let c_h = upload_f32(&b, &c_init);
2169
2170 b.batched_gemm(
2171 BackendTranspose::NoTrans,
2172 BackendTranspose::NoTrans,
2173 2,
2174 2,
2175 2,
2176 1.0,
2177 a_h,
2178 2,
2179 4, b_h,
2181 2,
2182 4, 0.0,
2184 c_h,
2185 2,
2186 4, 2, )
2189 .expect("batched_gemm");
2190
2191 let result = download_f32(&b, c_h, 8);
2192 for (r, e) in result.iter().zip(a.iter()) {
2193 assert!((r - e).abs() < 1e-5, "got {r}, expected {e}");
2194 }
2195
2196 b.free(a_h).expect("free");
2197 b.free(b_h).expect("free");
2198 b.free(c_h).expect("free");
2199 }
2200
2201 #[test]
2204 fn gemm_f16_not_initialized() {
2205 let b = WebGpuBackend::new();
2206 let result = b.gemm_f16(4, 4, 4, 1.0, 0, 0, 0.0, 0);
2207 assert_eq!(result, Err(BackendError::NotInitialized));
2208 }
2209
2210 #[test]
2211 fn gemm_f16_zero_dims_noop() {
2212 let Some(b) = try_init() else { return };
2213 assert_eq!(b.gemm_f16(0, 4, 4, 1.0, 0, 0, 0.0, 0), Ok(()));
2214 assert_eq!(b.gemm_f16(4, 0, 4, 1.0, 0, 0, 0.0, 0), Ok(()));
2215 assert_eq!(b.gemm_f16(4, 4, 0, 1.0, 0, 0, 0.0, 0), Ok(()));
2216 }
2217
2218 #[test]
2219 fn attention_dominant_key() {
2220 let Some(b) = try_init() else { return };
2225
2226 let q = [1.0f32, 0.0];
2227 let k = [10.0f32, 0.0, 0.0, 0.0];
2228 let v = [100.0f32, 200.0, 0.0, 0.0];
2229
2230 let q_h = upload_f32(&b, &q);
2231 let k_h = upload_f32(&b, &k);
2232 let v_h = upload_f32(&b, &v);
2233 let o_h = b.alloc(2 * 4).expect("alloc output");
2234
2235 b.attention(q_h, k_h, v_h, o_h, 1, 1, 1, 2, 2, 1.0, false)
2237 .expect("attention dominant");
2238
2239 let result = download_f32(&b, o_h, 2);
2240 assert!(
2241 (result[0] - 100.0).abs() < 0.1,
2242 "got {}, expected ~100",
2243 result[0]
2244 );
2245 assert!(
2246 (result[1] - 200.0).abs() < 0.1,
2247 "got {}, expected ~200",
2248 result[1]
2249 );
2250
2251 b.free(q_h).expect("free");
2252 b.free(k_h).expect("free");
2253 b.free(v_h).expect("free");
2254 b.free(o_h).expect("free");
2255 }
2256}