use crate::{check_cufile_error, sys, CuFileResult};
use std::fs::File;
use std::os::raw::{c_int, c_void};
use std::os::unix::io::AsRawFd;
use std::ptr;
pub struct CuFileHandle {
handle: sys::CUfileHandle_t,
_file: File, }
impl CuFileHandle {
pub fn register(file: File) -> CuFileResult<Self> {
let fd = file.as_raw_fd();
let mut handle = ptr::null_mut();
let mut descr = sys::CUfileDescr_t {
type_: sys::CUfileFileHandleType::CU_FILE_HANDLE_TYPE_OPAQUE_FD,
handle: sys::CUfileDescrHandle { fd: fd as c_int },
fs_ops: ptr::null_mut(),
};
unsafe {
check_cufile_error(sys::cuFileHandleRegister(&mut handle, &mut descr))?;
}
Ok(CuFileHandle {
handle,
_file: file,
})
}
pub(crate) fn raw_handle(&self) -> sys::CUfileHandle_t {
self.handle
}
pub fn read(
&self,
buffer: &mut [u8],
file_offset: i64,
dest_offset: i64,
) -> CuFileResult<usize> {
unsafe {
let ret = self.read_raw(
buffer.as_mut_ptr() as *mut c_void,
buffer.len(),
file_offset,
dest_offset,
)?;
Ok(ret as usize)
}
}
pub unsafe fn read_raw(
&self,
dest_base: *mut c_void,
size: usize,
file_offset: i64,
dest_offset: i64,
) -> CuFileResult<isize> {
unsafe {
let ret = sys::cuFileRead(self.handle, dest_base, size, file_offset, dest_offset);
if ret < 0 {
if ret == -1 {
let errno = std::io::Error::last_os_error().raw_os_error().unwrap_or(0);
check_cufile_error(errno)?;
} else {
panic!("Unexpected CuFile error code: {}", ret);
}
}
Ok(ret)
}
}
pub fn write(&self, buffer: &[u8], file_offset: i64, dest_offset: i64) -> CuFileResult<usize> {
unsafe {
let ret = self.write_raw(
buffer.as_ptr() as *const c_void,
buffer.len(),
file_offset,
dest_offset,
)?;
Ok(ret as usize)
}
}
pub unsafe fn write_raw(
&self,
dest_base: *const c_void,
size: usize,
file_offset: i64,
dest_offset: i64,
) -> CuFileResult<isize> {
unsafe {
let ret = sys::cuFileWrite(self.handle, dest_base, size, file_offset, dest_offset);
if ret < 0 {
if ret == -1 {
let errno = std::io::Error::last_os_error().raw_os_error().unwrap_or(0);
check_cufile_error(errno)?;
} else {
panic!("Unexpected CuFile error code: {}", ret);
}
}
Ok(ret)
}
}
}
impl Drop for CuFileHandle {
fn drop(&mut self) {
unsafe {
let _ = sys::cuFileHandleDeregister(self.handle);
}
}
}
unsafe impl Send for CuFileHandle {}
unsafe impl Sync for CuFileHandle {}
#[cfg(test)]
mod tests {
use crate::CuFileError;
use super::*;
use std::fs::OpenOptions;
use tempfile::tempdir;
#[test]
fn test_handle_creation() {
let temp_dir = tempdir().unwrap();
let file_path = temp_dir.path().join("test_file.dat");
let file = OpenOptions::new()
.create(true)
.read(true)
.write(true)
.open(&file_path)
.unwrap();
CuFileHandle::register(file).unwrap();
}
#[test]
fn test_handle_already_registered() {
use std::mem;
use std::os::unix::io::{FromRawFd, IntoRawFd};
let temp_dir = tempdir().unwrap();
let file_path = temp_dir.path().join("test_file.dat");
let file = OpenOptions::new()
.create(true)
.read(true)
.write(true)
.open(&file_path)
.unwrap();
let raw_fd = file.into_raw_fd();
let file1 = unsafe { File::from_raw_fd(raw_fd) };
let file2 = unsafe { File::from_raw_fd(raw_fd) };
let hdl1 = CuFileHandle::register(file1).unwrap();
match CuFileHandle::register(file2) {
Ok(_handle) => {
assert!(
false,
"Handle created successfully even though it should have failed"
);
}
Err(e) => {
assert_eq!(
e,
CuFileError::HandleAlreadyRegistered,
"Handle creation failed (expected): {:?}",
e
);
}
}
mem::forget(hdl1);
}
#[test]
fn test_handle_invalid_file_type() {
let temp_dir = tempdir().unwrap();
let file_path = temp_dir.path().join("/proc/self/fd/0");
let file = OpenOptions::new()
.create(true)
.read(true)
.write(true)
.open(&file_path)
.unwrap();
match CuFileHandle::register(file) {
Ok(_handle) => {
assert!(
false,
"Handle created successfully even though it should have failed"
);
}
Err(e) => {
assert_eq!(
e,
CuFileError::InvalidFile,
"Handle creation failed (expected): {:?}",
e
);
}
}
}
#[test]
fn test_handle_read_permissions_error() {
let temp_dir = tempdir().unwrap();
let file_path = temp_dir.path().join("test_file.dat");
let file = OpenOptions::new()
.create(true)
.read(false)
.write(true)
.open(&file_path)
.unwrap();
let handle = CuFileHandle::register(file).unwrap();
let mut buffer = [0u8; 10];
let ret = handle.read(&mut buffer, 0, 0);
assert_eq!(ret, Err(CuFileError::InvalidFileOpenFlag));
}
}