Skip to main content

cubecl_runtime/storage/
bytes_cpu.rs

1use crate::server::IoError;
2
3use super::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
4use alloc::alloc::{Layout, alloc_zeroed, dealloc};
5use cubecl_common::backtrace::BackTrace;
6use hashbrown::HashMap;
7
8/// The bytes storage maps ids to pointers of bytes in a contiguous layout.
9#[derive(Default)]
10pub struct BytesStorage {
11    memory: HashMap<StorageId, AllocatedBytes>,
12}
13
14impl core::fmt::Debug for BytesStorage {
15    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
16        f.write_str("BytesStorage")
17    }
18}
19
20/// Can send to other threads.
21unsafe impl Send for BytesStorage {}
22unsafe impl Send for BytesResource {}
23
24/// This struct is a pointer to a memory chunk or slice.
25#[derive(Debug)]
26pub struct BytesResource {
27    ptr: *mut u8,
28    utilization: StorageUtilization,
29}
30
31/// This struct refers to a specific (contiguous) layout of bytes.
32struct AllocatedBytes {
33    ptr: *mut u8,
34    layout: Layout,
35}
36
37impl BytesResource {
38    /// Returns a mutable pointer to the start of the resource and its length.
39    pub fn get_write_ptr_and_length(&self) -> (*mut u8, usize) {
40        (
41            // SAFETY:
42            // - The offset is created to be within the bounds of the allocation.
43            unsafe { self.ptr.add(self.utilization.offset as usize) },
44            self.utilization.size as usize,
45        )
46    }
47
48    /// Returns the resource as a mutable slice of bytes.
49    ///
50    /// The lifetime `'a` is the lifetime of the underlying `BytesStorage` allocation,
51    /// not of `self`. The `&mut self` ensures only one mutable slice is created per
52    /// resource. Multiple resources may point to non-overlapping regions of the same
53    /// allocation (like `split_at_mut`); each resource owns its region exclusively.
54    pub fn write<'a>(&mut self) -> &'a mut [u8] {
55        let (ptr, len) = self.get_write_ptr_and_length();
56
57        // SAFETY:
58        // - ptr is non-null and aligned (from BytesStorage::alloc).
59        // - The region [ptr..ptr+len) is within a single allocation.
60        // - Memory is initialized (BytesStorage uses alloc_zeroed).
61        // - `&mut self` ensures exclusive access to this resource's region.
62        // - `StorageHandle` assigns non-overlapping regions per resource.
63        // - Systems must make sure this is the only `BytesResource` with an outstanding mutable borrow.
64        unsafe { core::slice::from_raw_parts_mut(ptr, len) }
65    }
66
67    /// Returns the resource as an immutable slice of bytes.
68    ///
69    /// See [`write`](Self::write) for lifetime and safety notes.
70    pub fn read<'a>(&self) -> &'a [u8] {
71        let (ptr, len) = self.get_write_ptr_and_length();
72
73        // SAFETY:
74        // - ptr is non-null and aligned (from BytesStorage::alloc).
75        // - The region [ptr..ptr+len) is within a single allocation.
76        // - Memory is initialized (BytesStorage uses alloc_zeroed).
77        unsafe { core::slice::from_raw_parts(ptr, len) }
78    }
79}
80
81impl ComputeStorage for BytesStorage {
82    type Resource = BytesResource;
83
84    fn alignment(&self) -> usize {
85        4
86    }
87
88    fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
89        let allocated_bytes = self.memory.get(&handle.id).unwrap();
90
91        BytesResource {
92            ptr: allocated_bytes.ptr,
93            utilization: handle.utilization.clone(),
94        }
95    }
96
97    #[cfg_attr(
98        feature = "tracing",
99        tracing::instrument(level = "trace", skip(self, size))
100    )]
101    fn alloc(&mut self, size: u64) -> Result<StorageHandle, IoError> {
102        let id = StorageId::new();
103        let handle = StorageHandle {
104            id,
105            utilization: StorageUtilization { offset: 0, size },
106        };
107
108        if size == 0 {
109            // Zero-size allocations are valid handles but don't need real memory.
110            let memory = AllocatedBytes {
111                ptr: core::ptr::NonNull::dangling().as_ptr(),
112                layout: Layout::new::<()>(),
113            };
114            self.memory.insert(id, memory);
115        } else {
116            unsafe {
117                let layout = Layout::array::<u8>(size as usize).unwrap();
118
119                // We allocate zeroed memory since we expose it as &[u8] / &mut [u8]
120                // which requires initialization.
121                let ptr = alloc_zeroed(layout);
122                if ptr.is_null() {
123                    return Err(IoError::BufferTooBig {
124                        size,
125                        backtrace: BackTrace::capture(),
126                    });
127                }
128                let memory = AllocatedBytes { ptr, layout };
129                self.memory.insert(id, memory);
130            }
131        }
132
133        Ok(handle)
134    }
135
136    #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
137    fn dealloc(&mut self, id: StorageId) {
138        if let Some(memory) = self.memory.remove(&id)
139            && memory.layout.size() > 0
140        {
141            unsafe {
142                dealloc(memory.ptr, memory.layout);
143            }
144        }
145    }
146
147    fn flush(&mut self) {
148        // We don't wait for dealloc.
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test_log::test]
157    fn test_can_alloc_and_dealloc() {
158        let mut storage = BytesStorage::default();
159        let handle_1 = storage.alloc(64).unwrap();
160
161        assert_eq!(handle_1.size(), 64);
162        storage.dealloc(handle_1.id);
163    }
164
165    #[test_log::test]
166    fn test_slices() {
167        let mut storage = BytesStorage::default();
168        let handle_1 = storage.alloc(64).unwrap();
169        let handle_2 = StorageHandle::new(
170            handle_1.id,
171            StorageUtilization {
172                offset: 24,
173                size: 8,
174            },
175        );
176
177        storage
178            .get(&handle_1)
179            .write()
180            .iter_mut()
181            .enumerate()
182            .for_each(|(i, b)| {
183                *b = i as u8;
184            });
185
186        let bytes = storage.get(&handle_2).read().to_vec();
187
188        storage.dealloc(handle_1.id);
189        assert_eq!(bytes, &[24, 25, 26, 27, 28, 29, 30, 31]);
190    }
191
192    /// Miri catches: "reading memory, but memory is uninitialized"
193    #[test_log::test]
194    fn test_read_after_alloc_without_write() {
195        let mut storage = BytesStorage::default();
196        let handle = storage.alloc(16).unwrap();
197        let resource = storage.get(&handle);
198        assert!(resource.read().iter().all(|&b| b == 0));
199        storage.dealloc(handle.id);
200    }
201
202    /// Miri catches: "creating allocation with size 0"
203    #[test_log::test]
204    fn test_zero_size_alloc_and_dealloc() {
205        let mut storage = BytesStorage::default();
206        let handle = storage.alloc(0).unwrap();
207        assert_eq!(handle.size(), 0);
208        storage.dealloc(handle.id);
209    }
210
211    #[test_log::test]
212    fn test_alloc_dealloc_realloc() {
213        let mut storage = BytesStorage::default();
214        let h1 = storage.alloc(32).unwrap();
215        storage.get(&h1).write()[0] = 0xAA;
216        storage.dealloc(h1.id);
217        let h2 = storage.alloc(32).unwrap();
218        storage.dealloc(h2.id);
219    }
220
221    #[test_log::test]
222    fn test_multiple_non_overlapping_regions() {
223        let mut storage = BytesStorage::default();
224        let base = storage.alloc(64).unwrap();
225
226        let regions: alloc::vec::Vec<_> = (0..4)
227            .map(|i| {
228                StorageHandle::new(
229                    base.id,
230                    StorageUtilization {
231                        offset: i * 16,
232                        size: 16,
233                    },
234                )
235            })
236            .collect();
237
238        for (i, region) in regions.iter().enumerate() {
239            storage.get(region).write().fill(i as u8);
240        }
241        for (i, region) in regions.iter().enumerate() {
242            assert!(storage.get(region).read().iter().all(|&b| b == i as u8));
243        }
244        storage.dealloc(base.id);
245    }
246}