1#[cfg(feature = "cuda")]
2use std::ptr::null_mut;
3
4use crate::Matrix;
5use custos::{CDatatype, Device, MainMemory, Shape};
6
7#[cfg(feature = "cpu")]
8use custos::{Cache, CPU};
9
10#[cfg(feature = "cuda")]
11use custos::{
12 cuda::api::cublas::{cublasDgeam, cublasOperation_t, cublasSgeam, CublasHandle},
13 CUdeviceptr,
14};
15
16#[cfg(feature = "opencl")]
17use crate::cl_transpose;
18
19pub fn slice_transpose<T: Copy>(rows: usize, cols: usize, a: &[T], b: &mut [T]) {
20 for i in 0..rows {
21 let index = i * cols;
22 let row = &a[index..index + cols];
23
24 for (index, row) in row.iter().enumerate() {
25 let idx = rows * index + i;
26 b[idx] = *row;
27 }
28 }
29}
30
31impl<'a, T, IS: Shape, D: Device> Matrix<'a, T, D, IS> {
32 #[allow(non_snake_case)]
33 pub fn T<OS: Shape>(&self) -> Matrix<'a, T, D, OS>
34 where
35 D: TransposeOp<T, IS, OS>,
36 {
37 self.device().transpose(self)
38 }
39}
40
41pub trait TransposeOp<T, IS: Shape = (), OS: Shape = (), D: Device = Self>: Device {
42 fn transpose(&self, x: &Matrix<T, D, IS>) -> Matrix<T, Self, OS>;
43}
44
45#[cfg(feature = "cpu")]
46impl<T: Default + Copy, D: MainMemory, IS: Shape, OS: Shape> TransposeOp<T, IS, OS, D> for CPU {
47 fn transpose(&self, x: &Matrix<T, D, IS>) -> Matrix<T, Self, OS> {
48 let mut out = Cache::get(self, x.len(), x.node.idx);
49 slice_transpose(x.rows(), x.cols(), x.as_slice(), out.as_mut_slice());
50 (out, x.cols(), x.rows()).into()
51 }
52}
53
54#[cfg(feature = "opencl")]
55impl<T: CDatatype> TransposeOp<T> for custos::OpenCL {
56 fn transpose(&self, x: &Matrix<T, custos::OpenCL>) -> Matrix<T, custos::OpenCL> {
57 Matrix {
58 data: cl_transpose(self, x, x.rows(), x.cols()).unwrap(),
59 dims: (x.cols(), x.rows()),
60 }
61 }
62}
63
64#[cfg(feature = "cuda")]
65impl<T: CudaTranspose> TransposeOp<T> for custos::CUDA {
66 fn transpose(&self, x: &Matrix<T, custos::CUDA>) -> Matrix<T, custos::CUDA> {
67 let out = Cache::get(self, x.len(), x.node.idx);
68 T::transpose(&self.handle(), x.rows(), x.cols(), x.ptr.ptr, out.ptr.ptr).unwrap();
69 (out, x.cols(), x.rows()).into()
70 }
71}
72
73pub trait CudaTranspose {
74 #[cfg(feature = "cuda")]
75 fn transpose(
76 handle: &CublasHandle,
77 m: usize,
78 n: usize,
79 a: CUdeviceptr,
80 c: CUdeviceptr,
81 ) -> custos::Result<()>;
82}
83
84impl CudaTranspose for f32 {
85 #[cfg(feature = "cuda")]
86 fn transpose(
87 handle: &CublasHandle,
88 m: usize,
89 n: usize,
90 a: CUdeviceptr,
91 c: CUdeviceptr,
92 ) -> custos::Result<()> {
93 unsafe {
94 cublasSgeam(
96 handle.0,
97 cublasOperation_t::CUBLAS_OP_T,
98 cublasOperation_t::CUBLAS_OP_N,
99 m as i32,
100 n as i32,
101 &1f32 as *const f32,
102 a as *const CUdeviceptr as *const f32,
103 n as i32,
104 &0f32 as *const f32,
105 null_mut(),
106 m as i32,
107 c as *mut CUdeviceptr as *mut f32,
108 m as i32,
109 )
110 .to_result()?;
111 }
112 Ok(())
113 }
114}
115
116impl CudaTranspose for f64 {
117 #[cfg(feature = "cuda")]
118 fn transpose(
119 handle: &CublasHandle,
120 m: usize,
121 n: usize,
122 a: CUdeviceptr,
123 c: CUdeviceptr,
124 ) -> custos::Result<()> {
125 unsafe {
126 cublasDgeam(
128 handle.0,
129 cublasOperation_t::CUBLAS_OP_T,
130 cublasOperation_t::CUBLAS_OP_N,
131 m as i32,
132 n as i32,
133 &1f64 as *const f64,
134 a as *const CUdeviceptr as *const f64,
135 n as i32,
136 &0f64 as *const f64,
137 null_mut(),
138 m as i32,
139 c as *mut CUdeviceptr as *mut f64,
140 m as i32,
141 )
142 .to_result()?;
143 }
144 Ok(())
145 }
146}