ndarray_layout/transform/transpose.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
use crate::ArrayLayout;
use std::{collections::BTreeSet, iter::zip};
impl<const N: usize> ArrayLayout<N> {
/// 转置变换允许调换张量的维度顺序,但不改变元素的存储顺序。
///
/// ```rust
/// # use ndarray_layout::ArrayLayout;
/// let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).transpose(&[1, 0]);
/// assert_eq!(layout.shape(), &[3, 2, 4]);
/// assert_eq!(layout.strides(), &[4, 12, 1]);
/// assert_eq!(layout.offset(), 0);
/// ```
pub fn transpose(&self, perm: &[usize]) -> Self {
let perm_ = perm.iter().collect::<BTreeSet<_>>();
assert_eq!(perm_.len(), perm.len());
let content = self.content();
let shape = content.shape();
let strides = content.strides();
let mut ans = Self::with_ndim(self.ndim);
let mut content = ans.content_mut();
content.set_offset(self.offset());
let mut set = |i, j| {
content.set_shape(i, shape[j]);
content.set_stride(i, strides[j]);
};
let mut last = 0;
for (&i, &j) in zip(perm_, perm) {
for i in last..i {
set(i, i);
}
set(i, j);
last = i + 1;
}
for i in last..shape.len() {
set(i, i);
}
ans
}
}