rten_tensor/
contiguous.rs1use std::ops::Deref;
2
3use crate::storage::{CowData, ViewData};
4use crate::{AsView, Layout, Storage, TensorBase};
5
6#[derive(Copy, Clone, Debug, PartialEq)]
11pub struct Contiguous<T>(T);
12
13impl<T> Deref for Contiguous<T> {
14 type Target = T;
15
16 fn deref(&self) -> &T {
17 &self.0
18 }
19}
20
21impl<T> Contiguous<T> {
22 pub fn into_inner(self) -> T {
24 self.0
25 }
26}
27
28impl<S: Storage, L: Layout> Contiguous<TensorBase<S, L>> {
29 pub fn new(inner: TensorBase<S, L>) -> Option<Self> {
32 if inner.is_contiguous() {
33 Some(Self(inner))
34 } else {
35 None
36 }
37 }
38
39 pub fn data(&self) -> &[S::Elem] {
44 let len = self.0.len();
45 let ptr = self.0.data_ptr();
46
47 unsafe { std::slice::from_raw_parts(ptr, len) }
49 }
50
51 pub fn view(&self) -> Contiguous<TensorBase<ViewData<'_, S::Elem>, L>>
53 where
54 TensorBase<S, L>: AsView<Elem = S::Elem, Layout = L>,
55 {
56 Contiguous(self.0.view())
57 }
58}
59
60impl<T, L: Clone + Layout> Contiguous<TensorBase<Vec<T>, L>> {
61 pub fn into_data(self) -> Vec<T> {
63 self.0.into_non_contiguous_data()
64 }
65}
66
67impl<'a, T, L: Clone + Layout> Contiguous<TensorBase<CowData<'a, T>, L>> {
68 pub fn into_data(self) -> Option<Vec<T>> {
70 self.0.into_non_contiguous_data()
71 }
72}
73
74impl<S: Storage, L: Layout> From<Contiguous<TensorBase<S, L>>> for TensorBase<S, L> {
75 fn from(val: Contiguous<TensorBase<S, L>>) -> Self {
76 val.0
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use crate::{AsView, Contiguous, Layout, NdTensor};
83
84 #[test]
85 fn test_contiguous() {
86 let tensor = NdTensor::<f32, 2>::zeros([3, 3]);
87 let wrapped = Contiguous::new(tensor);
88 assert!(wrapped.is_some());
89
90 let mut tensor: NdTensor<f32, 2> = wrapped.unwrap().into();
91 tensor.transpose();
92 let wrapped = Contiguous::new(tensor);
93 assert!(wrapped.is_none());
94 }
95
96 #[test]
97 fn test_contiguous_view() {
98 let tensor = NdTensor::<f32, 2>::zeros([3, 4]);
99 let wrapped = Contiguous::new(tensor).unwrap();
100 assert_eq!(wrapped.view().shape(), [3, 4]);
101 }
102}