af_cuda_interop/
lib.rs

1//! af-cuda-interop package is to used only when the application intends to mix
2//! arrayfire code with raw CUDA code.
3
4use arrayfire::{handle_error_general, AfError};
5use cuda_runtime_sys::cudaStream_t;
6use libc::c_int;
7
8extern "C" {
9    fn afcu_get_native_id(native_id: *mut c_int, id: c_int) -> c_int;
10    fn afcu_set_native_id(native_id: c_int) -> c_int;
11    fn afcu_get_stream(out: *mut cudaStream_t, id: c_int) -> c_int;
12}
13
14/// Get active device's id in CUDA context
15///
16/// # Parameters
17///
18/// - `id` is the integer identifier of concerned CUDA device as per ArrayFire context
19///
20/// # Return Values
21///
22/// Integer identifier of device in CUDA context
23pub fn get_device_native_id(id: i32) -> i32 {
24    unsafe {
25        let mut temp: i32 = 0;
26        let err_val = afcu_get_native_id(&mut temp as *mut c_int, id);
27        handle_error_general(AfError::from(err_val));
28        temp
29    }
30}
31
32/// Set active device using CUDA context's id
33///
34/// # Parameters
35///
36/// - `id` is the identifier of GPU in CUDA context
37pub fn set_device_native_id(native_id: i32) {
38    unsafe {
39        let err_val = afcu_set_native_id(native_id);
40        handle_error_general(AfError::from(err_val));
41    }
42}
43
44/// Get CUDA stream of active CUDA device
45///
46/// # Parameters
47///
48/// - `id` is the identifier of device in ArrayFire context
49///
50/// # Return Values
51///
52/// [cudaStream_t](https://docs.rs/cuda-runtime-sys/0.3.0-alpha.1/cuda_runtime_sys/type.cudaStream_t.html) handle.
53pub fn get_stream(native_id: i32) -> cudaStream_t {
54    unsafe {
55        let mut ret_val: cudaStream_t = std::ptr::null_mut();
56        let err_val = afcu_get_stream(&mut ret_val as *mut cudaStream_t, native_id);
57        handle_error_general(AfError::from(err_val));
58        ret_val
59    }
60}