1use std::sync::Arc;
7
8use oxicuda_backend::{
9 BackendError, BackendResult, BackendTranspose, BinaryOp, ComputeBackend, ReduceOp, UnaryOp,
10};
11use wgpu;
12
13use crate::{device::WebGpuDevice, memory::WebGpuMemoryManager, shader};
14
15fn map_unary_op(op: UnaryOp) -> &'static str {
18 match op {
19 UnaryOp::Relu => "relu",
20 UnaryOp::Sigmoid => "sigmoid",
21 UnaryOp::Tanh => "tanh",
22 UnaryOp::Exp => "exp",
23 UnaryOp::Log => "log",
24 UnaryOp::Sqrt => "sqrt",
25 UnaryOp::Abs => "abs",
26 UnaryOp::Neg => "neg",
27 }
28}
29
30fn map_binary_op(op: BinaryOp) -> &'static str {
31 match op {
32 BinaryOp::Add => "add",
33 BinaryOp::Sub => "sub",
34 BinaryOp::Mul => "mul",
35 BinaryOp::Div => "div",
36 BinaryOp::Max => "max",
37 BinaryOp::Min => "min",
38 }
39}
40
41fn map_reduce_op(op: ReduceOp) -> &'static str {
42 match op {
43 ReduceOp::Sum => "sum",
44 ReduceOp::Max => "max",
45 ReduceOp::Min => "min",
46 ReduceOp::Mean => "mean",
47 }
48}
49
50#[derive(Debug)]
61pub struct WebGpuBackend {
62 device: Option<Arc<WebGpuDevice>>,
63 memory: Option<Arc<WebGpuMemoryManager>>,
64 initialized: bool,
65}
66
67impl WebGpuBackend {
68 pub fn new() -> Self {
70 Self {
71 device: None,
72 memory: None,
73 initialized: false,
74 }
75 }
76
77 fn check_init(&self) -> BackendResult<()> {
79 if self.initialized {
80 Ok(())
81 } else {
82 Err(BackendError::NotInitialized)
83 }
84 }
85
86 fn memory(&self) -> BackendResult<&Arc<WebGpuMemoryManager>> {
88 self.memory.as_ref().ok_or(BackendError::NotInitialized)
89 }
90
91 fn device(&self) -> BackendResult<&Arc<WebGpuDevice>> {
93 self.device.as_ref().ok_or(BackendError::NotInitialized)
94 }
95}
96
97impl Default for WebGpuBackend {
98 fn default() -> Self {
99 Self::new()
100 }
101}
102
103impl ComputeBackend for WebGpuBackend {
106 fn name(&self) -> &str {
107 "webgpu"
108 }
109
110 fn init(&mut self) -> BackendResult<()> {
111 if self.initialized {
112 return Ok(());
113 }
114
115 match WebGpuDevice::new() {
116 Ok(dev) => {
117 let dev = Arc::new(dev);
118 tracing::info!("WebGPU backend initialised on: {}", dev.adapter_name);
119 let memory = WebGpuMemoryManager::new(Arc::clone(&dev));
120 self.device = Some(dev);
121 self.memory = Some(Arc::new(memory));
122 self.initialized = true;
123 Ok(())
124 }
125 Err(e) => Err(BackendError::from(e)),
126 }
127 }
128
129 fn is_initialized(&self) -> bool {
130 self.initialized
131 }
132
133 fn gemm(
136 &self,
137 trans_a: BackendTranspose,
138 trans_b: BackendTranspose,
139 m: usize,
140 n: usize,
141 k: usize,
142 alpha: f64,
143 a_ptr: u64,
144 _lda: usize,
145 b_ptr: u64,
146 _ldb: usize,
147 beta: f64,
148 c_ptr: u64,
149 _ldc: usize,
150 ) -> BackendResult<()> {
151 self.check_init()?;
152 if m == 0 || n == 0 || k == 0 {
154 return Ok(());
155 }
156
157 if trans_a != BackendTranspose::NoTrans || trans_b != BackendTranspose::NoTrans {
159 return Err(BackendError::Unsupported(
160 "WebGPU GEMM does not yet support transposed inputs".into(),
161 ));
162 }
163
164 let dev = self.device()?;
165 let mem = self.memory()?;
166
167 let tile_size: u32 = 8;
168 let wgsl = shader::gemm_wgsl(tile_size);
169
170 let shader_mod = dev
171 .device
172 .create_shader_module(wgpu::ShaderModuleDescriptor {
173 label: Some("oxicuda-gemm"),
174 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
175 });
176
177 let pipeline = dev
178 .device
179 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
180 label: Some("oxicuda-gemm"),
181 layout: None,
182 module: &shader_mod,
183 entry_point: Some("main"),
184 compilation_options: Default::default(),
185 cache: None,
186 });
187
188 let bgl = pipeline.get_bind_group_layout(0);
189
190 let mut params_bytes = [0u8; 20];
192 params_bytes[0..4].copy_from_slice(&(m as u32).to_le_bytes());
193 params_bytes[4..8].copy_from_slice(&(n as u32).to_le_bytes());
194 params_bytes[8..12].copy_from_slice(&(k as u32).to_le_bytes());
195 params_bytes[12..16].copy_from_slice(&(alpha as f32).to_le_bytes());
196 params_bytes[16..20].copy_from_slice(&(beta as f32).to_le_bytes());
197
198 let uniform_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
199 label: Some("oxicuda-gemm-params"),
200 size: 20,
201 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
202 mapped_at_creation: false,
203 });
204 dev.queue.write_buffer(&uniform_buf, 0, ¶ms_bytes);
205
206 let bind_group = {
208 let buffers = mem
209 .lock_buffers()
210 .map_err(|e| BackendError::DeviceError(e.to_string()))?;
211 let a_info = buffers
212 .get(&a_ptr)
213 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {a_ptr}")))?;
214 let b_info = buffers
215 .get(&b_ptr)
216 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {b_ptr}")))?;
217 let c_info = buffers
218 .get(&c_ptr)
219 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {c_ptr}")))?;
220
221 dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
222 label: Some("oxicuda-gemm"),
223 layout: &bgl,
224 entries: &[
225 wgpu::BindGroupEntry {
226 binding: 0,
227 resource: a_info.buffer.as_entire_binding(),
228 },
229 wgpu::BindGroupEntry {
230 binding: 1,
231 resource: b_info.buffer.as_entire_binding(),
232 },
233 wgpu::BindGroupEntry {
234 binding: 2,
235 resource: c_info.buffer.as_entire_binding(),
236 },
237 wgpu::BindGroupEntry {
238 binding: 3,
239 resource: uniform_buf.as_entire_binding(),
240 },
241 ],
242 })
243 };
244
245 let mut encoder = dev
246 .device
247 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
248 label: Some("oxicuda-gemm"),
249 });
250
251 {
252 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
253 label: Some("oxicuda-gemm"),
254 timestamp_writes: None,
255 });
256 pass.set_pipeline(&pipeline);
257 pass.set_bind_group(0, &bind_group, &[]);
258 let wg_x = (n as u32).div_ceil(tile_size);
259 let wg_y = (m as u32).div_ceil(tile_size);
260 pass.dispatch_workgroups(wg_x, wg_y, 1);
261 }
262
263 dev.queue.submit(std::iter::once(encoder.finish()));
264 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
265
266 Ok(())
267 }
268
269 fn conv2d_forward(
270 &self,
271 input_ptr: u64,
272 input_shape: &[usize],
273 filter_ptr: u64,
274 filter_shape: &[usize],
275 output_ptr: u64,
276 output_shape: &[usize],
277 stride: &[usize],
278 padding: &[usize],
279 ) -> BackendResult<()> {
280 self.check_init()?;
281
282 if input_shape.len() != 4 {
283 return Err(BackendError::InvalidArgument(
284 "input_shape must have 4 elements (NCHW)".into(),
285 ));
286 }
287 if filter_shape.len() != 4 {
288 return Err(BackendError::InvalidArgument(
289 "filter_shape must have 4 elements (KCFHFW)".into(),
290 ));
291 }
292 if output_shape.len() != 4 {
293 return Err(BackendError::InvalidArgument(
294 "output_shape must have 4 elements (NKOhOw)".into(),
295 ));
296 }
297 if stride.len() != 2 {
298 return Err(BackendError::InvalidArgument(
299 "stride must have 2 elements [sh, sw]".into(),
300 ));
301 }
302 if padding.len() != 2 {
303 return Err(BackendError::InvalidArgument(
304 "padding must have 2 elements [ph, pw]".into(),
305 ));
306 }
307
308 let mem = self.memory()?;
309
310 let batch = input_shape[0];
311 let c_in = input_shape[1];
312 let h_in = input_shape[2];
313 let w_in = input_shape[3];
314 let k_out = filter_shape[0];
315 let fh = filter_shape[2];
316 let fw = filter_shape[3];
317 let oh = output_shape[2];
318 let ow = output_shape[3];
319 let sh = stride[0];
320 let sw = stride[1];
321 let ph = padding[0];
322 let pw = padding[1];
323
324 let in_elems: usize = input_shape.iter().product();
325 let f_elems: usize = filter_shape.iter().product();
326 let o_elems: usize = output_shape.iter().product();
327
328 let mut in_bytes = vec![0u8; in_elems * 4];
330 let mut f_bytes = vec![0u8; f_elems * 4];
331 mem.copy_from_device(&mut in_bytes, input_ptr)
332 .map_err(BackendError::from)?;
333 mem.copy_from_device(&mut f_bytes, filter_ptr)
334 .map_err(BackendError::from)?;
335
336 let in_f32 = bytes_to_f32_vec(&in_bytes);
337 let f_f32 = bytes_to_f32_vec(&f_bytes);
338 let mut out_f32 = vec![0.0f32; o_elems];
339
340 for b in 0..batch {
341 for kf in 0..k_out {
342 for oy in 0..oh {
343 for ox in 0..ow {
344 let mut acc = 0.0f32;
345 for ci in 0..c_in {
346 for fy in 0..fh {
347 for fx in 0..fw {
348 let iy = (oy * sh + fy) as isize - ph as isize;
349 let ix = (ox * sw + fx) as isize - pw as isize;
350 if iy >= 0
351 && (iy as usize) < h_in
352 && ix >= 0
353 && (ix as usize) < w_in
354 {
355 let in_idx = ((b * c_in + ci) * h_in + iy as usize) * w_in
356 + ix as usize;
357 let f_idx = ((kf * c_in + ci) * fh + fy) * fw + fx;
358 acc += in_f32[in_idx] * f_f32[f_idx];
359 }
360 }
361 }
362 }
363 out_f32[((b * k_out + kf) * oh + oy) * ow + ox] = acc;
364 }
365 }
366 }
367 }
368
369 let out_bytes = f32_slice_to_bytes(&out_f32);
370 mem.copy_to_device(output_ptr, &out_bytes)
371 .map_err(BackendError::from)?;
372
373 Ok(())
374 }
375
376 fn attention(
377 &self,
378 q_ptr: u64,
379 k_ptr: u64,
380 v_ptr: u64,
381 o_ptr: u64,
382 batch: usize,
383 heads: usize,
384 seq_q: usize,
385 seq_kv: usize,
386 head_dim: usize,
387 scale: f64,
388 causal: bool,
389 ) -> BackendResult<()> {
390 self.check_init()?;
391
392 if seq_q == 0 || seq_kv == 0 || head_dim == 0 {
393 return Err(BackendError::InvalidArgument(
394 "seq_q, seq_kv, and head_dim must all be > 0".into(),
395 ));
396 }
397 if scale <= 0.0 || !scale.is_finite() {
398 return Err(BackendError::InvalidArgument(format!(
399 "scale must be a positive finite number, got {scale}"
400 )));
401 }
402
403 let mem = self.memory()?;
404
405 let batch_heads = batch * heads;
406 let q_elems = batch_heads * seq_q * head_dim;
407 let kv_elems = batch_heads * seq_kv * head_dim;
408 let o_elems = q_elems;
409
410 let mut q_bytes = vec![0u8; q_elems * 4];
412 let mut k_bytes = vec![0u8; kv_elems * 4];
413 let mut v_bytes = vec![0u8; kv_elems * 4];
414
415 mem.copy_from_device(&mut q_bytes, q_ptr)
416 .map_err(BackendError::from)?;
417 mem.copy_from_device(&mut k_bytes, k_ptr)
418 .map_err(BackendError::from)?;
419 mem.copy_from_device(&mut v_bytes, v_ptr)
420 .map_err(BackendError::from)?;
421
422 let q_f32 = bytes_to_f32_vec(&q_bytes);
423 let k_f32 = bytes_to_f32_vec(&k_bytes);
424 let v_f32 = bytes_to_f32_vec(&v_bytes);
425 let mut o_f32 = vec![0.0f32; o_elems];
426
427 let scale_f32 = scale as f32;
428
429 for bh in 0..batch_heads {
430 let q_off = bh * seq_q * head_dim;
431 let k_off = bh * seq_kv * head_dim;
432 let v_off = k_off;
433
434 for sq in 0..seq_q {
435 let kv_limit = if causal { (sq + 1).min(seq_kv) } else { seq_kv };
436
437 let mut max_score = f32::NEG_INFINITY;
439 for sk in 0..kv_limit {
440 let mut dot = 0.0f32;
441 for dd in 0..head_dim {
442 dot +=
443 q_f32[q_off + sq * head_dim + dd] * k_f32[k_off + sk * head_dim + dd];
444 }
445 let s = dot * scale_f32;
446 if s > max_score {
447 max_score = s;
448 }
449 }
450
451 let mut sum_exp = 0.0f32;
453 let mut acc = vec![0.0f32; head_dim];
454 for sk in 0..kv_limit {
455 let mut dot = 0.0f32;
456 for dd in 0..head_dim {
457 dot +=
458 q_f32[q_off + sq * head_dim + dd] * k_f32[k_off + sk * head_dim + dd];
459 }
460 let w = (dot * scale_f32 - max_score).exp();
461 sum_exp += w;
462 for dd in 0..head_dim {
463 acc[dd] += w * v_f32[v_off + sk * head_dim + dd];
464 }
465 }
466
467 let o_base = q_off + sq * head_dim;
469 if sum_exp > 0.0 {
470 for dd in 0..head_dim {
471 o_f32[o_base + dd] = acc[dd] / sum_exp;
472 }
473 }
474 }
475 }
476
477 let o_bytes = f32_slice_to_bytes(&o_f32);
478 mem.copy_to_device(o_ptr, &o_bytes)
479 .map_err(BackendError::from)?;
480
481 Ok(())
482 }
483
484 fn reduce(
485 &self,
486 op: ReduceOp,
487 input_ptr: u64,
488 output_ptr: u64,
489 shape: &[usize],
490 axis: usize,
491 ) -> BackendResult<()> {
492 self.check_init()?;
493
494 if shape.is_empty() {
495 return Err(BackendError::InvalidArgument(
496 "shape must not be empty".into(),
497 ));
498 }
499 if axis >= shape.len() {
500 return Err(BackendError::InvalidArgument(format!(
501 "axis {axis} is out of bounds for shape of length {}",
502 shape.len()
503 )));
504 }
505
506 if shape.len() != 1 {
510 return Err(BackendError::Unsupported(
511 "WebGPU reduce currently supports only 1-D shapes".into(),
512 ));
513 }
514
515 let n_elements = shape[0];
516 if n_elements == 0 {
517 return Ok(());
518 }
519
520 let dev = self.device()?;
521 let mem = self.memory()?;
522 let op_str = map_reduce_op(op);
523
524 let wg_count = (n_elements as u32).div_ceil(256);
526
527 let pass1_wgsl = shader::reduction_wgsl(op_str);
528 let pass1_shader = dev
529 .device
530 .create_shader_module(wgpu::ShaderModuleDescriptor {
531 label: Some("oxicuda-reduce-pass1"),
532 source: wgpu::ShaderSource::Wgsl(pass1_wgsl.into()),
533 });
534 let pass1_pipeline = dev
535 .device
536 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
537 label: Some("oxicuda-reduce-pass1"),
538 layout: None,
539 module: &pass1_shader,
540 entry_point: Some("main"),
541 compilation_options: Default::default(),
542 cache: None,
543 });
544
545 let partial_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
547 label: Some("oxicuda-reduce-partial"),
548 size: (wg_count as u64) * 4, usage: wgpu::BufferUsages::STORAGE
550 | wgpu::BufferUsages::COPY_SRC
551 | wgpu::BufferUsages::COPY_DST,
552 mapped_at_creation: false,
553 });
554
555 let mut p1_params = [0u8; 4];
557 p1_params[0..4].copy_from_slice(&(n_elements as u32).to_le_bytes());
558 let p1_uniform = dev.device.create_buffer(&wgpu::BufferDescriptor {
559 label: Some("oxicuda-reduce-p1-params"),
560 size: 4,
561 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
562 mapped_at_creation: false,
563 });
564 dev.queue.write_buffer(&p1_uniform, 0, &p1_params);
565
566 let bgl1 = pass1_pipeline.get_bind_group_layout(0);
567
568 let bg1 = {
569 let buffers = mem
570 .lock_buffers()
571 .map_err(|e| BackendError::DeviceError(e.to_string()))?;
572 let in_info = buffers.get(&input_ptr).ok_or_else(|| {
573 BackendError::InvalidArgument(format!("unknown handle {input_ptr}"))
574 })?;
575
576 dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
577 label: Some("oxicuda-reduce-pass1"),
578 layout: &bgl1,
579 entries: &[
580 wgpu::BindGroupEntry {
581 binding: 0,
582 resource: in_info.buffer.as_entire_binding(),
583 },
584 wgpu::BindGroupEntry {
585 binding: 1,
586 resource: partial_buf.as_entire_binding(),
587 },
588 wgpu::BindGroupEntry {
589 binding: 2,
590 resource: p1_uniform.as_entire_binding(),
591 },
592 ],
593 })
594 };
595
596 let pass2_wgsl = shader::reduction_final_wgsl(op_str);
598 let pass2_shader = dev
599 .device
600 .create_shader_module(wgpu::ShaderModuleDescriptor {
601 label: Some("oxicuda-reduce-pass2"),
602 source: wgpu::ShaderSource::Wgsl(pass2_wgsl.into()),
603 });
604 let pass2_pipeline = dev
605 .device
606 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
607 label: Some("oxicuda-reduce-pass2"),
608 layout: None,
609 module: &pass2_shader,
610 entry_point: Some("main"),
611 compilation_options: Default::default(),
612 cache: None,
613 });
614
615 let mut p2_params = [0u8; 4];
617 p2_params[0..4].copy_from_slice(&wg_count.to_le_bytes());
618 let p2_uniform = dev.device.create_buffer(&wgpu::BufferDescriptor {
619 label: Some("oxicuda-reduce-p2-params"),
620 size: 4,
621 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
622 mapped_at_creation: false,
623 });
624 dev.queue.write_buffer(&p2_uniform, 0, &p2_params);
625
626 let bgl2 = pass2_pipeline.get_bind_group_layout(0);
627
628 let bg2 = {
629 let buffers = mem
630 .lock_buffers()
631 .map_err(|e| BackendError::DeviceError(e.to_string()))?;
632 let out_info = buffers.get(&output_ptr).ok_or_else(|| {
633 BackendError::InvalidArgument(format!("unknown handle {output_ptr}"))
634 })?;
635
636 dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
637 label: Some("oxicuda-reduce-pass2"),
638 layout: &bgl2,
639 entries: &[
640 wgpu::BindGroupEntry {
641 binding: 0,
642 resource: partial_buf.as_entire_binding(),
643 },
644 wgpu::BindGroupEntry {
645 binding: 1,
646 resource: out_info.buffer.as_entire_binding(),
647 },
648 wgpu::BindGroupEntry {
649 binding: 2,
650 resource: p2_uniform.as_entire_binding(),
651 },
652 ],
653 })
654 };
655
656 let mut encoder = dev
658 .device
659 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
660 label: Some("oxicuda-reduce"),
661 });
662
663 {
664 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
665 label: Some("oxicuda-reduce-pass1"),
666 timestamp_writes: None,
667 });
668 pass.set_pipeline(&pass1_pipeline);
669 pass.set_bind_group(0, &bg1, &[]);
670 pass.dispatch_workgroups(wg_count, 1, 1);
671 }
672 {
673 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
674 label: Some("oxicuda-reduce-pass2"),
675 timestamp_writes: None,
676 });
677 pass.set_pipeline(&pass2_pipeline);
678 pass.set_bind_group(0, &bg2, &[]);
679 pass.dispatch_workgroups(1, 1, 1);
680 }
681
682 dev.queue.submit(std::iter::once(encoder.finish()));
683 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
684
685 if op == ReduceOp::Mean && n_elements > 1 {
687 let mut buf = [0u8; 4];
688 mem.copy_from_device(&mut buf, output_ptr)
689 .map_err(BackendError::from)?;
690 let val = f32::from_le_bytes(buf);
691 let mean = val / (n_elements as f32);
692 mem.copy_to_device(output_ptr, &mean.to_le_bytes())
693 .map_err(BackendError::from)?;
694 }
695
696 Ok(())
697 }
698
699 fn unary(&self, op: UnaryOp, input_ptr: u64, output_ptr: u64, n: usize) -> BackendResult<()> {
700 self.check_init()?;
701 if n == 0 {
702 return Ok(());
703 }
704
705 let dev = self.device()?;
706 let mem = self.memory()?;
707
708 let op_str = map_unary_op(op);
709 let wgsl = shader::elementwise_wgsl(op_str);
710
711 let shader_mod = dev
712 .device
713 .create_shader_module(wgpu::ShaderModuleDescriptor {
714 label: Some("oxicuda-unary"),
715 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
716 });
717
718 let pipeline = dev
719 .device
720 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
721 label: Some("oxicuda-unary"),
722 layout: None,
723 module: &shader_mod,
724 entry_point: Some("main"),
725 compilation_options: Default::default(),
726 cache: None,
727 });
728
729 let bgl = pipeline.get_bind_group_layout(0);
730
731 let bind_group = {
732 let buffers = mem
733 .lock_buffers()
734 .map_err(|e| BackendError::DeviceError(e.to_string()))?;
735 let in_info = buffers.get(&input_ptr).ok_or_else(|| {
736 BackendError::InvalidArgument(format!("unknown handle {input_ptr}"))
737 })?;
738 let out_info = buffers.get(&output_ptr).ok_or_else(|| {
739 BackendError::InvalidArgument(format!("unknown handle {output_ptr}"))
740 })?;
741
742 dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
743 label: Some("oxicuda-unary"),
744 layout: &bgl,
745 entries: &[
746 wgpu::BindGroupEntry {
747 binding: 0,
748 resource: in_info.buffer.as_entire_binding(),
749 },
750 wgpu::BindGroupEntry {
751 binding: 1,
752 resource: out_info.buffer.as_entire_binding(),
753 },
754 ],
755 })
756 };
757
758 let mut encoder = dev
759 .device
760 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
761 label: Some("oxicuda-unary"),
762 });
763
764 {
765 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
766 label: Some("oxicuda-unary"),
767 timestamp_writes: None,
768 });
769 pass.set_pipeline(&pipeline);
770 pass.set_bind_group(0, &bind_group, &[]);
771 let workgroups = (n as u32).div_ceil(256);
772 pass.dispatch_workgroups(workgroups, 1, 1);
773 }
774
775 dev.queue.submit(std::iter::once(encoder.finish()));
776 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
777
778 Ok(())
779 }
780
781 fn binary(
782 &self,
783 op: BinaryOp,
784 a_ptr: u64,
785 b_ptr: u64,
786 output_ptr: u64,
787 n: usize,
788 ) -> BackendResult<()> {
789 self.check_init()?;
790 if n == 0 {
791 return Ok(());
792 }
793
794 let dev = self.device()?;
795 let mem = self.memory()?;
796
797 let op_str = map_binary_op(op);
798 let wgsl = shader::binary_wgsl(op_str);
799
800 let shader_mod = dev
801 .device
802 .create_shader_module(wgpu::ShaderModuleDescriptor {
803 label: Some("oxicuda-binary"),
804 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
805 });
806
807 let pipeline = dev
808 .device
809 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
810 label: Some("oxicuda-binary"),
811 layout: None,
812 module: &shader_mod,
813 entry_point: Some("main"),
814 compilation_options: Default::default(),
815 cache: None,
816 });
817
818 let bgl = pipeline.get_bind_group_layout(0);
819
820 let bind_group = {
821 let buffers = mem
822 .lock_buffers()
823 .map_err(|e| BackendError::DeviceError(e.to_string()))?;
824 let a_info = buffers
825 .get(&a_ptr)
826 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {a_ptr}")))?;
827 let b_info = buffers
828 .get(&b_ptr)
829 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {b_ptr}")))?;
830 let out_info = buffers.get(&output_ptr).ok_or_else(|| {
831 BackendError::InvalidArgument(format!("unknown handle {output_ptr}"))
832 })?;
833
834 dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
835 label: Some("oxicuda-binary"),
836 layout: &bgl,
837 entries: &[
838 wgpu::BindGroupEntry {
839 binding: 0,
840 resource: a_info.buffer.as_entire_binding(),
841 },
842 wgpu::BindGroupEntry {
843 binding: 1,
844 resource: b_info.buffer.as_entire_binding(),
845 },
846 wgpu::BindGroupEntry {
847 binding: 2,
848 resource: out_info.buffer.as_entire_binding(),
849 },
850 ],
851 })
852 };
853
854 let mut encoder = dev
855 .device
856 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
857 label: Some("oxicuda-binary"),
858 });
859
860 {
861 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
862 label: Some("oxicuda-binary"),
863 timestamp_writes: None,
864 });
865 pass.set_pipeline(&pipeline);
866 pass.set_bind_group(0, &bind_group, &[]);
867 let workgroups = (n as u32).div_ceil(256);
868 pass.dispatch_workgroups(workgroups, 1, 1);
869 }
870
871 dev.queue.submit(std::iter::once(encoder.finish()));
872 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
873
874 Ok(())
875 }
876
877 fn synchronize(&self) -> BackendResult<()> {
880 self.check_init()?;
881 if let Some(dev) = &self.device {
882 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
883 }
884 Ok(())
885 }
886
887 fn alloc(&self, bytes: usize) -> BackendResult<u64> {
890 self.check_init()?;
891 if bytes == 0 {
892 return Err(BackendError::InvalidArgument(
893 "cannot allocate 0 bytes".into(),
894 ));
895 }
896 self.memory()?.alloc(bytes).map_err(BackendError::from)
897 }
898
899 fn free(&self, ptr: u64) -> BackendResult<()> {
900 self.check_init()?;
901 self.memory()?.free(ptr).map_err(BackendError::from)
902 }
903
904 fn copy_htod(&self, dst: u64, src: &[u8]) -> BackendResult<()> {
905 self.check_init()?;
906 if src.is_empty() {
907 return Ok(());
908 }
909 self.memory()?
910 .copy_to_device(dst, src)
911 .map_err(BackendError::from)
912 }
913
914 fn copy_dtoh(&self, dst: &mut [u8], src: u64) -> BackendResult<()> {
915 self.check_init()?;
916 if dst.is_empty() {
917 return Ok(());
918 }
919 self.memory()?
920 .copy_from_device(dst, src)
921 .map_err(BackendError::from)
922 }
923}
924
925fn bytes_to_f32_vec(bytes: &[u8]) -> Vec<f32> {
929 bytes
930 .chunks_exact(4)
931 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
932 .collect()
933}
934
935fn f32_slice_to_bytes(data: &[f32]) -> Vec<u8> {
937 data.iter().flat_map(|v| v.to_le_bytes()).collect()
938}
939
940#[cfg(test)]
943mod tests {
944 use super::*;
945 use oxicuda_backend::{BackendTranspose, BinaryOp, ReduceOp, UnaryOp};
946
947 #[test]
950 fn webgpu_backend_new_uninitialized() {
951 let b = WebGpuBackend::new();
952 assert!(!b.is_initialized());
953 }
954
955 #[test]
956 fn webgpu_backend_name() {
957 let b = WebGpuBackend::new();
958 assert_eq!(b.name(), "webgpu");
959 }
960
961 #[test]
962 fn webgpu_backend_default() {
963 let b = WebGpuBackend::default();
964 assert!(!b.is_initialized());
965 assert_eq!(b.name(), "webgpu");
966 }
967
968 #[test]
969 fn backend_debug_impl() {
970 let b = WebGpuBackend::new();
971 let s = format!("{b:?}");
972 assert!(s.contains("WebGpuBackend"));
973 }
974
975 #[test]
978 fn backend_object_safe() {
979 let b: Box<dyn ComputeBackend> = Box::new(WebGpuBackend::new());
980 assert_eq!(b.name(), "webgpu");
981 }
982
983 #[test]
986 fn backend_not_initialized_gemm() {
987 let b = WebGpuBackend::new();
988 let result = b.gemm(
989 BackendTranspose::NoTrans,
990 BackendTranspose::NoTrans,
991 4,
992 4,
993 4,
994 1.0,
995 0,
996 4,
997 0,
998 4,
999 0.0,
1000 0,
1001 4,
1002 );
1003 assert_eq!(result, Err(BackendError::NotInitialized));
1004 }
1005
1006 #[test]
1007 fn backend_not_initialized_alloc() {
1008 let b = WebGpuBackend::new();
1009 let result = b.alloc(1024);
1010 assert_eq!(result, Err(BackendError::NotInitialized));
1011 }
1012
1013 #[test]
1014 fn backend_not_initialized_synchronize() {
1015 let b = WebGpuBackend::new();
1016 assert_eq!(b.synchronize(), Err(BackendError::NotInitialized));
1017 }
1018
1019 #[test]
1020 fn backend_not_initialized_free() {
1021 let b = WebGpuBackend::new();
1022 assert_eq!(b.free(1), Err(BackendError::NotInitialized));
1023 }
1024
1025 #[test]
1026 fn backend_not_initialized_copy_htod() {
1027 let b = WebGpuBackend::new();
1028 assert_eq!(b.copy_htod(1, b"hello"), Err(BackendError::NotInitialized));
1029 }
1030
1031 #[test]
1032 fn backend_not_initialized_copy_dtoh() {
1033 let b = WebGpuBackend::new();
1034 let mut buf = [0u8; 4];
1035 assert_eq!(b.copy_dtoh(&mut buf, 1), Err(BackendError::NotInitialized));
1036 }
1037
1038 fn try_init() -> Option<WebGpuBackend> {
1043 let mut b = WebGpuBackend::new();
1044 match b.init() {
1045 Ok(()) => Some(b),
1046 Err(_) => None,
1047 }
1048 }
1049
1050 #[test]
1051 fn gemm_zero_size_after_init() {
1052 let Some(b) = try_init() else {
1053 return;
1054 };
1055 let result = b.gemm(
1056 BackendTranspose::NoTrans,
1057 BackendTranspose::NoTrans,
1058 0,
1059 0,
1060 0,
1061 1.0,
1062 0,
1063 1,
1064 0,
1065 1,
1066 0.0,
1067 0,
1068 1,
1069 );
1070 assert_eq!(result, Ok(()));
1071 }
1072
1073 #[test]
1074 fn unary_zero_elements_after_init() {
1075 let Some(b) = try_init() else {
1076 return;
1077 };
1078 assert_eq!(b.unary(UnaryOp::Relu, 0, 0, 0), Ok(()));
1079 }
1080
1081 #[test]
1082 fn binary_zero_elements_after_init() {
1083 let Some(b) = try_init() else {
1084 return;
1085 };
1086 assert_eq!(b.binary(BinaryOp::Add, 0, 0, 0, 0), Ok(()));
1087 }
1088
1089 #[test]
1090 fn copy_htod_empty_noop() {
1091 let Some(b) = try_init() else {
1092 return;
1093 };
1094 assert_eq!(b.copy_htod(0, &[]), Ok(()));
1095 }
1096
1097 #[test]
1098 fn copy_dtoh_empty_noop() {
1099 let Some(b) = try_init() else {
1100 return;
1101 };
1102 assert_eq!(b.copy_dtoh(&mut [], 0), Ok(()));
1103 }
1104
1105 #[test]
1106 fn alloc_zero_bytes_error() {
1107 let Some(b) = try_init() else {
1108 return;
1109 };
1110 assert_eq!(
1111 b.alloc(0),
1112 Err(BackendError::InvalidArgument(
1113 "cannot allocate 0 bytes".into()
1114 ))
1115 );
1116 }
1117
1118 #[test]
1119 fn synchronize_after_init() {
1120 let Some(b) = try_init() else {
1121 return;
1122 };
1123 assert_eq!(b.synchronize(), Ok(()));
1124 }
1125
1126 #[test]
1129 fn reduce_empty_shape_error() {
1130 let Some(b) = try_init() else {
1131 return;
1132 };
1133 assert_eq!(
1134 b.reduce(ReduceOp::Sum, 0, 0, &[], 0),
1135 Err(BackendError::InvalidArgument(
1136 "shape must not be empty".into()
1137 ))
1138 );
1139 }
1140
1141 #[test]
1142 fn reduce_axis_out_of_bounds_error() {
1143 let Some(b) = try_init() else {
1144 return;
1145 };
1146 assert_eq!(
1147 b.reduce(ReduceOp::Sum, 0, 0, &[4, 4], 5),
1148 Err(BackendError::InvalidArgument(
1149 "axis 5 is out of bounds for shape of length 2".into()
1150 ))
1151 );
1152 }
1153
1154 #[test]
1155 fn attention_zero_seq_error() {
1156 let Some(b) = try_init() else {
1157 return;
1158 };
1159 assert_eq!(
1160 b.attention(0, 0, 0, 0, 1, 1, 0, 8, 64, 0.125, false),
1161 Err(BackendError::InvalidArgument(
1162 "seq_q, seq_kv, and head_dim must all be > 0".into()
1163 ))
1164 );
1165 }
1166
1167 #[test]
1168 fn attention_nonpositive_scale_error() {
1169 let Some(b) = try_init() else {
1170 return;
1171 };
1172 assert_eq!(
1173 b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, 0.0, false),
1174 Err(BackendError::InvalidArgument(
1175 "scale must be a positive finite number, got 0".into()
1176 ))
1177 );
1178 assert_eq!(
1179 b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, -1.0, false),
1180 Err(BackendError::InvalidArgument(
1181 "scale must be a positive finite number, got -1".into()
1182 ))
1183 );
1184 assert!(
1185 b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, f64::INFINITY, false)
1186 .is_err()
1187 );
1188 }
1189
1190 #[test]
1191 fn conv2d_wrong_input_shape_error() {
1192 let Some(b) = try_init() else {
1193 return;
1194 };
1195 assert_eq!(
1197 b.conv2d_forward(
1198 0,
1199 &[1, 3, 32],
1200 0,
1201 &[16, 3, 3, 3],
1202 0,
1203 &[1, 16, 30, 30],
1204 &[1, 1],
1205 &[0, 0]
1206 ),
1207 Err(BackendError::InvalidArgument(
1208 "input_shape must have 4 elements (NCHW)".into()
1209 ))
1210 );
1211 }
1212
1213 #[test]
1214 fn conv2d_wrong_filter_shape_error() {
1215 let Some(b) = try_init() else {
1216 return;
1217 };
1218 assert_eq!(
1219 b.conv2d_forward(
1220 0,
1221 &[1, 3, 32, 32],
1222 0,
1223 &[16, 3, 3],
1224 0,
1225 &[1, 16, 30, 30],
1226 &[1, 1],
1227 &[0, 0]
1228 ),
1229 Err(BackendError::InvalidArgument(
1230 "filter_shape must have 4 elements (KCFHFW)".into()
1231 ))
1232 );
1233 }
1234
1235 #[test]
1236 fn conv2d_wrong_stride_shape_error() {
1237 let Some(b) = try_init() else {
1238 return;
1239 };
1240 assert_eq!(
1241 b.conv2d_forward(
1242 0,
1243 &[1, 3, 32, 32],
1244 0,
1245 &[16, 3, 3, 3],
1246 0,
1247 &[1, 16, 30, 30],
1248 &[1], &[0, 0],
1250 ),
1251 Err(BackendError::InvalidArgument(
1252 "stride must have 2 elements [sh, sw]".into()
1253 ))
1254 );
1255 }
1256
1257 #[test]
1260 fn init_idempotent() {
1261 let Some(mut b) = try_init() else {
1262 return;
1263 };
1264 assert_eq!(b.init(), Ok(()));
1266 assert!(b.is_initialized());
1267 }
1268
1269 #[test]
1272 fn webgpu_init_graceful_failure() {
1273 let mut b = WebGpuBackend::new();
1276 let _result = b.init(); }
1279
1280 fn upload_f32(b: &WebGpuBackend, data: &[f32]) -> u64 {
1287 let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
1288 let h = b.alloc(bytes.len()).expect("alloc");
1289 b.copy_htod(h, &bytes).expect("copy_htod");
1290 h
1291 }
1292
1293 fn download_f32(b: &WebGpuBackend, h: u64, n: usize) -> Vec<f32> {
1295 let mut bytes = vec![0u8; n * 4];
1296 b.copy_dtoh(&mut bytes, h).expect("copy_dtoh");
1297 bytes
1298 .chunks_exact(4)
1299 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
1300 .collect()
1301 }
1302
1303 #[test]
1304 fn unary_neg_small() {
1305 let Some(b) = try_init() else { return };
1306 let input = [1.0f32, -2.0, 3.0, 0.0];
1307 let in_h = upload_f32(&b, &input);
1308 let out_h = b.alloc(input.len() * 4).expect("alloc output");
1309
1310 b.unary(UnaryOp::Neg, in_h, out_h, input.len())
1311 .expect("unary neg");
1312
1313 let result = download_f32(&b, out_h, input.len());
1314 let expected = [-1.0f32, 2.0, -3.0, 0.0];
1315 for (r, e) in result.iter().zip(expected.iter()) {
1316 assert!((r - e).abs() < 1e-6, "got {r}, expected {e}");
1317 }
1318
1319 b.free(in_h).expect("free");
1320 b.free(out_h).expect("free");
1321 }
1322
1323 #[test]
1324 fn unary_abs_small() {
1325 let Some(b) = try_init() else { return };
1326 let input = [-3.0f32, 4.0, -5.0, 0.0];
1327 let in_h = upload_f32(&b, &input);
1328 let out_h = b.alloc(input.len() * 4).expect("alloc output");
1329
1330 b.unary(UnaryOp::Abs, in_h, out_h, input.len())
1331 .expect("unary abs");
1332
1333 let result = download_f32(&b, out_h, input.len());
1334 let expected = [3.0f32, 4.0, 5.0, 0.0];
1335 for (r, e) in result.iter().zip(expected.iter()) {
1336 assert!((r - e).abs() < 1e-6, "got {r}, expected {e}");
1337 }
1338
1339 b.free(in_h).expect("free");
1340 b.free(out_h).expect("free");
1341 }
1342
1343 #[test]
1344 fn binary_add_small() {
1345 let Some(b) = try_init() else { return };
1346 let a = [1.0f32, 2.0, 3.0, 4.0];
1347 let bv = [10.0f32, 20.0, 30.0, 40.0];
1348 let a_h = upload_f32(&b, &a);
1349 let b_h = upload_f32(&b, &bv);
1350 let out_h = b.alloc(a.len() * 4).expect("alloc output");
1351
1352 b.binary(BinaryOp::Add, a_h, b_h, out_h, a.len())
1353 .expect("binary add");
1354
1355 let result = download_f32(&b, out_h, a.len());
1356 let expected = [11.0f32, 22.0, 33.0, 44.0];
1357 for (r, e) in result.iter().zip(expected.iter()) {
1358 assert!((r - e).abs() < 1e-6, "got {r}, expected {e}");
1359 }
1360
1361 b.free(a_h).expect("free");
1362 b.free(b_h).expect("free");
1363 b.free(out_h).expect("free");
1364 }
1365
1366 #[test]
1367 fn binary_mul_small() {
1368 let Some(b) = try_init() else { return };
1369 let a = [2.0f32, 3.0, 4.0, 5.0];
1370 let bv = [10.0f32, 10.0, 10.0, 10.0];
1371 let a_h = upload_f32(&b, &a);
1372 let b_h = upload_f32(&b, &bv);
1373 let out_h = b.alloc(a.len() * 4).expect("alloc output");
1374
1375 b.binary(BinaryOp::Mul, a_h, b_h, out_h, a.len())
1376 .expect("binary mul");
1377
1378 let result = download_f32(&b, out_h, a.len());
1379 let expected = [20.0f32, 30.0, 40.0, 50.0];
1380 for (r, e) in result.iter().zip(expected.iter()) {
1381 assert!((r - e).abs() < 1e-6, "got {r}, expected {e}");
1382 }
1383
1384 b.free(a_h).expect("free");
1385 b.free(b_h).expect("free");
1386 b.free(out_h).expect("free");
1387 }
1388
1389 #[test]
1390 fn reduce_sum_small() {
1391 let Some(b) = try_init() else { return };
1392 let input = [1.0f32, 2.0, 3.0, 4.0];
1393 let in_h = upload_f32(&b, &input);
1394 let out_h = b.alloc(4).expect("alloc output"); b.reduce(ReduceOp::Sum, in_h, out_h, &[4], 0)
1397 .expect("reduce sum");
1398
1399 let result = download_f32(&b, out_h, 1);
1400 assert!(
1401 (result[0] - 10.0).abs() < 1e-5,
1402 "expected 10.0, got {}",
1403 result[0]
1404 );
1405
1406 b.free(in_h).expect("free");
1407 b.free(out_h).expect("free");
1408 }
1409
1410 #[test]
1411 fn reduce_max_small() {
1412 let Some(b) = try_init() else { return };
1413 let input = [1.0f32, 5.0, 3.0, 2.0];
1414 let in_h = upload_f32(&b, &input);
1415 let out_h = b.alloc(4).expect("alloc output");
1416
1417 b.reduce(ReduceOp::Max, in_h, out_h, &[4], 0)
1418 .expect("reduce max");
1419
1420 let result = download_f32(&b, out_h, 1);
1421 assert!(
1422 (result[0] - 5.0).abs() < 1e-5,
1423 "expected 5.0, got {}",
1424 result[0]
1425 );
1426
1427 b.free(in_h).expect("free");
1428 b.free(out_h).expect("free");
1429 }
1430
1431 #[test]
1432 fn reduce_mean_small() {
1433 let Some(b) = try_init() else { return };
1434 let input = [2.0f32, 4.0, 6.0, 8.0];
1435 let in_h = upload_f32(&b, &input);
1436 let out_h = b.alloc(4).expect("alloc output");
1437
1438 b.reduce(ReduceOp::Mean, in_h, out_h, &[4], 0)
1439 .expect("reduce mean");
1440
1441 let result = download_f32(&b, out_h, 1);
1442 assert!(
1443 (result[0] - 5.0).abs() < 1e-5,
1444 "expected 5.0, got {}",
1445 result[0]
1446 );
1447
1448 b.free(in_h).expect("free");
1449 b.free(out_h).expect("free");
1450 }
1451
1452 #[test]
1453 fn gemm_identity_2x2() {
1454 let Some(b) = try_init() else { return };
1455 let a = [1.0f32, 2.0, 3.0, 4.0];
1458 let eye = [1.0f32, 0.0, 0.0, 1.0];
1459 let c_init = [0.0f32; 4];
1460
1461 let a_h = upload_f32(&b, &a);
1462 let b_h = upload_f32(&b, &eye);
1463 let c_h = upload_f32(&b, &c_init);
1464
1465 b.gemm(
1466 BackendTranspose::NoTrans,
1467 BackendTranspose::NoTrans,
1468 2,
1469 2,
1470 2,
1471 1.0,
1472 a_h,
1473 2,
1474 b_h,
1475 2,
1476 0.0,
1477 c_h,
1478 2,
1479 )
1480 .expect("gemm");
1481
1482 let result = download_f32(&b, c_h, 4);
1483 for (r, e) in result.iter().zip(a.iter()) {
1484 assert!((r - e).abs() < 1e-5, "got {r}, expected {e}");
1485 }
1486
1487 b.free(a_h).expect("free");
1488 b.free(b_h).expect("free");
1489 b.free(c_h).expect("free");
1490 }
1491
1492 #[test]
1493 fn gemm_2x3_times_3x2() {
1494 let Some(b) = try_init() else { return };
1495 let a = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
1497 let bm = [7.0f32, 8.0, 9.0, 10.0, 11.0, 12.0];
1498 let c_init = [0.0f32; 4];
1499
1500 let a_h = upload_f32(&b, &a);
1501 let b_h = upload_f32(&b, &bm);
1502 let c_h = upload_f32(&b, &c_init);
1503
1504 b.gemm(
1505 BackendTranspose::NoTrans,
1506 BackendTranspose::NoTrans,
1507 2,
1508 2,
1509 3,
1510 1.0,
1511 a_h,
1512 3,
1513 b_h,
1514 2,
1515 0.0,
1516 c_h,
1517 2,
1518 )
1519 .expect("gemm");
1520
1521 let result = download_f32(&b, c_h, 4);
1523 let expected = [58.0f32, 64.0, 139.0, 154.0];
1524 for (r, e) in result.iter().zip(expected.iter()) {
1525 assert!((r - e).abs() < 1e-4, "got {r}, expected {e}");
1526 }
1527
1528 b.free(a_h).expect("free");
1529 b.free(b_h).expect("free");
1530 b.free(c_h).expect("free");
1531 }
1532
1533 #[test]
1534 fn gemm_alpha_beta() {
1535 let Some(b) = try_init() else { return };
1536 let a = [1.0f32, 0.0, 0.0, 1.0];
1540 let bm = [1.0f32, 0.0, 0.0, 1.0];
1541 let c_init = [1.0f32, 1.0, 1.0, 1.0];
1542
1543 let a_h = upload_f32(&b, &a);
1544 let b_h = upload_f32(&b, &bm);
1545 let c_h = upload_f32(&b, &c_init);
1546
1547 b.gemm(
1548 BackendTranspose::NoTrans,
1549 BackendTranspose::NoTrans,
1550 2,
1551 2,
1552 2,
1553 2.0,
1554 a_h,
1555 2,
1556 b_h,
1557 2,
1558 3.0,
1559 c_h,
1560 2,
1561 )
1562 .expect("gemm alpha+beta");
1563
1564 let result = download_f32(&b, c_h, 4);
1565 let expected = [5.0f32, 3.0, 3.0, 5.0];
1566 for (r, e) in result.iter().zip(expected.iter()) {
1567 assert!((r - e).abs() < 1e-4, "got {r}, expected {e}");
1568 }
1569
1570 b.free(a_h).expect("free");
1571 b.free(b_h).expect("free");
1572 b.free(c_h).expect("free");
1573 }
1574
1575 #[test]
1578 fn conv2d_identity_1x1() {
1579 let Some(b) = try_init() else { return };
1582 let input: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1583 let filter = [2.0f32];
1584 let expected: Vec<f32> = input.iter().map(|x| x * 2.0).collect();
1585
1586 let in_h = upload_f32(&b, &input);
1587 let f_h = upload_f32(&b, &filter);
1588 let out_h = b.alloc(9 * 4).expect("alloc output");
1589
1590 b.conv2d_forward(
1591 in_h,
1592 &[1, 1, 3, 3],
1593 f_h,
1594 &[1, 1, 1, 1],
1595 out_h,
1596 &[1, 1, 3, 3],
1597 &[1, 1],
1598 &[0, 0],
1599 )
1600 .expect("conv2d");
1601
1602 let result = download_f32(&b, out_h, 9);
1603 for (r, e) in result.iter().zip(expected.iter()) {
1604 assert!((r - e).abs() < 1e-5, "got {r}, expected {e}");
1605 }
1606
1607 b.free(in_h).expect("free");
1608 b.free(f_h).expect("free");
1609 b.free(out_h).expect("free");
1610 }
1611
1612 #[test]
1613 fn conv2d_3x3_no_padding() {
1614 let Some(b) = try_init() else { return };
1617 let input: Vec<f32> = (0..16).map(|x| x as f32).collect();
1618 let filter = [1.0f32; 9];
1619
1620 let in_h = upload_f32(&b, &input);
1621 let f_h = upload_f32(&b, &filter);
1622 let out_h = b.alloc(4 * 4).expect("alloc output");
1623
1624 b.conv2d_forward(
1625 in_h,
1626 &[1, 1, 4, 4],
1627 f_h,
1628 &[1, 1, 3, 3],
1629 out_h,
1630 &[1, 1, 2, 2],
1631 &[1, 1],
1632 &[0, 0],
1633 )
1634 .expect("conv2d");
1635
1636 let result = download_f32(&b, out_h, 4);
1637 assert!((result[0] - 45.0).abs() < 1e-4, "got {}", result[0]);
1639 assert!((result[1] - 54.0).abs() < 1e-4, "got {}", result[1]);
1641
1642 b.free(in_h).expect("free");
1643 b.free(f_h).expect("free");
1644 b.free(out_h).expect("free");
1645 }
1646
1647 #[test]
1648 fn conv2d_with_padding() {
1649 let Some(b) = try_init() else { return };
1653 let input = [1.0f32, 2.0, 3.0, 4.0];
1654 let filter = [1.0f32; 9];
1655
1656 let in_h = upload_f32(&b, &input);
1657 let f_h = upload_f32(&b, &filter);
1658 let out_h = b.alloc(4 * 4).expect("alloc output");
1659
1660 b.conv2d_forward(
1661 in_h,
1662 &[1, 1, 2, 2],
1663 f_h,
1664 &[1, 1, 3, 3],
1665 out_h,
1666 &[1, 1, 2, 2],
1667 &[1, 1],
1668 &[1, 1],
1669 )
1670 .expect("conv2d");
1671
1672 let result = download_f32(&b, out_h, 4);
1673 assert!((result[0] - 10.0).abs() < 1e-4, "got {}", result[0]);
1676
1677 b.free(in_h).expect("free");
1678 b.free(f_h).expect("free");
1679 b.free(out_h).expect("free");
1680 }
1681
1682 #[test]
1685 fn attention_uniform_weights() {
1686 let Some(b) = try_init() else { return };
1690
1691 let q = [1.0f32, 0.0];
1692 let k = [1.0f32, 0.0, 1.0, 0.0];
1693 let v = [1.0f32, 2.0, 3.0, 4.0];
1694
1695 let q_h = upload_f32(&b, &q);
1696 let k_h = upload_f32(&b, &k);
1697 let v_h = upload_f32(&b, &v);
1698 let o_h = b.alloc(2 * 4).expect("alloc output");
1699
1700 b.attention(q_h, k_h, v_h, o_h, 1, 1, 1, 2, 2, 1.0, false)
1701 .expect("attention");
1702
1703 let result = download_f32(&b, o_h, 2);
1704 assert!(
1706 (result[0] - 2.0).abs() < 1e-4,
1707 "got {}, expected 2.0",
1708 result[0]
1709 );
1710 assert!(
1711 (result[1] - 3.0).abs() < 1e-4,
1712 "got {}, expected 3.0",
1713 result[1]
1714 );
1715
1716 b.free(q_h).expect("free");
1717 b.free(k_h).expect("free");
1718 b.free(v_h).expect("free");
1719 b.free(o_h).expect("free");
1720 }
1721
1722 #[test]
1723 fn attention_causal_single_token() {
1724 let Some(b) = try_init() else { return };
1729
1730 let q = [1.0f32, 1.0];
1731 let k = [1.0f32, 1.0];
1732 let v = [10.0f32, 20.0];
1733
1734 let q_h = upload_f32(&b, &q);
1735 let k_h = upload_f32(&b, &k);
1736 let v_h = upload_f32(&b, &v);
1737 let o_h = b.alloc(2 * 4).expect("alloc output");
1738
1739 b.attention(q_h, k_h, v_h, o_h, 1, 1, 2, 2, 1, 1.0, true)
1740 .expect("attention causal");
1741
1742 let result = download_f32(&b, o_h, 2);
1743 assert!(
1744 (result[0] - 10.0).abs() < 1e-4,
1745 "got {}, expected 10.0",
1746 result[0]
1747 );
1748 assert!(
1749 (result[1] - 15.0).abs() < 1e-4,
1750 "got {}, expected 15.0",
1751 result[1]
1752 );
1753
1754 b.free(q_h).expect("free");
1755 b.free(k_h).expect("free");
1756 b.free(v_h).expect("free");
1757 b.free(o_h).expect("free");
1758 }
1759
1760 #[test]
1761 fn attention_dominant_key() {
1762 let Some(b) = try_init() else { return };
1767
1768 let q = [1.0f32, 0.0];
1769 let k = [10.0f32, 0.0, 0.0, 0.0];
1770 let v = [100.0f32, 200.0, 0.0, 0.0];
1771
1772 let q_h = upload_f32(&b, &q);
1773 let k_h = upload_f32(&b, &k);
1774 let v_h = upload_f32(&b, &v);
1775 let o_h = b.alloc(2 * 4).expect("alloc output");
1776
1777 b.attention(q_h, k_h, v_h, o_h, 1, 1, 1, 2, 2, 1.0, false)
1779 .expect("attention dominant");
1780
1781 let result = download_f32(&b, o_h, 2);
1782 assert!(
1783 (result[0] - 100.0).abs() < 0.1,
1784 "got {}, expected ~100",
1785 result[0]
1786 );
1787 assert!(
1788 (result[1] - 200.0).abs() < 0.1,
1789 "got {}, expected ~200",
1790 result[1]
1791 );
1792
1793 b.free(q_h).expect("free");
1794 b.free(k_h).expect("free");
1795 b.free(v_h).expect("free");
1796 b.free(o_h).expect("free");
1797 }
1798}