cubecl_wgpu/compute/
storage.rs1use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
2use hashbrown::HashMap;
3use std::num::NonZeroU64;
4use wgpu::BufferUsages;
5
6pub struct WgpuStorage {
8 memory: HashMap<StorageId, wgpu::Buffer>,
9 device: wgpu::Device,
10 buffer_usages: BufferUsages,
11 mem_alignment: usize,
12}
13
14impl core::fmt::Debug for WgpuStorage {
15 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16 f.write_str(format!("WgpuStorage {{ device: {:?} }}", self.device).as_str())
17 }
18}
19
20#[derive(new, Debug)]
22pub struct WgpuResource {
23 buffer: wgpu::Buffer,
25 offset: u64,
26 size: u64,
27}
28
29impl WgpuResource {
30 pub fn as_wgpu_bind_resource(&self) -> wgpu::BindingResource {
32 let binding = wgpu::BufferBinding {
33 buffer: &self.buffer,
34 offset: self.offset,
35 size: Some(
36 NonZeroU64::new(self.size).expect("0 size resources are not yet supported."),
37 ),
38 };
39 wgpu::BindingResource::Buffer(binding)
40 }
41
42 pub fn buffer(&self) -> &wgpu::Buffer {
43 &self.buffer
44 }
45
46 pub fn size(&self) -> u64 {
48 self.size
49 }
50
51 pub fn offset(&self) -> u64 {
53 self.offset
54 }
55}
56
57impl WgpuStorage {
59 pub fn new(mem_alignment: usize, device: wgpu::Device, usages: BufferUsages) -> Self {
61 Self {
62 memory: HashMap::new(),
63 device,
64 buffer_usages: usages,
65 mem_alignment,
66 }
67 }
68}
69
70impl ComputeStorage for WgpuStorage {
71 type Resource = WgpuResource;
72
73 fn alignment(&self) -> usize {
74 self.mem_alignment
75 }
76
77 fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
78 let buffer = self.memory.get(&handle.id).unwrap();
79 WgpuResource::new(buffer.clone(), handle.offset(), handle.size())
80 }
81
82 fn alloc(&mut self, size: u64) -> StorageHandle {
83 let id = StorageId::new();
84
85 let buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
86 label: None,
87 size,
88 usage: self.buffer_usages,
89 mapped_at_creation: false,
90 });
91
92 self.memory.insert(id, buffer);
93 StorageHandle::new(id, StorageUtilization { offset: 0, size })
94 }
95
96 fn dealloc(&mut self, id: StorageId) {
97 self.memory.remove(&id);
98 }
99}