1use crate::prelude_dev::*;
2
3pub fn into_transpose_f<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> Result<TensorBase<S, D>>
6where
7 D: DimAPI,
8 I: TryInto<AxesIndex<isize>, Error = Error>,
9{
10 let axes = axes.try_into()?;
11 if axes.as_ref().is_empty() {
12 return Ok(into_reverse_axes(tensor));
13 }
14 let (storage, layout) = tensor.into_raw_parts();
15 let layout = layout.transpose(axes.as_ref())?;
16 unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
17}
18
19pub fn transpose<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> TensorView<'_, T, B, D>
25where
26 D: DimAPI,
27 I: TryInto<AxesIndex<isize>, Error = Error>,
28 R: DataAPI<Data = B::Raw>,
29 B: DeviceAPI<T>,
30{
31 into_transpose_f(tensor.view(), axes).rstsr_unwrap()
32}
33
34pub fn transpose_f<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> Result<TensorView<'_, T, B, D>>
35where
36 D: DimAPI,
37 I: TryInto<AxesIndex<isize>, Error = Error>,
38 R: DataAPI<Data = B::Raw>,
39 B: DeviceAPI<T>,
40{
41 into_transpose_f(tensor.view(), axes)
42}
43
44pub fn into_transpose<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> TensorBase<S, D>
45where
46 D: DimAPI,
47 I: TryInto<AxesIndex<isize>, Error = Error>,
48{
49 into_transpose_f(tensor, axes).rstsr_unwrap()
50}
51
52pub use into_transpose as into_permute_dims;
53pub use into_transpose_f as into_permute_dims_f;
54pub use transpose as permute_dims;
55pub use transpose_f as permute_dims_f;
56
57impl<R, T, B, D> TensorAny<R, T, B, D>
58where
59 R: DataAPI<Data = B::Raw>,
60 B: DeviceAPI<T>,
61 D: DimAPI,
62{
63 pub fn transpose<I>(&self, axes: I) -> TensorView<'_, T, B, D>
69 where
70 I: TryInto<AxesIndex<isize>, Error = Error>,
71 {
72 transpose(self, axes)
73 }
74
75 pub fn transpose_f<I>(&self, axes: I) -> Result<TensorView<'_, T, B, D>>
76 where
77 I: TryInto<AxesIndex<isize>, Error = Error>,
78 {
79 transpose_f(self, axes)
80 }
81
82 pub fn into_transpose<I>(self, axes: I) -> TensorAny<R, T, B, D>
88 where
89 I: TryInto<AxesIndex<isize>, Error = Error>,
90 {
91 into_transpose(self, axes)
92 }
93
94 pub fn into_transpose_f<I>(self, axes: I) -> Result<TensorAny<R, T, B, D>>
95 where
96 I: TryInto<AxesIndex<isize>, Error = Error>,
97 {
98 into_transpose_f(self, axes)
99 }
100
101 pub fn permute_dims<I>(&self, axes: I) -> TensorView<'_, T, B, D>
107 where
108 I: TryInto<AxesIndex<isize>, Error = Error>,
109 {
110 transpose(self, axes)
111 }
112
113 pub fn permute_dims_f<I>(&self, axes: I) -> Result<TensorView<'_, T, B, D>>
114 where
115 I: TryInto<AxesIndex<isize>, Error = Error>,
116 {
117 transpose_f(self, axes)
118 }
119
120 pub fn into_permute_dims<I>(self, axes: I) -> TensorAny<R, T, B, D>
126 where
127 I: TryInto<AxesIndex<isize>, Error = Error>,
128 {
129 into_transpose(self, axes)
130 }
131
132 pub fn into_permute_dims_f<I>(self, axes: I) -> Result<TensorAny<R, T, B, D>>
133 where
134 I: TryInto<AxesIndex<isize>, Error = Error>,
135 {
136 into_transpose_f(self, axes)
137 }
138}
139
140pub fn into_reverse_axes<S, D>(tensor: TensorBase<S, D>) -> TensorBase<S, D>
145where
146 D: DimAPI,
147{
148 let (storage, layout) = tensor.into_raw_parts();
149 let layout = layout.reverse_axes();
150 unsafe { TensorBase::new_unchecked(storage, layout) }
151}
152
153pub fn reverse_axes<R, T, B, D>(tensor: &TensorAny<R, T, B, D>) -> TensorView<'_, T, B, D>
155where
156 D: DimAPI,
157 R: DataAPI<Data = B::Raw>,
158 B: DeviceAPI<T>,
159{
160 into_reverse_axes(tensor.view())
161}
162
163impl<R, T, B, D> TensorAny<R, T, B, D>
164where
165 R: DataAPI<Data = B::Raw>,
166 B: DeviceAPI<T>,
167 D: DimAPI,
168{
169 pub fn reverse_axes(&self) -> TensorView<'_, T, B, D> {
175 into_reverse_axes(self.view())
176 }
177
178 pub fn into_reverse_axes(self) -> TensorAny<R, T, B, D> {
184 into_reverse_axes(self)
185 }
186
187 pub fn t(&self) -> TensorView<'_, T, B, D> {
193 into_reverse_axes(self.view())
194 }
195}
196
197pub fn into_swapaxes_f<I, S, D>(tensor: TensorBase<S, D>, axis1: I, axis2: I) -> Result<TensorBase<S, D>>
202where
203 D: DimAPI,
204 I: TryInto<isize>,
205{
206 let axis1 = axis1.try_into().map_err(|_| rstsr_error!(TryFromIntError))?;
207 let axis2 = axis2.try_into().map_err(|_| rstsr_error!(TryFromIntError))?;
208 let (storage, layout) = tensor.into_raw_parts();
209 let layout = layout.swapaxes(axis1, axis2)?;
210 unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
211}
212
213pub fn swapaxes<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axis1: I, axis2: I) -> TensorView<'_, T, B, D>
219where
220 D: DimAPI,
221 I: TryInto<isize>,
222 R: DataAPI<Data = B::Raw>,
223 B: DeviceAPI<T>,
224{
225 into_swapaxes_f(tensor.view(), axis1, axis2).rstsr_unwrap()
226}
227
228pub fn swapaxes_f<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axis1: I, axis2: I) -> Result<TensorView<'_, T, B, D>>
229where
230 D: DimAPI,
231 I: TryInto<isize>,
232 R: DataAPI<Data = B::Raw>,
233 B: DeviceAPI<T>,
234{
235 into_swapaxes_f(tensor.view(), axis1, axis2)
236}
237
238pub fn into_swapaxes<I, S, D>(tensor: TensorBase<S, D>, axis1: I, axis2: I) -> TensorBase<S, D>
239where
240 D: DimAPI,
241 I: TryInto<isize>,
242{
243 into_swapaxes_f(tensor, axis1, axis2).rstsr_unwrap()
244}
245
246impl<R, T, B, D> TensorAny<R, T, B, D>
247where
248 R: DataAPI<Data = B::Raw>,
249 B: DeviceAPI<T>,
250 D: DimAPI,
251{
252 pub fn swapaxes<I>(&self, axis1: I, axis2: I) -> TensorView<'_, T, B, D>
258 where
259 I: TryInto<isize>,
260 {
261 swapaxes(self, axis1, axis2)
262 }
263
264 pub fn swapaxes_f<I>(&self, axis1: I, axis2: I) -> Result<TensorView<'_, T, B, D>>
265 where
266 I: TryInto<isize>,
267 {
268 swapaxes_f(self, axis1, axis2)
269 }
270
271 pub fn into_swapaxes<I>(self, axis1: I, axis2: I) -> TensorAny<R, T, B, D>
277 where
278 I: TryInto<isize>,
279 {
280 into_swapaxes(self, axis1, axis2)
281 }
282
283 pub fn into_swapaxes_f<I>(self, axis1: I, axis2: I) -> Result<TensorAny<R, T, B, D>>
284 where
285 I: TryInto<isize>,
286 {
287 into_swapaxes_f(self, axis1, axis2)
288 }
289}
290
291