cumath 0.2.7

Cuda-based matrix/vector computations

#![allow(dead_code)]

use super::cuda_ffi::*;


pub enum StructCurandGenerator {}


#[derive(PartialEq, Debug, Clone, Copy)]
#[repr(u32)]
pub enum CurandStatus {
    Success = 0,
    VersionMismatch = 100,
    NotInitialized = 101,
    AllocationFailed = 102,
    TypeError = 103,
    OutOfRand = 104,
    LengthNotMultiple = 105,
    DoublePrecisionRequired = 106,
    LaunchFailure = 201,
    PreexistingFailure = 202,
    InitializationFailed = 203,
    ArchMismatch = 204,
    InternalError = 999,
}
impl CurandStatus {
    fn assert_success(&self) {
        assert_eq!(self, &CurandStatus::Success)
    }
    fn get_error_str(&self) -> Option<&'static str> {
        match *self {
            CurandStatus::Success => None,
            CurandStatus::VersionMismatch => Some("VersionMismatch"),
            CurandStatus::NotInitialized => Some("NotInitialized"),
            CurandStatus::AllocationFailed => Some("AllocationFailed"),
            CurandStatus::TypeError => Some("TypeError"),
            CurandStatus::OutOfRand => Some("OutOfRand"),
            CurandStatus::LengthNotMultiple => Some("LengthNotMultiple"),
            CurandStatus::DoublePrecisionRequired => Some("DoublePrecisionRequired"),
            CurandStatus::LaunchFailure => Some("LaunchFailure"),
            CurandStatus::PreexistingFailure => Some("PreexistingFailure"),
            CurandStatus::InitializationFailed => Some("InitializationFailed"),
            CurandStatus::ArchMismatch => Some("ArchMismatch"),
            CurandStatus::InternalError => Some("InternalError"),
        }
    }
}

#[derive(PartialEq, Debug, Clone, Copy)]
#[repr(u32)]
#[allow(non_camel_case_types)]
pub enum CurandRngType {
    Test = 0,
    PseudoDefault = 100,
    PseudoXORWOW = 101,
    PseudoMRG32K3A = 121,
    PseudoMTGP32 = 141,
    PseudoMT19937 = 142,
    PseudoPHILOX4_32 = 161,
    QuasiDefault = 200,
    QuasiSOBOL32 = 201,
    QuasiScrambledSOBOL32 = 202,
    QuasiSOBOL64 = 203,
    QuasiScrambledSOBOL64 = 204,
}


extern {
    fn curandCreateGenerator(generator: *mut*mut StructCurandGenerator, rng_type: CurandRngType) -> CurandStatus;
    fn curandDestroyGenerator(generator: *mut StructCurandGenerator) -> CurandStatus;

    fn curandGenerateUniform(generator: *mut StructCurandGenerator, outputPtr: *mut f32, num: usize) -> CurandStatus;
    fn curandGenerateNormal(generator: *mut StructCurandGenerator, outputPtr: *mut f32, num: usize, mean: f32, stddev: f32) -> CurandStatus;
    fn curandGenerateLogNormal(generator: *mut StructCurandGenerator, outputPtr: *mut f32, num: usize, mean: f32, stddev: f32) -> CurandStatus;
    fn curandGeneratePoisson(generator: *mut StructCurandGenerator, outputPtr: *mut f32, num: usize, lambda: f32) -> CurandStatus;

    fn curandSetStream(generator: *mut StructCurandGenerator, stream: cudaStream_t) -> CurandStatus;
}


#[inline]
pub fn curand_create_generator(generator: *mut*mut StructCurandGenerator, rng_type: CurandRngType) -> Option<&'static str> {
    #[cfg(not(feature = "disable_checks"))] {
        unsafe { curandCreateGenerator(generator, rng_type) }.get_error_str()
    }
    #[cfg(feature = "disable_checks")] {
        unsafe { curandCreateGenerator(generator, rng_type) };
        None
    }
}

#[inline]
pub fn curand_destroy_generator(generator: *mut StructCurandGenerator) {
    #[cfg(not(feature = "disable_checks"))] {
        unsafe { curandDestroyGenerator(generator) }.assert_success()
    }
    #[cfg(feature = "disable_checks")] {
        unsafe { curandDestroyGenerator(generator) };
    }
}

#[inline]
pub fn curand_generate_uniform(generator: *mut StructCurandGenerator, output_ptr: *mut f32, num: usize) {
    #[cfg(not(feature = "disable_checks"))] {
        unsafe { curandGenerateUniform(generator, output_ptr, num) }.assert_success()
    }
    #[cfg(feature = "disable_checks")] {
        unsafe { curandGenerateUniform(generator, output_ptr, num) };
    }
}

#[inline]
pub fn curand_generate_normal(generator: *mut StructCurandGenerator, output_ptr: *mut f32, num: usize, mean: f32, stddev: f32) {
    #[cfg(not(feature = "disable_checks"))] {
        unsafe { curandGenerateNormal(generator, output_ptr, num, mean, stddev) }.assert_success()
    }
    #[cfg(feature = "disable_checks")] {
        unsafe { curandGenerateNormal(generator, output_ptr, num, mean, stddev) };
    }
}

#[inline]
pub fn curand_generate_lognormal(generator: *mut StructCurandGenerator, output_ptr: *mut f32, num: usize, mean: f32, stddev: f32) {
    #[cfg(not(feature = "disable_checks"))] {
        unsafe { curandGenerateLogNormal(generator, output_ptr, num, mean, stddev) }.assert_success()
    }
    #[cfg(feature = "disable_checks")] {
        unsafe { curandGenerateLogNormal(generator, output_ptr, num, mean, stddev) };
    }
}

#[inline]
pub fn curand_generate_poisson(generator: *mut StructCurandGenerator, output_ptr: *mut f32, num: usize, lambda: f32) {
    #[cfg(not(feature = "disable_checks"))] {
        unsafe { curandGeneratePoisson(generator, output_ptr, num, lambda) }.assert_success()
    }
    #[cfg(feature = "disable_checks")] {
        unsafe { curandGeneratePoisson(generator, output_ptr, num, lambda) };
    }
}

#[inline]
pub fn curand_set_stream(generator: *mut StructCurandGenerator, stream: cudaStream_t) {
    #[cfg(not(feature = "disable_checks"))] {
        unsafe { curandSetStream(generator, stream) }.assert_success()
    }
    #[cfg(feature = "disable_checks")] {
        unsafe { curandSetStream(generator, stream) };
    }
}