cubecl_wgpu/compute/
storage.rs

1use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
2use hashbrown::HashMap;
3use std::num::NonZeroU64;
4use wgpu::BufferUsages;
5
6/// Buffer storage for wgpu.
7pub struct WgpuStorage {
8    memory: HashMap<StorageId, wgpu::Buffer>,
9    device: wgpu::Device,
10    buffer_usages: BufferUsages,
11    mem_alignment: usize,
12}
13
14impl core::fmt::Debug for WgpuStorage {
15    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16        f.write_str(format!("WgpuStorage {{ device: {:?} }}", self.device).as_str())
17    }
18}
19
20/// The memory resource that can be allocated for wgpu.
21#[derive(new, Debug)]
22pub struct WgpuResource {
23    /// The wgpu buffer.
24    buffer: wgpu::Buffer,
25    offset: u64,
26    size: u64,
27}
28
29impl WgpuResource {
30    /// Return the binding view of the buffer.
31    pub fn as_wgpu_bind_resource(&self) -> wgpu::BindingResource {
32        let binding = wgpu::BufferBinding {
33            buffer: &self.buffer,
34            offset: self.offset,
35            size: Some(
36                NonZeroU64::new(self.size).expect("0 size resources are not yet supported."),
37            ),
38        };
39        wgpu::BindingResource::Buffer(binding)
40    }
41
42    pub fn buffer(&self) -> &wgpu::Buffer {
43        &self.buffer
44    }
45
46    /// Return the buffer size.
47    pub fn size(&self) -> u64 {
48        self.size
49    }
50
51    /// Return the buffer offset.
52    pub fn offset(&self) -> u64 {
53        self.offset
54    }
55}
56
57/// Keeps actual wgpu buffer references in a hashmap with ids as key.
58impl WgpuStorage {
59    /// Create a new storage on the given [device](wgpu::Device).
60    pub fn new(mem_alignment: usize, device: wgpu::Device, usages: BufferUsages) -> Self {
61        Self {
62            memory: HashMap::new(),
63            device,
64            buffer_usages: usages,
65            mem_alignment,
66        }
67    }
68}
69
70impl ComputeStorage for WgpuStorage {
71    type Resource = WgpuResource;
72
73    fn alignment(&self) -> usize {
74        self.mem_alignment
75    }
76
77    fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
78        let buffer = self.memory.get(&handle.id).unwrap();
79        WgpuResource::new(buffer.clone(), handle.offset(), handle.size())
80    }
81
82    fn alloc(&mut self, size: u64) -> StorageHandle {
83        let id = StorageId::new();
84
85        let buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
86            label: None,
87            size,
88            usage: self.buffer_usages,
89            mapped_at_creation: false,
90        });
91
92        self.memory.insert(id, buffer);
93        StorageHandle::new(id, StorageUtilization { offset: 0, size })
94    }
95
96    fn dealloc(&mut self, id: StorageId) {
97        self.memory.remove(&id);
98    }
99}