use crate::autograd::GradFn;
use crate::buffer_pool::{pool_acquire, pool_release};
use crate::cuda_sys::{self, cudaMemcpyKind, CUDA_SUCCESS};
use crate::device::{get_device, CudaDevice};
use crate::{shape_to_bytes, DType};
use std::cell::UnsafeCell;
use std::ffi::c_void;
use std::sync::Arc;
use tl_backend::BackendResult;
pub type TensorRef = Arc<UnsafeCell<CudaTensor>>;
pub unsafe fn tensor_ref_from_ptr(ptr: *mut CudaTensor) -> TensorRef {
let arc: TensorRef = Arc::from_raw(ptr as *const UnsafeCell<CudaTensor>);
let cloned = arc.clone(); std::mem::forget(arc); cloned
}
#[inline]
pub unsafe fn tensor_ref_get(r: &TensorRef) -> &CudaTensor {
&*r.get()
}
#[inline]
pub unsafe fn tensor_ref_get_mut(r: &TensorRef) -> &mut CudaTensor {
&mut *r.get()
}
pub struct CudaBuffer {
ptr: *mut c_void,
size: usize,
}
unsafe impl Send for CudaBuffer {}
unsafe impl Sync for CudaBuffer {}
impl CudaBuffer {
pub fn new(size: usize) -> Result<Self, String> {
let device = get_device();
let ptr = device.allocate_buffer(size)?;
Ok(CudaBuffer { ptr, size })
}
pub fn ptr(&self) -> *mut c_void {
self.ptr
}
pub fn size(&self) -> usize {
self.size
}
}
impl Drop for CudaBuffer {
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe {
cuda_sys::cudaFree(self.ptr);
}
}
}
}
pub struct AutogradMeta {
pub grad: Option<CudaTensor>,
pub grad_fn: Option<Box<dyn GradFn>>,
pub requires_grad: bool,
}
pub struct CudaTensor {
pub(crate) buffer: Arc<CudaBuffer>,
pub shape: Vec<usize>,
pub dtype: DType,
#[allow(dead_code)]
device: Arc<CudaDevice>,
pub autograd: Option<Box<AutogradMeta>>,
}
impl std::fmt::Debug for CudaTensor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CudaTensor")
.field("shape", &self.shape)
.field("dtype", &self.dtype)
.field("buffer_ptr", &self.buffer.ptr())
.finish()
}
}
impl Clone for CudaTensor {
fn clone(&self) -> Self {
self.clone_data().expect("Clone failed")
}
}
unsafe impl Send for CudaTensor {}
unsafe impl Sync for CudaTensor {}
impl CudaTensor {
pub fn uninit(shape: &[usize], dtype: DType) -> Self {
let device = get_device();
let size = shape_to_bytes(shape, dtype);
let buffer = pool_acquire(size).unwrap_or_else(|| {
Arc::new(CudaBuffer::new(size).expect("CUDA buffer allocation failed"))
});
CudaTensor {
buffer,
shape: shape.to_vec(),
dtype,
device,
autograd: None,
}
}
pub(crate) fn view_with_shape(&self, new_shape: &[usize]) -> Self {
CudaTensor {
buffer: self.buffer.clone(), shape: new_shape.to_vec(),
dtype: self.dtype,
device: self.device.clone(),
autograd: None,
}
}
pub fn zeros(shape: &[usize], dtype: DType) -> Self {
let tensor = Self::uninit(shape, dtype);
let size = shape_to_bytes(shape, dtype);
if size > 0 {
let err = unsafe { cuda_sys::cudaMemset(tensor.buffer.ptr(), 0, size) };
if err != CUDA_SUCCESS {
eprintln!("cudaMemset failed in zeros(): {}", err);
}
}
tensor
}
pub fn from_slice<T: Copy>(data: &[T], shape: &[usize], dtype: DType) -> Self {
let tensor = Self::uninit(shape, dtype);
let byte_size = data.len() * std::mem::size_of::<T>();
if byte_size > 0 {
let err = unsafe {
cuda_sys::cudaMemcpy(
tensor.buffer.ptr(),
data.as_ptr() as *const c_void,
byte_size,
cudaMemcpyKind::cudaMemcpyHostToDevice,
)
};
if err != CUDA_SUCCESS {
panic!("cudaMemcpy HostToDevice failed in from_slice(): {}", err);
}
}
tensor
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn dtype(&self) -> DType {
self.dtype
}
pub fn elem_count(&self) -> usize {
self.shape.iter().product()
}
pub fn buffer_ptr(&self) -> *mut c_void {
self.buffer.ptr()
}
pub fn buffer_arc(&self) -> &Arc<CudaBuffer> {
&self.buffer
}
pub fn ones(shape: &[usize], dtype: DType) -> Self {
let count = shape.iter().product::<usize>();
match dtype {
DType::F32 => {
let data = vec![1.0f32; count];
Self::from_slice(&data, shape, dtype)
}
DType::I64 => {
let data = vec![1i64; count];
Self::from_slice(&data, shape, dtype)
}
DType::I32 => {
let data = vec![1i32; count];
Self::from_slice(&data, shape, dtype)
}
_ => unimplemented!("ones for {:?}", dtype),
}
}
pub fn randn(shape: &[usize], dtype: DType) -> Self {
use rand::Rng;
let count = shape.iter().product::<usize>();
match dtype {
DType::F32 => {
let mut rng = rand::thread_rng();
let data: Vec<f32> = (0..count)
.map(|_| {
let u1: f32 = rng.gen();
let u2: f32 = rng.gen();
(-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos()
})
.collect();
Self::from_slice(&data, shape, dtype)
}
_ => unimplemented!("randn for {:?}", dtype),
}
}
pub fn to_vec<T: Copy + Default>(&self) -> Vec<T> {
crate::stream::sync_stream();
let count = self.elem_count();
if count == 0 {
return Vec::new();
}
let t_size = std::mem::size_of::<T>();
let dtype_size = match self.dtype {
DType::F32 => 4,
DType::I64 => 8,
DType::I32 => 4,
_ => t_size,
};
if t_size == dtype_size {
let byte_size = count * t_size;
let mut result = vec![T::default(); count];
let err = unsafe {
cuda_sys::cudaMemcpy(
result.as_mut_ptr() as *mut c_void,
self.buffer.ptr() as *const c_void,
byte_size,
cudaMemcpyKind::cudaMemcpyDeviceToHost,
)
};
if err != CUDA_SUCCESS {
eprintln!("cudaMemcpy DeviceToHost failed in to_vec(): {} (byte_size={}, buffer_size={}, shape={:?}, dtype={:?})",
err, byte_size, self.buffer.size(), self.shape, self.dtype);
return vec![T::default(); count];
}
result
} else if dtype_size == 4 && t_size == 8 {
let byte_size = count * 4;
let mut f32_buf = vec![0.0f32; count];
let err = unsafe {
cuda_sys::cudaMemcpy(
f32_buf.as_mut_ptr() as *mut c_void,
self.buffer.ptr() as *const c_void,
byte_size,
cudaMemcpyKind::cudaMemcpyDeviceToHost,
)
};
if err != CUDA_SUCCESS {
eprintln!("cudaMemcpy DeviceToHost failed in to_vec() (f32→i64): {} (byte_size={}, buffer_size={}, shape={:?})",
err, byte_size, self.buffer.size(), self.shape);
return vec![T::default(); count];
}
let i64_buf: Vec<i64> = f32_buf.iter().map(|&x| x as i64).collect();
let mut result = vec![T::default(); count];
unsafe {
std::ptr::copy_nonoverlapping(
i64_buf.as_ptr() as *const T,
result.as_mut_ptr(),
count,
);
}
result
} else if dtype_size == 8 && t_size == 4 {
let byte_size = count * 8;
let mut i64_buf = vec![0i64; count];
let err = unsafe {
cuda_sys::cudaMemcpy(
i64_buf.as_mut_ptr() as *mut c_void,
self.buffer.ptr() as *const c_void,
byte_size,
cudaMemcpyKind::cudaMemcpyDeviceToHost,
)
};
if err != CUDA_SUCCESS {
eprintln!("cudaMemcpy DeviceToHost failed in to_vec() (i64→f32): {} (byte_size={}, buffer_size={}, shape={:?})",
err, byte_size, self.buffer.size(), self.shape);
return vec![T::default(); count];
}
let f32_buf: Vec<f32> = i64_buf.iter().map(|&x| x as f32).collect();
let mut result = vec![T::default(); count];
unsafe {
std::ptr::copy_nonoverlapping(
f32_buf.as_ptr() as *const T,
result.as_mut_ptr(),
count,
);
}
result
} else {
let copy_size = self.buffer.size().min(count * t_size);
let mut result = vec![T::default(); count];
let err = unsafe {
cuda_sys::cudaMemcpy(
result.as_mut_ptr() as *mut c_void,
self.buffer.ptr() as *const c_void,
copy_size,
cudaMemcpyKind::cudaMemcpyDeviceToHost,
)
};
if err != CUDA_SUCCESS {
eprintln!(
"cudaMemcpy DeviceToHost failed in to_vec() (fallback): {}",
err
);
}
result
}
}
pub fn clone_data(&self) -> BackendResult<CudaTensor> {
crate::stream::sync_stream();
let result = CudaTensor::uninit(self.shape(), self.dtype());
let size = shape_to_bytes(self.shape(), self.dtype());
if size > 0 {
let err = unsafe {
cuda_sys::cudaMemcpy(
result.buffer.ptr(),
self.buffer.ptr() as *const c_void,
size,
cudaMemcpyKind::cudaMemcpyDeviceToDevice,
)
};
if err != CUDA_SUCCESS {
return Err(tl_backend::BackendError::InternalError(format!(
"cudaMemcpy DeviceToDevice failed: {}",
err
)));
}
}
Ok(result)
}
pub fn shallow_clone(&self) -> Self {
CudaTensor {
buffer: Arc::clone(&self.buffer),
shape: self.shape.clone(),
dtype: self.dtype,
device: Arc::clone(&self.device),
autograd: if self.requires_grad() {
Some(Box::new(AutogradMeta {
grad: None,
grad_fn: None,
requires_grad: true,
}))
} else {
None
},
}
}
pub fn from_buffer_shared(shape: Vec<usize>, dtype: DType) -> Self {
let device = get_device();
let size = shape_to_bytes(&shape, dtype);
let buffer = Arc::new(CudaBuffer::new(size).expect("CUDA buffer allocation failed"));
CudaTensor {
buffer,
shape,
dtype,
device,
autograd: None,
}
}
pub fn requires_grad(&self) -> bool {
self.autograd.as_ref().map_or(false, |a| a.requires_grad)
}
pub fn enable_grad(&mut self) {
if self.autograd.is_none() {
self.autograd = Some(Box::new(AutogradMeta {
grad: None,
grad_fn: None,
requires_grad: true,
}));
} else {
self.autograd.as_mut().unwrap().requires_grad = true;
}
}
pub fn set_grad_fn(&mut self, grad_fn: Box<dyn GradFn>) {
self.autograd = Some(Box::new(AutogradMeta {
grad: None,
grad_fn: Some(grad_fn),
requires_grad: true,
}));
}
pub fn get_grad(&self) -> Option<CudaTensor> {
self.autograd
.as_ref()
.and_then(|a| a.grad.as_ref().map(|g| g.shallow_clone()))
}
pub fn zero_grad(&mut self) {
if let Some(ref mut a) = self.autograd {
a.grad = None;
}
}
pub fn accumulate_grad(&mut self, grad: CudaTensor) -> BackendResult<()> {
if let Some(ref mut meta) = self.autograd {
let detached_grad = grad.detach();
if let Some(ref mut g) = meta.grad {
*g = g.add_impl(&detached_grad)?;
} else {
meta.grad = Some(detached_grad);
}
}
Ok(())
}
pub fn backward(&mut self) -> BackendResult<()> {
if !self.requires_grad() {
return Ok(());
}
let ones = CudaTensor::ones(self.shape(), self.dtype());
let self_ptr = self as *mut CudaTensor;
let mut worklist: Vec<(*mut CudaTensor, CudaTensor, Option<TensorRef>)> =
vec![(self_ptr, ones, None)];
let mut visited: Vec<(*mut CudaTensor, Option<TensorRef>)> = Vec::new();
while let Some((tensor_ptr, grad_output, arc_ref)) = worklist.pop() {
let tensor = unsafe { &mut *tensor_ptr };
visited.push((tensor_ptr, arc_ref));
let propagation = if let Some(meta) = tensor.autograd.as_ref() {
if let Some(gf) = meta.grad_fn.as_ref() {
let grads = gf.backward(&grad_output)?;
let inputs = gf.inputs();
Some((grads, inputs))
} else {
None
}
} else {
None
};
tensor.accumulate_grad(grad_output)?;
if let Some((grads, inputs)) = propagation {
for (input_ref, grad) in inputs.into_iter().zip(grads.into_iter()) {
let input_ptr = input_ref.get() as *mut CudaTensor;
let input = unsafe { &*input_ptr };
if input.requires_grad() {
worklist.push((input_ptr, grad, Some(input_ref)));
}
}
}
}
for entry in visited.iter_mut() {
let tensor = unsafe { &mut *entry.0 };
if let Some(ref mut meta) = tensor.autograd {
meta.grad_fn = None;
}
}
Ok(())
}
pub fn detach(&self) -> CudaTensor {
self.shallow_clone()
}
}
impl Drop for CudaTensor {
fn drop(&mut self) {
if Arc::strong_count(&self.buffer) == 1 {
pool_release(self.buffer.clone());
}
}
}