use super::DeviceCopy;
use crate::error::*;
use crate::memory::DevicePointer;
use crate::memory::UnifiedPointer;
use crate::prelude::Stream;
use crate::sys as cuda;
use std::mem;
use std::os::raw::c_void;
use std::ptr;
pub unsafe fn cuda_malloc<T: DeviceCopy>(count: usize) -> CudaResult<DevicePointer<T>> {
let size = count.checked_mul(mem::size_of::<T>()).unwrap_or(0);
if size == 0 {
return Err(CudaError::InvalidMemoryAllocation);
}
let mut ptr = 0;
cuda::cuMemAlloc_v2(&mut ptr, size).to_result()?;
Ok(DevicePointer::from_raw(ptr))
}
pub unsafe fn cuda_malloc_async<T: DeviceCopy>(
stream: &Stream,
count: usize,
) -> CudaResult<DevicePointer<T>> {
let size = count.checked_mul(mem::size_of::<T>()).unwrap_or(0);
if size == 0 {
return Err(CudaError::InvalidMemoryAllocation);
}
let mut ptr: *mut c_void = ptr::null_mut();
cuda::cuMemAllocAsync(
&mut ptr as *mut *mut c_void as *mut u64,
size,
stream.as_inner(),
)
.to_result()?;
let ptr = ptr as *mut T;
Ok(DevicePointer::from_raw(ptr as cuda::CUdeviceptr))
}
pub unsafe fn cuda_free_async<T: DeviceCopy>(
stream: &Stream,
p: DevicePointer<T>,
) -> CudaResult<()> {
if mem::size_of::<T>() == 0 {
return Err(CudaError::InvalidMemoryAllocation);
}
cuda::cuMemFreeAsync(p.as_raw(), stream.as_inner()).to_result()
}
pub unsafe fn cuda_malloc_unified<T: DeviceCopy>(count: usize) -> CudaResult<UnifiedPointer<T>> {
let size = count.checked_mul(mem::size_of::<T>()).unwrap_or(0);
if size == 0 {
return Err(CudaError::InvalidMemoryAllocation);
}
let mut ptr: *mut c_void = ptr::null_mut();
cuda::cuMemAllocManaged(
&mut ptr as *mut *mut c_void as *mut u64,
size,
cuda::CUmemAttach_flags_enum::CU_MEM_ATTACH_GLOBAL as u32,
)
.to_result()?;
let ptr = ptr as *mut T;
Ok(UnifiedPointer::wrap(ptr as *mut T))
}
pub unsafe fn cuda_free<T: DeviceCopy>(ptr: DevicePointer<T>) -> CudaResult<()> {
if ptr.is_null() {
return Err(CudaError::InvalidMemoryAllocation);
}
cuda::cuMemFree_v2(ptr.as_raw()).to_result()?;
Ok(())
}
pub unsafe fn cuda_free_unified<T: DeviceCopy>(mut p: UnifiedPointer<T>) -> CudaResult<()> {
let ptr = p.as_raw_mut();
if ptr.is_null() {
return Err(CudaError::InvalidMemoryAllocation);
}
cuda::cuMemFree_v2(ptr as u64).to_result()?;
Ok(())
}
pub unsafe fn cuda_malloc_locked<T>(count: usize) -> CudaResult<*mut T> {
let size = count.checked_mul(mem::size_of::<T>()).unwrap_or(0);
if size == 0 {
return Err(CudaError::InvalidMemoryAllocation);
}
let mut ptr: *mut c_void = ptr::null_mut();
cuda::cuMemAllocHost_v2(&mut ptr as *mut *mut c_void, size).to_result()?;
let ptr = ptr as *mut T;
Ok(ptr as *mut T)
}
pub unsafe fn cuda_free_locked<T>(ptr: *mut T) -> CudaResult<()> {
if ptr.is_null() {
return Err(CudaError::InvalidMemoryAllocation);
}
cuda::cuMemFreeHost(ptr as *mut c_void).to_result()?;
Ok(())
}
#[cfg(test)]
mod test {
use super::*;
#[derive(Clone, Copy, Debug)]
struct ZeroSizedType;
unsafe impl DeviceCopy for ZeroSizedType {}
#[test]
fn test_cuda_malloc() {
let _context = crate::quick_init().unwrap();
unsafe {
let device_mem = cuda_malloc::<u64>(1).unwrap();
assert!(!device_mem.is_null());
cuda_free(device_mem).unwrap();
}
}
#[test]
fn test_cuda_malloc_zero_bytes() {
let _context = crate::quick_init().unwrap();
unsafe {
assert_eq!(
CudaError::InvalidMemoryAllocation,
cuda_malloc::<u64>(0).unwrap_err()
);
}
}
#[test]
fn test_cuda_malloc_zero_sized() {
let _context = crate::quick_init().unwrap();
unsafe {
assert_eq!(
CudaError::InvalidMemoryAllocation,
cuda_malloc::<ZeroSizedType>(10).unwrap_err()
);
}
}
#[test]
fn test_cuda_alloc_overflow() {
let _context = crate::quick_init().unwrap();
unsafe {
assert_eq!(
CudaError::InvalidMemoryAllocation,
cuda_malloc::<u64>(::std::usize::MAX - 1).unwrap_err()
);
}
}
#[test]
fn test_cuda_malloc_unified() {
let _context = crate::quick_init().unwrap();
unsafe {
let mut unified = cuda_malloc_unified::<u64>(1).unwrap();
assert!(!unified.is_null());
*unified.as_raw_mut() = 64;
cuda_free_unified(unified).unwrap();
}
}
#[test]
fn test_cuda_malloc_unified_zero_bytes() {
let _context = crate::quick_init().unwrap();
unsafe {
assert_eq!(
CudaError::InvalidMemoryAllocation,
cuda_malloc_unified::<u64>(0).unwrap_err()
);
}
}
#[test]
fn test_cuda_malloc_unified_zero_sized() {
let _context = crate::quick_init().unwrap();
unsafe {
assert_eq!(
CudaError::InvalidMemoryAllocation,
cuda_malloc_unified::<ZeroSizedType>(10).unwrap_err()
);
}
}
#[test]
fn test_cuda_malloc_unified_overflow() {
let _context = crate::quick_init().unwrap();
unsafe {
assert_eq!(
CudaError::InvalidMemoryAllocation,
cuda_malloc_unified::<u64>(::std::usize::MAX - 1).unwrap_err()
);
}
}
#[test]
fn test_cuda_free_null() {
let _context = crate::quick_init().unwrap();
unsafe {
assert_eq!(
CudaError::InvalidMemoryAllocation,
cuda_free(DevicePointer::<u64>::null()).unwrap_err()
);
}
}
#[test]
fn test_cuda_malloc_locked() {
let _context = crate::quick_init().unwrap();
unsafe {
let locked = cuda_malloc_locked::<u64>(1).unwrap();
assert!(!locked.is_null());
*locked = 64;
cuda_free_locked(locked).unwrap();
}
}
#[test]
fn test_cuda_malloc_locked_zero_bytes() {
let _context = crate::quick_init().unwrap();
unsafe {
assert_eq!(
CudaError::InvalidMemoryAllocation,
cuda_malloc_locked::<u64>(0).unwrap_err()
);
}
}
#[test]
fn test_cuda_malloc_locked_zero_sized() {
let _context = crate::quick_init().unwrap();
unsafe {
assert_eq!(
CudaError::InvalidMemoryAllocation,
cuda_malloc_locked::<ZeroSizedType>(10).unwrap_err()
);
}
}
#[test]
fn test_cuda_malloc_locked_overflow() {
let _context = crate::quick_init().unwrap();
unsafe {
assert_eq!(
CudaError::InvalidMemoryAllocation,
cuda_malloc_locked::<u64>(::std::usize::MAX - 1).unwrap_err()
);
}
}
#[test]
fn test_cuda_free_locked_null() {
let _context = crate::quick_init().unwrap();
unsafe {
assert_eq!(
CudaError::InvalidMemoryAllocation,
cuda_free_locked(::std::ptr::null_mut::<u64>()).unwrap_err()
);
}
}
}