cubecl_wgpu/compute/
storage.rs

1use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
2use hashbrown::HashMap;
3use std::{num::NonZeroU64, sync::Arc};
4
5/// Buffer storage for wgpu.
6pub struct WgpuStorage {
7    memory: HashMap<StorageId, Arc<wgpu::Buffer>>,
8    deallocations: Vec<StorageId>,
9    device: Arc<wgpu::Device>,
10}
11
12impl core::fmt::Debug for WgpuStorage {
13    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14        f.write_str(format!("WgpuStorage {{ device: {:?} }}", self.device).as_str())
15    }
16}
17
18/// The memory resource that can be allocated for wgpu.
19#[derive(new)]
20pub struct WgpuResource {
21    /// The wgpu buffer.
22    pub buffer: Arc<wgpu::Buffer>,
23
24    offset: u64,
25    size: u64,
26}
27
28impl WgpuResource {
29    /// Return the binding view of the buffer.
30    pub fn as_wgpu_bind_resource(&self) -> wgpu::BindingResource {
31        let binding = wgpu::BufferBinding {
32            buffer: &self.buffer,
33            offset: self.offset,
34            size: Some(
35                NonZeroU64::new(self.size).expect("0 size resources are not yet supported."),
36            ),
37        };
38        wgpu::BindingResource::Buffer(binding)
39    }
40
41    /// Return the buffer size.
42    pub fn size(&self) -> u64 {
43        self.size
44    }
45
46    /// Return the buffer offset.
47    pub fn offset(&self) -> u64 {
48        self.offset
49    }
50}
51
52/// Keeps actual wgpu buffer references in a hashmap with ids as key.
53impl WgpuStorage {
54    /// Create a new storage on the given [device](wgpu::Device).
55    pub fn new(device: Arc<wgpu::Device>) -> Self {
56        Self {
57            memory: HashMap::new(),
58            deallocations: Vec::new(),
59            device,
60        }
61    }
62
63    /// Actually deallocates buffers tagged to be deallocated.
64    pub fn perform_deallocations(&mut self) {
65        for id in self.deallocations.drain(..) {
66            if let Some(buffer) = self.memory.remove(&id) {
67                buffer.destroy()
68            }
69        }
70    }
71}
72
73impl ComputeStorage for WgpuStorage {
74    type Resource = WgpuResource;
75
76    // 32 bytes is enough to handle a double4 worth of alignment.
77    // See: https://github.com/gfx-rs/wgpu/issues/3508
78    // NB: cudamalloc and co. actually align to _256_ bytes. Worth
79    // trying this in the future to see if it reduces memory coalescing.
80    const ALIGNMENT: u64 = 32;
81
82    fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
83        let buffer = self.memory.get(&handle.id).unwrap();
84        WgpuResource::new(buffer.clone(), handle.offset(), handle.size())
85    }
86
87    fn alloc(&mut self, size: u64) -> StorageHandle {
88        let id = StorageId::new();
89        let buffer = Arc::new(self.device.create_buffer(&wgpu::BufferDescriptor {
90            label: None,
91            size,
92            usage: wgpu::BufferUsages::COPY_DST
93                | wgpu::BufferUsages::STORAGE
94                | wgpu::BufferUsages::COPY_SRC
95                | wgpu::BufferUsages::INDIRECT,
96            mapped_at_creation: false,
97        }));
98
99        self.memory.insert(id, buffer);
100        StorageHandle::new(id, StorageUtilization { offset: 0, size })
101    }
102
103    fn dealloc(&mut self, id: StorageId) {
104        self.deallocations.push(id);
105    }
106}