cubecl_runtime/storage/
bytes_cpu.rs1use crate::server::IoError;
2
3use super::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
4use alloc::alloc::{Layout, alloc, dealloc};
5use cubecl_common::backtrace::BackTrace;
6use hashbrown::HashMap;
7
8#[derive(Default)]
10pub struct BytesStorage {
11 memory: HashMap<StorageId, AllocatedBytes>,
12}
13
14impl core::fmt::Debug for BytesStorage {
15 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
16 f.write_str("BytesStorage")
17 }
18}
19
20unsafe impl Send for BytesStorage {}
22unsafe impl Send for BytesResource {}
23
24#[derive(Debug, Clone)]
26pub struct BytesResource {
27 ptr: *mut u8,
28 utilization: StorageUtilization,
29}
30
31struct AllocatedBytes {
33 ptr: *mut u8,
34 layout: Layout,
35}
36
37impl BytesResource {
38 fn get_exact_location_and_length(&self) -> (*mut u8, usize) {
39 unsafe {
40 (
41 self.ptr.add(self.utilization.offset as usize),
42 self.utilization.size as usize,
43 )
44 }
45 }
46
47 pub fn mut_ptr(&mut self) -> *mut u8 {
49 let (ptr, _) = self.get_exact_location_and_length();
50 ptr
51 }
52
53 pub fn write<'a>(&mut self) -> &'a mut [u8] {
55 let (ptr, len) = self.get_exact_location_and_length();
56
57 unsafe { core::slice::from_raw_parts_mut(ptr, len) }
63 }
64
65 pub fn read<'a>(&self) -> &'a [u8] {
67 let (ptr, len) = self.get_exact_location_and_length();
68
69 unsafe { core::slice::from_raw_parts(ptr, len) }
76 }
77}
78
79impl ComputeStorage for BytesStorage {
80 type Resource = BytesResource;
81
82 fn alignment(&self) -> usize {
83 4
84 }
85
86 fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
87 let allocated_bytes = self.memory.get(&handle.id).unwrap();
88
89 BytesResource {
90 ptr: allocated_bytes.ptr,
91 utilization: handle.utilization.clone(),
92 }
93 }
94
95 fn alloc(&mut self, size: u64) -> Result<StorageHandle, IoError> {
96 let id = StorageId::new();
97 let handle = StorageHandle {
98 id,
99 utilization: StorageUtilization { offset: 0, size },
100 };
101
102 unsafe {
103 let layout = Layout::array::<u8>(size as usize).unwrap();
104 let ptr = alloc(layout);
105 if ptr.is_null() {
106 return Err(IoError::BufferTooBig {
108 size,
109 backtrace: BackTrace::capture(),
110 });
111 }
112 let memory = AllocatedBytes { ptr, layout };
113
114 self.memory.insert(id, memory);
115 }
116
117 Ok(handle)
118 }
119
120 fn dealloc(&mut self, id: StorageId) {
121 if let Some(memory) = self.memory.remove(&id) {
122 unsafe {
123 dealloc(memory.ptr, memory.layout);
124 }
125 }
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 #[test]
134 fn test_can_alloc_and_dealloc() {
135 let mut storage = BytesStorage::default();
136 let handle_1 = storage.alloc(64).unwrap();
137
138 assert_eq!(handle_1.size(), 64);
139 storage.dealloc(handle_1.id);
140 }
141
142 #[test]
143 fn test_slices() {
144 let mut storage = BytesStorage::default();
145 let handle_1 = storage.alloc(64).unwrap();
146 let handle_2 = StorageHandle::new(
147 handle_1.id,
148 StorageUtilization {
149 offset: 24,
150 size: 8,
151 },
152 );
153
154 storage
155 .get(&handle_1)
156 .write()
157 .iter_mut()
158 .enumerate()
159 .for_each(|(i, b)| {
160 *b = i as u8;
161 });
162
163 let bytes = storage.get(&handle_2).read().to_vec();
164 storage.dealloc(handle_1.id);
165 assert_eq!(bytes, &[24, 25, 26, 27, 28, 29, 30, 31]);
166 }
167}