use std::ffi::c_void;
use std::marker::PhantomData;
use std::sync::Arc;
use gpufft_cuda_sys as sys;
use super::device::CudaContext;
use super::error::{CudaError, check_cuda};
use crate::backend::BufferOps;
use crate::scalar::Scalar;
pub struct CudaBuffer<T: Scalar> {
pub(crate) ctx: Arc<CudaContext>,
pub(crate) d_ptr: *mut c_void,
pub(crate) len: usize,
pub(crate) size_bytes: u64,
_marker: PhantomData<T>,
}
unsafe impl<T: Scalar> Send for CudaBuffer<T> {}
unsafe impl<T: Scalar> Sync for CudaBuffer<T> {}
impl<T: Scalar> CudaBuffer<T> {
pub(crate) fn new(ctx: Arc<CudaContext>, len: usize) -> Result<Self, CudaError> {
ctx.make_current()?;
let size_bytes = (len * T::BYTES) as u64;
let mut d_ptr: *mut c_void = std::ptr::null_mut();
unsafe {
check_cuda(
"cudaMalloc",
sys::cudaMalloc(&mut d_ptr, size_bytes as usize),
)?;
}
Ok(Self {
ctx,
d_ptr,
len,
size_bytes,
_marker: PhantomData,
})
}
pub fn device_ptr(&self) -> *mut c_void {
self.d_ptr
}
pub fn size_bytes(&self) -> u64 {
self.size_bytes
}
}
impl<T: Scalar> BufferOps<super::CudaBackend, T> for CudaBuffer<T> {
fn len(&self) -> usize {
self.len
}
fn write(&mut self, src: &[T]) -> Result<(), CudaError> {
if src.len() != self.len {
return Err(CudaError::LengthMismatch {
expected: self.len,
got: src.len(),
});
}
self.ctx.make_current()?;
unsafe {
check_cuda(
"cudaMemcpy(host-to-device)",
sys::cudaMemcpy(
self.d_ptr,
src.as_ptr().cast::<c_void>(),
self.size_bytes as usize,
sys::cudaMemcpyKind_cudaMemcpyHostToDevice,
),
)?;
}
Ok(())
}
fn read(&self, dst: &mut [T]) -> Result<(), CudaError> {
if dst.len() != self.len {
return Err(CudaError::LengthMismatch {
expected: self.len,
got: dst.len(),
});
}
self.ctx.make_current()?;
unsafe {
check_cuda(
"cudaMemcpy(device-to-host)",
sys::cudaMemcpy(
dst.as_mut_ptr().cast::<c_void>(),
self.d_ptr,
self.size_bytes as usize,
sys::cudaMemcpyKind_cudaMemcpyDeviceToHost,
),
)?;
}
Ok(())
}
}
impl<T: Scalar> Drop for CudaBuffer<T> {
fn drop(&mut self) {
unsafe {
let _ = self.ctx.make_current();
sys::cudaFree(self.d_ptr);
}
}
}