cubecl_runtime/storage/
bytes_cpu.rs

1use crate::server::IoError;
2
3use super::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
4use alloc::alloc::{Layout, alloc, dealloc};
5use cubecl_common::backtrace::BackTrace;
6use hashbrown::HashMap;
7
8/// The bytes storage maps ids to pointers of bytes in a contiguous layout.
9#[derive(Default)]
10pub struct BytesStorage {
11    memory: HashMap<StorageId, AllocatedBytes>,
12}
13
14impl core::fmt::Debug for BytesStorage {
15    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
16        f.write_str("BytesStorage")
17    }
18}
19
20/// Can send to other threads.
21unsafe impl Send for BytesStorage {}
22unsafe impl Send for BytesResource {}
23
24/// This struct is a pointer to a memory chunk or slice.
25#[derive(Debug, Clone)]
26pub struct BytesResource {
27    ptr: *mut u8,
28    utilization: StorageUtilization,
29}
30
31/// This struct refers to a specific (contiguous) layout of bytes.
32struct AllocatedBytes {
33    ptr: *mut u8,
34    layout: Layout,
35}
36
37impl BytesResource {
38    fn get_exact_location_and_length(&self) -> (*mut u8, usize) {
39        unsafe {
40            (
41                self.ptr.add(self.utilization.offset as usize),
42                self.utilization.size as usize,
43            )
44        }
45    }
46
47    /// Get the ptr this resource points to.
48    pub fn mut_ptr(&mut self) -> *mut u8 {
49        let (ptr, _) = self.get_exact_location_and_length();
50        ptr
51    }
52
53    /// Returns the resource as a mutable slice of bytes.
54    pub fn write<'a>(&mut self) -> &'a mut [u8] {
55        let (ptr, len) = self.get_exact_location_and_length();
56
57        // TODO: This is not safe if there are multiple resources which have a pointer.
58        // SAFETY:
59        // - ptr is constructed to not be null and aligned.
60        // - Total size of the allocation is at least `len`.
61        // - The total len is <= isize::MAX.
62        unsafe { core::slice::from_raw_parts_mut(ptr, len) }
63    }
64
65    /// Returns the resource as an immutable slice of bytes.
66    pub fn read<'a>(&self) -> &'a [u8] {
67        let (ptr, len) = self.get_exact_location_and_length();
68
69        // TODO: This is not safe if there are multiple resources which have a pointer.
70        //
71        // SAFETY:
72        // - ptr is constructed to not be null and aligned.
73        // - Total size of the allocation is at least `len`.
74        // - The total len is <= isize::MAX.
75        unsafe { core::slice::from_raw_parts(ptr, len) }
76    }
77}
78
79impl ComputeStorage for BytesStorage {
80    type Resource = BytesResource;
81
82    fn alignment(&self) -> usize {
83        4
84    }
85
86    fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
87        let allocated_bytes = self.memory.get(&handle.id).unwrap();
88
89        BytesResource {
90            ptr: allocated_bytes.ptr,
91            utilization: handle.utilization.clone(),
92        }
93    }
94
95    fn alloc(&mut self, size: u64) -> Result<StorageHandle, IoError> {
96        let id = StorageId::new();
97        let handle = StorageHandle {
98            id,
99            utilization: StorageUtilization { offset: 0, size },
100        };
101
102        unsafe {
103            let layout = Layout::array::<u8>(size as usize).unwrap();
104            let ptr = alloc(layout);
105            if ptr.is_null() {
106                // Assume allocation failure is OOM, we can't see the actual error on stable
107                return Err(IoError::BufferTooBig {
108                    size,
109                    backtrace: BackTrace::capture(),
110                });
111            }
112            let memory = AllocatedBytes { ptr, layout };
113
114            self.memory.insert(id, memory);
115        }
116
117        Ok(handle)
118    }
119
120    fn dealloc(&mut self, id: StorageId) {
121        if let Some(memory) = self.memory.remove(&id) {
122            unsafe {
123                dealloc(memory.ptr, memory.layout);
124            }
125        }
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    #[test]
134    fn test_can_alloc_and_dealloc() {
135        let mut storage = BytesStorage::default();
136        let handle_1 = storage.alloc(64).unwrap();
137
138        assert_eq!(handle_1.size(), 64);
139        storage.dealloc(handle_1.id);
140    }
141
142    #[test]
143    fn test_slices() {
144        let mut storage = BytesStorage::default();
145        let handle_1 = storage.alloc(64).unwrap();
146        let handle_2 = StorageHandle::new(
147            handle_1.id,
148            StorageUtilization {
149                offset: 24,
150                size: 8,
151            },
152        );
153
154        storage
155            .get(&handle_1)
156            .write()
157            .iter_mut()
158            .enumerate()
159            .for_each(|(i, b)| {
160                *b = i as u8;
161            });
162
163        let bytes = storage.get(&handle_2).read().to_vec();
164        storage.dealloc(handle_1.id);
165        assert_eq!(bytes, &[24, 25, 26, 27, 28, 29, 30, 31]);
166    }
167}