1use std::default::Default;
10use std::ffi::c_int;
11use std::ffi::c_void;
12use std::sync::Arc;
13use std::sync::Mutex;
14use crate::Backend;
15use crate::BackendArray;
16use crate::Error;
17use crate::Result;
18use crate::mutex_lock;
19
20pub use cudarc::cublas::result::CublasError;
21pub use cudarc::driver::DriverError;
22
23use cudarc::cublas::result::sgemm;
24use cudarc::cublas::sys::cublasOperation_t;
25use cudarc::cublas::CudaBlas;
26use cudarc::driver::sys::CUdeviceptr;
27use cudarc::driver::CudaDevice;
28use cudarc::driver::CudaFunction;
29use cudarc::driver::CudaSlice;
30use cudarc::driver::DeviceRepr;
31use cudarc::driver::DevicePtr;
32use cudarc::driver::DevicePtrMut;
33use cudarc::driver::LaunchAsync;
34use cudarc::driver::LaunchConfig;
35use cudarc::nvrtc::CompileError;
36use cudarc::nvrtc::CompileOptions;
37use cudarc::nvrtc::compile_ptx_with_opts;
38
39const SOURCE: &'static str = include_str!("cuda.cu");
40
41const KERNELS: &'static [&'static str] = &[
42 "transpose_a",
43 "add_a_b",
44 "add_at_b",
45 "add_a_bt",
46 "add_at_bt",
47 "sub_a_b",
48 "sub_at_b",
49 "sub_a_bt",
50 "sub_at_bt",
51 "mul_a_b",
52 "mul_at_b",
53 "mul_a_bt",
54 "mul_at_bt",
55 "mul_a_b_for_elems",
56 "mul_at_b_for_elems",
57 "mul_a_bt_for_elems",
58 "mul_at_bt_for_elems",
59 "div_a_b_for_elems",
60 "div_at_b_for_elems",
61 "div_a_bt_for_elems",
62 "div_at_bt_for_elems",
63 "add_a_b_for_scalar",
64 "add_at_b_for_scalar",
65 "sub_a_b_for_scalar",
66 "sub_at_b_for_scalar",
67 "rsub_a_b_for_scalar",
68 "rsub_at_b_for_scalar",
69 "mul_a_b_for_scalar",
70 "mul_at_b_for_scalar",
71 "div_a_b_for_scalar",
72 "div_at_b_for_scalar",
73 "rdiv_a_b_for_scalar",
74 "rdiv_at_b_for_scalar",
75 "sigmoid_a",
76 "sigmoid_at",
77 "tanh_a",
78 "tanh_at",
79 "swish_a",
80 "swish_at",
81 "softmax_a",
82 "softmax_at",
83 "sqrt_a",
84 "sqrt_at",
85 "repeat_col_a",
86 "repeat_row_a",
87 "abs_a",
88 "abs_at",
89 "pow_a_b",
90 "pow_at_b",
91 "pow_a_bt",
92 "pow_at_bt",
93 "pow_a_b_for_scalar",
94 "pow_at_b_for_scalar",
95 "rpow_a_b_for_scalar",
96 "rpow_at_b_for_scalar",
97 "exp_a",
98 "exp_at",
99 "ln_a",
100 "ln_at",
101 "log2_a",
102 "log2_at",
103 "log10_a",
104 "log10_at",
105 "sin_a",
106 "sin_at",
107 "cos_a",
108 "cos_at",
109 "tan_a",
110 "tan_at",
111 "asin_a",
112 "asin_at",
113 "acos_a",
114 "acos_at",
115 "atan_a",
116 "atan_at",
117 "atan2_a_b",
118 "atan2_at_b",
119 "atan2_a_bt",
120 "atan2_at_bt",
121 "atan2_a_b_for_scalar",
122 "atan2_at_b_for_scalar",
123 "ratan2_a_b_for_scalar",
124 "ratan2_at_b_for_scalar",
125 "sinh_a",
126 "sinh_at",
127 "cosh_a",
128 "cosh_at",
129 "asinh_a",
130 "asinh_at",
131 "acosh_a",
132 "acosh_at",
133 "atanh_a",
134 "atanh_at",
135 "signum_a",
136 "signum_at",
137 "ceil_a",
138 "ceil_at",
139 "floor_a",
140 "floor_at",
141 "round_a",
142 "round_at",
143 "trunc_a",
144 "trunc_at",
145 "max_a_b",
146 "max_at_b",
147 "max_a_bt",
148 "max_at_bt",
149 "max_a_b_for_scalar",
150 "max_at_b_for_scalar",
151 "min_a_b",
152 "min_at_b",
153 "min_a_bt",
154 "min_at_bt",
155 "min_a_b_for_scalar",
156 "min_at_b_for_scalar"
157];
158
159#[derive(Debug)]
163pub struct CudaBackendArray
164{
165 slice: Arc<Mutex<CudaSlice<f32>>>,
166 len: usize,
167}
168
169struct CudaInnerBackend
170{
171 device: Arc<CudaDevice>,
172 cublas: Option<CudaBlas>,
173}
174
175pub struct CudaBackend
177{
178 inner: Mutex<CudaInnerBackend>,
179 has_cublas: bool,
180 has_mma: bool,
181}
182
183fn preferred_launch_config(n: usize, m: usize, is_mul: bool, is_mma: bool) -> LaunchConfig
184{
185 if m == 1 && !is_mul {
186 let n2 = ((n + 1023) / 1024) as u32;
187 LaunchConfig {
188 grid_dim: (n2, 1, 1),
189 block_dim: (1024, 1, 1),
190 shared_mem_bytes: 0,
191 }
192 } else if n == 1 && !is_mul {
193 let m2 = ((m + 1023) / 1024) as u32;
194 LaunchConfig {
195 grid_dim: (1, m2, 1),
196 block_dim: (1, 1024, 1),
197 shared_mem_bytes: 0,
198 }
199 } else if is_mul {
200 if is_mma {
201 let n2 = ((n + 63) / 64) as u32;
202 let m2 = ((m + 63) / 64) as u32;
203 LaunchConfig {
204 grid_dim: (n2, m2, 1),
205 block_dim: (1024, 1, 1),
206 shared_mem_bytes: 0,
207 }
208 } else {
209 let n2 = (((n + 3) / 4 + 15) / 16) as u32;
210 let m2 = (((m + 3) / 4 + 15) / 16) as u32;
211 LaunchConfig {
212 grid_dim: (n2, m2, 1),
213 block_dim: (16, 16, 1),
214 shared_mem_bytes: 0,
215 }
216 }
217 } else {
218 let n2 = ((n + 31) / 32) as u32;
219 let m2 = ((m + 31) / 32) as u32;
220 LaunchConfig {
221 grid_dim: (n2, m2, 1),
222 block_dim: (32, 32, 1),
223 shared_mem_bytes: 0,
224 }
225 }
226}
227
228impl CudaBackend
229{
230 pub fn new() -> Result<CudaBackend>
232 {
233 if cfg!(feature = "default_cublas") {
234 Self::new_with_ordinal_and_flags(0, true, false)
235 } else if cfg!(feature = "default_mma") {
236 Self::new_with_ordinal_and_flags(0, false, true)
237 } else {
238 Self::new_with_ordinal_and_flags(0, false, false)
239 }
240 }
241
242 pub fn new_with_ordinal_and_flags(ordinal: usize, is_cublas: bool, is_mma: bool) -> Result<CudaBackend>
249 {
250 let device = match CudaDevice::new(ordinal) {
251 Ok(tmp_device) => tmp_device,
252 Err(err) => return Err(Error::Cuda(err)),
253 };
254 let mut options: CompileOptions = Default::default();
255 if is_mma {
256 options.options = vec![String::from("-DUNMTX_GPU_MMA=1")];
257 options.arch = Some("sm_80");
258 }
259 let ptx = match compile_ptx_with_opts(SOURCE, options) {
260 Ok(tmp_ptx) => tmp_ptx,
261 Err(CompileError::CompileError { log, .. }) => return Err(Error::Compilation(log.as_c_str().to_string_lossy().into_owned())),
262 Err(err) => return Err(Error::Compilation(format!("{}", err))),
263 };
264 match device.load_ptx(ptx, "unmtx_gpu", KERNELS) {
265 Ok(()) => (),
266 Err(err) => return Err(Error::Cuda(err)),
267 }
268 let cublas = if is_cublas {
269 match CudaBlas::new(device.clone()) {
270 Ok(tmp_cublas) => Some(tmp_cublas),
271 Err(err) => return Err(Error::Cublas(err)),
272 }
273 } else {
274 None
275 };
276 Ok(CudaBackend { inner: Mutex::new(CudaInnerBackend { device, cublas, }), has_cublas: is_cublas, has_mma: is_mma, })
277 }
278
279 pub fn has_cublas(&self) -> bool
280 { self.has_cublas }
281
282 fn check_and_launch2<F, G>(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, f: F, g: G) -> Result<()>
283 where F: FnOnce(&CudaBackendArray, &CudaBackendArray) -> Result<()>,
284 G: FnOnce(&CudaInnerBackend, CudaFunction, *mut c_void, *mut c_void) -> Result<()>
285 {
286 #[allow(unreachable_patterns)]
287 match (a, b) {
288 (BackendArray::Cuda(a2), BackendArray::Cuda(b2)) => {
289 f(a2, b2)?;
290 let inner_g = mutex_lock(&self.inner)?;
291 let kernel = match inner_g.device.get_func("unmtx_gpu", kernel_name) {
292 Some(tmp_kernel) => tmp_kernel,
293 None => return Err(Error::NoKernel(String::from(kernel_name))),
294 };
295 if !Arc::ptr_eq(&a2.slice, &b2.slice) {
296 let a_slice_g = mutex_lock(&a2.slice)?;
297 let mut b_slice_g = mutex_lock(&b2.slice)?;
298 g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&mut (*b_slice_g)).as_kernel_param())?;
299 } else {
300 let mut a_slice_g = mutex_lock(&a2.slice)?;
301 g(&*inner_g, kernel, (&mut (*a_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param())?;
302 }
303 match inner_g.device.synchronize() {
304 Ok(()) => (),
305 Err(err) => return Err(Error::Cuda(err)),
306 }
307 Ok(())
308 },
309 _ => Err(Error::InvalidBackendArray),
310 }
311 }
312
313 fn check_and_launch3<F, G>(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, c: &BackendArray, f: F, g: G) -> Result<()>
314 where F: FnOnce(&CudaBackendArray, &CudaBackendArray, &CudaBackendArray) -> Result<()>,
315 G: FnOnce(&CudaInnerBackend, CudaFunction, *mut c_void, *mut c_void, *mut c_void) -> Result<()>
316 {
317 #[allow(unreachable_patterns)]
318 match (a, b, c) {
319 (BackendArray::Cuda(a2), BackendArray::Cuda(b2), BackendArray::Cuda(c2)) => {
320 f(a2, b2, c2)?;
321 let inner_g = mutex_lock(&self.inner)?;
322 let kernel = match inner_g.device.get_func("unmtx_gpu", kernel_name) {
323 Some(tmp_kernel) => tmp_kernel,
324 None => return Err(Error::NoKernel(String::from(kernel_name))),
325 };
326 match (Arc::ptr_eq(&a2.slice, &b2.slice), Arc::ptr_eq(&a2.slice, &c2.slice), Arc::ptr_eq(&b2.slice, &c2.slice)) {
327 (false, false, false) => {
328 let a_slice_g = mutex_lock(&a2.slice)?;
329 let b_slice_g = mutex_lock(&b2.slice)?;
330 let mut c_slice_g = mutex_lock(&c2.slice)?;
331 g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&(*b_slice_g)).as_kernel_param(), (&mut (*c_slice_g)).as_kernel_param())?
332 },
333 (true, false, false) => {
334 let a_slice_g = mutex_lock(&a2.slice)?;
335 let mut c_slice_g = mutex_lock(&c2.slice)?;
336 g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&(*a_slice_g)).as_kernel_param(), (&mut (*c_slice_g)).as_kernel_param())?
337 },
338 (false, true, false) => {
339 let mut a_slice_g = mutex_lock(&a2.slice)?;
340 let b_slice_g = mutex_lock(&b2.slice)?;
341 g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&(*b_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param())?
342 },
343 (false, false, true) => {
344 let a_slice_g = mutex_lock(&a2.slice)?;
345 let mut b_slice_g = mutex_lock(&b2.slice)?;
346 g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&mut (*b_slice_g)).as_kernel_param(), (&mut (*b_slice_g)).as_kernel_param())?
347 },
348 _ => {
349 let mut a_slice_g = mutex_lock(&a2.slice)?;
350 g(&*inner_g, kernel, (&mut (*a_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param())?
351 },
352 }
353 match inner_g.device.synchronize() {
354 Ok(()) => (),
355 Err(err) => return Err(Error::Cuda(err)),
356 }
357 Ok(())
358 },
359 _ => Err(Error::InvalidBackendArray),
360 }
361 }
362
363 fn check_and_launch_cublas3<F, G>(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, f: F, g: G) -> Result<()>
364 where F: FnOnce(&CudaBackendArray, &CudaBackendArray, &CudaBackendArray) -> Result<()>,
365 G: FnOnce(&CudaInnerBackend, CUdeviceptr, CUdeviceptr, CUdeviceptr) -> Result<()>
366 {
367 #[allow(unreachable_patterns)]
368 match (a, b, c) {
369 (BackendArray::Cuda(a2), BackendArray::Cuda(b2), BackendArray::Cuda(c2)) => {
370 f(a2, b2, c2)?;
371 let inner_g = mutex_lock(&self.inner)?;
372 match (Arc::ptr_eq(&a2.slice, &b2.slice), Arc::ptr_eq(&a2.slice, &c2.slice), Arc::ptr_eq(&b2.slice, &c2.slice)) {
373 (false, false, false) => {
374 let a_slice_g = mutex_lock(&a2.slice)?;
375 let b_slice_g = mutex_lock(&b2.slice)?;
376 let mut c_slice_g = mutex_lock(&c2.slice)?;
377 let a_device_ptr = *(&(*a_slice_g)).device_ptr();
378 let b_device_ptr = *(&(*b_slice_g)).device_ptr();
379 let c_device_ptr = *(&mut (*c_slice_g)).device_ptr_mut();
380 g(&*inner_g, a_device_ptr, b_device_ptr, c_device_ptr)?
381 },
382 (true, false, false) => {
383 let a_slice_g = mutex_lock(&a2.slice)?;
384 let mut c_slice_g = mutex_lock(&c2.slice)?;
385 let a_device_ptr = *(&(*a_slice_g)).device_ptr();
386 let c_device_ptr = *(&mut (*c_slice_g)).device_ptr_mut();
387 g(&*inner_g, a_device_ptr, a_device_ptr, c_device_ptr)?
388 },
389 (false, true, false) => {
390 let mut a_slice_g = mutex_lock(&a2.slice)?;
391 let b_slice_g = mutex_lock(&b2.slice)?;
392 let a_device_ptr = *(&mut (*a_slice_g)).device_ptr_mut();
393 let b_device_ptr = *(&(*b_slice_g)).device_ptr();
394 g(&*inner_g, a_device_ptr, b_device_ptr, a_device_ptr)?
395 },
396 (false, false, true) => {
397 let a_slice_g = mutex_lock(&a2.slice)?;
398 let mut b_slice_g = mutex_lock(&b2.slice)?;
399 let a_device_ptr = *(&(*a_slice_g)).device_ptr();
400 let b_device_ptr = *(&mut (*b_slice_g)).device_ptr_mut();
401 g(&*inner_g, a_device_ptr, b_device_ptr, b_device_ptr)?
402 },
403 _ => {
404 let mut a_slice_g = mutex_lock(&a2.slice)?;
405 let a_device_ptr = *(&mut (*a_slice_g)).device_ptr_mut();
406 g(&*inner_g, a_device_ptr, a_device_ptr, a_device_ptr)?
407 },
408 }
409 match inner_g.device.synchronize() {
410 Ok(()) => (),
411 Err(err) => return Err(Error::Cuda(err)),
412 }
413 Ok(())
414 },
415 _ => Err(Error::InvalidBackendArray),
416 }
417 }
418
419 fn check_and_launch_for_fun(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
420 {
421 let is_mma = self.has_mma;
422 self.check_and_launch2(kernel_name, a, b, |a2, b2| {
423 if a2.len != n * m {
424 return Err(Error::BackendArrayElemCount(a2.len, n * m));
425 }
426 if b2.len != n * m {
427 return Err(Error::BackendArrayElemCount(b2.len, n * m));
428 }
429 Ok(())
430 }, |_, kernel, a_param, b_param| {
431 let config = preferred_launch_config(n, m, false, is_mma);
432 let mut params = vec![
433 a_param,
434 b_param,
435 n.as_kernel_param(),
436 m.as_kernel_param()
437 ];
438 unsafe {
439 match kernel.launch(config, &mut params) {
440 Ok(()) => Ok(()),
441 Err(err) => Err(Error::Cuda(err)),
442 }
443 }
444 })
445 }
446
447 fn check_and_launch_for_op(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
448 {
449 let is_mma = self.has_mma;
450 self.check_and_launch3(kernel_name, a, b, c, |a2, b2, c2| {
451 if a2.len != n * m {
452 return Err(Error::BackendArrayElemCount(a2.len, n * m));
453 }
454 if b2.len != n * m {
455 return Err(Error::BackendArrayElemCount(b2.len, n * m));
456 }
457 if c2.len != n * m {
458 return Err(Error::BackendArrayElemCount(c2.len, n * m));
459 }
460 Ok(())
461 }, |_, kernel, a_param, b_param, c_param| {
462 let config = preferred_launch_config(n, m, false, is_mma);
463 let mut params = vec![
464 a_param,
465 b_param,
466 c_param,
467 n.as_kernel_param(),
468 m.as_kernel_param()
469 ];
470 unsafe {
471 match kernel.launch(config, &mut params) {
472 Ok(()) => Ok(()),
473 Err(err) => Err(Error::Cuda(err)),
474 }
475 }
476 })
477 }
478
479 fn check_and_launch_for_mul(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
480 {
481 let is_mma = self.has_mma;
482 self.check_and_launch3(kernel_name, a, b, c, |a2, b2, c2| {
483 if a2.len != n * l {
484 return Err(Error::BackendArrayElemCount(a2.len, n * l));
485 }
486 if b2.len != l * m {
487 return Err(Error::BackendArrayElemCount(b2.len, l * m));
488 }
489 if c2.len != n * m {
490 return Err(Error::BackendArrayElemCount(c2.len, n * m));
491 }
492 Ok(())
493 }, |_, kernel, a_param, b_param, c_param| {
494 let config = preferred_launch_config(n, m, true, is_mma);
495 let mut params = vec![
496 a_param,
497 b_param,
498 c_param,
499 n.as_kernel_param(),
500 m.as_kernel_param(),
501 l.as_kernel_param()
502 ];
503 unsafe {
504 match kernel.launch(config, &mut params) {
505 Ok(()) => Ok(()),
506 Err(err) => Err(Error::Cuda(err)),
507 }
508 }
509 })
510 }
511
512 fn check_and_launch_for_scalar(&self, kernel_name: &str, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
513 {
514 let is_mma = self.has_mma;
515 self.check_and_launch2(kernel_name, a, c, |a2, c2| {
516 if a2.len != n * m {
517 return Err(Error::BackendArrayElemCount(a2.len, n * m));
518 }
519 if c2.len != n * m {
520 return Err(Error::BackendArrayElemCount(c2.len, n * m));
521 }
522 Ok(())
523 }, |_, kernel, a_param, c_param| {
524 let config = preferred_launch_config(n, m, false, is_mma);
525 let mut params = vec![
526 a_param,
527 b.as_kernel_param(),
528 c_param,
529 n.as_kernel_param(),
530 m.as_kernel_param()
531 ];
532 unsafe {
533 match kernel.launch(config, &mut params) {
534 Ok(()) => Ok(()),
535 Err(err) => Err(Error::Cuda(err)),
536 }
537 }
538 })
539 }
540
541 fn check_and_launch_for_fun_and_tiles(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
542 {
543 let is_mma = self.has_mma;
544 self.check_and_launch2(kernel_name, a, b, |a2, b2| {
545 if a2.len != n * m {
546 return Err(Error::BackendArrayElemCount(a2.len, n * m));
547 }
548 if b2.len != n * m {
549 return Err(Error::BackendArrayElemCount(b2.len, n * m));
550 }
551 Ok(())
552 }, |_, kernel, a_param, b_param| {
553 let config = preferred_launch_config(n, m, false, is_mma);
554 let mut params = vec![
555 a_param,
556 b_param,
557 n.as_kernel_param(),
558 m.as_kernel_param(),
559 ((config.block_dim.1) as usize).as_kernel_param(),
560 ((config.block_dim.0) as usize).as_kernel_param()
561 ];
562 unsafe {
563 match kernel.launch(config, &mut params) {
564 Ok(()) => Ok(()),
565 Err(err) => Err(Error::Cuda(err)),
566 }
567 }
568 })
569 }
570
571 fn check_and_launch_for_repeat_col(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
572 {
573 let is_mma = self.has_mma;
574 self.check_and_launch2(kernel_name, a, b, |a2, b2| {
575 if a2.len != n {
576 return Err(Error::BackendArrayElemCount(a2.len, n));
577 }
578 if b2.len != n * m {
579 return Err(Error::BackendArrayElemCount(b2.len, n * m));
580 }
581 Ok(())
582 }, |_, kernel, a_param, b_param| {
583 let config = preferred_launch_config(n, m, false, is_mma);
584 let mut params = vec![
585 a_param,
586 b_param,
587 n.as_kernel_param(),
588 m.as_kernel_param()
589 ];
590 unsafe {
591 match kernel.launch(config, &mut params) {
592 Ok(()) => Ok(()),
593 Err(err) => Err(Error::Cuda(err)),
594 }
595 }
596 })
597 }
598
599 fn check_and_launch_for_repeat_row(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
600 {
601 let is_mma = self.has_mma;
602 self.check_and_launch2(kernel_name, a, b, |a2, b2| {
603 if a2.len != m {
604 return Err(Error::BackendArrayElemCount(a2.len, m));
605 }
606 if b2.len != n * m {
607 return Err(Error::BackendArrayElemCount(b2.len, n * m));
608 }
609 Ok(())
610 }, |_, kernel, a_param, b_param| {
611 let config = preferred_launch_config(n, m, false, is_mma);
612 let mut params = vec![
613 a_param,
614 b_param,
615 n.as_kernel_param(),
616 m.as_kernel_param()
617 ];
618 unsafe {
619 match kernel.launch(config, &mut params) {
620 Ok(()) => Ok(()),
621 Err(err) => Err(Error::Cuda(err)),
622 }
623 }
624 })
625 }
626
627 fn check_and_launch_cublas_for_mul(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize, is_trans_a: bool, is_trans_b: bool) -> Result<()>
628 {
629 self.check_and_launch_cublas3(a, b, c, |a2, b2, c2| {
630 if a2.len != n * l {
631 return Err(Error::BackendArrayElemCount(a2.len, n * l));
632 }
633 if b2.len != l * m {
634 return Err(Error::BackendArrayElemCount(b2.len, l * m));
635 }
636 if c2.len != n * m {
637 return Err(Error::BackendArrayElemCount(c2.len, n * m));
638 }
639 Ok(())
640 }, |inner, a_device_ptr, b_device_ptr, c_device_ptr| {
641 unsafe {
642 match &inner.cublas {
643 Some(cublas) => {
644 let (transa, lda) = if is_trans_a {
645 (cublasOperation_t::CUBLAS_OP_T, n as c_int)
646 } else {
647 (cublasOperation_t::CUBLAS_OP_N, l as c_int)
648 };
649 let (transb, ldb) = if is_trans_b {
650 (cublasOperation_t::CUBLAS_OP_T, l as c_int)
651 } else {
652 (cublasOperation_t::CUBLAS_OP_N, m as c_int)
653 };
654 let alpha = 1.0f32;
655 let beta = 0.0f32;
656 let res = sgemm(*cublas.handle(),
657 transb, transa,
658 m as c_int, n as c_int, l as c_int,
659 (&alpha) as *const _,
660 b_device_ptr as *const _, ldb,
661 a_device_ptr as *const _, lda,
662 (&beta) as *const _,
663 c_device_ptr as *mut _, m as c_int);
664 match res {
665 Ok(()) => Ok(()),
666 Err(err) => Err(Error::Cublas(err)),
667 }
668 },
669 None => Err(Error::NoCublas),
670 }
671 }
672 })
673 }
674}
675
676impl Backend for CudaBackend
677{
678 fn name(&self) -> &'static str
679 {
680 if self.has_cublas {
681 "CUDA(cuBLAS)"
682 } else if self.has_mma {
683 "CUDA(mma)"
684 } else {
685 "CUDA"
686 }
687 }
688
689 fn has_cublas(&self) -> bool
690 { self.has_cublas }
691
692 unsafe fn alloc(&self, n: usize) -> Result<BackendArray>
693 {
694 let inner_g = mutex_lock(&self.inner)?;
695 let slice: CudaSlice<f32> = match inner_g.device.alloc(n) {
696 Ok(tmp_slice) => tmp_slice,
697 Err(err) => return Err(Error::Cuda(err)),
698 };
699 let cuda_array = CudaBackendArray { slice: Arc::new(Mutex::new(slice)), len: n, };
700 Ok(BackendArray::Cuda(cuda_array))
701 }
702
703 fn alloc_and_store_zeros(&self, n: usize) -> Result<BackendArray>
704 {
705 let inner_g = mutex_lock(&self.inner)?;
706 let slice: CudaSlice<f32> = match inner_g.device.alloc_zeros(n) {
707 Ok(tmp_slice) => tmp_slice,
708 Err(err) => return Err(Error::Cuda(err)),
709 };
710 let cuda_array = CudaBackendArray { slice: Arc::new(Mutex::new(slice)), len: n, };
711 Ok(BackendArray::Cuda(cuda_array))
712 }
713
714 fn alloc_and_store(&self, elems: &[f32]) -> Result<BackendArray>
715 {
716 let inner_g = mutex_lock(&self.inner)?;
717 let slice: CudaSlice<f32> = match inner_g.device.htod_sync_copy(elems) {
718 Ok(tmp_slice) => tmp_slice,
719 Err(err) => return Err(Error::Cuda(err)),
720 };
721 let cuda_array = CudaBackendArray { slice: Arc::new(Mutex::new(slice)), len: elems.len(), };
722 Ok(BackendArray::Cuda(cuda_array))
723 }
724
725 fn load(&self, a: &BackendArray, elems: &mut [f32]) -> Result<()>
726 {
727 #[allow(unreachable_patterns)]
728 match a {
729 BackendArray::Cuda(a2) => {
730 if a2.len != elems.len() {
731 return Err(Error::BackendArrayElemCount(a2.len, elems.len()));
732 }
733 let inner_g = mutex_lock(&self.inner)?;
734 let a_slice_g = mutex_lock(&a2.slice)?;
735 match inner_g.device.dtoh_sync_copy_into(&(*a_slice_g), elems) {
736 Ok(()) => (),
737 Err(err) => return Err(Error::Cuda(err)),
738 }
739 },
740 _ => return Err(Error::InvalidBackendArray),
741 }
742 Ok(())
743 }
744
745 fn store(&self, a: &BackendArray, elems: &[f32]) -> Result<()>
746 {
747 #[allow(unreachable_patterns)]
748 match a {
749 BackendArray::Cuda(a2) => {
750 if a2.len != elems.len() {
751 return Err(Error::BackendArrayElemCount(a2.len, elems.len()));
752 }
753 let inner_g = mutex_lock(&self.inner)?;
754 let mut a_slice_g = mutex_lock(&a2.slice)?;
755 match inner_g.device.htod_sync_copy_into(elems, &mut (*a_slice_g)) {
756 Ok(()) => (),
757 Err(err) => return Err(Error::Cuda(err)),
758 }
759 },
760 _ => return Err(Error::InvalidBackendArray),
761 }
762 Ok(())
763 }
764
765 fn copy(&self, a: &BackendArray, b: &BackendArray) -> Result<()>
766 {
767 #[allow(unreachable_patterns)]
768 match (a, b) {
769 (BackendArray::Cuda(a2), BackendArray::Cuda(b2)) => {
770 if Arc::ptr_eq(&a2.slice, &b2.slice) {
771 return Ok(());
772 }
773 if a2.len != b2.len {
774 return Err(Error::TwoBackendArrayElemCounts(a2.len, b2.len));
775 }
776 let inner_g = mutex_lock(&self.inner)?;
777 let a_slice_g = mutex_lock(&a2.slice)?;
778 let mut b_slice_g = mutex_lock(&b2.slice)?;
779 match inner_g.device.dtod_copy(&(*a_slice_g), &mut (*b_slice_g)) {
780 Ok(()) => (),
781 Err(err) => return Err(Error::Cuda(err)),
782 }
783 match inner_g.device.synchronize() {
784 Ok(()) => (),
785 Err(err) => return Err(Error::Cuda(err)),
786 }
787 },
788 _ => return Err(Error::InvalidBackendArray),
789 }
790 Ok(())
791 }
792
793 fn transpose_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
794 { self.check_and_launch_for_fun("transpose_a", a, b, n, m) }
795
796 fn add_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
797 { self.check_and_launch_for_op("add_a_b", a, b, c, n, m) }
798
799 fn add_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
800 { self.check_and_launch_for_op("add_at_b", a, b, c, n, m) }
801
802 fn add_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
803 { self.check_and_launch_for_op("add_a_bt", a, b, c, n, m) }
804
805 fn add_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
806 { self.check_and_launch_for_op("add_at_bt", a, b, c, n, m) }
807
808 fn sub_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
809 { self.check_and_launch_for_op("sub_a_b", a, b, c, n, m) }
810
811 fn sub_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
812 { self.check_and_launch_for_op("sub_at_b", a, b, c, n, m) }
813
814 fn sub_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
815 { self.check_and_launch_for_op("sub_a_bt", a, b, c, n, m) }
816
817 fn sub_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
818 { self.check_and_launch_for_op("sub_at_bt", a, b, c, n, m) }
819
820 fn mul_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
821 {
822 if self.has_cublas {
823 self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, false, false)
824 } else {
825 self.check_and_launch_for_mul("mul_a_b", a, b, c, n, m, l)
826 }
827 }
828
829 fn mul_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
830 {
831 if self.has_cublas {
832 self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, true, false)
833 } else {
834 self.check_and_launch_for_mul("mul_at_b", a, b, c, n, m, l)
835 }
836 }
837
838 fn mul_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
839 {
840 if self.has_cublas {
841 self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, false, true)
842 } else {
843 self.check_and_launch_for_mul("mul_a_bt", a, b, c, n, m, l)
844 }
845 }
846
847 fn mul_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
848 {
849 if self.has_cublas {
850 self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, true, true)
851 } else {
852 self.check_and_launch_for_mul("mul_at_bt", a, b, c, n, m, l)
853 }
854 }
855
856 fn mul_a_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
857 { self.check_and_launch_for_op("mul_a_b_for_elems", a, b, c, n, m) }
858
859 fn mul_at_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
860 { self.check_and_launch_for_op("mul_at_b_for_elems", a, b, c, n, m) }
861
862 fn mul_a_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
863 { self.check_and_launch_for_op("mul_a_bt_for_elems", a, b, c, n, m) }
864
865 fn mul_at_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
866 { self.check_and_launch_for_op("mul_at_bt_for_elems", a, b, c, n, m) }
867
868 fn div_a_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
869 { self.check_and_launch_for_op("div_a_b_for_elems", a, b, c, n, m) }
870
871 fn div_at_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
872 { self.check_and_launch_for_op("div_at_b_for_elems", a, b, c, n, m) }
873
874 fn div_a_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
875 { self.check_and_launch_for_op("div_a_bt_for_elems", a, b, c, n, m) }
876
877 fn div_at_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
878 { self.check_and_launch_for_op("div_at_bt_for_elems", a, b, c, n, m) }
879
880 fn add_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
881 { self.check_and_launch_for_scalar("add_a_b_for_scalar", a, b, c, n, m) }
882
883 fn add_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
884 { self.check_and_launch_for_scalar("add_at_b_for_scalar", a, b, c, n, m) }
885
886 fn sub_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
887 { self.check_and_launch_for_scalar("sub_a_b_for_scalar", a, b, c, n, m) }
888
889 fn sub_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
890 { self.check_and_launch_for_scalar("sub_at_b_for_scalar", a, b, c, n, m) }
891
892 fn rsub_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
893 { self.check_and_launch_for_scalar("rsub_a_b_for_scalar", a, b, c, n, m) }
894
895 fn rsub_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
896 { self.check_and_launch_for_scalar("rsub_at_b_for_scalar", a, b, c, n, m) }
897
898 fn mul_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
899 { self.check_and_launch_for_scalar("mul_a_b_for_scalar", a, b, c, n, m) }
900
901 fn mul_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
902 { self.check_and_launch_for_scalar("mul_at_b_for_scalar", a, b, c, n, m) }
903
904 fn div_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
905 { self.check_and_launch_for_scalar("div_a_b_for_scalar", a, b, c, n, m) }
906
907 fn div_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
908 { self.check_and_launch_for_scalar("div_at_b_for_scalar", a, b, c, n, m) }
909
910 fn rdiv_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
911 { self.check_and_launch_for_scalar("rdiv_a_b_for_scalar", a, b, c, n, m) }
912
913 fn rdiv_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
914 { self.check_and_launch_for_scalar("rdiv_at_b_for_scalar", a, b, c, n, m) }
915
916 fn sigmoid_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
917 { self.check_and_launch_for_fun("sigmoid_a", a, b, n, m) }
918
919 fn sigmoid_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
920 { self.check_and_launch_for_fun("sigmoid_at", a, b, n, m) }
921
922 fn tanh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
923 { self.check_and_launch_for_fun("tanh_a", a, b, n, m) }
924
925 fn tanh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
926 { self.check_and_launch_for_fun("tanh_at", a, b, n, m) }
927
928 fn swish_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
929 { self.check_and_launch_for_fun("swish_a", a, b, n, m) }
930
931 fn swish_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
932 { self.check_and_launch_for_fun("swish_at", a, b, n, m) }
933
934 fn softmax_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
935 { self.check_and_launch_for_fun_and_tiles("softmax_a", a, b, n, m) }
936
937 fn softmax_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
938 { self.check_and_launch_for_fun_and_tiles("softmax_at", a, b, n, m) }
939
940 fn sqrt_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
941 { self.check_and_launch_for_fun("sqrt_a", a, b, n, m) }
942
943 fn sqrt_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
944 { self.check_and_launch_for_fun("sqrt_at", a, b, n, m) }
945
946 fn repeat_col_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
947 { self.check_and_launch_for_repeat_col("repeat_col_a", a, b, n, m) }
948
949 fn repeat_row_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
950 { self.check_and_launch_for_repeat_row("repeat_row_a", a, b, n, m) }
951
952 fn abs_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
953 { self.check_and_launch_for_fun("abs_a", a, b, n, m) }
954
955 fn abs_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
956 { self.check_and_launch_for_fun("abs_at", a, b, n, m) }
957
958 fn pow_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
959 { self.check_and_launch_for_op("pow_a_b", a, b, c, n, m) }
960
961 fn pow_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
962 { self.check_and_launch_for_op("pow_at_b", a, b, c, n, m) }
963
964 fn pow_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
965 { self.check_and_launch_for_op("pow_a_bt", a, b, c, n, m) }
966
967 fn pow_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
968 { self.check_and_launch_for_op("pow_at_bt", a, b, c, n, m) }
969
970 fn pow_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
971 { self.check_and_launch_for_scalar("pow_a_b_for_scalar", a, b, c, n, m) }
972
973 fn pow_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
974 { self.check_and_launch_for_scalar("pow_at_b_for_scalar", a, b, c, n, m) }
975
976 fn rpow_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
977 { self.check_and_launch_for_scalar("rpow_a_b_for_scalar", a, b, c, n, m) }
978
979 fn rpow_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
980 { self.check_and_launch_for_scalar("rpow_at_b_for_scalar", a, b, c, n, m) }
981
982 fn exp_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
983 { self.check_and_launch_for_fun("exp_a", a, b, n, m) }
984
985 fn exp_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
986 { self.check_and_launch_for_fun("exp_at", a, b, n, m) }
987
988 fn ln_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
989 { self.check_and_launch_for_fun("ln_a", a, b, n, m) }
990
991 fn ln_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
992 { self.check_and_launch_for_fun("ln_at", a, b, n, m) }
993
994 fn log2_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
995 { self.check_and_launch_for_fun("log2_a", a, b, n, m) }
996
997 fn log2_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
998 { self.check_and_launch_for_fun("log2_at", a, b, n, m) }
999
1000 fn log10_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1001 { self.check_and_launch_for_fun("log10_a", a, b, n, m) }
1002
1003 fn log10_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1004 { self.check_and_launch_for_fun("log10_at", a, b, n, m) }
1005
1006 fn sin_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1007 { self.check_and_launch_for_fun("sin_a", a, b, n, m) }
1008
1009 fn sin_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1010 { self.check_and_launch_for_fun("sin_at", a, b, n, m) }
1011
1012 fn cos_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1013 { self.check_and_launch_for_fun("cos_a", a, b, n, m) }
1014
1015 fn cos_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1016 { self.check_and_launch_for_fun("cos_at", a, b, n, m) }
1017
1018 fn tan_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1019 { self.check_and_launch_for_fun("tan_a", a, b, n, m) }
1020
1021 fn tan_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1022 { self.check_and_launch_for_fun("tan_at", a, b, n, m) }
1023
1024 fn asin_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1025 { self.check_and_launch_for_fun("asin_a", a, b, n, m) }
1026
1027 fn asin_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1028 { self.check_and_launch_for_fun("asin_at", a, b, n, m) }
1029
1030 fn acos_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1031 { self.check_and_launch_for_fun("acos_a", a, b, n, m) }
1032
1033 fn acos_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1034 { self.check_and_launch_for_fun("acos_at", a, b, n, m) }
1035
1036 fn atan_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1037 { self.check_and_launch_for_fun("atan_a", a, b, n, m) }
1038
1039 fn atan_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1040 { self.check_and_launch_for_fun("atan_at", a, b, n, m) }
1041
1042 fn atan2_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1043 { self.check_and_launch_for_op("atan2_a_b", a, b, c, n, m) }
1044
1045 fn atan2_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1046 { self.check_and_launch_for_op("atan2_at_b", a, b, c, n, m) }
1047
1048 fn atan2_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1049 { self.check_and_launch_for_op("atan2_a_bt", a, b, c, n, m) }
1050
1051 fn atan2_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1052 { self.check_and_launch_for_op("atan2_at_bt", a, b, c, n, m) }
1053
1054 fn atan2_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
1055 { self.check_and_launch_for_scalar("atan2_a_b_for_scalar", a, b, c, n, m) }
1056
1057 fn atan2_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
1058 { self.check_and_launch_for_scalar("atan2_at_b_for_scalar", a, b, c, n, m) }
1059
1060 fn ratan2_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
1061 { self.check_and_launch_for_scalar("ratan2_a_b_for_scalar", a, b, c, n, m) }
1062
1063 fn ratan2_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
1064 { self.check_and_launch_for_scalar("ratan2_at_b_for_scalar", a, b, c, n, m) }
1065
1066 fn sinh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1067 { self.check_and_launch_for_fun("sinh_a", a, b, n, m) }
1068
1069 fn sinh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1070 { self.check_and_launch_for_fun("sinh_at", a, b, n, m) }
1071
1072 fn cosh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1073 { self.check_and_launch_for_fun("cosh_a", a, b, n, m) }
1074
1075 fn cosh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1076 { self.check_and_launch_for_fun("cosh_at", a, b, n, m) }
1077
1078 fn asinh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1079 { self.check_and_launch_for_fun("asinh_a", a, b, n, m) }
1080
1081 fn asinh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1082 { self.check_and_launch_for_fun("asinh_at", a, b, n, m) }
1083
1084 fn acosh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1085 { self.check_and_launch_for_fun("acosh_a", a, b, n, m) }
1086
1087 fn acosh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1088 { self.check_and_launch_for_fun("acosh_at", a, b, n, m) }
1089
1090 fn atanh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1091 { self.check_and_launch_for_fun("atanh_a", a, b, n, m) }
1092
1093 fn atanh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1094 { self.check_and_launch_for_fun("atanh_at", a, b, n, m) }
1095
1096 fn signum_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1097 { self.check_and_launch_for_fun("signum_a", a, b, n, m) }
1098
1099 fn signum_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1100 { self.check_and_launch_for_fun("signum_at", a, b, n, m) }
1101
1102 fn ceil_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1103 { self.check_and_launch_for_fun("ceil_a", a, b, n, m) }
1104
1105 fn ceil_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1106 { self.check_and_launch_for_fun("ceil_at", a, b, n, m) }
1107
1108 fn floor_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1109 { self.check_and_launch_for_fun("floor_a", a, b, n, m) }
1110
1111 fn floor_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1112 { self.check_and_launch_for_fun("floor_at", a, b, n, m) }
1113
1114 fn round_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1115 { self.check_and_launch_for_fun("round_a", a, b, n, m) }
1116
1117 fn round_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1118 { self.check_and_launch_for_fun("round_at", a, b, n, m) }
1119
1120 fn trunc_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1121 { self.check_and_launch_for_fun("trunc_a", a, b, n, m) }
1122
1123 fn trunc_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1124 { self.check_and_launch_for_fun("trunc_at", a, b, n, m) }
1125
1126 fn max_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1127 { self.check_and_launch_for_op("max_a_b", a, b, c, n, m) }
1128
1129 fn max_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1130 { self.check_and_launch_for_op("max_at_b", a, b, c, n, m) }
1131
1132 fn max_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1133 { self.check_and_launch_for_op("max_a_bt", a, b, c, n, m) }
1134
1135 fn max_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1136 { self.check_and_launch_for_op("max_at_bt", a, b, c, n, m) }
1137
1138 fn max_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
1139 { self.check_and_launch_for_scalar("max_a_b_for_scalar", a, b, c, n, m) }
1140
1141 fn max_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
1142 { self.check_and_launch_for_scalar("max_at_b_for_scalar", a, b, c, n, m) }
1143
1144 fn min_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1145 { self.check_and_launch_for_op("min_a_b", a, b, c, n, m) }
1146
1147 fn min_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1148 { self.check_and_launch_for_op("min_at_b", a, b, c, n, m) }
1149
1150 fn min_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1151 { self.check_and_launch_for_op("min_a_bt", a, b, c, n, m) }
1152
1153 fn min_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1154 { self.check_and_launch_for_op("min_at_bt", a, b, c, n, m) }
1155
1156 fn min_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
1157 { self.check_and_launch_for_scalar("min_a_b_for_scalar", a, b, c, n, m) }
1158
1159 fn min_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
1160 { self.check_and_launch_for_scalar("min_at_b_for_scalar", a, b, c, n, m) }
1161}
1162
1163#[cfg(test)]
1164mod tests;