1use super::backend_trait::{BackendCapabilities, BackendTrait, MemcpyKind};
9use async_trait::async_trait;
10use crate::{runtime_error, Result};
11use std::collections::HashMap;
12use std::sync::atomic::{AtomicUsize, Ordering};
13use std::sync::Mutex;
14
15const HANDLE_BASE: usize = 0x1_0000;
17
18const COPY_ALIGN: u64 = 4;
20
21fn aligned(size: usize) -> u64 {
23 let s = size as u64;
24 (s + COPY_ALIGN - 1) & !(COPY_ALIGN - 1)
25}
26
27pub struct WebGPUBackend {
29 capabilities: BackendCapabilities,
30 device: Option<wgpu::Device>,
31 queue: Option<wgpu::Queue>,
32 pipelines: Mutex<HashMap<u64, wgpu::ComputePipeline>>,
34 buffers: Mutex<HashMap<usize, (wgpu::Buffer, usize)>>,
36 next_pipeline_id: Mutex<u64>,
38 next_handle: AtomicUsize,
40}
41
42impl Default for WebGPUBackend {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl WebGPUBackend {
49 pub fn new() -> Self {
51 Self {
52 capabilities: BackendCapabilities {
53 name: "WebGPU (wgpu)".to_string(),
54 supports_cuda: false,
55 supports_opencl: false,
56 supports_vulkan: false,
57 supports_webgpu: true,
58 max_threads: 65535 * 256,
59 max_threads_per_block: 256,
60 max_blocks_per_grid: 65535,
61 max_shared_memory: 16384,
62 supports_dynamic_parallelism: false,
63 supports_unified_memory: false,
64 max_grid_dim: [65535, 65535, 65535],
65 max_block_dim: [256, 256, 64],
66 warp_size: 32,
67 },
68 device: None,
69 queue: None,
70 pipelines: Mutex::new(HashMap::new()),
71 buffers: Mutex::new(HashMap::new()),
72 next_pipeline_id: Mutex::new(1),
73 next_handle: AtomicUsize::new(HANDLE_BASE),
74 }
75 }
76
77 pub fn is_available() -> bool {
80 true
81 }
82
83 fn pipeline_id_to_bytes(id: u64) -> Vec<u8> {
85 id.to_le_bytes().to_vec()
86 }
87
88 fn bytes_to_pipeline_id(bytes: &[u8]) -> Result<u64> {
90 if bytes.len() < 8 {
91 return Err(runtime_error!("Invalid kernel handle: too short"));
92 }
93 let mut arr = [0u8; 8];
94 arr.copy_from_slice(&bytes[..8]);
95 Ok(u64::from_le_bytes(arr))
96 }
97
98 fn device(&self) -> Result<&wgpu::Device> {
99 self.device
100 .as_ref()
101 .ok_or_else(|| runtime_error!("Backend not initialized: call initialize() first"))
102 }
103
104 fn queue(&self) -> Result<&wgpu::Queue> {
105 self.queue
106 .as_ref()
107 .ok_or_else(|| runtime_error!("Backend not initialized: call initialize() first"))
108 }
109}
110
111unsafe impl Send for WebGPUBackend {}
112unsafe impl Sync for WebGPUBackend {}
113
114#[async_trait(?Send)]
115impl BackendTrait for WebGPUBackend {
116 fn name(&self) -> &str {
117 "WebGPU (wgpu)"
118 }
119
120 fn capabilities(&self) -> &BackendCapabilities {
121 &self.capabilities
122 }
123
124 async fn initialize(&mut self) -> Result<()> {
125 let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
126 backends: wgpu::Backends::all(),
127 ..Default::default()
128 });
129
130 let adapter = instance
131 .request_adapter(&wgpu::RequestAdapterOptions {
132 power_preference: wgpu::PowerPreference::HighPerformance,
133 compatible_surface: None,
134 force_fallback_adapter: false,
135 })
136 .await
137 .ok_or_else(|| runtime_error!("No WebGPU adapter found"))?;
138
139 let (device, queue) = adapter
140 .request_device(
141 &wgpu::DeviceDescriptor {
142 label: Some("cuda-wasm"),
143 required_features: wgpu::Features::empty(),
144 required_limits: wgpu::Limits::downlevel_defaults(),
145 },
146 None,
147 )
148 .await
149 .map_err(|e| runtime_error!("Failed to create wgpu device: {}", e))?;
150
151 self.device = Some(device);
152 self.queue = Some(queue);
153 Ok(())
154 }
155
156 async fn compile_kernel(&self, source: &str) -> Result<Vec<u8>> {
157 let device = self.device()?;
158
159 device.push_error_scope(wgpu::ErrorFilter::Validation);
161 let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
162 label: Some("kernel"),
163 source: wgpu::ShaderSource::Wgsl(source.into()),
164 });
165 device.poll(wgpu::Maintain::Wait);
166 if let Some(e) = pollster::block_on(device.pop_error_scope()) {
167 return Err(runtime_error!("Shader compilation failed: {}", e));
168 }
169
170 device.push_error_scope(wgpu::ErrorFilter::Validation);
172 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
173 label: Some("compute_pipeline"),
174 layout: None,
175 module: &module,
176 entry_point: "main",
177 });
178 device.poll(wgpu::Maintain::Wait);
179 if let Some(e) = pollster::block_on(device.pop_error_scope()) {
180 return Err(runtime_error!("Pipeline creation failed: {}", e));
181 }
182
183 let mut id_guard = self
184 .next_pipeline_id
185 .lock()
186 .map_err(|e| runtime_error!("Pipeline ID lock poisoned: {}", e))?;
187 let id = *id_guard;
188 *id_guard += 1;
189
190 self.pipelines
191 .lock()
192 .map_err(|e| runtime_error!("Pipeline lock poisoned: {}", e))?
193 .insert(id, pipeline);
194
195 Ok(Self::pipeline_id_to_bytes(id))
196 }
197
198 async fn launch_kernel(
199 &self,
200 kernel: &[u8],
201 grid: (u32, u32, u32),
202 _block: (u32, u32, u32),
203 args: &[*const u8],
204 ) -> Result<()> {
205 let arg_handles: Vec<usize> = args.iter().map(|p| *p as usize).collect();
208
209 let device = self.device()?;
210 let queue = self.queue()?;
211 let pipeline_id = Self::bytes_to_pipeline_id(kernel)?;
212
213 if grid.0 == 0 || grid.1 == 0 || grid.2 == 0 {
214 return Err(runtime_error!("Grid dimensions must be non-zero"));
215 }
216 if grid.0 > 65535 || grid.1 > 65535 || grid.2 > 65535 {
217 return Err(runtime_error!("Grid dimension exceeds maximum (65535)"));
218 }
219
220 let pipelines = self
221 .pipelines
222 .lock()
223 .map_err(|e| runtime_error!("Pipeline lock poisoned: {}", e))?;
224 let pipeline = pipelines
225 .get(&pipeline_id)
226 .ok_or_else(|| runtime_error!("Kernel not found: pipeline ID {}", pipeline_id))?;
227
228 let buffers_guard = self
229 .buffers
230 .lock()
231 .map_err(|e| runtime_error!("Buffer lock poisoned: {}", e))?;
232
233 let mut entries = Vec::with_capacity(arg_handles.len());
235 for (i, &handle) in arg_handles.iter().enumerate() {
236 let (buf, _) = buffers_guard
237 .get(&handle)
238 .ok_or_else(|| runtime_error!("Arg {} buffer handle {:#x} not found", i, handle))?;
239 entries.push(wgpu::BindGroupEntry {
240 binding: i as u32,
241 resource: buf.as_entire_binding(),
242 });
243 }
244
245 let bind_group = if !entries.is_empty() {
246 let layout = pipeline.get_bind_group_layout(0);
247 Some(device.create_bind_group(&wgpu::BindGroupDescriptor {
248 label: None,
249 layout: &layout,
250 entries: &entries,
251 }))
252 } else {
253 None
254 };
255
256 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
257 label: Some("compute_encoder"),
258 });
259 {
260 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
261 label: Some("compute_pass"),
262 timestamp_writes: None,
263 });
264 pass.set_pipeline(pipeline);
265 if let Some(bg) = &bind_group {
266 pass.set_bind_group(0, bg, &[]);
267 }
268 pass.dispatch_workgroups(grid.0, grid.1, grid.2);
269 }
270 queue.submit(std::iter::once(encoder.finish()));
271 device.poll(wgpu::Maintain::Wait);
272
273 Ok(())
274 }
275
276 fn allocate_memory(&self, size: usize) -> Result<*mut u8> {
277 if size == 0 {
278 return Err(runtime_error!("Cannot allocate zero bytes"));
279 }
280 let device = self.device()?;
281
282 let buffer = device.create_buffer(&wgpu::BufferDescriptor {
283 label: None,
284 size: aligned(size),
285 usage: wgpu::BufferUsages::STORAGE
286 | wgpu::BufferUsages::COPY_SRC
287 | wgpu::BufferUsages::COPY_DST,
288 mapped_at_creation: false,
289 });
290
291 let handle = self.next_handle.fetch_add(1, Ordering::SeqCst);
292 self.buffers
293 .lock()
294 .map_err(|e| runtime_error!("Buffer lock poisoned: {}", e))?
295 .insert(handle, (buffer, size));
296
297 Ok(handle as *mut u8)
298 }
299
300 fn free_memory(&self, ptr: *mut u8) -> Result<()> {
301 let handle = ptr as usize;
302 let (buffer, _) = self
303 .buffers
304 .lock()
305 .map_err(|e| runtime_error!("Buffer lock poisoned: {}", e))?
306 .remove(&handle)
307 .ok_or_else(|| runtime_error!("Attempted to free untracked handle {:#x}", handle))?;
308 drop(buffer);
309 Ok(())
310 }
311
312 fn copy_memory(
313 &self,
314 dst: *mut u8,
315 src: *const u8,
316 size: usize,
317 kind: MemcpyKind,
318 ) -> Result<()> {
319 if size == 0 {
320 return Ok(());
321 }
322 match kind {
323 MemcpyKind::HostToDevice => {
324 let queue = self.queue()?;
325 let device = self.device()?;
326 let dst_handle = dst as usize;
327 let buffers = self
328 .buffers
329 .lock()
330 .map_err(|e| runtime_error!("Buffer lock poisoned: {}", e))?;
331 let (gpu_buf, buf_size) = buffers
332 .get(&dst_handle)
333 .ok_or_else(|| runtime_error!("Dst buffer handle not found"))?;
334 if size > *buf_size {
335 return Err(runtime_error!(
336 "Copy size {} exceeds buffer size {}",
337 size,
338 buf_size
339 ));
340 }
341 let data = unsafe { std::slice::from_raw_parts(src, size) };
342 queue.write_buffer(gpu_buf, 0, data);
343 queue.submit(std::iter::empty());
344 device.poll(wgpu::Maintain::Wait);
345 Ok(())
346 }
347 MemcpyKind::DeviceToHost => {
348 let device = self.device()?;
349 let queue = self.queue()?;
350 let src_handle = src as usize;
351 let copy_size = aligned(size);
352 let buffers = self
353 .buffers
354 .lock()
355 .map_err(|e| runtime_error!("Buffer lock poisoned: {}", e))?;
356 let (gpu_buf, buf_size) = buffers
357 .get(&src_handle)
358 .ok_or_else(|| runtime_error!("Src buffer handle not found"))?;
359 if size > *buf_size {
360 return Err(runtime_error!(
361 "Copy size {} exceeds buffer size {}",
362 size,
363 buf_size
364 ));
365 }
366 let staging = device.create_buffer(&wgpu::BufferDescriptor {
367 label: Some("staging_read"),
368 size: copy_size,
369 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
370 mapped_at_creation: false,
371 });
372 let mut encoder =
373 device.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
374 encoder.copy_buffer_to_buffer(gpu_buf, 0, &staging, 0, copy_size);
375 queue.submit(std::iter::once(encoder.finish()));
376
377 let slice = staging.slice(..);
378 let (tx, rx) = std::sync::mpsc::channel();
379 slice.map_async(wgpu::MapMode::Read, move |result| {
380 tx.send(result).ok();
381 });
382 device.poll(wgpu::Maintain::Wait);
383 rx.recv()
384 .map_err(|_| runtime_error!("Buffer map channel closed"))?
385 .map_err(|e| runtime_error!("Buffer map failed: {:?}", e))?;
386
387 let mapped = slice.get_mapped_range();
388 unsafe {
389 std::ptr::copy_nonoverlapping(mapped.as_ptr(), dst, size);
390 }
391 drop(mapped);
392 staging.unmap();
393 Ok(())
394 }
395 MemcpyKind::DeviceToDevice => {
396 let device = self.device()?;
397 let queue = self.queue()?;
398 let src_handle = src as usize;
399 let dst_handle = dst as usize;
400 let copy_size = aligned(size);
401 let buffers = self
402 .buffers
403 .lock()
404 .map_err(|e| runtime_error!("Buffer lock poisoned: {}", e))?;
405 let (src_buf, _) = buffers
406 .get(&src_handle)
407 .ok_or_else(|| runtime_error!("Src buffer handle not found"))?;
408 let (dst_buf, _) = buffers
409 .get(&dst_handle)
410 .ok_or_else(|| runtime_error!("Dst buffer handle not found"))?;
411 let mut encoder =
412 device.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
413 encoder.copy_buffer_to_buffer(src_buf, 0, dst_buf, 0, copy_size);
414 queue.submit(std::iter::once(encoder.finish()));
415 device.poll(wgpu::Maintain::Wait);
416 Ok(())
417 }
418 MemcpyKind::HostToHost => {
419 if dst.is_null() || src.is_null() {
420 return Err(runtime_error!("Null pointer in host memory copy"));
421 }
422 unsafe { std::ptr::copy_nonoverlapping(src, dst, size) };
423 Ok(())
424 }
425 }
426 }
427
428 fn synchronize(&self) -> Result<()> {
429 if let Some(device) = &self.device {
430 device.poll(wgpu::Maintain::Wait);
431 }
432 Ok(())
433 }
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439
440 fn try_init_backend() -> Option<WebGPUBackend> {
442 let mut backend = WebGPUBackend::new();
443 pollster::block_on(backend.initialize()).ok()?;
444 Some(backend)
445 }
446
447 #[test]
450 fn test_backend_creation() {
451 let backend = WebGPUBackend::new();
452 assert_eq!(backend.name(), "WebGPU (wgpu)");
453 assert!(backend.capabilities().supports_webgpu);
454 }
455
456 #[test]
457 fn test_is_available() {
458 assert!(WebGPUBackend::is_available());
459 }
460
461 #[test]
462 fn test_capabilities() {
463 let backend = WebGPUBackend::new();
464 let caps = backend.capabilities();
465 assert_eq!(caps.warp_size, 32);
466 assert!(caps.max_shared_memory > 0);
467 }
468
469 #[test]
470 fn test_pipeline_id_roundtrip() {
471 let id = 12345u64;
472 let bytes = WebGPUBackend::pipeline_id_to_bytes(id);
473 assert_eq!(bytes.len(), 8);
474 assert_eq!(WebGPUBackend::bytes_to_pipeline_id(&bytes).unwrap(), id);
475 }
476
477 #[test]
478 fn test_pipeline_id_short_fails() {
479 assert!(WebGPUBackend::bytes_to_pipeline_id(&[1, 2]).is_err());
480 }
481
482 #[test]
483 fn test_allocate_zero_fails() {
484 let backend = WebGPUBackend::new();
485 assert!(backend.allocate_memory(0).is_err());
486 }
487
488 #[test]
489 fn test_uninitialized_allocate_fails() {
490 let backend = WebGPUBackend::new();
491 assert!(backend.allocate_memory(1024).is_err());
492 }
493
494 #[test]
495 fn test_free_untracked_fails() {
496 let backend = WebGPUBackend::new();
497 let fake = 0xDEAD as *mut u8;
498 assert!(backend.free_memory(fake).is_err());
499 }
500
501 #[test]
502 fn test_copy_zero_noop() {
503 let backend = WebGPUBackend::new();
504 let a = 1 as *mut u8;
505 backend
506 .copy_memory(a, a, 0, MemcpyKind::DeviceToDevice)
507 .unwrap();
508 }
509
510 #[test]
511 fn test_host_to_host_copy() {
512 let backend = WebGPUBackend::new();
513 let src = vec![1u8, 2, 3, 4];
514 let mut dst = vec![0u8; 4];
515 backend
516 .copy_memory(dst.as_mut_ptr(), src.as_ptr(), 4, MemcpyKind::HostToHost)
517 .unwrap();
518 assert_eq!(dst, vec![1, 2, 3, 4]);
519 }
520
521 #[test]
522 fn test_host_to_host_null_fails() {
523 let backend = WebGPUBackend::new();
524 let ptr = vec![0u8; 64];
525 assert!(backend
526 .copy_memory(std::ptr::null_mut(), ptr.as_ptr(), 64, MemcpyKind::HostToHost)
527 .is_err());
528 }
529
530 #[test]
531 fn test_synchronize_uninitialized() {
532 let backend = WebGPUBackend::new();
533 backend.synchronize().unwrap();
534 }
535
536 #[test]
539 fn test_gpu_allocate_and_free() {
540 let backend = match try_init_backend() {
541 Some(b) => b,
542 None => {
543 eprintln!("Skipping test_gpu_allocate_and_free: no GPU adapter");
544 return;
545 }
546 };
547 let handle = backend.allocate_memory(1024).unwrap();
548 assert!(!handle.is_null());
549 assert!(handle as usize >= HANDLE_BASE);
550 backend.free_memory(handle).unwrap();
551 }
552
553 #[test]
554 fn test_gpu_data_roundtrip() {
555 let backend = match try_init_backend() {
556 Some(b) => b,
557 None => {
558 eprintln!("Skipping test_gpu_data_roundtrip: no GPU adapter");
559 return;
560 }
561 };
562 let data: Vec<u8> = (0..256).map(|i| i as u8).collect();
563 let gpu_buf = backend.allocate_memory(256).unwrap();
564
565 backend
566 .copy_memory(gpu_buf, data.as_ptr(), 256, MemcpyKind::HostToDevice)
567 .unwrap();
568
569 let mut readback = vec![0u8; 256];
570 backend
571 .copy_memory(
572 readback.as_mut_ptr(),
573 gpu_buf as *const u8,
574 256,
575 MemcpyKind::DeviceToHost,
576 )
577 .unwrap();
578
579 assert_eq!(readback, data);
580 backend.free_memory(gpu_buf).unwrap();
581 }
582
583 #[test]
584 fn test_gpu_device_to_device_copy() {
585 let backend = match try_init_backend() {
586 Some(b) => b,
587 None => {
588 eprintln!("Skipping test_gpu_device_to_device_copy: no GPU adapter");
589 return;
590 }
591 };
592 let data: Vec<u8> = (0..128).map(|i| (i * 2) as u8).collect();
593 let buf_a = backend.allocate_memory(128).unwrap();
594 let buf_b = backend.allocate_memory(128).unwrap();
595
596 backend
597 .copy_memory(buf_a, data.as_ptr(), 128, MemcpyKind::HostToDevice)
598 .unwrap();
599 backend
600 .copy_memory(buf_b, buf_a as *const u8, 128, MemcpyKind::DeviceToDevice)
601 .unwrap();
602
603 let mut readback = vec![0u8; 128];
604 backend
605 .copy_memory(
606 readback.as_mut_ptr(),
607 buf_b as *const u8,
608 128,
609 MemcpyKind::DeviceToHost,
610 )
611 .unwrap();
612
613 assert_eq!(readback, data);
614 backend.free_memory(buf_a).unwrap();
615 backend.free_memory(buf_b).unwrap();
616 }
617
618 #[test]
619 fn test_gpu_synchronize() {
620 let backend = match try_init_backend() {
621 Some(b) => b,
622 None => {
623 eprintln!("Skipping test_gpu_synchronize: no GPU adapter");
624 return;
625 }
626 };
627 backend.synchronize().unwrap();
628 }
629
630 #[tokio::test]
631 async fn test_gpu_compile_valid_wgsl() {
632 let backend = match try_init_backend() {
633 Some(b) => b,
634 None => {
635 eprintln!("Skipping test_gpu_compile_valid_wgsl: no GPU adapter");
636 return;
637 }
638 };
639 let kernel = backend
640 .compile_kernel("@compute @workgroup_size(64) fn main() {}")
641 .await
642 .unwrap();
643 assert_eq!(kernel.len(), 8);
644 }
645
646 #[tokio::test]
647 async fn test_gpu_compile_invalid_wgsl() {
648 let backend = match try_init_backend() {
649 Some(b) => b,
650 None => {
651 eprintln!("Skipping test_gpu_compile_invalid_wgsl: no GPU adapter");
652 return;
653 }
654 };
655 assert!(backend.compile_kernel("not valid wgsl").await.is_err());
656 }
657
658 #[tokio::test]
659 async fn test_gpu_launch_missing_kernel() {
660 let backend = match try_init_backend() {
661 Some(b) => b,
662 None => {
663 eprintln!("Skipping test_gpu_launch_missing_kernel: no GPU adapter");
664 return;
665 }
666 };
667 let fake = WebGPUBackend::pipeline_id_to_bytes(999);
668 assert!(backend
669 .launch_kernel(&fake, (1, 1, 1), (64, 1, 1), &[])
670 .await
671 .is_err());
672 }
673
674 #[tokio::test]
675 async fn test_gpu_compile_and_launch() {
676 let backend = match try_init_backend() {
677 Some(b) => b,
678 None => {
679 eprintln!("Skipping test_gpu_compile_and_launch: no GPU adapter");
680 return;
681 }
682 };
683 let kernel = backend
684 .compile_kernel("@compute @workgroup_size(64) fn main() {}")
685 .await
686 .unwrap();
687 backend
688 .launch_kernel(&kernel, (1, 1, 1), (64, 1, 1), &[])
689 .await
690 .unwrap();
691 }
692}