Skip to main content

cuda_rust_wasm/backend/
webgpu.rs

1//! WebGPU backend implementation using wgpu
2//!
3//! Provides REAL GPU compute via WebGPU/wgpu with native device, queue, and pipeline
4//! management. Buffer handles returned by `allocate_memory` are synthetic pointers
5//! that map to real `wgpu::Buffer` objects stored internally, bridging the
6//! `BackendTrait` raw-pointer API with wgpu's owned buffer model.
7
8use super::backend_trait::{BackendCapabilities, BackendTrait, MemcpyKind};
9use async_trait::async_trait;
10use crate::{runtime_error, Result};
11use std::collections::HashMap;
12use std::sync::atomic::{AtomicUsize, Ordering};
13use std::sync::Mutex;
14
15/// Base address for synthetic GPU buffer handles (avoids null / low addresses).
16const HANDLE_BASE: usize = 0x1_0000;
17
18/// wgpu requires `copy_buffer_to_buffer` sizes aligned to 4 bytes.
19const COPY_ALIGN: u64 = 4;
20
21/// Round `size` up to the next multiple of [`COPY_ALIGN`].
22fn aligned(size: usize) -> u64 {
23    let s = size as u64;
24    (s + COPY_ALIGN - 1) & !(COPY_ALIGN - 1)
25}
26
27/// WebGPU backend using wgpu for cross-platform GPU compute.
28pub struct WebGPUBackend {
29    capabilities: BackendCapabilities,
30    device: Option<wgpu::Device>,
31    queue: Option<wgpu::Queue>,
32    /// Compiled compute pipelines keyed by pipeline ID.
33    pipelines: Mutex<HashMap<u64, wgpu::ComputePipeline>>,
34    /// GPU buffers keyed by synthetic handle address -> (Buffer, requested byte size).
35    buffers: Mutex<HashMap<usize, (wgpu::Buffer, usize)>>,
36    /// Next pipeline ID counter.
37    next_pipeline_id: Mutex<u64>,
38    /// Monotonic counter for generating unique buffer handles.
39    next_handle: AtomicUsize,
40}
41
42impl Default for WebGPUBackend {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl WebGPUBackend {
49    /// Create a new WebGPU backend. Call [`initialize`] before any GPU operations.
50    pub fn new() -> Self {
51        Self {
52            capabilities: BackendCapabilities {
53                name: "WebGPU (wgpu)".to_string(),
54                supports_cuda: false,
55                supports_opencl: false,
56                supports_vulkan: false,
57                supports_webgpu: true,
58                max_threads: 65535 * 256,
59                max_threads_per_block: 256,
60                max_blocks_per_grid: 65535,
61                max_shared_memory: 16384,
62                supports_dynamic_parallelism: false,
63                supports_unified_memory: false,
64                max_grid_dim: [65535, 65535, 65535],
65                max_block_dim: [256, 256, 64],
66                warp_size: 32,
67            },
68            device: None,
69            queue: None,
70            pipelines: Mutex::new(HashMap::new()),
71            buffers: Mutex::new(HashMap::new()),
72            next_pipeline_id: Mutex::new(1),
73            next_handle: AtomicUsize::new(HANDLE_BASE),
74        }
75    }
76
77    /// Check if WebGPU is conceptually available on this platform.
78    /// Actual adapter availability is verified in [`initialize`].
79    pub fn is_available() -> bool {
80        true
81    }
82
83    /// Encode a pipeline ID as kernel bytes (8 bytes, little-endian).
84    fn pipeline_id_to_bytes(id: u64) -> Vec<u8> {
85        id.to_le_bytes().to_vec()
86    }
87
88    /// Decode kernel bytes back to a pipeline ID.
89    fn bytes_to_pipeline_id(bytes: &[u8]) -> Result<u64> {
90        if bytes.len() < 8 {
91            return Err(runtime_error!("Invalid kernel handle: too short"));
92        }
93        let mut arr = [0u8; 8];
94        arr.copy_from_slice(&bytes[..8]);
95        Ok(u64::from_le_bytes(arr))
96    }
97
98    fn device(&self) -> Result<&wgpu::Device> {
99        self.device
100            .as_ref()
101            .ok_or_else(|| runtime_error!("Backend not initialized: call initialize() first"))
102    }
103
104    fn queue(&self) -> Result<&wgpu::Queue> {
105        self.queue
106            .as_ref()
107            .ok_or_else(|| runtime_error!("Backend not initialized: call initialize() first"))
108    }
109}
110
111unsafe impl Send for WebGPUBackend {}
112unsafe impl Sync for WebGPUBackend {}
113
114#[async_trait(?Send)]
115impl BackendTrait for WebGPUBackend {
116    fn name(&self) -> &str {
117        "WebGPU (wgpu)"
118    }
119
120    fn capabilities(&self) -> &BackendCapabilities {
121        &self.capabilities
122    }
123
124    async fn initialize(&mut self) -> Result<()> {
125        let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
126            backends: wgpu::Backends::all(),
127            ..Default::default()
128        });
129
130        let adapter = instance
131            .request_adapter(&wgpu::RequestAdapterOptions {
132                power_preference: wgpu::PowerPreference::HighPerformance,
133                compatible_surface: None,
134                force_fallback_adapter: false,
135            })
136            .await
137            .ok_or_else(|| runtime_error!("No WebGPU adapter found"))?;
138
139        let (device, queue) = adapter
140            .request_device(
141                &wgpu::DeviceDescriptor {
142                    label: Some("cuda-wasm"),
143                    required_features: wgpu::Features::empty(),
144                    required_limits: wgpu::Limits::downlevel_defaults(),
145                },
146                None,
147            )
148            .await
149            .map_err(|e| runtime_error!("Failed to create wgpu device: {}", e))?;
150
151        self.device = Some(device);
152        self.queue = Some(queue);
153        Ok(())
154    }
155
156    async fn compile_kernel(&self, source: &str) -> Result<Vec<u8>> {
157        let device = self.device()?;
158
159        // Use error scopes to capture shader validation failures.
160        device.push_error_scope(wgpu::ErrorFilter::Validation);
161        let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
162            label: Some("kernel"),
163            source: wgpu::ShaderSource::Wgsl(source.into()),
164        });
165        device.poll(wgpu::Maintain::Wait);
166        if let Some(e) = pollster::block_on(device.pop_error_scope()) {
167            return Err(runtime_error!("Shader compilation failed: {}", e));
168        }
169
170        // Create compute pipeline with auto bind-group layout.
171        device.push_error_scope(wgpu::ErrorFilter::Validation);
172        let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
173            label: Some("compute_pipeline"),
174            layout: None,
175            module: &module,
176            entry_point: "main",
177        });
178        device.poll(wgpu::Maintain::Wait);
179        if let Some(e) = pollster::block_on(device.pop_error_scope()) {
180            return Err(runtime_error!("Pipeline creation failed: {}", e));
181        }
182
183        let mut id_guard = self
184            .next_pipeline_id
185            .lock()
186            .map_err(|e| runtime_error!("Pipeline ID lock poisoned: {}", e))?;
187        let id = *id_guard;
188        *id_guard += 1;
189
190        self.pipelines
191            .lock()
192            .map_err(|e| runtime_error!("Pipeline lock poisoned: {}", e))?
193            .insert(id, pipeline);
194
195        Ok(Self::pipeline_id_to_bytes(id))
196    }
197
198    async fn launch_kernel(
199        &self,
200        kernel: &[u8],
201        grid: (u32, u32, u32),
202        _block: (u32, u32, u32),
203        args: &[*const u8],
204    ) -> Result<()> {
205        // Snapshot arg pointers as usize immediately so the future is Send
206        // (raw pointers are !Sync, making &[*const u8] !Send).
207        let arg_handles: Vec<usize> = args.iter().map(|p| *p as usize).collect();
208
209        let device = self.device()?;
210        let queue = self.queue()?;
211        let pipeline_id = Self::bytes_to_pipeline_id(kernel)?;
212
213        if grid.0 == 0 || grid.1 == 0 || grid.2 == 0 {
214            return Err(runtime_error!("Grid dimensions must be non-zero"));
215        }
216        if grid.0 > 65535 || grid.1 > 65535 || grid.2 > 65535 {
217            return Err(runtime_error!("Grid dimension exceeds maximum (65535)"));
218        }
219
220        let pipelines = self
221            .pipelines
222            .lock()
223            .map_err(|e| runtime_error!("Pipeline lock poisoned: {}", e))?;
224        let pipeline = pipelines
225            .get(&pipeline_id)
226            .ok_or_else(|| runtime_error!("Kernel not found: pipeline ID {}", pipeline_id))?;
227
228        let buffers_guard = self
229            .buffers
230            .lock()
231            .map_err(|e| runtime_error!("Buffer lock poisoned: {}", e))?;
232
233        // Build bind group entries from arg handles.
234        let mut entries = Vec::with_capacity(arg_handles.len());
235        for (i, &handle) in arg_handles.iter().enumerate() {
236            let (buf, _) = buffers_guard
237                .get(&handle)
238                .ok_or_else(|| runtime_error!("Arg {} buffer handle {:#x} not found", i, handle))?;
239            entries.push(wgpu::BindGroupEntry {
240                binding: i as u32,
241                resource: buf.as_entire_binding(),
242            });
243        }
244
245        let bind_group = if !entries.is_empty() {
246            let layout = pipeline.get_bind_group_layout(0);
247            Some(device.create_bind_group(&wgpu::BindGroupDescriptor {
248                label: None,
249                layout: &layout,
250                entries: &entries,
251            }))
252        } else {
253            None
254        };
255
256        let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
257            label: Some("compute_encoder"),
258        });
259        {
260            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
261                label: Some("compute_pass"),
262                timestamp_writes: None,
263            });
264            pass.set_pipeline(pipeline);
265            if let Some(bg) = &bind_group {
266                pass.set_bind_group(0, bg, &[]);
267            }
268            pass.dispatch_workgroups(grid.0, grid.1, grid.2);
269        }
270        queue.submit(std::iter::once(encoder.finish()));
271        device.poll(wgpu::Maintain::Wait);
272
273        Ok(())
274    }
275
276    fn allocate_memory(&self, size: usize) -> Result<*mut u8> {
277        if size == 0 {
278            return Err(runtime_error!("Cannot allocate zero bytes"));
279        }
280        let device = self.device()?;
281
282        let buffer = device.create_buffer(&wgpu::BufferDescriptor {
283            label: None,
284            size: aligned(size),
285            usage: wgpu::BufferUsages::STORAGE
286                | wgpu::BufferUsages::COPY_SRC
287                | wgpu::BufferUsages::COPY_DST,
288            mapped_at_creation: false,
289        });
290
291        let handle = self.next_handle.fetch_add(1, Ordering::SeqCst);
292        self.buffers
293            .lock()
294            .map_err(|e| runtime_error!("Buffer lock poisoned: {}", e))?
295            .insert(handle, (buffer, size));
296
297        Ok(handle as *mut u8)
298    }
299
300    fn free_memory(&self, ptr: *mut u8) -> Result<()> {
301        let handle = ptr as usize;
302        let (buffer, _) = self
303            .buffers
304            .lock()
305            .map_err(|e| runtime_error!("Buffer lock poisoned: {}", e))?
306            .remove(&handle)
307            .ok_or_else(|| runtime_error!("Attempted to free untracked handle {:#x}", handle))?;
308        drop(buffer);
309        Ok(())
310    }
311
312    fn copy_memory(
313        &self,
314        dst: *mut u8,
315        src: *const u8,
316        size: usize,
317        kind: MemcpyKind,
318    ) -> Result<()> {
319        if size == 0 {
320            return Ok(());
321        }
322        match kind {
323            MemcpyKind::HostToDevice => {
324                let queue = self.queue()?;
325                let device = self.device()?;
326                let dst_handle = dst as usize;
327                let buffers = self
328                    .buffers
329                    .lock()
330                    .map_err(|e| runtime_error!("Buffer lock poisoned: {}", e))?;
331                let (gpu_buf, buf_size) = buffers
332                    .get(&dst_handle)
333                    .ok_or_else(|| runtime_error!("Dst buffer handle not found"))?;
334                if size > *buf_size {
335                    return Err(runtime_error!(
336                        "Copy size {} exceeds buffer size {}",
337                        size,
338                        buf_size
339                    ));
340                }
341                let data = unsafe { std::slice::from_raw_parts(src, size) };
342                queue.write_buffer(gpu_buf, 0, data);
343                queue.submit(std::iter::empty());
344                device.poll(wgpu::Maintain::Wait);
345                Ok(())
346            }
347            MemcpyKind::DeviceToHost => {
348                let device = self.device()?;
349                let queue = self.queue()?;
350                let src_handle = src as usize;
351                let copy_size = aligned(size);
352                let buffers = self
353                    .buffers
354                    .lock()
355                    .map_err(|e| runtime_error!("Buffer lock poisoned: {}", e))?;
356                let (gpu_buf, buf_size) = buffers
357                    .get(&src_handle)
358                    .ok_or_else(|| runtime_error!("Src buffer handle not found"))?;
359                if size > *buf_size {
360                    return Err(runtime_error!(
361                        "Copy size {} exceeds buffer size {}",
362                        size,
363                        buf_size
364                    ));
365                }
366                let staging = device.create_buffer(&wgpu::BufferDescriptor {
367                    label: Some("staging_read"),
368                    size: copy_size,
369                    usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
370                    mapped_at_creation: false,
371                });
372                let mut encoder =
373                    device.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
374                encoder.copy_buffer_to_buffer(gpu_buf, 0, &staging, 0, copy_size);
375                queue.submit(std::iter::once(encoder.finish()));
376
377                let slice = staging.slice(..);
378                let (tx, rx) = std::sync::mpsc::channel();
379                slice.map_async(wgpu::MapMode::Read, move |result| {
380                    tx.send(result).ok();
381                });
382                device.poll(wgpu::Maintain::Wait);
383                rx.recv()
384                    .map_err(|_| runtime_error!("Buffer map channel closed"))?
385                    .map_err(|e| runtime_error!("Buffer map failed: {:?}", e))?;
386
387                let mapped = slice.get_mapped_range();
388                unsafe {
389                    std::ptr::copy_nonoverlapping(mapped.as_ptr(), dst, size);
390                }
391                drop(mapped);
392                staging.unmap();
393                Ok(())
394            }
395            MemcpyKind::DeviceToDevice => {
396                let device = self.device()?;
397                let queue = self.queue()?;
398                let src_handle = src as usize;
399                let dst_handle = dst as usize;
400                let copy_size = aligned(size);
401                let buffers = self
402                    .buffers
403                    .lock()
404                    .map_err(|e| runtime_error!("Buffer lock poisoned: {}", e))?;
405                let (src_buf, _) = buffers
406                    .get(&src_handle)
407                    .ok_or_else(|| runtime_error!("Src buffer handle not found"))?;
408                let (dst_buf, _) = buffers
409                    .get(&dst_handle)
410                    .ok_or_else(|| runtime_error!("Dst buffer handle not found"))?;
411                let mut encoder =
412                    device.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
413                encoder.copy_buffer_to_buffer(src_buf, 0, dst_buf, 0, copy_size);
414                queue.submit(std::iter::once(encoder.finish()));
415                device.poll(wgpu::Maintain::Wait);
416                Ok(())
417            }
418            MemcpyKind::HostToHost => {
419                if dst.is_null() || src.is_null() {
420                    return Err(runtime_error!("Null pointer in host memory copy"));
421                }
422                unsafe { std::ptr::copy_nonoverlapping(src, dst, size) };
423                Ok(())
424            }
425        }
426    }
427
428    fn synchronize(&self) -> Result<()> {
429        if let Some(device) = &self.device {
430            device.poll(wgpu::Maintain::Wait);
431        }
432        Ok(())
433    }
434}
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439
440    /// Try to create and initialize a backend. Returns None if no GPU adapter found.
441    fn try_init_backend() -> Option<WebGPUBackend> {
442        let mut backend = WebGPUBackend::new();
443        pollster::block_on(backend.initialize()).ok()?;
444        Some(backend)
445    }
446
447    // ---- Tests that do NOT require a GPU ----
448
449    #[test]
450    fn test_backend_creation() {
451        let backend = WebGPUBackend::new();
452        assert_eq!(backend.name(), "WebGPU (wgpu)");
453        assert!(backend.capabilities().supports_webgpu);
454    }
455
456    #[test]
457    fn test_is_available() {
458        assert!(WebGPUBackend::is_available());
459    }
460
461    #[test]
462    fn test_capabilities() {
463        let backend = WebGPUBackend::new();
464        let caps = backend.capabilities();
465        assert_eq!(caps.warp_size, 32);
466        assert!(caps.max_shared_memory > 0);
467    }
468
469    #[test]
470    fn test_pipeline_id_roundtrip() {
471        let id = 12345u64;
472        let bytes = WebGPUBackend::pipeline_id_to_bytes(id);
473        assert_eq!(bytes.len(), 8);
474        assert_eq!(WebGPUBackend::bytes_to_pipeline_id(&bytes).unwrap(), id);
475    }
476
477    #[test]
478    fn test_pipeline_id_short_fails() {
479        assert!(WebGPUBackend::bytes_to_pipeline_id(&[1, 2]).is_err());
480    }
481
482    #[test]
483    fn test_allocate_zero_fails() {
484        let backend = WebGPUBackend::new();
485        assert!(backend.allocate_memory(0).is_err());
486    }
487
488    #[test]
489    fn test_uninitialized_allocate_fails() {
490        let backend = WebGPUBackend::new();
491        assert!(backend.allocate_memory(1024).is_err());
492    }
493
494    #[test]
495    fn test_free_untracked_fails() {
496        let backend = WebGPUBackend::new();
497        let fake = 0xDEAD as *mut u8;
498        assert!(backend.free_memory(fake).is_err());
499    }
500
501    #[test]
502    fn test_copy_zero_noop() {
503        let backend = WebGPUBackend::new();
504        let a = 1 as *mut u8;
505        backend
506            .copy_memory(a, a, 0, MemcpyKind::DeviceToDevice)
507            .unwrap();
508    }
509
510    #[test]
511    fn test_host_to_host_copy() {
512        let backend = WebGPUBackend::new();
513        let src = vec![1u8, 2, 3, 4];
514        let mut dst = vec![0u8; 4];
515        backend
516            .copy_memory(dst.as_mut_ptr(), src.as_ptr(), 4, MemcpyKind::HostToHost)
517            .unwrap();
518        assert_eq!(dst, vec![1, 2, 3, 4]);
519    }
520
521    #[test]
522    fn test_host_to_host_null_fails() {
523        let backend = WebGPUBackend::new();
524        let ptr = vec![0u8; 64];
525        assert!(backend
526            .copy_memory(std::ptr::null_mut(), ptr.as_ptr(), 64, MemcpyKind::HostToHost)
527            .is_err());
528    }
529
530    #[test]
531    fn test_synchronize_uninitialized() {
532        let backend = WebGPUBackend::new();
533        backend.synchronize().unwrap();
534    }
535
536    // ---- Tests that REQUIRE a GPU adapter ----
537
538    #[test]
539    fn test_gpu_allocate_and_free() {
540        let backend = match try_init_backend() {
541            Some(b) => b,
542            None => {
543                eprintln!("Skipping test_gpu_allocate_and_free: no GPU adapter");
544                return;
545            }
546        };
547        let handle = backend.allocate_memory(1024).unwrap();
548        assert!(!handle.is_null());
549        assert!(handle as usize >= HANDLE_BASE);
550        backend.free_memory(handle).unwrap();
551    }
552
553    #[test]
554    fn test_gpu_data_roundtrip() {
555        let backend = match try_init_backend() {
556            Some(b) => b,
557            None => {
558                eprintln!("Skipping test_gpu_data_roundtrip: no GPU adapter");
559                return;
560            }
561        };
562        let data: Vec<u8> = (0..256).map(|i| i as u8).collect();
563        let gpu_buf = backend.allocate_memory(256).unwrap();
564
565        backend
566            .copy_memory(gpu_buf, data.as_ptr(), 256, MemcpyKind::HostToDevice)
567            .unwrap();
568
569        let mut readback = vec![0u8; 256];
570        backend
571            .copy_memory(
572                readback.as_mut_ptr(),
573                gpu_buf as *const u8,
574                256,
575                MemcpyKind::DeviceToHost,
576            )
577            .unwrap();
578
579        assert_eq!(readback, data);
580        backend.free_memory(gpu_buf).unwrap();
581    }
582
583    #[test]
584    fn test_gpu_device_to_device_copy() {
585        let backend = match try_init_backend() {
586            Some(b) => b,
587            None => {
588                eprintln!("Skipping test_gpu_device_to_device_copy: no GPU adapter");
589                return;
590            }
591        };
592        let data: Vec<u8> = (0..128).map(|i| (i * 2) as u8).collect();
593        let buf_a = backend.allocate_memory(128).unwrap();
594        let buf_b = backend.allocate_memory(128).unwrap();
595
596        backend
597            .copy_memory(buf_a, data.as_ptr(), 128, MemcpyKind::HostToDevice)
598            .unwrap();
599        backend
600            .copy_memory(buf_b, buf_a as *const u8, 128, MemcpyKind::DeviceToDevice)
601            .unwrap();
602
603        let mut readback = vec![0u8; 128];
604        backend
605            .copy_memory(
606                readback.as_mut_ptr(),
607                buf_b as *const u8,
608                128,
609                MemcpyKind::DeviceToHost,
610            )
611            .unwrap();
612
613        assert_eq!(readback, data);
614        backend.free_memory(buf_a).unwrap();
615        backend.free_memory(buf_b).unwrap();
616    }
617
618    #[test]
619    fn test_gpu_synchronize() {
620        let backend = match try_init_backend() {
621            Some(b) => b,
622            None => {
623                eprintln!("Skipping test_gpu_synchronize: no GPU adapter");
624                return;
625            }
626        };
627        backend.synchronize().unwrap();
628    }
629
630    #[tokio::test]
631    async fn test_gpu_compile_valid_wgsl() {
632        let backend = match try_init_backend() {
633            Some(b) => b,
634            None => {
635                eprintln!("Skipping test_gpu_compile_valid_wgsl: no GPU adapter");
636                return;
637            }
638        };
639        let kernel = backend
640            .compile_kernel("@compute @workgroup_size(64) fn main() {}")
641            .await
642            .unwrap();
643        assert_eq!(kernel.len(), 8);
644    }
645
646    #[tokio::test]
647    async fn test_gpu_compile_invalid_wgsl() {
648        let backend = match try_init_backend() {
649            Some(b) => b,
650            None => {
651                eprintln!("Skipping test_gpu_compile_invalid_wgsl: no GPU adapter");
652                return;
653            }
654        };
655        assert!(backend.compile_kernel("not valid wgsl").await.is_err());
656    }
657
658    #[tokio::test]
659    async fn test_gpu_launch_missing_kernel() {
660        let backend = match try_init_backend() {
661            Some(b) => b,
662            None => {
663                eprintln!("Skipping test_gpu_launch_missing_kernel: no GPU adapter");
664                return;
665            }
666        };
667        let fake = WebGPUBackend::pipeline_id_to_bytes(999);
668        assert!(backend
669            .launch_kernel(&fake, (1, 1, 1), (64, 1, 1), &[])
670            .await
671            .is_err());
672    }
673
674    #[tokio::test]
675    async fn test_gpu_compile_and_launch() {
676        let backend = match try_init_backend() {
677            Some(b) => b,
678            None => {
679                eprintln!("Skipping test_gpu_compile_and_launch: no GPU adapter");
680                return;
681            }
682        };
683        let kernel = backend
684            .compile_kernel("@compute @workgroup_size(64) fn main() {}")
685            .await
686            .unwrap();
687        backend
688            .launch_kernel(&kernel, (1, 1, 1), (64, 1, 1), &[])
689            .await
690            .unwrap();
691    }
692}