use std::ffi::c_void;
use oxicuda_driver::error::{CudaError, CudaResult};
use oxicuda_driver::loader::try_driver;
use oxicuda_driver::stream::Stream;
use crate::device_buffer::DeviceBuffer;
use crate::host_buffer::PinnedBuffer;
pub fn copy_htod<T: Copy>(dst: &mut DeviceBuffer<T>, src: &[T]) -> CudaResult<()> {
if src.len() != dst.len() {
return Err(CudaError::InvalidValue);
}
let byte_size = dst.byte_size();
let api = try_driver()?;
let rc = unsafe {
(api.cu_memcpy_htod_v2)(
dst.as_device_ptr(),
src.as_ptr().cast::<c_void>(),
byte_size,
)
};
oxicuda_driver::check(rc)
}
pub fn copy_dtoh<T: Copy>(dst: &mut [T], src: &DeviceBuffer<T>) -> CudaResult<()> {
if dst.len() != src.len() {
return Err(CudaError::InvalidValue);
}
let byte_size = src.byte_size();
let api = try_driver()?;
let rc = unsafe {
(api.cu_memcpy_dtoh_v2)(
dst.as_mut_ptr().cast::<c_void>(),
src.as_device_ptr(),
byte_size,
)
};
oxicuda_driver::check(rc)
}
pub fn copy_dtod<T: Copy>(dst: &mut DeviceBuffer<T>, src: &DeviceBuffer<T>) -> CudaResult<()> {
if dst.len() != src.len() {
return Err(CudaError::InvalidValue);
}
let byte_size = src.byte_size();
let api = try_driver()?;
let rc =
unsafe { (api.cu_memcpy_dtod_v2)(dst.as_device_ptr(), src.as_device_ptr(), byte_size) };
oxicuda_driver::check(rc)
}
pub fn copy_htod_async_raw<T: Copy>(
dst: &mut DeviceBuffer<T>,
src: &[T],
stream: &Stream,
) -> CudaResult<()> {
if src.len() != dst.len() {
return Err(CudaError::InvalidValue);
}
let byte_size = dst.byte_size();
let api = try_driver()?;
let rc = unsafe {
(api.cu_memcpy_htod_async_v2)(
dst.as_device_ptr(),
src.as_ptr().cast::<c_void>(),
byte_size,
stream.raw(),
)
};
oxicuda_driver::check(rc)
}
pub fn copy_dtoh_async_raw<T: Copy>(
dst: &mut [T],
src: &DeviceBuffer<T>,
stream: &Stream,
) -> CudaResult<()> {
if dst.len() != src.len() {
return Err(CudaError::InvalidValue);
}
let byte_size = src.byte_size();
let api = try_driver()?;
let rc = unsafe {
(api.cu_memcpy_dtoh_async_v2)(
dst.as_mut_ptr().cast::<c_void>(),
src.as_device_ptr(),
byte_size,
stream.raw(),
)
};
oxicuda_driver::check(rc)
}
pub fn copy_dtod_async<T: Copy>(
dst: &mut DeviceBuffer<T>,
src: &DeviceBuffer<T>,
stream: &Stream,
) -> CudaResult<()> {
if dst.len() != src.len() {
return Err(CudaError::InvalidValue);
}
let _ = stream;
copy_dtod(dst, src)
}
pub fn copy_htod_async<T: Copy>(
dst: &mut DeviceBuffer<T>,
src: &PinnedBuffer<T>,
stream: &Stream,
) -> CudaResult<()> {
if src.len() != dst.len() {
return Err(CudaError::InvalidValue);
}
let byte_size = dst.byte_size();
let api = try_driver()?;
let rc = unsafe {
(api.cu_memcpy_htod_async_v2)(
dst.as_device_ptr(),
src.as_ptr().cast::<c_void>(),
byte_size,
stream.raw(),
)
};
oxicuda_driver::check(rc)
}
pub fn copy_dtoh_async<T: Copy>(
dst: &mut PinnedBuffer<T>,
src: &DeviceBuffer<T>,
stream: &Stream,
) -> CudaResult<()> {
if dst.len() != src.len() {
return Err(CudaError::InvalidValue);
}
let byte_size = src.byte_size();
let api = try_driver()?;
let rc = unsafe {
(api.cu_memcpy_dtoh_async_v2)(
dst.as_mut_ptr().cast::<c_void>(),
src.as_device_ptr(),
byte_size,
stream.raw(),
)
};
oxicuda_driver::check(rc)
}
#[cfg(test)]
mod tests {
#[test]
fn copy_htod_signature_compiles() {
let _f: fn(&mut super::DeviceBuffer<f32>, &[f32]) -> super::CudaResult<()> =
super::copy_htod;
let _f2: fn(&mut [f32], &super::DeviceBuffer<f32>) -> super::CudaResult<()> =
super::copy_dtoh;
}
#[test]
fn copy_dtod_signature_compiles() {
let _f: fn(
&mut super::DeviceBuffer<f32>,
&super::DeviceBuffer<f32>,
) -> super::CudaResult<()> = super::copy_dtod;
}
#[test]
fn async_raw_htod_signature_compiles() {
let _f: fn(
&mut super::DeviceBuffer<f32>,
&[f32],
&oxicuda_driver::stream::Stream,
) -> super::CudaResult<()> = super::copy_htod_async_raw;
}
#[test]
fn async_raw_dtoh_signature_compiles() {
let _f: fn(
&mut [f32],
&super::DeviceBuffer<f32>,
&oxicuda_driver::stream::Stream,
) -> super::CudaResult<()> = super::copy_dtoh_async_raw;
}
#[test]
fn async_dtod_signature_compiles() {
let _f: fn(
&mut super::DeviceBuffer<f32>,
&super::DeviceBuffer<f32>,
&oxicuda_driver::stream::Stream,
) -> super::CudaResult<()> = super::copy_dtod_async;
}
#[test]
fn async_pinned_htod_signature_compiles() {
let _f: fn(
&mut super::DeviceBuffer<f32>,
&super::PinnedBuffer<f32>,
&oxicuda_driver::stream::Stream,
) -> super::CudaResult<()> = super::copy_htod_async;
}
}