cubecl_wgpu/compute/
storage.rs1use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
2use hashbrown::HashMap;
3use std::{num::NonZeroU64, sync::Arc};
4
5pub struct WgpuStorage {
7 memory: HashMap<StorageId, Arc<wgpu::Buffer>>,
8 deallocations: Vec<StorageId>,
9 device: Arc<wgpu::Device>,
10}
11
12impl core::fmt::Debug for WgpuStorage {
13 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14 f.write_str(format!("WgpuStorage {{ device: {:?} }}", self.device).as_str())
15 }
16}
17
18#[derive(new)]
20pub struct WgpuResource {
21 pub buffer: Arc<wgpu::Buffer>,
23
24 offset: u64,
25 size: u64,
26}
27
28impl WgpuResource {
29 pub fn as_wgpu_bind_resource(&self) -> wgpu::BindingResource {
31 let binding = wgpu::BufferBinding {
32 buffer: &self.buffer,
33 offset: self.offset,
34 size: Some(
35 NonZeroU64::new(self.size).expect("0 size resources are not yet supported."),
36 ),
37 };
38 wgpu::BindingResource::Buffer(binding)
39 }
40
41 pub fn size(&self) -> u64 {
43 self.size
44 }
45
46 pub fn offset(&self) -> u64 {
48 self.offset
49 }
50}
51
52impl WgpuStorage {
54 pub fn new(device: Arc<wgpu::Device>) -> Self {
56 Self {
57 memory: HashMap::new(),
58 deallocations: Vec::new(),
59 device,
60 }
61 }
62
63 pub fn perform_deallocations(&mut self) {
65 for id in self.deallocations.drain(..) {
66 if let Some(buffer) = self.memory.remove(&id) {
67 buffer.destroy()
68 }
69 }
70 }
71}
72
73impl ComputeStorage for WgpuStorage {
74 type Resource = WgpuResource;
75
76 const ALIGNMENT: u64 = 32;
81
82 fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
83 let buffer = self.memory.get(&handle.id).unwrap();
84 WgpuResource::new(buffer.clone(), handle.offset(), handle.size())
85 }
86
87 fn alloc(&mut self, size: u64) -> StorageHandle {
88 let id = StorageId::new();
89 let buffer = Arc::new(self.device.create_buffer(&wgpu::BufferDescriptor {
90 label: None,
91 size,
92 usage: wgpu::BufferUsages::COPY_DST
93 | wgpu::BufferUsages::STORAGE
94 | wgpu::BufferUsages::COPY_SRC
95 | wgpu::BufferUsages::INDIRECT,
96 mapped_at_creation: false,
97 }));
98
99 self.memory.insert(id, buffer);
100 StorageHandle::new(id, StorageUtilization { offset: 0, size })
101 }
102
103 fn dealloc(&mut self, id: StorageId) {
104 self.deallocations.push(id);
105 }
106}