cubecl_runtime/storage/
bytes_cpu.rs1use super::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
2use alloc::alloc::{Layout, alloc, dealloc};
3use hashbrown::HashMap;
4
5#[derive(Default)]
7pub struct BytesStorage {
8 memory: HashMap<StorageId, AllocatedBytes>,
9}
10
11impl core::fmt::Debug for BytesStorage {
12 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
13 f.write_str("BytesStorage")
14 }
15}
16
17unsafe impl Send for BytesStorage {}
19unsafe impl Send for BytesResource {}
20
21pub struct BytesResource {
23 ptr: *mut u8,
24 utilization: StorageUtilization,
25}
26
27struct AllocatedBytes {
29 ptr: *mut u8,
30 layout: Layout,
31}
32
33impl BytesResource {
34 fn get_exact_location_and_length(&self) -> (*mut u8, usize) {
35 unsafe {
36 (
37 self.ptr.add(self.utilization.offset as usize),
38 self.utilization.size as usize,
39 )
40 }
41 }
42
43 pub fn write<'a>(&self) -> &'a mut [u8] {
45 let (ptr, len) = self.get_exact_location_and_length();
46
47 unsafe { core::slice::from_raw_parts_mut(ptr, len) }
48 }
49
50 pub fn read<'a>(&self) -> &'a [u8] {
52 let (ptr, len) = self.get_exact_location_and_length();
53
54 unsafe { core::slice::from_raw_parts(ptr, len) }
55 }
56}
57
58impl ComputeStorage for BytesStorage {
59 type Resource = BytesResource;
60
61 fn alignment(&self) -> usize {
62 4
63 }
64
65 fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
66 let allocated_bytes = self.memory.get(&handle.id).unwrap();
67
68 BytesResource {
69 ptr: allocated_bytes.ptr,
70 utilization: handle.utilization.clone(),
71 }
72 }
73
74 fn alloc(&mut self, size: u64) -> StorageHandle {
75 let id = StorageId::new();
76 let handle = StorageHandle {
77 id,
78 utilization: StorageUtilization { offset: 0, size },
79 };
80
81 unsafe {
82 let layout = Layout::array::<u8>(size as usize).unwrap();
83 let ptr = alloc(layout);
84 let memory = AllocatedBytes { ptr, layout };
85
86 self.memory.insert(id, memory);
87 }
88
89 handle
90 }
91
92 fn dealloc(&mut self, id: StorageId) {
93 if let Some(memory) = self.memory.remove(&id) {
94 unsafe {
95 dealloc(memory.ptr, memory.layout);
96 }
97 }
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104
105 #[test]
106 fn test_can_alloc_and_dealloc() {
107 let mut storage = BytesStorage::default();
108 let handle_1 = storage.alloc(64);
109
110 assert_eq!(handle_1.size(), 64);
111 storage.dealloc(handle_1.id);
112 }
113
114 #[test]
115 fn test_slices() {
116 let mut storage = BytesStorage::default();
117 let handle_1 = storage.alloc(64);
118 let handle_2 = StorageHandle::new(
119 handle_1.id,
120 StorageUtilization {
121 offset: 24,
122 size: 8,
123 },
124 );
125
126 storage
127 .get(&handle_1)
128 .write()
129 .iter_mut()
130 .enumerate()
131 .for_each(|(i, b)| {
132 *b = i as u8;
133 });
134
135 let bytes = storage.get(&handle_2).read().to_vec();
136 storage.dealloc(handle_1.id);
137 assert_eq!(bytes, &[24, 25, 26, 27, 28, 29, 30, 31]);
138 }
139}