use crate::cuda::cuda_sys_compat as cuda_sys;
use crate::cuda::device::CudaDevice;
use crate::cuda::error::{CudaError, CudaResult};
use crate::cuda::memory::{CudaAllocation, SendSyncPtr};
use crate::cuda::stream::CudaStream;
#[allow(unused_imports)]
use crate::{Buffer, BufferError};
use std::ffi::c_void;
use std::sync::Arc;
use torsh_core::DType;
pub type BufferDevicePtr<T> = SendSyncPtr<T>;
#[derive(Debug, Clone)]
pub struct CudaBuffer<T> {
allocation: CudaAllocation,
length: usize,
dtype: DType,
device: Arc<CudaDevice>,
_phantom: std::marker::PhantomData<T>,
}
impl<T: Clone + Send + Sync + 'static> CudaBuffer<T> {
pub fn new(device: Arc<CudaDevice>, length: usize, dtype: DType) -> CudaResult<Self> {
let size = length * std::mem::size_of::<T>();
let allocation = device.memory_manager().allocate(size)?;
Ok(Self {
allocation,
length,
dtype,
device,
_phantom: std::marker::PhantomData,
})
}
pub fn from_allocation(
device: Arc<CudaDevice>,
allocation: CudaAllocation,
length: usize,
dtype: DType,
) -> Self {
Self {
allocation,
length,
dtype,
device,
_phantom: std::marker::PhantomData,
}
}
pub fn device_ptr(&self) -> BufferDevicePtr<T> {
SendSyncPtr::new(self.allocation.as_ptr() as *mut T)
}
pub fn raw_ptr(&self) -> *mut u8 {
self.allocation.as_ptr()
}
pub fn copy_from_host(&mut self, data: &[T]) -> CudaResult<()> {
if data.len() != self.length {
return Err(CudaError::Memory {
message: format!(
"Data length mismatch: expected {}, got {}",
self.length,
data.len()
),
});
}
unsafe {
let result = cuda_sys::cudaMemcpy(
self.allocation.as_ptr() as *mut c_void,
data.as_ptr() as *const c_void,
data.len() * std::mem::size_of::<T>(),
cuda_sys::cudaMemcpyKind_cudaMemcpyHostToDevice,
);
if result != crate::cuda::cudaSuccess {
return Err(CudaError::Memory {
message: format!("Host-to-device copy failed: {:?}", result),
});
}
}
Ok(())
}
pub fn copy_from_host_async(&mut self, data: &[T], stream: &CudaStream) -> CudaResult<()> {
if data.len() != self.length {
return Err(CudaError::Memory {
message: format!(
"Data length mismatch: expected {}, got {}",
self.length,
data.len()
),
});
}
unsafe {
let result = cuda_sys::cudaMemcpyAsync(
self.allocation.as_ptr() as *mut c_void,
data.as_ptr() as *const c_void,
data.len() * std::mem::size_of::<T>(),
cuda_sys::cudaMemcpyKind_cudaMemcpyHostToDevice,
stream.stream(),
);
if result != crate::cuda::cudaSuccess {
return Err(CudaError::Memory {
message: format!("Async host-to-device copy failed: {:?}", result),
});
}
}
Ok(())
}
pub fn copy_to_host(&self, data: &mut [T]) -> CudaResult<()> {
if data.len() != self.length {
return Err(CudaError::Memory {
message: format!(
"Data length mismatch: expected {}, got {}",
self.length,
data.len()
),
});
}
unsafe {
let result = cuda_sys::cudaMemcpy(
data.as_mut_ptr() as *mut c_void,
self.allocation.as_ptr() as *const c_void,
data.len() * std::mem::size_of::<T>(),
cuda_sys::cudaMemcpyKind_cudaMemcpyDeviceToHost,
);
if result != crate::cuda::cudaSuccess {
return Err(CudaError::Memory {
message: format!("Device-to-host copy failed: {:?}", result),
});
}
}
Ok(())
}
pub fn copy_to_host_async(&self, data: &mut [T], stream: &CudaStream) -> CudaResult<()> {
if data.len() != self.length {
return Err(CudaError::Memory {
message: format!(
"Data length mismatch: expected {}, got {}",
self.length,
data.len()
),
});
}
unsafe {
let result = cuda_sys::cudaMemcpyAsync(
data.as_mut_ptr() as *mut c_void,
self.allocation.as_ptr() as *const c_void,
data.len() * std::mem::size_of::<T>(),
cuda_sys::cudaMemcpyKind_cudaMemcpyDeviceToHost,
stream.stream(),
);
if result != crate::cuda::cudaSuccess {
return Err(CudaError::Memory {
message: format!("Async device-to-host copy failed: {:?}", result),
});
}
}
Ok(())
}
pub fn copy_from_buffer(&mut self, src: &CudaBuffer<T>) -> CudaResult<()> {
if self.length != src.length {
return Err(CudaError::Memory {
message: format!(
"Buffer length mismatch: expected {}, got {}",
self.length, src.length
),
});
}
unsafe {
let result = cuda_sys::cudaMemcpy(
self.allocation.as_ptr() as *mut c_void,
src.allocation.as_ptr() as *const c_void,
self.length * std::mem::size_of::<T>(),
cuda_sys::cudaMemcpyKind_cudaMemcpyDeviceToDevice,
);
if result != crate::cuda::cudaSuccess {
return Err(CudaError::Memory {
message: format!("Device-to-device copy failed: {:?}", result),
});
}
}
Ok(())
}
pub fn copy_from(&mut self, src: &CudaBuffer<T>) -> CudaResult<()> {
self.copy_from_buffer(src)
}
pub fn copy_from_buffer_async(
&mut self,
src: &CudaBuffer<T>,
stream: &CudaStream,
) -> CudaResult<()> {
if self.length != src.length {
return Err(CudaError::Memory {
message: format!(
"Buffer length mismatch: expected {}, got {}",
self.length, src.length
),
});
}
unsafe {
let result = cuda_sys::cudaMemcpyAsync(
self.allocation.as_ptr() as *mut c_void,
src.allocation.as_ptr() as *const c_void,
self.length * std::mem::size_of::<T>(),
cuda_sys::cudaMemcpyKind_cudaMemcpyDeviceToDevice,
stream.stream(),
);
if result != crate::cuda::cudaSuccess {
return Err(CudaError::Memory {
message: format!("Async device-to-device copy failed: {:?}", result),
});
}
}
Ok(())
}
pub fn fill(&mut self, value: T) -> CudaResult<()>
where
T: Copy,
{
let data = vec![value; self.length];
self.copy_from_host(&data)
}
pub fn size_bytes(&self) -> usize {
self.allocation.size()
}
pub fn len(&self) -> usize {
self.length
}
pub fn is_empty(&self) -> bool {
self.length == 0
}
pub fn dtype(&self) -> DType {
self.dtype
}
pub fn device(&self) -> &Arc<CudaDevice> {
&self.device
}
}
impl<T> Drop for CudaBuffer<T> {
fn drop(&mut self) {
if let Err(e) = self
.device
.memory_manager()
.deallocate(self.allocation.clone())
{
tracing::warn!("Failed to deallocate CUDA buffer: {}", e);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_core::DType;
#[test]
fn test_cuda_buffer_creation() {
if crate::is_available() {
let device = Arc::new(CudaDevice::new(0).expect("Arc should succeed"));
let buffer = CudaBuffer::<f32>::new(device, 1024, DType::F32);
assert!(buffer.is_ok());
let buffer = buffer.expect("operation should succeed");
assert_eq!(buffer.len(), 1024);
assert_eq!(buffer.dtype(), DType::F32);
}
}
#[test]
fn test_host_device_copy() {
if crate::is_available() {
let device = Arc::new(CudaDevice::new(0).expect("Arc should succeed"));
let mut buffer = CudaBuffer::<f32>::new(device, 4, DType::F32)
.expect("construction with valid parameters should succeed");
let host_data = vec![1.0, 2.0, 3.0, 4.0];
buffer
.copy_from_host(&host_data)
.expect("copy from host memory should succeed");
let mut result = vec![0.0; 4];
buffer
.copy_to_host(&mut result)
.expect("copy to host memory should succeed");
assert_eq!(host_data, result);
}
}
#[test]
fn test_buffer_copy() {
if crate::is_available() {
let device = Arc::new(CudaDevice::new(0).expect("Arc should succeed"));
let mut src = CudaBuffer::<f32>::new(Arc::clone(&device), 4, DType::F32)
.expect("operation should succeed");
let mut dst = CudaBuffer::<f32>::new(Arc::clone(&device), 4, DType::F32)
.expect("operation should succeed");
let host_data = vec![1.0, 2.0, 3.0, 4.0];
src.copy_from_host(&host_data)
.expect("copy from host memory should succeed");
dst.copy_from_buffer(&src)
.expect("buffer copy should succeed");
let mut result = vec![0.0; 4];
dst.copy_to_host(&mut result)
.expect("copy to host memory should succeed");
assert_eq!(host_data, result);
}
}
}