Skip to main content

rlmesh_spaces/tensor/
storage.rs

1use std::sync::Arc;
2
3/// Alignment guaranteed for buffers allocated by [`Storage`] itself.
4const ALIGN: usize = 64;
5
6/// Reference-counted, immutable byte storage shared by one or more
7/// [`Tensor`](super::Tensor) views.
8///
9/// Buffers allocated by `Storage` ([`from_slice`](Storage::from_slice),
10/// [`zeroed`](Storage::zeroed)) are 64-byte aligned. Buffers adopted from a
11/// caller ([`from_vec`](Storage::from_vec)) keep their original allocation and
12/// carry no alignment guarantee.
13#[derive(Debug, Clone)]
14pub struct Storage(Inner);
15
16#[derive(Debug, Clone)]
17enum Inner {
18    /// Storage-allocated buffer, 64-byte aligned.
19    Aligned(Arc<AlignedBytes>),
20    /// Caller-owned buffer adopted without copying.
21    Adopted(Arc<Vec<u8>>),
22}
23
24impl Storage {
25    /// Adopt an existing buffer without copying. No alignment guarantee.
26    pub fn from_vec(data: Vec<u8>) -> Self {
27        Storage(Inner::Adopted(Arc::new(data)))
28    }
29
30    /// Copy `data` into a fresh 64-byte-aligned buffer.
31    pub fn from_slice(data: &[u8]) -> Self {
32        Self::aligned_with(data.len(), |buf| buf.extend_from_slice(data))
33    }
34
35    /// Allocate a zero-filled 64-byte-aligned buffer of `len` bytes.
36    pub fn zeroed(len: usize) -> Self {
37        Self::aligned_with(len, |buf| buf.resize(buf.len() + len, 0))
38    }
39
40    /// Build a 64-byte-aligned buffer by letting `fill` append exactly `len`
41    /// payload bytes. `fill` must only append; growing past the reserved
42    /// capacity would reallocate and lose the alignment.
43    pub(crate) fn aligned_with(len: usize, fill: impl FnOnce(&mut Vec<u8>)) -> Self {
44        let mut bytes = AlignedBytes::with_aligned_offset(len);
45        fill(&mut bytes.buf);
46        debug_assert_eq!(bytes.buf.len(), bytes.offset + bytes.len);
47        Storage(Inner::Aligned(Arc::new(bytes)))
48    }
49
50    /// The full backing buffer.
51    pub fn as_slice(&self) -> &[u8] {
52        match &self.0 {
53            Inner::Aligned(bytes) => bytes.as_slice(),
54            Inner::Adopted(bytes) => bytes.as_slice(),
55        }
56    }
57
58    /// Length of the backing buffer in bytes.
59    pub fn len(&self) -> usize {
60        self.as_slice().len()
61    }
62
63    /// Whether the backing buffer is empty.
64    pub fn is_empty(&self) -> bool {
65        self.as_slice().is_empty()
66    }
67
68    /// Whether two storages share the same underlying allocation.
69    pub fn ptr_eq(&self, other: &Self) -> bool {
70        match (&self.0, &other.0) {
71            (Inner::Aligned(a), Inner::Aligned(b)) => Arc::ptr_eq(a, b),
72            (Inner::Adopted(a), Inner::Adopted(b)) => Arc::ptr_eq(a, b),
73            _ => false,
74        }
75    }
76}
77
78/// Owned bytes positioned at a 64-byte-aligned offset inside an
79/// over-allocated `Vec`.
80#[derive(Debug)]
81struct AlignedBytes {
82    buf: Vec<u8>,
83    offset: usize,
84    len: usize,
85}
86
87impl AlignedBytes {
88    /// Reserve capacity for `len` payload bytes plus alignment padding and
89    /// fill the padding. The reserved capacity guarantees later appends of up
90    /// to `len` bytes never reallocate, so the offset computed from the heap
91    /// pointer stays aligned.
92    fn with_aligned_offset(len: usize) -> Self {
93        let mut buf = Vec::with_capacity(len + ALIGN - 1);
94        let addr = buf.as_ptr() as usize;
95        let offset = (ALIGN - addr % ALIGN) % ALIGN;
96        buf.resize(offset, 0);
97        Self { buf, offset, len }
98    }
99
100    fn as_slice(&self) -> &[u8] {
101        &self.buf[self.offset..self.offset + self.len]
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108
109    #[test]
110    fn test_aligned_storage_is_64_byte_aligned() {
111        for len in [1usize, 7, 64, 100, 4096] {
112            let storage = Storage::zeroed(len);
113            assert_eq!(storage.as_slice().as_ptr() as usize % ALIGN, 0);
114            assert_eq!(storage.len(), len);
115
116            let storage = Storage::from_slice(&vec![7u8; len]);
117            assert_eq!(storage.as_slice().as_ptr() as usize % ALIGN, 0);
118            assert_eq!(storage.as_slice(), vec![7u8; len].as_slice());
119        }
120    }
121
122    #[test]
123    fn test_from_vec_adopts_without_copying() {
124        let data = vec![1u8, 2, 3, 4];
125        let heap_ptr = data.as_ptr();
126        let storage = Storage::from_vec(data);
127        assert_eq!(storage.as_slice().as_ptr(), heap_ptr);
128        assert_eq!(storage.as_slice(), &[1, 2, 3, 4]);
129    }
130
131    #[test]
132    fn test_ptr_eq_tracks_shared_allocations() {
133        let aligned = Storage::from_slice(&[1, 2, 3]);
134        let adopted = Storage::from_vec(vec![1, 2, 3]);
135
136        assert!(aligned.ptr_eq(&aligned.clone()));
137        assert!(adopted.ptr_eq(&adopted.clone()));
138        assert!(!aligned.ptr_eq(&adopted));
139        assert!(!aligned.ptr_eq(&Storage::from_slice(&[1, 2, 3])));
140        assert!(!adopted.ptr_eq(&Storage::from_vec(vec![1, 2, 3])));
141    }
142
143    #[test]
144    fn test_zero_length_storage() {
145        assert!(Storage::zeroed(0).is_empty());
146        assert!(Storage::from_slice(&[]).is_empty());
147        assert!(Storage::from_vec(Vec::new()).is_empty());
148    }
149}