cubecl_wgpu/compute/
storage.rs

1use cubecl_core::server::IoError;
2use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
3use hashbrown::HashMap;
4use std::num::NonZeroU64;
5use wgpu::BufferUsages;
6
7/// Buffer storage for wgpu.
8pub struct WgpuStorage {
9    memory: HashMap<StorageId, wgpu::Buffer>,
10    device: wgpu::Device,
11    buffer_usages: BufferUsages,
12    mem_alignment: usize,
13}
14
15impl core::fmt::Debug for WgpuStorage {
16    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17        f.write_str(format!("WgpuStorage {{ device: {:?} }}", self.device).as_str())
18    }
19}
20
21/// The memory resource that can be allocated for wgpu.
22#[derive(new, Debug)]
23pub struct WgpuResource {
24    /// The wgpu buffer.
25    pub buffer: wgpu::Buffer,
26    /// The buffer offset.
27    pub offset: u64,
28    /// The size of the resource.
29    ///
30    /// # Notes
31    ///
32    /// The result considers the offset.
33    pub size: u64,
34}
35
36impl WgpuResource {
37    /// Return the binding view of the buffer.
38    pub fn as_wgpu_bind_resource(&self) -> wgpu::BindingResource<'_> {
39        let binding = wgpu::BufferBinding {
40            buffer: &self.buffer,
41            offset: self.offset,
42            size: Some(
43                NonZeroU64::new(self.size).expect("0 size resources are not yet supported."),
44            ),
45        };
46        wgpu::BindingResource::Buffer(binding)
47    }
48}
49
50/// Keeps actual wgpu buffer references in a hashmap with ids as key.
51impl WgpuStorage {
52    /// Create a new storage on the given [device](wgpu::Device).
53    pub fn new(mem_alignment: usize, device: wgpu::Device, usages: BufferUsages) -> Self {
54        Self {
55            memory: HashMap::new(),
56            device,
57            buffer_usages: usages,
58            mem_alignment,
59        }
60    }
61}
62
63impl ComputeStorage for WgpuStorage {
64    type Resource = WgpuResource;
65
66    fn alignment(&self) -> usize {
67        self.mem_alignment
68    }
69
70    fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
71        let buffer = self.memory.get(&handle.id).unwrap();
72        WgpuResource::new(buffer.clone(), handle.offset(), handle.size())
73    }
74
75    fn alloc(&mut self, size: u64) -> Result<StorageHandle, IoError> {
76        let id = StorageId::new();
77
78        let buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
79            label: None,
80            size,
81            usage: self.buffer_usages,
82            mapped_at_creation: false,
83        });
84
85        self.memory.insert(id, buffer);
86        Ok(StorageHandle::new(
87            id,
88            StorageUtilization { offset: 0, size },
89        ))
90    }
91
92    fn dealloc(&mut self, id: StorageId) {
93        self.memory.remove(&id);
94    }
95}