rten_tensor/
storage.rs

1//! Data storage types and traits.
2
3use std::borrow::Cow;
4use std::marker::PhantomData;
5use std::mem::MaybeUninit;
6use std::ops::Range;
7use std::sync::Arc;
8
9use crate::assume_init::AssumeInit;
10
11/// Trait for backing storage used by tensors and views.
12///
13/// Mutable tensors have storage which also implement [`StorageMut`].
14///
15/// This specifies a contiguous array of elements in memory, as a pointer and a
16/// length. The storage may be owned or borrowed. For borrowed storage, there
17/// may be other storage whose ranges overlap. This is necessary to support
18/// mutable views of non-contiguous tensors (eg. independent columns of a
19/// matrix, whose data is stored in row-major order).
20///
21/// # Safety
22///
23/// Since different storage objects can have memory ranges that overlap, it is
24/// up to the caller to ensure that mutable tensors cannot logically overlap any
25/// other tensors. In other words, whenever a mutable tensor is split or sliced
26/// or iterated, it should not be possible to get duplicate mutable references
27/// to the same elements from those views.
28///
29/// Implementations of this trait must ensure that the
30/// [`as_ptr`](Storage::as_ptr) and [`len`](Storage::len) methods define a valid
31/// range of memory within the same allocated object, which is correctly aligned
32/// for the `Elem` type. For the case where the storage is contiguous, these
33/// requirements are the same as
34/// [`slice::from_raw_parts`](std::slice::from_raw_parts).
35///
36/// The [`MUTABLE`](Storage::MUTABLE) associated const must be true if the
37/// storage is mutable.
38pub unsafe trait Storage {
39    /// The element type.
40    type Elem;
41
42    /// True if this storage allows mutable access, either directly or by
43    /// creating a mutable view with dynamically-checked borrowing (think
44    /// [`Arc::get_mut`]).
45    ///
46    /// This used to determine if a layout can be safely used with a storage.
47    /// Mutable storage can only be used with layouts where every index maps
48    /// to a unique element (ie. non-broadcasting layouts), in order to comply
49    /// with Rust rules for mutable references. Conversely non-mutable storage
50    /// _can_ be used with broadcasting layouts.
51    const MUTABLE: bool;
52
53    /// Return the number of elements in the storage.
54    fn len(&self) -> usize;
55
56    /// Return true if the storage contains no elements.
57    fn is_empty(&self) -> bool {
58        self.len() == 0
59    }
60
61    /// Return a pointer to the first element in the storage.
62    fn as_ptr(&self) -> *const Self::Elem;
63
64    /// Return the element at a given offset, or None if `offset >= self.len()`.
65    ///
66    /// # Safety
67    ///
68    /// - The caller must ensure that no mutable references to the same element
69    ///   can be created.
70    unsafe fn get(&self, offset: usize) -> Option<&Self::Elem> {
71        if offset < self.len() {
72            Some(unsafe { &*self.as_ptr().add(offset) })
73        } else {
74            None
75        }
76    }
77
78    /// Return a reference to the element at `offset`.
79    ///
80    /// # Safety
81    ///
82    /// This has the same safety requirements as [`get`](Storage::get) plus
83    /// the caller must ensure that `offset < len`.
84    unsafe fn get_unchecked(&self, offset: usize) -> &Self::Elem {
85        debug_assert!(offset < self.len());
86        unsafe { &*self.as_ptr().add(offset) }
87    }
88
89    /// Return a view of a sub-region of the storage.
90    ///
91    /// Panics if the range is out of bounds.
92    fn slice(&self, range: Range<usize>) -> ViewData<'_, Self::Elem> {
93        assert_storage_range_valid(self, range.clone());
94        ViewData {
95            // Safety: We verified that `range` is in bounds.
96            ptr: unsafe { self.as_ptr().add(range.start) },
97            len: range.len(),
98            _marker: PhantomData,
99        }
100    }
101
102    /// Return an immutable view of this storage.
103    fn view(&self) -> ViewData<'_, Self::Elem> {
104        ViewData {
105            ptr: self.as_ptr(),
106            len: self.len(),
107            _marker: PhantomData,
108        }
109    }
110
111    /// Return the contents of the storage as a slice.
112    ///
113    /// # Safety
114    ///
115    /// The caller must ensure that no mutable references exist to any element
116    /// in the storage.
117    unsafe fn as_slice(&self) -> &[Self::Elem] {
118        let (ptr, len) = (self.as_ptr(), self.len());
119        unsafe { std::slice::from_raw_parts(ptr, len) }
120    }
121}
122
123/// Trait for converting owned and borrowed element containers (`Vec<T>`, slices)
124/// into their corresponding `Storage` type.
125///
126/// This is used by [`Tensor::from_data`](crate::TensorBase::from_data).
127pub trait IntoStorage {
128    type Output: Storage;
129
130    fn into_storage(self) -> Self::Output;
131}
132
133impl<T: Storage> IntoStorage for T {
134    type Output = Self;
135
136    fn into_storage(self) -> Self {
137        self
138    }
139}
140
141impl<'a, T> IntoStorage for &'a [T] {
142    type Output = ViewData<'a, T>;
143
144    fn into_storage(self) -> ViewData<'a, T> {
145        ViewData {
146            ptr: self.as_ptr(),
147            len: self.len(),
148            _marker: PhantomData,
149        }
150    }
151}
152
153impl<'a, T, const N: usize> IntoStorage for &'a [T; N] {
154    type Output = ViewData<'a, T>;
155
156    fn into_storage(self) -> ViewData<'a, T> {
157        self.as_slice().into_storage()
158    }
159}
160
161impl<'a, T> IntoStorage for &'a mut [T] {
162    type Output = ViewMutData<'a, T>;
163
164    fn into_storage(self) -> ViewMutData<'a, T> {
165        ViewMutData {
166            ptr: self.as_mut_ptr(),
167            len: self.len(),
168            _marker: PhantomData,
169        }
170    }
171}
172
173/// Panic if an offset range is out of bounds for a given storage.
174fn assert_storage_range_valid<S: Storage + ?Sized>(storage: &S, range: Range<usize>) {
175    assert!(
176        range.start <= storage.len() && range.end <= storage.len(),
177        "invalid slice range {:?} for storage length {}",
178        range,
179        storage.len()
180    );
181}
182
183/// Trait for backing storage used by mutable tensors and views.
184///
185/// This extends [`Storage`] with methods to get mutable pointers and references
186/// to elements in the storage.
187///
188/// # Safety
189///
190/// The [`as_mut_ptr`](StorageMut::as_mut_ptr) method has the same safety
191/// requirements as [`Storage::as_ptr`]. The result of `as_mut_ptr` must also
192/// be equal to `as_ptr`.
193pub unsafe trait StorageMut: Storage {
194    /// Return a mutable pointer to the first element in storage.
195    fn as_mut_ptr(&mut self) -> *mut Self::Elem;
196
197    /// Mutable version of [`Storage::get`].
198    ///
199    /// # Safety
200    ///
201    /// This has the same safety requirements as [`get`](Storage::get).
202    unsafe fn get_mut(&mut self, offset: usize) -> Option<&mut Self::Elem> {
203        if offset < self.len() {
204            Some(unsafe { &mut *self.as_mut_ptr().add(offset) })
205        } else {
206            None
207        }
208    }
209
210    /// Mutable version of [`Storage::get_unchecked`].
211    ///
212    /// # Safety
213    ///
214    /// This has the same requirement as [`get_mut`](StorageMut::get_mut) plus
215    /// the caller must ensure that `offset < self.len()`.
216    unsafe fn get_unchecked_mut(&mut self, offset: usize) -> &mut Self::Elem {
217        debug_assert!(offset < self.len());
218        unsafe { &mut *self.as_mut_ptr().add(offset) }
219    }
220
221    /// Return a slice of this storage.
222    fn slice_mut(&mut self, range: Range<usize>) -> ViewMutData<'_, Self::Elem> {
223        assert_storage_range_valid(self, range.clone());
224        ViewMutData {
225            // Safety: We verified that `range` is in bounds.
226            ptr: unsafe { self.as_mut_ptr().add(range.start) },
227            len: range.len(),
228            _marker: PhantomData,
229        }
230    }
231
232    /// Return a mutable view of this storage.
233    fn view_mut(&mut self) -> ViewMutData<'_, Self::Elem> {
234        ViewMutData {
235            ptr: self.as_mut_ptr(),
236            len: self.len(),
237            _marker: PhantomData,
238        }
239    }
240
241    /// Return the stored elements as a mutable slice.
242    ///
243    /// # Safety
244    ///
245    /// The caller must ensure that the storage is contiguous (ie. no unused
246    /// elements) and that there are no other references to any elements in the
247    /// storage.
248    unsafe fn as_slice_mut(&mut self) -> &mut [Self::Elem] {
249        let (ptr, len) = (self.as_mut_ptr(), self.len());
250        unsafe { std::slice::from_raw_parts_mut(ptr, len) }
251    }
252}
253
254unsafe impl<T> Storage for Vec<T> {
255    type Elem = T;
256
257    const MUTABLE: bool = true;
258
259    fn len(&self) -> usize {
260        self.len()
261    }
262
263    fn as_ptr(&self) -> *const T {
264        self.as_ptr()
265    }
266}
267
268unsafe impl<T> StorageMut for Vec<T> {
269    fn as_mut_ptr(&mut self) -> *mut T {
270        self.as_mut_ptr()
271    }
272}
273
274unsafe impl<T> Storage for Arc<Vec<T>> {
275    type Elem = T;
276
277    // This storage as marked as mutable to allow for adding methods to
278    // `ArcTensor` in future which are analagous to `Arc::{get_mut, make_mut}`
279    // (ie. they would return a mutable view or cloned tensor after a dynamic
280    // check of the reference count).
281    const MUTABLE: bool = true;
282
283    fn len(&self) -> usize {
284        self.as_ref().len()
285    }
286
287    fn as_ptr(&self) -> *const T {
288        self.as_ref().as_ptr()
289    }
290}
291
292/// Storage for an immutable tensor view.
293///
294/// This has the same representation in memory as a slice: a pointer and a
295/// length. Unlike a slice it allows for other mutable storage to reference
296/// memory ranges that overlap with this one. It is up to APIs built on top of
297/// this to ensure uniqueness of mutable element references.
298#[derive(Debug)]
299pub struct ViewData<'a, T> {
300    ptr: *const T,
301    len: usize,
302    _marker: PhantomData<&'a T>,
303}
304
305// Safety: `ViewData` does not provide mutable access to its elements, so it
306// is `Send` and `Sync`.
307unsafe impl<T> Send for ViewData<'_, T> {}
308unsafe impl<T> Sync for ViewData<'_, T> {}
309
310impl<T> Clone for ViewData<'_, T> {
311    fn clone(&self) -> Self {
312        *self
313    }
314}
315impl<T> Copy for ViewData<'_, T> {}
316
317impl<'a, T> ViewData<'a, T> {
318    /// Variant of [`Storage::get`] which preserves lifetimes.
319    ///
320    /// # Safety
321    ///
322    /// See [`Storage::get`].
323    pub unsafe fn get(&self, offset: usize) -> Option<&'a T> {
324        if offset < self.len {
325            Some(unsafe { &*self.ptr.add(offset) })
326        } else {
327            None
328        }
329    }
330
331    /// Variant of [`Storage::get_unchecked`] which preserves lifetimes.
332    ///
333    /// # Safety
334    ///
335    /// See [`Storage::get_unchecked`].
336    pub unsafe fn get_unchecked(&self, offset: usize) -> &'a T {
337        debug_assert!(offset < self.len);
338        unsafe { &*self.ptr.add(offset) }
339    }
340
341    /// Variant of [`Storage::slice`] which preserves lifetimes.
342    pub fn slice(&self, range: Range<usize>) -> ViewData<'a, T> {
343        assert_storage_range_valid(self, range.clone());
344        ViewData {
345            // Safety: `range.start < range.end` and `range.end <= self.len())`,
346            // so this is in-bounds.
347            ptr: unsafe { self.as_ptr().add(range.start) },
348            len: range.len(),
349            _marker: PhantomData,
350        }
351    }
352
353    /// Variant of [`Storage::view`] which preserves lifetimes.
354    pub fn view(&self) -> ViewData<'a, T> {
355        *self
356    }
357
358    /// Return the contents of the storage as a slice.
359    ///
360    /// # Safety
361    ///
362    /// The caller must ensure that no mutable references exist to any element
363    /// in the storage.
364    pub unsafe fn as_slice(&self) -> &'a [T] {
365        unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
366    }
367}
368
369unsafe impl<T> Storage for ViewData<'_, T> {
370    type Elem = T;
371
372    const MUTABLE: bool = false;
373
374    fn len(&self) -> usize {
375        self.len
376    }
377
378    fn as_ptr(&self) -> *const T {
379        self.ptr
380    }
381}
382
383impl<'a, T> AssumeInit for ViewData<'a, MaybeUninit<T>> {
384    type Output = ViewData<'a, T>;
385
386    unsafe fn assume_init(self) -> Self::Output {
387        unsafe { std::mem::transmute(self) }
388    }
389}
390
391/// Storage for a mutable tensor view.
392///
393/// This has the same representation in memory as a mutable slice: a pointer
394/// and a length. Unlike a slice it allows for other storage objects to
395/// reference memory ranges that overlap with this one. It is up to
396/// APIs built on top of this to ensure uniqueness of mutable references.
397#[derive(Debug)]
398pub struct ViewMutData<'a, T> {
399    ptr: *mut T,
400    len: usize,
401    _marker: PhantomData<&'a mut T>,
402}
403unsafe impl<T> Send for ViewMutData<'_, T> {}
404
405impl<'a, T> ViewMutData<'a, T> {
406    /// Variant of [`StorageMut::as_slice_mut`] which preserves the underlying
407    /// lifetime in the result.
408    ///
409    /// # Safety
410    ///
411    /// See [`StorageMut::as_slice_mut`].
412    pub unsafe fn to_slice_mut(mut self) -> &'a mut [T] {
413        let (ptr, len) = (self.as_mut_ptr(), self.len());
414        unsafe { std::slice::from_raw_parts_mut(ptr, len) }
415    }
416
417    /// Split the storage into two sub-views.
418    ///
419    /// Unlike splitting a slice, this does *not* ensure that the two halves
420    /// do not overlap, only that the "left" and "right" ranges are valid.
421    pub fn split_mut(
422        self,
423        left: Range<usize>,
424        right: Range<usize>,
425    ) -> (ViewMutData<'a, T>, ViewMutData<'a, T>) {
426        assert_storage_range_valid(&self, left.clone());
427        assert_storage_range_valid(&self, right.clone());
428
429        let left = ViewMutData {
430            ptr: unsafe { self.ptr.add(left.start) },
431            len: left.len(),
432            _marker: PhantomData,
433        };
434        let right = ViewMutData {
435            ptr: unsafe { self.ptr.add(right.start) },
436            len: right.len(),
437            _marker: PhantomData,
438        };
439        (left, right)
440    }
441
442    /// A variant of [`StorageMut::slice_mut`] which preserves the lifetime of
443    /// the slice.
444    ///
445    /// # Safety
446    ///
447    /// This is unsafe since this function cannot ensure that multiple references
448    /// to the same element are not created (by using `get_mut` on `self` and
449    /// the slice). It is up to the caller to prevent this.
450    pub unsafe fn to_view_slice_mut(&mut self, range: Range<usize>) -> ViewMutData<'a, T> {
451        assert_storage_range_valid(self, range.clone());
452        ViewMutData {
453            // Safety: We verified that `range` is in bounds.
454            ptr: unsafe { self.as_mut_ptr().add(range.start) },
455            len: range.len(),
456            _marker: PhantomData,
457        }
458    }
459}
460
461unsafe impl<T> Storage for ViewMutData<'_, T> {
462    type Elem = T;
463
464    const MUTABLE: bool = true;
465
466    fn len(&self) -> usize {
467        self.len
468    }
469
470    fn as_ptr(&self) -> *const T {
471        self.ptr
472    }
473}
474
475unsafe impl<T> StorageMut for ViewMutData<'_, T> {
476    fn as_mut_ptr(&mut self) -> *mut T {
477        self.ptr
478    }
479}
480
481impl<'a, T> AssumeInit for ViewMutData<'a, MaybeUninit<T>> {
482    type Output = ViewMutData<'a, T>;
483
484    unsafe fn assume_init(self) -> Self::Output {
485        unsafe { std::mem::transmute(self) }
486    }
487}
488
489/// Tensor storage which may be either owned or borrowed.
490///
491/// The name is taken from [`std::borrow::Cow`] in the standard library,
492/// which is conceptually similar.
493pub enum CowData<'a, T> {
494    /// A [`CowData`] that owns its data.
495    Owned(Vec<T>),
496    /// A [`CowData`] that borrows data.
497    Borrowed(ViewData<'a, T>),
498}
499
500unsafe impl<T> Storage for CowData<'_, T> {
501    type Elem = T;
502
503    const MUTABLE: bool = false;
504
505    fn len(&self) -> usize {
506        match self {
507            CowData::Owned(vec) => vec.len(),
508            CowData::Borrowed(view) => view.len(),
509        }
510    }
511
512    fn as_ptr(&self) -> *const T {
513        match self {
514            CowData::Owned(vec) => vec.as_ptr(),
515            CowData::Borrowed(view) => view.as_ptr(),
516        }
517    }
518}
519
520impl<'a, T> IntoStorage for Cow<'a, [T]>
521where
522    [T]: ToOwned<Owned = Vec<T>>,
523{
524    type Output = CowData<'a, T>;
525
526    fn into_storage(self) -> Self::Output {
527        match self {
528            Cow::Owned(vec) => CowData::Owned(vec),
529            Cow::Borrowed(slice) => CowData::Borrowed(slice.into_storage()),
530        }
531    }
532}
533
534/// Storage allocation trait.
535///
536/// This is used by various methods on [`TensorBase`](crate::TensorBase) with an
537/// `_in` suffix, which allow the caller to control the allocation of the data
538/// buffer for the returned owned tensor.
539pub trait Alloc {
540    /// Allocate storage for an owned tensor.
541    ///
542    /// The returned `Vec` should be empty but have the given capacity.
543    fn alloc<T>(&self, capacity: usize) -> Vec<T>;
544}
545
546impl<A: Alloc> Alloc for &A {
547    fn alloc<T>(&self, capacity: usize) -> Vec<T> {
548        A::alloc(self, capacity)
549    }
550}
551
552/// Implementation of [`Alloc`] which wraps the global allocator.
553pub struct GlobalAlloc {}
554
555impl GlobalAlloc {
556    pub const fn new() -> GlobalAlloc {
557        GlobalAlloc {}
558    }
559}
560
561impl Default for GlobalAlloc {
562    fn default() -> Self {
563        Self::new()
564    }
565}
566
567impl Alloc for GlobalAlloc {
568    fn alloc<T>(&self, capacity: usize) -> Vec<T> {
569        Vec::with_capacity(capacity)
570    }
571}
572
573#[cfg(test)]
574mod tests {
575    use std::borrow::Cow;
576
577    use super::{IntoStorage, Storage, StorageMut, ViewData, ViewMutData};
578
579    fn test_storage_impl<S: Storage<Elem = i32>>(s: S, expected: &[i32]) {
580        // Test `len`, `get`.
581        assert_eq!(s.len(), expected.len());
582        for i in 0..s.len() {
583            assert_eq!(unsafe { s.get(i) }, expected.get(i));
584        }
585        assert_eq!(unsafe { s.get(s.len()) }, None);
586
587        // Test slicing storage.
588        let range = 1..s.len() - 1;
589        let slice = s.slice(range.clone());
590        assert_eq!(slice.len(), range.len());
591        for i in 0..slice.len() {
592            assert_eq!(unsafe { slice.get(i) }, expected[range.clone()].get(i));
593        }
594
595        // Test restoring a slice.
596        assert_eq!(unsafe { s.as_slice() }, expected);
597    }
598
599    #[test]
600    fn test_storage() {
601        let data = &mut [1, 2, 3, 4];
602
603        let owned = data.to_vec();
604        test_storage_impl(owned, data);
605
606        let view: ViewData<i32> = data.as_slice().into_storage();
607        test_storage_impl(view, data);
608
609        let cow_view = Cow::Borrowed(data.as_slice()).into_storage();
610        test_storage_impl(cow_view, data);
611
612        let mut_view: ViewMutData<i32> = data.as_mut_slice().into_storage();
613        test_storage_impl(mut_view, &[1, 2, 3, 4]);
614    }
615
616    #[test]
617    #[should_panic(expected = "invalid slice range 5..2 for storage length 4")]
618    fn test_storage_slice_invalid_start() {
619        let data = vec![1, 2, 3, 4];
620        Storage::slice(&data, 5..2);
621    }
622
623    #[test]
624    #[should_panic(expected = "invalid slice range 2..5 for storage length 4")]
625    fn test_storage_slice_invalid_end() {
626        let data = vec![1, 2, 3, 4];
627        Storage::slice(&data, 2..5);
628    }
629
630    #[test]
631    #[should_panic(expected = "invalid slice range 5..2 for storage length 4")]
632    fn test_storage_slice_mut_invalid_start() {
633        let mut data = vec![1, 2, 3, 4];
634        StorageMut::slice_mut(&mut data, 5..2);
635    }
636
637    #[test]
638    #[should_panic(expected = "invalid slice range 2..5 for storage length 4")]
639    fn test_storage_slice_mut_invalid_end() {
640        let mut data = vec![1, 2, 3, 4];
641        StorageMut::slice_mut(&mut data, 2..5);
642    }
643}