1use std::sync::Arc;
7
8use oxicuda_backend::{
9 BackendError, BackendResult, BackendTranspose, BinaryOp, ComputeBackend, ReduceOp, UnaryOp,
10};
11use wgpu;
12
13use crate::{device::WebGpuDevice, memory::WebGpuMemoryManager, shader};
14
15fn map_unary_op(op: UnaryOp) -> &'static str {
18 match op {
19 UnaryOp::Relu => "relu",
20 UnaryOp::Sigmoid => "sigmoid",
21 UnaryOp::Tanh => "tanh",
22 UnaryOp::Exp => "exp",
23 UnaryOp::Log => "log",
24 UnaryOp::Sqrt => "sqrt",
25 UnaryOp::Abs => "abs",
26 UnaryOp::Neg => "neg",
27 }
28}
29
30fn map_binary_op(op: BinaryOp) -> &'static str {
31 match op {
32 BinaryOp::Add => "add",
33 BinaryOp::Sub => "sub",
34 BinaryOp::Mul => "mul",
35 BinaryOp::Div => "div",
36 BinaryOp::Max => "max",
37 BinaryOp::Min => "min",
38 }
39}
40
41fn map_reduce_op(op: ReduceOp) -> &'static str {
42 match op {
43 ReduceOp::Sum => "sum",
44 ReduceOp::Max => "max",
45 ReduceOp::Min => "min",
46 ReduceOp::Mean => "mean",
47 }
48}
49
50#[derive(Debug)]
61pub struct WebGpuBackend {
62 device: Option<Arc<WebGpuDevice>>,
63 memory: Option<Arc<WebGpuMemoryManager>>,
64 initialized: bool,
65}
66
67impl WebGpuBackend {
68 pub fn new() -> Self {
70 Self {
71 device: None,
72 memory: None,
73 initialized: false,
74 }
75 }
76
77 fn check_init(&self) -> BackendResult<()> {
79 if self.initialized {
80 Ok(())
81 } else {
82 Err(BackendError::NotInitialized)
83 }
84 }
85
86 fn memory(&self) -> BackendResult<&Arc<WebGpuMemoryManager>> {
88 self.memory.as_ref().ok_or(BackendError::NotInitialized)
89 }
90
91 fn device(&self) -> BackendResult<&Arc<WebGpuDevice>> {
93 self.device.as_ref().ok_or(BackendError::NotInitialized)
94 }
95
96 fn reduce_nd(
110 &self,
111 op: ReduceOp,
112 input_ptr: u64,
113 output_ptr: u64,
114 shape: &[usize],
115 axis: usize,
116 ) -> BackendResult<()> {
117 debug_assert!(!shape.is_empty());
121 debug_assert!(axis < shape.len());
122
123 let outer: usize = shape[..axis].iter().product();
125 let dk: usize = shape[axis];
126 let inner: usize = shape[axis + 1..].iter().product();
127
128 if outer == 0 || dk == 0 || inner == 0 {
130 return Ok(());
131 }
132
133 let total = outer.checked_mul(inner).ok_or_else(|| {
134 BackendError::InvalidArgument("reduce: outer * inner overflows usize".into())
135 })?;
136
137 let inner_stride: usize = 1;
139 let dk_stride: usize = inner;
140 let outer_stride: usize = dk
141 .checked_mul(inner)
142 .ok_or_else(|| BackendError::InvalidArgument("reduce: dk * inner overflows".into()))?;
143
144 const MAX_GRID_DIM: u32 = 32_768;
147 let total_u32: u32 = total.try_into().map_err(|_| {
148 BackendError::InvalidArgument(format!(
149 "reduce: output element count {total} exceeds u32 range"
150 ))
151 })?;
152 let grid_x: u32 = total_u32.clamp(1, MAX_GRID_DIM);
153 let grid_y: u32 = total_u32.div_ceil(grid_x);
154
155 let dev = self.device()?;
156 let mem = self.memory()?;
157 let op_str = map_reduce_op(op);
158
159 let wgsl = shader::reduction_nd_wgsl(op_str);
160 let shader_mod = dev
161 .device
162 .create_shader_module(wgpu::ShaderModuleDescriptor {
163 label: Some("oxicuda-reduce-nd"),
164 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
165 });
166 let pipeline = dev
167 .device
168 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
169 label: Some("oxicuda-reduce-nd"),
170 layout: None,
171 module: &shader_mod,
172 entry_point: Some("main"),
173 compilation_options: Default::default(),
174 cache: None,
175 });
176
177 let mut params_bytes = [0u8; 32];
179 let outer_u32: u32 = outer
180 .try_into()
181 .map_err(|_| BackendError::InvalidArgument("reduce: outer exceeds u32 range".into()))?;
182 let dk_u32: u32 = dk
183 .try_into()
184 .map_err(|_| BackendError::InvalidArgument("reduce: dk exceeds u32 range".into()))?;
185 let inner_u32: u32 = inner
186 .try_into()
187 .map_err(|_| BackendError::InvalidArgument("reduce: inner exceeds u32 range".into()))?;
188 let outer_stride_u32: u32 = outer_stride.try_into().map_err(|_| {
189 BackendError::InvalidArgument("reduce: outer_stride exceeds u32 range".into())
190 })?;
191 let dk_stride_u32: u32 = dk_stride.try_into().map_err(|_| {
192 BackendError::InvalidArgument("reduce: dk_stride exceeds u32 range".into())
193 })?;
194 let inner_stride_u32: u32 = inner_stride.try_into().map_err(|_| {
195 BackendError::InvalidArgument("reduce: inner_stride exceeds u32 range".into())
196 })?;
197 params_bytes[0..4].copy_from_slice(&outer_u32.to_le_bytes());
198 params_bytes[4..8].copy_from_slice(&dk_u32.to_le_bytes());
199 params_bytes[8..12].copy_from_slice(&inner_u32.to_le_bytes());
200 params_bytes[12..16].copy_from_slice(&outer_stride_u32.to_le_bytes());
201 params_bytes[16..20].copy_from_slice(&dk_stride_u32.to_le_bytes());
202 params_bytes[20..24].copy_from_slice(&inner_stride_u32.to_le_bytes());
203 params_bytes[24..28].copy_from_slice(&grid_x.to_le_bytes());
204 let uniform_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
207 label: Some("oxicuda-reduce-nd-params"),
208 size: 32,
209 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
210 mapped_at_creation: false,
211 });
212 dev.queue.write_buffer(&uniform_buf, 0, ¶ms_bytes);
213
214 let bgl = pipeline.get_bind_group_layout(0);
215 let bind_group = {
216 let buffers = mem
217 .lock_buffers()
218 .map_err(|e| BackendError::DeviceError(e.to_string()))?;
219 let in_info = buffers.get(&input_ptr).ok_or_else(|| {
220 BackendError::InvalidArgument(format!("unknown handle {input_ptr}"))
221 })?;
222 let out_info = buffers.get(&output_ptr).ok_or_else(|| {
223 BackendError::InvalidArgument(format!("unknown handle {output_ptr}"))
224 })?;
225
226 dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
227 label: Some("oxicuda-reduce-nd"),
228 layout: &bgl,
229 entries: &[
230 wgpu::BindGroupEntry {
231 binding: 0,
232 resource: in_info.buffer.as_entire_binding(),
233 },
234 wgpu::BindGroupEntry {
235 binding: 1,
236 resource: out_info.buffer.as_entire_binding(),
237 },
238 wgpu::BindGroupEntry {
239 binding: 2,
240 resource: uniform_buf.as_entire_binding(),
241 },
242 ],
243 })
244 };
245
246 let mut encoder = dev
247 .device
248 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
249 label: Some("oxicuda-reduce-nd"),
250 });
251 {
252 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
253 label: Some("oxicuda-reduce-nd"),
254 timestamp_writes: None,
255 });
256 pass.set_pipeline(&pipeline);
257 pass.set_bind_group(0, &bind_group, &[]);
258 pass.dispatch_workgroups(grid_x, grid_y, 1);
259 }
260
261 dev.queue.submit(std::iter::once(encoder.finish()));
262 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
263
264 Ok(())
265 }
266}
267
268impl WebGpuBackend {
269 #[allow(clippy::too_many_arguments)]
277 pub fn gemm_f16(
278 &self,
279 m: usize,
280 n: usize,
281 k: usize,
282 alpha: f64,
283 a_ptr: u64,
284 b_ptr: u64,
285 beta: f64,
286 c_ptr: u64,
287 ) -> BackendResult<()> {
288 self.check_init()?;
289 if m == 0 || n == 0 || k == 0 {
290 return Ok(());
291 }
292
293 let dev = self.device()?;
294 let mem = self.memory()?;
295
296 let tile_size: u32 = 8;
297 let wgsl = shader::gemm_wgsl_f16(tile_size);
298
299 let shader_mod = dev
300 .device
301 .create_shader_module(wgpu::ShaderModuleDescriptor {
302 label: Some("oxicuda-gemm-f16"),
303 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
304 });
305
306 let pipeline = dev
307 .device
308 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
309 label: Some("oxicuda-gemm-f16"),
310 layout: None,
311 module: &shader_mod,
312 entry_point: Some("main"),
313 compilation_options: Default::default(),
314 cache: None,
315 });
316
317 let bgl = pipeline.get_bind_group_layout(0);
318
319 let mut params_bytes = [0u8; 20];
321 params_bytes[0..4].copy_from_slice(&(m as u32).to_le_bytes());
322 params_bytes[4..8].copy_from_slice(&(n as u32).to_le_bytes());
323 params_bytes[8..12].copy_from_slice(&(k as u32).to_le_bytes());
324 params_bytes[12..16].copy_from_slice(&(alpha as f32).to_le_bytes());
325 params_bytes[16..20].copy_from_slice(&(beta as f32).to_le_bytes());
326
327 let uniform_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
328 label: Some("oxicuda-gemm-f16-params"),
329 size: 20,
330 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
331 mapped_at_creation: false,
332 });
333 dev.queue.write_buffer(&uniform_buf, 0, ¶ms_bytes);
334
335 let bind_group = {
336 let buffers = mem
337 .lock_buffers()
338 .map_err(|e| BackendError::DeviceError(e.to_string()))?;
339 let a_info = buffers
340 .get(&a_ptr)
341 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {a_ptr}")))?;
342 let b_info = buffers
343 .get(&b_ptr)
344 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {b_ptr}")))?;
345 let c_info = buffers
346 .get(&c_ptr)
347 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {c_ptr}")))?;
348
349 dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
350 label: Some("oxicuda-gemm-f16"),
351 layout: &bgl,
352 entries: &[
353 wgpu::BindGroupEntry {
354 binding: 0,
355 resource: a_info.buffer.as_entire_binding(),
356 },
357 wgpu::BindGroupEntry {
358 binding: 1,
359 resource: b_info.buffer.as_entire_binding(),
360 },
361 wgpu::BindGroupEntry {
362 binding: 2,
363 resource: c_info.buffer.as_entire_binding(),
364 },
365 wgpu::BindGroupEntry {
366 binding: 3,
367 resource: uniform_buf.as_entire_binding(),
368 },
369 ],
370 })
371 };
372
373 let mut encoder = dev
374 .device
375 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
376 label: Some("oxicuda-gemm-f16"),
377 });
378
379 {
380 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
381 label: Some("oxicuda-gemm-f16"),
382 timestamp_writes: None,
383 });
384 pass.set_pipeline(&pipeline);
385 pass.set_bind_group(0, &bind_group, &[]);
386 let wg_x = (n as u32).div_ceil(tile_size);
387 let wg_y = (m as u32).div_ceil(tile_size);
388 pass.dispatch_workgroups(wg_x, wg_y, 1);
389 }
390
391 dev.queue.submit(std::iter::once(encoder.finish()));
392 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
393
394 Ok(())
395 }
396}
397
398impl Default for WebGpuBackend {
399 fn default() -> Self {
400 Self::new()
401 }
402}
403
404impl ComputeBackend for WebGpuBackend {
407 fn name(&self) -> &str {
408 "webgpu"
409 }
410
411 fn init(&mut self) -> BackendResult<()> {
412 if self.initialized {
413 return Ok(());
414 }
415
416 match WebGpuDevice::new() {
417 Ok(dev) => {
418 let dev = Arc::new(dev);
419 tracing::info!("WebGPU backend initialised on: {}", dev.adapter_name);
420 let memory = WebGpuMemoryManager::new(Arc::clone(&dev));
421 self.device = Some(dev);
422 self.memory = Some(Arc::new(memory));
423 self.initialized = true;
424 Ok(())
425 }
426 Err(e) => Err(BackendError::from(e)),
427 }
428 }
429
430 fn is_initialized(&self) -> bool {
431 self.initialized
432 }
433
434 fn gemm(
437 &self,
438 trans_a: BackendTranspose,
439 trans_b: BackendTranspose,
440 m: usize,
441 n: usize,
442 k: usize,
443 alpha: f64,
444 a_ptr: u64,
445 _lda: usize,
446 b_ptr: u64,
447 _ldb: usize,
448 beta: f64,
449 c_ptr: u64,
450 _ldc: usize,
451 ) -> BackendResult<()> {
452 self.check_init()?;
453 if m == 0 || n == 0 || k == 0 {
455 return Ok(());
456 }
457
458 if trans_a != BackendTranspose::NoTrans || trans_b != BackendTranspose::NoTrans {
460 return Err(BackendError::Unsupported(
461 "WebGPU GEMM does not yet support transposed inputs".into(),
462 ));
463 }
464
465 let dev = self.device()?;
466 let mem = self.memory()?;
467
468 let tile_size: u32 = 8;
469 let wgsl = shader::gemm_wgsl(tile_size);
470
471 let shader_mod = dev
472 .device
473 .create_shader_module(wgpu::ShaderModuleDescriptor {
474 label: Some("oxicuda-gemm"),
475 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
476 });
477
478 let pipeline = dev
479 .device
480 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
481 label: Some("oxicuda-gemm"),
482 layout: None,
483 module: &shader_mod,
484 entry_point: Some("main"),
485 compilation_options: Default::default(),
486 cache: None,
487 });
488
489 let bgl = pipeline.get_bind_group_layout(0);
490
491 let mut params_bytes = [0u8; 20];
493 params_bytes[0..4].copy_from_slice(&(m as u32).to_le_bytes());
494 params_bytes[4..8].copy_from_slice(&(n as u32).to_le_bytes());
495 params_bytes[8..12].copy_from_slice(&(k as u32).to_le_bytes());
496 params_bytes[12..16].copy_from_slice(&(alpha as f32).to_le_bytes());
497 params_bytes[16..20].copy_from_slice(&(beta as f32).to_le_bytes());
498
499 let uniform_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
500 label: Some("oxicuda-gemm-params"),
501 size: 20,
502 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
503 mapped_at_creation: false,
504 });
505 dev.queue.write_buffer(&uniform_buf, 0, ¶ms_bytes);
506
507 let bind_group = {
509 let buffers = mem
510 .lock_buffers()
511 .map_err(|e| BackendError::DeviceError(e.to_string()))?;
512 let a_info = buffers
513 .get(&a_ptr)
514 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {a_ptr}")))?;
515 let b_info = buffers
516 .get(&b_ptr)
517 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {b_ptr}")))?;
518 let c_info = buffers
519 .get(&c_ptr)
520 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {c_ptr}")))?;
521
522 dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
523 label: Some("oxicuda-gemm"),
524 layout: &bgl,
525 entries: &[
526 wgpu::BindGroupEntry {
527 binding: 0,
528 resource: a_info.buffer.as_entire_binding(),
529 },
530 wgpu::BindGroupEntry {
531 binding: 1,
532 resource: b_info.buffer.as_entire_binding(),
533 },
534 wgpu::BindGroupEntry {
535 binding: 2,
536 resource: c_info.buffer.as_entire_binding(),
537 },
538 wgpu::BindGroupEntry {
539 binding: 3,
540 resource: uniform_buf.as_entire_binding(),
541 },
542 ],
543 })
544 };
545
546 let mut encoder = dev
547 .device
548 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
549 label: Some("oxicuda-gemm"),
550 });
551
552 {
553 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
554 label: Some("oxicuda-gemm"),
555 timestamp_writes: None,
556 });
557 pass.set_pipeline(&pipeline);
558 pass.set_bind_group(0, &bind_group, &[]);
559 let wg_x = (n as u32).div_ceil(tile_size);
560 let wg_y = (m as u32).div_ceil(tile_size);
561 pass.dispatch_workgroups(wg_x, wg_y, 1);
562 }
563
564 dev.queue.submit(std::iter::once(encoder.finish()));
565 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
566
567 Ok(())
568 }
569
570 #[allow(clippy::too_many_arguments)]
571 fn batched_gemm(
572 &self,
573 trans_a: BackendTranspose,
574 trans_b: BackendTranspose,
575 m: usize,
576 n: usize,
577 k: usize,
578 alpha: f64,
579 a_ptr: u64,
580 _lda: usize,
581 stride_a: usize,
582 b_ptr: u64,
583 _ldb: usize,
584 stride_b: usize,
585 beta: f64,
586 c_ptr: u64,
587 _ldc: usize,
588 stride_c: usize,
589 batch_count: usize,
590 ) -> BackendResult<()> {
591 self.check_init()?;
592
593 if batch_count == 0 || m == 0 || n == 0 || k == 0 {
594 return Ok(());
595 }
596
597 if trans_a != BackendTranspose::NoTrans || trans_b != BackendTranspose::NoTrans {
598 return Err(BackendError::Unsupported(
599 "WebGPU batched GEMM does not yet support transposed inputs".into(),
600 ));
601 }
602
603 let dev = self.device()?;
604 let mem = self.memory()?;
605
606 let tile_size: u32 = 8;
607 let wgsl = shader::batched_gemm_wgsl(tile_size);
608
609 let shader_mod = dev
610 .device
611 .create_shader_module(wgpu::ShaderModuleDescriptor {
612 label: Some("oxicuda-batched-gemm"),
613 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
614 });
615
616 let pipeline = dev
617 .device
618 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
619 label: Some("oxicuda-batched-gemm"),
620 layout: None,
621 module: &shader_mod,
622 entry_point: Some("main"),
623 compilation_options: Default::default(),
624 cache: None,
625 });
626
627 let bgl = pipeline.get_bind_group_layout(0);
628
629 let mut params_bytes = [0u8; 48];
635 params_bytes[0..4].copy_from_slice(&(m as u32).to_le_bytes());
636 params_bytes[4..8].copy_from_slice(&(n as u32).to_le_bytes());
637 params_bytes[8..12].copy_from_slice(&(k as u32).to_le_bytes());
638 params_bytes[12..16].copy_from_slice(&(alpha as f32).to_le_bytes());
639 params_bytes[16..20].copy_from_slice(&(beta as f32).to_le_bytes());
640 params_bytes[20..24].copy_from_slice(&(batch_count as u32).to_le_bytes());
641 params_bytes[24..28].copy_from_slice(&(stride_a as u32).to_le_bytes());
642 params_bytes[28..32].copy_from_slice(&(stride_b as u32).to_le_bytes());
643 params_bytes[32..36].copy_from_slice(&(stride_c as u32).to_le_bytes());
644 let uniform_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
647 label: Some("oxicuda-batched-gemm-params"),
648 size: 48,
649 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
650 mapped_at_creation: false,
651 });
652 dev.queue.write_buffer(&uniform_buf, 0, ¶ms_bytes);
653
654 let bind_group = {
655 let buffers = mem
656 .lock_buffers()
657 .map_err(|e| BackendError::DeviceError(e.to_string()))?;
658 let a_info = buffers
659 .get(&a_ptr)
660 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {a_ptr}")))?;
661 let b_info = buffers
662 .get(&b_ptr)
663 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {b_ptr}")))?;
664 let c_info = buffers
665 .get(&c_ptr)
666 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {c_ptr}")))?;
667
668 dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
669 label: Some("oxicuda-batched-gemm"),
670 layout: &bgl,
671 entries: &[
672 wgpu::BindGroupEntry {
673 binding: 0,
674 resource: a_info.buffer.as_entire_binding(),
675 },
676 wgpu::BindGroupEntry {
677 binding: 1,
678 resource: b_info.buffer.as_entire_binding(),
679 },
680 wgpu::BindGroupEntry {
681 binding: 2,
682 resource: c_info.buffer.as_entire_binding(),
683 },
684 wgpu::BindGroupEntry {
685 binding: 3,
686 resource: uniform_buf.as_entire_binding(),
687 },
688 ],
689 })
690 };
691
692 let mut encoder = dev
693 .device
694 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
695 label: Some("oxicuda-batched-gemm"),
696 });
697
698 {
699 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
700 label: Some("oxicuda-batched-gemm"),
701 timestamp_writes: None,
702 });
703 pass.set_pipeline(&pipeline);
704 pass.set_bind_group(0, &bind_group, &[]);
705 let wg_x = (n as u32).div_ceil(tile_size);
706 let wg_y = (m as u32).div_ceil(tile_size);
707 pass.dispatch_workgroups(wg_x, wg_y, batch_count as u32);
708 }
709
710 dev.queue.submit(std::iter::once(encoder.finish()));
711 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
712
713 Ok(())
714 }
715
716 fn conv2d_forward(
717 &self,
718 input_ptr: u64,
719 input_shape: &[usize],
720 filter_ptr: u64,
721 filter_shape: &[usize],
722 output_ptr: u64,
723 output_shape: &[usize],
724 stride: &[usize],
725 padding: &[usize],
726 ) -> BackendResult<()> {
727 self.check_init()?;
728
729 if input_shape.len() != 4 {
730 return Err(BackendError::InvalidArgument(
731 "input_shape must have 4 elements (NCHW)".into(),
732 ));
733 }
734 if filter_shape.len() != 4 {
735 return Err(BackendError::InvalidArgument(
736 "filter_shape must have 4 elements (KCFHFW)".into(),
737 ));
738 }
739 if output_shape.len() != 4 {
740 return Err(BackendError::InvalidArgument(
741 "output_shape must have 4 elements (NKOhOw)".into(),
742 ));
743 }
744 if stride.len() != 2 {
745 return Err(BackendError::InvalidArgument(
746 "stride must have 2 elements [sh, sw]".into(),
747 ));
748 }
749 if padding.len() != 2 {
750 return Err(BackendError::InvalidArgument(
751 "padding must have 2 elements [ph, pw]".into(),
752 ));
753 }
754
755 let mem = self.memory()?;
756
757 let batch = input_shape[0];
758 let c_in = input_shape[1];
759 let h_in = input_shape[2];
760 let w_in = input_shape[3];
761 let k_out = filter_shape[0];
762 let fh = filter_shape[2];
763 let fw = filter_shape[3];
764 let oh = output_shape[2];
765 let ow = output_shape[3];
766 let sh = stride[0];
767 let sw = stride[1];
768 let ph = padding[0];
769 let pw = padding[1];
770
771 let in_elems: usize = input_shape.iter().product();
772 let f_elems: usize = filter_shape.iter().product();
773 let o_elems: usize = output_shape.iter().product();
774
775 let mut in_bytes = vec![0u8; in_elems * 4];
777 let mut f_bytes = vec![0u8; f_elems * 4];
778 mem.copy_from_device(&mut in_bytes, input_ptr)
779 .map_err(BackendError::from)?;
780 mem.copy_from_device(&mut f_bytes, filter_ptr)
781 .map_err(BackendError::from)?;
782
783 let in_f32 = bytes_to_f32_vec(&in_bytes);
784 let f_f32 = bytes_to_f32_vec(&f_bytes);
785 let mut out_f32 = vec![0.0f32; o_elems];
786
787 for b in 0..batch {
788 for kf in 0..k_out {
789 for oy in 0..oh {
790 for ox in 0..ow {
791 let mut acc = 0.0f32;
792 for ci in 0..c_in {
793 for fy in 0..fh {
794 for fx in 0..fw {
795 let iy = (oy * sh + fy) as isize - ph as isize;
796 let ix = (ox * sw + fx) as isize - pw as isize;
797 if iy >= 0
798 && (iy as usize) < h_in
799 && ix >= 0
800 && (ix as usize) < w_in
801 {
802 let in_idx = ((b * c_in + ci) * h_in + iy as usize) * w_in
803 + ix as usize;
804 let f_idx = ((kf * c_in + ci) * fh + fy) * fw + fx;
805 acc += in_f32[in_idx] * f_f32[f_idx];
806 }
807 }
808 }
809 }
810 out_f32[((b * k_out + kf) * oh + oy) * ow + ox] = acc;
811 }
812 }
813 }
814 }
815
816 let out_bytes = f32_slice_to_bytes(&out_f32);
817 mem.copy_to_device(output_ptr, &out_bytes)
818 .map_err(BackendError::from)?;
819
820 Ok(())
821 }
822
823 fn attention(
824 &self,
825 q_ptr: u64,
826 k_ptr: u64,
827 v_ptr: u64,
828 o_ptr: u64,
829 batch: usize,
830 heads: usize,
831 seq_q: usize,
832 seq_kv: usize,
833 head_dim: usize,
834 scale: f64,
835 causal: bool,
836 ) -> BackendResult<()> {
837 self.check_init()?;
838
839 if seq_q == 0 || seq_kv == 0 || head_dim == 0 {
840 return Err(BackendError::InvalidArgument(
841 "seq_q, seq_kv, and head_dim must all be > 0".into(),
842 ));
843 }
844 if scale <= 0.0 || !scale.is_finite() {
845 return Err(BackendError::InvalidArgument(format!(
846 "scale must be a positive finite number, got {scale}"
847 )));
848 }
849
850 let mem = self.memory()?;
851
852 let batch_heads = batch * heads;
853 let q_elems = batch_heads * seq_q * head_dim;
854 let kv_elems = batch_heads * seq_kv * head_dim;
855 let o_elems = q_elems;
856
857 let mut q_bytes = vec![0u8; q_elems * 4];
859 let mut k_bytes = vec![0u8; kv_elems * 4];
860 let mut v_bytes = vec![0u8; kv_elems * 4];
861
862 mem.copy_from_device(&mut q_bytes, q_ptr)
863 .map_err(BackendError::from)?;
864 mem.copy_from_device(&mut k_bytes, k_ptr)
865 .map_err(BackendError::from)?;
866 mem.copy_from_device(&mut v_bytes, v_ptr)
867 .map_err(BackendError::from)?;
868
869 let q_f32 = bytes_to_f32_vec(&q_bytes);
870 let k_f32 = bytes_to_f32_vec(&k_bytes);
871 let v_f32 = bytes_to_f32_vec(&v_bytes);
872 let mut o_f32 = vec![0.0f32; o_elems];
873
874 let scale_f32 = scale as f32;
875
876 for bh in 0..batch_heads {
877 let q_off = bh * seq_q * head_dim;
878 let k_off = bh * seq_kv * head_dim;
879 let v_off = k_off;
880
881 for sq in 0..seq_q {
882 let kv_limit = if causal { (sq + 1).min(seq_kv) } else { seq_kv };
883
884 let mut max_score = f32::NEG_INFINITY;
886 for sk in 0..kv_limit {
887 let mut dot = 0.0f32;
888 for dd in 0..head_dim {
889 dot +=
890 q_f32[q_off + sq * head_dim + dd] * k_f32[k_off + sk * head_dim + dd];
891 }
892 let s = dot * scale_f32;
893 if s > max_score {
894 max_score = s;
895 }
896 }
897
898 let mut sum_exp = 0.0f32;
900 let mut acc = vec![0.0f32; head_dim];
901 for sk in 0..kv_limit {
902 let mut dot = 0.0f32;
903 for dd in 0..head_dim {
904 dot +=
905 q_f32[q_off + sq * head_dim + dd] * k_f32[k_off + sk * head_dim + dd];
906 }
907 let w = (dot * scale_f32 - max_score).exp();
908 sum_exp += w;
909 for dd in 0..head_dim {
910 acc[dd] += w * v_f32[v_off + sk * head_dim + dd];
911 }
912 }
913
914 let o_base = q_off + sq * head_dim;
916 if sum_exp > 0.0 {
917 for dd in 0..head_dim {
918 o_f32[o_base + dd] = acc[dd] / sum_exp;
919 }
920 }
921 }
922 }
923
924 let o_bytes = f32_slice_to_bytes(&o_f32);
925 mem.copy_to_device(o_ptr, &o_bytes)
926 .map_err(BackendError::from)?;
927
928 Ok(())
929 }
930
931 fn reduce(
932 &self,
933 op: ReduceOp,
934 input_ptr: u64,
935 output_ptr: u64,
936 shape: &[usize],
937 axis: usize,
938 ) -> BackendResult<()> {
939 self.check_init()?;
940
941 if shape.is_empty() {
942 return Err(BackendError::InvalidArgument(
943 "shape must not be empty".into(),
944 ));
945 }
946 if axis >= shape.len() {
947 return Err(BackendError::InvalidArgument(format!(
948 "axis {axis} is out of bounds for shape of length {}",
949 shape.len()
950 )));
951 }
952
953 if shape.len() != 1 {
957 return self.reduce_nd(op, input_ptr, output_ptr, shape, axis);
958 }
959
960 let n_elements = shape[0];
961 if n_elements == 0 {
962 return Ok(());
963 }
964
965 let dev = self.device()?;
966 let mem = self.memory()?;
967 let op_str = map_reduce_op(op);
968
969 let wg_count = (n_elements as u32).div_ceil(256);
971
972 let pass1_wgsl = shader::reduction_wgsl(op_str);
973 let pass1_shader = dev
974 .device
975 .create_shader_module(wgpu::ShaderModuleDescriptor {
976 label: Some("oxicuda-reduce-pass1"),
977 source: wgpu::ShaderSource::Wgsl(pass1_wgsl.into()),
978 });
979 let pass1_pipeline = dev
980 .device
981 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
982 label: Some("oxicuda-reduce-pass1"),
983 layout: None,
984 module: &pass1_shader,
985 entry_point: Some("main"),
986 compilation_options: Default::default(),
987 cache: None,
988 });
989
990 let partial_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
992 label: Some("oxicuda-reduce-partial"),
993 size: (wg_count as u64) * 4, usage: wgpu::BufferUsages::STORAGE
995 | wgpu::BufferUsages::COPY_SRC
996 | wgpu::BufferUsages::COPY_DST,
997 mapped_at_creation: false,
998 });
999
1000 let mut p1_params = [0u8; 4];
1002 p1_params[0..4].copy_from_slice(&(n_elements as u32).to_le_bytes());
1003 let p1_uniform = dev.device.create_buffer(&wgpu::BufferDescriptor {
1004 label: Some("oxicuda-reduce-p1-params"),
1005 size: 4,
1006 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
1007 mapped_at_creation: false,
1008 });
1009 dev.queue.write_buffer(&p1_uniform, 0, &p1_params);
1010
1011 let bgl1 = pass1_pipeline.get_bind_group_layout(0);
1012
1013 let bg1 = {
1014 let buffers = mem
1015 .lock_buffers()
1016 .map_err(|e| BackendError::DeviceError(e.to_string()))?;
1017 let in_info = buffers.get(&input_ptr).ok_or_else(|| {
1018 BackendError::InvalidArgument(format!("unknown handle {input_ptr}"))
1019 })?;
1020
1021 dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
1022 label: Some("oxicuda-reduce-pass1"),
1023 layout: &bgl1,
1024 entries: &[
1025 wgpu::BindGroupEntry {
1026 binding: 0,
1027 resource: in_info.buffer.as_entire_binding(),
1028 },
1029 wgpu::BindGroupEntry {
1030 binding: 1,
1031 resource: partial_buf.as_entire_binding(),
1032 },
1033 wgpu::BindGroupEntry {
1034 binding: 2,
1035 resource: p1_uniform.as_entire_binding(),
1036 },
1037 ],
1038 })
1039 };
1040
1041 let pass2_wgsl = shader::reduction_final_wgsl(op_str);
1043 let pass2_shader = dev
1044 .device
1045 .create_shader_module(wgpu::ShaderModuleDescriptor {
1046 label: Some("oxicuda-reduce-pass2"),
1047 source: wgpu::ShaderSource::Wgsl(pass2_wgsl.into()),
1048 });
1049 let pass2_pipeline = dev
1050 .device
1051 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1052 label: Some("oxicuda-reduce-pass2"),
1053 layout: None,
1054 module: &pass2_shader,
1055 entry_point: Some("main"),
1056 compilation_options: Default::default(),
1057 cache: None,
1058 });
1059
1060 let mut p2_params = [0u8; 4];
1062 p2_params[0..4].copy_from_slice(&wg_count.to_le_bytes());
1063 let p2_uniform = dev.device.create_buffer(&wgpu::BufferDescriptor {
1064 label: Some("oxicuda-reduce-p2-params"),
1065 size: 4,
1066 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
1067 mapped_at_creation: false,
1068 });
1069 dev.queue.write_buffer(&p2_uniform, 0, &p2_params);
1070
1071 let bgl2 = pass2_pipeline.get_bind_group_layout(0);
1072
1073 let bg2 = {
1074 let buffers = mem
1075 .lock_buffers()
1076 .map_err(|e| BackendError::DeviceError(e.to_string()))?;
1077 let out_info = buffers.get(&output_ptr).ok_or_else(|| {
1078 BackendError::InvalidArgument(format!("unknown handle {output_ptr}"))
1079 })?;
1080
1081 dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
1082 label: Some("oxicuda-reduce-pass2"),
1083 layout: &bgl2,
1084 entries: &[
1085 wgpu::BindGroupEntry {
1086 binding: 0,
1087 resource: partial_buf.as_entire_binding(),
1088 },
1089 wgpu::BindGroupEntry {
1090 binding: 1,
1091 resource: out_info.buffer.as_entire_binding(),
1092 },
1093 wgpu::BindGroupEntry {
1094 binding: 2,
1095 resource: p2_uniform.as_entire_binding(),
1096 },
1097 ],
1098 })
1099 };
1100
1101 let mut encoder = dev
1103 .device
1104 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1105 label: Some("oxicuda-reduce"),
1106 });
1107
1108 {
1109 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1110 label: Some("oxicuda-reduce-pass1"),
1111 timestamp_writes: None,
1112 });
1113 pass.set_pipeline(&pass1_pipeline);
1114 pass.set_bind_group(0, &bg1, &[]);
1115 pass.dispatch_workgroups(wg_count, 1, 1);
1116 }
1117 {
1118 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1119 label: Some("oxicuda-reduce-pass2"),
1120 timestamp_writes: None,
1121 });
1122 pass.set_pipeline(&pass2_pipeline);
1123 pass.set_bind_group(0, &bg2, &[]);
1124 pass.dispatch_workgroups(1, 1, 1);
1125 }
1126
1127 dev.queue.submit(std::iter::once(encoder.finish()));
1128 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
1129
1130 if op == ReduceOp::Mean && n_elements > 1 {
1132 let mut buf = [0u8; 4];
1133 mem.copy_from_device(&mut buf, output_ptr)
1134 .map_err(BackendError::from)?;
1135 let val = f32::from_le_bytes(buf);
1136 let mean = val / (n_elements as f32);
1137 mem.copy_to_device(output_ptr, &mean.to_le_bytes())
1138 .map_err(BackendError::from)?;
1139 }
1140
1141 Ok(())
1142 }
1143
1144 fn unary(&self, op: UnaryOp, input_ptr: u64, output_ptr: u64, n: usize) -> BackendResult<()> {
1145 self.check_init()?;
1146 if n == 0 {
1147 return Ok(());
1148 }
1149
1150 let dev = self.device()?;
1151 let mem = self.memory()?;
1152
1153 let op_str = map_unary_op(op);
1154 let wgsl = shader::elementwise_wgsl(op_str);
1155
1156 let shader_mod = dev
1157 .device
1158 .create_shader_module(wgpu::ShaderModuleDescriptor {
1159 label: Some("oxicuda-unary"),
1160 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1161 });
1162
1163 let pipeline = dev
1164 .device
1165 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1166 label: Some("oxicuda-unary"),
1167 layout: None,
1168 module: &shader_mod,
1169 entry_point: Some("main"),
1170 compilation_options: Default::default(),
1171 cache: None,
1172 });
1173
1174 let bgl = pipeline.get_bind_group_layout(0);
1175
1176 let bind_group = {
1177 let buffers = mem
1178 .lock_buffers()
1179 .map_err(|e| BackendError::DeviceError(e.to_string()))?;
1180 let in_info = buffers.get(&input_ptr).ok_or_else(|| {
1181 BackendError::InvalidArgument(format!("unknown handle {input_ptr}"))
1182 })?;
1183 let out_info = buffers.get(&output_ptr).ok_or_else(|| {
1184 BackendError::InvalidArgument(format!("unknown handle {output_ptr}"))
1185 })?;
1186
1187 dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
1188 label: Some("oxicuda-unary"),
1189 layout: &bgl,
1190 entries: &[
1191 wgpu::BindGroupEntry {
1192 binding: 0,
1193 resource: in_info.buffer.as_entire_binding(),
1194 },
1195 wgpu::BindGroupEntry {
1196 binding: 1,
1197 resource: out_info.buffer.as_entire_binding(),
1198 },
1199 ],
1200 })
1201 };
1202
1203 let mut encoder = dev
1204 .device
1205 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1206 label: Some("oxicuda-unary"),
1207 });
1208
1209 {
1210 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1211 label: Some("oxicuda-unary"),
1212 timestamp_writes: None,
1213 });
1214 pass.set_pipeline(&pipeline);
1215 pass.set_bind_group(0, &bind_group, &[]);
1216 let workgroups = (n as u32).div_ceil(256);
1217 pass.dispatch_workgroups(workgroups, 1, 1);
1218 }
1219
1220 dev.queue.submit(std::iter::once(encoder.finish()));
1221 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
1222
1223 Ok(())
1224 }
1225
1226 fn binary(
1227 &self,
1228 op: BinaryOp,
1229 a_ptr: u64,
1230 b_ptr: u64,
1231 output_ptr: u64,
1232 n: usize,
1233 ) -> BackendResult<()> {
1234 self.check_init()?;
1235 if n == 0 {
1236 return Ok(());
1237 }
1238
1239 let dev = self.device()?;
1240 let mem = self.memory()?;
1241
1242 let op_str = map_binary_op(op);
1243 let wgsl = shader::binary_wgsl(op_str);
1244
1245 let shader_mod = dev
1246 .device
1247 .create_shader_module(wgpu::ShaderModuleDescriptor {
1248 label: Some("oxicuda-binary"),
1249 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1250 });
1251
1252 let pipeline = dev
1253 .device
1254 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1255 label: Some("oxicuda-binary"),
1256 layout: None,
1257 module: &shader_mod,
1258 entry_point: Some("main"),
1259 compilation_options: Default::default(),
1260 cache: None,
1261 });
1262
1263 let bgl = pipeline.get_bind_group_layout(0);
1264
1265 let bind_group = {
1266 let buffers = mem
1267 .lock_buffers()
1268 .map_err(|e| BackendError::DeviceError(e.to_string()))?;
1269 let a_info = buffers
1270 .get(&a_ptr)
1271 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {a_ptr}")))?;
1272 let b_info = buffers
1273 .get(&b_ptr)
1274 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {b_ptr}")))?;
1275 let out_info = buffers.get(&output_ptr).ok_or_else(|| {
1276 BackendError::InvalidArgument(format!("unknown handle {output_ptr}"))
1277 })?;
1278
1279 dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
1280 label: Some("oxicuda-binary"),
1281 layout: &bgl,
1282 entries: &[
1283 wgpu::BindGroupEntry {
1284 binding: 0,
1285 resource: a_info.buffer.as_entire_binding(),
1286 },
1287 wgpu::BindGroupEntry {
1288 binding: 1,
1289 resource: b_info.buffer.as_entire_binding(),
1290 },
1291 wgpu::BindGroupEntry {
1292 binding: 2,
1293 resource: out_info.buffer.as_entire_binding(),
1294 },
1295 ],
1296 })
1297 };
1298
1299 let mut encoder = dev
1300 .device
1301 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1302 label: Some("oxicuda-binary"),
1303 });
1304
1305 {
1306 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1307 label: Some("oxicuda-binary"),
1308 timestamp_writes: None,
1309 });
1310 pass.set_pipeline(&pipeline);
1311 pass.set_bind_group(0, &bind_group, &[]);
1312 let workgroups = (n as u32).div_ceil(256);
1313 pass.dispatch_workgroups(workgroups, 1, 1);
1314 }
1315
1316 dev.queue.submit(std::iter::once(encoder.finish()));
1317 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
1318
1319 Ok(())
1320 }
1321
1322 fn synchronize(&self) -> BackendResult<()> {
1325 self.check_init()?;
1326 if let Some(dev) = &self.device {
1327 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
1328 }
1329 Ok(())
1330 }
1331
1332 fn alloc(&self, bytes: usize) -> BackendResult<u64> {
1335 self.check_init()?;
1336 if bytes == 0 {
1337 return Err(BackendError::InvalidArgument(
1338 "cannot allocate 0 bytes".into(),
1339 ));
1340 }
1341 self.memory()?.alloc(bytes).map_err(BackendError::from)
1342 }
1343
1344 fn free(&self, ptr: u64) -> BackendResult<()> {
1345 self.check_init()?;
1346 self.memory()?.free(ptr).map_err(BackendError::from)
1347 }
1348
1349 fn copy_htod(&self, dst: u64, src: &[u8]) -> BackendResult<()> {
1350 self.check_init()?;
1351 if src.is_empty() {
1352 return Ok(());
1353 }
1354 self.memory()?
1355 .copy_to_device(dst, src)
1356 .map_err(BackendError::from)
1357 }
1358
1359 fn copy_dtoh(&self, dst: &mut [u8], src: u64) -> BackendResult<()> {
1360 self.check_init()?;
1361 if dst.is_empty() {
1362 return Ok(());
1363 }
1364 self.memory()?
1365 .copy_from_device(dst, src)
1366 .map_err(BackendError::from)
1367 }
1368}
1369
1370fn bytes_to_f32_vec(bytes: &[u8]) -> Vec<f32> {
1374 bytes
1375 .chunks_exact(4)
1376 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
1377 .collect()
1378}
1379
1380fn f32_slice_to_bytes(data: &[f32]) -> Vec<u8> {
1382 data.iter().flat_map(|v| v.to_le_bytes()).collect()
1383}
1384
1385#[cfg(test)]
1390#[path = "backend_tests.rs"]
1391mod tests;