1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
use core::mem::MaybeUninit;
use cubecl_common::bytes::{AllocationController, AllocationProperty};
use cubecl_runtime::memory_management::ManagedMemoryBinding;
use wgpu::BufferView;
/// Controller for managing wgpu staging buffers managed by a memory pool.
pub struct WgpuAllocController {
view: Option<BufferView>,
buffer: wgpu::Buffer,
_binding: ManagedMemoryBinding,
}
impl Drop for WgpuAllocController {
fn drop(&mut self) {
// Drop the view first, then unmap the buffer.
// This ensures proper cleanup order since the view borrows from the buffer.
drop(self.view.take());
// We unmap the buffer and release the binding so that the same buffer can be used again.
self.buffer.unmap();
}
}
impl AllocationController for WgpuAllocController {
fn alloc_align(&self) -> usize {
wgpu::COPY_BUFFER_ALIGNMENT as usize
}
fn property(&self) -> AllocationProperty {
AllocationProperty::Pinned
}
unsafe fn memory_mut(&mut self) -> &mut [MaybeUninit<u8>] {
let bytes: &[u8] = self.view.as_ref().unwrap();
// SAFETY:
// - MaybeUninit<u8> has the same layout as u8
// - Caller promises not to write uninitialized values.
unsafe {
std::slice::from_raw_parts_mut(bytes.as_ptr() as *mut MaybeUninit<u8>, bytes.len())
}
}
fn memory(&self) -> &[std::mem::MaybeUninit<u8>] {
let bytes: &[u8] = self.view.as_ref().unwrap();
// SAFETY:
// - MaybeUninit<u8> has the same layout as u8
unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const MaybeUninit<u8>, bytes.len()) }
}
}
impl WgpuAllocController {
/// Creates a new allocation controller for a managed wgpu staging buffer.
///
/// # Arguments
///
/// * `binding` - The memory binding for the managed buffer.
/// * `buffer` - The wgpu buffer.
///
/// # Returns
///
/// The controller.
pub fn init(binding: ManagedMemoryBinding, buffer: wgpu::Buffer) -> Self {
// Needs immutable as of wgpu v29, mutable doesn't allow dereferencing as slice.
// This only affects wgpu's internal overlap checks so it should be fine as long as we
// map the whole buffer anyways.
let buf_view = buffer.get_mapped_range(..);
Self {
view: Some(buf_view),
buffer,
_binding: binding,
}
}
}