Skip to main content

oxicuda_webgpu/
backend.rs

1//! [`WebGpuBackend`] — the main entry point for the oxicuda-webgpu crate.
2//!
3//! Implements the [`ComputeBackend`] trait from `oxicuda-backend` using
4//! `wgpu` for cross-platform GPU compute (Vulkan, Metal, DX12, WebGPU).
5
6use std::sync::Arc;
7
8use oxicuda_backend::{
9    BackendError, BackendResult, BackendTranspose, BinaryOp, ComputeBackend, ReduceOp, UnaryOp,
10};
11use wgpu;
12
13use crate::{device::WebGpuDevice, memory::WebGpuMemoryManager, shader};
14
15// ─── Op-mapping helpers ──────────────────────────────────────────────────────
16
17fn map_unary_op(op: UnaryOp) -> &'static str {
18    match op {
19        UnaryOp::Relu => "relu",
20        UnaryOp::Sigmoid => "sigmoid",
21        UnaryOp::Tanh => "tanh",
22        UnaryOp::Exp => "exp",
23        UnaryOp::Log => "log",
24        UnaryOp::Sqrt => "sqrt",
25        UnaryOp::Abs => "abs",
26        UnaryOp::Neg => "neg",
27    }
28}
29
30fn map_binary_op(op: BinaryOp) -> &'static str {
31    match op {
32        BinaryOp::Add => "add",
33        BinaryOp::Sub => "sub",
34        BinaryOp::Mul => "mul",
35        BinaryOp::Div => "div",
36        BinaryOp::Max => "max",
37        BinaryOp::Min => "min",
38    }
39}
40
41fn map_reduce_op(op: ReduceOp) -> &'static str {
42    match op {
43        ReduceOp::Sum => "sum",
44        ReduceOp::Max => "max",
45        ReduceOp::Min => "min",
46        ReduceOp::Mean => "mean",
47    }
48}
49
50// ─── Backend struct ──────────────────────────────────────────────────────────
51
52/// Cross-platform GPU compute backend backed by `wgpu`.
53///
54/// # Lifecycle
55///
56/// 1. `WebGpuBackend::new()` — create an uninitialised backend.
57/// 2. `init()` — select the best available adapter and create the device.
58/// 3. Use `alloc`, `copy_htod`, compute ops, `copy_dtoh`, `free`.
59/// 4. `synchronize()` — wait for all pending GPU work to finish.
60#[derive(Debug)]
61pub struct WebGpuBackend {
62    device: Option<Arc<WebGpuDevice>>,
63    memory: Option<Arc<WebGpuMemoryManager>>,
64    initialized: bool,
65}
66
67impl WebGpuBackend {
68    /// Create a new, uninitialised WebGPU backend.
69    pub fn new() -> Self {
70        Self {
71            device: None,
72            memory: None,
73            initialized: false,
74        }
75    }
76
77    /// Return an error if the backend is not yet initialised.
78    fn check_init(&self) -> BackendResult<()> {
79        if self.initialized {
80            Ok(())
81        } else {
82            Err(BackendError::NotInitialized)
83        }
84    }
85
86    /// Convenience accessor: get the memory manager or return `NotInitialized`.
87    fn memory(&self) -> BackendResult<&Arc<WebGpuMemoryManager>> {
88        self.memory.as_ref().ok_or(BackendError::NotInitialized)
89    }
90
91    /// Convenience accessor: get the device or return `NotInitialized`.
92    fn device(&self) -> BackendResult<&Arc<WebGpuDevice>> {
93        self.device.as_ref().ok_or(BackendError::NotInitialized)
94    }
95
96    /// Multi-dimensional reduce along a single axis.
97    ///
98    /// The tensor is logically reshaped to `[outer, dk, inner]`:
99    /// * `outer` = product of dimensions before the reduce axis,
100    /// * `dk`    = the reduce axis length,
101    /// * `inner` = product of dimensions after the reduce axis.
102    ///
103    /// One workgroup of 256 threads is dispatched per `(o, j)` output slot.
104    /// To stay within WebGPU's 65 535-per-axis dispatch limit a 2-D grid is
105    /// used and the workgroup decodes its linear slot internally.
106    ///
107    /// `Mean` is handled inside the shader (divide by `dk`); the host does
108    /// not need a post-pass.
109    fn reduce_nd(
110        &self,
111        op: ReduceOp,
112        input_ptr: u64,
113        output_ptr: u64,
114        shape: &[usize],
115        axis: usize,
116    ) -> BackendResult<()> {
117        // Caller (`reduce`) already validated `shape.is_empty()` and
118        // `axis < shape.len()`; assert in debug to catch regressions but
119        // recompute defensively in release as well.
120        debug_assert!(!shape.is_empty());
121        debug_assert!(axis < shape.len());
122
123        // Output shape = shape with `axis` removed; length = outer * inner.
124        let outer: usize = shape[..axis].iter().product();
125        let dk: usize = shape[axis];
126        let inner: usize = shape[axis + 1..].iter().product();
127
128        // Empty tensor — nothing to do.
129        if outer == 0 || dk == 0 || inner == 0 {
130            return Ok(());
131        }
132
133        let total = outer.checked_mul(inner).ok_or_else(|| {
134            BackendError::InvalidArgument("reduce: outer * inner overflows usize".into())
135        })?;
136
137        // Strides in elements: row-major (C order) layout.
138        let inner_stride: usize = 1;
139        let dk_stride: usize = inner;
140        let outer_stride: usize = dk
141            .checked_mul(inner)
142            .ok_or_else(|| BackendError::InvalidArgument("reduce: dk * inner overflows".into()))?;
143
144        // Cap each dispatch dimension below the WebGPU 65 535 limit.  We pick
145        // grid_x = min(total, 32 768) so grid_y stays modest for huge tensors.
146        const MAX_GRID_DIM: u32 = 32_768;
147        let total_u32: u32 = total.try_into().map_err(|_| {
148            BackendError::InvalidArgument(format!(
149                "reduce: output element count {total} exceeds u32 range"
150            ))
151        })?;
152        let grid_x: u32 = total_u32.clamp(1, MAX_GRID_DIM);
153        let grid_y: u32 = total_u32.div_ceil(grid_x);
154
155        let dev = self.device()?;
156        let mem = self.memory()?;
157        let op_str = map_reduce_op(op);
158
159        let wgsl = shader::reduction_nd_wgsl(op_str);
160        let shader_mod = dev
161            .device
162            .create_shader_module(wgpu::ShaderModuleDescriptor {
163                label: Some("oxicuda-reduce-nd"),
164                source: wgpu::ShaderSource::Wgsl(wgsl.into()),
165            });
166        let pipeline = dev
167            .device
168            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
169                label: Some("oxicuda-reduce-nd"),
170                layout: None,
171                module: &shader_mod,
172                entry_point: Some("main"),
173                compilation_options: Default::default(),
174                cache: None,
175            });
176
177        // Build the uniform buffer: 8 × u32 = 32 bytes (16-byte aligned).
178        let mut params_bytes = [0u8; 32];
179        let outer_u32: u32 = outer
180            .try_into()
181            .map_err(|_| BackendError::InvalidArgument("reduce: outer exceeds u32 range".into()))?;
182        let dk_u32: u32 = dk
183            .try_into()
184            .map_err(|_| BackendError::InvalidArgument("reduce: dk exceeds u32 range".into()))?;
185        let inner_u32: u32 = inner
186            .try_into()
187            .map_err(|_| BackendError::InvalidArgument("reduce: inner exceeds u32 range".into()))?;
188        let outer_stride_u32: u32 = outer_stride.try_into().map_err(|_| {
189            BackendError::InvalidArgument("reduce: outer_stride exceeds u32 range".into())
190        })?;
191        let dk_stride_u32: u32 = dk_stride.try_into().map_err(|_| {
192            BackendError::InvalidArgument("reduce: dk_stride exceeds u32 range".into())
193        })?;
194        let inner_stride_u32: u32 = inner_stride.try_into().map_err(|_| {
195            BackendError::InvalidArgument("reduce: inner_stride exceeds u32 range".into())
196        })?;
197        params_bytes[0..4].copy_from_slice(&outer_u32.to_le_bytes());
198        params_bytes[4..8].copy_from_slice(&dk_u32.to_le_bytes());
199        params_bytes[8..12].copy_from_slice(&inner_u32.to_le_bytes());
200        params_bytes[12..16].copy_from_slice(&outer_stride_u32.to_le_bytes());
201        params_bytes[16..20].copy_from_slice(&dk_stride_u32.to_le_bytes());
202        params_bytes[20..24].copy_from_slice(&inner_stride_u32.to_le_bytes());
203        params_bytes[24..28].copy_from_slice(&grid_x.to_le_bytes());
204        // bytes 28..32 are zero padding.
205
206        let uniform_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
207            label: Some("oxicuda-reduce-nd-params"),
208            size: 32,
209            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
210            mapped_at_creation: false,
211        });
212        dev.queue.write_buffer(&uniform_buf, 0, &params_bytes);
213
214        let bgl = pipeline.get_bind_group_layout(0);
215        let bind_group = {
216            let buffers = mem
217                .lock_buffers()
218                .map_err(|e| BackendError::DeviceError(e.to_string()))?;
219            let in_info = buffers.get(&input_ptr).ok_or_else(|| {
220                BackendError::InvalidArgument(format!("unknown handle {input_ptr}"))
221            })?;
222            let out_info = buffers.get(&output_ptr).ok_or_else(|| {
223                BackendError::InvalidArgument(format!("unknown handle {output_ptr}"))
224            })?;
225
226            dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
227                label: Some("oxicuda-reduce-nd"),
228                layout: &bgl,
229                entries: &[
230                    wgpu::BindGroupEntry {
231                        binding: 0,
232                        resource: in_info.buffer.as_entire_binding(),
233                    },
234                    wgpu::BindGroupEntry {
235                        binding: 1,
236                        resource: out_info.buffer.as_entire_binding(),
237                    },
238                    wgpu::BindGroupEntry {
239                        binding: 2,
240                        resource: uniform_buf.as_entire_binding(),
241                    },
242                ],
243            })
244        };
245
246        let mut encoder = dev
247            .device
248            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
249                label: Some("oxicuda-reduce-nd"),
250            });
251        {
252            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
253                label: Some("oxicuda-reduce-nd"),
254                timestamp_writes: None,
255            });
256            pass.set_pipeline(&pipeline);
257            pass.set_bind_group(0, &bind_group, &[]);
258            pass.dispatch_workgroups(grid_x, grid_y, 1);
259        }
260
261        dev.queue.submit(std::iter::once(encoder.finish()));
262        let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
263
264        Ok(())
265    }
266}
267
268impl WebGpuBackend {
269    /// FP16 GEMM: `C = alpha * A * B + beta * C` with half-precision storage.
270    ///
271    /// This is an inherent method (not on `ComputeBackend`) because FP16
272    /// support is WebGPU-specific and requires the `f16` WGSL extension.
273    ///
274    /// Buffers pointed to by `a_ptr`, `b_ptr`, `c_ptr` must contain `f16`
275    /// elements (2 bytes each).
276    #[allow(clippy::too_many_arguments)]
277    pub fn gemm_f16(
278        &self,
279        m: usize,
280        n: usize,
281        k: usize,
282        alpha: f64,
283        a_ptr: u64,
284        b_ptr: u64,
285        beta: f64,
286        c_ptr: u64,
287    ) -> BackendResult<()> {
288        self.check_init()?;
289        if m == 0 || n == 0 || k == 0 {
290            return Ok(());
291        }
292
293        let dev = self.device()?;
294        let mem = self.memory()?;
295
296        let tile_size: u32 = 8;
297        let wgsl = shader::gemm_wgsl_f16(tile_size);
298
299        let shader_mod = dev
300            .device
301            .create_shader_module(wgpu::ShaderModuleDescriptor {
302                label: Some("oxicuda-gemm-f16"),
303                source: wgpu::ShaderSource::Wgsl(wgsl.into()),
304            });
305
306        let pipeline = dev
307            .device
308            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
309                label: Some("oxicuda-gemm-f16"),
310                layout: None,
311                module: &shader_mod,
312                entry_point: Some("main"),
313                compilation_options: Default::default(),
314                cache: None,
315            });
316
317        let bgl = pipeline.get_bind_group_layout(0);
318
319        // Build uniform buffer for GemmParams { m, n, k, alpha, beta }.
320        let mut params_bytes = [0u8; 20];
321        params_bytes[0..4].copy_from_slice(&(m as u32).to_le_bytes());
322        params_bytes[4..8].copy_from_slice(&(n as u32).to_le_bytes());
323        params_bytes[8..12].copy_from_slice(&(k as u32).to_le_bytes());
324        params_bytes[12..16].copy_from_slice(&(alpha as f32).to_le_bytes());
325        params_bytes[16..20].copy_from_slice(&(beta as f32).to_le_bytes());
326
327        let uniform_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
328            label: Some("oxicuda-gemm-f16-params"),
329            size: 20,
330            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
331            mapped_at_creation: false,
332        });
333        dev.queue.write_buffer(&uniform_buf, 0, &params_bytes);
334
335        let bind_group = {
336            let buffers = mem
337                .lock_buffers()
338                .map_err(|e| BackendError::DeviceError(e.to_string()))?;
339            let a_info = buffers
340                .get(&a_ptr)
341                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {a_ptr}")))?;
342            let b_info = buffers
343                .get(&b_ptr)
344                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {b_ptr}")))?;
345            let c_info = buffers
346                .get(&c_ptr)
347                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {c_ptr}")))?;
348
349            dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
350                label: Some("oxicuda-gemm-f16"),
351                layout: &bgl,
352                entries: &[
353                    wgpu::BindGroupEntry {
354                        binding: 0,
355                        resource: a_info.buffer.as_entire_binding(),
356                    },
357                    wgpu::BindGroupEntry {
358                        binding: 1,
359                        resource: b_info.buffer.as_entire_binding(),
360                    },
361                    wgpu::BindGroupEntry {
362                        binding: 2,
363                        resource: c_info.buffer.as_entire_binding(),
364                    },
365                    wgpu::BindGroupEntry {
366                        binding: 3,
367                        resource: uniform_buf.as_entire_binding(),
368                    },
369                ],
370            })
371        };
372
373        let mut encoder = dev
374            .device
375            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
376                label: Some("oxicuda-gemm-f16"),
377            });
378
379        {
380            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
381                label: Some("oxicuda-gemm-f16"),
382                timestamp_writes: None,
383            });
384            pass.set_pipeline(&pipeline);
385            pass.set_bind_group(0, &bind_group, &[]);
386            let wg_x = (n as u32).div_ceil(tile_size);
387            let wg_y = (m as u32).div_ceil(tile_size);
388            pass.dispatch_workgroups(wg_x, wg_y, 1);
389        }
390
391        dev.queue.submit(std::iter::once(encoder.finish()));
392        let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
393
394        Ok(())
395    }
396}
397
398impl Default for WebGpuBackend {
399    fn default() -> Self {
400        Self::new()
401    }
402}
403
404// ─── ComputeBackend impl ─────────────────────────────────────────────────────
405
406impl ComputeBackend for WebGpuBackend {
407    fn name(&self) -> &str {
408        "webgpu"
409    }
410
411    fn init(&mut self) -> BackendResult<()> {
412        if self.initialized {
413            return Ok(());
414        }
415
416        match WebGpuDevice::new() {
417            Ok(dev) => {
418                let dev = Arc::new(dev);
419                tracing::info!("WebGPU backend initialised on: {}", dev.adapter_name);
420                let memory = WebGpuMemoryManager::new(Arc::clone(&dev));
421                self.device = Some(dev);
422                self.memory = Some(Arc::new(memory));
423                self.initialized = true;
424                Ok(())
425            }
426            Err(e) => Err(BackendError::from(e)),
427        }
428    }
429
430    fn is_initialized(&self) -> bool {
431        self.initialized
432    }
433
434    // ── Compute operations ────────────────────────────────────────────────────
435
436    fn gemm(
437        &self,
438        trans_a: BackendTranspose,
439        trans_b: BackendTranspose,
440        m: usize,
441        n: usize,
442        k: usize,
443        alpha: f64,
444        a_ptr: u64,
445        _lda: usize,
446        b_ptr: u64,
447        _ldb: usize,
448        beta: f64,
449        c_ptr: u64,
450        _ldc: usize,
451    ) -> BackendResult<()> {
452        self.check_init()?;
453        // Zero-dimension matrices are trivially done.
454        if m == 0 || n == 0 || k == 0 {
455            return Ok(());
456        }
457
458        // Transpose not yet supported in the WGSL shader.
459        if trans_a != BackendTranspose::NoTrans || trans_b != BackendTranspose::NoTrans {
460            return Err(BackendError::Unsupported(
461                "WebGPU GEMM does not yet support transposed inputs".into(),
462            ));
463        }
464
465        let dev = self.device()?;
466        let mem = self.memory()?;
467
468        let tile_size: u32 = 8;
469        let wgsl = shader::gemm_wgsl(tile_size);
470
471        let shader_mod = dev
472            .device
473            .create_shader_module(wgpu::ShaderModuleDescriptor {
474                label: Some("oxicuda-gemm"),
475                source: wgpu::ShaderSource::Wgsl(wgsl.into()),
476            });
477
478        let pipeline = dev
479            .device
480            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
481                label: Some("oxicuda-gemm"),
482                layout: None,
483                module: &shader_mod,
484                entry_point: Some("main"),
485                compilation_options: Default::default(),
486                cache: None,
487            });
488
489        let bgl = pipeline.get_bind_group_layout(0);
490
491        // Build uniform buffer for GemmParams { m, n, k, alpha, beta }.
492        let mut params_bytes = [0u8; 20];
493        params_bytes[0..4].copy_from_slice(&(m as u32).to_le_bytes());
494        params_bytes[4..8].copy_from_slice(&(n as u32).to_le_bytes());
495        params_bytes[8..12].copy_from_slice(&(k as u32).to_le_bytes());
496        params_bytes[12..16].copy_from_slice(&(alpha as f32).to_le_bytes());
497        params_bytes[16..20].copy_from_slice(&(beta as f32).to_le_bytes());
498
499        let uniform_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
500            label: Some("oxicuda-gemm-params"),
501            size: 20,
502            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
503            mapped_at_creation: false,
504        });
505        dev.queue.write_buffer(&uniform_buf, 0, &params_bytes);
506
507        // Create bind group while holding the buffer lock.
508        let bind_group = {
509            let buffers = mem
510                .lock_buffers()
511                .map_err(|e| BackendError::DeviceError(e.to_string()))?;
512            let a_info = buffers
513                .get(&a_ptr)
514                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {a_ptr}")))?;
515            let b_info = buffers
516                .get(&b_ptr)
517                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {b_ptr}")))?;
518            let c_info = buffers
519                .get(&c_ptr)
520                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {c_ptr}")))?;
521
522            dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
523                label: Some("oxicuda-gemm"),
524                layout: &bgl,
525                entries: &[
526                    wgpu::BindGroupEntry {
527                        binding: 0,
528                        resource: a_info.buffer.as_entire_binding(),
529                    },
530                    wgpu::BindGroupEntry {
531                        binding: 1,
532                        resource: b_info.buffer.as_entire_binding(),
533                    },
534                    wgpu::BindGroupEntry {
535                        binding: 2,
536                        resource: c_info.buffer.as_entire_binding(),
537                    },
538                    wgpu::BindGroupEntry {
539                        binding: 3,
540                        resource: uniform_buf.as_entire_binding(),
541                    },
542                ],
543            })
544        };
545
546        let mut encoder = dev
547            .device
548            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
549                label: Some("oxicuda-gemm"),
550            });
551
552        {
553            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
554                label: Some("oxicuda-gemm"),
555                timestamp_writes: None,
556            });
557            pass.set_pipeline(&pipeline);
558            pass.set_bind_group(0, &bind_group, &[]);
559            let wg_x = (n as u32).div_ceil(tile_size);
560            let wg_y = (m as u32).div_ceil(tile_size);
561            pass.dispatch_workgroups(wg_x, wg_y, 1);
562        }
563
564        dev.queue.submit(std::iter::once(encoder.finish()));
565        let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
566
567        Ok(())
568    }
569
570    #[allow(clippy::too_many_arguments)]
571    fn batched_gemm(
572        &self,
573        trans_a: BackendTranspose,
574        trans_b: BackendTranspose,
575        m: usize,
576        n: usize,
577        k: usize,
578        alpha: f64,
579        a_ptr: u64,
580        _lda: usize,
581        stride_a: usize,
582        b_ptr: u64,
583        _ldb: usize,
584        stride_b: usize,
585        beta: f64,
586        c_ptr: u64,
587        _ldc: usize,
588        stride_c: usize,
589        batch_count: usize,
590    ) -> BackendResult<()> {
591        self.check_init()?;
592
593        if batch_count == 0 || m == 0 || n == 0 || k == 0 {
594            return Ok(());
595        }
596
597        if trans_a != BackendTranspose::NoTrans || trans_b != BackendTranspose::NoTrans {
598            return Err(BackendError::Unsupported(
599                "WebGPU batched GEMM does not yet support transposed inputs".into(),
600            ));
601        }
602
603        let dev = self.device()?;
604        let mem = self.memory()?;
605
606        let tile_size: u32 = 8;
607        let wgsl = shader::batched_gemm_wgsl(tile_size);
608
609        let shader_mod = dev
610            .device
611            .create_shader_module(wgpu::ShaderModuleDescriptor {
612                label: Some("oxicuda-batched-gemm"),
613                source: wgpu::ShaderSource::Wgsl(wgsl.into()),
614            });
615
616        let pipeline = dev
617            .device
618            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
619                label: Some("oxicuda-batched-gemm"),
620                layout: None,
621                module: &shader_mod,
622                entry_point: Some("main"),
623                compilation_options: Default::default(),
624                cache: None,
625            });
626
627        let bgl = pipeline.get_bind_group_layout(0);
628
629        // BatchedGemmParams: m, n, k, alpha, beta, batch_count, stride_a, stride_b, stride_c
630        // 9 fields: 5 x u32/f32 + 4 x u32 = 36 bytes total
631        // But we need 16-byte alignment for uniform buffers. 36 rounds up to 48.
632        // Actually: 3 u32 + 2 f32 + 1 u32 + 3 u32 = 9 x 4 = 36 bytes.
633        // Pad to 48 for safety (16-byte aligned).
634        let mut params_bytes = [0u8; 48];
635        params_bytes[0..4].copy_from_slice(&(m as u32).to_le_bytes());
636        params_bytes[4..8].copy_from_slice(&(n as u32).to_le_bytes());
637        params_bytes[8..12].copy_from_slice(&(k as u32).to_le_bytes());
638        params_bytes[12..16].copy_from_slice(&(alpha as f32).to_le_bytes());
639        params_bytes[16..20].copy_from_slice(&(beta as f32).to_le_bytes());
640        params_bytes[20..24].copy_from_slice(&(batch_count as u32).to_le_bytes());
641        params_bytes[24..28].copy_from_slice(&(stride_a as u32).to_le_bytes());
642        params_bytes[28..32].copy_from_slice(&(stride_b as u32).to_le_bytes());
643        params_bytes[32..36].copy_from_slice(&(stride_c as u32).to_le_bytes());
644        // bytes 36..48 are padding zeros
645
646        let uniform_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
647            label: Some("oxicuda-batched-gemm-params"),
648            size: 48,
649            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
650            mapped_at_creation: false,
651        });
652        dev.queue.write_buffer(&uniform_buf, 0, &params_bytes);
653
654        let bind_group = {
655            let buffers = mem
656                .lock_buffers()
657                .map_err(|e| BackendError::DeviceError(e.to_string()))?;
658            let a_info = buffers
659                .get(&a_ptr)
660                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {a_ptr}")))?;
661            let b_info = buffers
662                .get(&b_ptr)
663                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {b_ptr}")))?;
664            let c_info = buffers
665                .get(&c_ptr)
666                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {c_ptr}")))?;
667
668            dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
669                label: Some("oxicuda-batched-gemm"),
670                layout: &bgl,
671                entries: &[
672                    wgpu::BindGroupEntry {
673                        binding: 0,
674                        resource: a_info.buffer.as_entire_binding(),
675                    },
676                    wgpu::BindGroupEntry {
677                        binding: 1,
678                        resource: b_info.buffer.as_entire_binding(),
679                    },
680                    wgpu::BindGroupEntry {
681                        binding: 2,
682                        resource: c_info.buffer.as_entire_binding(),
683                    },
684                    wgpu::BindGroupEntry {
685                        binding: 3,
686                        resource: uniform_buf.as_entire_binding(),
687                    },
688                ],
689            })
690        };
691
692        let mut encoder = dev
693            .device
694            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
695                label: Some("oxicuda-batched-gemm"),
696            });
697
698        {
699            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
700                label: Some("oxicuda-batched-gemm"),
701                timestamp_writes: None,
702            });
703            pass.set_pipeline(&pipeline);
704            pass.set_bind_group(0, &bind_group, &[]);
705            let wg_x = (n as u32).div_ceil(tile_size);
706            let wg_y = (m as u32).div_ceil(tile_size);
707            pass.dispatch_workgroups(wg_x, wg_y, batch_count as u32);
708        }
709
710        dev.queue.submit(std::iter::once(encoder.finish()));
711        let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
712
713        Ok(())
714    }
715
716    fn conv2d_forward(
717        &self,
718        input_ptr: u64,
719        input_shape: &[usize],
720        filter_ptr: u64,
721        filter_shape: &[usize],
722        output_ptr: u64,
723        output_shape: &[usize],
724        stride: &[usize],
725        padding: &[usize],
726    ) -> BackendResult<()> {
727        self.check_init()?;
728
729        if input_shape.len() != 4 {
730            return Err(BackendError::InvalidArgument(
731                "input_shape must have 4 elements (NCHW)".into(),
732            ));
733        }
734        if filter_shape.len() != 4 {
735            return Err(BackendError::InvalidArgument(
736                "filter_shape must have 4 elements (KCFHFW)".into(),
737            ));
738        }
739        if output_shape.len() != 4 {
740            return Err(BackendError::InvalidArgument(
741                "output_shape must have 4 elements (NKOhOw)".into(),
742            ));
743        }
744        if stride.len() != 2 {
745            return Err(BackendError::InvalidArgument(
746                "stride must have 2 elements [sh, sw]".into(),
747            ));
748        }
749        if padding.len() != 2 {
750            return Err(BackendError::InvalidArgument(
751                "padding must have 2 elements [ph, pw]".into(),
752            ));
753        }
754
755        let mem = self.memory()?;
756
757        let batch = input_shape[0];
758        let c_in = input_shape[1];
759        let h_in = input_shape[2];
760        let w_in = input_shape[3];
761        let k_out = filter_shape[0];
762        let fh = filter_shape[2];
763        let fw = filter_shape[3];
764        let oh = output_shape[2];
765        let ow = output_shape[3];
766        let sh = stride[0];
767        let sw = stride[1];
768        let ph = padding[0];
769        let pw = padding[1];
770
771        let in_elems: usize = input_shape.iter().product();
772        let f_elems: usize = filter_shape.iter().product();
773        let o_elems: usize = output_shape.iter().product();
774
775        // CPU fallback: download input + filter, compute, upload output.
776        let mut in_bytes = vec![0u8; in_elems * 4];
777        let mut f_bytes = vec![0u8; f_elems * 4];
778        mem.copy_from_device(&mut in_bytes, input_ptr)
779            .map_err(BackendError::from)?;
780        mem.copy_from_device(&mut f_bytes, filter_ptr)
781            .map_err(BackendError::from)?;
782
783        let in_f32 = bytes_to_f32_vec(&in_bytes);
784        let f_f32 = bytes_to_f32_vec(&f_bytes);
785        let mut out_f32 = vec![0.0f32; o_elems];
786
787        for b in 0..batch {
788            for kf in 0..k_out {
789                for oy in 0..oh {
790                    for ox in 0..ow {
791                        let mut acc = 0.0f32;
792                        for ci in 0..c_in {
793                            for fy in 0..fh {
794                                for fx in 0..fw {
795                                    let iy = (oy * sh + fy) as isize - ph as isize;
796                                    let ix = (ox * sw + fx) as isize - pw as isize;
797                                    if iy >= 0
798                                        && (iy as usize) < h_in
799                                        && ix >= 0
800                                        && (ix as usize) < w_in
801                                    {
802                                        let in_idx = ((b * c_in + ci) * h_in + iy as usize) * w_in
803                                            + ix as usize;
804                                        let f_idx = ((kf * c_in + ci) * fh + fy) * fw + fx;
805                                        acc += in_f32[in_idx] * f_f32[f_idx];
806                                    }
807                                }
808                            }
809                        }
810                        out_f32[((b * k_out + kf) * oh + oy) * ow + ox] = acc;
811                    }
812                }
813            }
814        }
815
816        let out_bytes = f32_slice_to_bytes(&out_f32);
817        mem.copy_to_device(output_ptr, &out_bytes)
818            .map_err(BackendError::from)?;
819
820        Ok(())
821    }
822
823    fn attention(
824        &self,
825        q_ptr: u64,
826        k_ptr: u64,
827        v_ptr: u64,
828        o_ptr: u64,
829        batch: usize,
830        heads: usize,
831        seq_q: usize,
832        seq_kv: usize,
833        head_dim: usize,
834        scale: f64,
835        causal: bool,
836    ) -> BackendResult<()> {
837        self.check_init()?;
838
839        if seq_q == 0 || seq_kv == 0 || head_dim == 0 {
840            return Err(BackendError::InvalidArgument(
841                "seq_q, seq_kv, and head_dim must all be > 0".into(),
842            ));
843        }
844        if scale <= 0.0 || !scale.is_finite() {
845            return Err(BackendError::InvalidArgument(format!(
846                "scale must be a positive finite number, got {scale}"
847            )));
848        }
849
850        let mem = self.memory()?;
851
852        let batch_heads = batch * heads;
853        let q_elems = batch_heads * seq_q * head_dim;
854        let kv_elems = batch_heads * seq_kv * head_dim;
855        let o_elems = q_elems;
856
857        // CPU fallback: download Q, K, V, compute attention, upload O.
858        let mut q_bytes = vec![0u8; q_elems * 4];
859        let mut k_bytes = vec![0u8; kv_elems * 4];
860        let mut v_bytes = vec![0u8; kv_elems * 4];
861
862        mem.copy_from_device(&mut q_bytes, q_ptr)
863            .map_err(BackendError::from)?;
864        mem.copy_from_device(&mut k_bytes, k_ptr)
865            .map_err(BackendError::from)?;
866        mem.copy_from_device(&mut v_bytes, v_ptr)
867            .map_err(BackendError::from)?;
868
869        let q_f32 = bytes_to_f32_vec(&q_bytes);
870        let k_f32 = bytes_to_f32_vec(&k_bytes);
871        let v_f32 = bytes_to_f32_vec(&v_bytes);
872        let mut o_f32 = vec![0.0f32; o_elems];
873
874        let scale_f32 = scale as f32;
875
876        for bh in 0..batch_heads {
877            let q_off = bh * seq_q * head_dim;
878            let k_off = bh * seq_kv * head_dim;
879            let v_off = k_off;
880
881            for sq in 0..seq_q {
882                let kv_limit = if causal { (sq + 1).min(seq_kv) } else { seq_kv };
883
884                // Pass 1: find max score for numerical stability
885                let mut max_score = f32::NEG_INFINITY;
886                for sk in 0..kv_limit {
887                    let mut dot = 0.0f32;
888                    for dd in 0..head_dim {
889                        dot +=
890                            q_f32[q_off + sq * head_dim + dd] * k_f32[k_off + sk * head_dim + dd];
891                    }
892                    let s = dot * scale_f32;
893                    if s > max_score {
894                        max_score = s;
895                    }
896                }
897
898                // Pass 2: exp(score - max), accumulate weighted V
899                let mut sum_exp = 0.0f32;
900                let mut acc = vec![0.0f32; head_dim];
901                for sk in 0..kv_limit {
902                    let mut dot = 0.0f32;
903                    for dd in 0..head_dim {
904                        dot +=
905                            q_f32[q_off + sq * head_dim + dd] * k_f32[k_off + sk * head_dim + dd];
906                    }
907                    let w = (dot * scale_f32 - max_score).exp();
908                    sum_exp += w;
909                    for dd in 0..head_dim {
910                        acc[dd] += w * v_f32[v_off + sk * head_dim + dd];
911                    }
912                }
913
914                // Normalise
915                let o_base = q_off + sq * head_dim;
916                if sum_exp > 0.0 {
917                    for dd in 0..head_dim {
918                        o_f32[o_base + dd] = acc[dd] / sum_exp;
919                    }
920                }
921            }
922        }
923
924        let o_bytes = f32_slice_to_bytes(&o_f32);
925        mem.copy_to_device(o_ptr, &o_bytes)
926            .map_err(BackendError::from)?;
927
928        Ok(())
929    }
930
931    fn reduce(
932        &self,
933        op: ReduceOp,
934        input_ptr: u64,
935        output_ptr: u64,
936        shape: &[usize],
937        axis: usize,
938    ) -> BackendResult<()> {
939        self.check_init()?;
940
941        if shape.is_empty() {
942            return Err(BackendError::InvalidArgument(
943                "shape must not be empty".into(),
944            ));
945        }
946        if axis >= shape.len() {
947            return Err(BackendError::InvalidArgument(format!(
948                "axis {axis} is out of bounds for shape of length {}",
949                shape.len()
950            )));
951        }
952
953        // 1-D shapes (or any shape that reduces to a single scalar) take the
954        // optimised two-pass scalar path.  Higher-rank shapes go through the
955        // batched N-D shader below.
956        if shape.len() != 1 {
957            return self.reduce_nd(op, input_ptr, output_ptr, shape, axis);
958        }
959
960        let n_elements = shape[0];
961        if n_elements == 0 {
962            return Ok(());
963        }
964
965        let dev = self.device()?;
966        let mem = self.memory()?;
967        let op_str = map_reduce_op(op);
968
969        // ── Pass 1: per-workgroup reduction ─────────────────────────────────
970        let wg_count = (n_elements as u32).div_ceil(256);
971
972        let pass1_wgsl = shader::reduction_wgsl(op_str);
973        let pass1_shader = dev
974            .device
975            .create_shader_module(wgpu::ShaderModuleDescriptor {
976                label: Some("oxicuda-reduce-pass1"),
977                source: wgpu::ShaderSource::Wgsl(pass1_wgsl.into()),
978            });
979        let pass1_pipeline = dev
980            .device
981            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
982                label: Some("oxicuda-reduce-pass1"),
983                layout: None,
984                module: &pass1_shader,
985                entry_point: Some("main"),
986                compilation_options: Default::default(),
987                cache: None,
988            });
989
990        // Partial-sums buffer (temporary).
991        let partial_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
992            label: Some("oxicuda-reduce-partial"),
993            size: (wg_count as u64) * 4, // f32 per workgroup
994            usage: wgpu::BufferUsages::STORAGE
995                | wgpu::BufferUsages::COPY_SRC
996                | wgpu::BufferUsages::COPY_DST,
997            mapped_at_creation: false,
998        });
999
1000        // Uniform for ReduceParams { n: u32 }.
1001        let mut p1_params = [0u8; 4];
1002        p1_params[0..4].copy_from_slice(&(n_elements as u32).to_le_bytes());
1003        let p1_uniform = dev.device.create_buffer(&wgpu::BufferDescriptor {
1004            label: Some("oxicuda-reduce-p1-params"),
1005            size: 4,
1006            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
1007            mapped_at_creation: false,
1008        });
1009        dev.queue.write_buffer(&p1_uniform, 0, &p1_params);
1010
1011        let bgl1 = pass1_pipeline.get_bind_group_layout(0);
1012
1013        let bg1 = {
1014            let buffers = mem
1015                .lock_buffers()
1016                .map_err(|e| BackendError::DeviceError(e.to_string()))?;
1017            let in_info = buffers.get(&input_ptr).ok_or_else(|| {
1018                BackendError::InvalidArgument(format!("unknown handle {input_ptr}"))
1019            })?;
1020
1021            dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
1022                label: Some("oxicuda-reduce-pass1"),
1023                layout: &bgl1,
1024                entries: &[
1025                    wgpu::BindGroupEntry {
1026                        binding: 0,
1027                        resource: in_info.buffer.as_entire_binding(),
1028                    },
1029                    wgpu::BindGroupEntry {
1030                        binding: 1,
1031                        resource: partial_buf.as_entire_binding(),
1032                    },
1033                    wgpu::BindGroupEntry {
1034                        binding: 2,
1035                        resource: p1_uniform.as_entire_binding(),
1036                    },
1037                ],
1038            })
1039        };
1040
1041        // ── Pass 2: final reduction of partial sums ─────────────────────────
1042        let pass2_wgsl = shader::reduction_final_wgsl(op_str);
1043        let pass2_shader = dev
1044            .device
1045            .create_shader_module(wgpu::ShaderModuleDescriptor {
1046                label: Some("oxicuda-reduce-pass2"),
1047                source: wgpu::ShaderSource::Wgsl(pass2_wgsl.into()),
1048            });
1049        let pass2_pipeline = dev
1050            .device
1051            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1052                label: Some("oxicuda-reduce-pass2"),
1053                layout: None,
1054                module: &pass2_shader,
1055                entry_point: Some("main"),
1056                compilation_options: Default::default(),
1057                cache: None,
1058            });
1059
1060        // FinalReduceParams { num_groups: u32 }.
1061        let mut p2_params = [0u8; 4];
1062        p2_params[0..4].copy_from_slice(&wg_count.to_le_bytes());
1063        let p2_uniform = dev.device.create_buffer(&wgpu::BufferDescriptor {
1064            label: Some("oxicuda-reduce-p2-params"),
1065            size: 4,
1066            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
1067            mapped_at_creation: false,
1068        });
1069        dev.queue.write_buffer(&p2_uniform, 0, &p2_params);
1070
1071        let bgl2 = pass2_pipeline.get_bind_group_layout(0);
1072
1073        let bg2 = {
1074            let buffers = mem
1075                .lock_buffers()
1076                .map_err(|e| BackendError::DeviceError(e.to_string()))?;
1077            let out_info = buffers.get(&output_ptr).ok_or_else(|| {
1078                BackendError::InvalidArgument(format!("unknown handle {output_ptr}"))
1079            })?;
1080
1081            dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
1082                label: Some("oxicuda-reduce-pass2"),
1083                layout: &bgl2,
1084                entries: &[
1085                    wgpu::BindGroupEntry {
1086                        binding: 0,
1087                        resource: partial_buf.as_entire_binding(),
1088                    },
1089                    wgpu::BindGroupEntry {
1090                        binding: 1,
1091                        resource: out_info.buffer.as_entire_binding(),
1092                    },
1093                    wgpu::BindGroupEntry {
1094                        binding: 2,
1095                        resource: p2_uniform.as_entire_binding(),
1096                    },
1097                ],
1098            })
1099        };
1100
1101        // ── Encode both passes into one command buffer ──────────────────────
1102        let mut encoder = dev
1103            .device
1104            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1105                label: Some("oxicuda-reduce"),
1106            });
1107
1108        {
1109            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1110                label: Some("oxicuda-reduce-pass1"),
1111                timestamp_writes: None,
1112            });
1113            pass.set_pipeline(&pass1_pipeline);
1114            pass.set_bind_group(0, &bg1, &[]);
1115            pass.dispatch_workgroups(wg_count, 1, 1);
1116        }
1117        {
1118            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1119                label: Some("oxicuda-reduce-pass2"),
1120                timestamp_writes: None,
1121            });
1122            pass.set_pipeline(&pass2_pipeline);
1123            pass.set_bind_group(0, &bg2, &[]);
1124            pass.dispatch_workgroups(1, 1, 1);
1125        }
1126
1127        dev.queue.submit(std::iter::once(encoder.finish()));
1128        let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
1129
1130        // For "mean", divide the result by N on the host side.
1131        if op == ReduceOp::Mean && n_elements > 1 {
1132            let mut buf = [0u8; 4];
1133            mem.copy_from_device(&mut buf, output_ptr)
1134                .map_err(BackendError::from)?;
1135            let val = f32::from_le_bytes(buf);
1136            let mean = val / (n_elements as f32);
1137            mem.copy_to_device(output_ptr, &mean.to_le_bytes())
1138                .map_err(BackendError::from)?;
1139        }
1140
1141        Ok(())
1142    }
1143
1144    fn unary(&self, op: UnaryOp, input_ptr: u64, output_ptr: u64, n: usize) -> BackendResult<()> {
1145        self.check_init()?;
1146        if n == 0 {
1147            return Ok(());
1148        }
1149
1150        let dev = self.device()?;
1151        let mem = self.memory()?;
1152
1153        let op_str = map_unary_op(op);
1154        let wgsl = shader::elementwise_wgsl(op_str);
1155
1156        let shader_mod = dev
1157            .device
1158            .create_shader_module(wgpu::ShaderModuleDescriptor {
1159                label: Some("oxicuda-unary"),
1160                source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1161            });
1162
1163        let pipeline = dev
1164            .device
1165            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1166                label: Some("oxicuda-unary"),
1167                layout: None,
1168                module: &shader_mod,
1169                entry_point: Some("main"),
1170                compilation_options: Default::default(),
1171                cache: None,
1172            });
1173
1174        let bgl = pipeline.get_bind_group_layout(0);
1175
1176        let bind_group = {
1177            let buffers = mem
1178                .lock_buffers()
1179                .map_err(|e| BackendError::DeviceError(e.to_string()))?;
1180            let in_info = buffers.get(&input_ptr).ok_or_else(|| {
1181                BackendError::InvalidArgument(format!("unknown handle {input_ptr}"))
1182            })?;
1183            let out_info = buffers.get(&output_ptr).ok_or_else(|| {
1184                BackendError::InvalidArgument(format!("unknown handle {output_ptr}"))
1185            })?;
1186
1187            dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
1188                label: Some("oxicuda-unary"),
1189                layout: &bgl,
1190                entries: &[
1191                    wgpu::BindGroupEntry {
1192                        binding: 0,
1193                        resource: in_info.buffer.as_entire_binding(),
1194                    },
1195                    wgpu::BindGroupEntry {
1196                        binding: 1,
1197                        resource: out_info.buffer.as_entire_binding(),
1198                    },
1199                ],
1200            })
1201        };
1202
1203        let mut encoder = dev
1204            .device
1205            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1206                label: Some("oxicuda-unary"),
1207            });
1208
1209        {
1210            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1211                label: Some("oxicuda-unary"),
1212                timestamp_writes: None,
1213            });
1214            pass.set_pipeline(&pipeline);
1215            pass.set_bind_group(0, &bind_group, &[]);
1216            let workgroups = (n as u32).div_ceil(256);
1217            pass.dispatch_workgroups(workgroups, 1, 1);
1218        }
1219
1220        dev.queue.submit(std::iter::once(encoder.finish()));
1221        let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
1222
1223        Ok(())
1224    }
1225
1226    fn binary(
1227        &self,
1228        op: BinaryOp,
1229        a_ptr: u64,
1230        b_ptr: u64,
1231        output_ptr: u64,
1232        n: usize,
1233    ) -> BackendResult<()> {
1234        self.check_init()?;
1235        if n == 0 {
1236            return Ok(());
1237        }
1238
1239        let dev = self.device()?;
1240        let mem = self.memory()?;
1241
1242        let op_str = map_binary_op(op);
1243        let wgsl = shader::binary_wgsl(op_str);
1244
1245        let shader_mod = dev
1246            .device
1247            .create_shader_module(wgpu::ShaderModuleDescriptor {
1248                label: Some("oxicuda-binary"),
1249                source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1250            });
1251
1252        let pipeline = dev
1253            .device
1254            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1255                label: Some("oxicuda-binary"),
1256                layout: None,
1257                module: &shader_mod,
1258                entry_point: Some("main"),
1259                compilation_options: Default::default(),
1260                cache: None,
1261            });
1262
1263        let bgl = pipeline.get_bind_group_layout(0);
1264
1265        let bind_group = {
1266            let buffers = mem
1267                .lock_buffers()
1268                .map_err(|e| BackendError::DeviceError(e.to_string()))?;
1269            let a_info = buffers
1270                .get(&a_ptr)
1271                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {a_ptr}")))?;
1272            let b_info = buffers
1273                .get(&b_ptr)
1274                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {b_ptr}")))?;
1275            let out_info = buffers.get(&output_ptr).ok_or_else(|| {
1276                BackendError::InvalidArgument(format!("unknown handle {output_ptr}"))
1277            })?;
1278
1279            dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
1280                label: Some("oxicuda-binary"),
1281                layout: &bgl,
1282                entries: &[
1283                    wgpu::BindGroupEntry {
1284                        binding: 0,
1285                        resource: a_info.buffer.as_entire_binding(),
1286                    },
1287                    wgpu::BindGroupEntry {
1288                        binding: 1,
1289                        resource: b_info.buffer.as_entire_binding(),
1290                    },
1291                    wgpu::BindGroupEntry {
1292                        binding: 2,
1293                        resource: out_info.buffer.as_entire_binding(),
1294                    },
1295                ],
1296            })
1297        };
1298
1299        let mut encoder = dev
1300            .device
1301            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1302                label: Some("oxicuda-binary"),
1303            });
1304
1305        {
1306            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1307                label: Some("oxicuda-binary"),
1308                timestamp_writes: None,
1309            });
1310            pass.set_pipeline(&pipeline);
1311            pass.set_bind_group(0, &bind_group, &[]);
1312            let workgroups = (n as u32).div_ceil(256);
1313            pass.dispatch_workgroups(workgroups, 1, 1);
1314        }
1315
1316        dev.queue.submit(std::iter::once(encoder.finish()));
1317        let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
1318
1319        Ok(())
1320    }
1321
1322    // ── Synchronisation ───────────────────────────────────────────────────────
1323
1324    fn synchronize(&self) -> BackendResult<()> {
1325        self.check_init()?;
1326        if let Some(dev) = &self.device {
1327            let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
1328        }
1329        Ok(())
1330    }
1331
1332    // ── Memory management ─────────────────────────────────────────────────────
1333
1334    fn alloc(&self, bytes: usize) -> BackendResult<u64> {
1335        self.check_init()?;
1336        if bytes == 0 {
1337            return Err(BackendError::InvalidArgument(
1338                "cannot allocate 0 bytes".into(),
1339            ));
1340        }
1341        self.memory()?.alloc(bytes).map_err(BackendError::from)
1342    }
1343
1344    fn free(&self, ptr: u64) -> BackendResult<()> {
1345        self.check_init()?;
1346        self.memory()?.free(ptr).map_err(BackendError::from)
1347    }
1348
1349    fn copy_htod(&self, dst: u64, src: &[u8]) -> BackendResult<()> {
1350        self.check_init()?;
1351        if src.is_empty() {
1352            return Ok(());
1353        }
1354        self.memory()?
1355            .copy_to_device(dst, src)
1356            .map_err(BackendError::from)
1357    }
1358
1359    fn copy_dtoh(&self, dst: &mut [u8], src: u64) -> BackendResult<()> {
1360        self.check_init()?;
1361        if dst.is_empty() {
1362            return Ok(());
1363        }
1364        self.memory()?
1365            .copy_from_device(dst, src)
1366            .map_err(BackendError::from)
1367    }
1368}
1369
1370// ─── Byte ↔ f32 helpers ──────────────────────────────────────────────────────
1371
1372/// Convert a `&[u8]` (length must be a multiple of 4) to a `Vec<f32>`.
1373fn bytes_to_f32_vec(bytes: &[u8]) -> Vec<f32> {
1374    bytes
1375        .chunks_exact(4)
1376        .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
1377        .collect()
1378}
1379
1380/// Convert a `&[f32]` slice to its little-endian byte representation.
1381fn f32_slice_to_bytes(data: &[f32]) -> Vec<u8> {
1382    data.iter().flat_map(|v| v.to_le_bytes()).collect()
1383}
1384
1385// ─── Tests ───────────────────────────────────────────────────────────────────
1386//
1387// The test module lives in a sibling file (`backend_tests.rs`) so the
1388// production code in this file stays under the 2 000-line refactoring policy.
1389#[cfg(test)]
1390#[path = "backend_tests.rs"]
1391mod tests;