cubecl_runtime/storage/
bytes_cpu.rs1use crate::server::IoError;
2
3use super::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
4use alloc::alloc::{Layout, alloc, dealloc};
5use cubecl_common::backtrace::BackTrace;
6use hashbrown::HashMap;
7
8#[derive(Default)]
10pub struct BytesStorage {
11 memory: HashMap<StorageId, AllocatedBytes>,
12}
13
14impl core::fmt::Debug for BytesStorage {
15 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
16 f.write_str("BytesStorage")
17 }
18}
19
20unsafe impl Send for BytesStorage {}
22unsafe impl Send for BytesResource {}
23
24#[derive(Debug, Clone)]
26pub struct BytesResource {
27 ptr: *mut u8,
28 utilization: StorageUtilization,
29}
30
31struct AllocatedBytes {
33 ptr: *mut u8,
34 layout: Layout,
35}
36
37impl BytesResource {
38 fn get_exact_location_and_length(&self) -> (*mut u8, usize) {
39 unsafe {
40 (
41 self.ptr.add(self.utilization.offset as usize),
42 self.utilization.size as usize,
43 )
44 }
45 }
46
47 pub fn mut_ptr(&mut self) -> *mut u8 {
49 let (ptr, _) = self.get_exact_location_and_length();
50 ptr
51 }
52
53 pub fn write<'a>(&mut self) -> &'a mut [u8] {
55 let (ptr, len) = self.get_exact_location_and_length();
56
57 unsafe { core::slice::from_raw_parts_mut(ptr, len) }
63 }
64
65 pub fn read<'a>(&self) -> &'a [u8] {
67 let (ptr, len) = self.get_exact_location_and_length();
68
69 unsafe { core::slice::from_raw_parts(ptr, len) }
76 }
77}
78
79impl ComputeStorage for BytesStorage {
80 type Resource = BytesResource;
81
82 fn alignment(&self) -> usize {
83 4
84 }
85
86 fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
87 let allocated_bytes = self.memory.get(&handle.id).unwrap();
88
89 BytesResource {
90 ptr: allocated_bytes.ptr,
91 utilization: handle.utilization.clone(),
92 }
93 }
94
95 #[cfg_attr(
96 feature = "tracing",
97 tracing::instrument(level = "trace", skip(self, size))
98 )]
99 fn alloc(&mut self, size: u64) -> Result<StorageHandle, IoError> {
100 let id = StorageId::new();
101 let handle = StorageHandle {
102 id,
103 utilization: StorageUtilization { offset: 0, size },
104 };
105
106 unsafe {
107 let layout = Layout::array::<u8>(size as usize).unwrap();
108 let ptr = alloc(layout);
109 if ptr.is_null() {
110 return Err(IoError::BufferTooBig {
112 size,
113 backtrace: BackTrace::capture(),
114 });
115 }
116 let memory = AllocatedBytes { ptr, layout };
117
118 self.memory.insert(id, memory);
119 }
120
121 Ok(handle)
122 }
123
124 #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
125 fn dealloc(&mut self, id: StorageId) {
126 if let Some(memory) = self.memory.remove(&id) {
127 unsafe {
128 dealloc(memory.ptr, memory.layout);
129 }
130 }
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137
138 #[test_log::test]
139 fn test_can_alloc_and_dealloc() {
140 let mut storage = BytesStorage::default();
141 let handle_1 = storage.alloc(64).unwrap();
142
143 assert_eq!(handle_1.size(), 64);
144 storage.dealloc(handle_1.id);
145 }
146
147 #[test_log::test]
148 fn test_slices() {
149 let mut storage = BytesStorage::default();
150 let handle_1 = storage.alloc(64).unwrap();
151 let handle_2 = StorageHandle::new(
152 handle_1.id,
153 StorageUtilization {
154 offset: 24,
155 size: 8,
156 },
157 );
158
159 storage
160 .get(&handle_1)
161 .write()
162 .iter_mut()
163 .enumerate()
164 .for_each(|(i, b)| {
165 *b = i as u8;
166 });
167
168 let bytes = storage.get(&handle_2).read().to_vec();
169 storage.dealloc(handle_1.id);
170 assert_eq!(bytes, &[24, 25, 26, 27, 28, 29, 30, 31]);
171 }
172}