cubecl_wgpu/compute/
storage.rs1use cubecl_core::server::IoError;
2use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
3use hashbrown::HashMap;
4use std::num::NonZeroU64;
5use wgpu::BufferUsages;
6
7pub struct WgpuStorage {
9 memory: HashMap<StorageId, wgpu::Buffer>,
10 device: wgpu::Device,
11 buffer_usages: BufferUsages,
12 mem_alignment: usize,
13}
14
15impl core::fmt::Debug for WgpuStorage {
16 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17 f.write_str(format!("WgpuStorage {{ device: {:?} }}", self.device).as_str())
18 }
19}
20
21#[derive(new, Debug)]
23pub struct WgpuResource {
24 pub buffer: wgpu::Buffer,
26 pub offset: u64,
28 pub size: u64,
34}
35
36impl WgpuResource {
37 pub fn as_wgpu_bind_resource(&self) -> wgpu::BindingResource<'_> {
39 let size = self.size.next_multiple_of(4);
46
47 let binding = wgpu::BufferBinding {
48 buffer: &self.buffer,
49 offset: self.offset,
50 size: Some(NonZeroU64::new(size).expect("0 size resources are not yet supported.")),
51 };
52 wgpu::BindingResource::Buffer(binding)
53 }
54}
55
56impl WgpuStorage {
58 pub fn new(mem_alignment: usize, device: wgpu::Device, usages: BufferUsages) -> Self {
60 Self {
61 memory: HashMap::new(),
62 device,
63 buffer_usages: usages,
64 mem_alignment,
65 }
66 }
67}
68
69impl ComputeStorage for WgpuStorage {
70 type Resource = WgpuResource;
71
72 fn alignment(&self) -> usize {
73 self.mem_alignment
74 }
75
76 fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
77 let buffer = self.memory.get(&handle.id).unwrap();
78 WgpuResource::new(buffer.clone(), handle.offset(), handle.size())
79 }
80
81 #[cfg_attr(
82 feature = "tracing",
83 tracing::instrument(level = "trace", skip(self, size))
84 )]
85 fn alloc(&mut self, size: u64) -> Result<StorageHandle, IoError> {
86 let id = StorageId::new();
87
88 let buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
89 label: None,
90 size,
91 usage: self.buffer_usages,
92 mapped_at_creation: false,
93 });
94
95 self.memory.insert(id, buffer);
96 Ok(StorageHandle::new(
97 id,
98 StorageUtilization { offset: 0, size },
99 ))
100 }
101
102 #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
103 fn dealloc(&mut self, id: StorageId) {
104 self.memory.remove(&id);
105 }
106}