use std::ffi::c_void;
use oxicuda_driver::loader::try_driver;
use crate::error::{CudaRtError, CudaRtResult};
use crate::stream::CudaStream;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct DevicePtr(pub u64);
impl DevicePtr {
pub const NULL: Self = Self(0);
#[must_use]
pub fn is_null(self) -> bool {
self.0 == 0
}
#[must_use]
pub fn offset(self, offset: isize) -> Self {
Self((self.0 as i64 + offset as i64) as u64)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MemcpyKind {
HostToHost = 0,
HostToDevice = 1,
DeviceToHost = 2,
DeviceToDevice = 3,
Default = 4,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MemAttachFlags {
Global = 1,
Host = 2,
Single = 4,
}
pub fn malloc(size: usize) -> CudaRtResult<DevicePtr> {
if size == 0 {
return Ok(DevicePtr::NULL);
}
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let mut ptr: u64 = 0;
let rc = unsafe { (api.cu_mem_alloc_v2)(&raw mut ptr, size) };
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::MemoryAllocation));
}
Ok(DevicePtr(ptr))
}
pub fn free(ptr: DevicePtr) -> CudaRtResult<()> {
if ptr.is_null() {
return Ok(());
}
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let rc = unsafe { (api.cu_mem_free_v2)(ptr.0) };
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDevicePointer));
}
Ok(())
}
pub fn malloc_host(size: usize) -> CudaRtResult<*mut c_void> {
if size == 0 {
return Ok(std::ptr::null_mut());
}
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let mut ptr: *mut c_void = std::ptr::null_mut();
let rc = unsafe { (api.cu_mem_alloc_host_v2)(&raw mut ptr, size) };
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::MemoryAllocation));
}
Ok(ptr)
}
pub unsafe fn free_host(ptr: *mut c_void) -> CudaRtResult<()> {
if ptr.is_null() {
return Ok(());
}
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let rc = unsafe { (api.cu_mem_free_host)(ptr) };
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidHostPointer));
}
Ok(())
}
pub fn malloc_managed(size: usize, flags: MemAttachFlags) -> CudaRtResult<DevicePtr> {
if size == 0 {
return Ok(DevicePtr::NULL);
}
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let mut ptr: u64 = 0;
let rc = unsafe { (api.cu_mem_alloc_managed)(&raw mut ptr, size, flags as u32) };
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::MemoryAllocation));
}
Ok(DevicePtr(ptr))
}
pub fn malloc_pitch(width_bytes: usize, height: usize) -> CudaRtResult<(DevicePtr, usize)> {
if width_bytes == 0 || height == 0 {
return Ok((DevicePtr::NULL, 0));
}
let align: usize = 512;
let pitch = width_bytes.div_ceil(align) * align;
let size = pitch * height;
let ptr = malloc(size)?;
Ok((ptr, pitch))
}
pub unsafe fn memcpy(
dst: *mut c_void,
src: *const c_void,
count: usize,
kind: MemcpyKind,
) -> CudaRtResult<()> {
if count == 0 {
return Ok(());
}
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let rc = match kind {
MemcpyKind::HostToHost => {
unsafe { std::ptr::copy_nonoverlapping(src as *const u8, dst as *mut u8, count) };
0u32
}
MemcpyKind::HostToDevice => {
let dst_ptr = dst as u64;
unsafe { (api.cu_memcpy_htod_v2)(dst_ptr, src, count) }
}
MemcpyKind::DeviceToHost => {
let src_ptr = src as u64;
unsafe { (api.cu_memcpy_dtoh_v2)(dst, src_ptr, count) }
}
MemcpyKind::DeviceToDevice => {
let dst_ptr = dst as u64;
let src_ptr = src as u64;
unsafe { (api.cu_memcpy_dtod_v2)(dst_ptr, src_ptr, count) }
}
MemcpyKind::Default => {
let dst_ptr = dst as u64;
unsafe { (api.cu_memcpy_htod_v2)(dst_ptr, src, count) }
}
};
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidMemcpyDirection));
}
Ok(())
}
pub unsafe fn memcpy_async(
dst: *mut c_void,
src: *const c_void,
count: usize,
kind: MemcpyKind,
stream: &CudaStream,
) -> CudaRtResult<()> {
if count == 0 {
return Ok(());
}
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let rc = match kind {
MemcpyKind::HostToHost => {
unsafe { std::ptr::copy_nonoverlapping(src as *const u8, dst as *mut u8, count) };
0u32
}
MemcpyKind::HostToDevice | MemcpyKind::Default => {
let dst_ptr = dst as u64;
unsafe { (api.cu_memcpy_htod_async_v2)(dst_ptr, src, count, stream.raw()) }
}
MemcpyKind::DeviceToHost => {
let src_ptr = src as u64;
unsafe { (api.cu_memcpy_dtoh_async_v2)(dst, src_ptr, count, stream.raw()) }
}
MemcpyKind::DeviceToDevice => {
let dst_ptr = dst as u64;
let src_ptr = src as u64;
unsafe { (api.cu_memcpy_dtod_v2)(dst_ptr, src_ptr, count) }
}
};
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidMemcpyDirection));
}
Ok(())
}
pub fn memcpy_h2d<T: Copy>(dst: DevicePtr, src: &[T]) -> CudaRtResult<()> {
let bytes = std::mem::size_of_val(src);
unsafe {
memcpy(
dst.0 as *mut c_void,
src.as_ptr() as *const c_void,
bytes,
MemcpyKind::HostToDevice,
)
}
}
pub fn memcpy_d2h<T: Copy>(dst: &mut [T], src: DevicePtr) -> CudaRtResult<()> {
let bytes = std::mem::size_of_val(dst);
unsafe {
memcpy(
dst.as_mut_ptr() as *mut c_void,
src.0 as *const c_void,
bytes,
MemcpyKind::DeviceToHost,
)
}
}
pub fn memcpy_d2d(dst: DevicePtr, src: DevicePtr, bytes: usize) -> CudaRtResult<()> {
unsafe {
memcpy(
dst.0 as *mut c_void,
src.0 as *const c_void,
bytes,
MemcpyKind::DeviceToDevice,
)
}
}
pub fn memset(ptr: DevicePtr, value: u8, count: usize) -> CudaRtResult<()> {
if count == 0 || ptr.is_null() {
return Ok(());
}
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let rc = unsafe { (api.cu_memset_d8_v2)(ptr.0, value, count) };
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDevicePointer));
}
Ok(())
}
pub fn memset32(ptr: DevicePtr, value: u32, count: usize) -> CudaRtResult<()> {
if count == 0 || ptr.is_null() {
return Ok(());
}
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let rc = unsafe { (api.cu_memset_d32_v2)(ptr.0, value, count) };
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDevicePointer));
}
Ok(())
}
pub fn mem_get_info() -> CudaRtResult<(usize, usize)> {
let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
let mut free: usize = 0;
let mut total: usize = 0;
let rc = unsafe { (api.cu_mem_get_info_v2)(&raw mut free, &raw mut total) };
if rc != 0 {
return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::Unknown));
}
Ok((free, total))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn malloc_zero_returns_null() {
let result = malloc(0);
assert!(matches!(result, Ok(DevicePtr(0))));
}
#[test]
fn free_null_is_noop() {
let result = free(DevicePtr::NULL);
assert!(result.is_ok() || result.is_err()); }
#[test]
fn device_ptr_offset() {
let p = DevicePtr(1000);
assert_eq!(p.offset(8), DevicePtr(1008));
assert_eq!(p.offset(-8), DevicePtr(992));
}
#[test]
fn device_ptr_is_null() {
assert!(DevicePtr::NULL.is_null());
assert!(!DevicePtr(1).is_null());
}
#[test]
fn malloc_pitch_returns_aligned_pitch() {
let (_, pitch) = malloc_pitch(100, 32).unwrap_or((DevicePtr::NULL, 512));
assert_eq!(pitch % 512, 0);
assert!(pitch >= 100);
}
#[test]
fn memcpy_kind_values() {
assert_eq!(MemcpyKind::HostToHost as u32, 0);
assert_eq!(MemcpyKind::HostToDevice as u32, 1);
assert_eq!(MemcpyKind::DeviceToHost as u32, 2);
assert_eq!(MemcpyKind::DeviceToDevice as u32, 3);
assert_eq!(MemcpyKind::Default as u32, 4);
}
}