custos/devices/
generic_blas.rs

1/// Provides generic access to f32 and f64 BLAS functions
2
3#[cfg(feature = "blas")]
4#[cfg(feature = "cpu")]
5use super::cpu::{
6    api::{cblas_dgemm, cblas_sgemm},
7    Order, Transpose,
8};
9
10#[cfg(feature = "cuda")]
11use super::cuda::api::{
12    cublas::{cublasDgemm_v2, cublasOperation_t, cublasSgemm_v2, CublasHandle},
13    CUdeviceptr,
14};
15
16/// Provides generic access to f32 and f64 BLAS functions
17pub trait GenericBlas
18where
19    Self: Sized,
20{
21    /// Performs a f32 or f64 matrix multiplication
22    #[cfg(feature = "blas")]
23    #[cfg(feature = "cpu")]
24    #[allow(clippy::too_many_arguments)]
25    fn blas_gemm(
26        order: Order,
27        trans_a: Transpose,
28        trans_b: Transpose,
29        m: usize,
30        n: usize,
31        k: usize,
32        a: &[Self],
33        lda: usize,
34        b: &[Self],
35        ldb: usize,
36        c: &mut [Self],
37        ldc: usize,
38    );
39
40    /// A shortened wrapper around [`GenericBlas::blas_gemm`] with the correct parameters for a matrix multiplication
41    #[cfg(feature = "blas")]
42    #[cfg(feature = "cpu")]
43    #[inline]
44    fn gemm(m: usize, n: usize, k: usize, a: &[Self], b: &[Self], c: &mut [Self]) {
45        Self::blas_gemm(
46            Order::RowMajor,
47            Transpose::NoTrans,
48            Transpose::NoTrans,
49            m,
50            n,
51            k,
52            a,
53            k,
54            b,
55            n,
56            c,
57            n,
58        )
59    }
60
61    /// A shortened wrapper around [`GenericBlas::blas_gemm`] with the correct parameters for a matrix multiplication
62    /// It transposes the rhs (b) matrix.
63    #[cfg(feature = "blas")]
64    #[cfg(feature = "cpu")]
65    #[inline]
66    #[allow(non_snake_case)]
67    fn gemmT(m: usize, n: usize, k: usize, a: &[Self], b: &[Self], c: &mut [Self]) {
68        Self::blas_gemm(
69            Order::RowMajor,
70            Transpose::NoTrans,
71            Transpose::Trans,
72            m,
73            n,
74            k,
75            a,
76            k,
77            b,
78            k,
79            c,
80            n,
81        )
82    }
83
84    /// A shortened wrapper around [`GenericBlas::blas_gemm`] with the correct parameters for a matrix multiplication
85    /// It transposes the lhs (a) matrix.
86    #[cfg(feature = "blas")]
87    #[cfg(feature = "cpu")]
88    #[inline]
89    #[allow(non_snake_case)]
90    fn Tgemm(m: usize, n: usize, k: usize, a: &[Self], b: &[Self], c: &mut [Self]) {
91        Self::blas_gemm(
92            Order::RowMajor,
93            Transpose::Trans,
94            Transpose::NoTrans,
95            m,
96            n,
97            k,
98            a,
99            m,
100            b,
101            n,
102            c,
103            n,
104        )
105    }
106
107    /// Access to cublas matrix multiplication
108    #[cfg(feature = "cuda")]
109    fn cugemm(
110        handle: &CublasHandle,
111        m: usize,
112        n: usize,
113        k: usize,
114        a: CUdeviceptr,
115        b: CUdeviceptr,
116        c: CUdeviceptr,
117    ) -> crate::Result<()>;
118}
119
120impl GenericBlas for f32 {
121    #[cfg(feature = "blas")]
122    #[cfg(feature = "cpu")]
123    #[inline]
124    fn blas_gemm(
125        order: Order,
126        trans_a: Transpose,
127        trans_b: Transpose,
128        m: usize,
129        n: usize,
130        k: usize,
131        a: &[Self],
132        lda: usize,
133        b: &[Self],
134        ldb: usize,
135        c: &mut [Self],
136        ldc: usize,
137    ) {
138        unsafe {
139            cblas_sgemm(
140                order,
141                trans_a,
142                trans_b,
143                m,
144                n,
145                k,
146                1.0,
147                a.as_ptr(),
148                lda,
149                b.as_ptr(),
150                ldb,
151                0.0,
152                c.as_mut_ptr(),
153                ldc,
154            )
155        };
156    }
157    #[cfg(feature = "cuda")]
158    #[inline]
159    fn cugemm(
160        handle: &CublasHandle,
161        m: usize,
162        n: usize,
163        k: usize,
164        a: CUdeviceptr,
165        b: CUdeviceptr,
166        c: CUdeviceptr,
167    ) -> crate::Result<()> {
168        unsafe {
169            cublasSgemm_v2(
170                handle.0,
171                cublasOperation_t::CUBLAS_OP_N,
172                cublasOperation_t::CUBLAS_OP_N,
173                n as i32,
174                m as i32,
175                k as i32,
176                &1f32 as *const f32,
177                b as *const u64 as *const f32,
178                n as i32,
179                a as *const u64 as *const f32,
180                k as i32,
181                &0f32 as *const f32,
182                c as *mut u64 as *mut f32,
183                n as i32,
184            )
185        }
186        .to_result()?;
187        Ok(())
188    }
189}
190
191impl GenericBlas for f64 {
192    #[cfg(feature = "blas")]
193    #[cfg(feature = "cpu")]
194    #[inline]
195    fn blas_gemm(
196        order: Order,
197        trans_a: Transpose,
198        trans_b: Transpose,
199        m: usize,
200        n: usize,
201        k: usize,
202        a: &[Self],
203        lda: usize,
204        b: &[Self],
205        ldb: usize,
206        c: &mut [Self],
207        ldc: usize,
208    ) {
209        unsafe {
210            cblas_dgemm(
211                order,
212                trans_a,
213                trans_b,
214                m,
215                n,
216                k,
217                1.0,
218                a.as_ptr(),
219                lda,
220                b.as_ptr(),
221                ldb,
222                0.0,
223                c.as_mut_ptr(),
224                ldc,
225            )
226        };
227    }
228    #[cfg(feature = "cuda")]
229    #[inline]
230    fn cugemm(
231        handle: &CublasHandle,
232        m: usize,
233        n: usize,
234        k: usize,
235        a: CUdeviceptr,
236        b: CUdeviceptr,
237        c: CUdeviceptr,
238    ) -> crate::Result<()> {
239        unsafe {
240            cublasDgemm_v2(
241                handle.0,
242                cublasOperation_t::CUBLAS_OP_N,
243                cublasOperation_t::CUBLAS_OP_N,
244                n as i32,
245                m as i32,
246                k as i32,
247                &1f64 as *const f64,
248                b as *const u64 as *const f64,
249                n as i32,
250                a as *const u64 as *const f64,
251                k as i32,
252                &0f64 as *const f64,
253                c as *mut u64 as *mut f64,
254                n as i32,
255            )
256        }
257        .to_result()?;
258        Ok(())
259    }
260}