rten_tensor/
contiguous.rs

1use std::ops::Deref;
2
3use crate::storage::{CowData, ViewData};
4use crate::{AsView, Layout, Storage, TensorBase};
5
6/// A tensor wrapper which guarantees that the tensor has a contiguous layout.
7///
8/// A contiguous layout means that the order of elements in memory matches the
9/// logical row-major ordering of elements with no gaps.
10#[derive(Copy, Clone, Debug, PartialEq)]
11pub struct Contiguous<T>(T);
12
13impl<T> Deref for Contiguous<T> {
14    type Target = T;
15
16    fn deref(&self) -> &T {
17        &self.0
18    }
19}
20
21impl<T> Contiguous<T> {
22    /// Extract the tensor from the wrapper.
23    pub fn into_inner(self) -> T {
24        self.0
25    }
26}
27
28impl<S: Storage, L: Layout> Contiguous<TensorBase<S, L>> {
29    /// Wrap a tensor if it is contiguous, or return `None` if the tensor has
30    /// a non-contiguous layout.
31    pub fn new(inner: TensorBase<S, L>) -> Option<Self> {
32        if inner.is_contiguous() {
33            Some(Self(inner))
34        } else {
35            None
36        }
37    }
38
39    /// Return the tensor's underlying data as a slice.
40    ///
41    /// Unlike [`TensorBase::data`] this returns a slice instead of an option
42    /// because the tensor is known to be contiguous.
43    pub fn data(&self) -> &[S::Elem] {
44        let len = self.0.len();
45        let ptr = self.0.data_ptr();
46
47        // Safety: Constructor verified that tensor is contiguous.
48        unsafe { std::slice::from_raw_parts(ptr, len) }
49    }
50
51    /// Return a contiguous view of this tensor.
52    pub fn view(&self) -> Contiguous<TensorBase<ViewData<'_, S::Elem>, L>>
53    where
54        TensorBase<S, L>: AsView<Elem = S::Elem, Layout = L>,
55    {
56        Contiguous(self.0.view())
57    }
58}
59
60impl<T, L: Clone + Layout> Contiguous<TensorBase<Vec<T>, L>> {
61    /// Extract the owned, contiguous data from this tensor.
62    pub fn into_data(self) -> Vec<T> {
63        self.0.into_non_contiguous_data()
64    }
65}
66
67impl<'a, T, L: Clone + Layout> Contiguous<TensorBase<CowData<'a, T>, L>> {
68    /// Extract the owned data from this tensor, if the data is owned.
69    pub fn into_data(self) -> Option<Vec<T>> {
70        self.0.into_non_contiguous_data()
71    }
72}
73
74impl<S: Storage, L: Layout> From<Contiguous<TensorBase<S, L>>> for TensorBase<S, L> {
75    fn from(val: Contiguous<TensorBase<S, L>>) -> Self {
76        val.0
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use crate::{AsView, Contiguous, Layout, NdTensor};
83
84    #[test]
85    fn test_contiguous() {
86        let tensor = NdTensor::<f32, 2>::zeros([3, 3]);
87        let wrapped = Contiguous::new(tensor);
88        assert!(wrapped.is_some());
89
90        let mut tensor: NdTensor<f32, 2> = wrapped.unwrap().into();
91        tensor.transpose();
92        let wrapped = Contiguous::new(tensor);
93        assert!(wrapped.is_none());
94    }
95
96    #[test]
97    fn test_contiguous_view() {
98        let tensor = NdTensor::<f32, 2>::zeros([3, 4]);
99        let wrapped = Contiguous::new(tensor).unwrap();
100        assert_eq!(wrapped.view().shape(), [3, 4]);
101    }
102}