Skip to main content

cubecl_wgpu/compute/
storage.rs

1use cubecl_core::server::IoError;
2use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
3use hashbrown::HashMap;
4use std::num::NonZeroU64;
5use wgpu::BufferUsages;
6
7/// Minimum buffer size in bytes. The WebGPU spec requires buffer sizes > 0, and shaders
8/// declare typed arrays (e.g. `array<vec4<f32>>`) that impose a minimum binding size.
9/// 32 bytes covers the largest possible binding type (`vec4<f64>`).
10const MIN_BUFFER_SIZE: u64 = 32;
11
12/// Buffer storage for wgpu.
13pub struct WgpuStorage {
14    memory: HashMap<StorageId, wgpu::Buffer>,
15    device: wgpu::Device,
16    buffer_usages: BufferUsages,
17    mem_alignment: usize,
18}
19
20impl core::fmt::Debug for WgpuStorage {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        f.write_str(format!("WgpuStorage {{ device: {:?} }}", self.device).as_str())
23    }
24}
25
26/// The memory resource that can be allocated for wgpu.
27#[derive(new, Debug)]
28pub struct WgpuResource {
29    /// The wgpu buffer.
30    pub buffer: wgpu::Buffer,
31    /// The buffer offset.
32    pub offset: u64,
33    /// The size of the resource.
34    ///
35    /// # Notes
36    ///
37    /// The result considers the offset.
38    pub size: u64,
39}
40
41impl WgpuResource {
42    /// Return the binding view of the buffer.
43    pub fn as_wgpu_bind_resource(&self) -> wgpu::BindingResource<'_> {
44        // wgpu enforces 4-byte alignment for buffer binding sizes per the WebGPU spec.
45        // - https://github.com/gfx-rs/wgpu/pull/8041
46        //
47        // This padding is safe because:
48        // 1. In checked mode, bounds checks prevent reading beyond the logical size.
49        // 2. In unchecked mode, OOB access is already undefined behavior.
50        //
51        // For zero-sized resources, pass None (use rest of buffer from offset).
52        // The allocator guarantees the buffer is at least MIN_BUFFER_SIZE bytes.
53        let size = NonZeroU64::new(self.size.next_multiple_of(4));
54
55        let binding = wgpu::BufferBinding {
56            buffer: &self.buffer,
57            offset: self.offset,
58            size,
59        };
60        wgpu::BindingResource::Buffer(binding)
61    }
62}
63
64/// Keeps actual wgpu buffer references in a hashmap with ids as key.
65impl WgpuStorage {
66    /// Create a new storage on the given [device](wgpu::Device).
67    pub fn new(mem_alignment: usize, device: wgpu::Device, usages: BufferUsages) -> Self {
68        Self {
69            memory: HashMap::new(),
70            device,
71            buffer_usages: usages,
72            mem_alignment,
73        }
74    }
75}
76
77impl ComputeStorage for WgpuStorage {
78    type Resource = WgpuResource;
79
80    fn alignment(&self) -> usize {
81        self.mem_alignment
82    }
83
84    fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
85        let buffer = self.memory.get(&handle.id).unwrap();
86        WgpuResource::new(buffer.clone(), handle.offset(), handle.size())
87    }
88
89    #[cfg_attr(
90        feature = "tracing",
91        tracing::instrument(level = "trace", skip(self, size))
92    )]
93    fn alloc(&mut self, size: u64) -> Result<StorageHandle, IoError> {
94        let id = StorageId::new();
95
96        let alloc_size = size.max(MIN_BUFFER_SIZE);
97
98        let buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
99            label: None,
100            size: alloc_size,
101            usage: self.buffer_usages,
102            mapped_at_creation: false,
103        });
104
105        self.memory.insert(id, buffer);
106        Ok(StorageHandle::new(
107            id,
108            StorageUtilization { offset: 0, size },
109        ))
110    }
111
112    #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
113    fn dealloc(&mut self, id: StorageId) {
114        self.memory.remove(&id);
115    }
116
117    fn flush(&mut self) {
118        // We don't wait for dealloc
119    }
120}