use crate::{check_cufile_error, sys, CuFileResult};
use std::marker::PhantomData;
use std::os::raw::c_void;
pub struct BufferFlags;
impl BufferFlags {
pub const NONE: i32 = 0;
}
pub struct RegisteredBuffer<T> {
ptr: *mut c_void,
size: usize,
_phantom: PhantomData<T>,
}
impl<T> RegisteredBuffer<T> {
pub unsafe fn register(dev_ptr: *mut T, size: usize, flags: i32) -> CuFileResult<Self> {
let ptr = dev_ptr as *mut c_void;
check_cufile_error(sys::cuFileBufRegister(ptr, size, flags))?;
Ok(RegisteredBuffer {
ptr,
size,
_phantom: PhantomData,
})
}
pub fn as_ptr(&self) -> *mut T {
self.ptr as *mut T
}
pub fn as_void_ptr(&self) -> *mut c_void {
self.ptr
}
pub fn size(&self) -> usize {
self.size
}
pub fn len(&self) -> usize {
self.size / std::mem::size_of::<T>()
}
pub fn is_empty(&self) -> bool {
self.size == 0
}
}
impl<T> Drop for RegisteredBuffer<T> {
fn drop(&mut self) {
unsafe {
let _ = sys::cuFileBufDeregister(self.ptr);
}
}
}
unsafe impl<T> Send for RegisteredBuffer<T> {}
unsafe impl<T> Sync for RegisteredBuffer<T> {}
impl RegisteredBuffer<u8> {
pub unsafe fn register_bytes(dev_ptr: *mut u8, size: usize) -> CuFileResult<Self> {
Self::register(dev_ptr, size, BufferFlags::NONE)
}
}
impl<T> RegisteredBuffer<T>
where
T: Copy,
{
pub unsafe fn register_typed(dev_ptr: *mut T, count: usize) -> CuFileResult<Self> {
let size = count * std::mem::size_of::<T>();
Self::register(dev_ptr, size, BufferFlags::NONE)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_buffer_size_calculations() {
let size_bytes = 1024;
let count_u32 = 256;
assert_eq!(count_u32 * std::mem::size_of::<u32>(), size_bytes);
let count_f64 = 128; assert_eq!(count_f64 * std::mem::size_of::<f64>(), size_bytes);
}
#[test]
fn test_buffer_flags() {
assert_eq!(BufferFlags::NONE, 0);
}
}