unmtx_gpu/
cuda.rs

1//
2// Copyright (c) 2025 Ɓukasz Szpakowski
3//
4// This Source Code Form is subject to the terms of the Mozilla Public
5// License, v. 2.0. If a copy of the MPL was not distributed with this
6// file, You can obtain one at https://mozilla.org/MPL/2.0/.
7//
8//! A module that contains a CUDA backend.
9use std::default::Default;
10use std::ffi::c_int;
11use std::ffi::c_void;
12use std::sync::Arc;
13use std::sync::Mutex;
14use crate::Backend;
15use crate::BackendArray;
16use crate::Error;
17use crate::Result;
18use crate::mutex_lock;
19
20pub use cudarc::cublas::result::CublasError;
21pub use cudarc::driver::DriverError;
22
23use cudarc::cublas::result::sgemm;
24use cudarc::cublas::sys::cublasOperation_t;
25use cudarc::cublas::CudaBlas;
26use cudarc::driver::sys::CUdeviceptr;
27use cudarc::driver::CudaDevice;
28use cudarc::driver::CudaFunction;
29use cudarc::driver::CudaSlice;
30use cudarc::driver::DeviceRepr;
31use cudarc::driver::DevicePtr;
32use cudarc::driver::DevicePtrMut;
33use cudarc::driver::LaunchAsync;
34use cudarc::driver::LaunchConfig;
35use cudarc::nvrtc::CompileError;
36use cudarc::nvrtc::CompileOptions;
37use cudarc::nvrtc::compile_ptx_with_opts;
38
39const SOURCE: &'static str = include_str!("cuda.cu");
40
41const KERNELS: &'static [&'static str] = &[
42    "transpose_a",
43    "add_a_b",
44    "add_at_b",
45    "add_a_bt",
46    "add_at_bt",
47    "sub_a_b",
48    "sub_at_b",
49    "sub_a_bt",
50    "sub_at_bt",
51    "mul_a_b",
52    "mul_at_b",
53    "mul_a_bt",
54    "mul_at_bt",
55    "mul_a_b_for_elems",
56    "mul_at_b_for_elems",
57    "mul_a_bt_for_elems",
58    "mul_at_bt_for_elems",
59    "div_a_b_for_elems",
60    "div_at_b_for_elems",
61    "div_a_bt_for_elems",
62    "div_at_bt_for_elems",
63    "add_a_b_for_scalar",
64    "add_at_b_for_scalar",
65    "sub_a_b_for_scalar",
66    "sub_at_b_for_scalar",
67    "rsub_a_b_for_scalar",
68    "rsub_at_b_for_scalar",
69    "mul_a_b_for_scalar",
70    "mul_at_b_for_scalar",
71    "div_a_b_for_scalar",
72    "div_at_b_for_scalar",
73    "rdiv_a_b_for_scalar",
74    "rdiv_at_b_for_scalar",
75    "sigmoid_a",
76    "sigmoid_at",
77    "tanh_a",
78    "tanh_at",
79    "softmax_a",
80    "softmax_at",
81    "repeat_col_a",
82    "repeat_row_a"
83];
84
85/// A structure of CUDA backend array.
86///
87/// This structure contains the reference to the device memory.
88#[derive(Debug)]
89pub struct CudaBackendArray
90{
91    slice: Arc<Mutex<CudaSlice<f32>>>,
92    len: usize,
93}
94
95struct CudaInnerBackend
96{
97    device: Arc<CudaDevice>,
98    cublas: Option<CudaBlas>,
99}
100
101/// A structure of CUDA backend.
102pub struct CudaBackend
103{
104    inner: Mutex<CudaInnerBackend>,
105    has_cublas: bool,
106    has_mma: bool,
107}
108
109fn preferred_launch_config(n: usize, m: usize, is_mul: bool, is_mma: bool) -> LaunchConfig
110{
111    if m == 1 && !is_mul {
112        let n2 = ((n + 1023) / 1024) as u32;
113        LaunchConfig {
114            grid_dim: (n2, 1, 1),
115            block_dim: (1024, 1, 1),
116            shared_mem_bytes: 0,
117        }
118    } else if n == 1 && !is_mul {
119        let m2 = ((m + 1023) / 1024) as u32;
120        LaunchConfig {
121            grid_dim: (1, m2, 1),
122            block_dim: (1, 1024, 1),
123            shared_mem_bytes: 0,
124        }
125    } else if is_mul {
126        if is_mma {
127            let n2 = ((n + 63) / 64) as u32;
128            let m2 = ((m + 63) / 64) as u32;
129            LaunchConfig {
130                grid_dim: (n2, m2, 1),
131                block_dim: (1024, 1, 1),
132                shared_mem_bytes: 0,
133            }
134        } else {
135            let n2 = (((n + 3) / 4 + 15) / 16) as u32;
136            let m2 = (((m + 3) / 4 + 15) / 16) as u32;
137            LaunchConfig {
138                grid_dim: (n2, m2, 1),
139                block_dim: (16, 16, 1),
140                shared_mem_bytes: 0,
141            }
142        }
143    } else {
144        let n2 = ((n + 31) / 32) as u32;
145        let m2 = ((m + 31) / 32) as u32;
146        LaunchConfig {
147            grid_dim: (n2, m2, 1),
148            block_dim: (32, 32, 1),
149            shared_mem_bytes: 0,
150        }
151    }
152}
153
154impl CudaBackend
155{
156    /// Creates a CUDA backend for a first device.
157    pub fn new() -> Result<CudaBackend>
158    {
159        if cfg!(feature = "default_cublas") {
160            Self::new_with_ordinal_and_flags(0, true, false)
161        } else if cfg!(feature = "default_mma") {
162            Self::new_with_ordinal_and_flags(0, false, true)
163        } else {
164            Self::new_with_ordinal_and_flags(0, false, false)
165        }
166    }
167    
168    /// Creates a CUDA backend with the ordinal number and the flags.
169    ///
170    /// This method takes the following flags:
171    ///
172    /// - `is_cublas` - use the cuBLAS library to multiplication of matrices
173    /// - `is_mma` - use the mma instruction to multiplication of matrices
174    pub fn new_with_ordinal_and_flags(ordinal: usize, is_cublas: bool, is_mma: bool) -> Result<CudaBackend>
175    {
176        let device = match CudaDevice::new(ordinal) {
177            Ok(tmp_device) => tmp_device,
178            Err(err) => return Err(Error::Cuda(err)),
179        };
180        let mut options: CompileOptions = Default::default();
181        if is_mma {
182            options.options = vec![String::from("-DUNMTX_GPU_MMA=1")];
183            options.arch = Some("sm_80");
184        }
185        let ptx = match compile_ptx_with_opts(SOURCE, options) {
186            Ok(tmp_ptx) => tmp_ptx,
187            Err(CompileError::CompileError { log, .. }) => return Err(Error::Compilation(log.as_c_str().to_string_lossy().into_owned())),
188            Err(err) => return Err(Error::Compilation(format!("{}", err))),
189        };
190        match device.load_ptx(ptx, "unmtx_gpu", KERNELS) {
191            Ok(()) => (),
192            Err(err) => return Err(Error::Cuda(err)),
193        }
194        let cublas = if is_cublas {
195            match CudaBlas::new(device.clone()) {
196                Ok(tmp_cublas) => Some(tmp_cublas),
197                Err(err) => return Err(Error::Cublas(err)),
198            }
199        } else {
200            None
201        };
202        Ok(CudaBackend { inner: Mutex::new(CudaInnerBackend { device, cublas, }), has_cublas: is_cublas, has_mma: is_mma, })
203    }
204    
205    pub fn has_cublas(&self) -> bool
206    { self.has_cublas }
207    
208    fn check_and_launch2<F, G>(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, f: F, g: G) -> Result<()>
209        where F: FnOnce(&CudaBackendArray, &CudaBackendArray) -> Result<()>,
210            G: FnOnce(&CudaInnerBackend, CudaFunction, *mut c_void, *mut c_void) -> Result<()>
211    {
212        #[allow(unreachable_patterns)]
213        match (a, b) {
214            (BackendArray::Cuda(a2), BackendArray::Cuda(b2)) => {
215                f(a2, b2)?;
216                let inner_g = mutex_lock(&self.inner)?;
217                let kernel = match inner_g.device.get_func("unmtx_gpu", kernel_name) {
218                    Some(tmp_kernel) => tmp_kernel,
219                    None => return Err(Error::NoKernel(String::from(kernel_name))),
220                };
221                if !Arc::ptr_eq(&a2.slice, &b2.slice) {
222                    let a_slice_g = mutex_lock(&a2.slice)?;
223                    let mut b_slice_g = mutex_lock(&b2.slice)?;
224                    g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&mut (*b_slice_g)).as_kernel_param())?;
225                } else {
226                    let mut a_slice_g = mutex_lock(&a2.slice)?;
227                    g(&*inner_g, kernel, (&mut (*a_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param())?;
228                }
229                match inner_g.device.synchronize() {
230                    Ok(()) => (),
231                    Err(err) => return Err(Error::Cuda(err)),
232                }
233                Ok(())
234            },
235            _ => Err(Error::InvalidBackendArray),
236        }
237    }
238
239    fn check_and_launch3<F, G>(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, c: &BackendArray, f: F, g: G) -> Result<()>
240        where F: FnOnce(&CudaBackendArray, &CudaBackendArray, &CudaBackendArray) -> Result<()>,
241            G: FnOnce(&CudaInnerBackend, CudaFunction, *mut c_void, *mut c_void, *mut c_void) -> Result<()>
242    {
243        #[allow(unreachable_patterns)]
244        match (a, b, c) {
245            (BackendArray::Cuda(a2), BackendArray::Cuda(b2), BackendArray::Cuda(c2)) => {
246                f(a2, b2, c2)?;
247                let inner_g = mutex_lock(&self.inner)?;
248                let kernel = match inner_g.device.get_func("unmtx_gpu", kernel_name) {
249                    Some(tmp_kernel) => tmp_kernel,
250                    None => return Err(Error::NoKernel(String::from(kernel_name))),
251                };
252                match (Arc::ptr_eq(&a2.slice, &b2.slice), Arc::ptr_eq(&a2.slice, &c2.slice), Arc::ptr_eq(&b2.slice, &c2.slice)) {
253                    (false, false, false) => {
254                        let a_slice_g = mutex_lock(&a2.slice)?;
255                        let b_slice_g = mutex_lock(&b2.slice)?;
256                        let mut c_slice_g = mutex_lock(&c2.slice)?;
257                        g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&(*b_slice_g)).as_kernel_param(), (&mut (*c_slice_g)).as_kernel_param())?
258                    },
259                    (true, false, false) => {
260                        let a_slice_g = mutex_lock(&a2.slice)?;
261                        let mut c_slice_g = mutex_lock(&c2.slice)?;
262                        g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&(*a_slice_g)).as_kernel_param(), (&mut (*c_slice_g)).as_kernel_param())?
263                    },
264                    (false, true, false) => {
265                        let mut a_slice_g = mutex_lock(&a2.slice)?;
266                        let b_slice_g = mutex_lock(&b2.slice)?;
267                        g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&(*b_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param())?
268                    },
269                    (false, false, true) => {
270                        let a_slice_g = mutex_lock(&a2.slice)?;
271                        let mut b_slice_g = mutex_lock(&b2.slice)?;
272                        g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&mut (*b_slice_g)).as_kernel_param(), (&mut (*b_slice_g)).as_kernel_param())?
273                    },
274                    _ => {
275                        let mut a_slice_g = mutex_lock(&a2.slice)?;
276                        g(&*inner_g, kernel, (&mut (*a_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param())?
277                    },
278                }
279                match inner_g.device.synchronize() {
280                    Ok(()) => (),
281                    Err(err) => return Err(Error::Cuda(err)),
282                }
283                Ok(())
284            },
285            _ => Err(Error::InvalidBackendArray),
286        }
287    }    
288
289    fn check_and_launch_cublas3<F, G>(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, f: F, g: G) -> Result<()>
290        where F: FnOnce(&CudaBackendArray, &CudaBackendArray, &CudaBackendArray) -> Result<()>,
291            G: FnOnce(&CudaInnerBackend, CUdeviceptr, CUdeviceptr, CUdeviceptr) -> Result<()>
292    {
293        #[allow(unreachable_patterns)]
294        match (a, b, c) {
295            (BackendArray::Cuda(a2), BackendArray::Cuda(b2), BackendArray::Cuda(c2)) => {
296                f(a2, b2, c2)?;
297                let inner_g = mutex_lock(&self.inner)?;
298                match (Arc::ptr_eq(&a2.slice, &b2.slice), Arc::ptr_eq(&a2.slice, &c2.slice), Arc::ptr_eq(&b2.slice, &c2.slice)) {
299                    (false, false, false) => {
300                        let a_slice_g = mutex_lock(&a2.slice)?;
301                        let b_slice_g = mutex_lock(&b2.slice)?;
302                        let mut c_slice_g = mutex_lock(&c2.slice)?;
303                        let a_device_ptr = *(&(*a_slice_g)).device_ptr();
304                        let b_device_ptr = *(&(*b_slice_g)).device_ptr();
305                        let c_device_ptr = *(&mut (*c_slice_g)).device_ptr_mut();
306                        g(&*inner_g, a_device_ptr, b_device_ptr, c_device_ptr)?
307                    },
308                    (true, false, false) => {
309                        let a_slice_g = mutex_lock(&a2.slice)?;
310                        let mut c_slice_g = mutex_lock(&c2.slice)?;
311                        let a_device_ptr = *(&(*a_slice_g)).device_ptr();
312                        let c_device_ptr = *(&mut (*c_slice_g)).device_ptr_mut();
313                        g(&*inner_g, a_device_ptr, a_device_ptr, c_device_ptr)?
314                    },
315                    (false, true, false) => {
316                        let mut a_slice_g = mutex_lock(&a2.slice)?;
317                        let b_slice_g = mutex_lock(&b2.slice)?;
318                        let a_device_ptr = *(&mut (*a_slice_g)).device_ptr_mut();
319                        let b_device_ptr = *(&(*b_slice_g)).device_ptr();
320                        g(&*inner_g, a_device_ptr, b_device_ptr, a_device_ptr)?
321                    },
322                    (false, false, true) => {
323                        let a_slice_g = mutex_lock(&a2.slice)?;
324                        let mut b_slice_g = mutex_lock(&b2.slice)?;
325                        let a_device_ptr = *(&(*a_slice_g)).device_ptr();
326                        let b_device_ptr = *(&mut (*b_slice_g)).device_ptr_mut();
327                        g(&*inner_g, a_device_ptr, b_device_ptr, b_device_ptr)?
328                    },
329                    _ => {
330                        let mut a_slice_g = mutex_lock(&a2.slice)?;
331                        let a_device_ptr = *(&mut (*a_slice_g)).device_ptr_mut();
332                        g(&*inner_g, a_device_ptr, a_device_ptr, a_device_ptr)?
333                    },
334                }
335                match inner_g.device.synchronize() {
336                    Ok(()) => (),
337                    Err(err) => return Err(Error::Cuda(err)),
338                }
339                Ok(())
340            },
341            _ => Err(Error::InvalidBackendArray),
342        }
343    }
344    
345    fn check_and_launch_for_fun(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
346    {
347        let is_mma = self.has_mma;
348        self.check_and_launch2(kernel_name, a, b, |a2, b2| {
349                if a2.len != n * m {
350                    return Err(Error::BackendArrayElemCount(a2.len, n * m));
351                }
352                if b2.len != n * m {
353                    return Err(Error::BackendArrayElemCount(b2.len, n * m));
354                }
355                Ok(())
356        }, |_, kernel, a_param, b_param| {
357                let config = preferred_launch_config(n, m, false, is_mma);
358                let mut params = vec![
359                    a_param,
360                    b_param,
361                    n.as_kernel_param(),
362                    m.as_kernel_param()
363                ];
364                unsafe {
365                    match kernel.launch(config, &mut params) {
366                        Ok(()) => Ok(()),
367                        Err(err) => Err(Error::Cuda(err)),
368                    }
369                }
370        })
371    }
372
373    fn check_and_launch_for_op(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
374    {
375        let is_mma = self.has_mma;
376        self.check_and_launch3(kernel_name, a, b, c, |a2, b2, c2| {
377                if a2.len != n * m {
378                    return Err(Error::BackendArrayElemCount(a2.len, n * m));
379                }
380                if b2.len != n * m {
381                    return Err(Error::BackendArrayElemCount(b2.len, n * m));
382                }
383                if c2.len != n * m {
384                    return Err(Error::BackendArrayElemCount(c2.len, n * m));
385                }
386                Ok(())
387        }, |_, kernel, a_param, b_param, c_param| {
388                let config = preferred_launch_config(n, m, false, is_mma);
389                let mut params = vec![
390                    a_param,
391                    b_param,
392                    c_param,
393                    n.as_kernel_param(),
394                    m.as_kernel_param()
395                ];
396                unsafe {
397                    match kernel.launch(config, &mut params) {
398                        Ok(()) => Ok(()),
399                        Err(err) => Err(Error::Cuda(err)),
400                    }
401                }
402        })
403    }
404
405    fn check_and_launch_for_mul(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
406    {
407        let is_mma = self.has_mma;
408        self.check_and_launch3(kernel_name, a, b, c, |a2, b2, c2| {
409                if a2.len != n * l {
410                    return Err(Error::BackendArrayElemCount(a2.len, n * l));
411                }
412                if b2.len != l * m {
413                    return Err(Error::BackendArrayElemCount(b2.len, l * m));
414                }
415                if c2.len != n * m {
416                    return Err(Error::BackendArrayElemCount(c2.len, n * m));
417                }
418                Ok(())
419        }, |_, kernel, a_param, b_param, c_param| {
420                let config = preferred_launch_config(n, m, true, is_mma);
421                let mut params = vec![
422                    a_param,
423                    b_param,
424                    c_param,
425                    n.as_kernel_param(),
426                    m.as_kernel_param(),
427                    l.as_kernel_param()
428                ];
429                unsafe {
430                    match kernel.launch(config, &mut params) {
431                        Ok(()) => Ok(()),
432                        Err(err) => Err(Error::Cuda(err)),
433                    }
434                }
435        })
436    }
437
438    fn check_and_launch_for_scalar(&self, kernel_name: &str, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
439    {
440        let is_mma = self.has_mma;
441        self.check_and_launch2(kernel_name, a, c, |a2, c2| {
442                if a2.len != n * m  {
443                    return Err(Error::BackendArrayElemCount(a2.len, n * m));
444                }
445                if c2.len != n * m {
446                    return Err(Error::BackendArrayElemCount(c2.len, n * m));
447                }
448                Ok(())
449        }, |_, kernel, a_param, c_param| {
450                let config = preferred_launch_config(n, m, false, is_mma);
451                let mut params = vec![
452                    a_param,
453                    b.as_kernel_param(),
454                    c_param,
455                    n.as_kernel_param(),
456                    m.as_kernel_param()
457                ];
458                unsafe {
459                    match kernel.launch(config, &mut params) {
460                        Ok(()) => Ok(()),
461                        Err(err) => Err(Error::Cuda(err)),
462                    }
463                }
464        })
465    }
466
467    fn check_and_launch_for_fun_and_tiles(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
468    {
469        let is_mma = self.has_mma;
470        self.check_and_launch2(kernel_name, a, b, |a2, b2| {
471                if a2.len != n * m {
472                    return Err(Error::BackendArrayElemCount(a2.len, n * m));
473                }
474                if b2.len != n * m {
475                    return Err(Error::BackendArrayElemCount(b2.len, n * m));
476                }
477                Ok(())
478        }, |_, kernel, a_param, b_param| {
479                let config = preferred_launch_config(n, m, false, is_mma);
480                let mut params = vec![
481                    a_param,
482                    b_param,
483                    n.as_kernel_param(),
484                    m.as_kernel_param(),
485                    ((config.block_dim.1) as usize).as_kernel_param(),
486                    ((config.block_dim.0) as usize).as_kernel_param()
487                ];
488                unsafe {
489                    match kernel.launch(config, &mut params) {
490                        Ok(()) => Ok(()),
491                        Err(err) => Err(Error::Cuda(err)),
492                    }
493                }
494        })
495    }
496
497    fn check_and_launch_for_repeat_col(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
498    {
499        let is_mma = self.has_mma;
500        self.check_and_launch2(kernel_name, a, b, |a2, b2| {
501                if a2.len != n {
502                    return Err(Error::BackendArrayElemCount(a2.len, n));
503                }
504                if b2.len != n * m {
505                    return Err(Error::BackendArrayElemCount(b2.len, n * m));
506                }
507                Ok(())
508        }, |_, kernel, a_param, b_param| {
509                let config = preferred_launch_config(n, m, false, is_mma);
510                let mut params = vec![
511                    a_param,
512                    b_param,
513                    n.as_kernel_param(),
514                    m.as_kernel_param()
515                ];
516                unsafe {
517                    match kernel.launch(config, &mut params) {
518                        Ok(()) => Ok(()),
519                        Err(err) => Err(Error::Cuda(err)),
520                    }
521                }
522        })
523    }
524
525    fn check_and_launch_for_repeat_row(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
526    {
527        let is_mma = self.has_mma;
528        self.check_and_launch2(kernel_name, a, b, |a2, b2| {
529                if a2.len != m {
530                    return Err(Error::BackendArrayElemCount(a2.len, m));
531                }
532                if b2.len != n * m {
533                    return Err(Error::BackendArrayElemCount(b2.len, n * m));
534                }
535                Ok(())
536        }, |_, kernel, a_param, b_param| {
537                let config = preferred_launch_config(n, m, false, is_mma);
538                let mut params = vec![
539                    a_param,
540                    b_param,
541                    n.as_kernel_param(),
542                    m.as_kernel_param()
543                ];
544                unsafe {
545                    match kernel.launch(config, &mut params) {
546                        Ok(()) => Ok(()),
547                        Err(err) => Err(Error::Cuda(err)),
548                    }
549                }
550        })
551    }    
552    
553    fn check_and_launch_cublas_for_mul(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize, is_trans_a: bool, is_trans_b: bool) -> Result<()>
554    {
555        self.check_and_launch_cublas3(a, b, c, |a2, b2, c2| {
556                if a2.len != n * l {
557                    return Err(Error::BackendArrayElemCount(a2.len, n * l));
558                }
559                if b2.len != l * m {
560                    return Err(Error::BackendArrayElemCount(b2.len, l * m));
561                }
562                if c2.len != n * m {
563                    return Err(Error::BackendArrayElemCount(c2.len, n * m));
564                }
565                Ok(())
566        }, |inner, a_device_ptr, b_device_ptr, c_device_ptr| {
567                unsafe {
568                    match &inner.cublas {
569                        Some(cublas) => {
570                            let (transa, lda) = if is_trans_a {
571                                (cublasOperation_t::CUBLAS_OP_T, n as c_int)
572                            } else {
573                                (cublasOperation_t::CUBLAS_OP_N, l as c_int)
574                            };
575                            let (transb, ldb) = if is_trans_b {
576                                (cublasOperation_t::CUBLAS_OP_T, l as c_int)
577                            } else {
578                                (cublasOperation_t::CUBLAS_OP_N, m as c_int)
579                            };
580                            let alpha = 1.0f32;
581                            let beta = 0.0f32;
582                            let res = sgemm(*cublas.handle(),
583                                transb, transa,
584                                m as c_int, n as c_int, l as c_int,
585                                (&alpha) as *const _,
586                                b_device_ptr as *const _, ldb,
587                                a_device_ptr as *const _, lda,
588                                (&beta) as *const _,
589                                c_device_ptr as *mut _, m as c_int);
590                            match res {
591                                Ok(()) => Ok(()),
592                                Err(err) => Err(Error::Cublas(err)),
593                            }
594                        },
595                        None => Err(Error::NoCublas),
596                    }
597                }
598        })
599    }
600}
601
602impl Backend for CudaBackend
603{
604    fn name(&self) -> &'static str
605    {
606        if self.has_cublas {
607            "CUDA(cuBLAS)"
608        } else if self.has_mma {
609            "CUDA(mma)"
610        } else {
611            "CUDA"
612        }
613    }
614    
615    fn has_cublas(&self) -> bool
616    { self.has_cublas }
617
618    unsafe fn alloc(&self, n: usize) -> Result<BackendArray>
619    {
620        let inner_g = mutex_lock(&self.inner)?;
621        let slice: CudaSlice<f32> = match inner_g.device.alloc(n) {
622            Ok(tmp_slice) => tmp_slice,
623            Err(err) => return Err(Error::Cuda(err)),
624        };
625        let cuda_array = CudaBackendArray { slice: Arc::new(Mutex::new(slice)), len: n, };
626        Ok(BackendArray::Cuda(cuda_array))
627    }
628
629    fn alloc_and_store_zeros(&self, n: usize) -> Result<BackendArray>
630    {
631        let inner_g = mutex_lock(&self.inner)?;
632        let slice: CudaSlice<f32> = match inner_g.device.alloc_zeros(n) {
633            Ok(tmp_slice) => tmp_slice,
634            Err(err) => return Err(Error::Cuda(err)),
635        };
636        let cuda_array = CudaBackendArray { slice: Arc::new(Mutex::new(slice)), len: n, };
637        Ok(BackendArray::Cuda(cuda_array))
638    }
639    
640    fn alloc_and_store(&self, elems: &[f32]) -> Result<BackendArray>
641    {
642        let inner_g = mutex_lock(&self.inner)?;
643        let slice: CudaSlice<f32> = match inner_g.device.htod_sync_copy(elems) {
644            Ok(tmp_slice) => tmp_slice,
645            Err(err) => return Err(Error::Cuda(err)),
646        };
647        let cuda_array = CudaBackendArray { slice: Arc::new(Mutex::new(slice)), len: elems.len(), };
648        Ok(BackendArray::Cuda(cuda_array))
649    }
650    
651    fn load(&self, a: &BackendArray, elems: &mut [f32]) -> Result<()>
652    {
653        #[allow(unreachable_patterns)]
654        match a {
655            BackendArray::Cuda(a2) => {
656                if a2.len != elems.len() {
657                    return Err(Error::BackendArrayElemCount(a2.len, elems.len()));
658                }
659                let inner_g = mutex_lock(&self.inner)?;
660                let a_slice_g = mutex_lock(&a2.slice)?;
661                match inner_g.device.dtoh_sync_copy_into(&(*a_slice_g), elems) {
662                    Ok(()) => (),
663                    Err(err) => return Err(Error::Cuda(err)),
664                }
665            },
666            _ => return Err(Error::InvalidBackendArray),
667        }
668        Ok(())
669    }
670
671    fn store(&self, a: &BackendArray, elems: &[f32]) -> Result<()>
672    {
673        #[allow(unreachable_patterns)]
674        match a {
675            BackendArray::Cuda(a2) => {
676                if a2.len != elems.len() {
677                    return Err(Error::BackendArrayElemCount(a2.len, elems.len()));
678                }
679                let inner_g = mutex_lock(&self.inner)?;
680                let mut a_slice_g = mutex_lock(&a2.slice)?;
681                match inner_g.device.htod_sync_copy_into(elems, &mut (*a_slice_g)) {
682                    Ok(()) => (),
683                    Err(err) => return Err(Error::Cuda(err)),
684                }
685            },
686            _ => return Err(Error::InvalidBackendArray),
687        }
688        Ok(())
689    }
690    
691    fn copy(&self, a: &BackendArray, b: &BackendArray) -> Result<()>
692    {
693        #[allow(unreachable_patterns)]
694        match (a, b) {
695            (BackendArray::Cuda(a2), BackendArray::Cuda(b2)) => {
696                if Arc::ptr_eq(&a2.slice, &b2.slice) {
697                    return Ok(());
698                }
699                if a2.len != b2.len {
700                    return Err(Error::TwoBackendArrayElemCounts(a2.len, b2.len));
701                }
702                let inner_g = mutex_lock(&self.inner)?;
703                let a_slice_g = mutex_lock(&a2.slice)?;
704                let mut b_slice_g = mutex_lock(&b2.slice)?;
705                match inner_g.device.dtod_copy(&(*a_slice_g), &mut (*b_slice_g)) {
706                    Ok(()) => (),
707                    Err(err) => return Err(Error::Cuda(err)),
708                }
709                match inner_g.device.synchronize() {
710                    Ok(()) => (),
711                    Err(err) => return Err(Error::Cuda(err)),
712                }
713            },
714            _ => return Err(Error::InvalidBackendArray),
715        }
716        Ok(())
717    }
718
719    fn transpose_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
720    { self.check_and_launch_for_fun("transpose_a", a, b, n, m) }
721
722    fn add_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
723    { self.check_and_launch_for_op("add_a_b", a, b, c, n, m) }
724
725    fn add_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
726    { self.check_and_launch_for_op("add_at_b", a, b, c, n, m) }
727    
728    fn add_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
729    { self.check_and_launch_for_op("add_a_bt", a, b, c, n, m) }
730
731    fn add_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
732    { self.check_and_launch_for_op("add_at_bt", a, b, c, n, m) }
733
734    fn sub_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
735    { self.check_and_launch_for_op("sub_a_b", a, b, c, n, m) }
736
737    fn sub_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
738    { self.check_and_launch_for_op("sub_at_b", a, b, c, n, m) }
739    
740    fn sub_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
741    { self.check_and_launch_for_op("sub_a_bt", a, b, c, n, m) }
742
743    fn sub_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>    
744    { self.check_and_launch_for_op("sub_at_bt", a, b, c, n, m) }
745    
746    fn mul_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
747    {
748        if self.has_cublas {
749            self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, false, false)
750        } else {
751            self.check_and_launch_for_mul("mul_a_b", a, b, c, n, m, l)
752        }
753    }
754
755    fn mul_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
756    {
757        if self.has_cublas {
758            self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, true, false)
759        } else {
760            self.check_and_launch_for_mul("mul_at_b", a, b, c, n, m, l)
761        }
762    }
763
764    fn mul_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
765    {
766        if self.has_cublas {
767            self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, false, true)
768        } else {
769            self.check_and_launch_for_mul("mul_a_bt", a, b, c, n, m, l) 
770        }
771    }
772
773    fn mul_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
774    {
775        if self.has_cublas {
776            self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, true, true)
777        } else {
778            self.check_and_launch_for_mul("mul_at_bt", a, b, c, n, m, l)
779        }
780    }
781
782    fn mul_a_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
783    { self.check_and_launch_for_op("mul_a_b_for_elems", a, b, c, n, m) }
784
785    fn mul_at_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
786    { self.check_and_launch_for_op("mul_at_b_for_elems", a, b, c, n, m) }
787    
788    fn mul_a_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
789    { self.check_and_launch_for_op("mul_a_bt_for_elems", a, b, c, n, m) }
790    
791    fn mul_at_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
792    { self.check_and_launch_for_op("mul_at_bt_for_elems", a, b, c, n, m) }
793
794    fn div_a_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
795    { self.check_and_launch_for_op("div_a_b_for_elems", a, b, c, n, m) }
796
797    fn div_at_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
798    { self.check_and_launch_for_op("div_at_b_for_elems", a, b, c, n, m) }
799    
800    fn div_a_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
801    { self.check_and_launch_for_op("div_a_bt_for_elems", a, b, c, n, m) }
802    
803    fn div_at_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
804    { self.check_and_launch_for_op("div_at_bt_for_elems", a, b, c, n, m) }
805
806    fn add_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
807    { self.check_and_launch_for_scalar("add_a_b_for_scalar", a, b, c, n, m) }
808
809    fn add_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
810    { self.check_and_launch_for_scalar("add_at_b_for_scalar", a, b, c, n, m) }
811
812    fn sub_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
813    { self.check_and_launch_for_scalar("sub_a_b_for_scalar", a, b, c, n, m) }
814
815    fn sub_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
816    { self.check_and_launch_for_scalar("sub_at_b_for_scalar", a, b, c, n, m) }
817
818    fn rsub_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
819    { self.check_and_launch_for_scalar("rsub_a_b_for_scalar", a, b, c, n, m) }
820
821    fn rsub_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
822    { self.check_and_launch_for_scalar("rsub_at_b_for_scalar", a, b, c, n, m) }
823    
824    fn mul_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
825    { self.check_and_launch_for_scalar("mul_a_b_for_scalar", a, b, c, n, m) }
826
827    fn mul_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
828    { self.check_and_launch_for_scalar("mul_at_b_for_scalar", a, b, c, n, m) }
829
830    fn div_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
831    { self.check_and_launch_for_scalar("div_a_b_for_scalar", a, b, c, n, m) }
832
833    fn div_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
834    { self.check_and_launch_for_scalar("div_at_b_for_scalar", a, b, c, n, m) }
835
836    fn rdiv_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
837    { self.check_and_launch_for_scalar("rdiv_a_b_for_scalar", a, b, c, n, m) }
838
839    fn rdiv_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
840    { self.check_and_launch_for_scalar("rdiv_at_b_for_scalar", a, b, c, n, m) }
841
842    fn sigmoid_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
843    { self.check_and_launch_for_fun("sigmoid_a", a, b, n, m) }
844
845    fn sigmoid_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
846    { self.check_and_launch_for_fun("sigmoid_at", a, b, n, m) }
847
848    fn tanh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
849    { self.check_and_launch_for_fun("tanh_a", a, b, n, m) }
850
851    fn tanh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
852    { self.check_and_launch_for_fun("tanh_at", a, b, n, m) }
853
854    fn softmax_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
855    { self.check_and_launch_for_fun_and_tiles("softmax_a", a, b, n, m) }
856
857    fn softmax_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
858    { self.check_and_launch_for_fun_and_tiles("softmax_at", a, b, n, m) }
859
860    fn repeat_col_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
861    { self.check_and_launch_for_repeat_col("repeat_col_a", a, b, n, m) }
862
863    fn repeat_row_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
864    { self.check_and_launch_for_repeat_row("repeat_row_a", a, b, n, m) }
865}
866
867#[cfg(test)]
868mod tests;