#[cfg(feature = "cuda")]
use cudarc::driver::CudaSlice;
#[cfg(feature = "cuda")]
type PoolReturnFn<T> = Option<fn(usize, usize, CudaSlice<T>)>;
#[cfg(feature = "cuda")]
fn return_f32(device: usize, len: usize, slice: CudaSlice<f32>) {
crate::pool::pool_return::<CudaSlice<f32>>(device, len, 4, slice);
}
#[cfg(feature = "cuda")]
fn return_f64(device: usize, len: usize, slice: CudaSlice<f64>) {
crate::pool::pool_return::<CudaSlice<f64>>(device, len, 8, slice);
}
#[cfg(feature = "cuda")]
pub struct CudaBuffer<T> {
pub(crate) data: Option<CudaSlice<T>>,
pub(crate) len: usize,
pub(crate) device_ordinal: usize,
pub(crate) pool_fn: PoolReturnFn<T>,
}
#[cfg(feature = "cuda")]
impl CudaBuffer<f32> {
pub(crate) fn new_pooled(slice: CudaSlice<f32>, len: usize, device: usize) -> Self {
Self {
data: Some(slice),
len,
device_ordinal: device,
pool_fn: Some(return_f32),
}
}
}
#[cfg(feature = "cuda")]
impl CudaBuffer<f64> {
pub(crate) fn new_pooled(slice: CudaSlice<f64>, len: usize, device: usize) -> Self {
Self {
data: Some(slice),
len,
device_ordinal: device,
pool_fn: Some(return_f64),
}
}
}
#[cfg(feature = "cuda")]
impl<T> Drop for CudaBuffer<T> {
fn drop(&mut self) {
if let Some(slice) = self.data.take() {
if let Some(return_fn) = self.pool_fn {
return_fn(self.device_ordinal, self.len, slice);
}
}
}
}
#[cfg(feature = "cuda")]
impl<T> CudaBuffer<T> {
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn device_ordinal(&self) -> usize {
self.device_ordinal
}
#[inline]
pub fn inner(&self) -> &CudaSlice<T> {
self.data.as_ref().expect("CudaBuffer: inner slice already taken")
}
#[inline]
pub fn inner_mut(&mut self) -> &mut CudaSlice<T> {
self.data.as_mut().expect("CudaBuffer: inner slice already taken")
}
}
#[cfg(feature = "cuda")]
impl<T> std::fmt::Debug for CudaBuffer<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CudaBuffer")
.field("len", &self.len)
.field("device_ordinal", &self.device_ordinal)
.field("pooled", &self.pool_fn.is_some())
.finish_non_exhaustive()
}
}
#[cfg(not(feature = "cuda"))]
#[derive(Debug)]
pub struct CudaBuffer<T> {
pub(crate) _phantom: std::marker::PhantomData<T>,
pub(crate) len: usize,
pub(crate) device_ordinal: usize,
}
#[cfg(not(feature = "cuda"))]
impl<T> CudaBuffer<T> {
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn device_ordinal(&self) -> usize {
self.device_ordinal
}
}