use std::sync::Arc;
use cudarc::driver::sys as driver_sys;
use cudarc::driver::CudaStream;
use crate::error::GpuError;
use crate::sys::cuda_driver;
use super::managed::PrefetchTarget;
pub fn prefetch_async(
dev_ptr: driver_sys::CUdeviceptr,
bytes: usize,
target: PrefetchTarget,
stream: &Arc<CudaStream>,
) -> Result<(), GpuError> {
let location = unsafe {
let mut loc: driver_sys::CUmemLocation = std::mem::zeroed();
loc.type_ = match target {
PrefetchTarget::Device(_) => driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE,
PrefetchTarget::Cpu => driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_HOST,
};
loc
};
let _ = target;
cuda_driver::mem_prefetch_async_v2(dev_ptr, bytes, location, 0, stream.cu_stream())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn prefetch_async_returns_typed_error_on_no_driver() {
let host_loc = unsafe {
let mut loc: driver_sys::CUmemLocation = std::mem::zeroed();
loc.type_ = driver_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_HOST;
loc
};
let r = cuda_driver::mem_prefetch_async_v2(0, 0, host_loc, 0, std::ptr::null_mut());
match r {
Ok(()) => {}
Err(GpuError::Unrecoverable(_)) => {}
Err(GpuError::LibraryError { .. }) => {}
other => panic!("unexpected: {other:?}"),
}
}
}