use std::fs::File;
use std::mem::ManuallyDrop;
use std::ops::Range;
use std::sync::Arc;
use super::{CudaContext, CudaEvent, CudaStream, DevicePtr, DeviceSlice, SyncOnDrop};
use crate::driver::{result, sys, DriverError};
#[derive(Debug)]
pub struct ExternalMemory {
external_memory: sys::CUexternalMemory,
size: u64,
ctx: Arc<CudaContext>,
_file: ManuallyDrop<File>,
}
impl Drop for ExternalMemory {
fn drop(&mut self) {
let ctx = &self.ctx;
ctx.record_err(ctx.bind_to_thread());
ctx.record_err(unsafe {
result::external_memory::destroy_external_memory(self.external_memory)
});
#[cfg(windows)]
unsafe {
ManuallyDrop::<File>::drop(&mut self._file)
};
}
}
impl CudaContext {
#[cfg(any(unix, windows))]
pub unsafe fn import_external_memory(
self: &Arc<Self>,
file: File,
size: u64,
) -> Result<ExternalMemory, DriverError> {
self.bind_to_thread()?;
#[cfg(unix)]
let external_memory = unsafe {
use std::os::fd::AsRawFd;
result::external_memory::import_external_memory_opaque_fd(file.as_raw_fd(), size)
}?;
#[cfg(windows)]
let external_memory = unsafe {
use std::os::windows::io::AsRawHandle;
result::external_memory::import_external_memory_opaque_win32(file.as_raw_handle(), size)
}?;
Ok(ExternalMemory {
external_memory,
size,
ctx: self.clone(),
_file: ManuallyDrop::new(file),
})
}
}
impl ExternalMemory {
pub fn map_all(self) -> Result<MappedBuffer, DriverError> {
let size = self.size as usize;
self.map_range(0..size)
}
pub fn map_range(self, range: Range<usize>) -> Result<MappedBuffer, DriverError> {
assert!(range.start as u64 <= self.size);
assert!(range.end as u64 <= self.size);
let device_ptr = unsafe {
result::external_memory::get_mapped_buffer(
self.external_memory,
range.start as u64,
range.len() as u64,
)
}?;
let event = self.ctx.new_event(None)?;
let stream = self.ctx.default_stream();
Ok(MappedBuffer {
device_ptr,
len: range.len(),
external_memory: self,
event,
stream,
})
}
}
#[derive(Debug)]
pub struct MappedBuffer {
device_ptr: sys::CUdeviceptr,
len: usize,
external_memory: ExternalMemory,
event: CudaEvent,
stream: Arc<CudaStream>,
}
impl Drop for MappedBuffer {
fn drop(&mut self) {
let ctx = &self.external_memory.ctx;
ctx.record_err(ctx.bind_to_thread());
ctx.record_err(self.stream.wait(&self.event));
ctx.record_err(unsafe { result::memory_free(self.device_ptr) })
}
}
impl DeviceSlice<u8> for MappedBuffer {
fn len(&self) -> usize {
self.len
}
fn stream(&self) -> &Arc<CudaStream> {
&self.stream
}
}
impl DevicePtr<u8> for MappedBuffer {
fn device_ptr<'a>(&'a self, stream: &'a CudaStream) -> (sys::CUdeviceptr, SyncOnDrop<'a>) {
(
self.device_ptr,
SyncOnDrop::Record(Some((&self.event, stream))),
)
}
}