numrs/backend/webgpu/
mod.rs

1//! WebGPU backend
2//!
3//! Cross-platform GPU acceleration using wgpu.
4//! Works on both native (Vulkan/Metal/DX12) and WASM (WebGPU/WebGL).
5
6pub mod batchnorm;
7pub mod codegen;
8pub mod conv;
9pub mod dropout;
10
11use crate::array::Array;
12use anyhow::{anyhow, Result};
13use std::borrow::Cow;
14
15// Para native: usamos OnceCell + Mutex para cache thread-safe
16#[cfg(not(target_arch = "wasm32"))]
17use once_cell::sync::OnceCell;
18#[cfg(not(target_arch = "wasm32"))]
19use std::sync::Mutex;
20
21// Para WASM: usamos thread_local! que no requiere Sync
22#[cfg(target_arch = "wasm32")]
23use std::cell::RefCell;
24
25// Global flag for WASM WebGPU availability (set from JS binding)
26#[cfg(target_arch = "wasm32")]
27use std::sync::atomic::{AtomicBool, Ordering};
28
29#[cfg(target_arch = "wasm32")]
30static WEBGPU_AVAILABLE_FROM_JS: AtomicBool = AtomicBool::new(false);
31
32#[cfg(target_arch = "wasm32")]
33pub fn set_webgpu_available_wasm(available: bool) {
34    WEBGPU_AVAILABLE_FROM_JS.store(available, Ordering::SeqCst);
35    eprintln!("[numrs-webgpu] WebGPU available flag set to: {}", available);
36}
37
38#[cfg(target_arch = "wasm32")]
39pub fn get_webgpu_available_wasm() -> bool {
40    WEBGPU_AVAILABLE_FROM_JS.load(Ordering::SeqCst)
41}
42
43#[derive(Debug, Clone)]
44pub struct WebGpuBackend {}
45
46impl WebGpuBackend {
47    pub fn new() -> Self {
48        Self {}
49    }
50
51    /// Quick probe helper: is there a usable WebGPU adapter available?
52    pub fn is_available() -> bool {
53        #[cfg(target_arch = "wasm32")]
54        {
55            get_webgpu_available_wasm()
56        }
57        #[cfg(not(target_arch = "wasm32"))]
58        {
59            let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
60                backends: wgpu::Backends::all(),
61                ..Default::default()
62            });
63            let adapter =
64                pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
65                    power_preference: wgpu::PowerPreference::HighPerformance,
66                    compatible_surface: None,
67                    force_fallback_adapter: false,
68                }));
69
70            adapter.is_some()
71        }
72    }
73}
74
75// Cached GPU context to avoid re-creating adapter/device/pipelines every call.
76struct GpuContext {
77    device: wgpu::Device,
78    queue: wgpu::Queue,
79    matmul_pipeline: wgpu::ComputePipeline,
80    matmul_bgl: wgpu::BindGroupLayout,
81}
82
83use std::sync::Arc;
84
85struct DeviceQueue {
86    device: Arc<wgpu::Device>,
87    queue: Arc<wgpu::Queue>,
88}
89
90#[cfg(target_arch = "wasm32")]
91pub async fn init_webgpu_wasm() -> Result<()> {
92    // 1. Check if already initialized to likely avoid re-initialization overhead
93    // (Though wgpu might handle this, it's safer to check our cache)
94    let already_init = GPU_DEVICE.with(|cell| cell.borrow().is_some());
95    if already_init {
96        return Ok(());
97    }
98
99    // 2. Request adapter and device asynchronously (safe in browser main thread)
100    // No debug logs needed anymore
101    let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
102        backends: wgpu::Backends::all(),
103        ..Default::default()
104    });
105
106    let adapter = instance
107        .request_adapter(&wgpu::RequestAdapterOptions {
108            power_preference: wgpu::PowerPreference::None,
109            compatible_surface: None,
110            force_fallback_adapter: false,
111        })
112        .await
113        .ok_or_else(|| anyhow!("no WebGPU adapter available"))?;
114
115    let (device, queue) = adapter
116        .request_device(
117            &wgpu::DeviceDescriptor {
118                label: None,
119                required_features: wgpu::Features::empty(),
120                required_limits: wgpu::Limits::default(),
121                memory_hints: Default::default(),
122            },
123            None,
124        )
125        .await
126        .map_err(|e| anyhow!("request device failed: {:?}", e))?;
127
128    // 3. Store in thread-local cache
129    GPU_DEVICE.with(|cell| {
130        *cell.borrow_mut() = Some(Ok(DeviceQueue {
131            device: Arc::new(device),
132            queue: Arc::new(queue),
133        }));
134    });
135
136    set_webgpu_available_wasm(true);
137    Ok(())
138}
139
140#[cfg(target_arch = "wasm32")]
141fn get_gpu_device() -> Result<DeviceQueue> {
142    GPU_DEVICE.with(|cell| {
143        let borrow = cell.borrow();
144        match borrow.as_ref() {
145            Some(Ok(dq)) => {
146                // Return checked clone (Arc clone is cheap)
147                Ok(DeviceQueue {
148                    device: dq.device.clone(),
149                    queue: dq.queue.clone(),
150                })
151            }
152            Some(Err(e)) => Err(anyhow!("WebGPU init failed previously: {:?}", e)),
153            None => Err(anyhow!(
154                "WebGPU/WebGL not initialized. Ensure init_webgpu() is called."
155            )),
156        }
157    })
158}
159
160// ============================================================================
161// GPU Device Cache - Native vs WASM
162// ============================================================================
163
164#[cfg(not(target_arch = "wasm32"))]
165static GPU_DEVICE: OnceCell<Result<DeviceQueue, anyhow::Error>> = OnceCell::new();
166
167#[cfg(target_arch = "wasm32")]
168thread_local! {
169    static GPU_DEVICE: RefCell<Option<Result<DeviceQueue, anyhow::Error>>> = RefCell::new(None);
170}
171
172#[cfg(not(target_arch = "wasm32"))]
173fn get_gpu_device() -> Result<&'static DeviceQueue> {
174    GPU_DEVICE.get_or_init(|| {
175        let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
176            backends: wgpu::Backends::all(),
177            ..Default::default()
178        });
179        let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
180            power_preference: wgpu::PowerPreference::HighPerformance,
181            compatible_surface: None,
182            force_fallback_adapter: false,
183        }))
184        .ok_or_else(|| anyhow!("no WebGPU adapter available"))?;
185
186        let (device, queue) = pollster::block_on(adapter.request_device(
187            &wgpu::DeviceDescriptor {
188                label: None,
189                required_features: wgpu::Features::empty(), // Minimal features
190                required_limits: wgpu::Limits::default(),   // Default limits
191                memory_hints: Default::default(),
192            },
193            None,
194        ))
195        .map_err(|e| anyhow!("request device failed: {:?}", e))?;
196
197        Ok(DeviceQueue {
198            device: Arc::new(device),
199            queue: Arc::new(queue),
200        })
201    });
202
203    let init_ref = GPU_DEVICE.get().expect("gpu device was just initialized");
204    match init_ref {
205        Ok(dq) => Ok(dq),
206        Err(e) => Err(anyhow!("gpu init failed: {:?}", e)),
207    }
208}
209
210// Helper macro: ejecuta código con referencias a device/queue sin importar el target
211// En native devuelve &'static, en WASM devuelve owned pero podemos tomar prestado
212#[cfg(not(target_arch = "wasm32"))]
213macro_rules! with_gpu_device {
214    ($dq:ident, $code:expr) => {{
215        let $dq = get_gpu_device()?;
216        $code
217    }};
218}
219
220#[cfg(target_arch = "wasm32")]
221macro_rules! with_gpu_device {
222    ($dq:ident, $code:expr) => {{
223        let $dq = get_gpu_device()?;
224        $code
225    }};
226}
227
228// Cached reduction pipeline (pipeline + bind group layout) - solo en native
229#[cfg(not(target_arch = "wasm32"))]
230static REDUCTION_PIPELINE: OnceCell<
231    Result<(wgpu::ComputePipeline, wgpu::BindGroupLayout), anyhow::Error>,
232> = OnceCell::new();
233
234// get_gpu_context solo se usa en fast path (deshabilitado en WASM)
235#[cfg(not(target_arch = "wasm32"))]
236fn get_gpu_context(shader_src: &str) -> Result<&'static GpuContext> {
237    static CTX: OnceCell<Result<GpuContext, anyhow::Error>> = OnceCell::new();
238
239    CTX.get_or_init(|| -> Result<GpuContext, anyhow::Error> {
240        // Create instance, adapter, device and queue
241        eprintln!("NumRs-Core: Starting WebGPU Init...");
242
243        let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
244            backends: wgpu::Backends::all(),
245            ..Default::default()
246        });
247
248        eprintln!("NumRs-Core: Instance created. Requesting adapter...");
249
250        let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
251            power_preference: wgpu::PowerPreference::None,
252            compatible_surface: None,
253            force_fallback_adapter: false,
254        }));
255
256        if adapter.is_none() {
257            eprintln!("NumRs-Core: Adapter request returned None!");
258            return Err(anyhow!("no WebGPU adapter available"));
259        }
260        let adapter = adapter.unwrap();
261
262        eprintln!("NumRs-Core: Adapter found. Requesting device...");
263
264        let (device, queue) = pollster::block_on(adapter.request_device(
265            &wgpu::DeviceDescriptor {
266                label: None,
267                required_features: wgpu::Features::empty(),
268                required_limits: wgpu::Limits::default(),
269                memory_hints: Default::default(),
270            },
271            None,
272        ))
273        .map_err(|e| anyhow!("request device failed: {:?}", e))?;
274
275        // create shader module and pipeline for matmul
276        let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
277            label: Some("matmul_shader"),
278            source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Owned(shader_src.to_string())),
279        });
280
281        let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
282            label: Some("bgl_matmul"),
283            entries: &[
284                wgpu::BindGroupLayoutEntry {
285                    binding: 0,
286                    visibility: wgpu::ShaderStages::COMPUTE,
287                    ty: wgpu::BindingType::Buffer {
288                        ty: wgpu::BufferBindingType::Storage { read_only: true },
289                        has_dynamic_offset: false,
290                        min_binding_size: None,
291                    },
292                    count: None,
293                },
294                wgpu::BindGroupLayoutEntry {
295                    binding: 1,
296                    visibility: wgpu::ShaderStages::COMPUTE,
297                    ty: wgpu::BindingType::Buffer {
298                        ty: wgpu::BufferBindingType::Storage { read_only: true },
299                        has_dynamic_offset: false,
300                        min_binding_size: None,
301                    },
302                    count: None,
303                },
304                wgpu::BindGroupLayoutEntry {
305                    binding: 2,
306                    visibility: wgpu::ShaderStages::COMPUTE,
307                    ty: wgpu::BindingType::Buffer {
308                        ty: wgpu::BufferBindingType::Storage { read_only: false },
309                        has_dynamic_offset: false,
310                        min_binding_size: None,
311                    },
312                    count: None,
313                },
314                wgpu::BindGroupLayoutEntry {
315                    binding: 3,
316                    visibility: wgpu::ShaderStages::COMPUTE,
317                    ty: wgpu::BindingType::Buffer {
318                        ty: wgpu::BufferBindingType::Uniform,
319                        has_dynamic_offset: false,
320                        min_binding_size: None,
321                    },
322                    count: None,
323                },
324            ],
325        });
326
327        let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
328            label: Some("pl_matmul"),
329            bind_group_layouts: &[&bgl],
330            push_constant_ranges: &[],
331        });
332
333        let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
334            label: Some("pipeline_matmul"),
335            layout: Some(&pipeline_layout),
336            module: &shader_module,
337            entry_point: Some("main"),
338            cache: None,
339            compilation_options: Default::default(),
340        });
341
342        Ok(GpuContext {
343            device,
344            queue,
345            matmul_pipeline: compute_pipeline,
346            matmul_bgl: bgl,
347        })
348    });
349
350    let init_ref = CTX.get().expect("OnceCell was just initialized");
351    match init_ref {
352        Ok(ctx) => Ok(ctx),
353        Err(e) => Err(anyhow!("gpu init failed: {:?}", e)),
354    }
355}
356
357// Cached buffers to avoid re-allocating buffers across repeated matmul calls.
358#[cfg(not(target_arch = "wasm32"))]
359struct CachedBuffers {
360    m: u32,
361    n: u32,
362    k: u32,
363    _len: usize,
364    a_buf: wgpu::Buffer,
365    b_buf: wgpu::Buffer,
366    out_buf: wgpu::Buffer,
367    params_buf: wgpu::Buffer,
368    staging: wgpu::Buffer,
369}
370
371#[cfg(not(target_arch = "wasm32"))]
372static BUFFERS: OnceCell<Mutex<Option<CachedBuffers>>> = OnceCell::new();
373
374/// Cached probe helper that performs a single adapter check and caches the result.
375#[cfg(not(target_arch = "wasm32"))]
376pub fn is_available_cached() -> bool {
377    static PROBE: OnceCell<bool> = OnceCell::new();
378    *PROBE.get_or_init(|| WebGpuBackend::is_available())
379}
380
381// En WASM: implementación simplificada sin cache
382#[cfg(target_arch = "wasm32")]
383pub fn is_available_cached() -> bool {
384    // Check the global flag set by JavaScript
385    get_webgpu_available_wasm()
386}
387
388fn run_elementwise_gpu(a: &Array, b: &Array, kind: crate::llo::ElementwiseKind) -> Result<Array> {
389    use wgpu::util::DeviceExt;
390
391    let len = a.len();
392
393    // Generate WGSL depending on the kind
394    let op = match kind {
395        crate::llo::ElementwiseKind::Add => "a[idx] + b[idx]",
396        crate::llo::ElementwiseKind::Mul => "a[idx] * b[idx]",
397        crate::llo::ElementwiseKind::Sub => "a[idx] - b[idx]",
398        crate::llo::ElementwiseKind::Div => "a[idx] / b[idx]",
399        crate::llo::ElementwiseKind::Sqrt => "sqrt(a[idx])",
400        crate::llo::ElementwiseKind::Sin => "sin(a[idx])",
401        crate::llo::ElementwiseKind::Cos => "cos(a[idx])",
402        crate::llo::ElementwiseKind::Pow => "pow(a[idx], b[idx])",
403        crate::llo::ElementwiseKind::Abs => "abs(a[idx])",
404        crate::llo::ElementwiseKind::Neg => "-a[idx]",
405        crate::llo::ElementwiseKind::Exp => "exp(a[idx])",
406        crate::llo::ElementwiseKind::Log => "log(a[idx])",
407        crate::llo::ElementwiseKind::Tan => "tan(a[idx])",
408        crate::llo::ElementwiseKind::Asin => "asin(a[idx])",
409        crate::llo::ElementwiseKind::Acos => "acos(a[idx])",
410        crate::llo::ElementwiseKind::Atan => "atan(a[idx])",
411        crate::llo::ElementwiseKind::Relu => "max(a[idx], 0.0)",
412        crate::llo::ElementwiseKind::LeakyRelu => "select(0.01 * a[idx], a[idx], a[idx] > 0.0)",
413        crate::llo::ElementwiseKind::Sigmoid => "1.0 / (1.0 + exp(-a[idx]))",
414        crate::llo::ElementwiseKind::Tanh => "tanh(a[idx])",
415        crate::llo::ElementwiseKind::Softplus => "log(1.0 + exp(a[idx]))",
416    };
417
418    let shader = format!(
419        r#"
420    struct Params {{ size: u32, }};
421@group(0) @binding(0) var<storage, read> a: array<f32>;
422@group(0) @binding(1) var<storage, read> b: array<f32>;
423@group(0) @binding(2) var<storage, read_write> out: array<f32>;
424@group(0) @binding(3) var<uniform> params: Params;
425
426@compute @workgroup_size(64)
427fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
428    let idx: u32 = gid.x;
429    if (idx >= params.size) {{ return; }}
430    out[idx] = {op};
431}}
432"#
433    );
434
435    // Create instance, adapter, device and queue
436    let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
437        backends: wgpu::Backends::all(),
438        ..Default::default()
439    });
440    let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
441        power_preference: wgpu::PowerPreference::HighPerformance,
442        compatible_surface: None,
443        force_fallback_adapter: false,
444    }))
445    .ok_or_else(|| anyhow!("no WebGPU adapter available"))?;
446
447    let (device, queue) = pollster::block_on(adapter.request_device(
448        &wgpu::DeviceDescriptor {
449            label: None,
450            required_features: wgpu::Features::empty(),
451            required_limits: wgpu::Limits::default(),
452            memory_hints: Default::default(),
453        },
454        None,
455    ))
456    .map_err(|e| anyhow!("request device failed: {:?}", e))?;
457
458    // Create buffers
459    let a_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
460        label: Some("a_buf"),
461        contents: bytemuck::cast_slice(&a.data),
462        usage: wgpu::BufferUsages::STORAGE,
463    });
464
465    let b_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
466        label: Some("b_buf"),
467        contents: bytemuck::cast_slice(&b.data),
468        usage: wgpu::BufferUsages::STORAGE,
469    });
470
471    let out_buf = device.create_buffer(&wgpu::BufferDescriptor {
472        label: Some("out_buf"),
473        size: (len * std::mem::size_of::<f32>()) as u64,
474        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
475        mapped_at_creation: false,
476    });
477
478    let params = [len as u32];
479    let params_bytes = bytemuck::cast_slice(&params);
480    let params_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
481        label: Some("params"),
482        contents: params_bytes,
483        usage: wgpu::BufferUsages::UNIFORM,
484    });
485
486    // staging buffer for readback
487    let staging = device.create_buffer(&wgpu::BufferDescriptor {
488        label: Some("staging"),
489        size: (len * std::mem::size_of::<f32>()) as u64,
490        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
491        mapped_at_creation: false,
492    });
493
494    let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
495        label: Some("elementwise_shader"),
496        source: wgpu::ShaderSource::Wgsl(Cow::Owned(shader)),
497    });
498
499    // Bind group layout and pipeline
500    let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
501        label: Some("bgl"),
502        entries: &[
503            wgpu::BindGroupLayoutEntry {
504                binding: 0,
505                visibility: wgpu::ShaderStages::COMPUTE,
506                ty: wgpu::BindingType::Buffer {
507                    ty: wgpu::BufferBindingType::Storage { read_only: true },
508                    has_dynamic_offset: false,
509                    min_binding_size: None,
510                },
511                count: None,
512            },
513            wgpu::BindGroupLayoutEntry {
514                binding: 1,
515                visibility: wgpu::ShaderStages::COMPUTE,
516                ty: wgpu::BindingType::Buffer {
517                    ty: wgpu::BufferBindingType::Storage { read_only: true },
518                    has_dynamic_offset: false,
519                    min_binding_size: None,
520                },
521                count: None,
522            },
523            wgpu::BindGroupLayoutEntry {
524                binding: 2,
525                visibility: wgpu::ShaderStages::COMPUTE,
526                ty: wgpu::BindingType::Buffer {
527                    ty: wgpu::BufferBindingType::Storage { read_only: false },
528                    has_dynamic_offset: false,
529                    min_binding_size: None,
530                },
531                count: None,
532            },
533            wgpu::BindGroupLayoutEntry {
534                binding: 3,
535                visibility: wgpu::ShaderStages::COMPUTE,
536                ty: wgpu::BindingType::Buffer {
537                    ty: wgpu::BufferBindingType::Uniform,
538                    has_dynamic_offset: false,
539                    min_binding_size: None,
540                },
541                count: None,
542            },
543        ],
544    });
545
546    let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
547        label: Some("pl"),
548        bind_group_layouts: &[&bgl],
549        push_constant_ranges: &[],
550    });
551
552    let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
553        label: Some("pipeline"),
554        layout: Some(&pipeline_layout),
555        module: &shader_module,
556        entry_point: Some("main"),
557        cache: None,
558        compilation_options: Default::default(),
559    });
560
561    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
562        label: Some("bg"),
563        layout: &bgl,
564        entries: &[
565            wgpu::BindGroupEntry {
566                binding: 0,
567                resource: a_buf.as_entire_binding(),
568            },
569            wgpu::BindGroupEntry {
570                binding: 1,
571                resource: b_buf.as_entire_binding(),
572            },
573            wgpu::BindGroupEntry {
574                binding: 2,
575                resource: out_buf.as_entire_binding(),
576            },
577            wgpu::BindGroupEntry {
578                binding: 3,
579                resource: params_buf.as_entire_binding(),
580            },
581        ],
582    });
583
584    let mut encoder =
585        device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("ce") });
586
587    {
588        let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
589            label: Some("cp"),
590            timestamp_writes: None,
591        });
592        cpass.set_pipeline(&compute_pipeline);
593        cpass.set_bind_group(0, &bind_group, &[]);
594
595        let workgroups = ((len as u32) + 63) / 64;
596        cpass.dispatch_workgroups(workgroups, 1, 1);
597    }
598
599    // copy to staging
600    encoder.copy_buffer_to_buffer(
601        &out_buf,
602        0,
603        &staging,
604        0,
605        (len * std::mem::size_of::<f32>()) as u64,
606    );
607
608    queue.submit(Some(encoder.finish()));
609
610    // map staging and read back
611    let buffer_slice = staging.slice(..);
612    // map_async requires a callback for non-async contexts — we'll use a
613    // synchronous channel to wait for completion.
614    use std::sync::mpsc::channel;
615    let (tx, rx) = channel();
616    buffer_slice.map_async(wgpu::MapMode::Read, move |r| {
617        let _ = tx.send(r);
618    });
619    device.poll(wgpu::Maintain::Wait);
620    let ok = rx
621        .recv()
622        .map_err(|_| anyhow!("map callback channel error"))?;
623    ok.map_err(|e| anyhow!("map async failed: {:?}", e))?;
624
625    let data = buffer_slice.get_mapped_range();
626    let mut out_vec = Vec::with_capacity(len);
627    // convert bytes -> f32
628    for chunk in data.chunks_exact(4) {
629        let b = [chunk[0], chunk[1], chunk[2], chunk[3]];
630        out_vec.push(f32::from_bits(u32::from_le_bytes(b)));
631    }
632
633    drop(data);
634    staging.unmap();
635
636    Ok(Array::new(a.shape.clone(), out_vec))
637}
638
639// Fast path: original optimized tiled matmul that uses monolithic buffers and a
640// WGSL kernel with workgroup/shared memory. This is high-performance but requires
641// that the device supports buffers large enough for the full A/B/output tiles.
642#[cfg(not(target_arch = "wasm32"))]
643fn run_matmul_gpu_fast(a: &Array, b: &Array) -> Result<Array> {
644    use wgpu::util::DeviceExt;
645    let m = a.shape[0] as u32;
646    let k = a.shape[1] as u32;
647    let n = b.shape[1] as u32;
648    let len = (m as usize) * (n as usize);
649
650    // Optimized tiled WGSL matmul using workgroup memory.
651    // TILE = 32x32 (increased from 16x16 for better occupancy)
652    // Each thread computes a 4x4 sub-block using vec4 operations for maximum throughput
653    // Total: 8x8 threads = 64 threads per workgroup processing 32x32 outputs
654    let shader = format!(
655        r#"
656struct Params {{ m: u32, n: u32, k: u32, }};
657@group(0) @binding(0) var<storage, read> a: array<f32>;
658@group(0) @binding(1) var<storage, read> b: array<f32>;
659@group(0) @binding(2) var<storage, read_write> out: array<f32>;
660@group(0) @binding(3) var<uniform> params: Params;
661
662const TILE: u32 = 32u;
663var<workgroup> tileA: array<f32, 1024>;  // 32x32 tile
664var<workgroup> tileB: array<f32, 1024>;  // 32x32 tile
665
666@compute @workgroup_size(8, 8)
667fn main(@builtin(global_invocation_id) _gid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>, @builtin(workgroup_id) wid: vec3<u32>) {{
668    let row_base: u32 = wid.y * TILE;
669    let col_base: u32 = wid.x * TILE;
670
671    // Each thread computes a 4x4 block starting at these local indices
672    let local_r0: u32 = lid.y * 4u;
673    let local_c0: u32 = lid.x * 4u;
674
675    // Accumulator registers for 4x4 output block (16 values)
676    var sum00: vec4<f32> = vec4<f32>(0.0, 0.0, 0.0, 0.0);
677    var sum10: vec4<f32> = vec4<f32>(0.0, 0.0, 0.0, 0.0);
678    var sum20: vec4<f32> = vec4<f32>(0.0, 0.0, 0.0, 0.0);
679    var sum30: vec4<f32> = vec4<f32>(0.0, 0.0, 0.0, 0.0);
680
681    var k0: u32 = 0u;
682    loop {{
683        if (k0 >= params.k) {{ break; }}
684
685        // Cooperative loading: each thread loads 4x4 elements into shared memory
686        // This ensures coalesced memory access and full utilization
687        for (var dy: u32 = 0u; dy < 4u; dy = dy + 1u) {{
688            for (var dx: u32 = 0u; dx < 4u; dx = dx + 1u) {{
689                let lr = local_r0 + dy;
690                let lc = local_c0 + dx;
691                
692                // Load A tile
693                let a_row = row_base + lr;
694                let a_col = k0 + lc;
695                if (a_row < params.m && a_col < params.k) {{
696                    tileA[lr * TILE + lc] = a[a_row * params.k + a_col];
697                }} else {{
698                    tileA[lr * TILE + lc] = 0.0;
699                }}
700
701                // Load B tile
702                let b_row = k0 + lr;
703                let b_col = col_base + lc;
704                if (b_row < params.k && b_col < params.n) {{
705                    tileB[lr * TILE + lc] = b[b_row * params.n + b_col];
706                }} else {{
707                    tileB[lr * TILE + lc] = 0.0;
708                }}
709            }}
710        }}
711
712        workgroupBarrier();
713
714        // Inner loop: matrix multiply within tile using vec4 operations
715        // Process 4 k-elements at a time for better arithmetic throughput
716        var t: u32 = 0u;
717        loop {{
718            if (t + 4u > TILE) {{ break; }}
719
720            // Load A vectors for 4 output rows
721            let a0 = vec4<f32>(
722                tileA[local_r0 * TILE + t],
723                tileA[local_r0 * TILE + t + 1u],
724                tileA[local_r0 * TILE + t + 2u],
725                tileA[local_r0 * TILE + t + 3u]
726            );
727            let a1 = vec4<f32>(
728                tileA[(local_r0 + 1u) * TILE + t],
729                tileA[(local_r0 + 1u) * TILE + t + 1u],
730                tileA[(local_r0 + 1u) * TILE + t + 2u],
731                tileA[(local_r0 + 1u) * TILE + t + 3u]
732            );
733            let a2 = vec4<f32>(
734                tileA[(local_r0 + 2u) * TILE + t],
735                tileA[(local_r0 + 2u) * TILE + t + 1u],
736                tileA[(local_r0 + 2u) * TILE + t + 2u],
737                tileA[(local_r0 + 2u) * TILE + t + 3u]
738            );
739            let a3 = vec4<f32>(
740                tileA[(local_r0 + 3u) * TILE + t],
741                tileA[(local_r0 + 3u) * TILE + t + 1u],
742                tileA[(local_r0 + 3u) * TILE + t + 2u],
743                tileA[(local_r0 + 3u) * TILE + t + 3u]
744            );
745
746            // Load B vectors for 4 output columns (transposed access pattern)
747            let b0 = vec4<f32>(
748                tileB[t * TILE + local_c0],
749                tileB[(t + 1u) * TILE + local_c0],
750                tileB[(t + 2u) * TILE + local_c0],
751                tileB[(t + 3u) * TILE + local_c0]
752            );
753            let b1 = vec4<f32>(
754                tileB[t * TILE + local_c0 + 1u],
755                tileB[(t + 1u) * TILE + local_c0 + 1u],
756                tileB[(t + 2u) * TILE + local_c0 + 1u],
757                tileB[(t + 3u) * TILE + local_c0 + 1u]
758            );
759            let b2 = vec4<f32>(
760                tileB[t * TILE + local_c0 + 2u],
761                tileB[(t + 1u) * TILE + local_c0 + 2u],
762                tileB[(t + 2u) * TILE + local_c0 + 2u],
763                tileB[(t + 3u) * TILE + local_c0 + 2u]
764            );
765            let b3 = vec4<f32>(
766                tileB[t * TILE + local_c0 + 3u],
767                tileB[(t + 1u) * TILE + local_c0 + 3u],
768                tileB[(t + 2u) * TILE + local_c0 + 3u],
769                tileB[(t + 3u) * TILE + local_c0 + 3u]
770            );
771
772            // Compute 4x4 block using dot products (16 FMA operations)
773            sum00 = sum00 + vec4<f32>(dot(a0, b0), dot(a0, b1), dot(a0, b2), dot(a0, b3));
774            sum10 = sum10 + vec4<f32>(dot(a1, b0), dot(a1, b1), dot(a1, b2), dot(a1, b3));
775            sum20 = sum20 + vec4<f32>(dot(a2, b0), dot(a2, b1), dot(a2, b2), dot(a2, b3));
776            sum30 = sum30 + vec4<f32>(dot(a3, b0), dot(a3, b1), dot(a3, b2), dot(a3, b3));
777
778            t = t + 4u;
779        }}
780
781        workgroupBarrier();
782        k0 = k0 + TILE;
783    }}
784
785    // Write 4x4 block results back to output (with bounds checking)
786    for (var row_off: u32 = 0u; row_off < 4u; row_off = row_off + 1u) {{
787        let global_row = row_base + local_r0 + row_off;
788        if (global_row >= params.m) {{ continue; }}
789        
790        let result_vec = select(
791            sum00,
792            select(sum10, select(sum20, sum30, row_off == 3u), row_off == 2u),
793            row_off == 1u
794        );
795        
796        for (var col_off: u32 = 0u; col_off < 4u; col_off = col_off + 1u) {{
797            let global_col = col_base + local_c0 + col_off;
798            if (global_col < params.n) {{
799                out[global_row * params.n + global_col] = result_vec[col_off];
800            }}
801        }}
802    }}
803}}
804"#
805    );
806
807    // Use cached GPU context (device/queue/pipeline/layout) to avoid setup overhead
808    let ctx = get_gpu_context(&shader)?;
809
810    let device = &ctx.device;
811    let queue = &ctx.queue;
812
813    // Prepare or reuse GPU buffers for this shape
814    let buf_mutex = BUFFERS.get_or_init(|| Mutex::new(None));
815    let mut guard = buf_mutex.lock().unwrap();
816
817    if guard
818        .as_ref()
819        .map(|c| c.m != m || c.n != n || c.k != k)
820        .unwrap_or(true)
821    {
822        // allocate fresh buffers sized for this matmul
823        let a_buf = device.create_buffer(&wgpu::BufferDescriptor {
824            label: Some("a_buf"),
825            size: ((m as usize * k as usize) * std::mem::size_of::<f32>()) as u64,
826            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
827            mapped_at_creation: false,
828        });
829
830        let b_buf = device.create_buffer(&wgpu::BufferDescriptor {
831            label: Some("b_buf"),
832            size: ((k as usize * n as usize) * std::mem::size_of::<f32>()) as u64,
833            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
834            mapped_at_creation: false,
835        });
836
837        let out_buf = device.create_buffer(&wgpu::BufferDescriptor {
838            label: Some("out_buf"),
839            size: (len * std::mem::size_of::<f32>()) as u64,
840            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
841            mapped_at_creation: false,
842        });
843
844        let params = [m, n, k];
845        let params_bytes = bytemuck::cast_slice(&params);
846        let params_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
847            label: Some("params"),
848            contents: params_bytes,
849            usage: wgpu::BufferUsages::UNIFORM,
850        });
851
852        let staging = device.create_buffer(&wgpu::BufferDescriptor {
853            label: Some("staging"),
854            size: (len * std::mem::size_of::<f32>()) as u64,
855            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
856            mapped_at_creation: false,
857        });
858
859        *guard = Some(CachedBuffers {
860            m,
861            n,
862            k,
863            _len: len,
864            a_buf,
865            b_buf,
866            out_buf,
867            params_buf,
868            staging,
869        });
870    }
871
872    let cached = guard.as_ref().unwrap();
873
874    // upload input matrices using queue.write_buffer (avoids buffer re-allocation)
875    // a.data and b.data are f32 slices
876    let a_bytes = bytemuck::cast_slice(&a.data);
877    let b_bytes = bytemuck::cast_slice(&b.data);
878    queue.write_buffer(&cached.a_buf, 0, a_bytes);
879    queue.write_buffer(&cached.b_buf, 0, b_bytes);
880
881    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
882        label: Some("bg_matmul"),
883        layout: &ctx.matmul_bgl,
884        entries: &[
885            wgpu::BindGroupEntry {
886                binding: 0,
887                resource: cached.a_buf.as_entire_binding(),
888            },
889            wgpu::BindGroupEntry {
890                binding: 1,
891                resource: cached.b_buf.as_entire_binding(),
892            },
893            wgpu::BindGroupEntry {
894                binding: 2,
895                resource: cached.out_buf.as_entire_binding(),
896            },
897            wgpu::BindGroupEntry {
898                binding: 3,
899                resource: cached.params_buf.as_entire_binding(),
900            },
901        ],
902    });
903
904    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
905        label: Some("ce_matmul"),
906    });
907
908    {
909        let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
910            label: Some("cp_matmul"),
911            timestamp_writes: None,
912        });
913        cpass.set_pipeline(&ctx.matmul_pipeline);
914        cpass.set_bind_group(0, &bind_group, &[]);
915
916        // Updated for 32x32 tiles (increased from 16x16)
917        let wg_x = ((n + 31) / 32) as u32;
918        let wg_y = ((m + 31) / 32) as u32;
919        cpass.dispatch_workgroups(wg_x, wg_y, 1);
920    }
921
922    encoder.copy_buffer_to_buffer(
923        &cached.out_buf,
924        0,
925        &cached.staging,
926        0,
927        (len * std::mem::size_of::<f32>()) as u64,
928    );
929
930    queue.submit(Some(encoder.finish()));
931
932    // read back
933    let buffer_slice = cached.staging.slice(..);
934    use std::sync::mpsc::channel;
935    let (tx, rx) = channel();
936    buffer_slice.map_async(wgpu::MapMode::Read, move |r| {
937        let _ = tx.send(r);
938    });
939    device.poll(wgpu::Maintain::Wait);
940    let ok = rx
941        .recv()
942        .map_err(|_| anyhow!("map callback channel error"))?;
943    ok.map_err(|e| anyhow!("map async failed: {:?}", e))?;
944
945    let data = buffer_slice.get_mapped_range();
946    let mut out_vec = Vec::with_capacity(len);
947    for chunk in data.chunks_exact(4) {
948        let b = [chunk[0], chunk[1], chunk[2], chunk[3]];
949        out_vec.push(f32::from_bits(u32::from_le_bytes(b)));
950    }
951    drop(data);
952    cached.staging.unmap();
953
954    Ok(Array::new(vec![m as usize, n as usize], out_vec))
955}
956
957// ============================================================================
958// Public API for Dispatch System
959// ============================================================================
960
961/// Elementwise operations on WebGPU (public API for dispatch)
962pub fn elementwise_webgpu(
963    a: &Array,
964    b: &Array,
965    kind: crate::llo::ElementwiseKind,
966) -> Result<Array> {
967    run_elementwise_gpu(a, b, kind)
968}
969
970/// Matrix multiplication on WebGPU (public API for dispatch)
971pub fn matmul_webgpu(a: &Array, b: &Array) -> Array {
972    run_matmul_gpu(a, b).expect("WebGPU matmul failed")
973}
974
975/// Reduction operations on WebGPU (public API for dispatch)
976pub fn reduction_webgpu(a: &Array, axis: Option<usize>) -> Result<Array> {
977    run_reduction_gpu(a, axis)
978}
979
980/// Broadcast operation on WebGPU (public API for dispatch)
981pub fn broadcast_to_webgpu(a: &Array, target_shape: &[usize]) -> Result<Array> {
982    run_broadcast_gpu(a, target_shape)
983}
984
985// ============================================================================
986// Internal Implementation
987// ============================================================================
988
989// Streaming path: for very large matrices that would require buffers exceeding the
990// device limits, process tiles on the host and stream smaller A/B tiles to GPU.
991fn run_matmul_gpu_streaming(a: &Array, b: &Array) -> Result<Array> {
992    use std::cmp::min;
993    use wgpu::util::DeviceExt;
994
995    let m = a.shape[0] as usize;
996    let k = a.shape[1] as usize;
997    let n = b.shape[1] as usize;
998
999    // simple per-tile kernel: computes a tm x tn partial for given tk
1000    let tile_shader = r#"
1001struct Params { tm: u32, tn: u32, tk: u32 };
1002@group(0) @binding(0) var<storage, read> a: array<f32>;
1003@group(0) @binding(1) var<storage, read> b: array<f32>;
1004@group(0) @binding(2) var<storage, read_write> out: array<f32>;
1005@group(0) @binding(3) var<uniform> params: Params;
1006
1007@compute @workgroup_size(16,16)
1008fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
1009    let row = gid.y;
1010    let col = gid.x;
1011    if (row >= params.tm || col >= params.tn) { return; }
1012    var sum: f32 = 0.0;
1013    var kk: u32 = 0u;
1014    loop {
1015        if (kk >= params.tk) { break; }
1016        sum = sum + a[row * params.tk + kk] * b[kk * params.tn + col];
1017        kk = kk + 1u;
1018    }
1019    let idx = row * params.tn + col;
1020    out[idx] = out[idx] + sum;
1021}
1022"#;
1023
1024    with_gpu_device!(dq, {
1025        let device = &dq.device;
1026        let queue = &dq.queue;
1027
1028        let max_buf_bytes = device.limits().max_storage_buffer_binding_size as usize;
1029        let max_elems = max_buf_bytes / std::mem::size_of::<f32>();
1030
1031        let prefer_tile = 1024usize;
1032        let tile_k = std::cmp::min(
1033            k,
1034            std::cmp::min(
1035                prefer_tile,
1036                std::cmp::max(1, max_elems / std::cmp::max(1, std::cmp::max(m, n))),
1037            ),
1038        );
1039        let tile_m = std::cmp::min(
1040            m,
1041            std::cmp::max(
1042                64,
1043                std::cmp::min(
1044                    prefer_tile,
1045                    std::cmp::max(1, max_elems / std::cmp::max(1, k)),
1046                ),
1047            ),
1048        );
1049        let tile_n = tile_m;
1050
1051        // create module/pipeline once
1052        let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
1053            label: Some("tile_matmul"),
1054            source: wgpu::ShaderSource::Wgsl(Cow::Owned(tile_shader.to_string())),
1055        });
1056        let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
1057            label: Some("bgl_tile_matmul"),
1058            entries: &[
1059                wgpu::BindGroupLayoutEntry {
1060                    binding: 0,
1061                    visibility: wgpu::ShaderStages::COMPUTE,
1062                    ty: wgpu::BindingType::Buffer {
1063                        ty: wgpu::BufferBindingType::Storage { read_only: true },
1064                        has_dynamic_offset: false,
1065                        min_binding_size: None,
1066                    },
1067                    count: None,
1068                },
1069                wgpu::BindGroupLayoutEntry {
1070                    binding: 1,
1071                    visibility: wgpu::ShaderStages::COMPUTE,
1072                    ty: wgpu::BindingType::Buffer {
1073                        ty: wgpu::BufferBindingType::Storage { read_only: true },
1074                        has_dynamic_offset: false,
1075                        min_binding_size: None,
1076                    },
1077                    count: None,
1078                },
1079                wgpu::BindGroupLayoutEntry {
1080                    binding: 2,
1081                    visibility: wgpu::ShaderStages::COMPUTE,
1082                    ty: wgpu::BindingType::Buffer {
1083                        ty: wgpu::BufferBindingType::Storage { read_only: false },
1084                        has_dynamic_offset: false,
1085                        min_binding_size: None,
1086                    },
1087                    count: None,
1088                },
1089                wgpu::BindGroupLayoutEntry {
1090                    binding: 3,
1091                    visibility: wgpu::ShaderStages::COMPUTE,
1092                    ty: wgpu::BindingType::Buffer {
1093                        ty: wgpu::BufferBindingType::Uniform,
1094                        has_dynamic_offset: false,
1095                        min_binding_size: None,
1096                    },
1097                    count: None,
1098                },
1099            ],
1100        });
1101        let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
1102            label: Some("pl_tile_matmul"),
1103            bind_group_layouts: &[&bgl],
1104            push_constant_ranges: &[],
1105        });
1106        let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1107            label: Some("pipeline_tile_matmul"),
1108            layout: Some(&pipeline_layout),
1109            module: &shader_module,
1110            entry_point: Some("main"),
1111            cache: None,
1112            compilation_options: Default::default(),
1113        });
1114
1115        let mut result = vec![0.0f32; m * n];
1116
1117        for i in (0..m).step_by(tile_m) {
1118            let tm = min(tile_m, m - i);
1119            for j in (0..n).step_by(tile_n) {
1120                let tn = min(tile_n, n - j);
1121
1122                let out_size = (tm * tn) * std::mem::size_of::<f32>();
1123                let out_buf = device.create_buffer(&wgpu::BufferDescriptor {
1124                    label: Some("out_tile"),
1125                    size: out_size as u64,
1126                    usage: wgpu::BufferUsages::STORAGE
1127                        | wgpu::BufferUsages::COPY_SRC
1128                        | wgpu::BufferUsages::COPY_DST,
1129                    mapped_at_creation: false,
1130                });
1131                let zeros = vec![0u8; out_size];
1132                queue.write_buffer(&out_buf, 0, &zeros);
1133
1134                let mut p = 0usize;
1135                while p < k {
1136                    let tk = min(tile_k, k - p);
1137
1138                    let mut a_tile = Vec::with_capacity(tm * tk);
1139                    for ii in 0..tm {
1140                        let row = i + ii;
1141                        let src_off = row * k + p;
1142                        a_tile.extend_from_slice(&a.data[src_off..src_off + tk]);
1143                    }
1144
1145                    let mut b_tile = Vec::with_capacity(tk * tn);
1146                    for kk in 0..tk {
1147                        let src_off = (p + kk) * n + j;
1148                        b_tile.extend_from_slice(&b.data[src_off..src_off + tn]);
1149                    }
1150
1151                    let a_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1152                        label: Some("a_tile"),
1153                        contents: bytemuck::cast_slice(&a_tile),
1154                        usage: wgpu::BufferUsages::STORAGE,
1155                    });
1156                    let b_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1157                        label: Some("b_tile"),
1158                        contents: bytemuck::cast_slice(&b_tile),
1159                        usage: wgpu::BufferUsages::STORAGE,
1160                    });
1161
1162                    let params = [tm as u32, tn as u32, tk as u32];
1163                    let params_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1164                        label: Some("params_tile"),
1165                        contents: bytemuck::cast_slice(&params),
1166                        usage: wgpu::BufferUsages::UNIFORM,
1167                    });
1168
1169                    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
1170                        label: Some("bg_tile"),
1171                        layout: &bgl,
1172                        entries: &[
1173                            wgpu::BindGroupEntry {
1174                                binding: 0,
1175                                resource: a_buf.as_entire_binding(),
1176                            },
1177                            wgpu::BindGroupEntry {
1178                                binding: 1,
1179                                resource: b_buf.as_entire_binding(),
1180                            },
1181                            wgpu::BindGroupEntry {
1182                                binding: 2,
1183                                resource: out_buf.as_entire_binding(),
1184                            },
1185                            wgpu::BindGroupEntry {
1186                                binding: 3,
1187                                resource: params_buf.as_entire_binding(),
1188                            },
1189                        ],
1190                    });
1191
1192                    let mut encoder =
1193                        device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
1194                            label: Some("ce_tile"),
1195                        });
1196                    {
1197                        let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1198                            label: Some("cp_tile"),
1199                            timestamp_writes: None,
1200                        });
1201                        cpass.set_pipeline(&compute_pipeline);
1202                        cpass.set_bind_group(0, &bind_group, &[]);
1203                        let wg_x = ((tn as u32) + 15) / 16;
1204                        let wg_y = ((tm as u32) + 15) / 16;
1205                        cpass.dispatch_workgroups(wg_x, wg_y, 1);
1206                    }
1207
1208                    queue.submit(Some(encoder.finish()));
1209                    device.poll(wgpu::Maintain::Wait);
1210
1211                    p += tk;
1212                }
1213
1214                let staging = device.create_buffer(&wgpu::BufferDescriptor {
1215                    label: Some("staging_out_tile"),
1216                    size: out_size as u64,
1217                    usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
1218                    mapped_at_creation: false,
1219                });
1220                let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
1221                    label: Some("ce_copy_out"),
1222                });
1223                encoder.copy_buffer_to_buffer(&out_buf, 0, &staging, 0, out_size as u64);
1224                queue.submit(Some(encoder.finish()));
1225
1226                let buffer_slice = staging.slice(..);
1227                use std::sync::mpsc::channel;
1228                let (tx, rx) = channel();
1229                buffer_slice.map_async(wgpu::MapMode::Read, move |r| {
1230                    let _ = tx.send(r);
1231                });
1232                device.poll(wgpu::Maintain::Wait);
1233                let ok = rx
1234                    .recv()
1235                    .map_err(|_| anyhow!("map callback channel error"))?;
1236                ok.map_err(|e| anyhow!("map async failed: {:?}", e))?;
1237                let data = buffer_slice.get_mapped_range();
1238
1239                let mut idx = 0usize;
1240                for rr in 0..tm {
1241                    let dest_off = (i + rr) * n + j;
1242                    let row_bytes = &data[idx * 4..(idx + tn) * 4];
1243                    for cc in 0..tn {
1244                        let b0 = row_bytes[cc * 4..cc * 4 + 4].try_into().unwrap();
1245                        result[dest_off + cc] = f32::from_bits(u32::from_le_bytes(b0));
1246                    }
1247                    idx += tn;
1248                }
1249                drop(data);
1250                staging.unmap();
1251            }
1252        }
1253
1254        Ok(Array::new(vec![m, n], result))
1255    })
1256}
1257
1258// Dispatcher: choose fast path when buffers fit, otherwise fall back to streaming.
1259fn run_matmul_gpu(a: &Array, b: &Array) -> Result<Array> {
1260    // En WASM: siempre usar streaming (evita complicaciones con static caches)
1261    #[cfg(target_arch = "wasm32")]
1262    {
1263        return run_matmul_gpu_streaming(a, b);
1264    }
1265
1266    // En native: usar fast path si los buffers caben
1267    #[cfg(not(target_arch = "wasm32"))]
1268    {
1269        with_gpu_device!(dq, {
1270            let device = &dq.device;
1271
1272            let m = a.shape[0] as usize;
1273            let k = a.shape[1] as usize;
1274            let n = b.shape[1] as usize;
1275
1276            let bytes_a = m * k * std::mem::size_of::<f32>();
1277            let bytes_b = k * n * std::mem::size_of::<f32>();
1278            let bytes_out = m * n * std::mem::size_of::<f32>();
1279
1280            let max = device.limits().max_storage_buffer_binding_size as usize;
1281
1282            if bytes_a <= max && bytes_b <= max && bytes_out <= max {
1283                run_matmul_gpu_fast(a, b)
1284            } else {
1285                run_matmul_gpu_streaming(a, b)
1286            }
1287        })
1288    }
1289}
1290
1291fn run_reduction_gpu(a: &Array, axis: Option<usize>) -> Result<Array> {
1292    use wgpu::util::DeviceExt;
1293
1294    if axis.is_some() {
1295        return Err(anyhow!(
1296            "axis-based reduction not implemented in GPU prototype"
1297        ));
1298    }
1299
1300    let size = a.len() as u32;
1301    if size == 0 {
1302        return Ok(Array::new(vec![1], vec![0.0]));
1303    }
1304
1305    // WGSL: each workgroup reduces up to WORKGROUP_SIZE elements into one partial
1306    const WG_SIZE: u32 = 256u32;
1307
1308    // simple WGSL kernel: each invocation loads one element (if in range), does local reduction into workgroup memory, thread 0 writes partial
1309    let shader = format!(
1310        r#"
1311struct Params {{ size: u32, }};
1312@group(0) @binding(0) var<storage, read> data: array<f32>;
1313@group(0) @binding(1) var<storage, read_write> partials: array<f32>;
1314@group(0) @binding(2) var<uniform> params: Params;
1315
1316var<workgroup> sdata: array<f32, {wg}>;
1317
1318@compute @workgroup_size({wg})
1319fn main(@builtin(global_invocation_id) gid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>, @builtin(workgroup_id) wid: vec3<u32>) {{
1320    let local = lid.x;
1321    let group = wid.x;
1322    let idx = group * {wg}u + local;
1323    var v: f32 = 0.0;
1324    if (idx < params.size) {{ v = data[idx]; }}
1325    sdata[local] = v;
1326    workgroupBarrier();
1327
1328    var stride: u32 = {wg}u / 2u;
1329    loop {{
1330        if (stride == 0u) {{ break; }}
1331        if (local < stride) {{
1332            sdata[local] = sdata[local] + sdata[local + stride];
1333        }}
1334        workgroupBarrier();
1335        stride = stride / 2u;
1336    }}
1337
1338    if (local == 0u) {{
1339        partials[group] = sdata[0];
1340    }}
1341}}
1342"#,
1343        wg = WG_SIZE
1344    );
1345
1346    // get cached device/queue
1347    with_gpu_device!(dq, {
1348        let device = &dq.device;
1349        let queue = &dq.queue;
1350
1351        // En native: cachear pipeline para evitar reconstrucción
1352        #[cfg(not(target_arch = "wasm32"))]
1353        let pipe_res = REDUCTION_PIPELINE.get_or_init(
1354            || -> Result<(wgpu::ComputePipeline, wgpu::BindGroupLayout), anyhow::Error> {
1355                let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
1356                    label: Some("reduction_shader"),
1357                    source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Owned(shader.clone())),
1358                });
1359
1360                let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
1361                    label: Some("bgl_reduction"),
1362                    entries: &[
1363                        wgpu::BindGroupLayoutEntry {
1364                            binding: 0,
1365                            visibility: wgpu::ShaderStages::COMPUTE,
1366                            ty: wgpu::BindingType::Buffer {
1367                                ty: wgpu::BufferBindingType::Storage { read_only: true },
1368                                has_dynamic_offset: false,
1369                                min_binding_size: None,
1370                            },
1371                            count: None,
1372                        },
1373                        wgpu::BindGroupLayoutEntry {
1374                            binding: 1,
1375                            visibility: wgpu::ShaderStages::COMPUTE,
1376                            ty: wgpu::BindingType::Buffer {
1377                                ty: wgpu::BufferBindingType::Storage { read_only: false },
1378                                has_dynamic_offset: false,
1379                                min_binding_size: None,
1380                            },
1381                            count: None,
1382                        },
1383                        wgpu::BindGroupLayoutEntry {
1384                            binding: 2,
1385                            visibility: wgpu::ShaderStages::COMPUTE,
1386                            ty: wgpu::BindingType::Buffer {
1387                                ty: wgpu::BufferBindingType::Uniform,
1388                                has_dynamic_offset: false,
1389                                min_binding_size: None,
1390                            },
1391                            count: None,
1392                        },
1393                    ],
1394                });
1395
1396                let pipeline_layout =
1397                    device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
1398                        label: Some("pl_reduction"),
1399                        bind_group_layouts: &[&bgl],
1400                        push_constant_ranges: &[],
1401                    });
1402
1403                let compute_pipeline =
1404                    device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1405                        label: Some("pipeline_reduction"),
1406                        layout: Some(&pipeline_layout),
1407                        module: &shader_module,
1408                        entry_point: Some("main"),
1409                        cache: None,
1410                        compilation_options: Default::default(),
1411                    });
1412
1413                Ok((compute_pipeline, bgl))
1414            },
1415        );
1416
1417        #[cfg(not(target_arch = "wasm32"))]
1418        let (compute_pipeline, bgl) = pipe_res
1419            .as_ref()
1420            .map_err(|e| anyhow!("reduction pipeline init failed: {:?}", e))?;
1421
1422        // En WASM: crear fresh cada vez (sin cache estático)
1423        #[cfg(target_arch = "wasm32")]
1424        let (compute_pipeline, bgl) = {
1425            let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
1426                label: Some("reduction_shader"),
1427                source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Owned(shader.clone())),
1428            });
1429
1430            let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
1431                label: Some("bgl_reduction"),
1432                entries: &[
1433                    wgpu::BindGroupLayoutEntry {
1434                        binding: 0,
1435                        visibility: wgpu::ShaderStages::COMPUTE,
1436                        ty: wgpu::BindingType::Buffer {
1437                            ty: wgpu::BufferBindingType::Storage { read_only: true },
1438                            has_dynamic_offset: false,
1439                            min_binding_size: None,
1440                        },
1441                        count: None,
1442                    },
1443                    wgpu::BindGroupLayoutEntry {
1444                        binding: 1,
1445                        visibility: wgpu::ShaderStages::COMPUTE,
1446                        ty: wgpu::BindingType::Buffer {
1447                            ty: wgpu::BufferBindingType::Storage { read_only: false },
1448                            has_dynamic_offset: false,
1449                            min_binding_size: None,
1450                        },
1451                        count: None,
1452                    },
1453                    wgpu::BindGroupLayoutEntry {
1454                        binding: 2,
1455                        visibility: wgpu::ShaderStages::COMPUTE,
1456                        ty: wgpu::BindingType::Buffer {
1457                            ty: wgpu::BufferBindingType::Uniform,
1458                            has_dynamic_offset: false,
1459                            min_binding_size: None,
1460                        },
1461                        count: None,
1462                    },
1463                ],
1464            });
1465
1466            let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
1467                label: Some("pl_reduction"),
1468                bind_group_layouts: &[&bgl],
1469                push_constant_ranges: &[],
1470            });
1471
1472            let compute_pipeline =
1473                device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1474                    label: Some("pipeline_reduction"),
1475                    layout: Some(&pipeline_layout),
1476                    module: &shader_module,
1477                    entry_point: Some("main"),
1478                    cache: None,
1479                    compilation_options: Default::default(),
1480                });
1481
1482            (compute_pipeline, bgl)
1483        };
1484
1485        // iterative GPU reduction: feed input buffer, run reduction producing partials, then repeat on partials until single value
1486        let mut current_size = size as u32;
1487        let mut in_buf = {
1488            let data_bytes = bytemuck::cast_slice(&a.data);
1489            let buf = device.create_buffer(&wgpu::BufferDescriptor {
1490                label: Some("reduce_data"),
1491                size: data_bytes.len() as u64,
1492                usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
1493                mapped_at_creation: false,
1494            });
1495            queue.write_buffer(&buf, 0, data_bytes);
1496            buf
1497        };
1498
1499        // temporary staging used only for final scalar readback
1500        let final_value: f32;
1501
1502        loop {
1503            let groups = ((current_size + WG_SIZE - 1) / WG_SIZE) as u32;
1504
1505            let out_buf = device.create_buffer(&wgpu::BufferDescriptor {
1506                label: Some("partials"),
1507                size: (groups as usize * std::mem::size_of::<f32>()) as u64,
1508                usage: wgpu::BufferUsages::STORAGE
1509                    | wgpu::BufferUsages::COPY_SRC
1510                    | wgpu::BufferUsages::COPY_DST,
1511                mapped_at_creation: false,
1512            });
1513
1514            let params = [current_size];
1515            let params_bytes = bytemuck::cast_slice(&params);
1516            let params_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1517                label: Some("params_reduction"),
1518                contents: params_bytes,
1519                usage: wgpu::BufferUsages::UNIFORM,
1520            });
1521
1522            let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
1523                label: Some("bg_reduction"),
1524                layout: &bgl,
1525                entries: &[
1526                    wgpu::BindGroupEntry {
1527                        binding: 0,
1528                        resource: in_buf.as_entire_binding(),
1529                    },
1530                    wgpu::BindGroupEntry {
1531                        binding: 1,
1532                        resource: out_buf.as_entire_binding(),
1533                    },
1534                    wgpu::BindGroupEntry {
1535                        binding: 2,
1536                        resource: params_buf.as_entire_binding(),
1537                    },
1538                ],
1539            });
1540
1541            // dispatch
1542            let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
1543                label: Some("ce_reduction_iter"),
1544            });
1545            {
1546                let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1547                    label: Some("cp_reduction_iter"),
1548                    timestamp_writes: None,
1549                });
1550                cpass.set_pipeline(&compute_pipeline);
1551                cpass.set_bind_group(0, &bind_group, &[]);
1552                cpass.dispatch_workgroups(groups, 1, 1);
1553            }
1554
1555            // if this is final (groups == 1) copy to staging and read back
1556            if groups == 1 {
1557                let staging = device.create_buffer(&wgpu::BufferDescriptor {
1558                    label: Some("staging_final"),
1559                    size: std::mem::size_of::<f32>() as u64,
1560                    usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
1561                    mapped_at_creation: false,
1562                });
1563                encoder.copy_buffer_to_buffer(
1564                    &out_buf,
1565                    0,
1566                    &staging,
1567                    0,
1568                    std::mem::size_of::<f32>() as u64,
1569                );
1570                queue.submit(Some(encoder.finish()));
1571
1572                // read back
1573                let buffer_slice = staging.slice(..);
1574                use std::sync::mpsc::channel;
1575                let (tx, rx) = channel();
1576                buffer_slice.map_async(wgpu::MapMode::Read, move |r| {
1577                    let _ = tx.send(r);
1578                });
1579                device.poll(wgpu::Maintain::Wait);
1580                let ok = rx
1581                    .recv()
1582                    .map_err(|_| anyhow!("map callback channel error"))?;
1583                ok.map_err(|e| anyhow!("map async failed: {:?}", e))?;
1584
1585                let data = buffer_slice.get_mapped_range();
1586                let b = [data[0], data[1], data[2], data[3]];
1587                final_value = f32::from_bits(u32::from_le_bytes(b));
1588                drop(data);
1589                staging.unmap();
1590                break;
1591            } else {
1592                // prepare for next iteration: set in_buf = out_buf and continue
1593                queue.submit(Some(encoder.finish()));
1594                in_buf = out_buf;
1595                current_size = groups;
1596                // loop
1597            }
1598        }
1599
1600        let total = final_value;
1601        Ok(Array::new(vec![1], vec![total]))
1602    })
1603}
1604
1605/// GPU implementation of broadcast_to using WebGPU compute shader
1606fn run_broadcast_gpu(a: &Array, target_shape: &[usize]) -> Result<Array> {
1607    use wgpu::util::DeviceExt;
1608
1609    let src_ndim = a.shape.len() as u32;
1610    let target_ndim = target_shape.len() as u32;
1611    let target_size: usize = target_shape.iter().product();
1612
1613    if target_size == 0 {
1614        return Ok(Array::new(target_shape.to_vec(), vec![]));
1615    }
1616
1617    // Preparar shapes y strides para el shader
1618    let mut src_shape_padded = vec![1u32; 4];
1619    let mut target_shape_padded = vec![1u32; 4];
1620    let mut src_strides = vec![0u32; 4];
1621
1622    // Copiar shapes (alineados a la derecha)
1623    for i in 0..src_ndim.min(4) as usize {
1624        src_shape_padded[4 - src_ndim as usize + i] = a.shape[i] as u32;
1625    }
1626    for i in 0..target_ndim.min(4) as usize {
1627        target_shape_padded[4 - target_ndim as usize + i] = target_shape[i] as u32;
1628    }
1629
1630    // Calcular strides para source
1631    let mut stride = 1u32;
1632    for i in (0..src_ndim as usize).rev() {
1633        let idx = 4 - src_ndim as usize + i;
1634        src_strides[idx] = stride;
1635        stride *= a.shape[i] as u32;
1636    }
1637
1638    // Shader de broadcast
1639    let shader_code = format!(
1640        r#"
1641struct Params {{
1642    src_shape: vec4<u32>,
1643    target_shape: vec4<u32>,
1644    src_strides: vec4<u32>,
1645    src_ndim: u32,
1646    target_ndim: u32,
1647}};
1648
1649@group(0) @binding(0) var<storage, read> src: array<f32>;
1650@group(0) @binding(1) var<storage, read_write> dst: array<f32>;
1651@group(0) @binding(2) var<uniform> params: Params;
1652
1653@compute @workgroup_size(256)
1654fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
1655    let idx = gid.x;
1656    if (idx >= {target_size}u) {{ return; }}
1657    
1658    // Convertir índice plano a multi-índice en target
1659    var target_idx: vec4<u32> = vec4<u32>(0u);
1660    var remaining = idx;
1661    
1662    for (var i = 0u; i < 4u; i++) {{
1663        let dim_idx = 3u - i;
1664        if (dim_idx >= (4u - params.target_ndim)) {{
1665            target_idx[dim_idx] = remaining % params.target_shape[dim_idx];
1666            remaining = remaining / params.target_shape[dim_idx];
1667        }}
1668    }}
1669    
1670    // Mapear a índice en source (con broadcasting)
1671    var src_flat_idx = 0u;
1672    let src_start_dim = 4u - params.src_ndim;
1673    let target_start_dim = 4u - params.target_ndim;
1674    
1675    for (var i = 0u; i < params.src_ndim; i++) {{
1676        let src_dim_idx = src_start_dim + i;
1677        let target_dim_idx = target_start_dim + i;
1678        let src_dim = params.src_shape[src_dim_idx];
1679        
1680        // Si la dimensión es 1, usar índice 0 (broadcasting)
1681        var idx_val: u32;
1682        if (src_dim == 1u) {{
1683            idx_val = 0u;
1684        }} else {{
1685            idx_val = target_idx[target_dim_idx];
1686        }}
1687        
1688        src_flat_idx += idx_val * params.src_strides[src_dim_idx];
1689    }}
1690    
1691    dst[idx] = src[src_flat_idx];
1692}}
1693"#,
1694        target_size = target_size
1695    );
1696
1697    // Crear instancia, adapter y device
1698    let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
1699        backends: wgpu::Backends::all(),
1700        ..Default::default()
1701    });
1702
1703    let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
1704        power_preference: wgpu::PowerPreference::HighPerformance,
1705        compatible_surface: None,
1706        force_fallback_adapter: false,
1707    }))
1708    .ok_or_else(|| anyhow!("no WebGPU adapter available"))?;
1709
1710    let (device, queue) = pollster::block_on(adapter.request_device(
1711        &wgpu::DeviceDescriptor {
1712            label: Some("broadcast device"),
1713            required_features: wgpu::Features::empty(),
1714            required_limits: wgpu::Limits::default(),
1715            memory_hints: Default::default(),
1716        },
1717        None,
1718    ))
1719    .map_err(|e| anyhow!("request device failed: {:?}", e))?;
1720
1721    // Crear buffers
1722    let src_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1723        label: Some("broadcast src"),
1724        contents: bytemuck::cast_slice(&a.data),
1725        usage: wgpu::BufferUsages::STORAGE,
1726    });
1727
1728    let dst_buffer = device.create_buffer(&wgpu::BufferDescriptor {
1729        label: Some("broadcast dst"),
1730        size: (target_size * 4) as u64,
1731        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
1732        mapped_at_creation: false,
1733    });
1734
1735    // Crear uniform buffer con parámetros
1736    let params_data = [
1737        src_shape_padded[0],
1738        src_shape_padded[1],
1739        src_shape_padded[2],
1740        src_shape_padded[3],
1741        target_shape_padded[0],
1742        target_shape_padded[1],
1743        target_shape_padded[2],
1744        target_shape_padded[3],
1745        src_strides[0],
1746        src_strides[1],
1747        src_strides[2],
1748        src_strides[3],
1749        src_ndim,
1750        target_ndim,
1751        0,
1752        0, // padding
1753    ];
1754
1755    let params_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1756        label: Some("broadcast params"),
1757        contents: bytemuck::cast_slice(&params_data),
1758        usage: wgpu::BufferUsages::UNIFORM,
1759    });
1760
1761    // Compilar shader
1762    let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
1763        label: Some("broadcast shader"),
1764        source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(&shader_code)),
1765    });
1766
1767    // Crear pipeline
1768    let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
1769        label: Some("broadcast bgl"),
1770        entries: &[
1771            wgpu::BindGroupLayoutEntry {
1772                binding: 0,
1773                visibility: wgpu::ShaderStages::COMPUTE,
1774                ty: wgpu::BindingType::Buffer {
1775                    ty: wgpu::BufferBindingType::Storage { read_only: true },
1776                    has_dynamic_offset: false,
1777                    min_binding_size: None,
1778                },
1779                count: None,
1780            },
1781            wgpu::BindGroupLayoutEntry {
1782                binding: 1,
1783                visibility: wgpu::ShaderStages::COMPUTE,
1784                ty: wgpu::BindingType::Buffer {
1785                    ty: wgpu::BufferBindingType::Storage { read_only: false },
1786                    has_dynamic_offset: false,
1787                    min_binding_size: None,
1788                },
1789                count: None,
1790            },
1791            wgpu::BindGroupLayoutEntry {
1792                binding: 2,
1793                visibility: wgpu::ShaderStages::COMPUTE,
1794                ty: wgpu::BindingType::Buffer {
1795                    ty: wgpu::BufferBindingType::Uniform,
1796                    has_dynamic_offset: false,
1797                    min_binding_size: None,
1798                },
1799                count: None,
1800            },
1801        ],
1802    });
1803
1804    let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
1805        label: Some("broadcast pipeline layout"),
1806        bind_group_layouts: &[&bind_group_layout],
1807        push_constant_ranges: &[],
1808    });
1809
1810    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1811        label: Some("broadcast pipeline"),
1812        layout: Some(&pipeline_layout),
1813        module: &shader_module,
1814        entry_point: Some("main"),
1815        cache: None,
1816        compilation_options: Default::default(),
1817    });
1818
1819    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
1820        label: Some("broadcast bind group"),
1821        layout: &bind_group_layout,
1822        entries: &[
1823            wgpu::BindGroupEntry {
1824                binding: 0,
1825                resource: src_buffer.as_entire_binding(),
1826            },
1827            wgpu::BindGroupEntry {
1828                binding: 1,
1829                resource: dst_buffer.as_entire_binding(),
1830            },
1831            wgpu::BindGroupEntry {
1832                binding: 2,
1833                resource: params_buffer.as_entire_binding(),
1834            },
1835        ],
1836    });
1837
1838    // Ejecutar compute pass
1839    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
1840        label: Some("broadcast encoder"),
1841    });
1842
1843    {
1844        let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1845            label: Some("broadcast pass"),
1846            timestamp_writes: None,
1847        });
1848        cpass.set_pipeline(&pipeline);
1849        cpass.set_bind_group(0, &bind_group, &[]);
1850        let workgroups = (target_size + 255) / 256;
1851        cpass.dispatch_workgroups(workgroups as u32, 1, 1);
1852    }
1853
1854    // Copiar resultado a staging buffer
1855    let staging = device.create_buffer(&wgpu::BufferDescriptor {
1856        label: Some("broadcast staging"),
1857        size: (target_size * 4) as u64,
1858        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
1859        mapped_at_creation: false,
1860    });
1861
1862    encoder.copy_buffer_to_buffer(&dst_buffer, 0, &staging, 0, (target_size * 4) as u64);
1863    queue.submit(Some(encoder.finish()));
1864
1865    // Leer resultado
1866    let buffer_slice = staging.slice(..);
1867    use std::sync::mpsc::channel;
1868    let (tx, rx) = channel();
1869    buffer_slice.map_async(wgpu::MapMode::Read, move |r| {
1870        let _ = tx.send(r);
1871    });
1872    device.poll(wgpu::Maintain::Wait);
1873    let ok = rx.recv().map_err(|_| anyhow!("map callback error"))?;
1874    ok.map_err(|e| anyhow!("map async failed: {:?}", e))?;
1875
1876    let data = buffer_slice.get_mapped_range();
1877    let result: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
1878    drop(data);
1879    staging.unmap();
1880
1881    Ok(Array::new(target_shape.to_vec(), result))
1882}