cubecl_wgpu/compute/
storage.rs1use cubecl_core::server::IoError;
2use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
3use hashbrown::HashMap;
4use std::num::NonZeroU64;
5use wgpu::BufferUsages;
6
7const MIN_BUFFER_SIZE: u64 = 32;
11
12pub struct WgpuStorage {
14 memory: HashMap<StorageId, wgpu::Buffer>,
15 device: wgpu::Device,
16 buffer_usages: BufferUsages,
17 mem_alignment: usize,
18}
19
20impl core::fmt::Debug for WgpuStorage {
21 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22 f.write_str(format!("WgpuStorage {{ device: {:?} }}", self.device).as_str())
23 }
24}
25
26#[derive(new, Debug)]
28pub struct WgpuResource {
29 pub buffer: wgpu::Buffer,
31 pub offset: u64,
33 pub size: u64,
39}
40
41impl WgpuResource {
42 pub fn as_wgpu_bind_resource(&self) -> wgpu::BindingResource<'_> {
44 let size = NonZeroU64::new(self.size.next_multiple_of(4));
54
55 let binding = wgpu::BufferBinding {
56 buffer: &self.buffer,
57 offset: self.offset,
58 size,
59 };
60 wgpu::BindingResource::Buffer(binding)
61 }
62}
63
64impl WgpuStorage {
66 pub fn new(mem_alignment: usize, device: wgpu::Device, usages: BufferUsages) -> Self {
68 Self {
69 memory: HashMap::new(),
70 device,
71 buffer_usages: usages,
72 mem_alignment,
73 }
74 }
75}
76
77impl ComputeStorage for WgpuStorage {
78 type Resource = WgpuResource;
79
80 fn alignment(&self) -> usize {
81 self.mem_alignment
82 }
83
84 fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
85 let buffer = self.memory.get(&handle.id).unwrap();
86 WgpuResource::new(buffer.clone(), handle.offset(), handle.size())
87 }
88
89 #[cfg_attr(
90 feature = "tracing",
91 tracing::instrument(level = "trace", skip(self, size))
92 )]
93 fn alloc(&mut self, size: u64) -> Result<StorageHandle, IoError> {
94 let id = StorageId::new();
95
96 let alloc_size = size.max(MIN_BUFFER_SIZE);
97
98 let buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
99 label: None,
100 size: alloc_size,
101 usage: self.buffer_usages,
102 mapped_at_creation: false,
103 });
104
105 self.memory.insert(id, buffer);
106 Ok(StorageHandle::new(
107 id,
108 StorageUtilization { offset: 0, size },
109 ))
110 }
111
112 #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
113 fn dealloc(&mut self, id: StorageId) {
114 self.memory.remove(&id);
115 }
116
117 fn flush(&mut self) {
118 }
120}