cuda-bindings 0.1.0

NVIDIA CUDA bindings.
// SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

//! CUDA toolkit bindings with dynamic library loading.
//!
//! Types, enums, and constants are generated at build time from CUDA headers
//! via `bindgen`. The CUDA driver (`libcuda`) and cuRAND (`libcurand`)
//! libraries are loaded **dynamically at runtime** via `libloading`, so they
//! do not need to be present at build time.

#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]

// Types, enums, and constants generated by bindgen (no function declarations).
include!(concat!(env!("OUT_DIR"), "/types.rs"));

#[inline]
pub fn set_mem_location_id(loc: &mut CUmemLocation, id: ::core::ffi::c_int) {
    #[cfg(cu_mem_location_anon_union)]
    {
        loc.__bindgen_anon_1.id = id;
    }
    #[cfg(not(cu_mem_location_anon_union))]
    {
        loc.id = id;
    }
}

#[allow(dead_code)]
mod generated_cuda {
    use super::*;
    include!(concat!(env!("OUT_DIR"), "/cuda_driver_api.rs"));
}

#[allow(dead_code)]
mod generated_curand {
    use super::*;
    include!(concat!(env!("OUT_DIR"), "/curand_api.rs"));
}

// Dynamic loading of CUDA driver and cuRAND function symbols.
mod dyn_load;
pub use dyn_load::*;

use std::env;

pub fn cuda_toolkit_dir() -> String {
    env::var("CUDA_TOOLKIT_PATH").expect("CUDA_TOOLKIT_PATH is required but not set")
}

#[cfg(test)]
mod cuda_tests {
    use super::*;
    use std::ffi::{c_int, c_ulonglong};
    use std::mem::MaybeUninit;

    fn init() -> (CUdevice, CUcontext) {
        unsafe {
            let init_res = crate::cuInit(0);
            assert_eq!(init_res, 0, "init failed");

            let mut dev: MaybeUninit<crate::CUdevice> = MaybeUninit::uninit();
            let dev_result = crate::cuDeviceGet(dev.as_mut_ptr(), 0 as c_int);
            assert_eq!(dev_result, 0, "get device failed");
            let dev = dev.assume_init();

            let mut ctx = MaybeUninit::uninit();
            let ctx_res = crate::cuDevicePrimaryCtxRetain(ctx.as_mut_ptr(), dev);
            assert_eq!(ctx_res, 0, "retain context failed");
            let ctx = ctx.assume_init();
            assert_eq!(
                crate::cuCtxSetCurrent(ctx),
                0,
                "failed to set current context"
            );
            (dev, ctx)
        }
    }

    unsafe fn get_dptr(bytesize: usize) -> CUdeviceptr {
        let mut dptr: MaybeUninit<CUdeviceptr> = MaybeUninit::uninit();
        assert!(cuMemAlloc_v2(dptr.as_mut_ptr(), bytesize) == 0);
        dptr.assume_init()
    }

    unsafe fn get_rng() -> curandGenerator_t {
        let mut curand_gen_uninited: MaybeUninit<curandGenerator_t> = MaybeUninit::uninit();
        let curand_rng_type = curandRngType_CURAND_RNG_PSEUDO_DEFAULT;
        assert!(curandCreateGenerator(curand_gen_uninited.as_mut_ptr(), curand_rng_type) == 0);
        curand_gen_uninited.assume_init()
    }

    unsafe fn set_seed(gen: curandGenerator_t, seed: u64) {
        assert!(curandSetPseudoRandomGeneratorSeed(gen, c_ulonglong::from(seed)) == 0);
    }

    #[test]
    fn test_curand() {
        unsafe {
            let (_dev, _ctx) = init();
            let curand_gen = get_rng();
            set_seed(curand_gen, 123);
            let num_elements = 32;
            let bytesize = num_elements * size_of::<f32>();
            let dptr = get_dptr(bytesize);
            assert!(
                curandGenerateNormal(curand_gen, dptr as *mut f32, num_elements, 0.0, 1.0) == 0
            );
            assert!(curandDestroyGenerator(curand_gen) == 0);
            assert!(cuMemFree_v2(dptr) == 0);
        }
    }
}