cubecl_wgpu/compute/
storage.rs1use cubecl_core::server::IoError;
2use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
3use hashbrown::HashMap;
4use std::num::NonZeroU64;
5use wgpu::BufferUsages;
6
7pub struct WgpuStorage {
9 memory: HashMap<StorageId, wgpu::Buffer>,
10 device: wgpu::Device,
11 buffer_usages: BufferUsages,
12 mem_alignment: usize,
13}
14
15impl core::fmt::Debug for WgpuStorage {
16 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17 f.write_str(format!("WgpuStorage {{ device: {:?} }}", self.device).as_str())
18 }
19}
20
21#[derive(new, Debug)]
23pub struct WgpuResource {
24 pub buffer: wgpu::Buffer,
26 pub offset: u64,
28 pub size: u64,
34}
35
36impl WgpuResource {
37 pub fn as_wgpu_bind_resource(&self) -> wgpu::BindingResource<'_> {
39 let binding = wgpu::BufferBinding {
40 buffer: &self.buffer,
41 offset: self.offset,
42 size: Some(
43 NonZeroU64::new(self.size).expect("0 size resources are not yet supported."),
44 ),
45 };
46 wgpu::BindingResource::Buffer(binding)
47 }
48}
49
50impl WgpuStorage {
52 pub fn new(mem_alignment: usize, device: wgpu::Device, usages: BufferUsages) -> Self {
54 Self {
55 memory: HashMap::new(),
56 device,
57 buffer_usages: usages,
58 mem_alignment,
59 }
60 }
61}
62
63impl ComputeStorage for WgpuStorage {
64 type Resource = WgpuResource;
65
66 fn alignment(&self) -> usize {
67 self.mem_alignment
68 }
69
70 fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
71 let buffer = self.memory.get(&handle.id).unwrap();
72 WgpuResource::new(buffer.clone(), handle.offset(), handle.size())
73 }
74
75 fn alloc(&mut self, size: u64) -> Result<StorageHandle, IoError> {
76 let id = StorageId::new();
77
78 let buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
79 label: None,
80 size,
81 usage: self.buffer_usages,
82 mapped_at_creation: false,
83 });
84
85 self.memory.insert(id, buffer);
86 Ok(StorageHandle::new(
87 id,
88 StorageUtilization { offset: 0, size },
89 ))
90 }
91
92 fn dealloc(&mut self, id: StorageId) {
93 self.memory.remove(&id);
94 }
95}