use crate::cuda::device::CudaDevice;
use crate::cuda::error::CudaResult;
use crate::cuda::memory::{MemoryAdvice, UnifiedAllocation};
use std::sync::Arc;
use torsh_core::DType;
pub trait BufferTrait<T> {
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn dtype(&self) -> DType;
fn device(&self) -> &dyn DeviceTrait;
fn copy_from_host(&mut self, data: &[T]) -> Result<(), crate::BackendError>;
fn copy_to_host(&self, data: &mut [T]) -> Result<(), crate::BackendError>;
fn as_any(&self) -> &dyn std::any::Any;
fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
}
pub trait DeviceTrait: std::fmt::Debug + Send + Sync {
fn id(&self) -> usize;
}
impl DeviceTrait for CudaDevice {
fn id(&self) -> usize {
self.id()
}
}
pub trait BufferOpsTrait<T> {
fn fill(&mut self, value: T) -> Result<(), crate::BackendError>;
fn copy_from_buffer(&mut self, src: &dyn BufferTrait<T>) -> Result<(), crate::BackendError>;
fn set_zero(&mut self) -> Result<(), crate::BackendError>;
}
#[derive(Debug, Clone)]
pub struct UnifiedBufferDebugInfo {
pub ptr: *mut u8,
pub length: usize,
pub size_bytes: usize,
pub dtype: DType,
pub device_id: usize,
}
#[derive(Debug)]
pub struct UnifiedBuffer<T> {
allocation: UnifiedAllocation,
length: usize,
dtype: DType,
device: Arc<CudaDevice>,
_phantom: std::marker::PhantomData<T>,
}
impl<T: Clone + Send + Sync + 'static> UnifiedBuffer<T> {
pub fn new(device: Arc<CudaDevice>, length: usize, dtype: DType) -> CudaResult<Self> {
let byte_size = length * std::mem::size_of::<T>();
let allocation = device.memory_manager().allocate_unified(byte_size)?;
Ok(Self {
allocation,
length,
dtype,
device,
_phantom: std::marker::PhantomData,
})
}
pub fn allocation(&self) -> &UnifiedAllocation {
&self.allocation
}
pub fn allocation_mut(&mut self) -> &mut UnifiedAllocation {
&mut self.allocation
}
pub fn as_ptr(&self) -> *const T {
self.allocation.as_ptr() as *const T
}
pub fn as_mut_ptr(&mut self) -> *mut T {
self.allocation.as_ptr() as *mut T
}
pub unsafe fn as_slice(&self) -> &[T] {
std::slice::from_raw_parts(self.as_ptr(), self.length)
}
pub unsafe fn as_mut_slice(&mut self) -> &mut [T] {
std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.length)
}
pub fn prefetch_to_device(&self, _device_id: Option<usize>) -> CudaResult<()> {
let byte_size = self.length * std::mem::size_of::<T>();
self.device
.memory_manager()
.prefetch_to_device(self.allocation.ptr(), byte_size)
}
pub fn prefetch_to_host(&self) -> CudaResult<()> {
let byte_size = self.length * std::mem::size_of::<T>();
self.device
.memory_manager()
.prefetch_to_host(self.allocation.ptr(), byte_size)
}
pub fn set_memory_advice(
&self,
advice: MemoryAdvice,
device_id: Option<usize>,
) -> CudaResult<()> {
let byte_size = self.length * std::mem::size_of::<T>();
let device = device_id.unwrap_or(0) as i32;
self.device.memory_manager().set_memory_advice(
self.allocation.ptr(),
byte_size,
advice,
device,
)
}
pub fn set_read_mostly(&self) -> CudaResult<()> {
self.set_memory_advice(MemoryAdvice::SetReadMostly, None)
}
pub fn set_preferred_location(&self, device_id: usize) -> CudaResult<()> {
self.set_memory_advice(MemoryAdvice::SetPreferredLocation, Some(device_id))
}
pub fn set_accessed_by(&self, device_id: usize) -> CudaResult<()> {
self.set_memory_advice(MemoryAdvice::SetAccessedBy, Some(device_id))
}
pub fn debug_info(&self) -> UnifiedBufferDebugInfo {
UnifiedBufferDebugInfo {
ptr: self.allocation.ptr(),
length: self.length,
size_bytes: self.length * std::mem::size_of::<T>(),
dtype: self.dtype,
device_id: self.device.id(),
}
}
}
impl<T: Clone + Send + Sync + 'static> BufferTrait<T> for UnifiedBuffer<T> {
fn len(&self) -> usize {
self.length
}
fn dtype(&self) -> DType {
self.dtype
}
fn device(&self) -> &dyn DeviceTrait {
self.device.as_ref()
}
fn copy_from_host(&mut self, data: &[T]) -> Result<(), crate::BackendError> {
if data.len() != self.length {
return Err(crate::BackendError::InvalidBuffer {
message: format!(
"Data length {} does not match buffer length {}",
data.len(),
self.length
),
});
}
self.allocation
.copy_from_host(data)
.map_err(|e| crate::BackendError::Runtime {
message: format!("Failed to copy from host: {}", e),
})
}
fn copy_to_host(&self, data: &mut [T]) -> Result<(), crate::BackendError> {
if data.len() != self.length {
return Err(crate::BackendError::InvalidBuffer {
message: format!(
"Data length {} does not match buffer length {}",
data.len(),
self.length
),
});
}
self.allocation
.copy_to_host(data)
.map_err(|e| crate::BackendError::Runtime {
message: format!("Failed to copy to host: {}", e),
})
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
impl<T: Clone + Send + Sync + 'static> BufferOpsTrait<T> for UnifiedBuffer<T> {
fn fill(&mut self, value: T) -> Result<(), crate::BackendError> {
unsafe {
let slice = self.as_mut_slice();
slice.fill(value);
}
Ok(())
}
fn copy_from_buffer(&mut self, src: &dyn BufferTrait<T>) -> Result<(), crate::BackendError> {
if src.len() != self.length {
return Err(crate::BackendError::InvalidBuffer {
message: format!(
"Source buffer length {} does not match target length {}",
src.len(),
self.length
),
});
}
if let Some(src_unified) = src.as_any().downcast_ref::<UnifiedBuffer<T>>() {
unsafe {
std::ptr::copy_nonoverlapping(src_unified.as_ptr(), self.as_mut_ptr(), self.length);
}
return Ok(());
}
let mut temp_data = Vec::<T>::with_capacity(self.length);
unsafe {
temp_data.set_len(self.length);
}
src.copy_to_host(&mut temp_data)?;
self.copy_from_host(&temp_data)
}
fn set_zero(&mut self) -> Result<(), crate::BackendError> {
unsafe {
std::ptr::write_bytes(
self.as_mut_ptr() as *mut u8,
0,
self.length * std::mem::size_of::<T>(),
);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cuda::device::CudaDevice;
use torsh_core::DType;
#[test]
fn test_unified_buffer_creation() {
if crate::is_available() {
let device = Arc::new(CudaDevice::new(0).expect("Arc should succeed"));
if device
.supports_feature(crate::cuda::device::CudaFeature::ManagedMemory)
.unwrap_or(false)
{
let buffer = UnifiedBuffer::<f32>::new(device, 1024, DType::F32);
assert!(buffer.is_ok());
let buffer = buffer.expect("operation should succeed");
assert_eq!(buffer.len(), 1024);
assert_eq!(buffer.dtype(), DType::F32);
}
}
}
#[test]
fn test_unified_buffer_operations() {
if crate::is_available() {
let device = Arc::new(CudaDevice::new(0).expect("Arc should succeed"));
if device
.supports_feature(crate::cuda::device::CudaFeature::ManagedMemory)
.unwrap_or(false)
{
let mut buffer = UnifiedBuffer::<f32>::new(device, 4, DType::F32)
.expect("construction with valid parameters should succeed");
let test_data = vec![1.0, 2.0, 3.0, 4.0];
buffer
.copy_from_host(&test_data)
.expect("copy from host memory should succeed");
let mut result_data = vec![0.0; 4];
buffer
.copy_to_host(&mut result_data)
.expect("copy to host memory should succeed");
assert_eq!(result_data, test_data);
buffer
.prefetch_to_device(None)
.expect("prefetch to device should succeed");
buffer
.prefetch_to_host()
.expect("prefetch to host should succeed");
buffer
.set_read_mostly()
.expect("read-mostly hint should be applied successfully");
buffer
.set_preferred_location(0)
.expect("preferred location should be set successfully");
buffer
.set_accessed_by(0)
.expect("accessed-by hint should be set successfully");
}
}
}
#[test]
fn test_unified_buffer_slice_access() {
if crate::is_available() {
let device = Arc::new(CudaDevice::new(0).expect("Arc should succeed"));
if device
.supports_feature(crate::cuda::device::CudaFeature::ManagedMemory)
.unwrap_or(false)
{
let mut buffer = UnifiedBuffer::<i32>::new(device, 8, DType::I32)
.expect("construction with valid parameters should succeed");
let test_data = vec![1, 2, 3, 4, 5, 6, 7, 8];
buffer
.copy_from_host(&test_data)
.expect("copy from host memory should succeed");
buffer
.prefetch_to_host()
.expect("prefetch to host should succeed");
unsafe {
let slice = buffer.as_slice();
assert_eq!(slice.len(), 8);
assert_eq!(slice[0], 1);
assert_eq!(slice[7], 8);
}
}
}
}
#[test]
fn test_unified_buffer_fill_and_zero() {
if crate::is_available() {
let device = Arc::new(CudaDevice::new(0).expect("Arc should succeed"));
if device
.supports_feature(crate::cuda::device::CudaFeature::ManagedMemory)
.unwrap_or(false)
{
let mut buffer = UnifiedBuffer::<f32>::new(device, 10, DType::F32)
.expect("construction with valid parameters should succeed");
buffer.fill(3.14).expect("fill operation should succeed");
let mut result = vec![0.0; 10];
buffer
.copy_to_host(&mut result)
.expect("copy to host memory should succeed");
for &val in &result {
assert_eq!(val, 3.14);
}
buffer.set_zero().expect("zero-fill should succeed");
buffer
.copy_to_host(&mut result)
.expect("copy to host memory should succeed");
for &val in &result {
assert_eq!(val, 0.0);
}
}
}
}
}