Skip to main content

oxicuda_webgpu/
backend.rs

1//! [`WebGpuBackend`] — the main entry point for the oxicuda-webgpu crate.
2//!
3//! Implements the [`ComputeBackend`] trait from `oxicuda-backend` using
4//! `wgpu` for cross-platform GPU compute (Vulkan, Metal, DX12, WebGPU).
5
6use std::sync::Arc;
7
8use oxicuda_backend::{
9    BackendError, BackendResult, BackendTranspose, BinaryOp, ComputeBackend, ReduceOp, UnaryOp,
10};
11use wgpu;
12
13use crate::{device::WebGpuDevice, memory::WebGpuMemoryManager, shader};
14
15// ─── Op-mapping helpers ──────────────────────────────────────────────────────
16
17fn map_unary_op(op: UnaryOp) -> &'static str {
18    match op {
19        UnaryOp::Relu => "relu",
20        UnaryOp::Sigmoid => "sigmoid",
21        UnaryOp::Tanh => "tanh",
22        UnaryOp::Exp => "exp",
23        UnaryOp::Log => "log",
24        UnaryOp::Sqrt => "sqrt",
25        UnaryOp::Abs => "abs",
26        UnaryOp::Neg => "neg",
27    }
28}
29
30fn map_binary_op(op: BinaryOp) -> &'static str {
31    match op {
32        BinaryOp::Add => "add",
33        BinaryOp::Sub => "sub",
34        BinaryOp::Mul => "mul",
35        BinaryOp::Div => "div",
36        BinaryOp::Max => "max",
37        BinaryOp::Min => "min",
38    }
39}
40
41fn map_reduce_op(op: ReduceOp) -> &'static str {
42    match op {
43        ReduceOp::Sum => "sum",
44        ReduceOp::Max => "max",
45        ReduceOp::Min => "min",
46        ReduceOp::Mean => "mean",
47    }
48}
49
50// ─── Backend struct ──────────────────────────────────────────────────────────
51
52/// Cross-platform GPU compute backend backed by `wgpu`.
53///
54/// # Lifecycle
55///
56/// 1. `WebGpuBackend::new()` — create an uninitialised backend.
57/// 2. `init()` — select the best available adapter and create the device.
58/// 3. Use `alloc`, `copy_htod`, compute ops, `copy_dtoh`, `free`.
59/// 4. `synchronize()` — wait for all pending GPU work to finish.
60#[derive(Debug)]
61pub struct WebGpuBackend {
62    device: Option<Arc<WebGpuDevice>>,
63    memory: Option<Arc<WebGpuMemoryManager>>,
64    initialized: bool,
65}
66
67impl WebGpuBackend {
68    /// Create a new, uninitialised WebGPU backend.
69    pub fn new() -> Self {
70        Self {
71            device: None,
72            memory: None,
73            initialized: false,
74        }
75    }
76
77    /// Return an error if the backend is not yet initialised.
78    fn check_init(&self) -> BackendResult<()> {
79        if self.initialized {
80            Ok(())
81        } else {
82            Err(BackendError::NotInitialized)
83        }
84    }
85
86    /// Convenience accessor: get the memory manager or return `NotInitialized`.
87    fn memory(&self) -> BackendResult<&Arc<WebGpuMemoryManager>> {
88        self.memory.as_ref().ok_or(BackendError::NotInitialized)
89    }
90
91    /// Convenience accessor: get the device or return `NotInitialized`.
92    fn device(&self) -> BackendResult<&Arc<WebGpuDevice>> {
93        self.device.as_ref().ok_or(BackendError::NotInitialized)
94    }
95}
96
97impl Default for WebGpuBackend {
98    fn default() -> Self {
99        Self::new()
100    }
101}
102
103// ─── ComputeBackend impl ─────────────────────────────────────────────────────
104
105impl ComputeBackend for WebGpuBackend {
106    fn name(&self) -> &str {
107        "webgpu"
108    }
109
110    fn init(&mut self) -> BackendResult<()> {
111        if self.initialized {
112            return Ok(());
113        }
114
115        match WebGpuDevice::new() {
116            Ok(dev) => {
117                let dev = Arc::new(dev);
118                tracing::info!("WebGPU backend initialised on: {}", dev.adapter_name);
119                let memory = WebGpuMemoryManager::new(Arc::clone(&dev));
120                self.device = Some(dev);
121                self.memory = Some(Arc::new(memory));
122                self.initialized = true;
123                Ok(())
124            }
125            Err(e) => Err(BackendError::from(e)),
126        }
127    }
128
129    fn is_initialized(&self) -> bool {
130        self.initialized
131    }
132
133    // ── Compute operations ────────────────────────────────────────────────────
134
135    fn gemm(
136        &self,
137        trans_a: BackendTranspose,
138        trans_b: BackendTranspose,
139        m: usize,
140        n: usize,
141        k: usize,
142        alpha: f64,
143        a_ptr: u64,
144        _lda: usize,
145        b_ptr: u64,
146        _ldb: usize,
147        beta: f64,
148        c_ptr: u64,
149        _ldc: usize,
150    ) -> BackendResult<()> {
151        self.check_init()?;
152        // Zero-dimension matrices are trivially done.
153        if m == 0 || n == 0 || k == 0 {
154            return Ok(());
155        }
156
157        // Transpose not yet supported in the WGSL shader.
158        if trans_a != BackendTranspose::NoTrans || trans_b != BackendTranspose::NoTrans {
159            return Err(BackendError::Unsupported(
160                "WebGPU GEMM does not yet support transposed inputs".into(),
161            ));
162        }
163
164        let dev = self.device()?;
165        let mem = self.memory()?;
166
167        let tile_size: u32 = 8;
168        let wgsl = shader::gemm_wgsl(tile_size);
169
170        let shader_mod = dev
171            .device
172            .create_shader_module(wgpu::ShaderModuleDescriptor {
173                label: Some("oxicuda-gemm"),
174                source: wgpu::ShaderSource::Wgsl(wgsl.into()),
175            });
176
177        let pipeline = dev
178            .device
179            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
180                label: Some("oxicuda-gemm"),
181                layout: None,
182                module: &shader_mod,
183                entry_point: Some("main"),
184                compilation_options: Default::default(),
185                cache: None,
186            });
187
188        let bgl = pipeline.get_bind_group_layout(0);
189
190        // Build uniform buffer for GemmParams { m, n, k, alpha, beta }.
191        let mut params_bytes = [0u8; 20];
192        params_bytes[0..4].copy_from_slice(&(m as u32).to_le_bytes());
193        params_bytes[4..8].copy_from_slice(&(n as u32).to_le_bytes());
194        params_bytes[8..12].copy_from_slice(&(k as u32).to_le_bytes());
195        params_bytes[12..16].copy_from_slice(&(alpha as f32).to_le_bytes());
196        params_bytes[16..20].copy_from_slice(&(beta as f32).to_le_bytes());
197
198        let uniform_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
199            label: Some("oxicuda-gemm-params"),
200            size: 20,
201            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
202            mapped_at_creation: false,
203        });
204        dev.queue.write_buffer(&uniform_buf, 0, &params_bytes);
205
206        // Create bind group while holding the buffer lock.
207        let bind_group = {
208            let buffers = mem
209                .lock_buffers()
210                .map_err(|e| BackendError::DeviceError(e.to_string()))?;
211            let a_info = buffers
212                .get(&a_ptr)
213                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {a_ptr}")))?;
214            let b_info = buffers
215                .get(&b_ptr)
216                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {b_ptr}")))?;
217            let c_info = buffers
218                .get(&c_ptr)
219                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {c_ptr}")))?;
220
221            dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
222                label: Some("oxicuda-gemm"),
223                layout: &bgl,
224                entries: &[
225                    wgpu::BindGroupEntry {
226                        binding: 0,
227                        resource: a_info.buffer.as_entire_binding(),
228                    },
229                    wgpu::BindGroupEntry {
230                        binding: 1,
231                        resource: b_info.buffer.as_entire_binding(),
232                    },
233                    wgpu::BindGroupEntry {
234                        binding: 2,
235                        resource: c_info.buffer.as_entire_binding(),
236                    },
237                    wgpu::BindGroupEntry {
238                        binding: 3,
239                        resource: uniform_buf.as_entire_binding(),
240                    },
241                ],
242            })
243        };
244
245        let mut encoder = dev
246            .device
247            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
248                label: Some("oxicuda-gemm"),
249            });
250
251        {
252            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
253                label: Some("oxicuda-gemm"),
254                timestamp_writes: None,
255            });
256            pass.set_pipeline(&pipeline);
257            pass.set_bind_group(0, &bind_group, &[]);
258            let wg_x = (n as u32).div_ceil(tile_size);
259            let wg_y = (m as u32).div_ceil(tile_size);
260            pass.dispatch_workgroups(wg_x, wg_y, 1);
261        }
262
263        dev.queue.submit(std::iter::once(encoder.finish()));
264        let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
265
266        Ok(())
267    }
268
269    fn conv2d_forward(
270        &self,
271        input_ptr: u64,
272        input_shape: &[usize],
273        filter_ptr: u64,
274        filter_shape: &[usize],
275        output_ptr: u64,
276        output_shape: &[usize],
277        stride: &[usize],
278        padding: &[usize],
279    ) -> BackendResult<()> {
280        self.check_init()?;
281
282        if input_shape.len() != 4 {
283            return Err(BackendError::InvalidArgument(
284                "input_shape must have 4 elements (NCHW)".into(),
285            ));
286        }
287        if filter_shape.len() != 4 {
288            return Err(BackendError::InvalidArgument(
289                "filter_shape must have 4 elements (KCFHFW)".into(),
290            ));
291        }
292        if output_shape.len() != 4 {
293            return Err(BackendError::InvalidArgument(
294                "output_shape must have 4 elements (NKOhOw)".into(),
295            ));
296        }
297        if stride.len() != 2 {
298            return Err(BackendError::InvalidArgument(
299                "stride must have 2 elements [sh, sw]".into(),
300            ));
301        }
302        if padding.len() != 2 {
303            return Err(BackendError::InvalidArgument(
304                "padding must have 2 elements [ph, pw]".into(),
305            ));
306        }
307
308        let mem = self.memory()?;
309
310        let batch = input_shape[0];
311        let c_in = input_shape[1];
312        let h_in = input_shape[2];
313        let w_in = input_shape[3];
314        let k_out = filter_shape[0];
315        let fh = filter_shape[2];
316        let fw = filter_shape[3];
317        let oh = output_shape[2];
318        let ow = output_shape[3];
319        let sh = stride[0];
320        let sw = stride[1];
321        let ph = padding[0];
322        let pw = padding[1];
323
324        let in_elems: usize = input_shape.iter().product();
325        let f_elems: usize = filter_shape.iter().product();
326        let o_elems: usize = output_shape.iter().product();
327
328        // CPU fallback: download input + filter, compute, upload output.
329        let mut in_bytes = vec![0u8; in_elems * 4];
330        let mut f_bytes = vec![0u8; f_elems * 4];
331        mem.copy_from_device(&mut in_bytes, input_ptr)
332            .map_err(BackendError::from)?;
333        mem.copy_from_device(&mut f_bytes, filter_ptr)
334            .map_err(BackendError::from)?;
335
336        let in_f32 = bytes_to_f32_vec(&in_bytes);
337        let f_f32 = bytes_to_f32_vec(&f_bytes);
338        let mut out_f32 = vec![0.0f32; o_elems];
339
340        for b in 0..batch {
341            for kf in 0..k_out {
342                for oy in 0..oh {
343                    for ox in 0..ow {
344                        let mut acc = 0.0f32;
345                        for ci in 0..c_in {
346                            for fy in 0..fh {
347                                for fx in 0..fw {
348                                    let iy = (oy * sh + fy) as isize - ph as isize;
349                                    let ix = (ox * sw + fx) as isize - pw as isize;
350                                    if iy >= 0
351                                        && (iy as usize) < h_in
352                                        && ix >= 0
353                                        && (ix as usize) < w_in
354                                    {
355                                        let in_idx = ((b * c_in + ci) * h_in + iy as usize) * w_in
356                                            + ix as usize;
357                                        let f_idx = ((kf * c_in + ci) * fh + fy) * fw + fx;
358                                        acc += in_f32[in_idx] * f_f32[f_idx];
359                                    }
360                                }
361                            }
362                        }
363                        out_f32[((b * k_out + kf) * oh + oy) * ow + ox] = acc;
364                    }
365                }
366            }
367        }
368
369        let out_bytes = f32_slice_to_bytes(&out_f32);
370        mem.copy_to_device(output_ptr, &out_bytes)
371            .map_err(BackendError::from)?;
372
373        Ok(())
374    }
375
376    fn attention(
377        &self,
378        q_ptr: u64,
379        k_ptr: u64,
380        v_ptr: u64,
381        o_ptr: u64,
382        batch: usize,
383        heads: usize,
384        seq_q: usize,
385        seq_kv: usize,
386        head_dim: usize,
387        scale: f64,
388        causal: bool,
389    ) -> BackendResult<()> {
390        self.check_init()?;
391
392        if seq_q == 0 || seq_kv == 0 || head_dim == 0 {
393            return Err(BackendError::InvalidArgument(
394                "seq_q, seq_kv, and head_dim must all be > 0".into(),
395            ));
396        }
397        if scale <= 0.0 || !scale.is_finite() {
398            return Err(BackendError::InvalidArgument(format!(
399                "scale must be a positive finite number, got {scale}"
400            )));
401        }
402
403        let mem = self.memory()?;
404
405        let batch_heads = batch * heads;
406        let q_elems = batch_heads * seq_q * head_dim;
407        let kv_elems = batch_heads * seq_kv * head_dim;
408        let o_elems = q_elems;
409
410        // CPU fallback: download Q, K, V, compute attention, upload O.
411        let mut q_bytes = vec![0u8; q_elems * 4];
412        let mut k_bytes = vec![0u8; kv_elems * 4];
413        let mut v_bytes = vec![0u8; kv_elems * 4];
414
415        mem.copy_from_device(&mut q_bytes, q_ptr)
416            .map_err(BackendError::from)?;
417        mem.copy_from_device(&mut k_bytes, k_ptr)
418            .map_err(BackendError::from)?;
419        mem.copy_from_device(&mut v_bytes, v_ptr)
420            .map_err(BackendError::from)?;
421
422        let q_f32 = bytes_to_f32_vec(&q_bytes);
423        let k_f32 = bytes_to_f32_vec(&k_bytes);
424        let v_f32 = bytes_to_f32_vec(&v_bytes);
425        let mut o_f32 = vec![0.0f32; o_elems];
426
427        let scale_f32 = scale as f32;
428
429        for bh in 0..batch_heads {
430            let q_off = bh * seq_q * head_dim;
431            let k_off = bh * seq_kv * head_dim;
432            let v_off = k_off;
433
434            for sq in 0..seq_q {
435                let kv_limit = if causal { (sq + 1).min(seq_kv) } else { seq_kv };
436
437                // Pass 1: find max score for numerical stability
438                let mut max_score = f32::NEG_INFINITY;
439                for sk in 0..kv_limit {
440                    let mut dot = 0.0f32;
441                    for dd in 0..head_dim {
442                        dot +=
443                            q_f32[q_off + sq * head_dim + dd] * k_f32[k_off + sk * head_dim + dd];
444                    }
445                    let s = dot * scale_f32;
446                    if s > max_score {
447                        max_score = s;
448                    }
449                }
450
451                // Pass 2: exp(score - max), accumulate weighted V
452                let mut sum_exp = 0.0f32;
453                let mut acc = vec![0.0f32; head_dim];
454                for sk in 0..kv_limit {
455                    let mut dot = 0.0f32;
456                    for dd in 0..head_dim {
457                        dot +=
458                            q_f32[q_off + sq * head_dim + dd] * k_f32[k_off + sk * head_dim + dd];
459                    }
460                    let w = (dot * scale_f32 - max_score).exp();
461                    sum_exp += w;
462                    for dd in 0..head_dim {
463                        acc[dd] += w * v_f32[v_off + sk * head_dim + dd];
464                    }
465                }
466
467                // Normalise
468                let o_base = q_off + sq * head_dim;
469                if sum_exp > 0.0 {
470                    for dd in 0..head_dim {
471                        o_f32[o_base + dd] = acc[dd] / sum_exp;
472                    }
473                }
474            }
475        }
476
477        let o_bytes = f32_slice_to_bytes(&o_f32);
478        mem.copy_to_device(o_ptr, &o_bytes)
479            .map_err(BackendError::from)?;
480
481        Ok(())
482    }
483
484    fn reduce(
485        &self,
486        op: ReduceOp,
487        input_ptr: u64,
488        output_ptr: u64,
489        shape: &[usize],
490        axis: usize,
491    ) -> BackendResult<()> {
492        self.check_init()?;
493
494        if shape.is_empty() {
495            return Err(BackendError::InvalidArgument(
496                "shape must not be empty".into(),
497            ));
498        }
499        if axis >= shape.len() {
500            return Err(BackendError::InvalidArgument(format!(
501                "axis {axis} is out of bounds for shape of length {}",
502                shape.len()
503            )));
504        }
505
506        // Only flat 1-D reduction (shape.len() == 1, axis == 0) is currently
507        // supported on the GPU.  Multi-dimensional reductions require batched
508        // shaders that are not yet implemented.
509        if shape.len() != 1 {
510            return Err(BackendError::Unsupported(
511                "WebGPU reduce currently supports only 1-D shapes".into(),
512            ));
513        }
514
515        let n_elements = shape[0];
516        if n_elements == 0 {
517            return Ok(());
518        }
519
520        let dev = self.device()?;
521        let mem = self.memory()?;
522        let op_str = map_reduce_op(op);
523
524        // ── Pass 1: per-workgroup reduction ─────────────────────────────────
525        let wg_count = (n_elements as u32).div_ceil(256);
526
527        let pass1_wgsl = shader::reduction_wgsl(op_str);
528        let pass1_shader = dev
529            .device
530            .create_shader_module(wgpu::ShaderModuleDescriptor {
531                label: Some("oxicuda-reduce-pass1"),
532                source: wgpu::ShaderSource::Wgsl(pass1_wgsl.into()),
533            });
534        let pass1_pipeline = dev
535            .device
536            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
537                label: Some("oxicuda-reduce-pass1"),
538                layout: None,
539                module: &pass1_shader,
540                entry_point: Some("main"),
541                compilation_options: Default::default(),
542                cache: None,
543            });
544
545        // Partial-sums buffer (temporary).
546        let partial_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
547            label: Some("oxicuda-reduce-partial"),
548            size: (wg_count as u64) * 4, // f32 per workgroup
549            usage: wgpu::BufferUsages::STORAGE
550                | wgpu::BufferUsages::COPY_SRC
551                | wgpu::BufferUsages::COPY_DST,
552            mapped_at_creation: false,
553        });
554
555        // Uniform for ReduceParams { n: u32 }.
556        let mut p1_params = [0u8; 4];
557        p1_params[0..4].copy_from_slice(&(n_elements as u32).to_le_bytes());
558        let p1_uniform = dev.device.create_buffer(&wgpu::BufferDescriptor {
559            label: Some("oxicuda-reduce-p1-params"),
560            size: 4,
561            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
562            mapped_at_creation: false,
563        });
564        dev.queue.write_buffer(&p1_uniform, 0, &p1_params);
565
566        let bgl1 = pass1_pipeline.get_bind_group_layout(0);
567
568        let bg1 = {
569            let buffers = mem
570                .lock_buffers()
571                .map_err(|e| BackendError::DeviceError(e.to_string()))?;
572            let in_info = buffers.get(&input_ptr).ok_or_else(|| {
573                BackendError::InvalidArgument(format!("unknown handle {input_ptr}"))
574            })?;
575
576            dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
577                label: Some("oxicuda-reduce-pass1"),
578                layout: &bgl1,
579                entries: &[
580                    wgpu::BindGroupEntry {
581                        binding: 0,
582                        resource: in_info.buffer.as_entire_binding(),
583                    },
584                    wgpu::BindGroupEntry {
585                        binding: 1,
586                        resource: partial_buf.as_entire_binding(),
587                    },
588                    wgpu::BindGroupEntry {
589                        binding: 2,
590                        resource: p1_uniform.as_entire_binding(),
591                    },
592                ],
593            })
594        };
595
596        // ── Pass 2: final reduction of partial sums ─────────────────────────
597        let pass2_wgsl = shader::reduction_final_wgsl(op_str);
598        let pass2_shader = dev
599            .device
600            .create_shader_module(wgpu::ShaderModuleDescriptor {
601                label: Some("oxicuda-reduce-pass2"),
602                source: wgpu::ShaderSource::Wgsl(pass2_wgsl.into()),
603            });
604        let pass2_pipeline = dev
605            .device
606            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
607                label: Some("oxicuda-reduce-pass2"),
608                layout: None,
609                module: &pass2_shader,
610                entry_point: Some("main"),
611                compilation_options: Default::default(),
612                cache: None,
613            });
614
615        // FinalReduceParams { num_groups: u32 }.
616        let mut p2_params = [0u8; 4];
617        p2_params[0..4].copy_from_slice(&wg_count.to_le_bytes());
618        let p2_uniform = dev.device.create_buffer(&wgpu::BufferDescriptor {
619            label: Some("oxicuda-reduce-p2-params"),
620            size: 4,
621            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
622            mapped_at_creation: false,
623        });
624        dev.queue.write_buffer(&p2_uniform, 0, &p2_params);
625
626        let bgl2 = pass2_pipeline.get_bind_group_layout(0);
627
628        let bg2 = {
629            let buffers = mem
630                .lock_buffers()
631                .map_err(|e| BackendError::DeviceError(e.to_string()))?;
632            let out_info = buffers.get(&output_ptr).ok_or_else(|| {
633                BackendError::InvalidArgument(format!("unknown handle {output_ptr}"))
634            })?;
635
636            dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
637                label: Some("oxicuda-reduce-pass2"),
638                layout: &bgl2,
639                entries: &[
640                    wgpu::BindGroupEntry {
641                        binding: 0,
642                        resource: partial_buf.as_entire_binding(),
643                    },
644                    wgpu::BindGroupEntry {
645                        binding: 1,
646                        resource: out_info.buffer.as_entire_binding(),
647                    },
648                    wgpu::BindGroupEntry {
649                        binding: 2,
650                        resource: p2_uniform.as_entire_binding(),
651                    },
652                ],
653            })
654        };
655
656        // ── Encode both passes into one command buffer ──────────────────────
657        let mut encoder = dev
658            .device
659            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
660                label: Some("oxicuda-reduce"),
661            });
662
663        {
664            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
665                label: Some("oxicuda-reduce-pass1"),
666                timestamp_writes: None,
667            });
668            pass.set_pipeline(&pass1_pipeline);
669            pass.set_bind_group(0, &bg1, &[]);
670            pass.dispatch_workgroups(wg_count, 1, 1);
671        }
672        {
673            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
674                label: Some("oxicuda-reduce-pass2"),
675                timestamp_writes: None,
676            });
677            pass.set_pipeline(&pass2_pipeline);
678            pass.set_bind_group(0, &bg2, &[]);
679            pass.dispatch_workgroups(1, 1, 1);
680        }
681
682        dev.queue.submit(std::iter::once(encoder.finish()));
683        let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
684
685        // For "mean", divide the result by N on the host side.
686        if op == ReduceOp::Mean && n_elements > 1 {
687            let mut buf = [0u8; 4];
688            mem.copy_from_device(&mut buf, output_ptr)
689                .map_err(BackendError::from)?;
690            let val = f32::from_le_bytes(buf);
691            let mean = val / (n_elements as f32);
692            mem.copy_to_device(output_ptr, &mean.to_le_bytes())
693                .map_err(BackendError::from)?;
694        }
695
696        Ok(())
697    }
698
699    fn unary(&self, op: UnaryOp, input_ptr: u64, output_ptr: u64, n: usize) -> BackendResult<()> {
700        self.check_init()?;
701        if n == 0 {
702            return Ok(());
703        }
704
705        let dev = self.device()?;
706        let mem = self.memory()?;
707
708        let op_str = map_unary_op(op);
709        let wgsl = shader::elementwise_wgsl(op_str);
710
711        let shader_mod = dev
712            .device
713            .create_shader_module(wgpu::ShaderModuleDescriptor {
714                label: Some("oxicuda-unary"),
715                source: wgpu::ShaderSource::Wgsl(wgsl.into()),
716            });
717
718        let pipeline = dev
719            .device
720            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
721                label: Some("oxicuda-unary"),
722                layout: None,
723                module: &shader_mod,
724                entry_point: Some("main"),
725                compilation_options: Default::default(),
726                cache: None,
727            });
728
729        let bgl = pipeline.get_bind_group_layout(0);
730
731        let bind_group = {
732            let buffers = mem
733                .lock_buffers()
734                .map_err(|e| BackendError::DeviceError(e.to_string()))?;
735            let in_info = buffers.get(&input_ptr).ok_or_else(|| {
736                BackendError::InvalidArgument(format!("unknown handle {input_ptr}"))
737            })?;
738            let out_info = buffers.get(&output_ptr).ok_or_else(|| {
739                BackendError::InvalidArgument(format!("unknown handle {output_ptr}"))
740            })?;
741
742            dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
743                label: Some("oxicuda-unary"),
744                layout: &bgl,
745                entries: &[
746                    wgpu::BindGroupEntry {
747                        binding: 0,
748                        resource: in_info.buffer.as_entire_binding(),
749                    },
750                    wgpu::BindGroupEntry {
751                        binding: 1,
752                        resource: out_info.buffer.as_entire_binding(),
753                    },
754                ],
755            })
756        };
757
758        let mut encoder = dev
759            .device
760            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
761                label: Some("oxicuda-unary"),
762            });
763
764        {
765            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
766                label: Some("oxicuda-unary"),
767                timestamp_writes: None,
768            });
769            pass.set_pipeline(&pipeline);
770            pass.set_bind_group(0, &bind_group, &[]);
771            let workgroups = (n as u32).div_ceil(256);
772            pass.dispatch_workgroups(workgroups, 1, 1);
773        }
774
775        dev.queue.submit(std::iter::once(encoder.finish()));
776        let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
777
778        Ok(())
779    }
780
781    fn binary(
782        &self,
783        op: BinaryOp,
784        a_ptr: u64,
785        b_ptr: u64,
786        output_ptr: u64,
787        n: usize,
788    ) -> BackendResult<()> {
789        self.check_init()?;
790        if n == 0 {
791            return Ok(());
792        }
793
794        let dev = self.device()?;
795        let mem = self.memory()?;
796
797        let op_str = map_binary_op(op);
798        let wgsl = shader::binary_wgsl(op_str);
799
800        let shader_mod = dev
801            .device
802            .create_shader_module(wgpu::ShaderModuleDescriptor {
803                label: Some("oxicuda-binary"),
804                source: wgpu::ShaderSource::Wgsl(wgsl.into()),
805            });
806
807        let pipeline = dev
808            .device
809            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
810                label: Some("oxicuda-binary"),
811                layout: None,
812                module: &shader_mod,
813                entry_point: Some("main"),
814                compilation_options: Default::default(),
815                cache: None,
816            });
817
818        let bgl = pipeline.get_bind_group_layout(0);
819
820        let bind_group = {
821            let buffers = mem
822                .lock_buffers()
823                .map_err(|e| BackendError::DeviceError(e.to_string()))?;
824            let a_info = buffers
825                .get(&a_ptr)
826                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {a_ptr}")))?;
827            let b_info = buffers
828                .get(&b_ptr)
829                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {b_ptr}")))?;
830            let out_info = buffers.get(&output_ptr).ok_or_else(|| {
831                BackendError::InvalidArgument(format!("unknown handle {output_ptr}"))
832            })?;
833
834            dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
835                label: Some("oxicuda-binary"),
836                layout: &bgl,
837                entries: &[
838                    wgpu::BindGroupEntry {
839                        binding: 0,
840                        resource: a_info.buffer.as_entire_binding(),
841                    },
842                    wgpu::BindGroupEntry {
843                        binding: 1,
844                        resource: b_info.buffer.as_entire_binding(),
845                    },
846                    wgpu::BindGroupEntry {
847                        binding: 2,
848                        resource: out_info.buffer.as_entire_binding(),
849                    },
850                ],
851            })
852        };
853
854        let mut encoder = dev
855            .device
856            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
857                label: Some("oxicuda-binary"),
858            });
859
860        {
861            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
862                label: Some("oxicuda-binary"),
863                timestamp_writes: None,
864            });
865            pass.set_pipeline(&pipeline);
866            pass.set_bind_group(0, &bind_group, &[]);
867            let workgroups = (n as u32).div_ceil(256);
868            pass.dispatch_workgroups(workgroups, 1, 1);
869        }
870
871        dev.queue.submit(std::iter::once(encoder.finish()));
872        let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
873
874        Ok(())
875    }
876
877    // ── Synchronisation ───────────────────────────────────────────────────────
878
879    fn synchronize(&self) -> BackendResult<()> {
880        self.check_init()?;
881        if let Some(dev) = &self.device {
882            let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
883        }
884        Ok(())
885    }
886
887    // ── Memory management ─────────────────────────────────────────────────────
888
889    fn alloc(&self, bytes: usize) -> BackendResult<u64> {
890        self.check_init()?;
891        if bytes == 0 {
892            return Err(BackendError::InvalidArgument(
893                "cannot allocate 0 bytes".into(),
894            ));
895        }
896        self.memory()?.alloc(bytes).map_err(BackendError::from)
897    }
898
899    fn free(&self, ptr: u64) -> BackendResult<()> {
900        self.check_init()?;
901        self.memory()?.free(ptr).map_err(BackendError::from)
902    }
903
904    fn copy_htod(&self, dst: u64, src: &[u8]) -> BackendResult<()> {
905        self.check_init()?;
906        if src.is_empty() {
907            return Ok(());
908        }
909        self.memory()?
910            .copy_to_device(dst, src)
911            .map_err(BackendError::from)
912    }
913
914    fn copy_dtoh(&self, dst: &mut [u8], src: u64) -> BackendResult<()> {
915        self.check_init()?;
916        if dst.is_empty() {
917            return Ok(());
918        }
919        self.memory()?
920            .copy_from_device(dst, src)
921            .map_err(BackendError::from)
922    }
923}
924
925// ─── Byte ↔ f32 helpers ──────────────────────────────────────────────────────
926
927/// Convert a `&[u8]` (length must be a multiple of 4) to a `Vec<f32>`.
928fn bytes_to_f32_vec(bytes: &[u8]) -> Vec<f32> {
929    bytes
930        .chunks_exact(4)
931        .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
932        .collect()
933}
934
935/// Convert a `&[f32]` slice to its little-endian byte representation.
936fn f32_slice_to_bytes(data: &[f32]) -> Vec<u8> {
937    data.iter().flat_map(|v| v.to_le_bytes()).collect()
938}
939
940// ─── Tests ───────────────────────────────────────────────────────────────────
941
942#[cfg(test)]
943mod tests {
944    use super::*;
945    use oxicuda_backend::{BackendTranspose, BinaryOp, ReduceOp, UnaryOp};
946
947    // ── Construction ──────────────────────────────────────────────────────────
948
949    #[test]
950    fn webgpu_backend_new_uninitialized() {
951        let b = WebGpuBackend::new();
952        assert!(!b.is_initialized());
953    }
954
955    #[test]
956    fn webgpu_backend_name() {
957        let b = WebGpuBackend::new();
958        assert_eq!(b.name(), "webgpu");
959    }
960
961    #[test]
962    fn webgpu_backend_default() {
963        let b = WebGpuBackend::default();
964        assert!(!b.is_initialized());
965        assert_eq!(b.name(), "webgpu");
966    }
967
968    #[test]
969    fn backend_debug_impl() {
970        let b = WebGpuBackend::new();
971        let s = format!("{b:?}");
972        assert!(s.contains("WebGpuBackend"));
973    }
974
975    // ── Object-safety smoke test ──────────────────────────────────────────────
976
977    #[test]
978    fn backend_object_safe() {
979        let b: Box<dyn ComputeBackend> = Box::new(WebGpuBackend::new());
980        assert_eq!(b.name(), "webgpu");
981    }
982
983    // ── Not-initialized guards ────────────────────────────────────────────────
984
985    #[test]
986    fn backend_not_initialized_gemm() {
987        let b = WebGpuBackend::new();
988        let result = b.gemm(
989            BackendTranspose::NoTrans,
990            BackendTranspose::NoTrans,
991            4,
992            4,
993            4,
994            1.0,
995            0,
996            4,
997            0,
998            4,
999            0.0,
1000            0,
1001            4,
1002        );
1003        assert_eq!(result, Err(BackendError::NotInitialized));
1004    }
1005
1006    #[test]
1007    fn backend_not_initialized_alloc() {
1008        let b = WebGpuBackend::new();
1009        let result = b.alloc(1024);
1010        assert_eq!(result, Err(BackendError::NotInitialized));
1011    }
1012
1013    #[test]
1014    fn backend_not_initialized_synchronize() {
1015        let b = WebGpuBackend::new();
1016        assert_eq!(b.synchronize(), Err(BackendError::NotInitialized));
1017    }
1018
1019    #[test]
1020    fn backend_not_initialized_free() {
1021        let b = WebGpuBackend::new();
1022        assert_eq!(b.free(1), Err(BackendError::NotInitialized));
1023    }
1024
1025    #[test]
1026    fn backend_not_initialized_copy_htod() {
1027        let b = WebGpuBackend::new();
1028        assert_eq!(b.copy_htod(1, b"hello"), Err(BackendError::NotInitialized));
1029    }
1030
1031    #[test]
1032    fn backend_not_initialized_copy_dtoh() {
1033        let b = WebGpuBackend::new();
1034        let mut buf = [0u8; 4];
1035        assert_eq!(b.copy_dtoh(&mut buf, 1), Err(BackendError::NotInitialized));
1036    }
1037
1038    // ── Zero-size / trivial-OK paths (no GPU needed) ─────────────────────────
1039
1040    /// These tests exercise the "no-op for zero size" branches.  We need the
1041    /// backend to be initialised, but if no GPU is available we skip.
1042    fn try_init() -> Option<WebGpuBackend> {
1043        let mut b = WebGpuBackend::new();
1044        match b.init() {
1045            Ok(()) => Some(b),
1046            Err(_) => None,
1047        }
1048    }
1049
1050    #[test]
1051    fn gemm_zero_size_after_init() {
1052        let Some(b) = try_init() else {
1053            return;
1054        };
1055        let result = b.gemm(
1056            BackendTranspose::NoTrans,
1057            BackendTranspose::NoTrans,
1058            0,
1059            0,
1060            0,
1061            1.0,
1062            0,
1063            1,
1064            0,
1065            1,
1066            0.0,
1067            0,
1068            1,
1069        );
1070        assert_eq!(result, Ok(()));
1071    }
1072
1073    #[test]
1074    fn unary_zero_elements_after_init() {
1075        let Some(b) = try_init() else {
1076            return;
1077        };
1078        assert_eq!(b.unary(UnaryOp::Relu, 0, 0, 0), Ok(()));
1079    }
1080
1081    #[test]
1082    fn binary_zero_elements_after_init() {
1083        let Some(b) = try_init() else {
1084            return;
1085        };
1086        assert_eq!(b.binary(BinaryOp::Add, 0, 0, 0, 0), Ok(()));
1087    }
1088
1089    #[test]
1090    fn copy_htod_empty_noop() {
1091        let Some(b) = try_init() else {
1092            return;
1093        };
1094        assert_eq!(b.copy_htod(0, &[]), Ok(()));
1095    }
1096
1097    #[test]
1098    fn copy_dtoh_empty_noop() {
1099        let Some(b) = try_init() else {
1100            return;
1101        };
1102        assert_eq!(b.copy_dtoh(&mut [], 0), Ok(()));
1103    }
1104
1105    #[test]
1106    fn alloc_zero_bytes_error() {
1107        let Some(b) = try_init() else {
1108            return;
1109        };
1110        assert_eq!(
1111            b.alloc(0),
1112            Err(BackendError::InvalidArgument(
1113                "cannot allocate 0 bytes".into()
1114            ))
1115        );
1116    }
1117
1118    #[test]
1119    fn synchronize_after_init() {
1120        let Some(b) = try_init() else {
1121            return;
1122        };
1123        assert_eq!(b.synchronize(), Ok(()));
1124    }
1125
1126    // ── Argument validation (post-init) ───────────────────────────────────────
1127
1128    #[test]
1129    fn reduce_empty_shape_error() {
1130        let Some(b) = try_init() else {
1131            return;
1132        };
1133        assert_eq!(
1134            b.reduce(ReduceOp::Sum, 0, 0, &[], 0),
1135            Err(BackendError::InvalidArgument(
1136                "shape must not be empty".into()
1137            ))
1138        );
1139    }
1140
1141    #[test]
1142    fn reduce_axis_out_of_bounds_error() {
1143        let Some(b) = try_init() else {
1144            return;
1145        };
1146        assert_eq!(
1147            b.reduce(ReduceOp::Sum, 0, 0, &[4, 4], 5),
1148            Err(BackendError::InvalidArgument(
1149                "axis 5 is out of bounds for shape of length 2".into()
1150            ))
1151        );
1152    }
1153
1154    #[test]
1155    fn attention_zero_seq_error() {
1156        let Some(b) = try_init() else {
1157            return;
1158        };
1159        assert_eq!(
1160            b.attention(0, 0, 0, 0, 1, 1, 0, 8, 64, 0.125, false),
1161            Err(BackendError::InvalidArgument(
1162                "seq_q, seq_kv, and head_dim must all be > 0".into()
1163            ))
1164        );
1165    }
1166
1167    #[test]
1168    fn attention_nonpositive_scale_error() {
1169        let Some(b) = try_init() else {
1170            return;
1171        };
1172        assert_eq!(
1173            b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, 0.0, false),
1174            Err(BackendError::InvalidArgument(
1175                "scale must be a positive finite number, got 0".into()
1176            ))
1177        );
1178        assert_eq!(
1179            b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, -1.0, false),
1180            Err(BackendError::InvalidArgument(
1181                "scale must be a positive finite number, got -1".into()
1182            ))
1183        );
1184        assert!(
1185            b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, f64::INFINITY, false)
1186                .is_err()
1187        );
1188    }
1189
1190    #[test]
1191    fn conv2d_wrong_input_shape_error() {
1192        let Some(b) = try_init() else {
1193            return;
1194        };
1195        // 3-element input_shape — should fail.
1196        assert_eq!(
1197            b.conv2d_forward(
1198                0,
1199                &[1, 3, 32],
1200                0,
1201                &[16, 3, 3, 3],
1202                0,
1203                &[1, 16, 30, 30],
1204                &[1, 1],
1205                &[0, 0]
1206            ),
1207            Err(BackendError::InvalidArgument(
1208                "input_shape must have 4 elements (NCHW)".into()
1209            ))
1210        );
1211    }
1212
1213    #[test]
1214    fn conv2d_wrong_filter_shape_error() {
1215        let Some(b) = try_init() else {
1216            return;
1217        };
1218        assert_eq!(
1219            b.conv2d_forward(
1220                0,
1221                &[1, 3, 32, 32],
1222                0,
1223                &[16, 3, 3],
1224                0,
1225                &[1, 16, 30, 30],
1226                &[1, 1],
1227                &[0, 0]
1228            ),
1229            Err(BackendError::InvalidArgument(
1230                "filter_shape must have 4 elements (KCFHFW)".into()
1231            ))
1232        );
1233    }
1234
1235    #[test]
1236    fn conv2d_wrong_stride_shape_error() {
1237        let Some(b) = try_init() else {
1238            return;
1239        };
1240        assert_eq!(
1241            b.conv2d_forward(
1242                0,
1243                &[1, 3, 32, 32],
1244                0,
1245                &[16, 3, 3, 3],
1246                0,
1247                &[1, 16, 30, 30],
1248                &[1], // <-- wrong
1249                &[0, 0],
1250            ),
1251            Err(BackendError::InvalidArgument(
1252                "stride must have 2 elements [sh, sw]".into()
1253            ))
1254        );
1255    }
1256
1257    // ── Init is idempotent ────────────────────────────────────────────────────
1258
1259    #[test]
1260    fn init_idempotent() {
1261        let Some(mut b) = try_init() else {
1262            return;
1263        };
1264        // Second call must succeed without error.
1265        assert_eq!(b.init(), Ok(()));
1266        assert!(b.is_initialized());
1267    }
1268
1269    // ── Graceful failure ──────────────────────────────────────────────────────
1270
1271    #[test]
1272    fn webgpu_init_graceful_failure() {
1273        // We cannot force a failure, but we can at least verify that init()
1274        // returns a Result and never panics.
1275        let mut b = WebGpuBackend::new();
1276        let _result = b.init(); // Ok or Err — both are acceptable.
1277        // No panic => test passes.
1278    }
1279
1280    // ── GPU compute tests ─────────────────────────────────────────────────────
1281    //
1282    // These helpers upload f32 slices and read back results, exercising the
1283    // full shader → pipeline → dispatch path.
1284
1285    /// Helper: upload `data` (f32 slice) to a new GPU buffer, return its handle.
1286    fn upload_f32(b: &WebGpuBackend, data: &[f32]) -> u64 {
1287        let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
1288        let h = b.alloc(bytes.len()).expect("alloc");
1289        b.copy_htod(h, &bytes).expect("copy_htod");
1290        h
1291    }
1292
1293    /// Helper: download `n` f32 values from a GPU buffer handle.
1294    fn download_f32(b: &WebGpuBackend, h: u64, n: usize) -> Vec<f32> {
1295        let mut bytes = vec![0u8; n * 4];
1296        b.copy_dtoh(&mut bytes, h).expect("copy_dtoh");
1297        bytes
1298            .chunks_exact(4)
1299            .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
1300            .collect()
1301    }
1302
1303    #[test]
1304    fn unary_neg_small() {
1305        let Some(b) = try_init() else { return };
1306        let input = [1.0f32, -2.0, 3.0, 0.0];
1307        let in_h = upload_f32(&b, &input);
1308        let out_h = b.alloc(input.len() * 4).expect("alloc output");
1309
1310        b.unary(UnaryOp::Neg, in_h, out_h, input.len())
1311            .expect("unary neg");
1312
1313        let result = download_f32(&b, out_h, input.len());
1314        let expected = [-1.0f32, 2.0, -3.0, 0.0];
1315        for (r, e) in result.iter().zip(expected.iter()) {
1316            assert!((r - e).abs() < 1e-6, "got {r}, expected {e}");
1317        }
1318
1319        b.free(in_h).expect("free");
1320        b.free(out_h).expect("free");
1321    }
1322
1323    #[test]
1324    fn unary_abs_small() {
1325        let Some(b) = try_init() else { return };
1326        let input = [-3.0f32, 4.0, -5.0, 0.0];
1327        let in_h = upload_f32(&b, &input);
1328        let out_h = b.alloc(input.len() * 4).expect("alloc output");
1329
1330        b.unary(UnaryOp::Abs, in_h, out_h, input.len())
1331            .expect("unary abs");
1332
1333        let result = download_f32(&b, out_h, input.len());
1334        let expected = [3.0f32, 4.0, 5.0, 0.0];
1335        for (r, e) in result.iter().zip(expected.iter()) {
1336            assert!((r - e).abs() < 1e-6, "got {r}, expected {e}");
1337        }
1338
1339        b.free(in_h).expect("free");
1340        b.free(out_h).expect("free");
1341    }
1342
1343    #[test]
1344    fn binary_add_small() {
1345        let Some(b) = try_init() else { return };
1346        let a = [1.0f32, 2.0, 3.0, 4.0];
1347        let bv = [10.0f32, 20.0, 30.0, 40.0];
1348        let a_h = upload_f32(&b, &a);
1349        let b_h = upload_f32(&b, &bv);
1350        let out_h = b.alloc(a.len() * 4).expect("alloc output");
1351
1352        b.binary(BinaryOp::Add, a_h, b_h, out_h, a.len())
1353            .expect("binary add");
1354
1355        let result = download_f32(&b, out_h, a.len());
1356        let expected = [11.0f32, 22.0, 33.0, 44.0];
1357        for (r, e) in result.iter().zip(expected.iter()) {
1358            assert!((r - e).abs() < 1e-6, "got {r}, expected {e}");
1359        }
1360
1361        b.free(a_h).expect("free");
1362        b.free(b_h).expect("free");
1363        b.free(out_h).expect("free");
1364    }
1365
1366    #[test]
1367    fn binary_mul_small() {
1368        let Some(b) = try_init() else { return };
1369        let a = [2.0f32, 3.0, 4.0, 5.0];
1370        let bv = [10.0f32, 10.0, 10.0, 10.0];
1371        let a_h = upload_f32(&b, &a);
1372        let b_h = upload_f32(&b, &bv);
1373        let out_h = b.alloc(a.len() * 4).expect("alloc output");
1374
1375        b.binary(BinaryOp::Mul, a_h, b_h, out_h, a.len())
1376            .expect("binary mul");
1377
1378        let result = download_f32(&b, out_h, a.len());
1379        let expected = [20.0f32, 30.0, 40.0, 50.0];
1380        for (r, e) in result.iter().zip(expected.iter()) {
1381            assert!((r - e).abs() < 1e-6, "got {r}, expected {e}");
1382        }
1383
1384        b.free(a_h).expect("free");
1385        b.free(b_h).expect("free");
1386        b.free(out_h).expect("free");
1387    }
1388
1389    #[test]
1390    fn reduce_sum_small() {
1391        let Some(b) = try_init() else { return };
1392        let input = [1.0f32, 2.0, 3.0, 4.0];
1393        let in_h = upload_f32(&b, &input);
1394        let out_h = b.alloc(4).expect("alloc output"); // single f32
1395
1396        b.reduce(ReduceOp::Sum, in_h, out_h, &[4], 0)
1397            .expect("reduce sum");
1398
1399        let result = download_f32(&b, out_h, 1);
1400        assert!(
1401            (result[0] - 10.0).abs() < 1e-5,
1402            "expected 10.0, got {}",
1403            result[0]
1404        );
1405
1406        b.free(in_h).expect("free");
1407        b.free(out_h).expect("free");
1408    }
1409
1410    #[test]
1411    fn reduce_max_small() {
1412        let Some(b) = try_init() else { return };
1413        let input = [1.0f32, 5.0, 3.0, 2.0];
1414        let in_h = upload_f32(&b, &input);
1415        let out_h = b.alloc(4).expect("alloc output");
1416
1417        b.reduce(ReduceOp::Max, in_h, out_h, &[4], 0)
1418            .expect("reduce max");
1419
1420        let result = download_f32(&b, out_h, 1);
1421        assert!(
1422            (result[0] - 5.0).abs() < 1e-5,
1423            "expected 5.0, got {}",
1424            result[0]
1425        );
1426
1427        b.free(in_h).expect("free");
1428        b.free(out_h).expect("free");
1429    }
1430
1431    #[test]
1432    fn reduce_mean_small() {
1433        let Some(b) = try_init() else { return };
1434        let input = [2.0f32, 4.0, 6.0, 8.0];
1435        let in_h = upload_f32(&b, &input);
1436        let out_h = b.alloc(4).expect("alloc output");
1437
1438        b.reduce(ReduceOp::Mean, in_h, out_h, &[4], 0)
1439            .expect("reduce mean");
1440
1441        let result = download_f32(&b, out_h, 1);
1442        assert!(
1443            (result[0] - 5.0).abs() < 1e-5,
1444            "expected 5.0, got {}",
1445            result[0]
1446        );
1447
1448        b.free(in_h).expect("free");
1449        b.free(out_h).expect("free");
1450    }
1451
1452    #[test]
1453    fn gemm_identity_2x2() {
1454        let Some(b) = try_init() else { return };
1455        // A = [[1,2],[3,4]], B = [[1,0],[0,1]] (identity), C = zeros
1456        // C = 1.0 * A * I + 0.0 * C = A
1457        let a = [1.0f32, 2.0, 3.0, 4.0];
1458        let eye = [1.0f32, 0.0, 0.0, 1.0];
1459        let c_init = [0.0f32; 4];
1460
1461        let a_h = upload_f32(&b, &a);
1462        let b_h = upload_f32(&b, &eye);
1463        let c_h = upload_f32(&b, &c_init);
1464
1465        b.gemm(
1466            BackendTranspose::NoTrans,
1467            BackendTranspose::NoTrans,
1468            2,
1469            2,
1470            2,
1471            1.0,
1472            a_h,
1473            2,
1474            b_h,
1475            2,
1476            0.0,
1477            c_h,
1478            2,
1479        )
1480        .expect("gemm");
1481
1482        let result = download_f32(&b, c_h, 4);
1483        for (r, e) in result.iter().zip(a.iter()) {
1484            assert!((r - e).abs() < 1e-5, "got {r}, expected {e}");
1485        }
1486
1487        b.free(a_h).expect("free");
1488        b.free(b_h).expect("free");
1489        b.free(c_h).expect("free");
1490    }
1491
1492    #[test]
1493    fn gemm_2x3_times_3x2() {
1494        let Some(b) = try_init() else { return };
1495        // A 2x3, B 3x2 → C 2x2
1496        let a = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
1497        let bm = [7.0f32, 8.0, 9.0, 10.0, 11.0, 12.0];
1498        let c_init = [0.0f32; 4];
1499
1500        let a_h = upload_f32(&b, &a);
1501        let b_h = upload_f32(&b, &bm);
1502        let c_h = upload_f32(&b, &c_init);
1503
1504        b.gemm(
1505            BackendTranspose::NoTrans,
1506            BackendTranspose::NoTrans,
1507            2,
1508            2,
1509            3,
1510            1.0,
1511            a_h,
1512            3,
1513            b_h,
1514            2,
1515            0.0,
1516            c_h,
1517            2,
1518        )
1519        .expect("gemm");
1520
1521        // Expected: [[58, 64], [139, 154]]
1522        let result = download_f32(&b, c_h, 4);
1523        let expected = [58.0f32, 64.0, 139.0, 154.0];
1524        for (r, e) in result.iter().zip(expected.iter()) {
1525            assert!((r - e).abs() < 1e-4, "got {r}, expected {e}");
1526        }
1527
1528        b.free(a_h).expect("free");
1529        b.free(b_h).expect("free");
1530        b.free(c_h).expect("free");
1531    }
1532
1533    #[test]
1534    fn gemm_alpha_beta() {
1535        let Some(b) = try_init() else { return };
1536        // C = 2.0 * A * B + 3.0 * C
1537        // A = [[1,0],[0,1]], B = [[1,0],[0,1]], C = [[1,1],[1,1]]
1538        // C = 2*I + 3*ones = [[5,3],[3,5]]
1539        let a = [1.0f32, 0.0, 0.0, 1.0];
1540        let bm = [1.0f32, 0.0, 0.0, 1.0];
1541        let c_init = [1.0f32, 1.0, 1.0, 1.0];
1542
1543        let a_h = upload_f32(&b, &a);
1544        let b_h = upload_f32(&b, &bm);
1545        let c_h = upload_f32(&b, &c_init);
1546
1547        b.gemm(
1548            BackendTranspose::NoTrans,
1549            BackendTranspose::NoTrans,
1550            2,
1551            2,
1552            2,
1553            2.0,
1554            a_h,
1555            2,
1556            b_h,
1557            2,
1558            3.0,
1559            c_h,
1560            2,
1561        )
1562        .expect("gemm alpha+beta");
1563
1564        let result = download_f32(&b, c_h, 4);
1565        let expected = [5.0f32, 3.0, 3.0, 5.0];
1566        for (r, e) in result.iter().zip(expected.iter()) {
1567            assert!((r - e).abs() < 1e-4, "got {r}, expected {e}");
1568        }
1569
1570        b.free(a_h).expect("free");
1571        b.free(b_h).expect("free");
1572        b.free(c_h).expect("free");
1573    }
1574
1575    // ── Conv2D tests ──────────────────────────────────────────────────────
1576
1577    #[test]
1578    fn conv2d_identity_1x1() {
1579        // 1×1 convolution with single channel, no padding, stride=1
1580        // input: 1×1×3×3, filter: 1×1×1×1 (weight=2.0), output: 1×1×3×3
1581        let Some(b) = try_init() else { return };
1582        let input: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1583        let filter = [2.0f32];
1584        let expected: Vec<f32> = input.iter().map(|x| x * 2.0).collect();
1585
1586        let in_h = upload_f32(&b, &input);
1587        let f_h = upload_f32(&b, &filter);
1588        let out_h = b.alloc(9 * 4).expect("alloc output");
1589
1590        b.conv2d_forward(
1591            in_h,
1592            &[1, 1, 3, 3],
1593            f_h,
1594            &[1, 1, 1, 1],
1595            out_h,
1596            &[1, 1, 3, 3],
1597            &[1, 1],
1598            &[0, 0],
1599        )
1600        .expect("conv2d");
1601
1602        let result = download_f32(&b, out_h, 9);
1603        for (r, e) in result.iter().zip(expected.iter()) {
1604            assert!((r - e).abs() < 1e-5, "got {r}, expected {e}");
1605        }
1606
1607        b.free(in_h).expect("free");
1608        b.free(f_h).expect("free");
1609        b.free(out_h).expect("free");
1610    }
1611
1612    #[test]
1613    fn conv2d_3x3_no_padding() {
1614        // input: 1×1×4×4, filter: 1×1×3×3 (all ones), stride=1, pad=0
1615        // output: 1×1×2×2
1616        let Some(b) = try_init() else { return };
1617        let input: Vec<f32> = (0..16).map(|x| x as f32).collect();
1618        let filter = [1.0f32; 9];
1619
1620        let in_h = upload_f32(&b, &input);
1621        let f_h = upload_f32(&b, &filter);
1622        let out_h = b.alloc(4 * 4).expect("alloc output");
1623
1624        b.conv2d_forward(
1625            in_h,
1626            &[1, 1, 4, 4],
1627            f_h,
1628            &[1, 1, 3, 3],
1629            out_h,
1630            &[1, 1, 2, 2],
1631            &[1, 1],
1632            &[0, 0],
1633        )
1634        .expect("conv2d");
1635
1636        let result = download_f32(&b, out_h, 4);
1637        // top-left 3×3 sum: 0+1+2+4+5+6+8+9+10 = 45
1638        assert!((result[0] - 45.0).abs() < 1e-4, "got {}", result[0]);
1639        // top-right 3×3 sum: 1+2+3+5+6+7+9+10+11 = 54
1640        assert!((result[1] - 54.0).abs() < 1e-4, "got {}", result[1]);
1641
1642        b.free(in_h).expect("free");
1643        b.free(f_h).expect("free");
1644        b.free(out_h).expect("free");
1645    }
1646
1647    #[test]
1648    fn conv2d_with_padding() {
1649        // input: 1×1×2×2, filter: 1×1×3×3 (all ones), stride=1, pad=1
1650        // output: 1×1×2×2
1651        // With padding=1 around a 2×2 input, the output is also 2×2.
1652        let Some(b) = try_init() else { return };
1653        let input = [1.0f32, 2.0, 3.0, 4.0];
1654        let filter = [1.0f32; 9];
1655
1656        let in_h = upload_f32(&b, &input);
1657        let f_h = upload_f32(&b, &filter);
1658        let out_h = b.alloc(4 * 4).expect("alloc output");
1659
1660        b.conv2d_forward(
1661            in_h,
1662            &[1, 1, 2, 2],
1663            f_h,
1664            &[1, 1, 3, 3],
1665            out_h,
1666            &[1, 1, 2, 2],
1667            &[1, 1],
1668            &[1, 1],
1669        )
1670        .expect("conv2d");
1671
1672        let result = download_f32(&b, out_h, 4);
1673        // Top-left output: only 4 of 9 filter taps hit valid input
1674        // input[0,0]=1, input[0,1]=2, input[1,0]=3, input[1,1]=4 => sum=10
1675        assert!((result[0] - 10.0).abs() < 1e-4, "got {}", result[0]);
1676
1677        b.free(in_h).expect("free");
1678        b.free(f_h).expect("free");
1679        b.free(out_h).expect("free");
1680    }
1681
1682    // ── Attention tests ───────────────────────────────────────────────────
1683
1684    #[test]
1685    fn attention_uniform_weights() {
1686        // 1 head, seq_q=1, seq_kv=2, head_dim=2, no causal
1687        // Q = [1, 0], K = [[1, 0], [1, 0]], V = [[1, 2], [3, 4]]
1688        // scores = [1*scale, 1*scale] => equal weights => O = mean(V) = [2, 3]
1689        let Some(b) = try_init() else { return };
1690
1691        let q = [1.0f32, 0.0];
1692        let k = [1.0f32, 0.0, 1.0, 0.0];
1693        let v = [1.0f32, 2.0, 3.0, 4.0];
1694
1695        let q_h = upload_f32(&b, &q);
1696        let k_h = upload_f32(&b, &k);
1697        let v_h = upload_f32(&b, &v);
1698        let o_h = b.alloc(2 * 4).expect("alloc output");
1699
1700        b.attention(q_h, k_h, v_h, o_h, 1, 1, 1, 2, 2, 1.0, false)
1701            .expect("attention");
1702
1703        let result = download_f32(&b, o_h, 2);
1704        // Equal scores → equal softmax weights → average of V rows
1705        assert!(
1706            (result[0] - 2.0).abs() < 1e-4,
1707            "got {}, expected 2.0",
1708            result[0]
1709        );
1710        assert!(
1711            (result[1] - 3.0).abs() < 1e-4,
1712            "got {}, expected 3.0",
1713            result[1]
1714        );
1715
1716        b.free(q_h).expect("free");
1717        b.free(k_h).expect("free");
1718        b.free(v_h).expect("free");
1719        b.free(o_h).expect("free");
1720    }
1721
1722    #[test]
1723    fn attention_causal_single_token() {
1724        // 1 head, seq_q=2, seq_kv=2, head_dim=1, causal
1725        // Q = [1, 1], K = [1, 1], V = [10, 20]
1726        // sq=0: only sees sk=0 → O[0] = V[0] = 10
1727        // sq=1: sees sk=0,1 with equal scores → O[1] = (10+20)/2 = 15
1728        let Some(b) = try_init() else { return };
1729
1730        let q = [1.0f32, 1.0];
1731        let k = [1.0f32, 1.0];
1732        let v = [10.0f32, 20.0];
1733
1734        let q_h = upload_f32(&b, &q);
1735        let k_h = upload_f32(&b, &k);
1736        let v_h = upload_f32(&b, &v);
1737        let o_h = b.alloc(2 * 4).expect("alloc output");
1738
1739        b.attention(q_h, k_h, v_h, o_h, 1, 1, 2, 2, 1, 1.0, true)
1740            .expect("attention causal");
1741
1742        let result = download_f32(&b, o_h, 2);
1743        assert!(
1744            (result[0] - 10.0).abs() < 1e-4,
1745            "got {}, expected 10.0",
1746            result[0]
1747        );
1748        assert!(
1749            (result[1] - 15.0).abs() < 1e-4,
1750            "got {}, expected 15.0",
1751            result[1]
1752        );
1753
1754        b.free(q_h).expect("free");
1755        b.free(k_h).expect("free");
1756        b.free(v_h).expect("free");
1757        b.free(o_h).expect("free");
1758    }
1759
1760    #[test]
1761    fn attention_dominant_key() {
1762        // 1 head, seq_q=1, seq_kv=2, head_dim=2, no causal
1763        // Q = [1, 0], K = [[10, 0], [0, 0]], V = [[100, 200], [0, 0]]
1764        // score[0] = 10*scale, score[1] = 0*scale
1765        // With large enough difference, softmax saturates → O ≈ V[0]
1766        let Some(b) = try_init() else { return };
1767
1768        let q = [1.0f32, 0.0];
1769        let k = [10.0f32, 0.0, 0.0, 0.0];
1770        let v = [100.0f32, 200.0, 0.0, 0.0];
1771
1772        let q_h = upload_f32(&b, &q);
1773        let k_h = upload_f32(&b, &k);
1774        let v_h = upload_f32(&b, &v);
1775        let o_h = b.alloc(2 * 4).expect("alloc output");
1776
1777        // scale=1.0 gives scores 10 vs 0 → softmax ≈ [1, 0]
1778        b.attention(q_h, k_h, v_h, o_h, 1, 1, 1, 2, 2, 1.0, false)
1779            .expect("attention dominant");
1780
1781        let result = download_f32(&b, o_h, 2);
1782        assert!(
1783            (result[0] - 100.0).abs() < 0.1,
1784            "got {}, expected ~100",
1785            result[0]
1786        );
1787        assert!(
1788            (result[1] - 200.0).abs() < 0.1,
1789            "got {}, expected ~200",
1790            result[1]
1791        );
1792
1793        b.free(q_h).expect("free");
1794        b.free(k_h).expect("free");
1795        b.free(v_h).expect("free");
1796        b.free(o_h).expect("free");
1797    }
1798}