cubecl_runtime/storage/
bytes_cpu.rs1use crate::server::IoError;
2
3use super::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
4use alloc::alloc::{Layout, alloc_zeroed, dealloc};
5use cubecl_common::backtrace::BackTrace;
6use hashbrown::HashMap;
7
8#[derive(Default)]
10pub struct BytesStorage {
11 memory: HashMap<StorageId, AllocatedBytes>,
12}
13
14impl core::fmt::Debug for BytesStorage {
15 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
16 f.write_str("BytesStorage")
17 }
18}
19
20unsafe impl Send for BytesStorage {}
22unsafe impl Send for BytesResource {}
23
24#[derive(Debug)]
26pub struct BytesResource {
27 ptr: *mut u8,
28 utilization: StorageUtilization,
29}
30
31struct AllocatedBytes {
33 ptr: *mut u8,
34 layout: Layout,
35}
36
37impl BytesResource {
38 pub fn get_write_ptr_and_length(&self) -> (*mut u8, usize) {
40 (
41 unsafe { self.ptr.add(self.utilization.offset as usize) },
44 self.utilization.size as usize,
45 )
46 }
47
48 pub fn write<'a>(&mut self) -> &'a mut [u8] {
55 let (ptr, len) = self.get_write_ptr_and_length();
56
57 unsafe { core::slice::from_raw_parts_mut(ptr, len) }
65 }
66
67 pub fn read<'a>(&self) -> &'a [u8] {
71 let (ptr, len) = self.get_write_ptr_and_length();
72
73 unsafe { core::slice::from_raw_parts(ptr, len) }
78 }
79}
80
81impl ComputeStorage for BytesStorage {
82 type Resource = BytesResource;
83
84 fn alignment(&self) -> usize {
85 4
86 }
87
88 fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
89 let allocated_bytes = self.memory.get(&handle.id).unwrap();
90
91 BytesResource {
92 ptr: allocated_bytes.ptr,
93 utilization: handle.utilization.clone(),
94 }
95 }
96
97 #[cfg_attr(
98 feature = "tracing",
99 tracing::instrument(level = "trace", skip(self, size))
100 )]
101 fn alloc(&mut self, size: u64) -> Result<StorageHandle, IoError> {
102 let id = StorageId::new();
103 let handle = StorageHandle {
104 id,
105 utilization: StorageUtilization { offset: 0, size },
106 };
107
108 if size == 0 {
109 let memory = AllocatedBytes {
111 ptr: core::ptr::NonNull::dangling().as_ptr(),
112 layout: Layout::new::<()>(),
113 };
114 self.memory.insert(id, memory);
115 } else {
116 unsafe {
117 let layout = Layout::array::<u8>(size as usize).unwrap();
118
119 let ptr = alloc_zeroed(layout);
122 if ptr.is_null() {
123 return Err(IoError::BufferTooBig {
124 size,
125 backtrace: BackTrace::capture(),
126 });
127 }
128 let memory = AllocatedBytes { ptr, layout };
129 self.memory.insert(id, memory);
130 }
131 }
132
133 Ok(handle)
134 }
135
136 #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
137 fn dealloc(&mut self, id: StorageId) {
138 if let Some(memory) = self.memory.remove(&id)
139 && memory.layout.size() > 0
140 {
141 unsafe {
142 dealloc(memory.ptr, memory.layout);
143 }
144 }
145 }
146
147 fn flush(&mut self) {
148 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 #[test_log::test]
157 fn test_can_alloc_and_dealloc() {
158 let mut storage = BytesStorage::default();
159 let handle_1 = storage.alloc(64).unwrap();
160
161 assert_eq!(handle_1.size(), 64);
162 storage.dealloc(handle_1.id);
163 }
164
165 #[test_log::test]
166 fn test_slices() {
167 let mut storage = BytesStorage::default();
168 let handle_1 = storage.alloc(64).unwrap();
169 let handle_2 = StorageHandle::new(
170 handle_1.id,
171 StorageUtilization {
172 offset: 24,
173 size: 8,
174 },
175 );
176
177 storage
178 .get(&handle_1)
179 .write()
180 .iter_mut()
181 .enumerate()
182 .for_each(|(i, b)| {
183 *b = i as u8;
184 });
185
186 let bytes = storage.get(&handle_2).read().to_vec();
187
188 storage.dealloc(handle_1.id);
189 assert_eq!(bytes, &[24, 25, 26, 27, 28, 29, 30, 31]);
190 }
191
192 #[test_log::test]
194 fn test_read_after_alloc_without_write() {
195 let mut storage = BytesStorage::default();
196 let handle = storage.alloc(16).unwrap();
197 let resource = storage.get(&handle);
198 assert!(resource.read().iter().all(|&b| b == 0));
199 storage.dealloc(handle.id);
200 }
201
202 #[test_log::test]
204 fn test_zero_size_alloc_and_dealloc() {
205 let mut storage = BytesStorage::default();
206 let handle = storage.alloc(0).unwrap();
207 assert_eq!(handle.size(), 0);
208 storage.dealloc(handle.id);
209 }
210
211 #[test_log::test]
212 fn test_alloc_dealloc_realloc() {
213 let mut storage = BytesStorage::default();
214 let h1 = storage.alloc(32).unwrap();
215 storage.get(&h1).write()[0] = 0xAA;
216 storage.dealloc(h1.id);
217 let h2 = storage.alloc(32).unwrap();
218 storage.dealloc(h2.id);
219 }
220
221 #[test_log::test]
222 fn test_multiple_non_overlapping_regions() {
223 let mut storage = BytesStorage::default();
224 let base = storage.alloc(64).unwrap();
225
226 let regions: alloc::vec::Vec<_> = (0..4)
227 .map(|i| {
228 StorageHandle::new(
229 base.id,
230 StorageUtilization {
231 offset: i * 16,
232 size: 16,
233 },
234 )
235 })
236 .collect();
237
238 for (i, region) in regions.iter().enumerate() {
239 storage.get(region).write().fill(i as u8);
240 }
241 for (i, region) in regions.iter().enumerate() {
242 assert!(storage.get(region).read().iter().all(|&b| b == i as u8));
243 }
244 storage.dealloc(base.id);
245 }
246}