Skip to main content

oxicuda_webgpu/
backend.rs

1//! [`WebGpuBackend`] — the main entry point for the oxicuda-webgpu crate.
2//!
3//! Implements the [`ComputeBackend`] trait from `oxicuda-backend` using
4//! `wgpu` for cross-platform GPU compute (Vulkan, Metal, DX12, WebGPU).
5
6use std::sync::Arc;
7
8use oxicuda_backend::{
9    BackendError, BackendResult, BackendTranspose, BinaryOp, ComputeBackend, ReduceOp, UnaryOp,
10};
11use wgpu;
12
13use crate::{device::WebGpuDevice, memory::WebGpuMemoryManager, shader};
14
15// ─── Op-mapping helpers ──────────────────────────────────────────────────────
16
17fn map_unary_op(op: UnaryOp) -> &'static str {
18    match op {
19        UnaryOp::Relu => "relu",
20        UnaryOp::Sigmoid => "sigmoid",
21        UnaryOp::Tanh => "tanh",
22        UnaryOp::Exp => "exp",
23        UnaryOp::Log => "log",
24        UnaryOp::Sqrt => "sqrt",
25        UnaryOp::Abs => "abs",
26        UnaryOp::Neg => "neg",
27    }
28}
29
30fn map_binary_op(op: BinaryOp) -> &'static str {
31    match op {
32        BinaryOp::Add => "add",
33        BinaryOp::Sub => "sub",
34        BinaryOp::Mul => "mul",
35        BinaryOp::Div => "div",
36        BinaryOp::Max => "max",
37        BinaryOp::Min => "min",
38    }
39}
40
41fn map_reduce_op(op: ReduceOp) -> &'static str {
42    match op {
43        ReduceOp::Sum => "sum",
44        ReduceOp::Max => "max",
45        ReduceOp::Min => "min",
46        ReduceOp::Mean => "mean",
47    }
48}
49
50// ─── Backend struct ──────────────────────────────────────────────────────────
51
52/// Cross-platform GPU compute backend backed by `wgpu`.
53///
54/// # Lifecycle
55///
56/// 1. `WebGpuBackend::new()` — create an uninitialised backend.
57/// 2. `init()` — select the best available adapter and create the device.
58/// 3. Use `alloc`, `copy_htod`, compute ops, `copy_dtoh`, `free`.
59/// 4. `synchronize()` — wait for all pending GPU work to finish.
60#[derive(Debug)]
61pub struct WebGpuBackend {
62    device: Option<Arc<WebGpuDevice>>,
63    memory: Option<Arc<WebGpuMemoryManager>>,
64    initialized: bool,
65}
66
67impl WebGpuBackend {
68    /// Create a new, uninitialised WebGPU backend.
69    pub fn new() -> Self {
70        Self {
71            device: None,
72            memory: None,
73            initialized: false,
74        }
75    }
76
77    /// Return an error if the backend is not yet initialised.
78    fn check_init(&self) -> BackendResult<()> {
79        if self.initialized {
80            Ok(())
81        } else {
82            Err(BackendError::NotInitialized)
83        }
84    }
85
86    /// Convenience accessor: get the memory manager or return `NotInitialized`.
87    fn memory(&self) -> BackendResult<&Arc<WebGpuMemoryManager>> {
88        self.memory.as_ref().ok_or(BackendError::NotInitialized)
89    }
90
91    /// Convenience accessor: get the device or return `NotInitialized`.
92    fn device(&self) -> BackendResult<&Arc<WebGpuDevice>> {
93        self.device.as_ref().ok_or(BackendError::NotInitialized)
94    }
95
96    /// Multi-dimensional reduce along a single axis.
97    ///
98    /// The tensor is logically reshaped to `[outer, dk, inner]`:
99    /// * `outer` = product of dimensions before the reduce axis,
100    /// * `dk`    = the reduce axis length,
101    /// * `inner` = product of dimensions after the reduce axis.
102    ///
103    /// One workgroup of 256 threads is dispatched per `(o, j)` output slot.
104    /// To stay within WebGPU's 65 535-per-axis dispatch limit a 2-D grid is
105    /// used and the workgroup decodes its linear slot internally.
106    ///
107    /// `Mean` is handled inside the shader (divide by `dk`); the host does
108    /// not need a post-pass.
109    fn reduce_nd(
110        &self,
111        op: ReduceOp,
112        input_ptr: u64,
113        output_ptr: u64,
114        shape: &[usize],
115        axis: usize,
116    ) -> BackendResult<()> {
117        // Caller (`reduce`) already validated `shape.is_empty()` and
118        // `axis < shape.len()`; assert in debug to catch regressions but
119        // recompute defensively in release as well.
120        debug_assert!(!shape.is_empty());
121        debug_assert!(axis < shape.len());
122
123        // Output shape = shape with `axis` removed; length = outer * inner.
124        let outer: usize = shape[..axis].iter().product();
125        let dk: usize = shape[axis];
126        let inner: usize = shape[axis + 1..].iter().product();
127
128        // Empty tensor — nothing to do.
129        if outer == 0 || dk == 0 || inner == 0 {
130            return Ok(());
131        }
132
133        let total = outer.checked_mul(inner).ok_or_else(|| {
134            BackendError::InvalidArgument("reduce: outer * inner overflows usize".into())
135        })?;
136
137        // Strides in elements: row-major (C order) layout.
138        let inner_stride: usize = 1;
139        let dk_stride: usize = inner;
140        let outer_stride: usize = dk
141            .checked_mul(inner)
142            .ok_or_else(|| BackendError::InvalidArgument("reduce: dk * inner overflows".into()))?;
143
144        // Cap each dispatch dimension below the WebGPU 65 535 limit.  We pick
145        // grid_x = min(total, 32 768) so grid_y stays modest for huge tensors.
146        const MAX_GRID_DIM: u32 = 32_768;
147        let total_u32: u32 = total.try_into().map_err(|_| {
148            BackendError::InvalidArgument(format!(
149                "reduce: output element count {total} exceeds u32 range"
150            ))
151        })?;
152        let grid_x: u32 = total_u32.clamp(1, MAX_GRID_DIM);
153        let grid_y: u32 = total_u32.div_ceil(grid_x);
154
155        let dev = self.device()?;
156        let mem = self.memory()?;
157        let op_str = map_reduce_op(op);
158
159        let wgsl = shader::reduction_nd_wgsl(op_str);
160        let shader_mod = dev
161            .device
162            .create_shader_module(wgpu::ShaderModuleDescriptor {
163                label: Some("oxicuda-reduce-nd"),
164                source: wgpu::ShaderSource::Wgsl(wgsl.into()),
165            });
166        let pipeline = dev
167            .device
168            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
169                label: Some("oxicuda-reduce-nd"),
170                layout: None,
171                module: &shader_mod,
172                entry_point: Some("main"),
173                compilation_options: Default::default(),
174                cache: None,
175            });
176
177        // Build the uniform buffer: 8 × u32 = 32 bytes (16-byte aligned).
178        let mut params_bytes = [0u8; 32];
179        let outer_u32: u32 = outer
180            .try_into()
181            .map_err(|_| BackendError::InvalidArgument("reduce: outer exceeds u32 range".into()))?;
182        let dk_u32: u32 = dk
183            .try_into()
184            .map_err(|_| BackendError::InvalidArgument("reduce: dk exceeds u32 range".into()))?;
185        let inner_u32: u32 = inner
186            .try_into()
187            .map_err(|_| BackendError::InvalidArgument("reduce: inner exceeds u32 range".into()))?;
188        let outer_stride_u32: u32 = outer_stride.try_into().map_err(|_| {
189            BackendError::InvalidArgument("reduce: outer_stride exceeds u32 range".into())
190        })?;
191        let dk_stride_u32: u32 = dk_stride.try_into().map_err(|_| {
192            BackendError::InvalidArgument("reduce: dk_stride exceeds u32 range".into())
193        })?;
194        let inner_stride_u32: u32 = inner_stride.try_into().map_err(|_| {
195            BackendError::InvalidArgument("reduce: inner_stride exceeds u32 range".into())
196        })?;
197        params_bytes[0..4].copy_from_slice(&outer_u32.to_le_bytes());
198        params_bytes[4..8].copy_from_slice(&dk_u32.to_le_bytes());
199        params_bytes[8..12].copy_from_slice(&inner_u32.to_le_bytes());
200        params_bytes[12..16].copy_from_slice(&outer_stride_u32.to_le_bytes());
201        params_bytes[16..20].copy_from_slice(&dk_stride_u32.to_le_bytes());
202        params_bytes[20..24].copy_from_slice(&inner_stride_u32.to_le_bytes());
203        params_bytes[24..28].copy_from_slice(&grid_x.to_le_bytes());
204        // bytes 28..32 are zero padding.
205
206        let uniform_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
207            label: Some("oxicuda-reduce-nd-params"),
208            size: 32,
209            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
210            mapped_at_creation: false,
211        });
212        dev.queue.write_buffer(&uniform_buf, 0, &params_bytes);
213
214        let bgl = pipeline.get_bind_group_layout(0);
215        let bind_group = {
216            let buffers = mem
217                .lock_buffers()
218                .map_err(|e| BackendError::DeviceError(e.to_string()))?;
219            let in_info = buffers.get(&input_ptr).ok_or_else(|| {
220                BackendError::InvalidArgument(format!("unknown handle {input_ptr}"))
221            })?;
222            let out_info = buffers.get(&output_ptr).ok_or_else(|| {
223                BackendError::InvalidArgument(format!("unknown handle {output_ptr}"))
224            })?;
225
226            dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
227                label: Some("oxicuda-reduce-nd"),
228                layout: &bgl,
229                entries: &[
230                    wgpu::BindGroupEntry {
231                        binding: 0,
232                        resource: in_info.buffer.as_entire_binding(),
233                    },
234                    wgpu::BindGroupEntry {
235                        binding: 1,
236                        resource: out_info.buffer.as_entire_binding(),
237                    },
238                    wgpu::BindGroupEntry {
239                        binding: 2,
240                        resource: uniform_buf.as_entire_binding(),
241                    },
242                ],
243            })
244        };
245
246        let mut encoder = dev
247            .device
248            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
249                label: Some("oxicuda-reduce-nd"),
250            });
251        {
252            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
253                label: Some("oxicuda-reduce-nd"),
254                timestamp_writes: None,
255            });
256            pass.set_pipeline(&pipeline);
257            pass.set_bind_group(0, &bind_group, &[]);
258            pass.dispatch_workgroups(grid_x, grid_y, 1);
259        }
260
261        dev.queue.submit(std::iter::once(encoder.finish()));
262        let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
263
264        Ok(())
265    }
266}
267
268impl WebGpuBackend {
269    /// FP16 GEMM: `C = alpha * A * B + beta * C` with half-precision storage.
270    ///
271    /// This is an inherent method (not on `ComputeBackend`) because FP16
272    /// support is WebGPU-specific and requires the `f16` WGSL extension.
273    ///
274    /// Buffers pointed to by `a_ptr`, `b_ptr`, `c_ptr` must contain `f16`
275    /// elements (2 bytes each).
276    #[allow(clippy::too_many_arguments)]
277    pub fn gemm_f16(
278        &self,
279        m: usize,
280        n: usize,
281        k: usize,
282        alpha: f64,
283        a_ptr: u64,
284        b_ptr: u64,
285        beta: f64,
286        c_ptr: u64,
287    ) -> BackendResult<()> {
288        self.check_init()?;
289        if m == 0 || n == 0 || k == 0 {
290            return Ok(());
291        }
292
293        let dev = self.device()?;
294        let mem = self.memory()?;
295
296        let tile_size: u32 = 8;
297        let wgsl = shader::gemm_wgsl_f16(tile_size);
298
299        let shader_mod = dev
300            .device
301            .create_shader_module(wgpu::ShaderModuleDescriptor {
302                label: Some("oxicuda-gemm-f16"),
303                source: wgpu::ShaderSource::Wgsl(wgsl.into()),
304            });
305
306        let pipeline = dev
307            .device
308            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
309                label: Some("oxicuda-gemm-f16"),
310                layout: None,
311                module: &shader_mod,
312                entry_point: Some("main"),
313                compilation_options: Default::default(),
314                cache: None,
315            });
316
317        let bgl = pipeline.get_bind_group_layout(0);
318
319        // Build uniform buffer for GemmParams { m, n, k, alpha, beta }.
320        let mut params_bytes = [0u8; 20];
321        params_bytes[0..4].copy_from_slice(&(m as u32).to_le_bytes());
322        params_bytes[4..8].copy_from_slice(&(n as u32).to_le_bytes());
323        params_bytes[8..12].copy_from_slice(&(k as u32).to_le_bytes());
324        params_bytes[12..16].copy_from_slice(&(alpha as f32).to_le_bytes());
325        params_bytes[16..20].copy_from_slice(&(beta as f32).to_le_bytes());
326
327        let uniform_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
328            label: Some("oxicuda-gemm-f16-params"),
329            size: 20,
330            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
331            mapped_at_creation: false,
332        });
333        dev.queue.write_buffer(&uniform_buf, 0, &params_bytes);
334
335        let bind_group = {
336            let buffers = mem
337                .lock_buffers()
338                .map_err(|e| BackendError::DeviceError(e.to_string()))?;
339            let a_info = buffers
340                .get(&a_ptr)
341                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {a_ptr}")))?;
342            let b_info = buffers
343                .get(&b_ptr)
344                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {b_ptr}")))?;
345            let c_info = buffers
346                .get(&c_ptr)
347                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {c_ptr}")))?;
348
349            dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
350                label: Some("oxicuda-gemm-f16"),
351                layout: &bgl,
352                entries: &[
353                    wgpu::BindGroupEntry {
354                        binding: 0,
355                        resource: a_info.buffer.as_entire_binding(),
356                    },
357                    wgpu::BindGroupEntry {
358                        binding: 1,
359                        resource: b_info.buffer.as_entire_binding(),
360                    },
361                    wgpu::BindGroupEntry {
362                        binding: 2,
363                        resource: c_info.buffer.as_entire_binding(),
364                    },
365                    wgpu::BindGroupEntry {
366                        binding: 3,
367                        resource: uniform_buf.as_entire_binding(),
368                    },
369                ],
370            })
371        };
372
373        let mut encoder = dev
374            .device
375            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
376                label: Some("oxicuda-gemm-f16"),
377            });
378
379        {
380            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
381                label: Some("oxicuda-gemm-f16"),
382                timestamp_writes: None,
383            });
384            pass.set_pipeline(&pipeline);
385            pass.set_bind_group(0, &bind_group, &[]);
386            let wg_x = (n as u32).div_ceil(tile_size);
387            let wg_y = (m as u32).div_ceil(tile_size);
388            pass.dispatch_workgroups(wg_x, wg_y, 1);
389        }
390
391        dev.queue.submit(std::iter::once(encoder.finish()));
392        let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
393
394        Ok(())
395    }
396}
397
398impl Default for WebGpuBackend {
399    fn default() -> Self {
400        Self::new()
401    }
402}
403
404// ─── ComputeBackend impl ─────────────────────────────────────────────────────
405
406impl ComputeBackend for WebGpuBackend {
407    fn name(&self) -> &str {
408        "webgpu"
409    }
410
411    fn init(&mut self) -> BackendResult<()> {
412        if self.initialized {
413            return Ok(());
414        }
415
416        match WebGpuDevice::new() {
417            Ok(dev) => {
418                let dev = Arc::new(dev);
419                tracing::info!("WebGPU backend initialised on: {}", dev.adapter_name);
420                let memory = WebGpuMemoryManager::new(Arc::clone(&dev));
421                self.device = Some(dev);
422                self.memory = Some(Arc::new(memory));
423                self.initialized = true;
424                Ok(())
425            }
426            Err(e) => Err(BackendError::from(e)),
427        }
428    }
429
430    fn is_initialized(&self) -> bool {
431        self.initialized
432    }
433
434    // ── Compute operations ────────────────────────────────────────────────────
435
436    fn gemm(
437        &self,
438        trans_a: BackendTranspose,
439        trans_b: BackendTranspose,
440        m: usize,
441        n: usize,
442        k: usize,
443        alpha: f64,
444        a_ptr: u64,
445        _lda: usize,
446        b_ptr: u64,
447        _ldb: usize,
448        beta: f64,
449        c_ptr: u64,
450        _ldc: usize,
451    ) -> BackendResult<()> {
452        self.check_init()?;
453        // Zero-dimension matrices are trivially done.
454        if m == 0 || n == 0 || k == 0 {
455            return Ok(());
456        }
457
458        // The WGSL tiled GEMM kernel handles every NN / NT / TN / TT
459        // combination at runtime via the `trans_a` / `trans_b` uniforms.
460        // `ConjTrans` collapses to `Trans` because the f32 buffers are real.
461        let trans_a_flag: u32 = u32::from(trans_a != BackendTranspose::NoTrans);
462        let trans_b_flag: u32 = u32::from(trans_b != BackendTranspose::NoTrans);
463
464        let dev = self.device()?;
465        let mem = self.memory()?;
466
467        let tile_size: u32 = 8;
468        let wgsl = shader::gemm_wgsl(tile_size);
469
470        let shader_mod = dev
471            .device
472            .create_shader_module(wgpu::ShaderModuleDescriptor {
473                label: Some("oxicuda-gemm"),
474                source: wgpu::ShaderSource::Wgsl(wgsl.into()),
475            });
476
477        let pipeline = dev
478            .device
479            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
480                label: Some("oxicuda-gemm"),
481                layout: None,
482                module: &shader_mod,
483                entry_point: Some("main"),
484                compilation_options: Default::default(),
485                cache: None,
486            });
487
488        let bgl = pipeline.get_bind_group_layout(0);
489
490        // Build uniform buffer for GemmParams { m, n, k, alpha, beta,
491        // trans_a, trans_b, _pad } — 8 × 4 = 32 bytes (16-byte aligned).
492        let mut params_bytes = [0u8; 32];
493        params_bytes[0..4].copy_from_slice(&(m as u32).to_le_bytes());
494        params_bytes[4..8].copy_from_slice(&(n as u32).to_le_bytes());
495        params_bytes[8..12].copy_from_slice(&(k as u32).to_le_bytes());
496        params_bytes[12..16].copy_from_slice(&(alpha as f32).to_le_bytes());
497        params_bytes[16..20].copy_from_slice(&(beta as f32).to_le_bytes());
498        params_bytes[20..24].copy_from_slice(&trans_a_flag.to_le_bytes());
499        params_bytes[24..28].copy_from_slice(&trans_b_flag.to_le_bytes());
500        // bytes 28..32 are zero padding.
501
502        let uniform_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
503            label: Some("oxicuda-gemm-params"),
504            size: 32,
505            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
506            mapped_at_creation: false,
507        });
508        dev.queue.write_buffer(&uniform_buf, 0, &params_bytes);
509
510        // Create bind group while holding the buffer lock.
511        let bind_group = {
512            let buffers = mem
513                .lock_buffers()
514                .map_err(|e| BackendError::DeviceError(e.to_string()))?;
515            let a_info = buffers
516                .get(&a_ptr)
517                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {a_ptr}")))?;
518            let b_info = buffers
519                .get(&b_ptr)
520                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {b_ptr}")))?;
521            let c_info = buffers
522                .get(&c_ptr)
523                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {c_ptr}")))?;
524
525            dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
526                label: Some("oxicuda-gemm"),
527                layout: &bgl,
528                entries: &[
529                    wgpu::BindGroupEntry {
530                        binding: 0,
531                        resource: a_info.buffer.as_entire_binding(),
532                    },
533                    wgpu::BindGroupEntry {
534                        binding: 1,
535                        resource: b_info.buffer.as_entire_binding(),
536                    },
537                    wgpu::BindGroupEntry {
538                        binding: 2,
539                        resource: c_info.buffer.as_entire_binding(),
540                    },
541                    wgpu::BindGroupEntry {
542                        binding: 3,
543                        resource: uniform_buf.as_entire_binding(),
544                    },
545                ],
546            })
547        };
548
549        let mut encoder = dev
550            .device
551            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
552                label: Some("oxicuda-gemm"),
553            });
554
555        {
556            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
557                label: Some("oxicuda-gemm"),
558                timestamp_writes: None,
559            });
560            pass.set_pipeline(&pipeline);
561            pass.set_bind_group(0, &bind_group, &[]);
562            let wg_x = (n as u32).div_ceil(tile_size);
563            let wg_y = (m as u32).div_ceil(tile_size);
564            pass.dispatch_workgroups(wg_x, wg_y, 1);
565        }
566
567        dev.queue.submit(std::iter::once(encoder.finish()));
568        let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
569
570        Ok(())
571    }
572
573    #[allow(clippy::too_many_arguments)]
574    fn batched_gemm(
575        &self,
576        trans_a: BackendTranspose,
577        trans_b: BackendTranspose,
578        m: usize,
579        n: usize,
580        k: usize,
581        alpha: f64,
582        a_ptr: u64,
583        _lda: usize,
584        stride_a: usize,
585        b_ptr: u64,
586        _ldb: usize,
587        stride_b: usize,
588        beta: f64,
589        c_ptr: u64,
590        _ldc: usize,
591        stride_c: usize,
592        batch_count: usize,
593    ) -> BackendResult<()> {
594        self.check_init()?;
595
596        if batch_count == 0 || m == 0 || n == 0 || k == 0 {
597            return Ok(());
598        }
599
600        // The WGSL batched tiled GEMM kernel handles every NN / NT / TN / TT
601        // combination at runtime via the `trans_a` / `trans_b` uniforms.
602        // `ConjTrans` collapses to `Trans` because the f32 buffers are real.
603        let trans_a_flag: u32 = u32::from(trans_a != BackendTranspose::NoTrans);
604        let trans_b_flag: u32 = u32::from(trans_b != BackendTranspose::NoTrans);
605
606        let dev = self.device()?;
607        let mem = self.memory()?;
608
609        let tile_size: u32 = 8;
610        let wgsl = shader::batched_gemm_wgsl(tile_size);
611
612        let shader_mod = dev
613            .device
614            .create_shader_module(wgpu::ShaderModuleDescriptor {
615                label: Some("oxicuda-batched-gemm"),
616                source: wgpu::ShaderSource::Wgsl(wgsl.into()),
617            });
618
619        let pipeline = dev
620            .device
621            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
622                label: Some("oxicuda-batched-gemm"),
623                layout: None,
624                module: &shader_mod,
625                entry_point: Some("main"),
626                compilation_options: Default::default(),
627                cache: None,
628            });
629
630        let bgl = pipeline.get_bind_group_layout(0);
631
632        // BatchedGemmParams: m, n, k, alpha, beta, batch_count, stride_a,
633        // stride_b, stride_c, trans_a, trans_b — 11 × 4 = 44 bytes.
634        // Uniform buffers need 16-byte alignment, so 44 rounds up to 48.
635        let mut params_bytes = [0u8; 48];
636        params_bytes[0..4].copy_from_slice(&(m as u32).to_le_bytes());
637        params_bytes[4..8].copy_from_slice(&(n as u32).to_le_bytes());
638        params_bytes[8..12].copy_from_slice(&(k as u32).to_le_bytes());
639        params_bytes[12..16].copy_from_slice(&(alpha as f32).to_le_bytes());
640        params_bytes[16..20].copy_from_slice(&(beta as f32).to_le_bytes());
641        params_bytes[20..24].copy_from_slice(&(batch_count as u32).to_le_bytes());
642        params_bytes[24..28].copy_from_slice(&(stride_a as u32).to_le_bytes());
643        params_bytes[28..32].copy_from_slice(&(stride_b as u32).to_le_bytes());
644        params_bytes[32..36].copy_from_slice(&(stride_c as u32).to_le_bytes());
645        params_bytes[36..40].copy_from_slice(&trans_a_flag.to_le_bytes());
646        params_bytes[40..44].copy_from_slice(&trans_b_flag.to_le_bytes());
647        // bytes 44..48 are padding zeros
648
649        let uniform_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
650            label: Some("oxicuda-batched-gemm-params"),
651            size: 48,
652            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
653            mapped_at_creation: false,
654        });
655        dev.queue.write_buffer(&uniform_buf, 0, &params_bytes);
656
657        let bind_group = {
658            let buffers = mem
659                .lock_buffers()
660                .map_err(|e| BackendError::DeviceError(e.to_string()))?;
661            let a_info = buffers
662                .get(&a_ptr)
663                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {a_ptr}")))?;
664            let b_info = buffers
665                .get(&b_ptr)
666                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {b_ptr}")))?;
667            let c_info = buffers
668                .get(&c_ptr)
669                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {c_ptr}")))?;
670
671            dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
672                label: Some("oxicuda-batched-gemm"),
673                layout: &bgl,
674                entries: &[
675                    wgpu::BindGroupEntry {
676                        binding: 0,
677                        resource: a_info.buffer.as_entire_binding(),
678                    },
679                    wgpu::BindGroupEntry {
680                        binding: 1,
681                        resource: b_info.buffer.as_entire_binding(),
682                    },
683                    wgpu::BindGroupEntry {
684                        binding: 2,
685                        resource: c_info.buffer.as_entire_binding(),
686                    },
687                    wgpu::BindGroupEntry {
688                        binding: 3,
689                        resource: uniform_buf.as_entire_binding(),
690                    },
691                ],
692            })
693        };
694
695        let mut encoder = dev
696            .device
697            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
698                label: Some("oxicuda-batched-gemm"),
699            });
700
701        {
702            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
703                label: Some("oxicuda-batched-gemm"),
704                timestamp_writes: None,
705            });
706            pass.set_pipeline(&pipeline);
707            pass.set_bind_group(0, &bind_group, &[]);
708            let wg_x = (n as u32).div_ceil(tile_size);
709            let wg_y = (m as u32).div_ceil(tile_size);
710            pass.dispatch_workgroups(wg_x, wg_y, batch_count as u32);
711        }
712
713        dev.queue.submit(std::iter::once(encoder.finish()));
714        let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
715
716        Ok(())
717    }
718
719    fn conv2d_forward(
720        &self,
721        input_ptr: u64,
722        input_shape: &[usize],
723        filter_ptr: u64,
724        filter_shape: &[usize],
725        output_ptr: u64,
726        output_shape: &[usize],
727        stride: &[usize],
728        padding: &[usize],
729    ) -> BackendResult<()> {
730        self.check_init()?;
731
732        if input_shape.len() != 4 {
733            return Err(BackendError::InvalidArgument(
734                "input_shape must have 4 elements (NCHW)".into(),
735            ));
736        }
737        if filter_shape.len() != 4 {
738            return Err(BackendError::InvalidArgument(
739                "filter_shape must have 4 elements (KCFHFW)".into(),
740            ));
741        }
742        if output_shape.len() != 4 {
743            return Err(BackendError::InvalidArgument(
744                "output_shape must have 4 elements (NKOhOw)".into(),
745            ));
746        }
747        if stride.len() != 2 {
748            return Err(BackendError::InvalidArgument(
749                "stride must have 2 elements [sh, sw]".into(),
750            ));
751        }
752        if padding.len() != 2 {
753            return Err(BackendError::InvalidArgument(
754                "padding must have 2 elements [ph, pw]".into(),
755            ));
756        }
757
758        let mem = self.memory()?;
759
760        let batch = input_shape[0];
761        let c_in = input_shape[1];
762        let h_in = input_shape[2];
763        let w_in = input_shape[3];
764        let k_out = filter_shape[0];
765        let fh = filter_shape[2];
766        let fw = filter_shape[3];
767        let oh = output_shape[2];
768        let ow = output_shape[3];
769        let sh = stride[0];
770        let sw = stride[1];
771        let ph = padding[0];
772        let pw = padding[1];
773
774        let in_elems: usize = input_shape.iter().product();
775        let f_elems: usize = filter_shape.iter().product();
776        let o_elems: usize = output_shape.iter().product();
777
778        // CPU fallback: download input + filter, compute, upload output.
779        let mut in_bytes = vec![0u8; in_elems * 4];
780        let mut f_bytes = vec![0u8; f_elems * 4];
781        mem.copy_from_device(&mut in_bytes, input_ptr)
782            .map_err(BackendError::from)?;
783        mem.copy_from_device(&mut f_bytes, filter_ptr)
784            .map_err(BackendError::from)?;
785
786        let in_f32 = bytes_to_f32_vec(&in_bytes);
787        let f_f32 = bytes_to_f32_vec(&f_bytes);
788        let mut out_f32 = vec![0.0f32; o_elems];
789
790        for b in 0..batch {
791            for kf in 0..k_out {
792                for oy in 0..oh {
793                    for ox in 0..ow {
794                        let mut acc = 0.0f32;
795                        for ci in 0..c_in {
796                            for fy in 0..fh {
797                                for fx in 0..fw {
798                                    let iy = (oy * sh + fy) as isize - ph as isize;
799                                    let ix = (ox * sw + fx) as isize - pw as isize;
800                                    if iy >= 0
801                                        && (iy as usize) < h_in
802                                        && ix >= 0
803                                        && (ix as usize) < w_in
804                                    {
805                                        let in_idx = ((b * c_in + ci) * h_in + iy as usize) * w_in
806                                            + ix as usize;
807                                        let f_idx = ((kf * c_in + ci) * fh + fy) * fw + fx;
808                                        acc += in_f32[in_idx] * f_f32[f_idx];
809                                    }
810                                }
811                            }
812                        }
813                        out_f32[((b * k_out + kf) * oh + oy) * ow + ox] = acc;
814                    }
815                }
816            }
817        }
818
819        let out_bytes = f32_slice_to_bytes(&out_f32);
820        mem.copy_to_device(output_ptr, &out_bytes)
821            .map_err(BackendError::from)?;
822
823        Ok(())
824    }
825
826    fn attention(
827        &self,
828        q_ptr: u64,
829        k_ptr: u64,
830        v_ptr: u64,
831        o_ptr: u64,
832        batch: usize,
833        heads: usize,
834        seq_q: usize,
835        seq_kv: usize,
836        head_dim: usize,
837        scale: f64,
838        causal: bool,
839    ) -> BackendResult<()> {
840        self.check_init()?;
841
842        if seq_q == 0 || seq_kv == 0 || head_dim == 0 {
843            return Err(BackendError::InvalidArgument(
844                "seq_q, seq_kv, and head_dim must all be > 0".into(),
845            ));
846        }
847        if scale <= 0.0 || !scale.is_finite() {
848            return Err(BackendError::InvalidArgument(format!(
849                "scale must be a positive finite number, got {scale}"
850            )));
851        }
852
853        let mem = self.memory()?;
854
855        let batch_heads = batch * heads;
856        let q_elems = batch_heads * seq_q * head_dim;
857        let kv_elems = batch_heads * seq_kv * head_dim;
858        let o_elems = q_elems;
859
860        // CPU fallback: download Q, K, V, compute attention, upload O.
861        let mut q_bytes = vec![0u8; q_elems * 4];
862        let mut k_bytes = vec![0u8; kv_elems * 4];
863        let mut v_bytes = vec![0u8; kv_elems * 4];
864
865        mem.copy_from_device(&mut q_bytes, q_ptr)
866            .map_err(BackendError::from)?;
867        mem.copy_from_device(&mut k_bytes, k_ptr)
868            .map_err(BackendError::from)?;
869        mem.copy_from_device(&mut v_bytes, v_ptr)
870            .map_err(BackendError::from)?;
871
872        let q_f32 = bytes_to_f32_vec(&q_bytes);
873        let k_f32 = bytes_to_f32_vec(&k_bytes);
874        let v_f32 = bytes_to_f32_vec(&v_bytes);
875        let mut o_f32 = vec![0.0f32; o_elems];
876
877        let scale_f32 = scale as f32;
878
879        for bh in 0..batch_heads {
880            let q_off = bh * seq_q * head_dim;
881            let k_off = bh * seq_kv * head_dim;
882            let v_off = k_off;
883
884            for sq in 0..seq_q {
885                let kv_limit = if causal { (sq + 1).min(seq_kv) } else { seq_kv };
886
887                // Pass 1: find max score for numerical stability
888                let mut max_score = f32::NEG_INFINITY;
889                for sk in 0..kv_limit {
890                    let mut dot = 0.0f32;
891                    for dd in 0..head_dim {
892                        dot +=
893                            q_f32[q_off + sq * head_dim + dd] * k_f32[k_off + sk * head_dim + dd];
894                    }
895                    let s = dot * scale_f32;
896                    if s > max_score {
897                        max_score = s;
898                    }
899                }
900
901                // Pass 2: exp(score - max), accumulate weighted V
902                let mut sum_exp = 0.0f32;
903                let mut acc = vec![0.0f32; head_dim];
904                for sk in 0..kv_limit {
905                    let mut dot = 0.0f32;
906                    for dd in 0..head_dim {
907                        dot +=
908                            q_f32[q_off + sq * head_dim + dd] * k_f32[k_off + sk * head_dim + dd];
909                    }
910                    let w = (dot * scale_f32 - max_score).exp();
911                    sum_exp += w;
912                    for dd in 0..head_dim {
913                        acc[dd] += w * v_f32[v_off + sk * head_dim + dd];
914                    }
915                }
916
917                // Normalise
918                let o_base = q_off + sq * head_dim;
919                if sum_exp > 0.0 {
920                    for dd in 0..head_dim {
921                        o_f32[o_base + dd] = acc[dd] / sum_exp;
922                    }
923                }
924            }
925        }
926
927        let o_bytes = f32_slice_to_bytes(&o_f32);
928        mem.copy_to_device(o_ptr, &o_bytes)
929            .map_err(BackendError::from)?;
930
931        Ok(())
932    }
933
934    fn reduce(
935        &self,
936        op: ReduceOp,
937        input_ptr: u64,
938        output_ptr: u64,
939        shape: &[usize],
940        axis: usize,
941    ) -> BackendResult<()> {
942        self.check_init()?;
943
944        if shape.is_empty() {
945            return Err(BackendError::InvalidArgument(
946                "shape must not be empty".into(),
947            ));
948        }
949        if axis >= shape.len() {
950            return Err(BackendError::InvalidArgument(format!(
951                "axis {axis} is out of bounds for shape of length {}",
952                shape.len()
953            )));
954        }
955
956        // 1-D shapes (or any shape that reduces to a single scalar) take the
957        // optimised two-pass scalar path.  Higher-rank shapes go through the
958        // batched N-D shader below.
959        if shape.len() != 1 {
960            return self.reduce_nd(op, input_ptr, output_ptr, shape, axis);
961        }
962
963        let n_elements = shape[0];
964        if n_elements == 0 {
965            return Ok(());
966        }
967
968        let dev = self.device()?;
969        let mem = self.memory()?;
970        let op_str = map_reduce_op(op);
971
972        // ── Pass 1: per-workgroup reduction ─────────────────────────────────
973        let wg_count = (n_elements as u32).div_ceil(256);
974
975        let pass1_wgsl = shader::reduction_wgsl(op_str);
976        let pass1_shader = dev
977            .device
978            .create_shader_module(wgpu::ShaderModuleDescriptor {
979                label: Some("oxicuda-reduce-pass1"),
980                source: wgpu::ShaderSource::Wgsl(pass1_wgsl.into()),
981            });
982        let pass1_pipeline = dev
983            .device
984            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
985                label: Some("oxicuda-reduce-pass1"),
986                layout: None,
987                module: &pass1_shader,
988                entry_point: Some("main"),
989                compilation_options: Default::default(),
990                cache: None,
991            });
992
993        // Partial-sums buffer (temporary).
994        let partial_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
995            label: Some("oxicuda-reduce-partial"),
996            size: (wg_count as u64) * 4, // f32 per workgroup
997            usage: wgpu::BufferUsages::STORAGE
998                | wgpu::BufferUsages::COPY_SRC
999                | wgpu::BufferUsages::COPY_DST,
1000            mapped_at_creation: false,
1001        });
1002
1003        // Uniform for ReduceParams { n: u32 }.
1004        let mut p1_params = [0u8; 4];
1005        p1_params[0..4].copy_from_slice(&(n_elements as u32).to_le_bytes());
1006        let p1_uniform = dev.device.create_buffer(&wgpu::BufferDescriptor {
1007            label: Some("oxicuda-reduce-p1-params"),
1008            size: 4,
1009            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
1010            mapped_at_creation: false,
1011        });
1012        dev.queue.write_buffer(&p1_uniform, 0, &p1_params);
1013
1014        let bgl1 = pass1_pipeline.get_bind_group_layout(0);
1015
1016        let bg1 = {
1017            let buffers = mem
1018                .lock_buffers()
1019                .map_err(|e| BackendError::DeviceError(e.to_string()))?;
1020            let in_info = buffers.get(&input_ptr).ok_or_else(|| {
1021                BackendError::InvalidArgument(format!("unknown handle {input_ptr}"))
1022            })?;
1023
1024            dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
1025                label: Some("oxicuda-reduce-pass1"),
1026                layout: &bgl1,
1027                entries: &[
1028                    wgpu::BindGroupEntry {
1029                        binding: 0,
1030                        resource: in_info.buffer.as_entire_binding(),
1031                    },
1032                    wgpu::BindGroupEntry {
1033                        binding: 1,
1034                        resource: partial_buf.as_entire_binding(),
1035                    },
1036                    wgpu::BindGroupEntry {
1037                        binding: 2,
1038                        resource: p1_uniform.as_entire_binding(),
1039                    },
1040                ],
1041            })
1042        };
1043
1044        // ── Pass 2: final reduction of partial sums ─────────────────────────
1045        let pass2_wgsl = shader::reduction_final_wgsl(op_str);
1046        let pass2_shader = dev
1047            .device
1048            .create_shader_module(wgpu::ShaderModuleDescriptor {
1049                label: Some("oxicuda-reduce-pass2"),
1050                source: wgpu::ShaderSource::Wgsl(pass2_wgsl.into()),
1051            });
1052        let pass2_pipeline = dev
1053            .device
1054            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1055                label: Some("oxicuda-reduce-pass2"),
1056                layout: None,
1057                module: &pass2_shader,
1058                entry_point: Some("main"),
1059                compilation_options: Default::default(),
1060                cache: None,
1061            });
1062
1063        // FinalReduceParams { num_groups: u32 }.
1064        let mut p2_params = [0u8; 4];
1065        p2_params[0..4].copy_from_slice(&wg_count.to_le_bytes());
1066        let p2_uniform = dev.device.create_buffer(&wgpu::BufferDescriptor {
1067            label: Some("oxicuda-reduce-p2-params"),
1068            size: 4,
1069            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
1070            mapped_at_creation: false,
1071        });
1072        dev.queue.write_buffer(&p2_uniform, 0, &p2_params);
1073
1074        let bgl2 = pass2_pipeline.get_bind_group_layout(0);
1075
1076        let bg2 = {
1077            let buffers = mem
1078                .lock_buffers()
1079                .map_err(|e| BackendError::DeviceError(e.to_string()))?;
1080            let out_info = buffers.get(&output_ptr).ok_or_else(|| {
1081                BackendError::InvalidArgument(format!("unknown handle {output_ptr}"))
1082            })?;
1083
1084            dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
1085                label: Some("oxicuda-reduce-pass2"),
1086                layout: &bgl2,
1087                entries: &[
1088                    wgpu::BindGroupEntry {
1089                        binding: 0,
1090                        resource: partial_buf.as_entire_binding(),
1091                    },
1092                    wgpu::BindGroupEntry {
1093                        binding: 1,
1094                        resource: out_info.buffer.as_entire_binding(),
1095                    },
1096                    wgpu::BindGroupEntry {
1097                        binding: 2,
1098                        resource: p2_uniform.as_entire_binding(),
1099                    },
1100                ],
1101            })
1102        };
1103
1104        // ── Encode both passes into one command buffer ──────────────────────
1105        let mut encoder = dev
1106            .device
1107            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1108                label: Some("oxicuda-reduce"),
1109            });
1110
1111        {
1112            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1113                label: Some("oxicuda-reduce-pass1"),
1114                timestamp_writes: None,
1115            });
1116            pass.set_pipeline(&pass1_pipeline);
1117            pass.set_bind_group(0, &bg1, &[]);
1118            pass.dispatch_workgroups(wg_count, 1, 1);
1119        }
1120        {
1121            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1122                label: Some("oxicuda-reduce-pass2"),
1123                timestamp_writes: None,
1124            });
1125            pass.set_pipeline(&pass2_pipeline);
1126            pass.set_bind_group(0, &bg2, &[]);
1127            pass.dispatch_workgroups(1, 1, 1);
1128        }
1129
1130        dev.queue.submit(std::iter::once(encoder.finish()));
1131        let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
1132
1133        // For "mean", divide the result by N on the host side.
1134        if op == ReduceOp::Mean && n_elements > 1 {
1135            let mut buf = [0u8; 4];
1136            mem.copy_from_device(&mut buf, output_ptr)
1137                .map_err(BackendError::from)?;
1138            let val = f32::from_le_bytes(buf);
1139            let mean = val / (n_elements as f32);
1140            mem.copy_to_device(output_ptr, &mean.to_le_bytes())
1141                .map_err(BackendError::from)?;
1142        }
1143
1144        Ok(())
1145    }
1146
1147    fn unary(&self, op: UnaryOp, input_ptr: u64, output_ptr: u64, n: usize) -> BackendResult<()> {
1148        self.check_init()?;
1149        if n == 0 {
1150            return Ok(());
1151        }
1152
1153        let dev = self.device()?;
1154        let mem = self.memory()?;
1155
1156        let op_str = map_unary_op(op);
1157        let wgsl = shader::elementwise_wgsl(op_str);
1158
1159        let shader_mod = dev
1160            .device
1161            .create_shader_module(wgpu::ShaderModuleDescriptor {
1162                label: Some("oxicuda-unary"),
1163                source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1164            });
1165
1166        let pipeline = dev
1167            .device
1168            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1169                label: Some("oxicuda-unary"),
1170                layout: None,
1171                module: &shader_mod,
1172                entry_point: Some("main"),
1173                compilation_options: Default::default(),
1174                cache: None,
1175            });
1176
1177        let bgl = pipeline.get_bind_group_layout(0);
1178
1179        let bind_group = {
1180            let buffers = mem
1181                .lock_buffers()
1182                .map_err(|e| BackendError::DeviceError(e.to_string()))?;
1183            let in_info = buffers.get(&input_ptr).ok_or_else(|| {
1184                BackendError::InvalidArgument(format!("unknown handle {input_ptr}"))
1185            })?;
1186            let out_info = buffers.get(&output_ptr).ok_or_else(|| {
1187                BackendError::InvalidArgument(format!("unknown handle {output_ptr}"))
1188            })?;
1189
1190            dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
1191                label: Some("oxicuda-unary"),
1192                layout: &bgl,
1193                entries: &[
1194                    wgpu::BindGroupEntry {
1195                        binding: 0,
1196                        resource: in_info.buffer.as_entire_binding(),
1197                    },
1198                    wgpu::BindGroupEntry {
1199                        binding: 1,
1200                        resource: out_info.buffer.as_entire_binding(),
1201                    },
1202                ],
1203            })
1204        };
1205
1206        let mut encoder = dev
1207            .device
1208            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1209                label: Some("oxicuda-unary"),
1210            });
1211
1212        {
1213            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1214                label: Some("oxicuda-unary"),
1215                timestamp_writes: None,
1216            });
1217            pass.set_pipeline(&pipeline);
1218            pass.set_bind_group(0, &bind_group, &[]);
1219            let workgroups = (n as u32).div_ceil(256);
1220            pass.dispatch_workgroups(workgroups, 1, 1);
1221        }
1222
1223        dev.queue.submit(std::iter::once(encoder.finish()));
1224        let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
1225
1226        Ok(())
1227    }
1228
1229    fn binary(
1230        &self,
1231        op: BinaryOp,
1232        a_ptr: u64,
1233        b_ptr: u64,
1234        output_ptr: u64,
1235        n: usize,
1236    ) -> BackendResult<()> {
1237        self.check_init()?;
1238        if n == 0 {
1239            return Ok(());
1240        }
1241
1242        let dev = self.device()?;
1243        let mem = self.memory()?;
1244
1245        let op_str = map_binary_op(op);
1246        let wgsl = shader::binary_wgsl(op_str);
1247
1248        let shader_mod = dev
1249            .device
1250            .create_shader_module(wgpu::ShaderModuleDescriptor {
1251                label: Some("oxicuda-binary"),
1252                source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1253            });
1254
1255        let pipeline = dev
1256            .device
1257            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1258                label: Some("oxicuda-binary"),
1259                layout: None,
1260                module: &shader_mod,
1261                entry_point: Some("main"),
1262                compilation_options: Default::default(),
1263                cache: None,
1264            });
1265
1266        let bgl = pipeline.get_bind_group_layout(0);
1267
1268        let bind_group = {
1269            let buffers = mem
1270                .lock_buffers()
1271                .map_err(|e| BackendError::DeviceError(e.to_string()))?;
1272            let a_info = buffers
1273                .get(&a_ptr)
1274                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {a_ptr}")))?;
1275            let b_info = buffers
1276                .get(&b_ptr)
1277                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {b_ptr}")))?;
1278            let out_info = buffers.get(&output_ptr).ok_or_else(|| {
1279                BackendError::InvalidArgument(format!("unknown handle {output_ptr}"))
1280            })?;
1281
1282            dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
1283                label: Some("oxicuda-binary"),
1284                layout: &bgl,
1285                entries: &[
1286                    wgpu::BindGroupEntry {
1287                        binding: 0,
1288                        resource: a_info.buffer.as_entire_binding(),
1289                    },
1290                    wgpu::BindGroupEntry {
1291                        binding: 1,
1292                        resource: b_info.buffer.as_entire_binding(),
1293                    },
1294                    wgpu::BindGroupEntry {
1295                        binding: 2,
1296                        resource: out_info.buffer.as_entire_binding(),
1297                    },
1298                ],
1299            })
1300        };
1301
1302        let mut encoder = dev
1303            .device
1304            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1305                label: Some("oxicuda-binary"),
1306            });
1307
1308        {
1309            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1310                label: Some("oxicuda-binary"),
1311                timestamp_writes: None,
1312            });
1313            pass.set_pipeline(&pipeline);
1314            pass.set_bind_group(0, &bind_group, &[]);
1315            let workgroups = (n as u32).div_ceil(256);
1316            pass.dispatch_workgroups(workgroups, 1, 1);
1317        }
1318
1319        dev.queue.submit(std::iter::once(encoder.finish()));
1320        let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
1321
1322        Ok(())
1323    }
1324
1325    // ── Synchronisation ───────────────────────────────────────────────────────
1326
1327    fn synchronize(&self) -> BackendResult<()> {
1328        self.check_init()?;
1329        if let Some(dev) = &self.device {
1330            let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
1331        }
1332        Ok(())
1333    }
1334
1335    // ── Memory management ─────────────────────────────────────────────────────
1336
1337    fn alloc(&self, bytes: usize) -> BackendResult<u64> {
1338        self.check_init()?;
1339        if bytes == 0 {
1340            return Err(BackendError::InvalidArgument(
1341                "cannot allocate 0 bytes".into(),
1342            ));
1343        }
1344        self.memory()?.alloc(bytes).map_err(BackendError::from)
1345    }
1346
1347    fn free(&self, ptr: u64) -> BackendResult<()> {
1348        self.check_init()?;
1349        self.memory()?.free(ptr).map_err(BackendError::from)
1350    }
1351
1352    fn copy_htod(&self, dst: u64, src: &[u8]) -> BackendResult<()> {
1353        self.check_init()?;
1354        if src.is_empty() {
1355            return Ok(());
1356        }
1357        self.memory()?
1358            .copy_to_device(dst, src)
1359            .map_err(BackendError::from)
1360    }
1361
1362    fn copy_dtoh(&self, dst: &mut [u8], src: u64) -> BackendResult<()> {
1363        self.check_init()?;
1364        if dst.is_empty() {
1365            return Ok(());
1366        }
1367        self.memory()?
1368            .copy_from_device(dst, src)
1369            .map_err(BackendError::from)
1370    }
1371}
1372
1373// ─── Byte ↔ f32 helpers ──────────────────────────────────────────────────────
1374
1375/// Convert a `&[u8]` (length must be a multiple of 4) to a `Vec<f32>`.
1376fn bytes_to_f32_vec(bytes: &[u8]) -> Vec<f32> {
1377    bytes
1378        .chunks_exact(4)
1379        .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
1380        .collect()
1381}
1382
1383/// Convert a `&[f32]` slice to its little-endian byte representation.
1384fn f32_slice_to_bytes(data: &[f32]) -> Vec<u8> {
1385    data.iter().flat_map(|v| v.to_le_bytes()).collect()
1386}
1387
1388// ─── Tests ───────────────────────────────────────────────────────────────────
1389//
1390// The test module lives in a sibling file (`backend_tests.rs`) so the
1391// production code in this file stays under the 2 000-line refactoring policy.
1392#[cfg(test)]
1393#[path = "backend_tests.rs"]
1394mod tests;