custos_math/ops/
transpose.rs

1#[cfg(feature = "cuda")]
2use std::ptr::null_mut;
3
4use crate::Matrix;
5use custos::{CDatatype, Device, MainMemory, Shape};
6
7#[cfg(feature = "cpu")]
8use custos::{Cache, CPU};
9
10#[cfg(feature = "cuda")]
11use custos::{
12    cuda::api::cublas::{cublasDgeam, cublasOperation_t, cublasSgeam, CublasHandle},
13    CUdeviceptr,
14};
15
16#[cfg(feature = "opencl")]
17use crate::cl_transpose;
18
19pub fn slice_transpose<T: Copy>(rows: usize, cols: usize, a: &[T], b: &mut [T]) {
20    for i in 0..rows {
21        let index = i * cols;
22        let row = &a[index..index + cols];
23
24        for (index, row) in row.iter().enumerate() {
25            let idx = rows * index + i;
26            b[idx] = *row;
27        }
28    }
29}
30
31impl<'a, T, IS: Shape, D: Device> Matrix<'a, T, D, IS> {
32    #[allow(non_snake_case)]
33    pub fn T<OS: Shape>(&self) -> Matrix<'a, T, D, OS>
34    where
35        D: TransposeOp<T, IS, OS>,
36    {
37        self.device().transpose(self)
38    }
39}
40
41pub trait TransposeOp<T, IS: Shape = (), OS: Shape = (), D: Device = Self>: Device {
42    fn transpose(&self, x: &Matrix<T, D, IS>) -> Matrix<T, Self, OS>;
43}
44
45#[cfg(feature = "cpu")]
46impl<T: Default + Copy, D: MainMemory, IS: Shape, OS: Shape> TransposeOp<T, IS, OS, D> for CPU {
47    fn transpose(&self, x: &Matrix<T, D, IS>) -> Matrix<T, Self, OS> {
48        let mut out = Cache::get(self, x.len(), x.node.idx);
49        slice_transpose(x.rows(), x.cols(), x.as_slice(), out.as_mut_slice());
50        (out, x.cols(), x.rows()).into()
51    }
52}
53
54#[cfg(feature = "opencl")]
55impl<T: CDatatype> TransposeOp<T> for custos::OpenCL {
56    fn transpose(&self, x: &Matrix<T, custos::OpenCL>) -> Matrix<T, custos::OpenCL> {
57        Matrix {
58            data: cl_transpose(self, x, x.rows(), x.cols()).unwrap(),
59            dims: (x.cols(), x.rows()),
60        }
61    }
62}
63
64#[cfg(feature = "cuda")]
65impl<T: CudaTranspose> TransposeOp<T> for custos::CUDA {
66    fn transpose(&self, x: &Matrix<T, custos::CUDA>) -> Matrix<T, custos::CUDA> {
67        let out = Cache::get(self, x.len(), x.node.idx);
68        T::transpose(&self.handle(), x.rows(), x.cols(), x.ptr.ptr, out.ptr.ptr).unwrap();
69        (out, x.cols(), x.rows()).into()
70    }
71}
72
73pub trait CudaTranspose {
74    #[cfg(feature = "cuda")]
75    fn transpose(
76        handle: &CublasHandle,
77        m: usize,
78        n: usize,
79        a: CUdeviceptr,
80        c: CUdeviceptr,
81    ) -> custos::Result<()>;
82}
83
84impl CudaTranspose for f32 {
85    #[cfg(feature = "cuda")]
86    fn transpose(
87        handle: &CublasHandle,
88        m: usize,
89        n: usize,
90        a: CUdeviceptr,
91        c: CUdeviceptr,
92    ) -> custos::Result<()> {
93        unsafe {
94            // TODO: better casting than: usize as i32
95            cublasSgeam(
96                handle.0,
97                cublasOperation_t::CUBLAS_OP_T,
98                cublasOperation_t::CUBLAS_OP_N,
99                m as i32,
100                n as i32,
101                &1f32 as *const f32,
102                a as *const CUdeviceptr as *const f32,
103                n as i32,
104                &0f32 as *const f32,
105                null_mut(),
106                m as i32,
107                c as *mut CUdeviceptr as *mut f32,
108                m as i32,
109            )
110            .to_result()?;
111        }
112        Ok(())
113    }
114}
115
116impl CudaTranspose for f64 {
117    #[cfg(feature = "cuda")]
118    fn transpose(
119        handle: &CublasHandle,
120        m: usize,
121        n: usize,
122        a: CUdeviceptr,
123        c: CUdeviceptr,
124    ) -> custos::Result<()> {
125        unsafe {
126            // TODO: better casting than: usize as i32
127            cublasDgeam(
128                handle.0,
129                cublasOperation_t::CUBLAS_OP_T,
130                cublasOperation_t::CUBLAS_OP_N,
131                m as i32,
132                n as i32,
133                &1f64 as *const f64,
134                a as *const CUdeviceptr as *const f64,
135                n as i32,
136                &0f64 as *const f64,
137                null_mut(),
138                m as i32,
139                c as *mut CUdeviceptr as *mut f64,
140                m as i32,
141            )
142            .to_result()?;
143        }
144        Ok(())
145    }
146}