cubecl_runtime/storage/
bytes_cpu.rs

1use crate::server::IoError;
2
3use super::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
4use alloc::alloc::{Layout, alloc, dealloc};
5use hashbrown::HashMap;
6
7/// The bytes storage maps ids to pointers of bytes in a contiguous layout.
8#[derive(Default)]
9pub struct BytesStorage {
10    memory: HashMap<StorageId, AllocatedBytes>,
11}
12
13impl core::fmt::Debug for BytesStorage {
14    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
15        f.write_str("BytesStorage")
16    }
17}
18
19/// Can send to other threads.
20unsafe impl Send for BytesStorage {}
21unsafe impl Send for BytesResource {}
22
23/// This struct is a pointer to a memory chunk or slice.
24#[derive(Debug, Clone)]
25pub struct BytesResource {
26    ptr: *mut u8,
27    utilization: StorageUtilization,
28}
29
30/// This struct refers to a specific (contiguous) layout of bytes.
31struct AllocatedBytes {
32    ptr: *mut u8,
33    layout: Layout,
34}
35
36impl BytesResource {
37    fn get_exact_location_and_length(&self) -> (*mut u8, usize) {
38        unsafe {
39            (
40                self.ptr.add(self.utilization.offset as usize),
41                self.utilization.size as usize,
42            )
43        }
44    }
45
46    /// Get the ptr this resource points to.
47    pub fn mut_ptr(&mut self) -> *mut u8 {
48        let (ptr, _) = self.get_exact_location_and_length();
49        ptr
50    }
51
52    /// Returns the resource as a mutable slice of bytes.
53    pub fn write<'a>(&mut self) -> &'a mut [u8] {
54        let (ptr, len) = self.get_exact_location_and_length();
55
56        // TODO: This is not safe if there are multiple resources which have a pointer.
57        // SAFETY:
58        // - ptr is constructed to not be null and aligned.
59        // - Total size of the allocation is at least `len`.
60        // - The total len is <= isize::MAX.
61        unsafe { core::slice::from_raw_parts_mut(ptr, len) }
62    }
63
64    /// Returns the resource as an immutable slice of bytes.
65    pub fn read<'a>(&self) -> &'a [u8] {
66        let (ptr, len) = self.get_exact_location_and_length();
67
68        // TODO: This is not safe if there are multiple resources which have a pointer.
69        //
70        // SAFETY:
71        // - ptr is constructed to not be null and aligned.
72        // - Total size of the allocation is at least `len`.
73        // - The total len is <= isize::MAX.
74        unsafe { core::slice::from_raw_parts(ptr, len) }
75    }
76}
77
78impl ComputeStorage for BytesStorage {
79    type Resource = BytesResource;
80
81    fn alignment(&self) -> usize {
82        4
83    }
84
85    fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
86        let allocated_bytes = self.memory.get(&handle.id).unwrap();
87
88        BytesResource {
89            ptr: allocated_bytes.ptr,
90            utilization: handle.utilization.clone(),
91        }
92    }
93
94    fn alloc(&mut self, size: u64) -> Result<StorageHandle, IoError> {
95        let id = StorageId::new();
96        let handle = StorageHandle {
97            id,
98            utilization: StorageUtilization { offset: 0, size },
99        };
100
101        unsafe {
102            let layout = Layout::array::<u8>(size as usize).unwrap();
103            let ptr = alloc(layout);
104            if ptr.is_null() {
105                // Assume allocation failure is OOM, we can't see the actual error on stable
106                return Err(IoError::BufferTooBig(size as usize));
107            }
108            let memory = AllocatedBytes { ptr, layout };
109
110            self.memory.insert(id, memory);
111        }
112
113        Ok(handle)
114    }
115
116    fn dealloc(&mut self, id: StorageId) {
117        if let Some(memory) = self.memory.remove(&id) {
118            unsafe {
119                dealloc(memory.ptr, memory.layout);
120            }
121        }
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128
129    #[test]
130    fn test_can_alloc_and_dealloc() {
131        let mut storage = BytesStorage::default();
132        let handle_1 = storage.alloc(64).unwrap();
133
134        assert_eq!(handle_1.size(), 64);
135        storage.dealloc(handle_1.id);
136    }
137
138    #[test]
139    fn test_slices() {
140        let mut storage = BytesStorage::default();
141        let handle_1 = storage.alloc(64).unwrap();
142        let handle_2 = StorageHandle::new(
143            handle_1.id,
144            StorageUtilization {
145                offset: 24,
146                size: 8,
147            },
148        );
149
150        storage
151            .get(&handle_1)
152            .write()
153            .iter_mut()
154            .enumerate()
155            .for_each(|(i, b)| {
156                *b = i as u8;
157            });
158
159        let bytes = storage.get(&handle_2).read().to_vec();
160        storage.dealloc(handle_1.id);
161        assert_eq!(bytes, &[24, 25, 26, 27, 28, 29, 30, 31]);
162    }
163}