cubecl_runtime/storage/
bytes_cpu.rs1use crate::server::IoError;
2
3use super::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
4use alloc::alloc::{Layout, alloc, dealloc};
5use hashbrown::HashMap;
6
7#[derive(Default)]
9pub struct BytesStorage {
10 memory: HashMap<StorageId, AllocatedBytes>,
11}
12
13impl core::fmt::Debug for BytesStorage {
14 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
15 f.write_str("BytesStorage")
16 }
17}
18
19unsafe impl Send for BytesStorage {}
21unsafe impl Send for BytesResource {}
22
23#[derive(Debug, Clone)]
25pub struct BytesResource {
26 ptr: *mut u8,
27 utilization: StorageUtilization,
28}
29
30struct AllocatedBytes {
32 ptr: *mut u8,
33 layout: Layout,
34}
35
36impl BytesResource {
37 fn get_exact_location_and_length(&self) -> (*mut u8, usize) {
38 unsafe {
39 (
40 self.ptr.add(self.utilization.offset as usize),
41 self.utilization.size as usize,
42 )
43 }
44 }
45
46 pub fn mut_ptr(&mut self) -> *mut u8 {
48 let (ptr, _) = self.get_exact_location_and_length();
49 ptr
50 }
51
52 pub fn write<'a>(&mut self) -> &'a mut [u8] {
54 let (ptr, len) = self.get_exact_location_and_length();
55
56 unsafe { core::slice::from_raw_parts_mut(ptr, len) }
62 }
63
64 pub fn read<'a>(&self) -> &'a [u8] {
66 let (ptr, len) = self.get_exact_location_and_length();
67
68 unsafe { core::slice::from_raw_parts(ptr, len) }
75 }
76}
77
78impl ComputeStorage for BytesStorage {
79 type Resource = BytesResource;
80
81 fn alignment(&self) -> usize {
82 4
83 }
84
85 fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
86 let allocated_bytes = self.memory.get(&handle.id).unwrap();
87
88 BytesResource {
89 ptr: allocated_bytes.ptr,
90 utilization: handle.utilization.clone(),
91 }
92 }
93
94 fn alloc(&mut self, size: u64) -> Result<StorageHandle, IoError> {
95 let id = StorageId::new();
96 let handle = StorageHandle {
97 id,
98 utilization: StorageUtilization { offset: 0, size },
99 };
100
101 unsafe {
102 let layout = Layout::array::<u8>(size as usize).unwrap();
103 let ptr = alloc(layout);
104 if ptr.is_null() {
105 return Err(IoError::BufferTooBig(size as usize));
107 }
108 let memory = AllocatedBytes { ptr, layout };
109
110 self.memory.insert(id, memory);
111 }
112
113 Ok(handle)
114 }
115
116 fn dealloc(&mut self, id: StorageId) {
117 if let Some(memory) = self.memory.remove(&id) {
118 unsafe {
119 dealloc(memory.ptr, memory.layout);
120 }
121 }
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128
129 #[test]
130 fn test_can_alloc_and_dealloc() {
131 let mut storage = BytesStorage::default();
132 let handle_1 = storage.alloc(64).unwrap();
133
134 assert_eq!(handle_1.size(), 64);
135 storage.dealloc(handle_1.id);
136 }
137
138 #[test]
139 fn test_slices() {
140 let mut storage = BytesStorage::default();
141 let handle_1 = storage.alloc(64).unwrap();
142 let handle_2 = StorageHandle::new(
143 handle_1.id,
144 StorageUtilization {
145 offset: 24,
146 size: 8,
147 },
148 );
149
150 storage
151 .get(&handle_1)
152 .write()
153 .iter_mut()
154 .enumerate()
155 .for_each(|(i, b)| {
156 *b = i as u8;
157 });
158
159 let bytes = storage.get(&handle_2).read().to_vec();
160 storage.dealloc(handle_1.id);
161 assert_eq!(bytes, &[24, 25, 26, 27, 28, 29, 30, 31]);
162 }
163}