cubecl_runtime/storage/
bytes_cpu.rs

1use super::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
2use alloc::alloc::{Layout, alloc, dealloc};
3use hashbrown::HashMap;
4
5/// The bytes storage maps ids to pointers of bytes in a contiguous layout.
6#[derive(Default)]
7pub struct BytesStorage {
8    memory: HashMap<StorageId, AllocatedBytes>,
9}
10
11impl core::fmt::Debug for BytesStorage {
12    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
13        f.write_str("BytesStorage")
14    }
15}
16
17/// Can send to other threads.
18unsafe impl Send for BytesStorage {}
19unsafe impl Send for BytesResource {}
20
21/// This struct is a pointer to a memory chunk or slice.
22pub struct BytesResource {
23    ptr: *mut u8,
24    utilization: StorageUtilization,
25}
26
27/// This struct refers to a specific (contiguous) layout of bytes.
28struct AllocatedBytes {
29    ptr: *mut u8,
30    layout: Layout,
31}
32
33impl BytesResource {
34    fn get_exact_location_and_length(&self) -> (*mut u8, usize) {
35        unsafe {
36            (
37                self.ptr.add(self.utilization.offset as usize),
38                self.utilization.size as usize,
39            )
40        }
41    }
42
43    /// Returns the resource as a mutable slice of bytes.
44    pub fn write<'a>(&self) -> &'a mut [u8] {
45        let (ptr, len) = self.get_exact_location_and_length();
46
47        unsafe { core::slice::from_raw_parts_mut(ptr, len) }
48    }
49
50    /// Returns the resource as an immutable slice of bytes.
51    pub fn read<'a>(&self) -> &'a [u8] {
52        let (ptr, len) = self.get_exact_location_and_length();
53
54        unsafe { core::slice::from_raw_parts(ptr, len) }
55    }
56}
57
58impl ComputeStorage for BytesStorage {
59    type Resource = BytesResource;
60
61    fn alignment(&self) -> usize {
62        4
63    }
64
65    fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
66        let allocated_bytes = self.memory.get(&handle.id).unwrap();
67
68        BytesResource {
69            ptr: allocated_bytes.ptr,
70            utilization: handle.utilization.clone(),
71        }
72    }
73
74    fn alloc(&mut self, size: u64) -> StorageHandle {
75        let id = StorageId::new();
76        let handle = StorageHandle {
77            id,
78            utilization: StorageUtilization { offset: 0, size },
79        };
80
81        unsafe {
82            let layout = Layout::array::<u8>(size as usize).unwrap();
83            let ptr = alloc(layout);
84            let memory = AllocatedBytes { ptr, layout };
85
86            self.memory.insert(id, memory);
87        }
88
89        handle
90    }
91
92    fn dealloc(&mut self, id: StorageId) {
93        if let Some(memory) = self.memory.remove(&id) {
94            unsafe {
95                dealloc(memory.ptr, memory.layout);
96            }
97        }
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104
105    #[test]
106    fn test_can_alloc_and_dealloc() {
107        let mut storage = BytesStorage::default();
108        let handle_1 = storage.alloc(64);
109
110        assert_eq!(handle_1.size(), 64);
111        storage.dealloc(handle_1.id);
112    }
113
114    #[test]
115    fn test_slices() {
116        let mut storage = BytesStorage::default();
117        let handle_1 = storage.alloc(64);
118        let handle_2 = StorageHandle::new(
119            handle_1.id,
120            StorageUtilization {
121                offset: 24,
122                size: 8,
123            },
124        );
125
126        storage
127            .get(&handle_1)
128            .write()
129            .iter_mut()
130            .enumerate()
131            .for_each(|(i, b)| {
132                *b = i as u8;
133            });
134
135        let bytes = storage.get(&handle_2).read().to_vec();
136        storage.dealloc(handle_1.id);
137        assert_eq!(bytes, &[24, 25, 26, 27, 28, 29, 30, 31]);
138    }
139}