rstsr_core/tensor/
pack_array.rs

1//! Cast the most contiguous dimension as array.
2
3use crate::prelude_dev::*;
4use core::mem::ManuallyDrop;
5
6/* #region impl directly to PackableArrayAPI */
7
8// DataOwned
9
10impl<T, const N: usize> PackableArrayAPI<T, N> for DataOwned<Vec<T>> {
11    type Array = [T; N];
12    type ArrayVec = DataOwned<Vec<[T; N]>>;
13}
14
15impl<T> PackArrayAPI<T> for DataOwned<Vec<T>> {
16    fn pack_array_f<const N: usize>(self) -> Result<<Self as PackableArrayAPI<T, N>>::ArrayVec> {
17        let raw = self.into_raw();
18        let raw = raw.pack_array_f::<N>()?;
19        Ok(DataOwned::from(raw))
20    }
21}
22
23impl<T, const N: usize> UnpackArrayAPI for DataOwned<Vec<[T; N]>> {
24    type Output = DataOwned<Vec<T>>;
25
26    fn unpack_array(self) -> Self::Output {
27        let raw = self.into_raw();
28        let raw = raw.unpack_array();
29        DataOwned::from(raw)
30    }
31}
32
33// DataRef
34
35impl<'l, T, const N: usize> PackableArrayAPI<T, N> for DataRef<'l, Vec<T>> {
36    type Array = [T; N];
37    type ArrayVec = DataRef<'l, Vec<[T; N]>>;
38}
39
40impl<'l, T> PackArrayAPI<T> for DataRef<'l, Vec<T>> {
41    fn pack_array_f<const N: usize>(self) -> Result<<Self as PackableArrayAPI<T, N>>::ArrayVec> {
42        let raw = self.raw().as_slice().pack_array_f::<N>()?;
43        let vec = unsafe { Vec::from_raw_parts(raw.as_ptr() as *mut [T; N], raw.len(), raw.len()) };
44        Ok(DataRef::from_manually_drop(ManuallyDrop::new(vec)))
45    }
46}
47
48impl<'l, T, const N: usize> UnpackArrayAPI for DataRef<'l, Vec<[T; N]>> {
49    type Output = DataRef<'l, Vec<T>>;
50
51    fn unpack_array(self) -> Self::Output {
52        let raw = self.raw().as_slice().unpack_array();
53        let vec = unsafe { Vec::from_raw_parts(raw.as_ptr() as *mut T, raw.len(), raw.len()) };
54        DataRef::from_manually_drop(ManuallyDrop::new(vec))
55    }
56}
57
58// DataMut
59
60impl<'l, T, const N: usize> PackableArrayAPI<T, N> for DataMut<'l, Vec<T>> {
61    type Array = [T; N];
62    type ArrayVec = DataMut<'l, Vec<[T; N]>>;
63}
64
65impl<'l, T> PackArrayAPI<T> for DataMut<'l, Vec<T>> {
66    fn pack_array_f<const N: usize>(self) -> Result<<Self as PackableArrayAPI<T, N>>::ArrayVec> {
67        let raw = self.raw().as_slice().pack_array_f::<N>()?;
68        let vec = unsafe { Vec::from_raw_parts(raw.as_ptr() as *mut [T; N], raw.len(), raw.len()) };
69        Ok(DataMut::from_manually_drop(ManuallyDrop::new(vec)))
70    }
71}
72
73impl<'l, T, const N: usize> UnpackArrayAPI for DataMut<'l, Vec<[T; N]>> {
74    type Output = DataMut<'l, Vec<T>>;
75
76    fn unpack_array(self) -> Self::Output {
77        let raw = self.raw().as_slice().unpack_array();
78        let vec = unsafe { Vec::from_raw_parts(raw.as_ptr() as *mut T, raw.len(), raw.len()) };
79        DataMut::from_manually_drop(ManuallyDrop::new(vec))
80    }
81}
82
83// DataCow
84
85impl<'l, T, const N: usize> PackableArrayAPI<T, N> for DataCow<'l, Vec<T>> {
86    type Array = [T; N];
87    type ArrayVec = DataCow<'l, Vec<[T; N]>>;
88}
89
90impl<'l, T> PackArrayAPI<T> for DataCow<'l, Vec<T>> {
91    fn pack_array_f<const N: usize>(self) -> Result<<Self as PackableArrayAPI<T, N>>::ArrayVec> {
92        match self {
93            DataCow::Owned(data) => Ok(DataCow::Owned(data.pack_array_f::<N>()?)),
94            DataCow::Ref(data) => Ok(DataCow::Ref(data.pack_array_f::<N>()?)),
95        }
96    }
97}
98
99impl<'l, T, const N: usize> UnpackArrayAPI for DataCow<'l, Vec<[T; N]>> {
100    type Output = DataCow<'l, Vec<T>>;
101
102    fn unpack_array(self) -> Self::Output {
103        match self {
104            DataCow::Owned(data) => DataCow::Owned(data.unpack_array()),
105            DataCow::Ref(data) => DataCow::Ref(data.unpack_array()),
106        }
107    }
108}
109
110/* #endregion */
111
112/* #region into_pack_array */
113
114impl<R, T, B, D> TensorAny<R, T, B, D>
115where
116    R: DataAPI<Data = B::Raw>,
117    B: DeviceAPI<T>,
118    D: DimAPI + DimSmallerOneAPI,
119    D::SmallerOne: DimAPI,
120{
121    #[substitute_item(
122        ArrayData [<R as PackableArrayAPI<T, N>>::ArrayVec];
123        ArrayType [<R as PackableArrayAPI<T, N>>::Array];
124    )]
125    #[allow(clippy::type_complexity)]
126    pub fn into_pack_array_f<const N: usize>(
127        self,
128        axis: isize,
129    ) -> Result<TensorAny<ArrayData, ArrayType, B, D::SmallerOne>>
130    where
131        B: DeviceAPI<ArrayType>,
132        R: PackableArrayAPI<T, N> + PackArrayAPI<T>,
133        ArrayData: DataAPI<Data = <B as DeviceRawAPI<ArrayType>>::Raw>,
134    {
135        // check if the axis is valid
136        // dimension check
137        let axis = if axis < 0 { self.ndim() as isize + axis } else { axis };
138        rstsr_pattern!(axis, 0..self.ndim() as isize, ValueOutOfRange)?;
139        let axis = axis as usize;
140        rstsr_assert_eq!(self.layout().stride()[axis], 1, InvalidLayout, "The axis must be contiguous")?;
141        rstsr_assert_eq!(self.layout().shape()[axis], N, InvalidLayout, "The axis length must be a exactly {N}")?;
142        rstsr_assert!(self.layout().offset() % N == 0, InvalidLayout, "The offset must be a multiple of {N}")?;
143
144        let (storage, layout) = self.into_raw_parts();
145        let (data, device) = storage.into_raw_parts();
146        let data = data.pack_array_f::<N>()?;
147        let storage = Storage::new(data, device);
148        let layout = layout.dim_select(axis as isize, 0)?;
149        let stride = layout
150            .stride()
151            .as_ref()
152            .iter()
153            .map(|&s| s / N as isize)
154            .collect_vec()
155            .try_into()
156            .unwrap_or_else(|_| panic!("stride conversion failed"));
157        let new_offset = layout.offset() / N;
158        let new_layout = unsafe { Layout::new_unchecked(layout.shape().clone(), stride, new_offset) };
159        let tensor = unsafe { TensorAny::new_unchecked(storage, new_layout) };
160        Ok(tensor)
161    }
162
163    #[substitute_item(
164        ArrayData [<R as PackableArrayAPI<T, N>>::ArrayVec];
165        ArrayType [<R as PackableArrayAPI<T, N>>::Array];
166    )]
167    #[allow(clippy::type_complexity)]
168    pub fn into_pack_array<const N: usize>(self, axis: isize) -> TensorAny<ArrayData, ArrayType, B, D::SmallerOne>
169    where
170        B: DeviceAPI<ArrayType>,
171        R: PackableArrayAPI<T, N> + PackArrayAPI<T>,
172        ArrayData: DataAPI<Data = <B as DeviceRawAPI<ArrayType>>::Raw>,
173    {
174        self.into_pack_array_f::<N>(axis).unwrap()
175    }
176}
177
178/* #endregion */
179
180/* #region into_unpack_array */
181
182impl<R, T, B, D, const N: usize> TensorAny<R, [T; N], B, D>
183where
184    R: DataAPI<Data = <B as DeviceRawAPI<[T; N]>>::Raw>,
185    B: DeviceAPI<T> + DeviceAPI<[T; N]>,
186    D: DimAPI + DimLargerOneAPI,
187    D::LargerOne: DimAPI,
188{
189    #[substitute_item(ROut [<R as UnpackArrayAPI>::Output])]
190    pub fn into_unpack_array_f(self, axis: isize) -> Result<TensorAny<ROut, T, B, D::LargerOne>>
191    where
192        R: UnpackArrayAPI,
193        ROut: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
194        B: DeviceAPI<T>,
195    {
196        // dimension check
197        let axis = if axis < 0 { self.ndim() as isize + axis + 1 } else { axis };
198        rstsr_pattern!(axis, 0..(self.ndim() + 1) as isize, ValueOutOfRange)?;
199        let axis = axis as usize;
200
201        let (storage, layout) = self.into_raw_parts();
202        let (data, device) = storage.into_raw_parts();
203        let data = data.unpack_array();
204        let storage = Storage::new(data, device);
205
206        let mut shape = layout.shape().as_ref().to_vec();
207        let mut stride = layout.stride().as_ref().to_vec();
208        let mut offset = layout.offset();
209
210        shape.insert(axis, N);
211        stride.iter_mut().map(|s| *s *= N as isize).count();
212        stride.insert(axis, 1);
213        offset *= N;
214        let layout = unsafe { Layout::new_unchecked(shape, stride, offset) };
215        let layout = layout.into_dim().unwrap();
216        let tensor = unsafe { TensorAny::new_unchecked(storage, layout) };
217        Ok(tensor)
218    }
219
220    #[substitute_item(ROut [<R as UnpackArrayAPI>::Output])]
221    pub fn into_unpack_array(self, axis: isize) -> TensorAny<ROut, T, B, D::LargerOne>
222    where
223        R: UnpackArrayAPI,
224        ROut: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
225        B: DeviceAPI<T>,
226    {
227        self.into_unpack_array_f(axis).unwrap()
228    }
229}
230
231/* #endregion */
232
233#[cfg(test)]
234mod test {
235    use super::*;
236
237    #[test]
238    fn test_pack_array_owned() {
239        let device = DeviceCpuSerial::default();
240        let a = asarray((vec![1, 2, 3, 4, 5, 6], [3, 2].c(), &device));
241        let b = a.into_pack_array_f::<2>(-1).unwrap();
242        println!("{b:?}");
243        assert_eq!(b.raw(), &vec![[1, 2], [3, 4], [5, 6]]);
244
245        let c = b.into_unpack_array(-1);
246        println!("{c:?}");
247        assert_eq!(c.raw(), &vec![1, 2, 3, 4, 5, 6]);
248    }
249
250    #[test]
251    fn test_pack_array_ref() {
252        let device = DeviceCpuSerial::default();
253        let vec = vec![1, 2, 3, 4, 5, 6];
254        let a = asarray((&vec, [3, 2].c(), &device));
255        let b = a.into_pack_array_f::<2>(-1).unwrap();
256        println!("{b:?}");
257        assert_eq!(b.raw(), &vec![[1, 2], [3, 4], [5, 6]]);
258
259        let c = b.into_unpack_array(-1);
260        println!("{c:?}");
261        assert_eq!(c.raw(), &vec![1, 2, 3, 4, 5, 6]);
262    }
263}