burn_compute/storage/
bytes_cpu.rs

1use super::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
2use alloc::alloc::{alloc, dealloc, Layout};
3use hashbrown::HashMap;
4
5/// The bytes storage maps ids to pointers of bytes in a contiguous layout.
6#[derive(Default)]
7pub struct BytesStorage {
8    memory: HashMap<StorageId, AllocatedBytes>,
9}
10
11impl core::fmt::Debug for BytesStorage {
12    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
13        f.write_str("BytesStorage")
14    }
15}
16
17/// Can send to other threads.
18unsafe impl Send for BytesStorage {}
19unsafe impl Send for BytesResource {}
20
21/// This struct is a pointer to a memory chunk or slice.
22pub struct BytesResource {
23    ptr: *mut u8,
24    utilization: StorageUtilization,
25}
26
27/// This struct refers to a specific (contiguous) layout of bytes.
28struct AllocatedBytes {
29    ptr: *mut u8,
30    layout: Layout,
31}
32
33impl BytesResource {
34    fn get_exact_location_and_length(&self) -> (*mut u8, usize) {
35        match self.utilization {
36            StorageUtilization::Full(len) => (self.ptr, len),
37            StorageUtilization::Slice(location, len) => unsafe { (self.ptr.add(location), len) },
38        }
39    }
40
41    /// Returns the resource as a mutable slice of bytes.
42    pub fn write<'a>(&self) -> &'a mut [u8] {
43        let (ptr, len) = self.get_exact_location_and_length();
44
45        unsafe { core::slice::from_raw_parts_mut(ptr, len) }
46    }
47
48    /// Returns the resource as an immutable slice of bytes.
49    pub fn read<'a>(&self) -> &'a [u8] {
50        let (ptr, len) = self.get_exact_location_and_length();
51
52        unsafe { core::slice::from_raw_parts(ptr, len) }
53    }
54}
55
56impl ComputeStorage for BytesStorage {
57    type Resource = BytesResource;
58
59    fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
60        let allocated_bytes = self.memory.get_mut(&handle.id).unwrap();
61
62        BytesResource {
63            ptr: allocated_bytes.ptr,
64            utilization: handle.utilization.clone(),
65        }
66    }
67
68    fn alloc(&mut self, size: usize) -> StorageHandle {
69        let id = StorageId::new();
70        let handle = StorageHandle {
71            id: id.clone(),
72            utilization: StorageUtilization::Full(size),
73        };
74
75        unsafe {
76            let layout = Layout::array::<u8>(size).unwrap();
77            let ptr = alloc(layout);
78            let memory = AllocatedBytes { ptr, layout };
79
80            self.memory.insert(id, memory);
81        }
82
83        handle
84    }
85
86    fn dealloc(&mut self, id: StorageId) {
87        if let Some(memory) = self.memory.remove(&id) {
88            unsafe {
89                dealloc(memory.ptr, memory.layout);
90            }
91        }
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98
99    #[test]
100    fn test_can_alloc_and_dealloc() {
101        let mut storage = BytesStorage::default();
102        let handle_1 = storage.alloc(64);
103
104        assert_eq!(handle_1.size(), 64);
105        storage.dealloc(handle_1.id);
106    }
107
108    #[test]
109    fn test_slices() {
110        let mut storage = BytesStorage::default();
111        let handle_1 = storage.alloc(64);
112        let handle_2 = StorageHandle::new(handle_1.id.clone(), StorageUtilization::Slice(24, 8));
113
114        storage
115            .get(&handle_1)
116            .write()
117            .iter_mut()
118            .enumerate()
119            .for_each(|(i, b)| {
120                *b = i as u8;
121            });
122
123        let bytes = storage.get(&handle_2).read().to_vec();
124        storage.dealloc(handle_1.id);
125        assert_eq!(bytes, &[24, 25, 26, 27, 28, 29, 30, 31]);
126    }
127}