burn_tensor/tensor/
bytes.rs

1//! A version of [`bytemuck::BoxBytes`] that is cloneable and allows trailing uninitialized elements.
2
3use alloc::alloc::{Layout, LayoutError};
4use core::mem::MaybeUninit;
5use core::ops::{Deref, DerefMut};
6use core::ptr::NonNull;
7
8use alloc::vec::Vec;
9
10/// Internally used to avoid accidentally leaking an allocation or using the wrong layout.
11struct Allocation {
12    /// SAFETY:
13    ///  - If `layout.size() > 0`, `ptr` points to a valid allocation from the global allocator
14    ///    of the specified layout. The first `len` bytes are initialized.
15    ///  - If `layout.size() == 0`, `ptr` is aligned to `layout.align()` and `len` is 0.
16    ///    `ptr` is further suitable to be used as the argument for `Vec::from_raw_parts` see [buffer alloc]
17    ///    for more details.
18    ptr: NonNull<u8>,
19    layout: Layout,
20}
21
22/// A sort of `Box<[u8]>` that remembers the original alignment and can contain trailing uninitialized bytes.
23pub struct Bytes {
24    alloc: Allocation,
25    // SAFETY: The first `len` bytes of the allocation are initialized
26    len: usize,
27}
28
29/// The maximum supported alignment. The limit exists to not have to store alignment when serializing. Instead,
30/// the bytes are always over-aligned when deserializing to MAX_ALIGN.
31const MAX_ALIGN: usize = core::mem::align_of::<u128>();
32
33fn debug_from_fn<F: Fn(&mut core::fmt::Formatter<'_>) -> core::fmt::Result>(
34    f: F,
35) -> impl core::fmt::Debug {
36    // See also: std::fmt::from_fn
37    struct FromFn<F>(F);
38    impl<F> core::fmt::Debug for FromFn<F>
39    where
40        F: Fn(&mut core::fmt::Formatter<'_>) -> core::fmt::Result,
41    {
42        fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
43            (self.0)(f)
44        }
45    }
46    FromFn(f)
47}
48
49impl core::fmt::Debug for Bytes {
50    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
51        let data = &**self;
52        let fmt_data = move |f: &mut core::fmt::Formatter<'_>| {
53            if data.len() > 3 {
54                // There is a nightly API `debug_more_non_exhaustive` which has `finish_non_exhaustive`
55                f.debug_list().entries(&data[0..3]).entry(&"...").finish()
56            } else {
57                f.debug_list().entries(data).finish()
58            }
59        };
60        f.debug_struct("Bytes")
61            .field("data", &debug_from_fn(fmt_data))
62            .field("len", &self.len)
63            .finish()
64    }
65}
66
67impl serde::Serialize for Bytes {
68    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
69    where
70        S: serde::Serializer,
71    {
72        serde_bytes::serialize(self.deref(), serializer)
73    }
74}
75
76impl<'de> serde::Deserialize<'de> for Bytes {
77    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
78    where
79        D: serde::Deserializer<'de>,
80    {
81        #[cold]
82        fn too_large<E: serde::de::Error>(len: usize, align: usize) -> E {
83            // max_length = largest multiple of align that is <= isize::MAX
84            // align is a power of 2, hence a multiple has the lower bits unset. Mask them off to find the largest multiple
85            let max_length = (isize::MAX as usize) & !(align - 1);
86            E::custom(core::format_args!(
87                "length too large: {len}. Expected at most {max_length} bytes"
88            ))
89        }
90
91        // TODO: we can possibly avoid one copy here by deserializing into an existing, correctly aligned, slice of bytes.
92        // We might not be able to predict the length of the data, hence it's far more convenient to let `Vec` handle the growth and re-allocations.
93        // Further, on a lot of systems, the allocator naturally aligns data to some reasonably large alignment, where no further copy is then
94        // necessary.
95        let data: Vec<u8> = serde_bytes::deserialize(deserializer)?;
96        // When deserializing, we over-align the data. This saves us from having to encode the alignment (which is platform-dependent in any case).
97        // If we had more context information here, we could enforce some (smaller) alignment per data type. But this information is only available
98        // in `TensorData`. Moreover it depends on the Deserializer there whether the datatype or data comes first.
99        let align = MAX_ALIGN;
100        let mut bytes = Self::from_elems(data);
101        bytes
102            .try_enforce_runtime_align(align)
103            .map_err(|_| too_large(bytes.len(), align))?;
104        Ok(bytes)
105    }
106}
107
108impl Clone for Bytes {
109    fn clone(&self) -> Self {
110        // unwrap here: the layout is valid as it has the alignment & size of self
111        Self::try_from_data(self.align(), self.deref()).unwrap()
112    }
113}
114
115impl PartialEq for Bytes {
116    fn eq(&self, other: &Self) -> bool {
117        self.deref() == other.deref()
118    }
119}
120
121impl Eq for Bytes {}
122
123impl Allocation {
124    // Wrap the allocation of a vector without copying
125    fn from_vec<E: Copy>(vec: Vec<E>) -> Self {
126        let mut elems = core::mem::ManuallyDrop::new(vec);
127        // Set the length to 0, then all data is in the "spare capacity".
128        // SAFETY: Data is Copy, so in particular does not need to be dropped. In any case, try not to panic until
129        //  we have taken ownership of the data!
130        unsafe { elems.set_len(0) };
131        let data = elems.spare_capacity_mut();
132        // We now have one contiguous slice of data to pass to Layout::for_value.
133        let layout = Layout::for_value(data);
134        // SAFETY: data is the allocation of a vec, hence can not be null. We use unchecked to avoid a panic-path.
135        let ptr = unsafe { NonNull::new_unchecked(elems.as_mut_ptr().cast()) };
136        Self { ptr, layout }
137    }
138    // Create a new allocation with the specified layout
139    fn new(layout: Layout) -> Self {
140        let ptr = buffer_alloc(layout);
141        Self { ptr, layout }
142    }
143    // Reallocate to fit at least the size and align of min_layout
144    fn grow(&mut self, min_layout: Layout) {
145        (self.layout, self.ptr) = buffer_grow(self.layout, self.ptr, min_layout);
146    }
147    // Returns a mutable view of the memory of the whole allocation
148    fn memory_mut(&mut self) -> &mut [MaybeUninit<u8>] {
149        // SAFETY: See type invariants
150        unsafe { core::slice::from_raw_parts_mut(self.ptr.as_ptr().cast(), self.layout.size()) }
151    }
152    // Return a pointer to the underlying allocation. This pointer is valid for reads and writes until the allocation is dropped or reallocated.
153    fn as_mut_ptr(&self) -> *mut u8 {
154        self.ptr.as_ptr()
155    }
156    // Try to convert the allocation to a Vec. The Vec has a length of 0 when returned, but correct capacity and pointer!
157    fn try_into_vec<E>(self) -> Result<Vec<E>, Self> {
158        let byte_capacity = self.layout.size();
159        let Some(capacity) = byte_capacity.checked_div(size_of::<E>()) else {
160            return Err(self);
161        };
162        if capacity * size_of::<E>() != byte_capacity {
163            return Err(self);
164        };
165        if self.layout.align() != align_of::<E>() {
166            return Err(self);
167        }
168        // Okay, let's commit
169        let ptr = self.ptr.as_ptr().cast();
170        core::mem::forget(self);
171        // SAFETY:
172        // - ptr was allocated by the global allocator as per type-invariant
173        // - `E` has the same alignment as indicated by the stored layout.
174        // - capacity * size_of::<E> == layout.size()
175        // - 0 <= capacity
176        // - no bytes are claimed to be initialized
177        // - the layout represents a valid allocation, hence has allocation size less than isize::MAX
178        Ok(unsafe { Vec::from_raw_parts(ptr, 0, capacity) })
179    }
180}
181
182impl Drop for Allocation {
183    fn drop(&mut self) {
184        buffer_dealloc(self.layout, self.ptr);
185    }
186}
187
188// Allocate a pointer that can be passed to Vec::from_raw_parts
189fn buffer_alloc(layout: Layout) -> NonNull<u8> {
190    // [buffer alloc]: The current docs of Vec::from_raw_parts(ptr, ...) say:
191    //   > ptr must have been allocated using the global allocator
192    // Yet, an empty Vec is guaranteed to not allocate (it is even illegal! to allocate with a zero-sized layout)
193    // Hence, we slightly re-interpret the above to only needing to hold if `capacity > 0`. Still, the pointer
194    // must be non-zero. So in case we need a pointer for an empty vec, use a correctly aligned, dangling one.
195    if layout.size() == 0 {
196        // we would use NonNull:dangling() but we don't have a concrete type for the requested alignment
197        let ptr = core::ptr::null_mut::<u8>().wrapping_add(layout.align());
198        // SAFETY: layout.align() is never 0
199        unsafe { NonNull::new_unchecked(ptr) }
200    } else {
201        // SAFETY: layout has non-zero size.
202        let ptr = unsafe { alloc::alloc::alloc(layout) };
203        NonNull::new(ptr).unwrap_or_else(|| alloc::alloc::handle_alloc_error(layout))
204    }
205}
206
207fn expect_dangling(align: usize, buffer: NonNull<u8>) {
208    debug_assert!(
209        buffer.as_ptr().wrapping_sub(align).is_null(),
210        "expected a nullptr for size 0"
211    );
212}
213
214#[cold]
215fn alloc_overflow() -> ! {
216    panic!("Overflow, too many elements")
217}
218
219// Grow the buffer while keeping alignment
220fn buffer_grow(
221    old_layout: Layout,
222    buffer: NonNull<u8>,
223    min_layout: Layout,
224) -> (Layout, NonNull<u8>) {
225    let new_align = min_layout.align().max(old_layout.align()); // Don't let data become less aligned
226    let new_size = min_layout.size().next_multiple_of(new_align);
227    if new_size > isize::MAX as usize {
228        alloc_overflow();
229    }
230
231    assert!(new_size > old_layout.size(), "size must actually grow");
232    if old_layout.size() == 0 {
233        expect_dangling(old_layout.align(), buffer);
234        let new_layout = Layout::from_size_align(new_size, new_align).unwrap();
235        let buffer = buffer_alloc(new_layout);
236        return (new_layout, buffer);
237    };
238    let realloc = || {
239        let new_layout = Layout::from_size_align(new_size, old_layout.align()).unwrap();
240        // SAFETY:
241        // - buffer comes from a Vec or from [`buffer_alloc`/`buffer_grow`].
242        // - old_layout is the same as with which the pointer was allocated
243        // - new_size is not 0, since it is larger than old_layout.size() which is non-zero
244        // - size constitutes a valid layout
245        let ptr = unsafe { alloc::alloc::realloc(buffer.as_ptr(), old_layout, new_layout.size()) };
246        (new_layout, ptr)
247    };
248    if new_align == old_layout.align() {
249        // happy path. We can just realloc.
250        let (new_layout, ptr) = realloc();
251        let buffer = NonNull::new(ptr);
252        let buffer = buffer.unwrap_or_else(|| alloc::alloc::handle_alloc_error(new_layout));
253        return (new_layout, buffer);
254    }
255    // [buffer grow]: alloc::realloc can *not* change the alignment of the allocation's layout.
256    // The unstable Allocator::{grow,shrink} API changes this, but might take a while to make it
257    // into alloc::GlobalAlloc.
258    //
259    // As such, we can not request a specific alignment. But most allocators will give us the required
260    // alignment "for free". Hence, we speculatively avoid a mem-copy by using realloc.
261    //
262    // If in the future requesting an alignment change for an existing is available, this can be removed.
263    #[cfg(target_has_atomic = "8")]
264    mod alignment_assumption {
265        use core::sync::atomic::{AtomicBool, Ordering};
266        static SPECULATE: AtomicBool = AtomicBool::new(true);
267        pub fn speculate() -> bool {
268            // We load and store with relaxed order, since worst case this leads to a few more memcopies
269            SPECULATE.load(Ordering::Relaxed)
270        }
271        pub fn report_violation() {
272            SPECULATE.store(false, Ordering::Relaxed)
273        }
274    }
275    #[cfg(not(target_has_atomic = "8"))]
276    mod alignment_assumption {
277        // On these platforms we don't speculate, and take the hit of performance
278        pub fn speculate() -> bool {
279            false
280        }
281        pub fn report_violation() {}
282    }
283    // reminder: old_layout.align() < new_align
284    let mut old_buffer = buffer;
285    let mut old_layout = old_layout;
286    if alignment_assumption::speculate() {
287        let (realloc_layout, ptr) = realloc();
288        if let Some(buffer) = NonNull::new(ptr) {
289            if buffer.align_offset(new_align) == 0 {
290                return (realloc_layout, buffer);
291            }
292            // Speculating hasn't succeeded, but access now has to go through the reallocated buffer
293            alignment_assumption::report_violation();
294            old_buffer = buffer;
295            old_layout = realloc_layout;
296        } else {
297            // If realloc fails, the later alloc will likely too, but don't report this yet
298        }
299    }
300    // realloc but change alignment. This requires a mem copy as pointed out above
301    let new_layout = Layout::from_size_align(new_size, new_align).unwrap();
302    let new_buffer = buffer_alloc(new_layout);
303    // SAFETY: two different memory allocations, and old buffer's size is smaller than new_size
304    unsafe {
305        core::ptr::copy_nonoverlapping(old_buffer.as_ptr(), new_buffer.as_ptr(), old_layout.size());
306    }
307    buffer_dealloc(old_layout, old_buffer);
308    (new_layout, new_buffer)
309}
310
311// Deallocate a buffer of a Vec
312fn buffer_dealloc(layout: Layout, buffer: NonNull<u8>) {
313    if layout.size() != 0 {
314        // SAFETY: buffer comes from a Vec or from [`buffer_alloc`/`buffer_grow`].
315        // The layout is the same as per type-invariants
316        unsafe {
317            alloc::alloc::dealloc(buffer.as_ptr(), layout);
318        }
319    } else {
320        // An empty Vec does not allocate, hence nothing to dealloc
321        expect_dangling(layout.align(), buffer);
322    }
323}
324
325impl Bytes {
326    /// Copy an existing slice of data into Bytes that are aligned to `align`
327    fn try_from_data(align: usize, data: &[u8]) -> Result<Self, LayoutError> {
328        let len = data.len();
329        let layout = Layout::from_size_align(len, align)?;
330        let alloc = Allocation::new(layout);
331        unsafe {
332            // SAFETY:
333            // - data and alloc are distinct allocations of `len` bytes
334            core::ptr::copy_nonoverlapping::<u8>(data.as_ref().as_ptr(), alloc.as_mut_ptr(), len);
335        };
336        Ok(Self { alloc, len })
337    }
338
339    /// Ensure the contained buffer is aligned to `align` by possibly moving it to a new buffer.
340    fn try_enforce_runtime_align(&mut self, align: usize) -> Result<(), LayoutError> {
341        if self.as_mut_ptr().align_offset(align) == 0 {
342            // data is already aligned correctly
343            return Ok(());
344        }
345        *self = Self::try_from_data(align, self)?;
346        Ok(())
347    }
348
349    /// Create a sequence of [Bytes] from the memory representation of an unknown type of elements.
350    /// Prefer this over [Self::from_elems] when the datatype is not statically known and erased at runtime.
351    pub fn from_bytes_vec(bytes: Vec<u8>) -> Self {
352        let mut bytes = Self::from_elems(bytes);
353        // TODO: this method could be datatype aware and enforce a less strict alignment.
354        // On most platforms, this alignment check is fulfilled either way though, so
355        // the benefits of potentially saving a memcopy are negligible.
356        bytes.try_enforce_runtime_align(MAX_ALIGN).unwrap();
357        bytes
358    }
359
360    /// Erase the element type of a vector by converting into a sequence of [Bytes].
361    ///
362    /// In case the element type is not statically known at runtime, prefer to use [Self::from_bytes_vec].
363    pub fn from_elems<E>(elems: Vec<E>) -> Self
364    where
365        // NoUninit implies Copy
366        E: bytemuck::NoUninit + Send + Sync,
367    {
368        let _: () = const {
369            assert!(
370                core::mem::align_of::<E>() <= MAX_ALIGN,
371                "element type not supported due to too large alignment"
372            );
373        };
374        // Note: going through a Box as in Vec::into_boxed_slice would re-allocate on excess capacity. Avoid that.
375        let byte_len = elems.len() * core::mem::size_of::<E>();
376        let alloc = Allocation::from_vec(elems);
377        Self {
378            alloc,
379            len: byte_len,
380        }
381    }
382
383    fn reserve(&mut self, additional: usize, align: usize) {
384        let needs_to_grow = additional > self.capacity().wrapping_sub(self.len());
385        if !needs_to_grow {
386            return;
387        }
388        let Some(required_cap) = self.len().checked_add(additional) else {
389            alloc_overflow()
390        };
391        // guarantee exponential growth for amortization
392        let new_cap = required_cap.max(self.capacity() * 2);
393        let new_cap = new_cap.max(align); // Small allocations would be pointless
394        let Ok(new_layout) = Layout::from_size_align(new_cap, align) else {
395            alloc_overflow()
396        };
397        self.alloc.grow(new_layout);
398    }
399
400    /// Extend the byte buffer from a slice of bytes.
401    ///
402    /// This is used internally to preserve the alignment of the memory layout when matching elements
403    /// are extended. Prefer [`Self::extend_from_byte_slice`] otherwise.
404    pub(crate) fn extend_from_byte_slice_aligned(&mut self, bytes: &[u8], align: usize) {
405        let additional = bytes.len();
406        self.reserve(additional, align);
407        let len = self.len();
408        let new_cap = len.wrapping_add(additional); // Can not overflow, as we've just successfully reserved sufficient space for it
409        let uninit_spare = &mut self.alloc.memory_mut()[len..new_cap];
410        // SAFETY: reinterpreting the slice as a MaybeUninit<u8>.
411        // See also #![feature(maybe_uninit_write_slice)], which would replace this with safe code
412        uninit_spare.copy_from_slice(unsafe {
413            core::slice::from_raw_parts(bytes.as_ptr().cast(), additional)
414        });
415        self.len = new_cap;
416    }
417
418    /// Extend the byte buffer from a slice of bytes
419    pub fn extend_from_byte_slice(&mut self, bytes: &[u8]) {
420        self.extend_from_byte_slice_aligned(bytes, MAX_ALIGN)
421    }
422
423    /// Get the total capacity, in bytes, of the wrapped allocation.
424    pub fn capacity(&self) -> usize {
425        self.alloc.layout.size()
426    }
427
428    /// Get the alignment of the wrapped allocation.
429    pub(crate) fn align(&self) -> usize {
430        self.alloc.layout.align()
431    }
432
433    /// Convert the bytes back into a vector. This requires that the type has the same alignment as the element
434    /// type this [Bytes] was initialized with.
435    /// This only returns with Ok(_) if the conversion can be done without a memcopy
436    pub fn try_into_vec<E: bytemuck::CheckedBitPattern + bytemuck::NoUninit>(
437        mut self,
438    ) -> Result<Vec<E>, Self> {
439        // See if the length is compatible
440        let Ok(data) = bytemuck::checked::try_cast_slice_mut::<_, E>(&mut self) else {
441            return Err(self);
442        };
443        let length = data.len();
444        // If so, try to convert the allocation to a vec
445        let mut vec = match self.alloc.try_into_vec::<E>() {
446            Ok(vec) => vec,
447            Err(alloc) => {
448                self.alloc = alloc;
449                return Err(self);
450            }
451        };
452        // SAFETY: We computed this length from the bytemuck-ed slice into this allocation
453        unsafe {
454            vec.set_len(length);
455        };
456        Ok(vec)
457    }
458}
459
460impl Deref for Bytes {
461    type Target = [u8];
462
463    fn deref(&self) -> &Self::Target {
464        // SAFETY: see type invariants
465        unsafe { core::slice::from_raw_parts(self.alloc.as_mut_ptr(), self.len) }
466    }
467}
468
469impl DerefMut for Bytes {
470    fn deref_mut(&mut self) -> &mut Self::Target {
471        // SAFETY: see type invariants
472        unsafe { core::slice::from_raw_parts_mut(self.alloc.as_mut_ptr(), self.len) }
473    }
474}
475
476// SAFETY: Bytes behaves like a Box<[u8]> and can contain only elements that are themselves Send
477unsafe impl Send for Bytes {}
478// SAFETY: Bytes behaves like a Box<[u8]> and can contain only elements that are themselves Sync
479unsafe impl Sync for Bytes {}
480
481#[cfg(test)]
482mod tests {
483    use super::Bytes;
484    use alloc::{vec, vec::Vec};
485
486    const _CONST_ASSERTS: fn() = || {
487        fn test_send<T: Send>() {}
488        fn test_sync<T: Sync>() {}
489        test_send::<Bytes>();
490        test_sync::<Bytes>();
491    };
492
493    fn test_serialization_roundtrip(bytes: &Bytes) {
494        let config = bincode::config::standard();
495        let serialized =
496            bincode::serde::encode_to_vec(bytes, config).expect("serialization to succeed");
497        let (roundtripped, _) = bincode::serde::decode_from_slice(&serialized, config)
498            .expect("deserialization to succeed");
499        assert_eq!(
500            bytes, &roundtripped,
501            "roundtripping through serialization didn't lead to equal Bytes"
502        );
503    }
504
505    #[test]
506    fn test_serialization() {
507        test_serialization_roundtrip(&Bytes::from_elems::<i32>(vec![]));
508        test_serialization_roundtrip(&Bytes::from_elems(vec![0xdead, 0xbeaf]));
509    }
510
511    #[test]
512    fn test_into_vec() {
513        // We test an edge case here, where the capacity (but not actual size) makes it impossible to convert to a vec
514        let mut bytes = Vec::with_capacity(6);
515        let actual_cap = bytes.capacity();
516        bytes.extend_from_slice(&[0, 1, 2, 3]);
517        let mut bytes = Bytes::from_elems::<u8>(bytes);
518
519        bytes = bytes
520            .try_into_vec::<[u8; 0]>()
521            .expect_err("Conversion should not succeed for a zero-sized type");
522        if actual_cap % 4 != 0 {
523            // We most likely get actual_cap == 6, we can't force Vec to actually do that. Code coverage should complain if the actual test misses this
524            bytes = bytes.try_into_vec::<[u8; 4]>().err().unwrap_or_else(|| {
525                panic!("Conversion should not succeed due to capacity {actual_cap} not fitting a whole number of elements");
526            });
527        }
528        bytes = bytes
529            .try_into_vec::<u16>()
530            .expect_err("Conversion should not succeed due to mismatched alignment");
531        bytes = bytes.try_into_vec::<[u8; 3]>().expect_err(
532            "Conversion should not succeed due to size not fitting a whole number of elements",
533        );
534        let bytes = bytes.try_into_vec::<[u8; 2]>().expect("Conversion should succeed for bit-convertible types of equal alignment and compatible size");
535        assert_eq!(bytes, &[[0, 1], [2, 3]]);
536    }
537
538    #[test]
539    fn test_grow() {
540        let mut bytes = Bytes::from_elems::<u8>(vec![]);
541        bytes.extend_from_byte_slice(&[0, 1, 2, 3]);
542        assert_eq!(bytes[..], [0, 1, 2, 3][..]);
543
544        let mut bytes = Bytes::from_elems(vec![42u8; 4]);
545        bytes.extend_from_byte_slice(&[0, 1, 2, 3]);
546        assert_eq!(bytes[..], [42, 42, 42, 42, 0, 1, 2, 3][..]);
547    }
548
549    #[test]
550    fn test_large_elems() {
551        let mut bytes = Bytes::from_elems(vec![42u128]);
552        const TEST_BYTES: [u8; 16] = [
553            0x12, 0x90, 0x78, 0x56, 0x34, 0x12, 0x90, 0x78, 0x56, 0x34, 0x12, 0x90, 0x78, 0x56,
554            0x34, 0x12,
555        ];
556        bytes.extend_from_byte_slice(&TEST_BYTES);
557        let vec = bytes.try_into_vec::<u128>().unwrap();
558        assert_eq!(vec, [42u128, u128::from_ne_bytes(TEST_BYTES)]);
559    }
560}