Skip to main content

cubecl_wgpu/compute/
storage.rs

1use cubecl_core::server::IoError;
2use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
3use hashbrown::HashMap;
4use std::num::NonZeroU64;
5use wgpu::BufferUsages;
6
7/// Buffer storage for wgpu.
8pub struct WgpuStorage {
9    memory: HashMap<StorageId, wgpu::Buffer>,
10    device: wgpu::Device,
11    buffer_usages: BufferUsages,
12    mem_alignment: usize,
13}
14
15impl core::fmt::Debug for WgpuStorage {
16    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17        f.write_str(format!("WgpuStorage {{ device: {:?} }}", self.device).as_str())
18    }
19}
20
21/// The memory resource that can be allocated for wgpu.
22#[derive(new, Debug)]
23pub struct WgpuResource {
24    /// The wgpu buffer.
25    pub buffer: wgpu::Buffer,
26    /// The buffer offset.
27    pub offset: u64,
28    /// The size of the resource.
29    ///
30    /// # Notes
31    ///
32    /// The result considers the offset.
33    pub size: u64,
34}
35
36impl WgpuResource {
37    /// Return the binding view of the buffer.
38    pub fn as_wgpu_bind_resource(&self) -> wgpu::BindingResource<'_> {
39        // wgpu enforces 4-byte alignment for buffer binding sizes per the WebGPU spec.
40        // - https://github.com/gfx-rs/wgpu/pull/8041
41        //
42        // This padding is safe because:
43        // 1. In checked mode, bounds checks prevent reading beyond the logical size.
44        // 2. In unchecked mode, OOB access is already undefined behavior.
45        let size = self.size.next_multiple_of(4);
46
47        let binding = wgpu::BufferBinding {
48            buffer: &self.buffer,
49            offset: self.offset,
50            size: Some(NonZeroU64::new(size).expect("0 size resources are not yet supported.")),
51        };
52        wgpu::BindingResource::Buffer(binding)
53    }
54}
55
56/// Keeps actual wgpu buffer references in a hashmap with ids as key.
57impl WgpuStorage {
58    /// Create a new storage on the given [device](wgpu::Device).
59    pub fn new(mem_alignment: usize, device: wgpu::Device, usages: BufferUsages) -> Self {
60        Self {
61            memory: HashMap::new(),
62            device,
63            buffer_usages: usages,
64            mem_alignment,
65        }
66    }
67}
68
69impl ComputeStorage for WgpuStorage {
70    type Resource = WgpuResource;
71
72    fn alignment(&self) -> usize {
73        self.mem_alignment
74    }
75
76    fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
77        let buffer = self.memory.get(&handle.id).unwrap();
78        WgpuResource::new(buffer.clone(), handle.offset(), handle.size())
79    }
80
81    #[cfg_attr(
82        feature = "tracing",
83        tracing::instrument(level = "trace", skip(self, size))
84    )]
85    fn alloc(&mut self, size: u64) -> Result<StorageHandle, IoError> {
86        let id = StorageId::new();
87
88        let buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
89            label: None,
90            size,
91            usage: self.buffer_usages,
92            mapped_at_creation: false,
93        });
94
95        self.memory.insert(id, buffer);
96        Ok(StorageHandle::new(
97            id,
98            StorageUtilization { offset: 0, size },
99        ))
100    }
101
102    #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
103    fn dealloc(&mut self, id: StorageId) {
104        self.memory.remove(&id);
105    }
106}