1pub mod batchnorm;
7pub mod codegen;
8pub mod conv;
9pub mod dropout;
10
11use crate::array::Array;
12use anyhow::{anyhow, Result};
13use std::borrow::Cow;
14
15#[cfg(not(target_arch = "wasm32"))]
17use once_cell::sync::OnceCell;
18#[cfg(not(target_arch = "wasm32"))]
19use std::sync::Mutex;
20
21#[cfg(target_arch = "wasm32")]
23use std::cell::RefCell;
24
25#[cfg(target_arch = "wasm32")]
27use std::sync::atomic::{AtomicBool, Ordering};
28
29#[cfg(target_arch = "wasm32")]
30static WEBGPU_AVAILABLE_FROM_JS: AtomicBool = AtomicBool::new(false);
31
32#[cfg(target_arch = "wasm32")]
33pub fn set_webgpu_available_wasm(available: bool) {
34 WEBGPU_AVAILABLE_FROM_JS.store(available, Ordering::SeqCst);
35 eprintln!("[numrs-webgpu] WebGPU available flag set to: {}", available);
36}
37
38#[cfg(target_arch = "wasm32")]
39pub fn get_webgpu_available_wasm() -> bool {
40 WEBGPU_AVAILABLE_FROM_JS.load(Ordering::SeqCst)
41}
42
43#[derive(Debug, Clone)]
44pub struct WebGpuBackend {}
45
46impl WebGpuBackend {
47 pub fn new() -> Self {
48 Self {}
49 }
50
51 pub fn is_available() -> bool {
53 #[cfg(target_arch = "wasm32")]
54 {
55 get_webgpu_available_wasm()
56 }
57 #[cfg(not(target_arch = "wasm32"))]
58 {
59 let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
60 backends: wgpu::Backends::all(),
61 ..Default::default()
62 });
63 let adapter =
64 pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
65 power_preference: wgpu::PowerPreference::HighPerformance,
66 compatible_surface: None,
67 force_fallback_adapter: false,
68 }));
69
70 adapter.is_some()
71 }
72 }
73}
74
75struct GpuContext {
77 device: wgpu::Device,
78 queue: wgpu::Queue,
79 matmul_pipeline: wgpu::ComputePipeline,
80 matmul_bgl: wgpu::BindGroupLayout,
81}
82
83use std::sync::Arc;
84
85struct DeviceQueue {
86 device: Arc<wgpu::Device>,
87 queue: Arc<wgpu::Queue>,
88}
89
90#[cfg(target_arch = "wasm32")]
91pub async fn init_webgpu_wasm() -> Result<()> {
92 let already_init = GPU_DEVICE.with(|cell| cell.borrow().is_some());
95 if already_init {
96 return Ok(());
97 }
98
99 let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
102 backends: wgpu::Backends::all(),
103 ..Default::default()
104 });
105
106 let adapter = instance
107 .request_adapter(&wgpu::RequestAdapterOptions {
108 power_preference: wgpu::PowerPreference::None,
109 compatible_surface: None,
110 force_fallback_adapter: false,
111 })
112 .await
113 .ok_or_else(|| anyhow!("no WebGPU adapter available"))?;
114
115 let (device, queue) = adapter
116 .request_device(
117 &wgpu::DeviceDescriptor {
118 label: None,
119 required_features: wgpu::Features::empty(),
120 required_limits: wgpu::Limits::default(),
121 memory_hints: Default::default(),
122 },
123 None,
124 )
125 .await
126 .map_err(|e| anyhow!("request device failed: {:?}", e))?;
127
128 GPU_DEVICE.with(|cell| {
130 *cell.borrow_mut() = Some(Ok(DeviceQueue {
131 device: Arc::new(device),
132 queue: Arc::new(queue),
133 }));
134 });
135
136 set_webgpu_available_wasm(true);
137 Ok(())
138}
139
140#[cfg(target_arch = "wasm32")]
141fn get_gpu_device() -> Result<DeviceQueue> {
142 GPU_DEVICE.with(|cell| {
143 let borrow = cell.borrow();
144 match borrow.as_ref() {
145 Some(Ok(dq)) => {
146 Ok(DeviceQueue {
148 device: dq.device.clone(),
149 queue: dq.queue.clone(),
150 })
151 }
152 Some(Err(e)) => Err(anyhow!("WebGPU init failed previously: {:?}", e)),
153 None => Err(anyhow!(
154 "WebGPU/WebGL not initialized. Ensure init_webgpu() is called."
155 )),
156 }
157 })
158}
159
160#[cfg(not(target_arch = "wasm32"))]
165static GPU_DEVICE: OnceCell<Result<DeviceQueue, anyhow::Error>> = OnceCell::new();
166
167#[cfg(target_arch = "wasm32")]
168thread_local! {
169 static GPU_DEVICE: RefCell<Option<Result<DeviceQueue, anyhow::Error>>> = RefCell::new(None);
170}
171
172#[cfg(not(target_arch = "wasm32"))]
173fn get_gpu_device() -> Result<&'static DeviceQueue> {
174 GPU_DEVICE.get_or_init(|| {
175 let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
176 backends: wgpu::Backends::all(),
177 ..Default::default()
178 });
179 let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
180 power_preference: wgpu::PowerPreference::HighPerformance,
181 compatible_surface: None,
182 force_fallback_adapter: false,
183 }))
184 .ok_or_else(|| anyhow!("no WebGPU adapter available"))?;
185
186 let (device, queue) = pollster::block_on(adapter.request_device(
187 &wgpu::DeviceDescriptor {
188 label: None,
189 required_features: wgpu::Features::empty(), required_limits: wgpu::Limits::default(), memory_hints: Default::default(),
192 },
193 None,
194 ))
195 .map_err(|e| anyhow!("request device failed: {:?}", e))?;
196
197 Ok(DeviceQueue {
198 device: Arc::new(device),
199 queue: Arc::new(queue),
200 })
201 });
202
203 let init_ref = GPU_DEVICE.get().expect("gpu device was just initialized");
204 match init_ref {
205 Ok(dq) => Ok(dq),
206 Err(e) => Err(anyhow!("gpu init failed: {:?}", e)),
207 }
208}
209
210#[cfg(not(target_arch = "wasm32"))]
213macro_rules! with_gpu_device {
214 ($dq:ident, $code:expr) => {{
215 let $dq = get_gpu_device()?;
216 $code
217 }};
218}
219
220#[cfg(target_arch = "wasm32")]
221macro_rules! with_gpu_device {
222 ($dq:ident, $code:expr) => {{
223 let $dq = get_gpu_device()?;
224 $code
225 }};
226}
227
228#[cfg(not(target_arch = "wasm32"))]
230static REDUCTION_PIPELINE: OnceCell<
231 Result<(wgpu::ComputePipeline, wgpu::BindGroupLayout), anyhow::Error>,
232> = OnceCell::new();
233
234#[cfg(not(target_arch = "wasm32"))]
236fn get_gpu_context(shader_src: &str) -> Result<&'static GpuContext> {
237 static CTX: OnceCell<Result<GpuContext, anyhow::Error>> = OnceCell::new();
238
239 CTX.get_or_init(|| -> Result<GpuContext, anyhow::Error> {
240 eprintln!("NumRs-Core: Starting WebGPU Init...");
242
243 let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
244 backends: wgpu::Backends::all(),
245 ..Default::default()
246 });
247
248 eprintln!("NumRs-Core: Instance created. Requesting adapter...");
249
250 let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
251 power_preference: wgpu::PowerPreference::None,
252 compatible_surface: None,
253 force_fallback_adapter: false,
254 }));
255
256 if adapter.is_none() {
257 eprintln!("NumRs-Core: Adapter request returned None!");
258 return Err(anyhow!("no WebGPU adapter available"));
259 }
260 let adapter = adapter.unwrap();
261
262 eprintln!("NumRs-Core: Adapter found. Requesting device...");
263
264 let (device, queue) = pollster::block_on(adapter.request_device(
265 &wgpu::DeviceDescriptor {
266 label: None,
267 required_features: wgpu::Features::empty(),
268 required_limits: wgpu::Limits::default(),
269 memory_hints: Default::default(),
270 },
271 None,
272 ))
273 .map_err(|e| anyhow!("request device failed: {:?}", e))?;
274
275 let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
277 label: Some("matmul_shader"),
278 source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Owned(shader_src.to_string())),
279 });
280
281 let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
282 label: Some("bgl_matmul"),
283 entries: &[
284 wgpu::BindGroupLayoutEntry {
285 binding: 0,
286 visibility: wgpu::ShaderStages::COMPUTE,
287 ty: wgpu::BindingType::Buffer {
288 ty: wgpu::BufferBindingType::Storage { read_only: true },
289 has_dynamic_offset: false,
290 min_binding_size: None,
291 },
292 count: None,
293 },
294 wgpu::BindGroupLayoutEntry {
295 binding: 1,
296 visibility: wgpu::ShaderStages::COMPUTE,
297 ty: wgpu::BindingType::Buffer {
298 ty: wgpu::BufferBindingType::Storage { read_only: true },
299 has_dynamic_offset: false,
300 min_binding_size: None,
301 },
302 count: None,
303 },
304 wgpu::BindGroupLayoutEntry {
305 binding: 2,
306 visibility: wgpu::ShaderStages::COMPUTE,
307 ty: wgpu::BindingType::Buffer {
308 ty: wgpu::BufferBindingType::Storage { read_only: false },
309 has_dynamic_offset: false,
310 min_binding_size: None,
311 },
312 count: None,
313 },
314 wgpu::BindGroupLayoutEntry {
315 binding: 3,
316 visibility: wgpu::ShaderStages::COMPUTE,
317 ty: wgpu::BindingType::Buffer {
318 ty: wgpu::BufferBindingType::Uniform,
319 has_dynamic_offset: false,
320 min_binding_size: None,
321 },
322 count: None,
323 },
324 ],
325 });
326
327 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
328 label: Some("pl_matmul"),
329 bind_group_layouts: &[&bgl],
330 push_constant_ranges: &[],
331 });
332
333 let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
334 label: Some("pipeline_matmul"),
335 layout: Some(&pipeline_layout),
336 module: &shader_module,
337 entry_point: Some("main"),
338 cache: None,
339 compilation_options: Default::default(),
340 });
341
342 Ok(GpuContext {
343 device,
344 queue,
345 matmul_pipeline: compute_pipeline,
346 matmul_bgl: bgl,
347 })
348 });
349
350 let init_ref = CTX.get().expect("OnceCell was just initialized");
351 match init_ref {
352 Ok(ctx) => Ok(ctx),
353 Err(e) => Err(anyhow!("gpu init failed: {:?}", e)),
354 }
355}
356
357#[cfg(not(target_arch = "wasm32"))]
359struct CachedBuffers {
360 m: u32,
361 n: u32,
362 k: u32,
363 _len: usize,
364 a_buf: wgpu::Buffer,
365 b_buf: wgpu::Buffer,
366 out_buf: wgpu::Buffer,
367 params_buf: wgpu::Buffer,
368 staging: wgpu::Buffer,
369}
370
371#[cfg(not(target_arch = "wasm32"))]
372static BUFFERS: OnceCell<Mutex<Option<CachedBuffers>>> = OnceCell::new();
373
374#[cfg(not(target_arch = "wasm32"))]
376pub fn is_available_cached() -> bool {
377 static PROBE: OnceCell<bool> = OnceCell::new();
378 *PROBE.get_or_init(|| WebGpuBackend::is_available())
379}
380
381#[cfg(target_arch = "wasm32")]
383pub fn is_available_cached() -> bool {
384 get_webgpu_available_wasm()
386}
387
388fn run_elementwise_gpu(a: &Array, b: &Array, kind: crate::llo::ElementwiseKind) -> Result<Array> {
389 use wgpu::util::DeviceExt;
390
391 let len = a.len();
392
393 let op = match kind {
395 crate::llo::ElementwiseKind::Add => "a[idx] + b[idx]",
396 crate::llo::ElementwiseKind::Mul => "a[idx] * b[idx]",
397 crate::llo::ElementwiseKind::Sub => "a[idx] - b[idx]",
398 crate::llo::ElementwiseKind::Div => "a[idx] / b[idx]",
399 crate::llo::ElementwiseKind::Sqrt => "sqrt(a[idx])",
400 crate::llo::ElementwiseKind::Sin => "sin(a[idx])",
401 crate::llo::ElementwiseKind::Cos => "cos(a[idx])",
402 crate::llo::ElementwiseKind::Pow => "pow(a[idx], b[idx])",
403 crate::llo::ElementwiseKind::Abs => "abs(a[idx])",
404 crate::llo::ElementwiseKind::Neg => "-a[idx]",
405 crate::llo::ElementwiseKind::Exp => "exp(a[idx])",
406 crate::llo::ElementwiseKind::Log => "log(a[idx])",
407 crate::llo::ElementwiseKind::Tan => "tan(a[idx])",
408 crate::llo::ElementwiseKind::Asin => "asin(a[idx])",
409 crate::llo::ElementwiseKind::Acos => "acos(a[idx])",
410 crate::llo::ElementwiseKind::Atan => "atan(a[idx])",
411 crate::llo::ElementwiseKind::Relu => "max(a[idx], 0.0)",
412 crate::llo::ElementwiseKind::LeakyRelu => "select(0.01 * a[idx], a[idx], a[idx] > 0.0)",
413 crate::llo::ElementwiseKind::Sigmoid => "1.0 / (1.0 + exp(-a[idx]))",
414 crate::llo::ElementwiseKind::Tanh => "tanh(a[idx])",
415 crate::llo::ElementwiseKind::Softplus => "log(1.0 + exp(a[idx]))",
416 };
417
418 let shader = format!(
419 r#"
420 struct Params {{ size: u32, }};
421@group(0) @binding(0) var<storage, read> a: array<f32>;
422@group(0) @binding(1) var<storage, read> b: array<f32>;
423@group(0) @binding(2) var<storage, read_write> out: array<f32>;
424@group(0) @binding(3) var<uniform> params: Params;
425
426@compute @workgroup_size(64)
427fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
428 let idx: u32 = gid.x;
429 if (idx >= params.size) {{ return; }}
430 out[idx] = {op};
431}}
432"#
433 );
434
435 let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
437 backends: wgpu::Backends::all(),
438 ..Default::default()
439 });
440 let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
441 power_preference: wgpu::PowerPreference::HighPerformance,
442 compatible_surface: None,
443 force_fallback_adapter: false,
444 }))
445 .ok_or_else(|| anyhow!("no WebGPU adapter available"))?;
446
447 let (device, queue) = pollster::block_on(adapter.request_device(
448 &wgpu::DeviceDescriptor {
449 label: None,
450 required_features: wgpu::Features::empty(),
451 required_limits: wgpu::Limits::default(),
452 memory_hints: Default::default(),
453 },
454 None,
455 ))
456 .map_err(|e| anyhow!("request device failed: {:?}", e))?;
457
458 let a_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
460 label: Some("a_buf"),
461 contents: bytemuck::cast_slice(&a.data),
462 usage: wgpu::BufferUsages::STORAGE,
463 });
464
465 let b_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
466 label: Some("b_buf"),
467 contents: bytemuck::cast_slice(&b.data),
468 usage: wgpu::BufferUsages::STORAGE,
469 });
470
471 let out_buf = device.create_buffer(&wgpu::BufferDescriptor {
472 label: Some("out_buf"),
473 size: (len * std::mem::size_of::<f32>()) as u64,
474 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
475 mapped_at_creation: false,
476 });
477
478 let params = [len as u32];
479 let params_bytes = bytemuck::cast_slice(¶ms);
480 let params_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
481 label: Some("params"),
482 contents: params_bytes,
483 usage: wgpu::BufferUsages::UNIFORM,
484 });
485
486 let staging = device.create_buffer(&wgpu::BufferDescriptor {
488 label: Some("staging"),
489 size: (len * std::mem::size_of::<f32>()) as u64,
490 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
491 mapped_at_creation: false,
492 });
493
494 let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
495 label: Some("elementwise_shader"),
496 source: wgpu::ShaderSource::Wgsl(Cow::Owned(shader)),
497 });
498
499 let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
501 label: Some("bgl"),
502 entries: &[
503 wgpu::BindGroupLayoutEntry {
504 binding: 0,
505 visibility: wgpu::ShaderStages::COMPUTE,
506 ty: wgpu::BindingType::Buffer {
507 ty: wgpu::BufferBindingType::Storage { read_only: true },
508 has_dynamic_offset: false,
509 min_binding_size: None,
510 },
511 count: None,
512 },
513 wgpu::BindGroupLayoutEntry {
514 binding: 1,
515 visibility: wgpu::ShaderStages::COMPUTE,
516 ty: wgpu::BindingType::Buffer {
517 ty: wgpu::BufferBindingType::Storage { read_only: true },
518 has_dynamic_offset: false,
519 min_binding_size: None,
520 },
521 count: None,
522 },
523 wgpu::BindGroupLayoutEntry {
524 binding: 2,
525 visibility: wgpu::ShaderStages::COMPUTE,
526 ty: wgpu::BindingType::Buffer {
527 ty: wgpu::BufferBindingType::Storage { read_only: false },
528 has_dynamic_offset: false,
529 min_binding_size: None,
530 },
531 count: None,
532 },
533 wgpu::BindGroupLayoutEntry {
534 binding: 3,
535 visibility: wgpu::ShaderStages::COMPUTE,
536 ty: wgpu::BindingType::Buffer {
537 ty: wgpu::BufferBindingType::Uniform,
538 has_dynamic_offset: false,
539 min_binding_size: None,
540 },
541 count: None,
542 },
543 ],
544 });
545
546 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
547 label: Some("pl"),
548 bind_group_layouts: &[&bgl],
549 push_constant_ranges: &[],
550 });
551
552 let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
553 label: Some("pipeline"),
554 layout: Some(&pipeline_layout),
555 module: &shader_module,
556 entry_point: Some("main"),
557 cache: None,
558 compilation_options: Default::default(),
559 });
560
561 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
562 label: Some("bg"),
563 layout: &bgl,
564 entries: &[
565 wgpu::BindGroupEntry {
566 binding: 0,
567 resource: a_buf.as_entire_binding(),
568 },
569 wgpu::BindGroupEntry {
570 binding: 1,
571 resource: b_buf.as_entire_binding(),
572 },
573 wgpu::BindGroupEntry {
574 binding: 2,
575 resource: out_buf.as_entire_binding(),
576 },
577 wgpu::BindGroupEntry {
578 binding: 3,
579 resource: params_buf.as_entire_binding(),
580 },
581 ],
582 });
583
584 let mut encoder =
585 device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("ce") });
586
587 {
588 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
589 label: Some("cp"),
590 timestamp_writes: None,
591 });
592 cpass.set_pipeline(&compute_pipeline);
593 cpass.set_bind_group(0, &bind_group, &[]);
594
595 let workgroups = ((len as u32) + 63) / 64;
596 cpass.dispatch_workgroups(workgroups, 1, 1);
597 }
598
599 encoder.copy_buffer_to_buffer(
601 &out_buf,
602 0,
603 &staging,
604 0,
605 (len * std::mem::size_of::<f32>()) as u64,
606 );
607
608 queue.submit(Some(encoder.finish()));
609
610 let buffer_slice = staging.slice(..);
612 use std::sync::mpsc::channel;
615 let (tx, rx) = channel();
616 buffer_slice.map_async(wgpu::MapMode::Read, move |r| {
617 let _ = tx.send(r);
618 });
619 device.poll(wgpu::Maintain::Wait);
620 let ok = rx
621 .recv()
622 .map_err(|_| anyhow!("map callback channel error"))?;
623 ok.map_err(|e| anyhow!("map async failed: {:?}", e))?;
624
625 let data = buffer_slice.get_mapped_range();
626 let mut out_vec = Vec::with_capacity(len);
627 for chunk in data.chunks_exact(4) {
629 let b = [chunk[0], chunk[1], chunk[2], chunk[3]];
630 out_vec.push(f32::from_bits(u32::from_le_bytes(b)));
631 }
632
633 drop(data);
634 staging.unmap();
635
636 Ok(Array::new(a.shape.clone(), out_vec))
637}
638
639#[cfg(not(target_arch = "wasm32"))]
643fn run_matmul_gpu_fast(a: &Array, b: &Array) -> Result<Array> {
644 use wgpu::util::DeviceExt;
645 let m = a.shape[0] as u32;
646 let k = a.shape[1] as u32;
647 let n = b.shape[1] as u32;
648 let len = (m as usize) * (n as usize);
649
650 let shader = format!(
655 r#"
656struct Params {{ m: u32, n: u32, k: u32, }};
657@group(0) @binding(0) var<storage, read> a: array<f32>;
658@group(0) @binding(1) var<storage, read> b: array<f32>;
659@group(0) @binding(2) var<storage, read_write> out: array<f32>;
660@group(0) @binding(3) var<uniform> params: Params;
661
662const TILE: u32 = 32u;
663var<workgroup> tileA: array<f32, 1024>; // 32x32 tile
664var<workgroup> tileB: array<f32, 1024>; // 32x32 tile
665
666@compute @workgroup_size(8, 8)
667fn main(@builtin(global_invocation_id) _gid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>, @builtin(workgroup_id) wid: vec3<u32>) {{
668 let row_base: u32 = wid.y * TILE;
669 let col_base: u32 = wid.x * TILE;
670
671 // Each thread computes a 4x4 block starting at these local indices
672 let local_r0: u32 = lid.y * 4u;
673 let local_c0: u32 = lid.x * 4u;
674
675 // Accumulator registers for 4x4 output block (16 values)
676 var sum00: vec4<f32> = vec4<f32>(0.0, 0.0, 0.0, 0.0);
677 var sum10: vec4<f32> = vec4<f32>(0.0, 0.0, 0.0, 0.0);
678 var sum20: vec4<f32> = vec4<f32>(0.0, 0.0, 0.0, 0.0);
679 var sum30: vec4<f32> = vec4<f32>(0.0, 0.0, 0.0, 0.0);
680
681 var k0: u32 = 0u;
682 loop {{
683 if (k0 >= params.k) {{ break; }}
684
685 // Cooperative loading: each thread loads 4x4 elements into shared memory
686 // This ensures coalesced memory access and full utilization
687 for (var dy: u32 = 0u; dy < 4u; dy = dy + 1u) {{
688 for (var dx: u32 = 0u; dx < 4u; dx = dx + 1u) {{
689 let lr = local_r0 + dy;
690 let lc = local_c0 + dx;
691
692 // Load A tile
693 let a_row = row_base + lr;
694 let a_col = k0 + lc;
695 if (a_row < params.m && a_col < params.k) {{
696 tileA[lr * TILE + lc] = a[a_row * params.k + a_col];
697 }} else {{
698 tileA[lr * TILE + lc] = 0.0;
699 }}
700
701 // Load B tile
702 let b_row = k0 + lr;
703 let b_col = col_base + lc;
704 if (b_row < params.k && b_col < params.n) {{
705 tileB[lr * TILE + lc] = b[b_row * params.n + b_col];
706 }} else {{
707 tileB[lr * TILE + lc] = 0.0;
708 }}
709 }}
710 }}
711
712 workgroupBarrier();
713
714 // Inner loop: matrix multiply within tile using vec4 operations
715 // Process 4 k-elements at a time for better arithmetic throughput
716 var t: u32 = 0u;
717 loop {{
718 if (t + 4u > TILE) {{ break; }}
719
720 // Load A vectors for 4 output rows
721 let a0 = vec4<f32>(
722 tileA[local_r0 * TILE + t],
723 tileA[local_r0 * TILE + t + 1u],
724 tileA[local_r0 * TILE + t + 2u],
725 tileA[local_r0 * TILE + t + 3u]
726 );
727 let a1 = vec4<f32>(
728 tileA[(local_r0 + 1u) * TILE + t],
729 tileA[(local_r0 + 1u) * TILE + t + 1u],
730 tileA[(local_r0 + 1u) * TILE + t + 2u],
731 tileA[(local_r0 + 1u) * TILE + t + 3u]
732 );
733 let a2 = vec4<f32>(
734 tileA[(local_r0 + 2u) * TILE + t],
735 tileA[(local_r0 + 2u) * TILE + t + 1u],
736 tileA[(local_r0 + 2u) * TILE + t + 2u],
737 tileA[(local_r0 + 2u) * TILE + t + 3u]
738 );
739 let a3 = vec4<f32>(
740 tileA[(local_r0 + 3u) * TILE + t],
741 tileA[(local_r0 + 3u) * TILE + t + 1u],
742 tileA[(local_r0 + 3u) * TILE + t + 2u],
743 tileA[(local_r0 + 3u) * TILE + t + 3u]
744 );
745
746 // Load B vectors for 4 output columns (transposed access pattern)
747 let b0 = vec4<f32>(
748 tileB[t * TILE + local_c0],
749 tileB[(t + 1u) * TILE + local_c0],
750 tileB[(t + 2u) * TILE + local_c0],
751 tileB[(t + 3u) * TILE + local_c0]
752 );
753 let b1 = vec4<f32>(
754 tileB[t * TILE + local_c0 + 1u],
755 tileB[(t + 1u) * TILE + local_c0 + 1u],
756 tileB[(t + 2u) * TILE + local_c0 + 1u],
757 tileB[(t + 3u) * TILE + local_c0 + 1u]
758 );
759 let b2 = vec4<f32>(
760 tileB[t * TILE + local_c0 + 2u],
761 tileB[(t + 1u) * TILE + local_c0 + 2u],
762 tileB[(t + 2u) * TILE + local_c0 + 2u],
763 tileB[(t + 3u) * TILE + local_c0 + 2u]
764 );
765 let b3 = vec4<f32>(
766 tileB[t * TILE + local_c0 + 3u],
767 tileB[(t + 1u) * TILE + local_c0 + 3u],
768 tileB[(t + 2u) * TILE + local_c0 + 3u],
769 tileB[(t + 3u) * TILE + local_c0 + 3u]
770 );
771
772 // Compute 4x4 block using dot products (16 FMA operations)
773 sum00 = sum00 + vec4<f32>(dot(a0, b0), dot(a0, b1), dot(a0, b2), dot(a0, b3));
774 sum10 = sum10 + vec4<f32>(dot(a1, b0), dot(a1, b1), dot(a1, b2), dot(a1, b3));
775 sum20 = sum20 + vec4<f32>(dot(a2, b0), dot(a2, b1), dot(a2, b2), dot(a2, b3));
776 sum30 = sum30 + vec4<f32>(dot(a3, b0), dot(a3, b1), dot(a3, b2), dot(a3, b3));
777
778 t = t + 4u;
779 }}
780
781 workgroupBarrier();
782 k0 = k0 + TILE;
783 }}
784
785 // Write 4x4 block results back to output (with bounds checking)
786 for (var row_off: u32 = 0u; row_off < 4u; row_off = row_off + 1u) {{
787 let global_row = row_base + local_r0 + row_off;
788 if (global_row >= params.m) {{ continue; }}
789
790 let result_vec = select(
791 sum00,
792 select(sum10, select(sum20, sum30, row_off == 3u), row_off == 2u),
793 row_off == 1u
794 );
795
796 for (var col_off: u32 = 0u; col_off < 4u; col_off = col_off + 1u) {{
797 let global_col = col_base + local_c0 + col_off;
798 if (global_col < params.n) {{
799 out[global_row * params.n + global_col] = result_vec[col_off];
800 }}
801 }}
802 }}
803}}
804"#
805 );
806
807 let ctx = get_gpu_context(&shader)?;
809
810 let device = &ctx.device;
811 let queue = &ctx.queue;
812
813 let buf_mutex = BUFFERS.get_or_init(|| Mutex::new(None));
815 let mut guard = buf_mutex.lock().unwrap();
816
817 if guard
818 .as_ref()
819 .map(|c| c.m != m || c.n != n || c.k != k)
820 .unwrap_or(true)
821 {
822 let a_buf = device.create_buffer(&wgpu::BufferDescriptor {
824 label: Some("a_buf"),
825 size: ((m as usize * k as usize) * std::mem::size_of::<f32>()) as u64,
826 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
827 mapped_at_creation: false,
828 });
829
830 let b_buf = device.create_buffer(&wgpu::BufferDescriptor {
831 label: Some("b_buf"),
832 size: ((k as usize * n as usize) * std::mem::size_of::<f32>()) as u64,
833 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
834 mapped_at_creation: false,
835 });
836
837 let out_buf = device.create_buffer(&wgpu::BufferDescriptor {
838 label: Some("out_buf"),
839 size: (len * std::mem::size_of::<f32>()) as u64,
840 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
841 mapped_at_creation: false,
842 });
843
844 let params = [m, n, k];
845 let params_bytes = bytemuck::cast_slice(¶ms);
846 let params_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
847 label: Some("params"),
848 contents: params_bytes,
849 usage: wgpu::BufferUsages::UNIFORM,
850 });
851
852 let staging = device.create_buffer(&wgpu::BufferDescriptor {
853 label: Some("staging"),
854 size: (len * std::mem::size_of::<f32>()) as u64,
855 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
856 mapped_at_creation: false,
857 });
858
859 *guard = Some(CachedBuffers {
860 m,
861 n,
862 k,
863 _len: len,
864 a_buf,
865 b_buf,
866 out_buf,
867 params_buf,
868 staging,
869 });
870 }
871
872 let cached = guard.as_ref().unwrap();
873
874 let a_bytes = bytemuck::cast_slice(&a.data);
877 let b_bytes = bytemuck::cast_slice(&b.data);
878 queue.write_buffer(&cached.a_buf, 0, a_bytes);
879 queue.write_buffer(&cached.b_buf, 0, b_bytes);
880
881 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
882 label: Some("bg_matmul"),
883 layout: &ctx.matmul_bgl,
884 entries: &[
885 wgpu::BindGroupEntry {
886 binding: 0,
887 resource: cached.a_buf.as_entire_binding(),
888 },
889 wgpu::BindGroupEntry {
890 binding: 1,
891 resource: cached.b_buf.as_entire_binding(),
892 },
893 wgpu::BindGroupEntry {
894 binding: 2,
895 resource: cached.out_buf.as_entire_binding(),
896 },
897 wgpu::BindGroupEntry {
898 binding: 3,
899 resource: cached.params_buf.as_entire_binding(),
900 },
901 ],
902 });
903
904 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
905 label: Some("ce_matmul"),
906 });
907
908 {
909 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
910 label: Some("cp_matmul"),
911 timestamp_writes: None,
912 });
913 cpass.set_pipeline(&ctx.matmul_pipeline);
914 cpass.set_bind_group(0, &bind_group, &[]);
915
916 let wg_x = ((n + 31) / 32) as u32;
918 let wg_y = ((m + 31) / 32) as u32;
919 cpass.dispatch_workgroups(wg_x, wg_y, 1);
920 }
921
922 encoder.copy_buffer_to_buffer(
923 &cached.out_buf,
924 0,
925 &cached.staging,
926 0,
927 (len * std::mem::size_of::<f32>()) as u64,
928 );
929
930 queue.submit(Some(encoder.finish()));
931
932 let buffer_slice = cached.staging.slice(..);
934 use std::sync::mpsc::channel;
935 let (tx, rx) = channel();
936 buffer_slice.map_async(wgpu::MapMode::Read, move |r| {
937 let _ = tx.send(r);
938 });
939 device.poll(wgpu::Maintain::Wait);
940 let ok = rx
941 .recv()
942 .map_err(|_| anyhow!("map callback channel error"))?;
943 ok.map_err(|e| anyhow!("map async failed: {:?}", e))?;
944
945 let data = buffer_slice.get_mapped_range();
946 let mut out_vec = Vec::with_capacity(len);
947 for chunk in data.chunks_exact(4) {
948 let b = [chunk[0], chunk[1], chunk[2], chunk[3]];
949 out_vec.push(f32::from_bits(u32::from_le_bytes(b)));
950 }
951 drop(data);
952 cached.staging.unmap();
953
954 Ok(Array::new(vec![m as usize, n as usize], out_vec))
955}
956
957pub fn elementwise_webgpu(
963 a: &Array,
964 b: &Array,
965 kind: crate::llo::ElementwiseKind,
966) -> Result<Array> {
967 run_elementwise_gpu(a, b, kind)
968}
969
970pub fn matmul_webgpu(a: &Array, b: &Array) -> Array {
972 run_matmul_gpu(a, b).expect("WebGPU matmul failed")
973}
974
975pub fn reduction_webgpu(a: &Array, axis: Option<usize>) -> Result<Array> {
977 run_reduction_gpu(a, axis)
978}
979
980pub fn broadcast_to_webgpu(a: &Array, target_shape: &[usize]) -> Result<Array> {
982 run_broadcast_gpu(a, target_shape)
983}
984
985fn run_matmul_gpu_streaming(a: &Array, b: &Array) -> Result<Array> {
992 use std::cmp::min;
993 use wgpu::util::DeviceExt;
994
995 let m = a.shape[0] as usize;
996 let k = a.shape[1] as usize;
997 let n = b.shape[1] as usize;
998
999 let tile_shader = r#"
1001struct Params { tm: u32, tn: u32, tk: u32 };
1002@group(0) @binding(0) var<storage, read> a: array<f32>;
1003@group(0) @binding(1) var<storage, read> b: array<f32>;
1004@group(0) @binding(2) var<storage, read_write> out: array<f32>;
1005@group(0) @binding(3) var<uniform> params: Params;
1006
1007@compute @workgroup_size(16,16)
1008fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
1009 let row = gid.y;
1010 let col = gid.x;
1011 if (row >= params.tm || col >= params.tn) { return; }
1012 var sum: f32 = 0.0;
1013 var kk: u32 = 0u;
1014 loop {
1015 if (kk >= params.tk) { break; }
1016 sum = sum + a[row * params.tk + kk] * b[kk * params.tn + col];
1017 kk = kk + 1u;
1018 }
1019 let idx = row * params.tn + col;
1020 out[idx] = out[idx] + sum;
1021}
1022"#;
1023
1024 with_gpu_device!(dq, {
1025 let device = &dq.device;
1026 let queue = &dq.queue;
1027
1028 let max_buf_bytes = device.limits().max_storage_buffer_binding_size as usize;
1029 let max_elems = max_buf_bytes / std::mem::size_of::<f32>();
1030
1031 let prefer_tile = 1024usize;
1032 let tile_k = std::cmp::min(
1033 k,
1034 std::cmp::min(
1035 prefer_tile,
1036 std::cmp::max(1, max_elems / std::cmp::max(1, std::cmp::max(m, n))),
1037 ),
1038 );
1039 let tile_m = std::cmp::min(
1040 m,
1041 std::cmp::max(
1042 64,
1043 std::cmp::min(
1044 prefer_tile,
1045 std::cmp::max(1, max_elems / std::cmp::max(1, k)),
1046 ),
1047 ),
1048 );
1049 let tile_n = tile_m;
1050
1051 let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
1053 label: Some("tile_matmul"),
1054 source: wgpu::ShaderSource::Wgsl(Cow::Owned(tile_shader.to_string())),
1055 });
1056 let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
1057 label: Some("bgl_tile_matmul"),
1058 entries: &[
1059 wgpu::BindGroupLayoutEntry {
1060 binding: 0,
1061 visibility: wgpu::ShaderStages::COMPUTE,
1062 ty: wgpu::BindingType::Buffer {
1063 ty: wgpu::BufferBindingType::Storage { read_only: true },
1064 has_dynamic_offset: false,
1065 min_binding_size: None,
1066 },
1067 count: None,
1068 },
1069 wgpu::BindGroupLayoutEntry {
1070 binding: 1,
1071 visibility: wgpu::ShaderStages::COMPUTE,
1072 ty: wgpu::BindingType::Buffer {
1073 ty: wgpu::BufferBindingType::Storage { read_only: true },
1074 has_dynamic_offset: false,
1075 min_binding_size: None,
1076 },
1077 count: None,
1078 },
1079 wgpu::BindGroupLayoutEntry {
1080 binding: 2,
1081 visibility: wgpu::ShaderStages::COMPUTE,
1082 ty: wgpu::BindingType::Buffer {
1083 ty: wgpu::BufferBindingType::Storage { read_only: false },
1084 has_dynamic_offset: false,
1085 min_binding_size: None,
1086 },
1087 count: None,
1088 },
1089 wgpu::BindGroupLayoutEntry {
1090 binding: 3,
1091 visibility: wgpu::ShaderStages::COMPUTE,
1092 ty: wgpu::BindingType::Buffer {
1093 ty: wgpu::BufferBindingType::Uniform,
1094 has_dynamic_offset: false,
1095 min_binding_size: None,
1096 },
1097 count: None,
1098 },
1099 ],
1100 });
1101 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
1102 label: Some("pl_tile_matmul"),
1103 bind_group_layouts: &[&bgl],
1104 push_constant_ranges: &[],
1105 });
1106 let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1107 label: Some("pipeline_tile_matmul"),
1108 layout: Some(&pipeline_layout),
1109 module: &shader_module,
1110 entry_point: Some("main"),
1111 cache: None,
1112 compilation_options: Default::default(),
1113 });
1114
1115 let mut result = vec![0.0f32; m * n];
1116
1117 for i in (0..m).step_by(tile_m) {
1118 let tm = min(tile_m, m - i);
1119 for j in (0..n).step_by(tile_n) {
1120 let tn = min(tile_n, n - j);
1121
1122 let out_size = (tm * tn) * std::mem::size_of::<f32>();
1123 let out_buf = device.create_buffer(&wgpu::BufferDescriptor {
1124 label: Some("out_tile"),
1125 size: out_size as u64,
1126 usage: wgpu::BufferUsages::STORAGE
1127 | wgpu::BufferUsages::COPY_SRC
1128 | wgpu::BufferUsages::COPY_DST,
1129 mapped_at_creation: false,
1130 });
1131 let zeros = vec![0u8; out_size];
1132 queue.write_buffer(&out_buf, 0, &zeros);
1133
1134 let mut p = 0usize;
1135 while p < k {
1136 let tk = min(tile_k, k - p);
1137
1138 let mut a_tile = Vec::with_capacity(tm * tk);
1139 for ii in 0..tm {
1140 let row = i + ii;
1141 let src_off = row * k + p;
1142 a_tile.extend_from_slice(&a.data[src_off..src_off + tk]);
1143 }
1144
1145 let mut b_tile = Vec::with_capacity(tk * tn);
1146 for kk in 0..tk {
1147 let src_off = (p + kk) * n + j;
1148 b_tile.extend_from_slice(&b.data[src_off..src_off + tn]);
1149 }
1150
1151 let a_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1152 label: Some("a_tile"),
1153 contents: bytemuck::cast_slice(&a_tile),
1154 usage: wgpu::BufferUsages::STORAGE,
1155 });
1156 let b_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1157 label: Some("b_tile"),
1158 contents: bytemuck::cast_slice(&b_tile),
1159 usage: wgpu::BufferUsages::STORAGE,
1160 });
1161
1162 let params = [tm as u32, tn as u32, tk as u32];
1163 let params_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1164 label: Some("params_tile"),
1165 contents: bytemuck::cast_slice(¶ms),
1166 usage: wgpu::BufferUsages::UNIFORM,
1167 });
1168
1169 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
1170 label: Some("bg_tile"),
1171 layout: &bgl,
1172 entries: &[
1173 wgpu::BindGroupEntry {
1174 binding: 0,
1175 resource: a_buf.as_entire_binding(),
1176 },
1177 wgpu::BindGroupEntry {
1178 binding: 1,
1179 resource: b_buf.as_entire_binding(),
1180 },
1181 wgpu::BindGroupEntry {
1182 binding: 2,
1183 resource: out_buf.as_entire_binding(),
1184 },
1185 wgpu::BindGroupEntry {
1186 binding: 3,
1187 resource: params_buf.as_entire_binding(),
1188 },
1189 ],
1190 });
1191
1192 let mut encoder =
1193 device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
1194 label: Some("ce_tile"),
1195 });
1196 {
1197 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1198 label: Some("cp_tile"),
1199 timestamp_writes: None,
1200 });
1201 cpass.set_pipeline(&compute_pipeline);
1202 cpass.set_bind_group(0, &bind_group, &[]);
1203 let wg_x = ((tn as u32) + 15) / 16;
1204 let wg_y = ((tm as u32) + 15) / 16;
1205 cpass.dispatch_workgroups(wg_x, wg_y, 1);
1206 }
1207
1208 queue.submit(Some(encoder.finish()));
1209 device.poll(wgpu::Maintain::Wait);
1210
1211 p += tk;
1212 }
1213
1214 let staging = device.create_buffer(&wgpu::BufferDescriptor {
1215 label: Some("staging_out_tile"),
1216 size: out_size as u64,
1217 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
1218 mapped_at_creation: false,
1219 });
1220 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
1221 label: Some("ce_copy_out"),
1222 });
1223 encoder.copy_buffer_to_buffer(&out_buf, 0, &staging, 0, out_size as u64);
1224 queue.submit(Some(encoder.finish()));
1225
1226 let buffer_slice = staging.slice(..);
1227 use std::sync::mpsc::channel;
1228 let (tx, rx) = channel();
1229 buffer_slice.map_async(wgpu::MapMode::Read, move |r| {
1230 let _ = tx.send(r);
1231 });
1232 device.poll(wgpu::Maintain::Wait);
1233 let ok = rx
1234 .recv()
1235 .map_err(|_| anyhow!("map callback channel error"))?;
1236 ok.map_err(|e| anyhow!("map async failed: {:?}", e))?;
1237 let data = buffer_slice.get_mapped_range();
1238
1239 let mut idx = 0usize;
1240 for rr in 0..tm {
1241 let dest_off = (i + rr) * n + j;
1242 let row_bytes = &data[idx * 4..(idx + tn) * 4];
1243 for cc in 0..tn {
1244 let b0 = row_bytes[cc * 4..cc * 4 + 4].try_into().unwrap();
1245 result[dest_off + cc] = f32::from_bits(u32::from_le_bytes(b0));
1246 }
1247 idx += tn;
1248 }
1249 drop(data);
1250 staging.unmap();
1251 }
1252 }
1253
1254 Ok(Array::new(vec![m, n], result))
1255 })
1256}
1257
1258fn run_matmul_gpu(a: &Array, b: &Array) -> Result<Array> {
1260 #[cfg(target_arch = "wasm32")]
1262 {
1263 return run_matmul_gpu_streaming(a, b);
1264 }
1265
1266 #[cfg(not(target_arch = "wasm32"))]
1268 {
1269 with_gpu_device!(dq, {
1270 let device = &dq.device;
1271
1272 let m = a.shape[0] as usize;
1273 let k = a.shape[1] as usize;
1274 let n = b.shape[1] as usize;
1275
1276 let bytes_a = m * k * std::mem::size_of::<f32>();
1277 let bytes_b = k * n * std::mem::size_of::<f32>();
1278 let bytes_out = m * n * std::mem::size_of::<f32>();
1279
1280 let max = device.limits().max_storage_buffer_binding_size as usize;
1281
1282 if bytes_a <= max && bytes_b <= max && bytes_out <= max {
1283 run_matmul_gpu_fast(a, b)
1284 } else {
1285 run_matmul_gpu_streaming(a, b)
1286 }
1287 })
1288 }
1289}
1290
1291fn run_reduction_gpu(a: &Array, axis: Option<usize>) -> Result<Array> {
1292 use wgpu::util::DeviceExt;
1293
1294 if axis.is_some() {
1295 return Err(anyhow!(
1296 "axis-based reduction not implemented in GPU prototype"
1297 ));
1298 }
1299
1300 let size = a.len() as u32;
1301 if size == 0 {
1302 return Ok(Array::new(vec![1], vec![0.0]));
1303 }
1304
1305 const WG_SIZE: u32 = 256u32;
1307
1308 let shader = format!(
1310 r#"
1311struct Params {{ size: u32, }};
1312@group(0) @binding(0) var<storage, read> data: array<f32>;
1313@group(0) @binding(1) var<storage, read_write> partials: array<f32>;
1314@group(0) @binding(2) var<uniform> params: Params;
1315
1316var<workgroup> sdata: array<f32, {wg}>;
1317
1318@compute @workgroup_size({wg})
1319fn main(@builtin(global_invocation_id) gid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>, @builtin(workgroup_id) wid: vec3<u32>) {{
1320 let local = lid.x;
1321 let group = wid.x;
1322 let idx = group * {wg}u + local;
1323 var v: f32 = 0.0;
1324 if (idx < params.size) {{ v = data[idx]; }}
1325 sdata[local] = v;
1326 workgroupBarrier();
1327
1328 var stride: u32 = {wg}u / 2u;
1329 loop {{
1330 if (stride == 0u) {{ break; }}
1331 if (local < stride) {{
1332 sdata[local] = sdata[local] + sdata[local + stride];
1333 }}
1334 workgroupBarrier();
1335 stride = stride / 2u;
1336 }}
1337
1338 if (local == 0u) {{
1339 partials[group] = sdata[0];
1340 }}
1341}}
1342"#,
1343 wg = WG_SIZE
1344 );
1345
1346 with_gpu_device!(dq, {
1348 let device = &dq.device;
1349 let queue = &dq.queue;
1350
1351 #[cfg(not(target_arch = "wasm32"))]
1353 let pipe_res = REDUCTION_PIPELINE.get_or_init(
1354 || -> Result<(wgpu::ComputePipeline, wgpu::BindGroupLayout), anyhow::Error> {
1355 let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
1356 label: Some("reduction_shader"),
1357 source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Owned(shader.clone())),
1358 });
1359
1360 let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
1361 label: Some("bgl_reduction"),
1362 entries: &[
1363 wgpu::BindGroupLayoutEntry {
1364 binding: 0,
1365 visibility: wgpu::ShaderStages::COMPUTE,
1366 ty: wgpu::BindingType::Buffer {
1367 ty: wgpu::BufferBindingType::Storage { read_only: true },
1368 has_dynamic_offset: false,
1369 min_binding_size: None,
1370 },
1371 count: None,
1372 },
1373 wgpu::BindGroupLayoutEntry {
1374 binding: 1,
1375 visibility: wgpu::ShaderStages::COMPUTE,
1376 ty: wgpu::BindingType::Buffer {
1377 ty: wgpu::BufferBindingType::Storage { read_only: false },
1378 has_dynamic_offset: false,
1379 min_binding_size: None,
1380 },
1381 count: None,
1382 },
1383 wgpu::BindGroupLayoutEntry {
1384 binding: 2,
1385 visibility: wgpu::ShaderStages::COMPUTE,
1386 ty: wgpu::BindingType::Buffer {
1387 ty: wgpu::BufferBindingType::Uniform,
1388 has_dynamic_offset: false,
1389 min_binding_size: None,
1390 },
1391 count: None,
1392 },
1393 ],
1394 });
1395
1396 let pipeline_layout =
1397 device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
1398 label: Some("pl_reduction"),
1399 bind_group_layouts: &[&bgl],
1400 push_constant_ranges: &[],
1401 });
1402
1403 let compute_pipeline =
1404 device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1405 label: Some("pipeline_reduction"),
1406 layout: Some(&pipeline_layout),
1407 module: &shader_module,
1408 entry_point: Some("main"),
1409 cache: None,
1410 compilation_options: Default::default(),
1411 });
1412
1413 Ok((compute_pipeline, bgl))
1414 },
1415 );
1416
1417 #[cfg(not(target_arch = "wasm32"))]
1418 let (compute_pipeline, bgl) = pipe_res
1419 .as_ref()
1420 .map_err(|e| anyhow!("reduction pipeline init failed: {:?}", e))?;
1421
1422 #[cfg(target_arch = "wasm32")]
1424 let (compute_pipeline, bgl) = {
1425 let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
1426 label: Some("reduction_shader"),
1427 source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Owned(shader.clone())),
1428 });
1429
1430 let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
1431 label: Some("bgl_reduction"),
1432 entries: &[
1433 wgpu::BindGroupLayoutEntry {
1434 binding: 0,
1435 visibility: wgpu::ShaderStages::COMPUTE,
1436 ty: wgpu::BindingType::Buffer {
1437 ty: wgpu::BufferBindingType::Storage { read_only: true },
1438 has_dynamic_offset: false,
1439 min_binding_size: None,
1440 },
1441 count: None,
1442 },
1443 wgpu::BindGroupLayoutEntry {
1444 binding: 1,
1445 visibility: wgpu::ShaderStages::COMPUTE,
1446 ty: wgpu::BindingType::Buffer {
1447 ty: wgpu::BufferBindingType::Storage { read_only: false },
1448 has_dynamic_offset: false,
1449 min_binding_size: None,
1450 },
1451 count: None,
1452 },
1453 wgpu::BindGroupLayoutEntry {
1454 binding: 2,
1455 visibility: wgpu::ShaderStages::COMPUTE,
1456 ty: wgpu::BindingType::Buffer {
1457 ty: wgpu::BufferBindingType::Uniform,
1458 has_dynamic_offset: false,
1459 min_binding_size: None,
1460 },
1461 count: None,
1462 },
1463 ],
1464 });
1465
1466 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
1467 label: Some("pl_reduction"),
1468 bind_group_layouts: &[&bgl],
1469 push_constant_ranges: &[],
1470 });
1471
1472 let compute_pipeline =
1473 device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1474 label: Some("pipeline_reduction"),
1475 layout: Some(&pipeline_layout),
1476 module: &shader_module,
1477 entry_point: Some("main"),
1478 cache: None,
1479 compilation_options: Default::default(),
1480 });
1481
1482 (compute_pipeline, bgl)
1483 };
1484
1485 let mut current_size = size as u32;
1487 let mut in_buf = {
1488 let data_bytes = bytemuck::cast_slice(&a.data);
1489 let buf = device.create_buffer(&wgpu::BufferDescriptor {
1490 label: Some("reduce_data"),
1491 size: data_bytes.len() as u64,
1492 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
1493 mapped_at_creation: false,
1494 });
1495 queue.write_buffer(&buf, 0, data_bytes);
1496 buf
1497 };
1498
1499 let final_value: f32;
1501
1502 loop {
1503 let groups = ((current_size + WG_SIZE - 1) / WG_SIZE) as u32;
1504
1505 let out_buf = device.create_buffer(&wgpu::BufferDescriptor {
1506 label: Some("partials"),
1507 size: (groups as usize * std::mem::size_of::<f32>()) as u64,
1508 usage: wgpu::BufferUsages::STORAGE
1509 | wgpu::BufferUsages::COPY_SRC
1510 | wgpu::BufferUsages::COPY_DST,
1511 mapped_at_creation: false,
1512 });
1513
1514 let params = [current_size];
1515 let params_bytes = bytemuck::cast_slice(¶ms);
1516 let params_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1517 label: Some("params_reduction"),
1518 contents: params_bytes,
1519 usage: wgpu::BufferUsages::UNIFORM,
1520 });
1521
1522 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
1523 label: Some("bg_reduction"),
1524 layout: &bgl,
1525 entries: &[
1526 wgpu::BindGroupEntry {
1527 binding: 0,
1528 resource: in_buf.as_entire_binding(),
1529 },
1530 wgpu::BindGroupEntry {
1531 binding: 1,
1532 resource: out_buf.as_entire_binding(),
1533 },
1534 wgpu::BindGroupEntry {
1535 binding: 2,
1536 resource: params_buf.as_entire_binding(),
1537 },
1538 ],
1539 });
1540
1541 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
1543 label: Some("ce_reduction_iter"),
1544 });
1545 {
1546 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1547 label: Some("cp_reduction_iter"),
1548 timestamp_writes: None,
1549 });
1550 cpass.set_pipeline(&compute_pipeline);
1551 cpass.set_bind_group(0, &bind_group, &[]);
1552 cpass.dispatch_workgroups(groups, 1, 1);
1553 }
1554
1555 if groups == 1 {
1557 let staging = device.create_buffer(&wgpu::BufferDescriptor {
1558 label: Some("staging_final"),
1559 size: std::mem::size_of::<f32>() as u64,
1560 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
1561 mapped_at_creation: false,
1562 });
1563 encoder.copy_buffer_to_buffer(
1564 &out_buf,
1565 0,
1566 &staging,
1567 0,
1568 std::mem::size_of::<f32>() as u64,
1569 );
1570 queue.submit(Some(encoder.finish()));
1571
1572 let buffer_slice = staging.slice(..);
1574 use std::sync::mpsc::channel;
1575 let (tx, rx) = channel();
1576 buffer_slice.map_async(wgpu::MapMode::Read, move |r| {
1577 let _ = tx.send(r);
1578 });
1579 device.poll(wgpu::Maintain::Wait);
1580 let ok = rx
1581 .recv()
1582 .map_err(|_| anyhow!("map callback channel error"))?;
1583 ok.map_err(|e| anyhow!("map async failed: {:?}", e))?;
1584
1585 let data = buffer_slice.get_mapped_range();
1586 let b = [data[0], data[1], data[2], data[3]];
1587 final_value = f32::from_bits(u32::from_le_bytes(b));
1588 drop(data);
1589 staging.unmap();
1590 break;
1591 } else {
1592 queue.submit(Some(encoder.finish()));
1594 in_buf = out_buf;
1595 current_size = groups;
1596 }
1598 }
1599
1600 let total = final_value;
1601 Ok(Array::new(vec![1], vec![total]))
1602 })
1603}
1604
1605fn run_broadcast_gpu(a: &Array, target_shape: &[usize]) -> Result<Array> {
1607 use wgpu::util::DeviceExt;
1608
1609 let src_ndim = a.shape.len() as u32;
1610 let target_ndim = target_shape.len() as u32;
1611 let target_size: usize = target_shape.iter().product();
1612
1613 if target_size == 0 {
1614 return Ok(Array::new(target_shape.to_vec(), vec![]));
1615 }
1616
1617 let mut src_shape_padded = vec![1u32; 4];
1619 let mut target_shape_padded = vec![1u32; 4];
1620 let mut src_strides = vec![0u32; 4];
1621
1622 for i in 0..src_ndim.min(4) as usize {
1624 src_shape_padded[4 - src_ndim as usize + i] = a.shape[i] as u32;
1625 }
1626 for i in 0..target_ndim.min(4) as usize {
1627 target_shape_padded[4 - target_ndim as usize + i] = target_shape[i] as u32;
1628 }
1629
1630 let mut stride = 1u32;
1632 for i in (0..src_ndim as usize).rev() {
1633 let idx = 4 - src_ndim as usize + i;
1634 src_strides[idx] = stride;
1635 stride *= a.shape[i] as u32;
1636 }
1637
1638 let shader_code = format!(
1640 r#"
1641struct Params {{
1642 src_shape: vec4<u32>,
1643 target_shape: vec4<u32>,
1644 src_strides: vec4<u32>,
1645 src_ndim: u32,
1646 target_ndim: u32,
1647}};
1648
1649@group(0) @binding(0) var<storage, read> src: array<f32>;
1650@group(0) @binding(1) var<storage, read_write> dst: array<f32>;
1651@group(0) @binding(2) var<uniform> params: Params;
1652
1653@compute @workgroup_size(256)
1654fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{
1655 let idx = gid.x;
1656 if (idx >= {target_size}u) {{ return; }}
1657
1658 // Convertir índice plano a multi-índice en target
1659 var target_idx: vec4<u32> = vec4<u32>(0u);
1660 var remaining = idx;
1661
1662 for (var i = 0u; i < 4u; i++) {{
1663 let dim_idx = 3u - i;
1664 if (dim_idx >= (4u - params.target_ndim)) {{
1665 target_idx[dim_idx] = remaining % params.target_shape[dim_idx];
1666 remaining = remaining / params.target_shape[dim_idx];
1667 }}
1668 }}
1669
1670 // Mapear a índice en source (con broadcasting)
1671 var src_flat_idx = 0u;
1672 let src_start_dim = 4u - params.src_ndim;
1673 let target_start_dim = 4u - params.target_ndim;
1674
1675 for (var i = 0u; i < params.src_ndim; i++) {{
1676 let src_dim_idx = src_start_dim + i;
1677 let target_dim_idx = target_start_dim + i;
1678 let src_dim = params.src_shape[src_dim_idx];
1679
1680 // Si la dimensión es 1, usar índice 0 (broadcasting)
1681 var idx_val: u32;
1682 if (src_dim == 1u) {{
1683 idx_val = 0u;
1684 }} else {{
1685 idx_val = target_idx[target_dim_idx];
1686 }}
1687
1688 src_flat_idx += idx_val * params.src_strides[src_dim_idx];
1689 }}
1690
1691 dst[idx] = src[src_flat_idx];
1692}}
1693"#,
1694 target_size = target_size
1695 );
1696
1697 let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
1699 backends: wgpu::Backends::all(),
1700 ..Default::default()
1701 });
1702
1703 let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
1704 power_preference: wgpu::PowerPreference::HighPerformance,
1705 compatible_surface: None,
1706 force_fallback_adapter: false,
1707 }))
1708 .ok_or_else(|| anyhow!("no WebGPU adapter available"))?;
1709
1710 let (device, queue) = pollster::block_on(adapter.request_device(
1711 &wgpu::DeviceDescriptor {
1712 label: Some("broadcast device"),
1713 required_features: wgpu::Features::empty(),
1714 required_limits: wgpu::Limits::default(),
1715 memory_hints: Default::default(),
1716 },
1717 None,
1718 ))
1719 .map_err(|e| anyhow!("request device failed: {:?}", e))?;
1720
1721 let src_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1723 label: Some("broadcast src"),
1724 contents: bytemuck::cast_slice(&a.data),
1725 usage: wgpu::BufferUsages::STORAGE,
1726 });
1727
1728 let dst_buffer = device.create_buffer(&wgpu::BufferDescriptor {
1729 label: Some("broadcast dst"),
1730 size: (target_size * 4) as u64,
1731 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
1732 mapped_at_creation: false,
1733 });
1734
1735 let params_data = [
1737 src_shape_padded[0],
1738 src_shape_padded[1],
1739 src_shape_padded[2],
1740 src_shape_padded[3],
1741 target_shape_padded[0],
1742 target_shape_padded[1],
1743 target_shape_padded[2],
1744 target_shape_padded[3],
1745 src_strides[0],
1746 src_strides[1],
1747 src_strides[2],
1748 src_strides[3],
1749 src_ndim,
1750 target_ndim,
1751 0,
1752 0, ];
1754
1755 let params_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1756 label: Some("broadcast params"),
1757 contents: bytemuck::cast_slice(¶ms_data),
1758 usage: wgpu::BufferUsages::UNIFORM,
1759 });
1760
1761 let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
1763 label: Some("broadcast shader"),
1764 source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(&shader_code)),
1765 });
1766
1767 let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
1769 label: Some("broadcast bgl"),
1770 entries: &[
1771 wgpu::BindGroupLayoutEntry {
1772 binding: 0,
1773 visibility: wgpu::ShaderStages::COMPUTE,
1774 ty: wgpu::BindingType::Buffer {
1775 ty: wgpu::BufferBindingType::Storage { read_only: true },
1776 has_dynamic_offset: false,
1777 min_binding_size: None,
1778 },
1779 count: None,
1780 },
1781 wgpu::BindGroupLayoutEntry {
1782 binding: 1,
1783 visibility: wgpu::ShaderStages::COMPUTE,
1784 ty: wgpu::BindingType::Buffer {
1785 ty: wgpu::BufferBindingType::Storage { read_only: false },
1786 has_dynamic_offset: false,
1787 min_binding_size: None,
1788 },
1789 count: None,
1790 },
1791 wgpu::BindGroupLayoutEntry {
1792 binding: 2,
1793 visibility: wgpu::ShaderStages::COMPUTE,
1794 ty: wgpu::BindingType::Buffer {
1795 ty: wgpu::BufferBindingType::Uniform,
1796 has_dynamic_offset: false,
1797 min_binding_size: None,
1798 },
1799 count: None,
1800 },
1801 ],
1802 });
1803
1804 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
1805 label: Some("broadcast pipeline layout"),
1806 bind_group_layouts: &[&bind_group_layout],
1807 push_constant_ranges: &[],
1808 });
1809
1810 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
1811 label: Some("broadcast pipeline"),
1812 layout: Some(&pipeline_layout),
1813 module: &shader_module,
1814 entry_point: Some("main"),
1815 cache: None,
1816 compilation_options: Default::default(),
1817 });
1818
1819 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
1820 label: Some("broadcast bind group"),
1821 layout: &bind_group_layout,
1822 entries: &[
1823 wgpu::BindGroupEntry {
1824 binding: 0,
1825 resource: src_buffer.as_entire_binding(),
1826 },
1827 wgpu::BindGroupEntry {
1828 binding: 1,
1829 resource: dst_buffer.as_entire_binding(),
1830 },
1831 wgpu::BindGroupEntry {
1832 binding: 2,
1833 resource: params_buffer.as_entire_binding(),
1834 },
1835 ],
1836 });
1837
1838 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
1840 label: Some("broadcast encoder"),
1841 });
1842
1843 {
1844 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1845 label: Some("broadcast pass"),
1846 timestamp_writes: None,
1847 });
1848 cpass.set_pipeline(&pipeline);
1849 cpass.set_bind_group(0, &bind_group, &[]);
1850 let workgroups = (target_size + 255) / 256;
1851 cpass.dispatch_workgroups(workgroups as u32, 1, 1);
1852 }
1853
1854 let staging = device.create_buffer(&wgpu::BufferDescriptor {
1856 label: Some("broadcast staging"),
1857 size: (target_size * 4) as u64,
1858 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
1859 mapped_at_creation: false,
1860 });
1861
1862 encoder.copy_buffer_to_buffer(&dst_buffer, 0, &staging, 0, (target_size * 4) as u64);
1863 queue.submit(Some(encoder.finish()));
1864
1865 let buffer_slice = staging.slice(..);
1867 use std::sync::mpsc::channel;
1868 let (tx, rx) = channel();
1869 buffer_slice.map_async(wgpu::MapMode::Read, move |r| {
1870 let _ = tx.send(r);
1871 });
1872 device.poll(wgpu::Maintain::Wait);
1873 let ok = rx.recv().map_err(|_| anyhow!("map callback error"))?;
1874 ok.map_err(|e| anyhow!("map async failed: {:?}", e))?;
1875
1876 let data = buffer_slice.get_mapped_range();
1877 let result: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
1878 drop(data);
1879 staging.unmap();
1880
1881 Ok(Array::new(target_shape.to_vec(), result))
1882}