1use std::sync::Arc;
7
8use oxicuda_backend::{
9 BackendError, BackendResult, BackendTranspose, BinaryOp, ComputeBackend, ReduceOp, UnaryOp,
10};
11use wgpu;
12
13use crate::{device::WebGpuDevice, memory::WebGpuMemoryManager, shader};
14
15fn map_unary_op(op: UnaryOp) -> &'static str {
18 match op {
19 UnaryOp::Relu => "relu",
20 UnaryOp::Sigmoid => "sigmoid",
21 UnaryOp::Tanh => "tanh",
22 UnaryOp::Exp => "exp",
23 UnaryOp::Log => "log",
24 UnaryOp::Sqrt => "sqrt",
25 UnaryOp::Abs => "abs",
26 UnaryOp::Neg => "neg",
27 }
28}
29
30fn map_binary_op(op: BinaryOp) -> &'static str {
31 match op {
32 BinaryOp::Add => "add",
33 BinaryOp::Sub => "sub",
34 BinaryOp::Mul => "mul",
35 BinaryOp::Div => "div",
36 BinaryOp::Max => "max",
37 BinaryOp::Min => "min",
38 }
39}
40
41fn map_reduce_op(op: ReduceOp) -> &'static str {
42 match op {
43 ReduceOp::Sum => "sum",
44 ReduceOp::Max => "max",
45 ReduceOp::Min => "min",
46 ReduceOp::Mean => "mean",
47 }
48}
49
50#[derive(Debug)]
61pub struct WebGpuBackend {
62 device: Option<Arc<WebGpuDevice>>,
63 memory: Option<Arc<WebGpuMemoryManager>>,
64 initialized: bool,
65}
66
67impl WebGpuBackend {
68 pub fn new() -> Self {
70 Self {
71 device: None,
72 memory: None,
73 initialized: false,
74 }
75 }
76
77 fn check_init(&self) -> BackendResult<()> {
79 if self.initialized {
80 Ok(())
81 } else {
82 Err(BackendError::NotInitialized)
83 }
84 }
85
86 fn memory(&self) -> BackendResult<&Arc<WebGpuMemoryManager>> {
88 self.memory.as_ref().ok_or(BackendError::NotInitialized)
89 }
90
91 fn device(&self) -> BackendResult<&Arc<WebGpuDevice>> {
93 self.device.as_ref().ok_or(BackendError::NotInitialized)
94 }
95
96 fn reduce_nd(
110 &self,
111 op: ReduceOp,
112 input_ptr: u64,
113 output_ptr: u64,
114 shape: &[usize],
115 axis: usize,
116 ) -> BackendResult<()> {
117 debug_assert!(!shape.is_empty());
121 debug_assert!(axis < shape.len());
122
123 let outer: usize = shape[..axis].iter().product();
125 let dk: usize = shape[axis];
126 let inner: usize = shape[axis + 1..].iter().product();
127
128 if outer == 0 || dk == 0 || inner == 0 {
130 return Ok(());
131 }
132
133 let total = outer.checked_mul(inner).ok_or_else(|| {
134 BackendError::InvalidArgument("reduce: outer * inner overflows usize".into())
135 })?;
136
137 let inner_stride: usize = 1;
139 let dk_stride: usize = inner;
140 let outer_stride: usize = dk
141 .checked_mul(inner)
142 .ok_or_else(|| BackendError::InvalidArgument("reduce: dk * inner overflows".into()))?;
143
144 const MAX_GRID_DIM: u32 = 32_768;
147 let total_u32: u32 = total.try_into().map_err(|_| {
148 BackendError::InvalidArgument(format!(
149 "reduce: output element count {total} exceeds u32 range"
150 ))
151 })?;
152 let grid_x: u32 = total_u32.clamp(1, MAX_GRID_DIM);
153 let grid_y: u32 = total_u32.div_ceil(grid_x);
154
155 let dev = self.device()?;
156 let mem = self.memory()?;
157 let op_str = map_reduce_op(op);
158
159 let wgsl = shader::reduction_nd_wgsl(op_str);
160 let shader_mod = dev
161 .device
162 .create_shader_module(wgpu::ShaderModuleDescriptor {
163 label: Some("oxicuda-reduce-nd"),
164 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
165 });
166 let pipeline = dev
167 .device
168 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
169 label: Some("oxicuda-reduce-nd"),
170 layout: None,
171 module: &shader_mod,
172 entry_point: Some("main"),
173 compilation_options: Default::default(),
174 cache: None,
175 });
176
177 let mut params_bytes = [0u8; 32];
179 let outer_u32: u32 = outer
180 .try_into()
181 .map_err(|_| BackendError::InvalidArgument("reduce: outer exceeds u32 range".into()))?;
182 let dk_u32: u32 = dk
183 .try_into()
184 .map_err(|_| BackendError::InvalidArgument("reduce: dk exceeds u32 range".into()))?;
185 let inner_u32: u32 = inner
186 .try_into()
187 .map_err(|_| BackendError::InvalidArgument("reduce: inner exceeds u32 range".into()))?;
188 let outer_stride_u32: u32 = outer_stride.try_into().map_err(|_| {
189 BackendError::InvalidArgument("reduce: outer_stride exceeds u32 range".into())
190 })?;
191 let dk_stride_u32: u32 = dk_stride.try_into().map_err(|_| {
192 BackendError::InvalidArgument("reduce: dk_stride exceeds u32 range".into())
193 })?;
194 let inner_stride_u32: u32 = inner_stride.try_into().map_err(|_| {
195 BackendError::InvalidArgument("reduce: inner_stride exceeds u32 range".into())
196 })?;
197 params_bytes[0..4].copy_from_slice(&outer_u32.to_le_bytes());
198 params_bytes[4..8].copy_from_slice(&dk_u32.to_le_bytes());
199 params_bytes[8..12].copy_from_slice(&inner_u32.to_le_bytes());
200 params_bytes[12..16].copy_from_slice(&outer_stride_u32.to_le_bytes());
201 params_bytes[16..20].copy_from_slice(&dk_stride_u32.to_le_bytes());
202 params_bytes[20..24].copy_from_slice(&inner_stride_u32.to_le_bytes());
203 params_bytes[24..28].copy_from_slice(&grid_x.to_le_bytes());
204 let uniform_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
207 label: Some("oxicuda-reduce-nd-params"),
208 size: 32,
209 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
210 mapped_at_creation: false,
211 });
212 dev.queue.write_buffer(&uniform_buf, 0, ¶ms_bytes);
213
214 let bgl = pipeline.get_bind_group_layout(0);
215 let bind_group = {
216 let buffers = mem
217 .lock_buffers()
218 .map_err(|e| BackendError::DeviceError(e.to_string()))?;
219 let in_info = buffers.get(&input_ptr).ok_or_else(|| {
220 BackendError::InvalidArgument(format!("unknown handle {input_ptr}"))
221 })?;
222 let out_info = buffers.get(&output_ptr).ok_or_else(|| {
223 BackendError::InvalidArgument(format!("unknown handle {output_ptr}"))
224 })?;
225
226 dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
227 label: Some("oxicuda-reduce-nd"),
228 layout: &bgl,
229 entries: &[
230 wgpu::BindGroupEntry {
231 binding: 0,
232 resource: in_info.buffer.as_entire_binding(),
233 },
234 wgpu::BindGroupEntry {
235 binding: 1,
236 resource: out_info.buffer.as_entire_binding(),
237 },
238 wgpu::BindGroupEntry {
239 binding: 2,
240 resource: uniform_buf.as_entire_binding(),
241 },
242 ],
243 })
244 };
245
246 let mut encoder = dev
247 .device
248 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
249 label: Some("oxicuda-reduce-nd"),
250 });
251 {
252 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
253 label: Some("oxicuda-reduce-nd"),
254 timestamp_writes: None,
255 });
256 pass.set_pipeline(&pipeline);
257 pass.set_bind_group(0, &bind_group, &[]);
258 pass.dispatch_workgroups(grid_x, grid_y, 1);
259 }
260
261 dev.queue.submit(std::iter::once(encoder.finish()));
262 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
263
264 Ok(())
265 }
266}
267
268impl WebGpuBackend {
269 #[allow(clippy::too_many_arguments)]
277 pub fn gemm_f16(
278 &self,
279 m: usize,
280 n: usize,
281 k: usize,
282 alpha: f64,
283 a_ptr: u64,
284 b_ptr: u64,
285 beta: f64,
286 c_ptr: u64,
287 ) -> BackendResult<()> {
288 self.check_init()?;
289 if m == 0 || n == 0 || k == 0 {
290 return Ok(());
291 }
292
293 let dev = self.device()?;
294 let mem = self.memory()?;
295
296 let tile_size: u32 = 8;
297 let wgsl = shader::gemm_wgsl_f16(tile_size);
298
299 let shader_mod = dev
300 .device
301 .create_shader_module(wgpu::ShaderModuleDescriptor {
302 label: Some("oxicuda-gemm-f16"),
303 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
304 });
305
306 let pipeline = dev
307 .device
308 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
309 label: Some("oxicuda-gemm-f16"),
310 layout: None,
311 module: &shader_mod,
312 entry_point: Some("main"),
313 compilation_options: Default::default(),
314 cache: None,
315 });
316
317 let bgl = pipeline.get_bind_group_layout(0);
318
319 let mut params_bytes = [0u8; 20];
321 params_bytes[0..4].copy_from_slice(&(m as u32).to_le_bytes());
322 params_bytes[4..8].copy_from_slice(&(n as u32).to_le_bytes());
323 params_bytes[8..12].copy_from_slice(&(k as u32).to_le_bytes());
324 params_bytes[12..16].copy_from_slice(&(alpha as f32).to_le_bytes());
325 params_bytes[16..20].copy_from_slice(&(beta as f32).to_le_bytes());
326
327 let uniform_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
328 label: Some("oxicuda-gemm-f16-params"),
329 size: 20,
330 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
331 mapped_at_creation: false,
332 });
333 dev.queue.write_buffer(&uniform_buf, 0, ¶ms_bytes);
334
335 let bind_group = {
336 let buffers = mem
337 .lock_buffers()
338 .map_err(|e| BackendError::DeviceError(e.to_string()))?;
339 let a_info = buffers
340 .get(&a_ptr)
341 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {a_ptr}")))?;
342 let b_info = buffers
343 .get(&b_ptr)
344 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {b_ptr}")))?;
345 let c_info = buffers
346 .get(&c_ptr)
347 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {c_ptr}")))?;
348
349 dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
350 label: Some("oxicuda-gemm-f16"),
351 layout: &bgl,
352 entries: &[
353 wgpu::BindGroupEntry {
354 binding: 0,
355 resource: a_info.buffer.as_entire_binding(),
356 },
357 wgpu::BindGroupEntry {
358 binding: 1,
359 resource: b_info.buffer.as_entire_binding(),
360 },
361 wgpu::BindGroupEntry {
362 binding: 2,
363 resource: c_info.buffer.as_entire_binding(),
364 },
365 wgpu::BindGroupEntry {
366 binding: 3,
367 resource: uniform_buf.as_entire_binding(),
368 },
369 ],
370 })
371 };
372
373 let mut encoder = dev
374 .device
375 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
376 label: Some("oxicuda-gemm-f16"),
377 });
378
379 {
380 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
381 label: Some("oxicuda-gemm-f16"),
382 timestamp_writes: None,
383 });
384 pass.set_pipeline(&pipeline);
385 pass.set_bind_group(0, &bind_group, &[]);
386 let wg_x = (n as u32).div_ceil(tile_size);
387 let wg_y = (m as u32).div_ceil(tile_size);
388 pass.dispatch_workgroups(wg_x, wg_y, 1);
389 }
390
391 dev.queue.submit(std::iter::once(encoder.finish()));
392 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
393
394 Ok(())
395 }
396}
397
398impl Default for WebGpuBackend {
399 fn default() -> Self {
400 Self::new()
401 }
402}
403
404impl ComputeBackend for WebGpuBackend {
407 fn name(&self) -> &str {
408 "webgpu"
409 }
410
411 fn init(&mut self) -> BackendResult<()> {
412 if self.initialized {
413 return Ok(());
414 }
415
416 match WebGpuDevice::new() {
417 Ok(dev) => {
418 let dev = Arc::new(dev);
419 tracing::info!("WebGPU backend initialised on: {}", dev.adapter_name);
420 let memory = WebGpuMemoryManager::new(Arc::clone(&dev));
421 self.device = Some(dev);
422 self.memory = Some(Arc::new(memory));
423 self.initialized = true;
424 Ok(())
425 }
426 Err(e) => Err(BackendError::from(e)),
427 }
428 }
429
430 fn is_initialized(&self) -> bool {
431 self.initialized
432 }
433
434 fn gemm(
437 &self,
438 trans_a: BackendTranspose,
439 trans_b: BackendTranspose,
440 m: usize,
441 n: usize,
442 k: usize,
443 alpha: f64,
444 a_ptr: u64,
445 _lda: usize,
446 b_ptr: u64,
447 _ldb: usize,
448 beta: f64,
449 c_ptr: u64,
450 _ldc: usize,
451 ) -> BackendResult<()> {
452 self.check_init()?;
453 if m == 0 || n == 0 || k == 0 {
455 return Ok(());
456 }
457
458 let trans_a_flag: u32 = u32::from(trans_a != BackendTranspose::NoTrans);
462 let trans_b_flag: u32 = u32::from(trans_b != BackendTranspose::NoTrans);
463
464 let dev = self.device()?;
465 let mem = self.memory()?;
466
467 let tile_size: u32 = 8;
468 let wgsl = shader::gemm_wgsl(tile_size);
469
470 let shader_mod = dev
471 .device
472 .create_shader_module(wgpu::ShaderModuleDescriptor {
473 label: Some("oxicuda-gemm"),
474 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
475 });
476
477 let pipeline = dev
478 .device
479 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
480 label: Some("oxicuda-gemm"),
481 layout: None,
482 module: &shader_mod,
483 entry_point: Some("main"),
484 compilation_options: Default::default(),
485 cache: None,
486 });
487
488 let bgl = pipeline.get_bind_group_layout(0);
489
490 let mut params_bytes = [0u8; 32];
493 params_bytes[0..4].copy_from_slice(&(m as u32).to_le_bytes());
494 params_bytes[4..8].copy_from_slice(&(n as u32).to_le_bytes());
495 params_bytes[8..12].copy_from_slice(&(k as u32).to_le_bytes());
496 params_bytes[12..16].copy_from_slice(&(alpha as f32).to_le_bytes());
497 params_bytes[16..20].copy_from_slice(&(beta as f32).to_le_bytes());
498 params_bytes[20..24].copy_from_slice(&trans_a_flag.to_le_bytes());
499 params_bytes[24..28].copy_from_slice(&trans_b_flag.to_le_bytes());
500 let uniform_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
503 label: Some("oxicuda-gemm-params"),
504 size: 32,
505 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
506 mapped_at_creation: false,
507 });
508 dev.queue.write_buffer(&uniform_buf, 0, ¶ms_bytes);
509
510 let bind_group = {
512 let buffers = mem
513 .lock_buffers()
514 .map_err(|e| BackendError::DeviceError(e.to_string()))?;
515 let a_info = buffers
516 .get(&a_ptr)
517 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {a_ptr}")))?;
518 let b_info = buffers
519 .get(&b_ptr)
520 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {b_ptr}")))?;
521 let c_info = buffers
522 .get(&c_ptr)
523 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {c_ptr}")))?;
524
525 dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
526 label: Some("oxicuda-gemm"),
527 layout: &bgl,
528 entries: &[
529 wgpu::BindGroupEntry {
530 binding: 0,
531 resource: a_info.buffer.as_entire_binding(),
532 },
533 wgpu::BindGroupEntry {
534 binding: 1,
535 resource: b_info.buffer.as_entire_binding(),
536 },
537 wgpu::BindGroupEntry {
538 binding: 2,
539 resource: c_info.buffer.as_entire_binding(),
540 },
541 wgpu::BindGroupEntry {
542 binding: 3,
543 resource: uniform_buf.as_entire_binding(),
544 },
545 ],
546 })
547 };
548
549 let mut encoder = dev
550 .device
551 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
552 label: Some("oxicuda-gemm"),
553 });
554
555 {
556 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
557 label: Some("oxicuda-gemm"),
558 timestamp_writes: None,
559 });
560 pass.set_pipeline(&pipeline);
561 pass.set_bind_group(0, &bind_group, &[]);
562 let wg_x = (n as u32).div_ceil(tile_size);
563 let wg_y = (m as u32).div_ceil(tile_size);
564 pass.dispatch_workgroups(wg_x, wg_y, 1);
565 }
566
567 dev.queue.submit(std::iter::once(encoder.finish()));
568 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
569
570 Ok(())
571 }
572
573 #[allow(clippy::too_many_arguments)]
574 fn batched_gemm(
575 &self,
576 trans_a: BackendTranspose,
577 trans_b: BackendTranspose,
578 m: usize,
579 n: usize,
580 k: usize,
581 alpha: f64,
582 a_ptr: u64,
583 _lda: usize,
584 stride_a: usize,
585 b_ptr: u64,
586 _ldb: usize,
587 stride_b: usize,
588 beta: f64,
589 c_ptr: u64,
590 _ldc: usize,
591 stride_c: usize,
592 batch_count: usize,
593 ) -> BackendResult<()> {
594 self.check_init()?;
595
596 if batch_count == 0 || m == 0 || n == 0 || k == 0 {
597 return Ok(());
598 }
599
600 let trans_a_flag: u32 = u32::from(trans_a != BackendTranspose::NoTrans);
604 let trans_b_flag: u32 = u32::from(trans_b != BackendTranspose::NoTrans);
605
606 let dev = self.device()?;
607 let mem = self.memory()?;
608
609 let tile_size: u32 = 8;
610 let wgsl = shader::batched_gemm_wgsl(tile_size);
611
612 let shader_mod = dev
613 .device
614 .create_shader_module(wgpu::ShaderModuleDescriptor {
615 label: Some("oxicuda-batched-gemm"),
616 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
617 });
618
619 let pipeline = dev
620 .device
621 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
622 label: Some("oxicuda-batched-gemm"),
623 layout: None,
624 module: &shader_mod,
625 entry_point: Some("main"),
626 compilation_options: Default::default(),
627 cache: None,
628 });
629
630 let bgl = pipeline.get_bind_group_layout(0);
631
632 let mut params_bytes = [0u8; 48];
636 params_bytes[0..4].copy_from_slice(&(m as u32).to_le_bytes());
637 params_bytes[4..8].copy_from_slice(&(n as u32).to_le_bytes());
638 params_bytes[8..12].copy_from_slice(&(k as u32).to_le_bytes());
639 params_bytes[12..16].copy_from_slice(&(alpha as f32).to_le_bytes());
640 params_bytes[16..20].copy_from_slice(&(beta as f32).to_le_bytes());
641 params_bytes[20..24].copy_from_slice(&(batch_count as u32).to_le_bytes());
642 params_bytes[24..28].copy_from_slice(&(stride_a as u32).to_le_bytes());
643 params_bytes[28..32].copy_from_slice(&(stride_b as u32).to_le_bytes());
644 params_bytes[32..36].copy_from_slice(&(stride_c as u32).to_le_bytes());
645 params_bytes[36..40].copy_from_slice(&trans_a_flag.to_le_bytes());
646 params_bytes[40..44].copy_from_slice(&trans_b_flag.to_le_bytes());
647 let uniform_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
650 label: Some("oxicuda-batched-gemm-params"),
651 size: 48,
652 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
653 mapped_at_creation: false,
654 });
655 dev.queue.write_buffer(&uniform_buf, 0, ¶ms_bytes);
656
657 let bind_group = {
658 let buffers = mem
659 .lock_buffers()
660 .map_err(|e| BackendError::DeviceError(e.to_string()))?;
661 let a_info = buffers
662 .get(&a_ptr)
663 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {a_ptr}")))?;
664 let b_info = buffers
665 .get(&b_ptr)
666 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {b_ptr}")))?;
667 let c_info = buffers
668 .get(&c_ptr)
669 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {c_ptr}")))?;
670
671 dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
672 label: Some("oxicuda-batched-gemm"),
673 layout: &bgl,
674 entries: &[
675 wgpu::BindGroupEntry {
676 binding: 0,
677 resource: a_info.buffer.as_entire_binding(),
678 },
679 wgpu::BindGroupEntry {
680 binding: 1,
681 resource: b_info.buffer.as_entire_binding(),
682 },
683 wgpu::BindGroupEntry {
684 binding: 2,
685 resource: c_info.buffer.as_entire_binding(),
686 },
687 wgpu::BindGroupEntry {
688 binding: 3,
689 resource: uniform_buf.as_entire_binding(),
690 },
691 ],
692 })
693 };
694
695 let mut encoder = dev
696 .device
697 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
698 label: Some("oxicuda-batched-gemm"),
699 });
700
701 {
702 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
703 label: Some("oxicuda-batched-gemm"),
704 timestamp_writes: None,
705 });
706 pass.set_pipeline(&pipeline);
707 pass.set_bind_group(0, &bind_group, &[]);
708 let wg_x = (n as u32).div_ceil(tile_size);
709 let wg_y = (m as u32).div_ceil(tile_size);
710 pass.dispatch_workgroups(wg_x, wg_y, batch_count as u32);
711 }
712
713 dev.queue.submit(std::iter::once(encoder.finish()));
714 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
715
716 Ok(())
717 }
718
719 fn conv2d_forward(
720 &self,
721 input_ptr: u64,
722 input_shape: &[usize],
723 filter_ptr: u64,
724 filter_shape: &[usize],
725 output_ptr: u64,
726 output_shape: &[usize],
727 stride: &[usize],
728 padding: &[usize],
729 ) -> BackendResult<()> {
730 self.check_init()?;
731
732 if input_shape.len() != 4 {
733 return Err(BackendError::InvalidArgument(
734 "input_shape must have 4 elements (NCHW)".into(),
735 ));
736 }
737 if filter_shape.len() != 4 {
738 return Err(BackendError::InvalidArgument(
739 "filter_shape must have 4 elements (KCFHFW)".into(),
740 ));
741 }
742 if output_shape.len() != 4 {
743 return Err(BackendError::InvalidArgument(
744 "output_shape must have 4 elements (NKOhOw)".into(),
745 ));
746 }
747 if stride.len() != 2 {
748 return Err(BackendError::InvalidArgument(
749 "stride must have 2 elements [sh, sw]".into(),
750 ));
751 }
752 if padding.len() != 2 {
753 return Err(BackendError::InvalidArgument(
754 "padding must have 2 elements [ph, pw]".into(),
755 ));
756 }
757
758 let mem = self.memory()?;
759
760 let batch = input_shape[0];
761 let c_in = input_shape[1];
762 let h_in = input_shape[2];
763 let w_in = input_shape[3];
764 let k_out = filter_shape[0];
765 let fh = filter_shape[2];
766 let fw = filter_shape[3];
767 let oh = output_shape[2];
768 let ow = output_shape[3];
769 let sh = stride[0];
770 let sw = stride[1];
771 let ph = padding[0];
772 let pw = padding[1];
773
774 let in_elems: usize = input_shape.iter().product();
775 let f_elems: usize = filter_shape.iter().product();
776 let o_elems: usize = output_shape.iter().product();
777
778 let mut in_bytes = vec![0u8; in_elems * 4];
780 let mut f_bytes = vec![0u8; f_elems * 4];
781 mem.copy_from_device(&mut in_bytes, input_ptr)
782 .map_err(BackendError::from)?;
783 mem.copy_from_device(&mut f_bytes, filter_ptr)
784 .map_err(BackendError::from)?;
785
786 let in_f32 = bytes_to_f32_vec(&in_bytes);
787 let f_f32 = bytes_to_f32_vec(&f_bytes);
788 let mut out_f32 = vec![0.0f32; o_elems];
789
790 for b in 0..batch {
791 for kf in 0..k_out {
792 for oy in 0..oh {
793 for ox in 0..ow {
794 let mut acc = 0.0f32;
795 for ci in 0..c_in {
796 for fy in 0..fh {
797 for fx in 0..fw {
798 let iy = (oy * sh + fy) as isize - ph as isize;
799 let ix = (ox * sw + fx) as isize - pw as isize;
800 if iy >= 0
801 && (iy as usize) < h_in
802 && ix >= 0
803 && (ix as usize) < w_in
804 {
805 let in_idx = ((b * c_in + ci) * h_in + iy as usize) * w_in
806 + ix as usize;
807 let f_idx = ((kf * c_in + ci) * fh + fy) * fw + fx;
808 acc += in_f32[in_idx] * f_f32[f_idx];
809 }
810 }
811 }
812 }
813 out_f32[((b * k_out + kf) * oh + oy) * ow + ox] = acc;
814 }
815 }
816 }
817 }
818
819 let out_bytes = f32_slice_to_bytes(&out_f32);
820 mem.copy_to_device(output_ptr, &out_bytes)
821 .map_err(BackendError::from)?;
822
823 Ok(())
824 }
825
826 fn attention(
827 &self,
828 q_ptr: u64,
829 k_ptr: u64,
830 v_ptr: u64,
831 o_ptr: u64,
832 batch: usize,
833 heads: usize,
834 seq_q: usize,
835 seq_kv: usize,
836 head_dim: usize,
837 scale: f64,
838 causal: bool,
839 ) -> BackendResult<()> {
840 self.check_init()?;
841
842 if seq_q == 0 || seq_kv == 0 || head_dim == 0 {
843 return Err(BackendError::InvalidArgument(
844 "seq_q, seq_kv, and head_dim must all be > 0".into(),
845 ));
846 }
847 if scale <= 0.0 || !scale.is_finite() {
848 return Err(BackendError::InvalidArgument(format!(
849 "scale must be a positive finite number, got {scale}"
850 )));
851 }
852
853 let mem = self.memory()?;
854
855 let batch_heads = batch * heads;
856 let q_elems = batch_heads * seq_q * head_dim;
857 let kv_elems = batch_heads * seq_kv * head_dim;
858 let o_elems = q_elems;
859
860 let mut q_bytes = vec![0u8; q_elems * 4];
862 let mut k_bytes = vec![0u8; kv_elems * 4];
863 let mut v_bytes = vec![0u8; kv_elems * 4];
864
865 mem.copy_from_device(&mut q_bytes, q_ptr)
866 .map_err(BackendError::from)?;
867 mem.copy_from_device(&mut k_bytes, k_ptr)
868 .map_err(BackendError::from)?;
869 mem.copy_from_device(&mut v_bytes, v_ptr)
870 .map_err(BackendError::from)?;
871
872 let q_f32 = bytes_to_f32_vec(&q_bytes);
873 let k_f32 = bytes_to_f32_vec(&k_bytes);
874 let v_f32 = bytes_to_f32_vec(&v_bytes);
875 let mut o_f32 = vec![0.0f32; o_elems];
876
877 let scale_f32 = scale as f32;
878
879 for bh in 0..batch_heads {
880 let q_off = bh * seq_q * head_dim;
881 let k_off = bh * seq_kv * head_dim;
882 let v_off = k_off;
883
884 for sq in 0..seq_q {
885 let kv_limit = if causal { (sq + 1).min(seq_kv) } else { seq_kv };
886
887 let mut max_score = f32::NEG_INFINITY;
889 for sk in 0..kv_limit {
890 let mut dot = 0.0f32;
891 for dd in 0..head_dim {
892 dot +=
893 q_f32[q_off + sq * head_dim + dd] * k_f32[k_off + sk * head_dim + dd];
894 }
895 let s = dot * scale_f32;
896 if s > max_score {
897 max_score = s;
898 }
899 }
900
901 let mut sum_exp = 0.0f32;
903 let mut acc = vec![0.0f32; head_dim];
904 for sk in 0..kv_limit {
905 let mut dot = 0.0f32;
906 for dd in 0..head_dim {
907 dot +=
908 q_f32[q_off + sq * head_dim + dd] * k_f32[k_off + sk * head_dim + dd];
909 }
910 let w = (dot * scale_f32 - max_score).exp();
911 sum_exp += w;
912 for dd in 0..head_dim {
913 acc[dd] += w * v_f32[v_off + sk * head_dim + dd];
914 }
915 }
916
917 let o_base = q_off + sq * head_dim;
919 if sum_exp > 0.0 {
920 for dd in 0..head_dim {
921 o_f32[o_base + dd] = acc[dd] / sum_exp;
922 }
923 }
924 }
925 }
926
927 let o_bytes = f32_slice_to_bytes(&o_f32);
928 mem.copy_to_device(o_ptr, &o_bytes)
929 .map_err(BackendError::from)?;
930
931 Ok(())
932 }
933
934 fn reduce(
935 &self,
936 op: ReduceOp,
937 input_ptr: u64,
938 output_ptr: u64,
939 shape: &[usize],
940 axis: usize,
941 ) -> BackendResult<()> {
942 self.check_init()?;
943
944 if shape.is_empty() {
945 return Err(BackendError::InvalidArgument(
946 "shape must not be empty".into(),
947 ));
948 }
949 if axis >= shape.len() {
950 return Err(BackendError::InvalidArgument(format!(
951 "axis {axis} is out of bounds for shape of length {}",
952 shape.len()
953 )));
954 }
955
956 if shape.len() != 1 {
960 return self.reduce_nd(op, input_ptr, output_ptr, shape, axis);
961 }
962
963 let n_elements = shape[0];
964 if n_elements == 0 {
965 return Ok(());
966 }
967
968 let dev = self.device()?;
969 let mem = self.memory()?;
970 let op_str = map_reduce_op(op);
971
972 let wg_count = (n_elements as u32).div_ceil(256);
974
975 let pass1_wgsl = shader::reduction_wgsl(op_str);
976 let pass1_shader = dev
977 .device
978 .create_shader_module(wgpu::ShaderModuleDescriptor {
979 label: Some("oxicuda-reduce-pass1"),
980 source: wgpu::ShaderSource::Wgsl(pass1_wgsl.into()),
981 });
982 let pass1_pipeline = dev
983 .device
984 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
985 label: Some("oxicuda-reduce-pass1"),
986 layout: None,
987 module: &pass1_shader,
988 entry_point: Some("main"),
989 compilation_options: Default::default(),
990 cache: None,
991 });
992
993 let partial_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
995 label: Some("oxicuda-reduce-partial"),
996 size: (wg_count as u64) * 4, usage: wgpu::BufferUsages::STORAGE
998 | wgpu::BufferUsages::COPY_SRC
999 | wgpu::BufferUsages::COPY_DST,
1000 mapped_at_creation: false,
1001 });
1002
1003 let mut p1_params = [0u8; 4];
1005 p1_params[0..4].copy_from_slice(&(n_elements as u32).to_le_bytes());
1006 let p1_uniform = dev.device.create_buffer(&wgpu::BufferDescriptor {
1007 label: Some("oxicuda-reduce-p1-params"),
1008 size: 4,
1009 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
1010 mapped_at_creation: false,
1011 });
1012 dev.queue.write_buffer(&p1_uniform, 0, &p1_params);
1013
1014 let bgl1 = pass1_pipeline.get_bind_group_layout(0);
1015
1016 let bg1 = {
1017 let buffers = mem
1018 .lock_buffers()
1019 .map_err(|e| BackendError::DeviceError(e.to_string()))?;
1020 let in_info = buffers.get(&input_ptr).ok_or_else(|| {
1021 BackendError::InvalidArgument(format!("unknown handle {input_ptr}"))
1022 })?;
1023
1024 dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
1025 label: Some("oxicuda-reduce-pass1"),
1026 layout: &bgl1,
1027 entries: &[
1028 wgpu::BindGroupEntry {
1029 binding: 0,
1030 resource: in_info.buffer.as_entire_binding(),
1031 },
1032 wgpu::BindGroupEntry {
1033 binding: 1,
1034 resource: partial_buf.as_entire_binding(),
1035 },
1036 wgpu::BindGroupEntry {
1037 binding: 2,
1038 resource: p1_uniform.as_entire_binding(),
1039 },
1040 ],
1041 })
1042 };
1043
1044 let pass2_wgsl = shader::reduction_final_wgsl(op_str);
1046 let pass2_shader = dev
1047 .device
1048 .create_shader_module(wgpu::ShaderModuleDescriptor {
1049 label: Some("oxicuda-reduce-pass2"),
1050 source: wgpu::ShaderSource::Wgsl(pass2_wgsl.into()),
1051 });
1052 let pass2_pipeline = dev
1053 .device
1054 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1055 label: Some("oxicuda-reduce-pass2"),
1056 layout: None,
1057 module: &pass2_shader,
1058 entry_point: Some("main"),
1059 compilation_options: Default::default(),
1060 cache: None,
1061 });
1062
1063 let mut p2_params = [0u8; 4];
1065 p2_params[0..4].copy_from_slice(&wg_count.to_le_bytes());
1066 let p2_uniform = dev.device.create_buffer(&wgpu::BufferDescriptor {
1067 label: Some("oxicuda-reduce-p2-params"),
1068 size: 4,
1069 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
1070 mapped_at_creation: false,
1071 });
1072 dev.queue.write_buffer(&p2_uniform, 0, &p2_params);
1073
1074 let bgl2 = pass2_pipeline.get_bind_group_layout(0);
1075
1076 let bg2 = {
1077 let buffers = mem
1078 .lock_buffers()
1079 .map_err(|e| BackendError::DeviceError(e.to_string()))?;
1080 let out_info = buffers.get(&output_ptr).ok_or_else(|| {
1081 BackendError::InvalidArgument(format!("unknown handle {output_ptr}"))
1082 })?;
1083
1084 dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
1085 label: Some("oxicuda-reduce-pass2"),
1086 layout: &bgl2,
1087 entries: &[
1088 wgpu::BindGroupEntry {
1089 binding: 0,
1090 resource: partial_buf.as_entire_binding(),
1091 },
1092 wgpu::BindGroupEntry {
1093 binding: 1,
1094 resource: out_info.buffer.as_entire_binding(),
1095 },
1096 wgpu::BindGroupEntry {
1097 binding: 2,
1098 resource: p2_uniform.as_entire_binding(),
1099 },
1100 ],
1101 })
1102 };
1103
1104 let mut encoder = dev
1106 .device
1107 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1108 label: Some("oxicuda-reduce"),
1109 });
1110
1111 {
1112 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1113 label: Some("oxicuda-reduce-pass1"),
1114 timestamp_writes: None,
1115 });
1116 pass.set_pipeline(&pass1_pipeline);
1117 pass.set_bind_group(0, &bg1, &[]);
1118 pass.dispatch_workgroups(wg_count, 1, 1);
1119 }
1120 {
1121 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1122 label: Some("oxicuda-reduce-pass2"),
1123 timestamp_writes: None,
1124 });
1125 pass.set_pipeline(&pass2_pipeline);
1126 pass.set_bind_group(0, &bg2, &[]);
1127 pass.dispatch_workgroups(1, 1, 1);
1128 }
1129
1130 dev.queue.submit(std::iter::once(encoder.finish()));
1131 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
1132
1133 if op == ReduceOp::Mean && n_elements > 1 {
1135 let mut buf = [0u8; 4];
1136 mem.copy_from_device(&mut buf, output_ptr)
1137 .map_err(BackendError::from)?;
1138 let val = f32::from_le_bytes(buf);
1139 let mean = val / (n_elements as f32);
1140 mem.copy_to_device(output_ptr, &mean.to_le_bytes())
1141 .map_err(BackendError::from)?;
1142 }
1143
1144 Ok(())
1145 }
1146
1147 fn unary(&self, op: UnaryOp, input_ptr: u64, output_ptr: u64, n: usize) -> BackendResult<()> {
1148 self.check_init()?;
1149 if n == 0 {
1150 return Ok(());
1151 }
1152
1153 let dev = self.device()?;
1154 let mem = self.memory()?;
1155
1156 let op_str = map_unary_op(op);
1157 let wgsl = shader::elementwise_wgsl(op_str);
1158
1159 let shader_mod = dev
1160 .device
1161 .create_shader_module(wgpu::ShaderModuleDescriptor {
1162 label: Some("oxicuda-unary"),
1163 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1164 });
1165
1166 let pipeline = dev
1167 .device
1168 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1169 label: Some("oxicuda-unary"),
1170 layout: None,
1171 module: &shader_mod,
1172 entry_point: Some("main"),
1173 compilation_options: Default::default(),
1174 cache: None,
1175 });
1176
1177 let bgl = pipeline.get_bind_group_layout(0);
1178
1179 let bind_group = {
1180 let buffers = mem
1181 .lock_buffers()
1182 .map_err(|e| BackendError::DeviceError(e.to_string()))?;
1183 let in_info = buffers.get(&input_ptr).ok_or_else(|| {
1184 BackendError::InvalidArgument(format!("unknown handle {input_ptr}"))
1185 })?;
1186 let out_info = buffers.get(&output_ptr).ok_or_else(|| {
1187 BackendError::InvalidArgument(format!("unknown handle {output_ptr}"))
1188 })?;
1189
1190 dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
1191 label: Some("oxicuda-unary"),
1192 layout: &bgl,
1193 entries: &[
1194 wgpu::BindGroupEntry {
1195 binding: 0,
1196 resource: in_info.buffer.as_entire_binding(),
1197 },
1198 wgpu::BindGroupEntry {
1199 binding: 1,
1200 resource: out_info.buffer.as_entire_binding(),
1201 },
1202 ],
1203 })
1204 };
1205
1206 let mut encoder = dev
1207 .device
1208 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1209 label: Some("oxicuda-unary"),
1210 });
1211
1212 {
1213 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1214 label: Some("oxicuda-unary"),
1215 timestamp_writes: None,
1216 });
1217 pass.set_pipeline(&pipeline);
1218 pass.set_bind_group(0, &bind_group, &[]);
1219 let workgroups = (n as u32).div_ceil(256);
1220 pass.dispatch_workgroups(workgroups, 1, 1);
1221 }
1222
1223 dev.queue.submit(std::iter::once(encoder.finish()));
1224 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
1225
1226 Ok(())
1227 }
1228
1229 fn binary(
1230 &self,
1231 op: BinaryOp,
1232 a_ptr: u64,
1233 b_ptr: u64,
1234 output_ptr: u64,
1235 n: usize,
1236 ) -> BackendResult<()> {
1237 self.check_init()?;
1238 if n == 0 {
1239 return Ok(());
1240 }
1241
1242 let dev = self.device()?;
1243 let mem = self.memory()?;
1244
1245 let op_str = map_binary_op(op);
1246 let wgsl = shader::binary_wgsl(op_str);
1247
1248 let shader_mod = dev
1249 .device
1250 .create_shader_module(wgpu::ShaderModuleDescriptor {
1251 label: Some("oxicuda-binary"),
1252 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1253 });
1254
1255 let pipeline = dev
1256 .device
1257 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1258 label: Some("oxicuda-binary"),
1259 layout: None,
1260 module: &shader_mod,
1261 entry_point: Some("main"),
1262 compilation_options: Default::default(),
1263 cache: None,
1264 });
1265
1266 let bgl = pipeline.get_bind_group_layout(0);
1267
1268 let bind_group = {
1269 let buffers = mem
1270 .lock_buffers()
1271 .map_err(|e| BackendError::DeviceError(e.to_string()))?;
1272 let a_info = buffers
1273 .get(&a_ptr)
1274 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {a_ptr}")))?;
1275 let b_info = buffers
1276 .get(&b_ptr)
1277 .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {b_ptr}")))?;
1278 let out_info = buffers.get(&output_ptr).ok_or_else(|| {
1279 BackendError::InvalidArgument(format!("unknown handle {output_ptr}"))
1280 })?;
1281
1282 dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
1283 label: Some("oxicuda-binary"),
1284 layout: &bgl,
1285 entries: &[
1286 wgpu::BindGroupEntry {
1287 binding: 0,
1288 resource: a_info.buffer.as_entire_binding(),
1289 },
1290 wgpu::BindGroupEntry {
1291 binding: 1,
1292 resource: b_info.buffer.as_entire_binding(),
1293 },
1294 wgpu::BindGroupEntry {
1295 binding: 2,
1296 resource: out_info.buffer.as_entire_binding(),
1297 },
1298 ],
1299 })
1300 };
1301
1302 let mut encoder = dev
1303 .device
1304 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1305 label: Some("oxicuda-binary"),
1306 });
1307
1308 {
1309 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1310 label: Some("oxicuda-binary"),
1311 timestamp_writes: None,
1312 });
1313 pass.set_pipeline(&pipeline);
1314 pass.set_bind_group(0, &bind_group, &[]);
1315 let workgroups = (n as u32).div_ceil(256);
1316 pass.dispatch_workgroups(workgroups, 1, 1);
1317 }
1318
1319 dev.queue.submit(std::iter::once(encoder.finish()));
1320 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
1321
1322 Ok(())
1323 }
1324
1325 fn synchronize(&self) -> BackendResult<()> {
1328 self.check_init()?;
1329 if let Some(dev) = &self.device {
1330 let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
1331 }
1332 Ok(())
1333 }
1334
1335 fn alloc(&self, bytes: usize) -> BackendResult<u64> {
1338 self.check_init()?;
1339 if bytes == 0 {
1340 return Err(BackendError::InvalidArgument(
1341 "cannot allocate 0 bytes".into(),
1342 ));
1343 }
1344 self.memory()?.alloc(bytes).map_err(BackendError::from)
1345 }
1346
1347 fn free(&self, ptr: u64) -> BackendResult<()> {
1348 self.check_init()?;
1349 self.memory()?.free(ptr).map_err(BackendError::from)
1350 }
1351
1352 fn copy_htod(&self, dst: u64, src: &[u8]) -> BackendResult<()> {
1353 self.check_init()?;
1354 if src.is_empty() {
1355 return Ok(());
1356 }
1357 self.memory()?
1358 .copy_to_device(dst, src)
1359 .map_err(BackendError::from)
1360 }
1361
1362 fn copy_dtoh(&self, dst: &mut [u8], src: u64) -> BackendResult<()> {
1363 self.check_init()?;
1364 if dst.is_empty() {
1365 return Ok(());
1366 }
1367 self.memory()?
1368 .copy_from_device(dst, src)
1369 .map_err(BackendError::from)
1370 }
1371}
1372
1373fn bytes_to_f32_vec(bytes: &[u8]) -> Vec<f32> {
1377 bytes
1378 .chunks_exact(4)
1379 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
1380 .collect()
1381}
1382
1383fn f32_slice_to_bytes(data: &[f32]) -> Vec<u8> {
1385 data.iter().flat_map(|v| v.to_le_bytes()).collect()
1386}
1387
1388#[cfg(test)]
1393#[path = "backend_tests.rs"]
1394mod tests;