cust 0.3.2

High level bindings to the CUDA Driver API
Documentation
//! Functions and types for working with CUDA modules.

use crate::error::{CudaResult, DropResult, ToResult};
use crate::function::Function;
use crate::memory::{CopyDestination, DeviceCopy, DevicePointer};
use crate::sys as cuda;
use std::ffi::{c_void, CStr, CString};
use std::fmt;
use std::marker::PhantomData;
use std::mem;
use std::os::raw::c_uint;
use std::path::Path;
use std::ptr;

/// A compiled CUDA module, loaded into a context.
#[derive(Debug)]
pub struct Module {
    inner: cuda::CUmodule,
}

unsafe impl Send for Module {}
unsafe impl Sync for Module {}

/// The possible optimization levels when JIT compiling a PTX module. `O4` by default (most optimized).
#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum OptLevel {
    O0 = 0,
    O1 = 1,
    O2 = 2,
    O3 = 3,
    O4 = 4,
}

/// The possible targets when JIT compiling a PTX module.
#[non_exhaustive]
#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum JitTarget {
    Compute20 = 20,
    Compute21 = 21,
    Compute30 = 30,
    Compute32 = 32,
    Compute35 = 35,
    Compute37 = 37,
    Compute50 = 50,
    Compute52 = 52,
    Compute53 = 53,
    Compute60 = 60,
    Compute61 = 61,
    Compute62 = 62,
    Compute70 = 70,
    Compute72 = 72,
    Compute75 = 75,
    Compute80 = 80,
    Compute86 = 86,
}

/// How to handle cases where a loaded module's data does not contain an exact match for the
/// specified architecture.
#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum JitFallback {
    /// Prefer to compile PTX if present if an exact binary match is not found.
    PreferPtx = 0,
    /// Prefer to fall back to a compatible binary code match if exact match is not found.
    /// This means the driver may pick binary code for `7.0` if your device is `7.2` for example.
    PreferCompatibleBinary = 1,
}

/// Different options that could be applied when loading a module.
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ModuleJitOption {
    /// Specifies the maximum amount of registers any compiled PTX is allowed to use.
    MaxRegisters(u32),
    /// Specifies the optimization level for the JIT compiler.
    OptLevel(OptLevel),
    /// Determines the PTX target from the current context's architecture. Cannot be combined with
    /// [`ModuleJitOption::Target`].
    DetermineTargetFromContext,
    /// Specifies the target for the JIT compiler. Cannot be combined with [`ModuleJitOption::DetermineTargetFromContext`].
    Target(JitTarget),
    /// Specifies how to handle cases where a loaded module's data does not have an exact match for the specified
    /// architecture.
    Fallback(JitFallback),
    /// Generates debug info in the compiled binary.
    GenenerateDebugInfo(bool),
    /// Generates line info in the compiled binary.
    GenerateLineInfo(bool),
}

impl ModuleJitOption {
    pub fn into_raw(opts: &[Self]) -> (Vec<cuda::CUjit_option>, Vec<*mut c_void>) {
        // And here we stumble across one of the most horrific things i have ever seen in my entire
        // journey of working with many parts of CUDA. As a background, CUDA usually wants an array
        // of pointers to values when it takes void**, after all, this is what is expected by anyone.
        // However, there is a SINGLE exception in the entire driver API, and that is cuModuleLoadDataEx,
        // it actually wants you to pass values by value instead of by ref if they fit into pointer length.
        // Therefore something like MaxRegisters should be passed as `u32 as usize as *mut c_void`.
        // This is completely undocumented. I initially brought this up to an nvidia developer,
        // who eventually was able to figure out this issue, currently it appears to be labeled "not a bug",
        // however this will likely be changed in the future, or at least get documented better. (hopefully)
        let mut raw_opts = Vec::with_capacity(opts.len());
        let mut raw_vals = Vec::with_capacity(opts.len());

        for opt in opts {
            match opt {
                Self::MaxRegisters(regs) => {
                    raw_opts.push(cuda::CUjit_option::CU_JIT_MAX_REGISTERS);
                    raw_vals.push(*regs as usize as *mut c_void);
                }
                Self::OptLevel(level) => {
                    raw_opts.push(cuda::CUjit_option::CU_JIT_OPTIMIZATION_LEVEL);
                    raw_vals.push(*level as usize as *mut c_void);
                }
                Self::DetermineTargetFromContext => {
                    raw_opts.push(cuda::CUjit_option::CU_JIT_TARGET_FROM_CUCONTEXT);
                }
                Self::Target(target) => {
                    raw_opts.push(cuda::CUjit_option::CU_JIT_TARGET);
                    raw_vals.push(*target as usize as *mut c_void);
                }
                Self::Fallback(fallback) => {
                    raw_opts.push(cuda::CUjit_option::CU_JIT_FALLBACK_STRATEGY);
                    raw_vals.push(*fallback as usize as *mut c_void);
                }
                Self::GenenerateDebugInfo(gen) => {
                    raw_opts.push(cuda::CUjit_option::CU_JIT_GENERATE_DEBUG_INFO);
                    raw_vals.push(*gen as usize as *mut c_void);
                }
                Self::GenerateLineInfo(gen) => {
                    raw_opts.push(cuda::CUjit_option::CU_JIT_GENERATE_LINE_INFO);
                    raw_vals.push(*gen as usize as *mut c_void)
                }
            }
        }
        (raw_opts, raw_vals)
    }
}

#[cfg(unix)]
fn path_to_bytes<P: AsRef<Path>>(path: P) -> Vec<u8> {
    use std::os::unix::ffi::OsStrExt;
    path.as_ref().as_os_str().as_bytes().to_vec()
}

#[cfg(not(unix))]
fn path_to_bytes<P: AsRef<Path>>(path: P) -> Vec<u8> {
    path.as_ref().to_string_lossy().to_string().into_bytes()
}

impl Module {
    /// Load a module from the given path into the current context.
    ///
    /// The given path should be either a cubin file, a ptx file, or a fatbin file such as
    /// those produced by `nvcc`.
    ///
    /// # Example
    ///
    /// ```
    /// # use cust::*;
    /// # use std::error::Error;
    /// # fn main() -> Result<(), Box<dyn Error>> {
    /// # let _ctx = quick_init()?;
    /// use cust::module::Module;
    /// use std::ffi::CString;
    ///
    /// let module = Module::from_file("./resources/add.ptx")?;
    /// # Ok(())
    /// # }
    /// ```
    pub fn from_file<P: AsRef<Path>>(path: P) -> CudaResult<Module> {
        unsafe {
            let mut bytes = path_to_bytes(path);
            if !bytes.contains(&0) {
                bytes.push(0);
            }
            let mut module = Module {
                inner: ptr::null_mut(),
            };
            cuda::cuModuleLoad(
                &mut module.inner as *mut cuda::CUmodule,
                bytes.as_ptr() as *const _,
            )
            .to_result()?;
            Ok(module)
        }
    }

    /// Creates a new module by loading a fatbin (fat binary) file.
    ///
    /// Fatbinary files are files that contain multiple ptx or cubin files. The driver will choose already-built
    /// cubin if it is present, and otherwise JIT compile any PTX in the file to cubin.
    ///
    /// # Example
    ///
    /// ```
    /// # use cust::*;
    /// # use std::error::Error;
    /// # fn main() -> Result<(), Box<dyn Error>> {
    /// # let _ctx = quick_init()?;
    /// use cust::module::Module;
    /// let fatbin_bytes = std::fs::read("./resources/add.fatbin")?;
    /// // will return InvalidSource if the fatbin does not contain any compatible code, meaning, either
    /// // cubin compiled for the same device architecture OR PTX that can be JITted into valid code.
    /// let module = Module::from_fatbin(&fatbin_bytes, &[])?;
    /// # Ok(())
    /// # }
    /// ```
    pub fn from_fatbin<T: AsRef<[u8]>>(
        bytes: T,
        options: &[ModuleJitOption],
    ) -> CudaResult<Module> {
        // fatbins can be loaded just like cubins, we just use different methods so it's explicit.
        // please don't use from_cubin for fatbins, that is pure chaos and ferris will come to your house
        Self::from_cubin(bytes, options)
    }

    /// Creates a new module by loading a cubin (CUDA Binary) file.
    ///
    /// Cubins are architecture/compute-capability specific files generated as the final step of the CUDA compilation
    /// process. They cannot be interchanged across compute capabilities unlike PTX (to some degree). You can create one
    /// using the PTX compiler APIs, the cust [`Linker`](crate::link::Linker), or nvcc (`nvcc a.ptx --cubin -arch=sm_XX`).
    ///
    /// # Example
    ///
    /// ```
    /// # use cust::*;
    /// # use std::error::Error;
    /// # fn main() -> Result<(), Box<dyn Error>> {
    /// # let _ctx = quick_init()?;
    /// use cust::module::Module;
    /// let cubin_bytes = std::fs::read("./resources/add.cubin")?;
    /// // will return InvalidSource if the cubin arch doesn't match the context's device arch!
    /// let module = Module::from_cubin(&cubin_bytes, &[])?;
    /// # Ok(())
    /// # }
    /// ```
    pub fn from_cubin<T: AsRef<[u8]>>(bytes: T, options: &[ModuleJitOption]) -> CudaResult<Module> {
        // it is very unclear whether cuda wants or doesn't want a null terminator. The method works
        // whether you have one or not. So for safety we just add one. In theory you can figure out the
        // length of an ELF image without a null terminator. But the docs are confusing, so we add one just
        // to be sure.
        let mut bytes = bytes.as_ref().to_vec();
        bytes.push(0);
        // SAFETY: the image is known to be dereferenceable
        unsafe { Self::load_module(bytes.as_ptr() as *const c_void, options) }
    }

    unsafe fn load_module(image: *const c_void, options: &[ModuleJitOption]) -> CudaResult<Module> {
        let mut module = Module {
            inner: ptr::null_mut(),
        };
        let (mut options, mut option_values) = ModuleJitOption::into_raw(options);
        cuda::cuModuleLoadDataEx(
            &mut module.inner as *mut cuda::CUmodule,
            image,
            options.len() as c_uint,
            options.as_mut_ptr(),
            option_values.as_mut_ptr(),
        )
        .to_result()?;
        Ok(module)
    }

    /// Creates a new module from a [`CStr`] pointing to PTX code.
    ///
    /// The driver will JIT the PTX into arch-specific cubin or pick already-cached cubin if available.
    pub fn from_ptx_cstr(cstr: &CStr, options: &[ModuleJitOption]) -> CudaResult<Module> {
        // SAFETY: the image is known to be dereferenceable
        unsafe { Self::load_module(cstr.as_ptr() as *const c_void, options) }
    }

    /// Creates a new module from a PTX string, allocating an intermediate buffer for the [`CString`].
    ///
    /// The driver will JIT the PTX into arch-specific cubin or pick already-cached cubin if available.
    ///
    /// # Panics
    ///
    /// Panics if `string` contains a nul.
    ///
    /// # Example
    ///
    /// ```
    /// # use cust::*;
    /// # use std::error::Error;
    /// # fn main() -> Result<(), Box<dyn Error>> {
    /// # let _ctx = quick_init()?;
    /// use cust::module::Module;
    /// let ptx = std::fs::read("./resources/add.ptx")?;
    /// let module = Module::from_ptx(&ptx, &[])?;
    /// # Ok(())
    /// # }
    /// ```
    pub fn from_ptx<T: AsRef<str>>(string: T, options: &[ModuleJitOption]) -> CudaResult<Module> {
        let cstr = CString::new(string.as_ref())
            .expect("string given to Module::from_str contained nul bytes");
        Self::from_ptx_cstr(cstr.as_c_str(), options)
    }

    /// Load a module from a normal (rust) string, implicitly making it into
    /// a cstring.
    #[deprecated(
        since = "0.3.0",
        note = "from_str was too generic of a name, use from_ptx instead, passing an empty slice of options (usually)"
    )]
    #[allow(clippy::should_implement_trait)]
    pub fn from_str<T: AsRef<str>>(string: T) -> CudaResult<Module> {
        let cstr = CString::new(string.as_ref())
            .expect("string given to Module::from_str contained nul bytes");
        #[allow(deprecated)]
        Self::load_from_string(cstr.as_c_str())
    }

    /// Load a module from a CStr.
    ///
    /// This is useful in combination with `include_str!`, to include the device code into the
    /// compiled executable.
    ///
    /// The given CStr must contain the bytes of a cubin file, a ptx file or a fatbin file such as
    /// those produced by `nvcc`.
    ///
    /// # Example
    ///
    /// ```
    /// # use cust::*;
    /// # use std::error::Error;
    /// # fn main() -> Result<(), Box<dyn Error>> {
    /// # let _ctx = quick_init()?;
    /// use cust::module::Module;
    /// use std::ffi::CString;
    ///
    /// let image = CString::new(include_str!("../resources/add.ptx"))?;
    /// let module = Module::load_from_string(&image)?;
    /// # Ok(())
    /// # }
    /// ```
    #[deprecated(
        since = "0.3.0",
        note = "load_from_string was an inconsistent name with inconsistent params, use from_ptx/from_ptx_cstr, passing 
    an empty slice of options (usually)
    "
    )]
    pub fn load_from_string(image: &CStr) -> CudaResult<Module> {
        unsafe {
            let mut module = Module {
                inner: ptr::null_mut(),
            };
            cuda::cuModuleLoadData(
                &mut module.inner as *mut cuda::CUmodule,
                image.as_ptr() as *const c_void,
            )
            .to_result()?;
            Ok(module)
        }
    }

    /// Get a reference to a global symbol, which can then be copied to/from.
    ///
    /// # Panics:
    ///
    /// This function panics if the size of the symbol is not the same as the `mem::sizeof<T>()`.
    ///
    /// # Examples
    ///
    /// ```
    /// # use cust::*;
    /// # use cust::memory::CopyDestination;
    /// # use std::error::Error;
    /// # fn main() -> Result<(), Box<dyn Error>> {
    /// # let _ctx = quick_init()?;
    /// use cust::module::Module;
    /// use std::ffi::CString;
    ///
    /// let ptx = CString::new(include_str!("../resources/add.ptx"))?;
    /// let module = Module::load_from_string(&ptx)?;
    /// let name = CString::new("my_constant")?;
    /// let symbol = module.get_global::<u32>(&name)?;
    /// let mut host_const = 0;
    /// symbol.copy_to(&mut host_const)?;
    /// assert_eq!(314, host_const);
    /// # Ok(())
    /// # }
    /// ```
    pub fn get_global<'a, T: DeviceCopy>(&'a self, name: &CStr) -> CudaResult<Symbol<'a, T>> {
        unsafe {
            let mut ptr: DevicePointer<T> = DevicePointer::null();
            let mut size: usize = 0;

            cuda::cuModuleGetGlobal_v2(
                &mut ptr as *mut DevicePointer<T> as *mut cuda::CUdeviceptr,
                &mut size as *mut usize,
                self.inner,
                name.as_ptr(),
            )
            .to_result()?;
            assert_eq!(size, mem::size_of::<T>());
            Ok(Symbol {
                ptr,
                module: PhantomData,
            })
        }
    }

    /// Get a reference to a kernel function which can then be launched.
    ///
    /// # Examples
    ///
    /// ```
    /// # use cust::*;
    /// # use std::error::Error;
    /// # fn main() -> Result<(), Box<dyn Error>> {
    /// # let _ctx = quick_init()?;
    /// use cust::module::Module;
    /// use std::ffi::CString;
    ///
    /// let ptx = CString::new(include_str!("../resources/add.ptx"))?;
    /// let module = Module::load_from_string(&ptx)?;
    /// let function = module.get_function("sum")?;
    /// # Ok(())
    /// # }
    /// ```
    pub fn get_function<T: AsRef<str>>(&'_ self, name: T) -> CudaResult<Function<'_>> {
        unsafe {
            let name = name.as_ref();
            let cstr = CString::new(name).expect("Argument to get_function had a nul");
            let mut func: cuda::CUfunction = ptr::null_mut();

            cuda::cuModuleGetFunction(
                &mut func as *mut cuda::CUfunction,
                self.inner,
                cstr.as_ptr(),
            )
            .to_result()?;
            Ok(Function::new(func, self))
        }
    }

    /// Destroy a `Module`, returning an error.
    ///
    /// Destroying a module can return errors from previous asynchronous work. This function
    /// destroys the given module and returns the error and the un-destroyed module on failure.
    ///
    /// # Example
    ///
    /// ```
    /// # use cust::*;
    /// # use std::error::Error;
    /// # fn main() -> Result<(), Box<dyn Error>> {
    /// # let _ctx = quick_init()?;
    /// use cust::module::Module;
    /// use std::ffi::CString;
    ///
    /// let ptx = CString::new(include_str!("../resources/add.ptx"))?;
    /// let module = Module::load_from_string(&ptx)?;
    /// match Module::drop(module) {
    ///     Ok(()) => println!("Successfully destroyed"),
    ///     Err((e, module)) => {
    ///         println!("Failed to destroy module: {:?}", e);
    ///         // Do something with module
    ///     },
    /// }
    /// # Ok(())
    /// # }
    /// ```
    pub fn drop(mut module: Module) -> DropResult<Module> {
        if module.inner.is_null() {
            return Ok(());
        }

        unsafe {
            let inner = mem::replace(&mut module.inner, ptr::null_mut());
            match cuda::cuModuleUnload(inner).to_result() {
                Ok(()) => {
                    mem::forget(module);
                    Ok(())
                }
                Err(e) => Err((e, Module { inner })),
            }
        }
    }
}
impl Drop for Module {
    fn drop(&mut self) {
        if self.inner.is_null() {
            return;
        }
        unsafe {
            // No choice but to panic if this fails...
            let module = mem::replace(&mut self.inner, ptr::null_mut());
            cuda::cuModuleUnload(module);
        }
    }
}

/// Handle to a symbol defined within a CUDA module.
#[derive(Debug)]
pub struct Symbol<'a, T: DeviceCopy> {
    ptr: DevicePointer<T>,
    module: PhantomData<&'a Module>,
}
impl<'a, T: DeviceCopy> crate::private::Sealed for Symbol<'a, T> {}
impl<'a, T: DeviceCopy> fmt::Pointer for Symbol<'a, T> {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        fmt::Pointer::fmt(&self.ptr, f)
    }
}
impl<'a, T: DeviceCopy> CopyDestination<T> for Symbol<'a, T> {
    fn copy_from(&mut self, val: &T) -> CudaResult<()> {
        let size = mem::size_of::<T>();
        if size != 0 {
            unsafe {
                cuda::cuMemcpyHtoD_v2(self.ptr.as_raw(), val as *const T as *const c_void, size)
                    .to_result()?
            }
        }
        Ok(())
    }

    fn copy_to(&self, val: &mut T) -> CudaResult<()> {
        let size = mem::size_of::<T>();
        if size != 0 {
            unsafe {
                cuda::cuMemcpyDtoH_v2(
                    val as *const T as *mut c_void,
                    self.ptr.as_raw() as u64,
                    size,
                )
                .to_result()?
            }
        }
        Ok(())
    }
}