Skip to main content

oxicuda_webgpu/
backend.rs

1//! [`WebGpuBackend`] — the main entry point for the oxicuda-webgpu crate.
2//!
3//! Implements the [`ComputeBackend`] trait from `oxicuda-backend` using
4//! `wgpu` for cross-platform GPU compute (Vulkan, Metal, DX12, WebGPU).
5
6use std::sync::Arc;
7
8use oxicuda_backend::{
9    BackendError, BackendResult, BackendTranspose, BinaryOp, ComputeBackend, ReduceOp, UnaryOp,
10};
11use wgpu;
12
13use crate::{device::WebGpuDevice, memory::WebGpuMemoryManager, shader};
14
15// ─── Op-mapping helpers ──────────────────────────────────────────────────────
16
17fn map_unary_op(op: UnaryOp) -> &'static str {
18    match op {
19        UnaryOp::Relu => "relu",
20        UnaryOp::Sigmoid => "sigmoid",
21        UnaryOp::Tanh => "tanh",
22        UnaryOp::Exp => "exp",
23        UnaryOp::Log => "log",
24        UnaryOp::Sqrt => "sqrt",
25        UnaryOp::Abs => "abs",
26        UnaryOp::Neg => "neg",
27    }
28}
29
30fn map_binary_op(op: BinaryOp) -> &'static str {
31    match op {
32        BinaryOp::Add => "add",
33        BinaryOp::Sub => "sub",
34        BinaryOp::Mul => "mul",
35        BinaryOp::Div => "div",
36        BinaryOp::Max => "max",
37        BinaryOp::Min => "min",
38    }
39}
40
41fn map_reduce_op(op: ReduceOp) -> &'static str {
42    match op {
43        ReduceOp::Sum => "sum",
44        ReduceOp::Max => "max",
45        ReduceOp::Min => "min",
46        ReduceOp::Mean => "mean",
47    }
48}
49
50// ─── Backend struct ──────────────────────────────────────────────────────────
51
52/// Cross-platform GPU compute backend backed by `wgpu`.
53///
54/// # Lifecycle
55///
56/// 1. `WebGpuBackend::new()` — create an uninitialised backend.
57/// 2. `init()` — select the best available adapter and create the device.
58/// 3. Use `alloc`, `copy_htod`, compute ops, `copy_dtoh`, `free`.
59/// 4. `synchronize()` — wait for all pending GPU work to finish.
60#[derive(Debug)]
61pub struct WebGpuBackend {
62    device: Option<Arc<WebGpuDevice>>,
63    memory: Option<Arc<WebGpuMemoryManager>>,
64    initialized: bool,
65}
66
67impl WebGpuBackend {
68    /// Create a new, uninitialised WebGPU backend.
69    pub fn new() -> Self {
70        Self {
71            device: None,
72            memory: None,
73            initialized: false,
74        }
75    }
76
77    /// Return an error if the backend is not yet initialised.
78    fn check_init(&self) -> BackendResult<()> {
79        if self.initialized {
80            Ok(())
81        } else {
82            Err(BackendError::NotInitialized)
83        }
84    }
85
86    /// Convenience accessor: get the memory manager or return `NotInitialized`.
87    fn memory(&self) -> BackendResult<&Arc<WebGpuMemoryManager>> {
88        self.memory.as_ref().ok_or(BackendError::NotInitialized)
89    }
90
91    /// Convenience accessor: get the device or return `NotInitialized`.
92    fn device(&self) -> BackendResult<&Arc<WebGpuDevice>> {
93        self.device.as_ref().ok_or(BackendError::NotInitialized)
94    }
95}
96
97impl WebGpuBackend {
98    /// FP16 GEMM: `C = alpha * A * B + beta * C` with half-precision storage.
99    ///
100    /// This is an inherent method (not on `ComputeBackend`) because FP16
101    /// support is WebGPU-specific and requires the `f16` WGSL extension.
102    ///
103    /// Buffers pointed to by `a_ptr`, `b_ptr`, `c_ptr` must contain `f16`
104    /// elements (2 bytes each).
105    #[allow(clippy::too_many_arguments)]
106    pub fn gemm_f16(
107        &self,
108        m: usize,
109        n: usize,
110        k: usize,
111        alpha: f64,
112        a_ptr: u64,
113        b_ptr: u64,
114        beta: f64,
115        c_ptr: u64,
116    ) -> BackendResult<()> {
117        self.check_init()?;
118        if m == 0 || n == 0 || k == 0 {
119            return Ok(());
120        }
121
122        let dev = self.device()?;
123        let mem = self.memory()?;
124
125        let tile_size: u32 = 8;
126        let wgsl = shader::gemm_wgsl_f16(tile_size);
127
128        let shader_mod = dev
129            .device
130            .create_shader_module(wgpu::ShaderModuleDescriptor {
131                label: Some("oxicuda-gemm-f16"),
132                source: wgpu::ShaderSource::Wgsl(wgsl.into()),
133            });
134
135        let pipeline = dev
136            .device
137            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
138                label: Some("oxicuda-gemm-f16"),
139                layout: None,
140                module: &shader_mod,
141                entry_point: Some("main"),
142                compilation_options: Default::default(),
143                cache: None,
144            });
145
146        let bgl = pipeline.get_bind_group_layout(0);
147
148        // Build uniform buffer for GemmParams { m, n, k, alpha, beta }.
149        let mut params_bytes = [0u8; 20];
150        params_bytes[0..4].copy_from_slice(&(m as u32).to_le_bytes());
151        params_bytes[4..8].copy_from_slice(&(n as u32).to_le_bytes());
152        params_bytes[8..12].copy_from_slice(&(k as u32).to_le_bytes());
153        params_bytes[12..16].copy_from_slice(&(alpha as f32).to_le_bytes());
154        params_bytes[16..20].copy_from_slice(&(beta as f32).to_le_bytes());
155
156        let uniform_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
157            label: Some("oxicuda-gemm-f16-params"),
158            size: 20,
159            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
160            mapped_at_creation: false,
161        });
162        dev.queue.write_buffer(&uniform_buf, 0, &params_bytes);
163
164        let bind_group = {
165            let buffers = mem
166                .lock_buffers()
167                .map_err(|e| BackendError::DeviceError(e.to_string()))?;
168            let a_info = buffers
169                .get(&a_ptr)
170                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {a_ptr}")))?;
171            let b_info = buffers
172                .get(&b_ptr)
173                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {b_ptr}")))?;
174            let c_info = buffers
175                .get(&c_ptr)
176                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {c_ptr}")))?;
177
178            dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
179                label: Some("oxicuda-gemm-f16"),
180                layout: &bgl,
181                entries: &[
182                    wgpu::BindGroupEntry {
183                        binding: 0,
184                        resource: a_info.buffer.as_entire_binding(),
185                    },
186                    wgpu::BindGroupEntry {
187                        binding: 1,
188                        resource: b_info.buffer.as_entire_binding(),
189                    },
190                    wgpu::BindGroupEntry {
191                        binding: 2,
192                        resource: c_info.buffer.as_entire_binding(),
193                    },
194                    wgpu::BindGroupEntry {
195                        binding: 3,
196                        resource: uniform_buf.as_entire_binding(),
197                    },
198                ],
199            })
200        };
201
202        let mut encoder = dev
203            .device
204            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
205                label: Some("oxicuda-gemm-f16"),
206            });
207
208        {
209            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
210                label: Some("oxicuda-gemm-f16"),
211                timestamp_writes: None,
212            });
213            pass.set_pipeline(&pipeline);
214            pass.set_bind_group(0, &bind_group, &[]);
215            let wg_x = (n as u32).div_ceil(tile_size);
216            let wg_y = (m as u32).div_ceil(tile_size);
217            pass.dispatch_workgroups(wg_x, wg_y, 1);
218        }
219
220        dev.queue.submit(std::iter::once(encoder.finish()));
221        let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
222
223        Ok(())
224    }
225}
226
227impl Default for WebGpuBackend {
228    fn default() -> Self {
229        Self::new()
230    }
231}
232
233// ─── ComputeBackend impl ─────────────────────────────────────────────────────
234
235impl ComputeBackend for WebGpuBackend {
236    fn name(&self) -> &str {
237        "webgpu"
238    }
239
240    fn init(&mut self) -> BackendResult<()> {
241        if self.initialized {
242            return Ok(());
243        }
244
245        match WebGpuDevice::new() {
246            Ok(dev) => {
247                let dev = Arc::new(dev);
248                tracing::info!("WebGPU backend initialised on: {}", dev.adapter_name);
249                let memory = WebGpuMemoryManager::new(Arc::clone(&dev));
250                self.device = Some(dev);
251                self.memory = Some(Arc::new(memory));
252                self.initialized = true;
253                Ok(())
254            }
255            Err(e) => Err(BackendError::from(e)),
256        }
257    }
258
259    fn is_initialized(&self) -> bool {
260        self.initialized
261    }
262
263    // ── Compute operations ────────────────────────────────────────────────────
264
265    fn gemm(
266        &self,
267        trans_a: BackendTranspose,
268        trans_b: BackendTranspose,
269        m: usize,
270        n: usize,
271        k: usize,
272        alpha: f64,
273        a_ptr: u64,
274        _lda: usize,
275        b_ptr: u64,
276        _ldb: usize,
277        beta: f64,
278        c_ptr: u64,
279        _ldc: usize,
280    ) -> BackendResult<()> {
281        self.check_init()?;
282        // Zero-dimension matrices are trivially done.
283        if m == 0 || n == 0 || k == 0 {
284            return Ok(());
285        }
286
287        // Transpose not yet supported in the WGSL shader.
288        if trans_a != BackendTranspose::NoTrans || trans_b != BackendTranspose::NoTrans {
289            return Err(BackendError::Unsupported(
290                "WebGPU GEMM does not yet support transposed inputs".into(),
291            ));
292        }
293
294        let dev = self.device()?;
295        let mem = self.memory()?;
296
297        let tile_size: u32 = 8;
298        let wgsl = shader::gemm_wgsl(tile_size);
299
300        let shader_mod = dev
301            .device
302            .create_shader_module(wgpu::ShaderModuleDescriptor {
303                label: Some("oxicuda-gemm"),
304                source: wgpu::ShaderSource::Wgsl(wgsl.into()),
305            });
306
307        let pipeline = dev
308            .device
309            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
310                label: Some("oxicuda-gemm"),
311                layout: None,
312                module: &shader_mod,
313                entry_point: Some("main"),
314                compilation_options: Default::default(),
315                cache: None,
316            });
317
318        let bgl = pipeline.get_bind_group_layout(0);
319
320        // Build uniform buffer for GemmParams { m, n, k, alpha, beta }.
321        let mut params_bytes = [0u8; 20];
322        params_bytes[0..4].copy_from_slice(&(m as u32).to_le_bytes());
323        params_bytes[4..8].copy_from_slice(&(n as u32).to_le_bytes());
324        params_bytes[8..12].copy_from_slice(&(k as u32).to_le_bytes());
325        params_bytes[12..16].copy_from_slice(&(alpha as f32).to_le_bytes());
326        params_bytes[16..20].copy_from_slice(&(beta as f32).to_le_bytes());
327
328        let uniform_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
329            label: Some("oxicuda-gemm-params"),
330            size: 20,
331            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
332            mapped_at_creation: false,
333        });
334        dev.queue.write_buffer(&uniform_buf, 0, &params_bytes);
335
336        // Create bind group while holding the buffer lock.
337        let bind_group = {
338            let buffers = mem
339                .lock_buffers()
340                .map_err(|e| BackendError::DeviceError(e.to_string()))?;
341            let a_info = buffers
342                .get(&a_ptr)
343                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {a_ptr}")))?;
344            let b_info = buffers
345                .get(&b_ptr)
346                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {b_ptr}")))?;
347            let c_info = buffers
348                .get(&c_ptr)
349                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {c_ptr}")))?;
350
351            dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
352                label: Some("oxicuda-gemm"),
353                layout: &bgl,
354                entries: &[
355                    wgpu::BindGroupEntry {
356                        binding: 0,
357                        resource: a_info.buffer.as_entire_binding(),
358                    },
359                    wgpu::BindGroupEntry {
360                        binding: 1,
361                        resource: b_info.buffer.as_entire_binding(),
362                    },
363                    wgpu::BindGroupEntry {
364                        binding: 2,
365                        resource: c_info.buffer.as_entire_binding(),
366                    },
367                    wgpu::BindGroupEntry {
368                        binding: 3,
369                        resource: uniform_buf.as_entire_binding(),
370                    },
371                ],
372            })
373        };
374
375        let mut encoder = dev
376            .device
377            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
378                label: Some("oxicuda-gemm"),
379            });
380
381        {
382            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
383                label: Some("oxicuda-gemm"),
384                timestamp_writes: None,
385            });
386            pass.set_pipeline(&pipeline);
387            pass.set_bind_group(0, &bind_group, &[]);
388            let wg_x = (n as u32).div_ceil(tile_size);
389            let wg_y = (m as u32).div_ceil(tile_size);
390            pass.dispatch_workgroups(wg_x, wg_y, 1);
391        }
392
393        dev.queue.submit(std::iter::once(encoder.finish()));
394        let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
395
396        Ok(())
397    }
398
399    #[allow(clippy::too_many_arguments)]
400    fn batched_gemm(
401        &self,
402        trans_a: BackendTranspose,
403        trans_b: BackendTranspose,
404        m: usize,
405        n: usize,
406        k: usize,
407        alpha: f64,
408        a_ptr: u64,
409        _lda: usize,
410        stride_a: usize,
411        b_ptr: u64,
412        _ldb: usize,
413        stride_b: usize,
414        beta: f64,
415        c_ptr: u64,
416        _ldc: usize,
417        stride_c: usize,
418        batch_count: usize,
419    ) -> BackendResult<()> {
420        self.check_init()?;
421
422        if batch_count == 0 || m == 0 || n == 0 || k == 0 {
423            return Ok(());
424        }
425
426        if trans_a != BackendTranspose::NoTrans || trans_b != BackendTranspose::NoTrans {
427            return Err(BackendError::Unsupported(
428                "WebGPU batched GEMM does not yet support transposed inputs".into(),
429            ));
430        }
431
432        let dev = self.device()?;
433        let mem = self.memory()?;
434
435        let tile_size: u32 = 8;
436        let wgsl = shader::batched_gemm_wgsl(tile_size);
437
438        let shader_mod = dev
439            .device
440            .create_shader_module(wgpu::ShaderModuleDescriptor {
441                label: Some("oxicuda-batched-gemm"),
442                source: wgpu::ShaderSource::Wgsl(wgsl.into()),
443            });
444
445        let pipeline = dev
446            .device
447            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
448                label: Some("oxicuda-batched-gemm"),
449                layout: None,
450                module: &shader_mod,
451                entry_point: Some("main"),
452                compilation_options: Default::default(),
453                cache: None,
454            });
455
456        let bgl = pipeline.get_bind_group_layout(0);
457
458        // BatchedGemmParams: m, n, k, alpha, beta, batch_count, stride_a, stride_b, stride_c
459        // 9 fields: 5 x u32/f32 + 4 x u32 = 36 bytes total
460        // But we need 16-byte alignment for uniform buffers. 36 rounds up to 48.
461        // Actually: 3 u32 + 2 f32 + 1 u32 + 3 u32 = 9 x 4 = 36 bytes.
462        // Pad to 48 for safety (16-byte aligned).
463        let mut params_bytes = [0u8; 48];
464        params_bytes[0..4].copy_from_slice(&(m as u32).to_le_bytes());
465        params_bytes[4..8].copy_from_slice(&(n as u32).to_le_bytes());
466        params_bytes[8..12].copy_from_slice(&(k as u32).to_le_bytes());
467        params_bytes[12..16].copy_from_slice(&(alpha as f32).to_le_bytes());
468        params_bytes[16..20].copy_from_slice(&(beta as f32).to_le_bytes());
469        params_bytes[20..24].copy_from_slice(&(batch_count as u32).to_le_bytes());
470        params_bytes[24..28].copy_from_slice(&(stride_a as u32).to_le_bytes());
471        params_bytes[28..32].copy_from_slice(&(stride_b as u32).to_le_bytes());
472        params_bytes[32..36].copy_from_slice(&(stride_c as u32).to_le_bytes());
473        // bytes 36..48 are padding zeros
474
475        let uniform_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
476            label: Some("oxicuda-batched-gemm-params"),
477            size: 48,
478            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
479            mapped_at_creation: false,
480        });
481        dev.queue.write_buffer(&uniform_buf, 0, &params_bytes);
482
483        let bind_group = {
484            let buffers = mem
485                .lock_buffers()
486                .map_err(|e| BackendError::DeviceError(e.to_string()))?;
487            let a_info = buffers
488                .get(&a_ptr)
489                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {a_ptr}")))?;
490            let b_info = buffers
491                .get(&b_ptr)
492                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {b_ptr}")))?;
493            let c_info = buffers
494                .get(&c_ptr)
495                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {c_ptr}")))?;
496
497            dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
498                label: Some("oxicuda-batched-gemm"),
499                layout: &bgl,
500                entries: &[
501                    wgpu::BindGroupEntry {
502                        binding: 0,
503                        resource: a_info.buffer.as_entire_binding(),
504                    },
505                    wgpu::BindGroupEntry {
506                        binding: 1,
507                        resource: b_info.buffer.as_entire_binding(),
508                    },
509                    wgpu::BindGroupEntry {
510                        binding: 2,
511                        resource: c_info.buffer.as_entire_binding(),
512                    },
513                    wgpu::BindGroupEntry {
514                        binding: 3,
515                        resource: uniform_buf.as_entire_binding(),
516                    },
517                ],
518            })
519        };
520
521        let mut encoder = dev
522            .device
523            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
524                label: Some("oxicuda-batched-gemm"),
525            });
526
527        {
528            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
529                label: Some("oxicuda-batched-gemm"),
530                timestamp_writes: None,
531            });
532            pass.set_pipeline(&pipeline);
533            pass.set_bind_group(0, &bind_group, &[]);
534            let wg_x = (n as u32).div_ceil(tile_size);
535            let wg_y = (m as u32).div_ceil(tile_size);
536            pass.dispatch_workgroups(wg_x, wg_y, batch_count as u32);
537        }
538
539        dev.queue.submit(std::iter::once(encoder.finish()));
540        let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
541
542        Ok(())
543    }
544
545    fn conv2d_forward(
546        &self,
547        input_ptr: u64,
548        input_shape: &[usize],
549        filter_ptr: u64,
550        filter_shape: &[usize],
551        output_ptr: u64,
552        output_shape: &[usize],
553        stride: &[usize],
554        padding: &[usize],
555    ) -> BackendResult<()> {
556        self.check_init()?;
557
558        if input_shape.len() != 4 {
559            return Err(BackendError::InvalidArgument(
560                "input_shape must have 4 elements (NCHW)".into(),
561            ));
562        }
563        if filter_shape.len() != 4 {
564            return Err(BackendError::InvalidArgument(
565                "filter_shape must have 4 elements (KCFHFW)".into(),
566            ));
567        }
568        if output_shape.len() != 4 {
569            return Err(BackendError::InvalidArgument(
570                "output_shape must have 4 elements (NKOhOw)".into(),
571            ));
572        }
573        if stride.len() != 2 {
574            return Err(BackendError::InvalidArgument(
575                "stride must have 2 elements [sh, sw]".into(),
576            ));
577        }
578        if padding.len() != 2 {
579            return Err(BackendError::InvalidArgument(
580                "padding must have 2 elements [ph, pw]".into(),
581            ));
582        }
583
584        let mem = self.memory()?;
585
586        let batch = input_shape[0];
587        let c_in = input_shape[1];
588        let h_in = input_shape[2];
589        let w_in = input_shape[3];
590        let k_out = filter_shape[0];
591        let fh = filter_shape[2];
592        let fw = filter_shape[3];
593        let oh = output_shape[2];
594        let ow = output_shape[3];
595        let sh = stride[0];
596        let sw = stride[1];
597        let ph = padding[0];
598        let pw = padding[1];
599
600        let in_elems: usize = input_shape.iter().product();
601        let f_elems: usize = filter_shape.iter().product();
602        let o_elems: usize = output_shape.iter().product();
603
604        // CPU fallback: download input + filter, compute, upload output.
605        let mut in_bytes = vec![0u8; in_elems * 4];
606        let mut f_bytes = vec![0u8; f_elems * 4];
607        mem.copy_from_device(&mut in_bytes, input_ptr)
608            .map_err(BackendError::from)?;
609        mem.copy_from_device(&mut f_bytes, filter_ptr)
610            .map_err(BackendError::from)?;
611
612        let in_f32 = bytes_to_f32_vec(&in_bytes);
613        let f_f32 = bytes_to_f32_vec(&f_bytes);
614        let mut out_f32 = vec![0.0f32; o_elems];
615
616        for b in 0..batch {
617            for kf in 0..k_out {
618                for oy in 0..oh {
619                    for ox in 0..ow {
620                        let mut acc = 0.0f32;
621                        for ci in 0..c_in {
622                            for fy in 0..fh {
623                                for fx in 0..fw {
624                                    let iy = (oy * sh + fy) as isize - ph as isize;
625                                    let ix = (ox * sw + fx) as isize - pw as isize;
626                                    if iy >= 0
627                                        && (iy as usize) < h_in
628                                        && ix >= 0
629                                        && (ix as usize) < w_in
630                                    {
631                                        let in_idx = ((b * c_in + ci) * h_in + iy as usize) * w_in
632                                            + ix as usize;
633                                        let f_idx = ((kf * c_in + ci) * fh + fy) * fw + fx;
634                                        acc += in_f32[in_idx] * f_f32[f_idx];
635                                    }
636                                }
637                            }
638                        }
639                        out_f32[((b * k_out + kf) * oh + oy) * ow + ox] = acc;
640                    }
641                }
642            }
643        }
644
645        let out_bytes = f32_slice_to_bytes(&out_f32);
646        mem.copy_to_device(output_ptr, &out_bytes)
647            .map_err(BackendError::from)?;
648
649        Ok(())
650    }
651
652    fn attention(
653        &self,
654        q_ptr: u64,
655        k_ptr: u64,
656        v_ptr: u64,
657        o_ptr: u64,
658        batch: usize,
659        heads: usize,
660        seq_q: usize,
661        seq_kv: usize,
662        head_dim: usize,
663        scale: f64,
664        causal: bool,
665    ) -> BackendResult<()> {
666        self.check_init()?;
667
668        if seq_q == 0 || seq_kv == 0 || head_dim == 0 {
669            return Err(BackendError::InvalidArgument(
670                "seq_q, seq_kv, and head_dim must all be > 0".into(),
671            ));
672        }
673        if scale <= 0.0 || !scale.is_finite() {
674            return Err(BackendError::InvalidArgument(format!(
675                "scale must be a positive finite number, got {scale}"
676            )));
677        }
678
679        let mem = self.memory()?;
680
681        let batch_heads = batch * heads;
682        let q_elems = batch_heads * seq_q * head_dim;
683        let kv_elems = batch_heads * seq_kv * head_dim;
684        let o_elems = q_elems;
685
686        // CPU fallback: download Q, K, V, compute attention, upload O.
687        let mut q_bytes = vec![0u8; q_elems * 4];
688        let mut k_bytes = vec![0u8; kv_elems * 4];
689        let mut v_bytes = vec![0u8; kv_elems * 4];
690
691        mem.copy_from_device(&mut q_bytes, q_ptr)
692            .map_err(BackendError::from)?;
693        mem.copy_from_device(&mut k_bytes, k_ptr)
694            .map_err(BackendError::from)?;
695        mem.copy_from_device(&mut v_bytes, v_ptr)
696            .map_err(BackendError::from)?;
697
698        let q_f32 = bytes_to_f32_vec(&q_bytes);
699        let k_f32 = bytes_to_f32_vec(&k_bytes);
700        let v_f32 = bytes_to_f32_vec(&v_bytes);
701        let mut o_f32 = vec![0.0f32; o_elems];
702
703        let scale_f32 = scale as f32;
704
705        for bh in 0..batch_heads {
706            let q_off = bh * seq_q * head_dim;
707            let k_off = bh * seq_kv * head_dim;
708            let v_off = k_off;
709
710            for sq in 0..seq_q {
711                let kv_limit = if causal { (sq + 1).min(seq_kv) } else { seq_kv };
712
713                // Pass 1: find max score for numerical stability
714                let mut max_score = f32::NEG_INFINITY;
715                for sk in 0..kv_limit {
716                    let mut dot = 0.0f32;
717                    for dd in 0..head_dim {
718                        dot +=
719                            q_f32[q_off + sq * head_dim + dd] * k_f32[k_off + sk * head_dim + dd];
720                    }
721                    let s = dot * scale_f32;
722                    if s > max_score {
723                        max_score = s;
724                    }
725                }
726
727                // Pass 2: exp(score - max), accumulate weighted V
728                let mut sum_exp = 0.0f32;
729                let mut acc = vec![0.0f32; head_dim];
730                for sk in 0..kv_limit {
731                    let mut dot = 0.0f32;
732                    for dd in 0..head_dim {
733                        dot +=
734                            q_f32[q_off + sq * head_dim + dd] * k_f32[k_off + sk * head_dim + dd];
735                    }
736                    let w = (dot * scale_f32 - max_score).exp();
737                    sum_exp += w;
738                    for dd in 0..head_dim {
739                        acc[dd] += w * v_f32[v_off + sk * head_dim + dd];
740                    }
741                }
742
743                // Normalise
744                let o_base = q_off + sq * head_dim;
745                if sum_exp > 0.0 {
746                    for dd in 0..head_dim {
747                        o_f32[o_base + dd] = acc[dd] / sum_exp;
748                    }
749                }
750            }
751        }
752
753        let o_bytes = f32_slice_to_bytes(&o_f32);
754        mem.copy_to_device(o_ptr, &o_bytes)
755            .map_err(BackendError::from)?;
756
757        Ok(())
758    }
759
760    fn reduce(
761        &self,
762        op: ReduceOp,
763        input_ptr: u64,
764        output_ptr: u64,
765        shape: &[usize],
766        axis: usize,
767    ) -> BackendResult<()> {
768        self.check_init()?;
769
770        if shape.is_empty() {
771            return Err(BackendError::InvalidArgument(
772                "shape must not be empty".into(),
773            ));
774        }
775        if axis >= shape.len() {
776            return Err(BackendError::InvalidArgument(format!(
777                "axis {axis} is out of bounds for shape of length {}",
778                shape.len()
779            )));
780        }
781
782        // Only flat 1-D reduction (shape.len() == 1, axis == 0) is currently
783        // supported on the GPU.  Multi-dimensional reductions require batched
784        // shaders that are not yet implemented.
785        if shape.len() != 1 {
786            return Err(BackendError::Unsupported(
787                "WebGPU reduce currently supports only 1-D shapes".into(),
788            ));
789        }
790
791        let n_elements = shape[0];
792        if n_elements == 0 {
793            return Ok(());
794        }
795
796        let dev = self.device()?;
797        let mem = self.memory()?;
798        let op_str = map_reduce_op(op);
799
800        // ── Pass 1: per-workgroup reduction ─────────────────────────────────
801        let wg_count = (n_elements as u32).div_ceil(256);
802
803        let pass1_wgsl = shader::reduction_wgsl(op_str);
804        let pass1_shader = dev
805            .device
806            .create_shader_module(wgpu::ShaderModuleDescriptor {
807                label: Some("oxicuda-reduce-pass1"),
808                source: wgpu::ShaderSource::Wgsl(pass1_wgsl.into()),
809            });
810        let pass1_pipeline = dev
811            .device
812            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
813                label: Some("oxicuda-reduce-pass1"),
814                layout: None,
815                module: &pass1_shader,
816                entry_point: Some("main"),
817                compilation_options: Default::default(),
818                cache: None,
819            });
820
821        // Partial-sums buffer (temporary).
822        let partial_buf = dev.device.create_buffer(&wgpu::BufferDescriptor {
823            label: Some("oxicuda-reduce-partial"),
824            size: (wg_count as u64) * 4, // f32 per workgroup
825            usage: wgpu::BufferUsages::STORAGE
826                | wgpu::BufferUsages::COPY_SRC
827                | wgpu::BufferUsages::COPY_DST,
828            mapped_at_creation: false,
829        });
830
831        // Uniform for ReduceParams { n: u32 }.
832        let mut p1_params = [0u8; 4];
833        p1_params[0..4].copy_from_slice(&(n_elements as u32).to_le_bytes());
834        let p1_uniform = dev.device.create_buffer(&wgpu::BufferDescriptor {
835            label: Some("oxicuda-reduce-p1-params"),
836            size: 4,
837            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
838            mapped_at_creation: false,
839        });
840        dev.queue.write_buffer(&p1_uniform, 0, &p1_params);
841
842        let bgl1 = pass1_pipeline.get_bind_group_layout(0);
843
844        let bg1 = {
845            let buffers = mem
846                .lock_buffers()
847                .map_err(|e| BackendError::DeviceError(e.to_string()))?;
848            let in_info = buffers.get(&input_ptr).ok_or_else(|| {
849                BackendError::InvalidArgument(format!("unknown handle {input_ptr}"))
850            })?;
851
852            dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
853                label: Some("oxicuda-reduce-pass1"),
854                layout: &bgl1,
855                entries: &[
856                    wgpu::BindGroupEntry {
857                        binding: 0,
858                        resource: in_info.buffer.as_entire_binding(),
859                    },
860                    wgpu::BindGroupEntry {
861                        binding: 1,
862                        resource: partial_buf.as_entire_binding(),
863                    },
864                    wgpu::BindGroupEntry {
865                        binding: 2,
866                        resource: p1_uniform.as_entire_binding(),
867                    },
868                ],
869            })
870        };
871
872        // ── Pass 2: final reduction of partial sums ─────────────────────────
873        let pass2_wgsl = shader::reduction_final_wgsl(op_str);
874        let pass2_shader = dev
875            .device
876            .create_shader_module(wgpu::ShaderModuleDescriptor {
877                label: Some("oxicuda-reduce-pass2"),
878                source: wgpu::ShaderSource::Wgsl(pass2_wgsl.into()),
879            });
880        let pass2_pipeline = dev
881            .device
882            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
883                label: Some("oxicuda-reduce-pass2"),
884                layout: None,
885                module: &pass2_shader,
886                entry_point: Some("main"),
887                compilation_options: Default::default(),
888                cache: None,
889            });
890
891        // FinalReduceParams { num_groups: u32 }.
892        let mut p2_params = [0u8; 4];
893        p2_params[0..4].copy_from_slice(&wg_count.to_le_bytes());
894        let p2_uniform = dev.device.create_buffer(&wgpu::BufferDescriptor {
895            label: Some("oxicuda-reduce-p2-params"),
896            size: 4,
897            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
898            mapped_at_creation: false,
899        });
900        dev.queue.write_buffer(&p2_uniform, 0, &p2_params);
901
902        let bgl2 = pass2_pipeline.get_bind_group_layout(0);
903
904        let bg2 = {
905            let buffers = mem
906                .lock_buffers()
907                .map_err(|e| BackendError::DeviceError(e.to_string()))?;
908            let out_info = buffers.get(&output_ptr).ok_or_else(|| {
909                BackendError::InvalidArgument(format!("unknown handle {output_ptr}"))
910            })?;
911
912            dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
913                label: Some("oxicuda-reduce-pass2"),
914                layout: &bgl2,
915                entries: &[
916                    wgpu::BindGroupEntry {
917                        binding: 0,
918                        resource: partial_buf.as_entire_binding(),
919                    },
920                    wgpu::BindGroupEntry {
921                        binding: 1,
922                        resource: out_info.buffer.as_entire_binding(),
923                    },
924                    wgpu::BindGroupEntry {
925                        binding: 2,
926                        resource: p2_uniform.as_entire_binding(),
927                    },
928                ],
929            })
930        };
931
932        // ── Encode both passes into one command buffer ──────────────────────
933        let mut encoder = dev
934            .device
935            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
936                label: Some("oxicuda-reduce"),
937            });
938
939        {
940            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
941                label: Some("oxicuda-reduce-pass1"),
942                timestamp_writes: None,
943            });
944            pass.set_pipeline(&pass1_pipeline);
945            pass.set_bind_group(0, &bg1, &[]);
946            pass.dispatch_workgroups(wg_count, 1, 1);
947        }
948        {
949            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
950                label: Some("oxicuda-reduce-pass2"),
951                timestamp_writes: None,
952            });
953            pass.set_pipeline(&pass2_pipeline);
954            pass.set_bind_group(0, &bg2, &[]);
955            pass.dispatch_workgroups(1, 1, 1);
956        }
957
958        dev.queue.submit(std::iter::once(encoder.finish()));
959        let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
960
961        // For "mean", divide the result by N on the host side.
962        if op == ReduceOp::Mean && n_elements > 1 {
963            let mut buf = [0u8; 4];
964            mem.copy_from_device(&mut buf, output_ptr)
965                .map_err(BackendError::from)?;
966            let val = f32::from_le_bytes(buf);
967            let mean = val / (n_elements as f32);
968            mem.copy_to_device(output_ptr, &mean.to_le_bytes())
969                .map_err(BackendError::from)?;
970        }
971
972        Ok(())
973    }
974
975    fn unary(&self, op: UnaryOp, input_ptr: u64, output_ptr: u64, n: usize) -> BackendResult<()> {
976        self.check_init()?;
977        if n == 0 {
978            return Ok(());
979        }
980
981        let dev = self.device()?;
982        let mem = self.memory()?;
983
984        let op_str = map_unary_op(op);
985        let wgsl = shader::elementwise_wgsl(op_str);
986
987        let shader_mod = dev
988            .device
989            .create_shader_module(wgpu::ShaderModuleDescriptor {
990                label: Some("oxicuda-unary"),
991                source: wgpu::ShaderSource::Wgsl(wgsl.into()),
992            });
993
994        let pipeline = dev
995            .device
996            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
997                label: Some("oxicuda-unary"),
998                layout: None,
999                module: &shader_mod,
1000                entry_point: Some("main"),
1001                compilation_options: Default::default(),
1002                cache: None,
1003            });
1004
1005        let bgl = pipeline.get_bind_group_layout(0);
1006
1007        let bind_group = {
1008            let buffers = mem
1009                .lock_buffers()
1010                .map_err(|e| BackendError::DeviceError(e.to_string()))?;
1011            let in_info = buffers.get(&input_ptr).ok_or_else(|| {
1012                BackendError::InvalidArgument(format!("unknown handle {input_ptr}"))
1013            })?;
1014            let out_info = buffers.get(&output_ptr).ok_or_else(|| {
1015                BackendError::InvalidArgument(format!("unknown handle {output_ptr}"))
1016            })?;
1017
1018            dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
1019                label: Some("oxicuda-unary"),
1020                layout: &bgl,
1021                entries: &[
1022                    wgpu::BindGroupEntry {
1023                        binding: 0,
1024                        resource: in_info.buffer.as_entire_binding(),
1025                    },
1026                    wgpu::BindGroupEntry {
1027                        binding: 1,
1028                        resource: out_info.buffer.as_entire_binding(),
1029                    },
1030                ],
1031            })
1032        };
1033
1034        let mut encoder = dev
1035            .device
1036            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1037                label: Some("oxicuda-unary"),
1038            });
1039
1040        {
1041            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1042                label: Some("oxicuda-unary"),
1043                timestamp_writes: None,
1044            });
1045            pass.set_pipeline(&pipeline);
1046            pass.set_bind_group(0, &bind_group, &[]);
1047            let workgroups = (n as u32).div_ceil(256);
1048            pass.dispatch_workgroups(workgroups, 1, 1);
1049        }
1050
1051        dev.queue.submit(std::iter::once(encoder.finish()));
1052        let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
1053
1054        Ok(())
1055    }
1056
1057    fn binary(
1058        &self,
1059        op: BinaryOp,
1060        a_ptr: u64,
1061        b_ptr: u64,
1062        output_ptr: u64,
1063        n: usize,
1064    ) -> BackendResult<()> {
1065        self.check_init()?;
1066        if n == 0 {
1067            return Ok(());
1068        }
1069
1070        let dev = self.device()?;
1071        let mem = self.memory()?;
1072
1073        let op_str = map_binary_op(op);
1074        let wgsl = shader::binary_wgsl(op_str);
1075
1076        let shader_mod = dev
1077            .device
1078            .create_shader_module(wgpu::ShaderModuleDescriptor {
1079                label: Some("oxicuda-binary"),
1080                source: wgpu::ShaderSource::Wgsl(wgsl.into()),
1081            });
1082
1083        let pipeline = dev
1084            .device
1085            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1086                label: Some("oxicuda-binary"),
1087                layout: None,
1088                module: &shader_mod,
1089                entry_point: Some("main"),
1090                compilation_options: Default::default(),
1091                cache: None,
1092            });
1093
1094        let bgl = pipeline.get_bind_group_layout(0);
1095
1096        let bind_group = {
1097            let buffers = mem
1098                .lock_buffers()
1099                .map_err(|e| BackendError::DeviceError(e.to_string()))?;
1100            let a_info = buffers
1101                .get(&a_ptr)
1102                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {a_ptr}")))?;
1103            let b_info = buffers
1104                .get(&b_ptr)
1105                .ok_or_else(|| BackendError::InvalidArgument(format!("unknown handle {b_ptr}")))?;
1106            let out_info = buffers.get(&output_ptr).ok_or_else(|| {
1107                BackendError::InvalidArgument(format!("unknown handle {output_ptr}"))
1108            })?;
1109
1110            dev.device.create_bind_group(&wgpu::BindGroupDescriptor {
1111                label: Some("oxicuda-binary"),
1112                layout: &bgl,
1113                entries: &[
1114                    wgpu::BindGroupEntry {
1115                        binding: 0,
1116                        resource: a_info.buffer.as_entire_binding(),
1117                    },
1118                    wgpu::BindGroupEntry {
1119                        binding: 1,
1120                        resource: b_info.buffer.as_entire_binding(),
1121                    },
1122                    wgpu::BindGroupEntry {
1123                        binding: 2,
1124                        resource: out_info.buffer.as_entire_binding(),
1125                    },
1126                ],
1127            })
1128        };
1129
1130        let mut encoder = dev
1131            .device
1132            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1133                label: Some("oxicuda-binary"),
1134            });
1135
1136        {
1137            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1138                label: Some("oxicuda-binary"),
1139                timestamp_writes: None,
1140            });
1141            pass.set_pipeline(&pipeline);
1142            pass.set_bind_group(0, &bind_group, &[]);
1143            let workgroups = (n as u32).div_ceil(256);
1144            pass.dispatch_workgroups(workgroups, 1, 1);
1145        }
1146
1147        dev.queue.submit(std::iter::once(encoder.finish()));
1148        let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
1149
1150        Ok(())
1151    }
1152
1153    // ── Synchronisation ───────────────────────────────────────────────────────
1154
1155    fn synchronize(&self) -> BackendResult<()> {
1156        self.check_init()?;
1157        if let Some(dev) = &self.device {
1158            let _ = dev.device.poll(wgpu::PollType::wait_indefinitely());
1159        }
1160        Ok(())
1161    }
1162
1163    // ── Memory management ─────────────────────────────────────────────────────
1164
1165    fn alloc(&self, bytes: usize) -> BackendResult<u64> {
1166        self.check_init()?;
1167        if bytes == 0 {
1168            return Err(BackendError::InvalidArgument(
1169                "cannot allocate 0 bytes".into(),
1170            ));
1171        }
1172        self.memory()?.alloc(bytes).map_err(BackendError::from)
1173    }
1174
1175    fn free(&self, ptr: u64) -> BackendResult<()> {
1176        self.check_init()?;
1177        self.memory()?.free(ptr).map_err(BackendError::from)
1178    }
1179
1180    fn copy_htod(&self, dst: u64, src: &[u8]) -> BackendResult<()> {
1181        self.check_init()?;
1182        if src.is_empty() {
1183            return Ok(());
1184        }
1185        self.memory()?
1186            .copy_to_device(dst, src)
1187            .map_err(BackendError::from)
1188    }
1189
1190    fn copy_dtoh(&self, dst: &mut [u8], src: u64) -> BackendResult<()> {
1191        self.check_init()?;
1192        if dst.is_empty() {
1193            return Ok(());
1194        }
1195        self.memory()?
1196            .copy_from_device(dst, src)
1197            .map_err(BackendError::from)
1198    }
1199}
1200
1201// ─── Byte ↔ f32 helpers ──────────────────────────────────────────────────────
1202
1203/// Convert a `&[u8]` (length must be a multiple of 4) to a `Vec<f32>`.
1204fn bytes_to_f32_vec(bytes: &[u8]) -> Vec<f32> {
1205    bytes
1206        .chunks_exact(4)
1207        .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
1208        .collect()
1209}
1210
1211/// Convert a `&[f32]` slice to its little-endian byte representation.
1212fn f32_slice_to_bytes(data: &[f32]) -> Vec<u8> {
1213    data.iter().flat_map(|v| v.to_le_bytes()).collect()
1214}
1215
1216// ─── Tests ───────────────────────────────────────────────────────────────────
1217
1218#[cfg(test)]
1219mod tests {
1220    use super::*;
1221    use oxicuda_backend::{BackendTranspose, BinaryOp, ReduceOp, UnaryOp};
1222
1223    // ── Construction ──────────────────────────────────────────────────────────
1224
1225    #[test]
1226    fn webgpu_backend_new_uninitialized() {
1227        let b = WebGpuBackend::new();
1228        assert!(!b.is_initialized());
1229    }
1230
1231    #[test]
1232    fn webgpu_backend_name() {
1233        let b = WebGpuBackend::new();
1234        assert_eq!(b.name(), "webgpu");
1235    }
1236
1237    #[test]
1238    fn webgpu_backend_default() {
1239        let b = WebGpuBackend::default();
1240        assert!(!b.is_initialized());
1241        assert_eq!(b.name(), "webgpu");
1242    }
1243
1244    #[test]
1245    fn backend_debug_impl() {
1246        let b = WebGpuBackend::new();
1247        let s = format!("{b:?}");
1248        assert!(s.contains("WebGpuBackend"));
1249    }
1250
1251    // ── Object-safety smoke test ──────────────────────────────────────────────
1252
1253    #[test]
1254    fn backend_object_safe() {
1255        let b: Box<dyn ComputeBackend> = Box::new(WebGpuBackend::new());
1256        assert_eq!(b.name(), "webgpu");
1257    }
1258
1259    // ── Not-initialized guards ────────────────────────────────────────────────
1260
1261    #[test]
1262    fn backend_not_initialized_gemm() {
1263        let b = WebGpuBackend::new();
1264        let result = b.gemm(
1265            BackendTranspose::NoTrans,
1266            BackendTranspose::NoTrans,
1267            4,
1268            4,
1269            4,
1270            1.0,
1271            0,
1272            4,
1273            0,
1274            4,
1275            0.0,
1276            0,
1277            4,
1278        );
1279        assert_eq!(result, Err(BackendError::NotInitialized));
1280    }
1281
1282    #[test]
1283    fn backend_not_initialized_alloc() {
1284        let b = WebGpuBackend::new();
1285        let result = b.alloc(1024);
1286        assert_eq!(result, Err(BackendError::NotInitialized));
1287    }
1288
1289    #[test]
1290    fn backend_not_initialized_synchronize() {
1291        let b = WebGpuBackend::new();
1292        assert_eq!(b.synchronize(), Err(BackendError::NotInitialized));
1293    }
1294
1295    #[test]
1296    fn backend_not_initialized_free() {
1297        let b = WebGpuBackend::new();
1298        assert_eq!(b.free(1), Err(BackendError::NotInitialized));
1299    }
1300
1301    #[test]
1302    fn backend_not_initialized_copy_htod() {
1303        let b = WebGpuBackend::new();
1304        assert_eq!(b.copy_htod(1, b"hello"), Err(BackendError::NotInitialized));
1305    }
1306
1307    #[test]
1308    fn backend_not_initialized_copy_dtoh() {
1309        let b = WebGpuBackend::new();
1310        let mut buf = [0u8; 4];
1311        assert_eq!(b.copy_dtoh(&mut buf, 1), Err(BackendError::NotInitialized));
1312    }
1313
1314    // ── Zero-size / trivial-OK paths (no GPU needed) ─────────────────────────
1315
1316    /// These tests exercise the "no-op for zero size" branches.  We need the
1317    /// backend to be initialised, but if no GPU is available we skip.
1318    fn try_init() -> Option<WebGpuBackend> {
1319        let mut b = WebGpuBackend::new();
1320        match b.init() {
1321            Ok(()) => Some(b),
1322            Err(_) => None,
1323        }
1324    }
1325
1326    #[test]
1327    fn gemm_zero_size_after_init() {
1328        let Some(b) = try_init() else {
1329            return;
1330        };
1331        let result = b.gemm(
1332            BackendTranspose::NoTrans,
1333            BackendTranspose::NoTrans,
1334            0,
1335            0,
1336            0,
1337            1.0,
1338            0,
1339            1,
1340            0,
1341            1,
1342            0.0,
1343            0,
1344            1,
1345        );
1346        assert_eq!(result, Ok(()));
1347    }
1348
1349    #[test]
1350    fn unary_zero_elements_after_init() {
1351        let Some(b) = try_init() else {
1352            return;
1353        };
1354        assert_eq!(b.unary(UnaryOp::Relu, 0, 0, 0), Ok(()));
1355    }
1356
1357    #[test]
1358    fn binary_zero_elements_after_init() {
1359        let Some(b) = try_init() else {
1360            return;
1361        };
1362        assert_eq!(b.binary(BinaryOp::Add, 0, 0, 0, 0), Ok(()));
1363    }
1364
1365    #[test]
1366    fn copy_htod_empty_noop() {
1367        let Some(b) = try_init() else {
1368            return;
1369        };
1370        assert_eq!(b.copy_htod(0, &[]), Ok(()));
1371    }
1372
1373    #[test]
1374    fn copy_dtoh_empty_noop() {
1375        let Some(b) = try_init() else {
1376            return;
1377        };
1378        assert_eq!(b.copy_dtoh(&mut [], 0), Ok(()));
1379    }
1380
1381    #[test]
1382    fn alloc_zero_bytes_error() {
1383        let Some(b) = try_init() else {
1384            return;
1385        };
1386        assert_eq!(
1387            b.alloc(0),
1388            Err(BackendError::InvalidArgument(
1389                "cannot allocate 0 bytes".into()
1390            ))
1391        );
1392    }
1393
1394    #[test]
1395    fn synchronize_after_init() {
1396        let Some(b) = try_init() else {
1397            return;
1398        };
1399        assert_eq!(b.synchronize(), Ok(()));
1400    }
1401
1402    // ── Argument validation (post-init) ───────────────────────────────────────
1403
1404    #[test]
1405    fn reduce_empty_shape_error() {
1406        let Some(b) = try_init() else {
1407            return;
1408        };
1409        assert_eq!(
1410            b.reduce(ReduceOp::Sum, 0, 0, &[], 0),
1411            Err(BackendError::InvalidArgument(
1412                "shape must not be empty".into()
1413            ))
1414        );
1415    }
1416
1417    #[test]
1418    fn reduce_axis_out_of_bounds_error() {
1419        let Some(b) = try_init() else {
1420            return;
1421        };
1422        assert_eq!(
1423            b.reduce(ReduceOp::Sum, 0, 0, &[4, 4], 5),
1424            Err(BackendError::InvalidArgument(
1425                "axis 5 is out of bounds for shape of length 2".into()
1426            ))
1427        );
1428    }
1429
1430    #[test]
1431    fn attention_zero_seq_error() {
1432        let Some(b) = try_init() else {
1433            return;
1434        };
1435        assert_eq!(
1436            b.attention(0, 0, 0, 0, 1, 1, 0, 8, 64, 0.125, false),
1437            Err(BackendError::InvalidArgument(
1438                "seq_q, seq_kv, and head_dim must all be > 0".into()
1439            ))
1440        );
1441    }
1442
1443    #[test]
1444    fn attention_nonpositive_scale_error() {
1445        let Some(b) = try_init() else {
1446            return;
1447        };
1448        assert_eq!(
1449            b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, 0.0, false),
1450            Err(BackendError::InvalidArgument(
1451                "scale must be a positive finite number, got 0".into()
1452            ))
1453        );
1454        assert_eq!(
1455            b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, -1.0, false),
1456            Err(BackendError::InvalidArgument(
1457                "scale must be a positive finite number, got -1".into()
1458            ))
1459        );
1460        assert!(
1461            b.attention(0, 0, 0, 0, 1, 1, 8, 8, 64, f64::INFINITY, false)
1462                .is_err()
1463        );
1464    }
1465
1466    #[test]
1467    fn conv2d_wrong_input_shape_error() {
1468        let Some(b) = try_init() else {
1469            return;
1470        };
1471        // 3-element input_shape — should fail.
1472        assert_eq!(
1473            b.conv2d_forward(
1474                0,
1475                &[1, 3, 32],
1476                0,
1477                &[16, 3, 3, 3],
1478                0,
1479                &[1, 16, 30, 30],
1480                &[1, 1],
1481                &[0, 0]
1482            ),
1483            Err(BackendError::InvalidArgument(
1484                "input_shape must have 4 elements (NCHW)".into()
1485            ))
1486        );
1487    }
1488
1489    #[test]
1490    fn conv2d_wrong_filter_shape_error() {
1491        let Some(b) = try_init() else {
1492            return;
1493        };
1494        assert_eq!(
1495            b.conv2d_forward(
1496                0,
1497                &[1, 3, 32, 32],
1498                0,
1499                &[16, 3, 3],
1500                0,
1501                &[1, 16, 30, 30],
1502                &[1, 1],
1503                &[0, 0]
1504            ),
1505            Err(BackendError::InvalidArgument(
1506                "filter_shape must have 4 elements (KCFHFW)".into()
1507            ))
1508        );
1509    }
1510
1511    #[test]
1512    fn conv2d_wrong_stride_shape_error() {
1513        let Some(b) = try_init() else {
1514            return;
1515        };
1516        assert_eq!(
1517            b.conv2d_forward(
1518                0,
1519                &[1, 3, 32, 32],
1520                0,
1521                &[16, 3, 3, 3],
1522                0,
1523                &[1, 16, 30, 30],
1524                &[1], // <-- wrong
1525                &[0, 0],
1526            ),
1527            Err(BackendError::InvalidArgument(
1528                "stride must have 2 elements [sh, sw]".into()
1529            ))
1530        );
1531    }
1532
1533    // ── Init is idempotent ────────────────────────────────────────────────────
1534
1535    #[test]
1536    fn init_idempotent() {
1537        let Some(mut b) = try_init() else {
1538            return;
1539        };
1540        // Second call must succeed without error.
1541        assert_eq!(b.init(), Ok(()));
1542        assert!(b.is_initialized());
1543    }
1544
1545    // ── Graceful failure ──────────────────────────────────────────────────────
1546
1547    #[test]
1548    fn webgpu_init_graceful_failure() {
1549        // We cannot force a failure, but we can at least verify that init()
1550        // returns a Result and never panics.
1551        let mut b = WebGpuBackend::new();
1552        let _result = b.init(); // Ok or Err — both are acceptable.
1553        // No panic => test passes.
1554    }
1555
1556    // ── GPU compute tests ─────────────────────────────────────────────────────
1557    //
1558    // These helpers upload f32 slices and read back results, exercising the
1559    // full shader → pipeline → dispatch path.
1560
1561    /// Helper: upload `data` (f32 slice) to a new GPU buffer, return its handle.
1562    fn upload_f32(b: &WebGpuBackend, data: &[f32]) -> u64 {
1563        let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
1564        let h = b.alloc(bytes.len()).expect("alloc");
1565        b.copy_htod(h, &bytes).expect("copy_htod");
1566        h
1567    }
1568
1569    /// Helper: download `n` f32 values from a GPU buffer handle.
1570    fn download_f32(b: &WebGpuBackend, h: u64, n: usize) -> Vec<f32> {
1571        let mut bytes = vec![0u8; n * 4];
1572        b.copy_dtoh(&mut bytes, h).expect("copy_dtoh");
1573        bytes
1574            .chunks_exact(4)
1575            .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
1576            .collect()
1577    }
1578
1579    #[test]
1580    fn unary_neg_small() {
1581        let Some(b) = try_init() else { return };
1582        let input = [1.0f32, -2.0, 3.0, 0.0];
1583        let in_h = upload_f32(&b, &input);
1584        let out_h = b.alloc(input.len() * 4).expect("alloc output");
1585
1586        b.unary(UnaryOp::Neg, in_h, out_h, input.len())
1587            .expect("unary neg");
1588
1589        let result = download_f32(&b, out_h, input.len());
1590        let expected = [-1.0f32, 2.0, -3.0, 0.0];
1591        for (r, e) in result.iter().zip(expected.iter()) {
1592            assert!((r - e).abs() < 1e-6, "got {r}, expected {e}");
1593        }
1594
1595        b.free(in_h).expect("free");
1596        b.free(out_h).expect("free");
1597    }
1598
1599    #[test]
1600    fn unary_abs_small() {
1601        let Some(b) = try_init() else { return };
1602        let input = [-3.0f32, 4.0, -5.0, 0.0];
1603        let in_h = upload_f32(&b, &input);
1604        let out_h = b.alloc(input.len() * 4).expect("alloc output");
1605
1606        b.unary(UnaryOp::Abs, in_h, out_h, input.len())
1607            .expect("unary abs");
1608
1609        let result = download_f32(&b, out_h, input.len());
1610        let expected = [3.0f32, 4.0, 5.0, 0.0];
1611        for (r, e) in result.iter().zip(expected.iter()) {
1612            assert!((r - e).abs() < 1e-6, "got {r}, expected {e}");
1613        }
1614
1615        b.free(in_h).expect("free");
1616        b.free(out_h).expect("free");
1617    }
1618
1619    #[test]
1620    fn binary_add_small() {
1621        let Some(b) = try_init() else { return };
1622        let a = [1.0f32, 2.0, 3.0, 4.0];
1623        let bv = [10.0f32, 20.0, 30.0, 40.0];
1624        let a_h = upload_f32(&b, &a);
1625        let b_h = upload_f32(&b, &bv);
1626        let out_h = b.alloc(a.len() * 4).expect("alloc output");
1627
1628        b.binary(BinaryOp::Add, a_h, b_h, out_h, a.len())
1629            .expect("binary add");
1630
1631        let result = download_f32(&b, out_h, a.len());
1632        let expected = [11.0f32, 22.0, 33.0, 44.0];
1633        for (r, e) in result.iter().zip(expected.iter()) {
1634            assert!((r - e).abs() < 1e-6, "got {r}, expected {e}");
1635        }
1636
1637        b.free(a_h).expect("free");
1638        b.free(b_h).expect("free");
1639        b.free(out_h).expect("free");
1640    }
1641
1642    #[test]
1643    fn binary_mul_small() {
1644        let Some(b) = try_init() else { return };
1645        let a = [2.0f32, 3.0, 4.0, 5.0];
1646        let bv = [10.0f32, 10.0, 10.0, 10.0];
1647        let a_h = upload_f32(&b, &a);
1648        let b_h = upload_f32(&b, &bv);
1649        let out_h = b.alloc(a.len() * 4).expect("alloc output");
1650
1651        b.binary(BinaryOp::Mul, a_h, b_h, out_h, a.len())
1652            .expect("binary mul");
1653
1654        let result = download_f32(&b, out_h, a.len());
1655        let expected = [20.0f32, 30.0, 40.0, 50.0];
1656        for (r, e) in result.iter().zip(expected.iter()) {
1657            assert!((r - e).abs() < 1e-6, "got {r}, expected {e}");
1658        }
1659
1660        b.free(a_h).expect("free");
1661        b.free(b_h).expect("free");
1662        b.free(out_h).expect("free");
1663    }
1664
1665    #[test]
1666    fn reduce_sum_small() {
1667        let Some(b) = try_init() else { return };
1668        let input = [1.0f32, 2.0, 3.0, 4.0];
1669        let in_h = upload_f32(&b, &input);
1670        let out_h = b.alloc(4).expect("alloc output"); // single f32
1671
1672        b.reduce(ReduceOp::Sum, in_h, out_h, &[4], 0)
1673            .expect("reduce sum");
1674
1675        let result = download_f32(&b, out_h, 1);
1676        assert!(
1677            (result[0] - 10.0).abs() < 1e-5,
1678            "expected 10.0, got {}",
1679            result[0]
1680        );
1681
1682        b.free(in_h).expect("free");
1683        b.free(out_h).expect("free");
1684    }
1685
1686    #[test]
1687    fn reduce_max_small() {
1688        let Some(b) = try_init() else { return };
1689        let input = [1.0f32, 5.0, 3.0, 2.0];
1690        let in_h = upload_f32(&b, &input);
1691        let out_h = b.alloc(4).expect("alloc output");
1692
1693        b.reduce(ReduceOp::Max, in_h, out_h, &[4], 0)
1694            .expect("reduce max");
1695
1696        let result = download_f32(&b, out_h, 1);
1697        assert!(
1698            (result[0] - 5.0).abs() < 1e-5,
1699            "expected 5.0, got {}",
1700            result[0]
1701        );
1702
1703        b.free(in_h).expect("free");
1704        b.free(out_h).expect("free");
1705    }
1706
1707    #[test]
1708    fn reduce_mean_small() {
1709        let Some(b) = try_init() else { return };
1710        let input = [2.0f32, 4.0, 6.0, 8.0];
1711        let in_h = upload_f32(&b, &input);
1712        let out_h = b.alloc(4).expect("alloc output");
1713
1714        b.reduce(ReduceOp::Mean, in_h, out_h, &[4], 0)
1715            .expect("reduce mean");
1716
1717        let result = download_f32(&b, out_h, 1);
1718        assert!(
1719            (result[0] - 5.0).abs() < 1e-5,
1720            "expected 5.0, got {}",
1721            result[0]
1722        );
1723
1724        b.free(in_h).expect("free");
1725        b.free(out_h).expect("free");
1726    }
1727
1728    #[test]
1729    fn gemm_identity_2x2() {
1730        let Some(b) = try_init() else { return };
1731        // A = [[1,2],[3,4]], B = [[1,0],[0,1]] (identity), C = zeros
1732        // C = 1.0 * A * I + 0.0 * C = A
1733        let a = [1.0f32, 2.0, 3.0, 4.0];
1734        let eye = [1.0f32, 0.0, 0.0, 1.0];
1735        let c_init = [0.0f32; 4];
1736
1737        let a_h = upload_f32(&b, &a);
1738        let b_h = upload_f32(&b, &eye);
1739        let c_h = upload_f32(&b, &c_init);
1740
1741        b.gemm(
1742            BackendTranspose::NoTrans,
1743            BackendTranspose::NoTrans,
1744            2,
1745            2,
1746            2,
1747            1.0,
1748            a_h,
1749            2,
1750            b_h,
1751            2,
1752            0.0,
1753            c_h,
1754            2,
1755        )
1756        .expect("gemm");
1757
1758        let result = download_f32(&b, c_h, 4);
1759        for (r, e) in result.iter().zip(a.iter()) {
1760            assert!((r - e).abs() < 1e-5, "got {r}, expected {e}");
1761        }
1762
1763        b.free(a_h).expect("free");
1764        b.free(b_h).expect("free");
1765        b.free(c_h).expect("free");
1766    }
1767
1768    #[test]
1769    fn gemm_2x3_times_3x2() {
1770        let Some(b) = try_init() else { return };
1771        // A 2x3, B 3x2 → C 2x2
1772        let a = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
1773        let bm = [7.0f32, 8.0, 9.0, 10.0, 11.0, 12.0];
1774        let c_init = [0.0f32; 4];
1775
1776        let a_h = upload_f32(&b, &a);
1777        let b_h = upload_f32(&b, &bm);
1778        let c_h = upload_f32(&b, &c_init);
1779
1780        b.gemm(
1781            BackendTranspose::NoTrans,
1782            BackendTranspose::NoTrans,
1783            2,
1784            2,
1785            3,
1786            1.0,
1787            a_h,
1788            3,
1789            b_h,
1790            2,
1791            0.0,
1792            c_h,
1793            2,
1794        )
1795        .expect("gemm");
1796
1797        // Expected: [[58, 64], [139, 154]]
1798        let result = download_f32(&b, c_h, 4);
1799        let expected = [58.0f32, 64.0, 139.0, 154.0];
1800        for (r, e) in result.iter().zip(expected.iter()) {
1801            assert!((r - e).abs() < 1e-4, "got {r}, expected {e}");
1802        }
1803
1804        b.free(a_h).expect("free");
1805        b.free(b_h).expect("free");
1806        b.free(c_h).expect("free");
1807    }
1808
1809    #[test]
1810    fn gemm_alpha_beta() {
1811        let Some(b) = try_init() else { return };
1812        // C = 2.0 * A * B + 3.0 * C
1813        // A = [[1,0],[0,1]], B = [[1,0],[0,1]], C = [[1,1],[1,1]]
1814        // C = 2*I + 3*ones = [[5,3],[3,5]]
1815        let a = [1.0f32, 0.0, 0.0, 1.0];
1816        let bm = [1.0f32, 0.0, 0.0, 1.0];
1817        let c_init = [1.0f32, 1.0, 1.0, 1.0];
1818
1819        let a_h = upload_f32(&b, &a);
1820        let b_h = upload_f32(&b, &bm);
1821        let c_h = upload_f32(&b, &c_init);
1822
1823        b.gemm(
1824            BackendTranspose::NoTrans,
1825            BackendTranspose::NoTrans,
1826            2,
1827            2,
1828            2,
1829            2.0,
1830            a_h,
1831            2,
1832            b_h,
1833            2,
1834            3.0,
1835            c_h,
1836            2,
1837        )
1838        .expect("gemm alpha+beta");
1839
1840        let result = download_f32(&b, c_h, 4);
1841        let expected = [5.0f32, 3.0, 3.0, 5.0];
1842        for (r, e) in result.iter().zip(expected.iter()) {
1843            assert!((r - e).abs() < 1e-4, "got {r}, expected {e}");
1844        }
1845
1846        b.free(a_h).expect("free");
1847        b.free(b_h).expect("free");
1848        b.free(c_h).expect("free");
1849    }
1850
1851    // ── Conv2D tests ──────────────────────────────────────────────────────
1852
1853    #[test]
1854    fn conv2d_identity_1x1() {
1855        // 1×1 convolution with single channel, no padding, stride=1
1856        // input: 1×1×3×3, filter: 1×1×1×1 (weight=2.0), output: 1×1×3×3
1857        let Some(b) = try_init() else { return };
1858        let input: Vec<f32> = (1..=9).map(|x| x as f32).collect();
1859        let filter = [2.0f32];
1860        let expected: Vec<f32> = input.iter().map(|x| x * 2.0).collect();
1861
1862        let in_h = upload_f32(&b, &input);
1863        let f_h = upload_f32(&b, &filter);
1864        let out_h = b.alloc(9 * 4).expect("alloc output");
1865
1866        b.conv2d_forward(
1867            in_h,
1868            &[1, 1, 3, 3],
1869            f_h,
1870            &[1, 1, 1, 1],
1871            out_h,
1872            &[1, 1, 3, 3],
1873            &[1, 1],
1874            &[0, 0],
1875        )
1876        .expect("conv2d");
1877
1878        let result = download_f32(&b, out_h, 9);
1879        for (r, e) in result.iter().zip(expected.iter()) {
1880            assert!((r - e).abs() < 1e-5, "got {r}, expected {e}");
1881        }
1882
1883        b.free(in_h).expect("free");
1884        b.free(f_h).expect("free");
1885        b.free(out_h).expect("free");
1886    }
1887
1888    #[test]
1889    fn conv2d_3x3_no_padding() {
1890        // input: 1×1×4×4, filter: 1×1×3×3 (all ones), stride=1, pad=0
1891        // output: 1×1×2×2
1892        let Some(b) = try_init() else { return };
1893        let input: Vec<f32> = (0..16).map(|x| x as f32).collect();
1894        let filter = [1.0f32; 9];
1895
1896        let in_h = upload_f32(&b, &input);
1897        let f_h = upload_f32(&b, &filter);
1898        let out_h = b.alloc(4 * 4).expect("alloc output");
1899
1900        b.conv2d_forward(
1901            in_h,
1902            &[1, 1, 4, 4],
1903            f_h,
1904            &[1, 1, 3, 3],
1905            out_h,
1906            &[1, 1, 2, 2],
1907            &[1, 1],
1908            &[0, 0],
1909        )
1910        .expect("conv2d");
1911
1912        let result = download_f32(&b, out_h, 4);
1913        // top-left 3×3 sum: 0+1+2+4+5+6+8+9+10 = 45
1914        assert!((result[0] - 45.0).abs() < 1e-4, "got {}", result[0]);
1915        // top-right 3×3 sum: 1+2+3+5+6+7+9+10+11 = 54
1916        assert!((result[1] - 54.0).abs() < 1e-4, "got {}", result[1]);
1917
1918        b.free(in_h).expect("free");
1919        b.free(f_h).expect("free");
1920        b.free(out_h).expect("free");
1921    }
1922
1923    #[test]
1924    fn conv2d_with_padding() {
1925        // input: 1×1×2×2, filter: 1×1×3×3 (all ones), stride=1, pad=1
1926        // output: 1×1×2×2
1927        // With padding=1 around a 2×2 input, the output is also 2×2.
1928        let Some(b) = try_init() else { return };
1929        let input = [1.0f32, 2.0, 3.0, 4.0];
1930        let filter = [1.0f32; 9];
1931
1932        let in_h = upload_f32(&b, &input);
1933        let f_h = upload_f32(&b, &filter);
1934        let out_h = b.alloc(4 * 4).expect("alloc output");
1935
1936        b.conv2d_forward(
1937            in_h,
1938            &[1, 1, 2, 2],
1939            f_h,
1940            &[1, 1, 3, 3],
1941            out_h,
1942            &[1, 1, 2, 2],
1943            &[1, 1],
1944            &[1, 1],
1945        )
1946        .expect("conv2d");
1947
1948        let result = download_f32(&b, out_h, 4);
1949        // Top-left output: only 4 of 9 filter taps hit valid input
1950        // input[0,0]=1, input[0,1]=2, input[1,0]=3, input[1,1]=4 => sum=10
1951        assert!((result[0] - 10.0).abs() < 1e-4, "got {}", result[0]);
1952
1953        b.free(in_h).expect("free");
1954        b.free(f_h).expect("free");
1955        b.free(out_h).expect("free");
1956    }
1957
1958    // ── Attention tests ───────────────────────────────────────────────────
1959
1960    #[test]
1961    fn attention_uniform_weights() {
1962        // 1 head, seq_q=1, seq_kv=2, head_dim=2, no causal
1963        // Q = [1, 0], K = [[1, 0], [1, 0]], V = [[1, 2], [3, 4]]
1964        // scores = [1*scale, 1*scale] => equal weights => O = mean(V) = [2, 3]
1965        let Some(b) = try_init() else { return };
1966
1967        let q = [1.0f32, 0.0];
1968        let k = [1.0f32, 0.0, 1.0, 0.0];
1969        let v = [1.0f32, 2.0, 3.0, 4.0];
1970
1971        let q_h = upload_f32(&b, &q);
1972        let k_h = upload_f32(&b, &k);
1973        let v_h = upload_f32(&b, &v);
1974        let o_h = b.alloc(2 * 4).expect("alloc output");
1975
1976        b.attention(q_h, k_h, v_h, o_h, 1, 1, 1, 2, 2, 1.0, false)
1977            .expect("attention");
1978
1979        let result = download_f32(&b, o_h, 2);
1980        // Equal scores → equal softmax weights → average of V rows
1981        assert!(
1982            (result[0] - 2.0).abs() < 1e-4,
1983            "got {}, expected 2.0",
1984            result[0]
1985        );
1986        assert!(
1987            (result[1] - 3.0).abs() < 1e-4,
1988            "got {}, expected 3.0",
1989            result[1]
1990        );
1991
1992        b.free(q_h).expect("free");
1993        b.free(k_h).expect("free");
1994        b.free(v_h).expect("free");
1995        b.free(o_h).expect("free");
1996    }
1997
1998    #[test]
1999    fn attention_causal_single_token() {
2000        // 1 head, seq_q=2, seq_kv=2, head_dim=1, causal
2001        // Q = [1, 1], K = [1, 1], V = [10, 20]
2002        // sq=0: only sees sk=0 → O[0] = V[0] = 10
2003        // sq=1: sees sk=0,1 with equal scores → O[1] = (10+20)/2 = 15
2004        let Some(b) = try_init() else { return };
2005
2006        let q = [1.0f32, 1.0];
2007        let k = [1.0f32, 1.0];
2008        let v = [10.0f32, 20.0];
2009
2010        let q_h = upload_f32(&b, &q);
2011        let k_h = upload_f32(&b, &k);
2012        let v_h = upload_f32(&b, &v);
2013        let o_h = b.alloc(2 * 4).expect("alloc output");
2014
2015        b.attention(q_h, k_h, v_h, o_h, 1, 1, 2, 2, 1, 1.0, true)
2016            .expect("attention causal");
2017
2018        let result = download_f32(&b, o_h, 2);
2019        assert!(
2020            (result[0] - 10.0).abs() < 1e-4,
2021            "got {}, expected 10.0",
2022            result[0]
2023        );
2024        assert!(
2025            (result[1] - 15.0).abs() < 1e-4,
2026            "got {}, expected 15.0",
2027            result[1]
2028        );
2029
2030        b.free(q_h).expect("free");
2031        b.free(k_h).expect("free");
2032        b.free(v_h).expect("free");
2033        b.free(o_h).expect("free");
2034    }
2035
2036    // ── Batched GEMM tests ─────────────────────────────────────────────
2037
2038    #[test]
2039    fn batched_gemm_not_initialized() {
2040        let b = WebGpuBackend::new();
2041        let result = b.batched_gemm(
2042            BackendTranspose::NoTrans,
2043            BackendTranspose::NoTrans,
2044            4,
2045            4,
2046            4,
2047            1.0,
2048            0,
2049            4,
2050            16,
2051            0,
2052            4,
2053            16,
2054            0.0,
2055            0,
2056            4,
2057            16,
2058            2,
2059        );
2060        assert_eq!(result, Err(BackendError::NotInitialized));
2061    }
2062
2063    #[test]
2064    fn batched_gemm_zero_batch_noop() {
2065        let Some(b) = try_init() else { return };
2066        let result = b.batched_gemm(
2067            BackendTranspose::NoTrans,
2068            BackendTranspose::NoTrans,
2069            4,
2070            4,
2071            4,
2072            1.0,
2073            0,
2074            4,
2075            16,
2076            0,
2077            4,
2078            16,
2079            0.0,
2080            0,
2081            4,
2082            16,
2083            0, // batch_count = 0
2084        );
2085        assert_eq!(result, Ok(()));
2086    }
2087
2088    #[test]
2089    fn batched_gemm_zero_dims_noop() {
2090        let Some(b) = try_init() else { return };
2091        // m = 0
2092        let result = b.batched_gemm(
2093            BackendTranspose::NoTrans,
2094            BackendTranspose::NoTrans,
2095            0,
2096            4,
2097            4,
2098            1.0,
2099            0,
2100            4,
2101            16,
2102            0,
2103            4,
2104            16,
2105            0.0,
2106            0,
2107            4,
2108            16,
2109            2,
2110        );
2111        assert_eq!(result, Ok(()));
2112        // n = 0
2113        let result = b.batched_gemm(
2114            BackendTranspose::NoTrans,
2115            BackendTranspose::NoTrans,
2116            4,
2117            0,
2118            4,
2119            1.0,
2120            0,
2121            4,
2122            16,
2123            0,
2124            4,
2125            16,
2126            0.0,
2127            0,
2128            4,
2129            16,
2130            2,
2131        );
2132        assert_eq!(result, Ok(()));
2133        // k = 0
2134        let result = b.batched_gemm(
2135            BackendTranspose::NoTrans,
2136            BackendTranspose::NoTrans,
2137            4,
2138            4,
2139            0,
2140            1.0,
2141            0,
2142            4,
2143            16,
2144            0,
2145            4,
2146            16,
2147            0.0,
2148            0,
2149            4,
2150            16,
2151            2,
2152        );
2153        assert_eq!(result, Ok(()));
2154    }
2155
2156    #[test]
2157    fn batched_gemm_identity_2x2() {
2158        let Some(b) = try_init() else { return };
2159        // 2 batches of 2x2 identity multiply
2160        // batch 0: A0=[[1,2],[3,4]] * I = [[1,2],[3,4]]
2161        // batch 1: A1=[[5,6],[7,8]] * I = [[5,6],[7,8]]
2162        let a = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
2163        let eye = [1.0f32, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0];
2164        let c_init = [0.0f32; 8];
2165
2166        let a_h = upload_f32(&b, &a);
2167        let b_h = upload_f32(&b, &eye);
2168        let c_h = upload_f32(&b, &c_init);
2169
2170        b.batched_gemm(
2171            BackendTranspose::NoTrans,
2172            BackendTranspose::NoTrans,
2173            2,
2174            2,
2175            2,
2176            1.0,
2177            a_h,
2178            2,
2179            4, // stride_a = 2*2 = 4
2180            b_h,
2181            2,
2182            4, // stride_b = 4
2183            0.0,
2184            c_h,
2185            2,
2186            4, // stride_c = 4
2187            2, // batch_count
2188        )
2189        .expect("batched_gemm");
2190
2191        let result = download_f32(&b, c_h, 8);
2192        for (r, e) in result.iter().zip(a.iter()) {
2193            assert!((r - e).abs() < 1e-5, "got {r}, expected {e}");
2194        }
2195
2196        b.free(a_h).expect("free");
2197        b.free(b_h).expect("free");
2198        b.free(c_h).expect("free");
2199    }
2200
2201    // ── FP16 GEMM tests ─────────────────────────────────────────────────
2202
2203    #[test]
2204    fn gemm_f16_not_initialized() {
2205        let b = WebGpuBackend::new();
2206        let result = b.gemm_f16(4, 4, 4, 1.0, 0, 0, 0.0, 0);
2207        assert_eq!(result, Err(BackendError::NotInitialized));
2208    }
2209
2210    #[test]
2211    fn gemm_f16_zero_dims_noop() {
2212        let Some(b) = try_init() else { return };
2213        assert_eq!(b.gemm_f16(0, 4, 4, 1.0, 0, 0, 0.0, 0), Ok(()));
2214        assert_eq!(b.gemm_f16(4, 0, 4, 1.0, 0, 0, 0.0, 0), Ok(()));
2215        assert_eq!(b.gemm_f16(4, 4, 0, 1.0, 0, 0, 0.0, 0), Ok(()));
2216    }
2217
2218    #[test]
2219    fn attention_dominant_key() {
2220        // 1 head, seq_q=1, seq_kv=2, head_dim=2, no causal
2221        // Q = [1, 0], K = [[10, 0], [0, 0]], V = [[100, 200], [0, 0]]
2222        // score[0] = 10*scale, score[1] = 0*scale
2223        // With large enough difference, softmax saturates → O ≈ V[0]
2224        let Some(b) = try_init() else { return };
2225
2226        let q = [1.0f32, 0.0];
2227        let k = [10.0f32, 0.0, 0.0, 0.0];
2228        let v = [100.0f32, 200.0, 0.0, 0.0];
2229
2230        let q_h = upload_f32(&b, &q);
2231        let k_h = upload_f32(&b, &k);
2232        let v_h = upload_f32(&b, &v);
2233        let o_h = b.alloc(2 * 4).expect("alloc output");
2234
2235        // scale=1.0 gives scores 10 vs 0 → softmax ≈ [1, 0]
2236        b.attention(q_h, k_h, v_h, o_h, 1, 1, 1, 2, 2, 1.0, false)
2237            .expect("attention dominant");
2238
2239        let result = download_f32(&b, o_h, 2);
2240        assert!(
2241            (result[0] - 100.0).abs() < 0.1,
2242            "got {}, expected ~100",
2243            result[0]
2244        );
2245        assert!(
2246            (result[1] - 200.0).abs() < 0.1,
2247            "got {}, expected ~200",
2248            result[1]
2249        );
2250
2251        b.free(q_h).expect("free");
2252        b.free(k_h).expect("free");
2253        b.free(v_h).expect("free");
2254        b.free(o_h).expect("free");
2255    }
2256}