use slop_alloc::{Backend, CpuBackend};
use crate::{Tensor, TensorViewMut};
pub trait TransposeBackend<T>: Backend {
fn transpose_tensor_into(src: &Tensor<T, Self>, dst: TensorViewMut<T, Self>);
}
impl<T, A: TransposeBackend<T>> Tensor<T, A> {
pub fn transpose(&self) -> Tensor<T, A> {
let mut sizes = self.sizes().to_vec();
let len = sizes.len();
assert_eq!(len, 2, "Transpose is only supported for 2D tensors");
sizes.swap(len - 1, len - 2);
let mut dst = Tensor::with_sizes_in(sizes, self.backend().clone());
unsafe {
dst.assume_init();
}
A::transpose_tensor_into(self, dst.as_view_mut());
dst
}
}
impl<T: Copy> TransposeBackend<T> for CpuBackend {
fn transpose_tensor_into(src: &Tensor<T, Self>, dst: TensorViewMut<T, Self>) {
debug_assert_eq!(src.sizes().len(), 2);
debug_assert_eq!(dst.sizes().len(), 2);
debug_assert_eq!(src.sizes()[src.sizes().len() - 1], dst.sizes()[dst.sizes().len() - 2]);
debug_assert_eq!(src.sizes()[src.sizes().len() - 2], dst.sizes()[dst.sizes().len() - 1]);
let input_width = src.sizes()[src.sizes().len() - 1];
let input_height = src.sizes()[src.sizes().len() - 2];
transpose::transpose(src.as_buffer(), dst.as_mut_slice(), input_width, input_height);
}
}
#[cfg(test)]
mod tests {
use rand::Rng;
use super::*;
#[test]
fn test_transpose() {
let mut rng = rand::thread_rng();
for (width, height) in [(2, 3), (5, 10), (100, 500), (1000, 1 << 16)] {
let tensor =
Tensor::<u32>::from((0..width * height).map(|_| rng.gen()).collect::<Vec<_>>())
.reshape([height, width]);
let transposed = tensor.transpose();
assert_eq!(transposed.sizes(), &[width, height]);
let i = rng.gen_range(0..height);
let j = rng.gen_range(0..width);
assert_eq!(tensor[[i, j]], transposed[[j, i]]);
}
}
}