1use crate::prelude_dev::*;
4use core::mem::ManuallyDrop;
5
6impl<T, const N: usize> PackableArrayAPI<T, N> for DataOwned<Vec<T>> {
11 type Array = [T; N];
12 type ArrayVec = DataOwned<Vec<[T; N]>>;
13}
14
15impl<T> PackArrayAPI<T> for DataOwned<Vec<T>> {
16 fn pack_array_f<const N: usize>(self) -> Result<<Self as PackableArrayAPI<T, N>>::ArrayVec> {
17 let raw = self.into_raw();
18 let raw = raw.pack_array_f::<N>()?;
19 Ok(DataOwned::from(raw))
20 }
21}
22
23impl<T, const N: usize> UnpackArrayAPI for DataOwned<Vec<[T; N]>> {
24 type Output = DataOwned<Vec<T>>;
25
26 fn unpack_array(self) -> Self::Output {
27 let raw = self.into_raw();
28 let raw = raw.unpack_array();
29 DataOwned::from(raw)
30 }
31}
32
33impl<'l, T, const N: usize> PackableArrayAPI<T, N> for DataRef<'l, Vec<T>> {
36 type Array = [T; N];
37 type ArrayVec = DataRef<'l, Vec<[T; N]>>;
38}
39
40impl<'l, T> PackArrayAPI<T> for DataRef<'l, Vec<T>> {
41 fn pack_array_f<const N: usize>(self) -> Result<<Self as PackableArrayAPI<T, N>>::ArrayVec> {
42 let raw = self.raw().as_slice().pack_array_f::<N>()?;
43 let vec = unsafe { Vec::from_raw_parts(raw.as_ptr() as *mut [T; N], raw.len(), raw.len()) };
44 Ok(DataRef::from_manually_drop(ManuallyDrop::new(vec)))
45 }
46}
47
48impl<'l, T, const N: usize> UnpackArrayAPI for DataRef<'l, Vec<[T; N]>> {
49 type Output = DataRef<'l, Vec<T>>;
50
51 fn unpack_array(self) -> Self::Output {
52 let raw = self.raw().as_slice().unpack_array();
53 let vec = unsafe { Vec::from_raw_parts(raw.as_ptr() as *mut T, raw.len(), raw.len()) };
54 DataRef::from_manually_drop(ManuallyDrop::new(vec))
55 }
56}
57
58impl<'l, T, const N: usize> PackableArrayAPI<T, N> for DataMut<'l, Vec<T>> {
61 type Array = [T; N];
62 type ArrayVec = DataMut<'l, Vec<[T; N]>>;
63}
64
65impl<'l, T> PackArrayAPI<T> for DataMut<'l, Vec<T>> {
66 fn pack_array_f<const N: usize>(self) -> Result<<Self as PackableArrayAPI<T, N>>::ArrayVec> {
67 let raw = self.raw().as_slice().pack_array_f::<N>()?;
68 let vec = unsafe { Vec::from_raw_parts(raw.as_ptr() as *mut [T; N], raw.len(), raw.len()) };
69 Ok(DataMut::from_manually_drop(ManuallyDrop::new(vec)))
70 }
71}
72
73impl<'l, T, const N: usize> UnpackArrayAPI for DataMut<'l, Vec<[T; N]>> {
74 type Output = DataMut<'l, Vec<T>>;
75
76 fn unpack_array(self) -> Self::Output {
77 let raw = self.raw().as_slice().unpack_array();
78 let vec = unsafe { Vec::from_raw_parts(raw.as_ptr() as *mut T, raw.len(), raw.len()) };
79 DataMut::from_manually_drop(ManuallyDrop::new(vec))
80 }
81}
82
83impl<'l, T, const N: usize> PackableArrayAPI<T, N> for DataCow<'l, Vec<T>> {
86 type Array = [T; N];
87 type ArrayVec = DataCow<'l, Vec<[T; N]>>;
88}
89
90impl<'l, T> PackArrayAPI<T> for DataCow<'l, Vec<T>> {
91 fn pack_array_f<const N: usize>(self) -> Result<<Self as PackableArrayAPI<T, N>>::ArrayVec> {
92 match self {
93 DataCow::Owned(data) => Ok(DataCow::Owned(data.pack_array_f::<N>()?)),
94 DataCow::Ref(data) => Ok(DataCow::Ref(data.pack_array_f::<N>()?)),
95 }
96 }
97}
98
99impl<'l, T, const N: usize> UnpackArrayAPI for DataCow<'l, Vec<[T; N]>> {
100 type Output = DataCow<'l, Vec<T>>;
101
102 fn unpack_array(self) -> Self::Output {
103 match self {
104 DataCow::Owned(data) => DataCow::Owned(data.unpack_array()),
105 DataCow::Ref(data) => DataCow::Ref(data.unpack_array()),
106 }
107 }
108}
109
110impl<R, T, B, D> TensorAny<R, T, B, D>
115where
116 R: DataAPI<Data = B::Raw>,
117 B: DeviceAPI<T>,
118 D: DimAPI + DimSmallerOneAPI,
119 D::SmallerOne: DimAPI,
120{
121 #[substitute_item(
122 ArrayData [<R as PackableArrayAPI<T, N>>::ArrayVec];
123 ArrayType [<R as PackableArrayAPI<T, N>>::Array];
124 )]
125 #[allow(clippy::type_complexity)]
126 pub fn into_pack_array_f<const N: usize>(
127 self,
128 axis: isize,
129 ) -> Result<TensorAny<ArrayData, ArrayType, B, D::SmallerOne>>
130 where
131 B: DeviceAPI<ArrayType>,
132 R: PackableArrayAPI<T, N> + PackArrayAPI<T>,
133 ArrayData: DataAPI<Data = <B as DeviceRawAPI<ArrayType>>::Raw>,
134 {
135 let axis = if axis < 0 { self.ndim() as isize + axis } else { axis };
138 rstsr_pattern!(axis, 0..self.ndim() as isize, ValueOutOfRange)?;
139 let axis = axis as usize;
140 rstsr_assert_eq!(self.layout().stride()[axis], 1, InvalidLayout, "The axis must be contiguous")?;
141 rstsr_assert_eq!(self.layout().shape()[axis], N, InvalidLayout, "The axis length must be a exactly {N}")?;
142 rstsr_assert!(self.layout().offset() % N == 0, InvalidLayout, "The offset must be a multiple of {N}")?;
143
144 let (storage, layout) = self.into_raw_parts();
145 let (data, device) = storage.into_raw_parts();
146 let data = data.pack_array_f::<N>()?;
147 let storage = Storage::new(data, device);
148 let layout = layout.dim_select(axis as isize, 0)?;
149 let stride = layout
150 .stride()
151 .as_ref()
152 .iter()
153 .map(|&s| s / N as isize)
154 .collect_vec()
155 .try_into()
156 .unwrap_or_else(|_| panic!("stride conversion failed"));
157 let new_offset = layout.offset() / N;
158 let new_layout = unsafe { Layout::new_unchecked(layout.shape().clone(), stride, new_offset) };
159 let tensor = unsafe { TensorAny::new_unchecked(storage, new_layout) };
160 Ok(tensor)
161 }
162
163 #[substitute_item(
164 ArrayData [<R as PackableArrayAPI<T, N>>::ArrayVec];
165 ArrayType [<R as PackableArrayAPI<T, N>>::Array];
166 )]
167 #[allow(clippy::type_complexity)]
168 pub fn into_pack_array<const N: usize>(self, axis: isize) -> TensorAny<ArrayData, ArrayType, B, D::SmallerOne>
169 where
170 B: DeviceAPI<ArrayType>,
171 R: PackableArrayAPI<T, N> + PackArrayAPI<T>,
172 ArrayData: DataAPI<Data = <B as DeviceRawAPI<ArrayType>>::Raw>,
173 {
174 self.into_pack_array_f::<N>(axis).unwrap()
175 }
176}
177
178impl<R, T, B, D, const N: usize> TensorAny<R, [T; N], B, D>
183where
184 R: DataAPI<Data = <B as DeviceRawAPI<[T; N]>>::Raw>,
185 B: DeviceAPI<T> + DeviceAPI<[T; N]>,
186 D: DimAPI + DimLargerOneAPI,
187 D::LargerOne: DimAPI,
188{
189 #[substitute_item(ROut [<R as UnpackArrayAPI>::Output])]
190 pub fn into_unpack_array_f(self, axis: isize) -> Result<TensorAny<ROut, T, B, D::LargerOne>>
191 where
192 R: UnpackArrayAPI,
193 ROut: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
194 B: DeviceAPI<T>,
195 {
196 let axis = if axis < 0 { self.ndim() as isize + axis + 1 } else { axis };
198 rstsr_pattern!(axis, 0..(self.ndim() + 1) as isize, ValueOutOfRange)?;
199 let axis = axis as usize;
200
201 let (storage, layout) = self.into_raw_parts();
202 let (data, device) = storage.into_raw_parts();
203 let data = data.unpack_array();
204 let storage = Storage::new(data, device);
205
206 let mut shape = layout.shape().as_ref().to_vec();
207 let mut stride = layout.stride().as_ref().to_vec();
208 let mut offset = layout.offset();
209
210 shape.insert(axis, N);
211 stride.iter_mut().map(|s| *s *= N as isize).count();
212 stride.insert(axis, 1);
213 offset *= N;
214 let layout = unsafe { Layout::new_unchecked(shape, stride, offset) };
215 let layout = layout.into_dim().unwrap();
216 let tensor = unsafe { TensorAny::new_unchecked(storage, layout) };
217 Ok(tensor)
218 }
219
220 #[substitute_item(ROut [<R as UnpackArrayAPI>::Output])]
221 pub fn into_unpack_array(self, axis: isize) -> TensorAny<ROut, T, B, D::LargerOne>
222 where
223 R: UnpackArrayAPI,
224 ROut: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
225 B: DeviceAPI<T>,
226 {
227 self.into_unpack_array_f(axis).unwrap()
228 }
229}
230
231#[cfg(test)]
234mod test {
235 use super::*;
236
237 #[test]
238 fn test_pack_array_owned() {
239 let device = DeviceCpuSerial::default();
240 let a = asarray((vec![1, 2, 3, 4, 5, 6], [3, 2].c(), &device));
241 let b = a.into_pack_array_f::<2>(-1).unwrap();
242 println!("{b:?}");
243 assert_eq!(b.raw(), &vec![[1, 2], [3, 4], [5, 6]]);
244
245 let c = b.into_unpack_array(-1);
246 println!("{c:?}");
247 assert_eq!(c.raw(), &vec![1, 2, 3, 4, 5, 6]);
248 }
249
250 #[test]
251 fn test_pack_array_ref() {
252 let device = DeviceCpuSerial::default();
253 let vec = vec![1, 2, 3, 4, 5, 6];
254 let a = asarray((&vec, [3, 2].c(), &device));
255 let b = a.into_pack_array_f::<2>(-1).unwrap();
256 println!("{b:?}");
257 assert_eq!(b.raw(), &vec![[1, 2], [3, 4], [5, 6]]);
258
259 let c = b.into_unpack_array(-1);
260 println!("{c:?}");
261 assert_eq!(c.raw(), &vec![1, 2, 3, 4, 5, 6]);
262 }
263}