use std::fmt;
use std::sync::{Arc, Mutex};
use crate::device::Device;
use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::shape::c_contiguous_strides;
use crate::storage::TensorStorage;
static NEXT_TENSOR_ID: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TensorId(u64);
impl TensorId {
fn next() -> Self {
Self(NEXT_TENSOR_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed))
}
}
pub trait GradFn<T: Float>: Send + Sync + fmt::Debug {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>>;
fn inputs(&self) -> Vec<&Tensor<T>>;
fn name(&self) -> &'static str;
}
struct TensorInner<T: Float> {
id: TensorId,
storage: Arc<TensorStorage<T>>,
shape: Vec<usize>,
strides: Vec<isize>,
offset: usize,
grad: Mutex<Option<Box<Tensor<T>>>>,
grad_fn: Option<Arc<dyn GradFn<T>>>,
requires_grad: bool,
is_leaf: bool,
}
pub struct Tensor<T: Float = f32> {
inner: Arc<TensorInner<T>>,
}
impl<T: Float> Tensor<T> {
pub fn from_storage(
storage: TensorStorage<T>,
shape: Vec<usize>,
requires_grad: bool,
) -> FerrotorchResult<Self> {
let numel: usize = shape.iter().product();
if numel > storage.len() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"shape {:?} requires {} elements but storage has {}",
shape,
numel,
storage.len()
),
});
}
let strides = c_contiguous_strides(&shape);
Ok(Self {
inner: Arc::new(TensorInner {
id: TensorId::next(),
storage: Arc::new(storage),
shape,
strides,
offset: 0,
grad: Mutex::new(None),
grad_fn: None,
requires_grad,
is_leaf: true,
}),
})
}
pub fn view_reshape(&self, new_shape: Vec<usize>) -> FerrotorchResult<Self> {
let new_numel: usize = new_shape.iter().product();
if new_numel != self.numel() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"view_reshape: new shape {:?} ({} elements) vs old {:?} ({} elements)",
new_shape, new_numel, self.shape(), self.numel()
),
});
}
let strides = c_contiguous_strides(&new_shape);
Ok(Self {
inner: Arc::new(TensorInner {
id: TensorId::next(),
storage: Arc::clone(&self.inner.storage),
shape: new_shape,
strides,
offset: self.inner.offset,
grad: Mutex::new(None),
grad_fn: None,
requires_grad: false,
is_leaf: true,
}),
})
}
pub fn view_operation(
&self,
new_shape: Vec<usize>,
grad_fn: Arc<dyn GradFn<T>>,
) -> FerrotorchResult<Self> {
let new_numel: usize = new_shape.iter().product();
if new_numel != self.numel() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"view_operation: new shape {:?} ({} elements) vs {:?} ({} elements)",
new_shape, new_numel, self.shape(), self.numel()
),
});
}
let strides = c_contiguous_strides(&new_shape);
Ok(Self {
inner: Arc::new(TensorInner {
id: TensorId::next(),
storage: Arc::clone(&self.inner.storage),
shape: new_shape,
strides,
offset: self.inner.offset,
grad: Mutex::new(None),
grad_fn: Some(grad_fn),
requires_grad: true,
is_leaf: false,
}),
})
}
pub fn from_operation(
storage: TensorStorage<T>,
shape: Vec<usize>,
grad_fn: Arc<dyn GradFn<T>>,
) -> FerrotorchResult<Self> {
let numel: usize = shape.iter().product();
if numel > storage.len() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"shape {:?} requires {} elements but storage has {}",
shape,
numel,
storage.len()
),
});
}
let strides = c_contiguous_strides(&shape);
Ok(Self {
inner: Arc::new(TensorInner {
id: TensorId::next(),
storage: Arc::new(storage),
shape,
strides,
offset: 0,
grad: Mutex::new(None),
grad_fn: Some(grad_fn),
requires_grad: true,
is_leaf: false,
}),
})
}
}
impl<T: Float> Tensor<T> {
#[inline]
pub fn id(&self) -> TensorId {
self.inner.id
}
#[inline]
pub fn shape(&self) -> &[usize] {
&self.inner.shape
}
#[inline]
pub fn ndim(&self) -> usize {
self.inner.shape.len()
}
#[inline]
pub fn numel(&self) -> usize {
self.inner.shape.iter().product()
}
#[inline]
pub fn strides(&self) -> &[isize] {
&self.inner.strides
}
#[inline]
pub fn device(&self) -> Device {
self.inner.storage.device()
}
#[inline]
pub fn requires_grad(&self) -> bool {
self.inner.requires_grad
}
#[inline]
pub fn is_leaf(&self) -> bool {
self.inner.is_leaf
}
#[inline]
pub fn grad_fn(&self) -> Option<&Arc<dyn GradFn<T>>> {
self.inner.grad_fn.as_ref()
}
pub fn grad(&self) -> FerrotorchResult<Option<Tensor<T>>> {
let guard = self
.inner
.grad
.lock()
.map_err(|e| FerrotorchError::LockPoisoned {
message: format!("grad mutex: {e}"),
})?;
Ok(guard.as_ref().map(|b| (**b).clone()))
}
pub fn set_grad(&self, grad: Option<Tensor<T>>) -> FerrotorchResult<()> {
let mut guard =
self.inner
.grad
.lock()
.map_err(|e| FerrotorchError::LockPoisoned {
message: format!("grad mutex: {e}"),
})?;
*guard = grad.map(Box::new);
Ok(())
}
pub fn zero_grad(&self) -> FerrotorchResult<()> {
self.set_grad(None)
}
pub(crate) fn accumulate_grad(&self, incoming: &Tensor<T>) -> FerrotorchResult<()> {
let mut guard =
self.inner
.grad
.lock()
.map_err(|e| FerrotorchError::LockPoisoned {
message: format!("grad mutex: {e}"),
})?;
match guard.as_mut() {
None => {
let cpu_incoming = if incoming.is_cuda() { incoming.cpu()? } else { incoming.clone() };
let data = cpu_incoming.data()?.to_vec();
let storage = TensorStorage::cpu(data);
let tensor = Tensor::from_storage(storage, incoming.shape().to_vec(), false)?;
*guard = Some(Box::new(tensor));
}
Some(existing) => {
let cpu_incoming = if incoming.is_cuda() { incoming.cpu()? } else { incoming.clone() };
let incoming_data = cpu_incoming.data()?;
let existing_data = unsafe { existing.data_mut()? };
if existing_data.len() != incoming_data.len() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"gradient accumulation shape mismatch: {:?} vs {:?}",
existing.shape(),
incoming.shape()
),
});
}
for (e, &n) in existing_data.iter_mut().zip(incoming_data.iter()) {
*e = *e + n;
}
}
}
Ok(())
}
pub fn data(&self) -> FerrotorchResult<&[T]> {
if self.inner.storage.is_gpu() {
return Err(FerrotorchError::GpuTensorNotAccessible);
}
let slice = self.inner.storage.as_slice();
let end = self.inner.offset + self.numel();
if end > slice.len() {
return Err(FerrotorchError::InvalidArgument {
message: "tensor view extends beyond storage".into(),
});
}
Ok(&slice[self.inner.offset..end])
}
pub fn data_vec(&self) -> FerrotorchResult<Vec<T>> {
if self.is_cuda() {
let cpu_tensor = self.cpu()?;
Ok(cpu_tensor.data()?.to_vec())
} else {
Ok(self.data()?.to_vec())
}
}
pub fn into_storage_and_shape(self) -> FerrotorchResult<(TensorStorage<T>, Vec<usize>)> {
let shape = self.inner.shape.clone();
match Arc::try_unwrap(self.inner) {
Ok(inner) => {
match Arc::try_unwrap(inner.storage) {
Ok(storage) => Ok((storage, shape)),
Err(arc_storage) => {
Ok(((*arc_storage).clone(), shape))
}
}
}
Err(arc_inner) => {
let data = arc_inner.storage.as_slice();
let end = arc_inner.offset + shape.iter().product::<usize>();
Ok((TensorStorage::cpu(data[arc_inner.offset..end].to_vec()), shape))
}
}
}
pub fn to(&self, device: Device) -> FerrotorchResult<Tensor<T>> {
if self.device() == device {
return Ok(self.clone());
}
match (self.device(), device) {
(Device::Cpu, Device::Cuda(ordinal)) => {
let backend = crate::gpu_dispatch::gpu_backend()
.ok_or(FerrotorchError::DeviceUnavailable)?;
let cpu_data = self.data()?;
let bytes = unsafe {
std::slice::from_raw_parts(
cpu_data.as_ptr() as *const u8,
cpu_data.len() * std::mem::size_of::<T>(),
)
};
let handle = backend.cpu_to_gpu(bytes, std::mem::size_of::<T>(), ordinal)?;
let storage = TensorStorage::gpu(handle);
Tensor::from_storage(storage, self.shape().to_vec(), self.requires_grad())
}
(Device::Cuda(_), Device::Cpu) => {
let backend = crate::gpu_dispatch::gpu_backend()
.ok_or(FerrotorchError::DeviceUnavailable)?;
let handle = self.gpu_handle()?;
let bytes = backend.gpu_to_cpu(handle)?;
let data: Vec<T> = unsafe {
let mut bytes = std::mem::ManuallyDrop::new(bytes);
let len = bytes.len() / std::mem::size_of::<T>();
let cap = bytes.capacity() / std::mem::size_of::<T>();
Vec::from_raw_parts(bytes.as_mut_ptr() as *mut T, len, cap)
};
Tensor::from_storage(TensorStorage::cpu(data), self.shape().to_vec(), self.requires_grad())
}
(Device::Cuda(a), Device::Cuda(b)) if a != b => {
let cpu = self.to(Device::Cpu)?;
cpu.to(Device::Cuda(b))
}
_ => Ok(self.clone()),
}
}
pub fn cuda(&self) -> FerrotorchResult<Tensor<T>> {
self.to(Device::Cuda(0))
}
pub fn cpu(&self) -> FerrotorchResult<Tensor<T>> {
self.to(Device::Cpu)
}
#[inline]
pub fn is_cpu(&self) -> bool {
self.device().is_cpu()
}
#[inline]
pub fn is_cuda(&self) -> bool {
self.device().is_cuda()
}
pub fn gpu_handle(&self) -> FerrotorchResult<&crate::gpu_dispatch::GpuBufferHandle> {
self.inner.storage.gpu_handle().ok_or(FerrotorchError::InvalidArgument {
message: "tensor is on CPU, not GPU".into(),
})
}
pub unsafe fn data_mut(&self) -> FerrotorchResult<&mut [T]> {
let storage_ptr = Arc::as_ptr(&self.inner.storage) as *mut TensorStorage<T>;
let storage = unsafe { &mut *storage_ptr };
let slice = storage.as_mut_slice();
let end = self.inner.offset + self.numel();
if end > slice.len() {
return Err(FerrotorchError::InvalidArgument {
message: "tensor view extends beyond storage".into(),
});
}
Ok(&mut slice[self.inner.offset..end])
}
pub fn detach(&self) -> Self {
Self {
inner: Arc::new(TensorInner {
id: TensorId::next(),
storage: Arc::clone(&self.inner.storage),
shape: self.inner.shape.clone(),
strides: self.inner.strides.clone(),
offset: self.inner.offset,
grad: Mutex::new(None),
grad_fn: None,
requires_grad: false,
is_leaf: true,
}),
}
}
pub fn requires_grad_(self, requires_grad: bool) -> Self {
Self {
inner: Arc::new(TensorInner {
id: self.inner.id,
storage: Arc::clone(&self.inner.storage),
shape: self.inner.shape.clone(),
strides: self.inner.strides.clone(),
offset: self.inner.offset,
grad: Mutex::new(None),
grad_fn: self.inner.grad_fn.clone(),
requires_grad,
is_leaf: self.inner.is_leaf,
}),
}
}
pub fn is_contiguous(&self) -> bool {
if self.inner.shape.is_empty() {
return true;
}
let mut expected_stride: isize = 1;
for i in (0..self.ndim()).rev() {
if self.inner.shape[i] == 0 {
return true;
}
if self.inner.strides[i] != expected_stride {
return false;
}
expected_stride *= self.inner.shape[i] as isize;
}
true
}
#[inline]
pub fn is_scalar(&self) -> bool {
self.inner.shape.is_empty()
}
pub fn item(&self) -> FerrotorchResult<T> {
if !self.is_scalar() && self.numel() != 1 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"item() requires a scalar or single-element tensor, got shape {:?}",
self.shape()
),
});
}
let data = self.data()?;
Ok(data[0])
}
pub fn is_same(&self, other: &Self) -> bool {
self.inner.id == other.inner.id
}
#[cfg(test)]
pub(crate) fn shares_storage(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.inner.storage, &other.inner.storage)
}
}
impl<T: Float> Clone for Tensor<T> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl<T: Float> fmt::Debug for Tensor<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Tensor")
.field("id", &self.inner.id)
.field("shape", &self.inner.shape)
.field("device", &self.device())
.field("requires_grad", &self.inner.requires_grad)
.field("is_leaf", &self.inner.is_leaf)
.field(
"grad_fn",
&self.inner.grad_fn.as_ref().map(|gf| gf.name()),
)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::TensorStorage;
#[test]
fn test_tensor_from_storage() {
let storage = TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]);
let t = Tensor::from_storage(storage, vec![2, 3], false).unwrap();
assert_eq!(t.shape(), &[2, 3]);
assert_eq!(t.strides(), &[3, 1]);
assert_eq!(t.ndim(), 2);
assert_eq!(t.numel(), 6);
assert!(t.is_contiguous());
assert!(t.is_leaf());
assert!(!t.requires_grad());
assert_eq!(t.device(), Device::Cpu);
}
#[test]
fn test_tensor_shape_mismatch() {
let storage = TensorStorage::cpu(vec![1.0f32, 2.0, 3.0]);
let result = Tensor::from_storage(storage, vec![2, 3], false);
assert!(result.is_err());
}
#[test]
fn test_tensor_data_access() {
let storage = TensorStorage::cpu(vec![1.0f64, 2.0, 3.0]);
let t = Tensor::from_storage(storage, vec![3], false).unwrap();
assert_eq!(t.data().unwrap(), &[1.0, 2.0, 3.0]);
}
#[test]
fn test_tensor_scalar() {
let storage = TensorStorage::cpu(vec![42.0f32]);
let t = Tensor::from_storage(storage, vec![], false).unwrap();
assert!(t.is_scalar());
assert_eq!(t.item().unwrap(), 42.0);
}
#[test]
fn test_tensor_detach() {
let storage = TensorStorage::cpu(vec![1.0f32, 2.0]);
let t = Tensor::from_storage(storage, vec![2], true).unwrap();
assert!(t.requires_grad());
let d = t.detach();
assert!(!d.requires_grad());
assert!(d.is_leaf());
assert!(d.grad_fn().is_none());
}
#[test]
fn test_tensor_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<Tensor<f32>>();
assert_send_sync::<Tensor<f64>>();
}
#[test]
fn test_clone_shares_identity() {
let storage = TensorStorage::cpu(vec![1.0f32, 2.0]);
let t = Tensor::from_storage(storage, vec![2], true).unwrap();
let t2 = t.clone();
assert!(t.is_same(&t2));
assert_eq!(t.id(), t2.id());
}
#[test]
fn test_view_operation_shares_storage() {
use crate::grad_fns::shape::FlattenBackward;
let storage = TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]);
let t = Tensor::from_storage(storage, vec![2, 3], true).unwrap();
let grad_fn = Arc::new(FlattenBackward::new(t.clone(), t.shape().to_vec()));
let view = t.view_operation(vec![6], grad_fn).unwrap();
assert!(t.shares_storage(&view), "view_operation must share storage");
assert!(!t.is_same(&view), "view_operation creates new tensor identity");
}
#[test]
fn test_clone_shares_grad() {
let storage = TensorStorage::cpu(vec![1.0f32, 2.0, 3.0]);
let t = Tensor::from_storage(storage, vec![3], true).unwrap();
let t2 = t.clone();
let g = Tensor::from_storage(TensorStorage::cpu(vec![0.1, 0.2, 0.3]), vec![3], false)
.unwrap();
t.accumulate_grad(&g).unwrap();
let grad = t2.grad().unwrap().unwrap();
let data = grad.data().unwrap();
assert!((data[0] - 0.1).abs() < 1e-7);
}
#[test]
fn test_tensor_grad_accumulation() {
let storage = TensorStorage::cpu(vec![1.0f32, 2.0, 3.0]);
let t = Tensor::from_storage(storage, vec![3], true).unwrap();
assert!(t.grad().unwrap().is_none());
let g1 = Tensor::from_storage(TensorStorage::cpu(vec![0.1, 0.2, 0.3]), vec![3], false)
.unwrap();
t.accumulate_grad(&g1).unwrap();
let grad = t.grad().unwrap().unwrap();
let data = grad.data().unwrap();
assert!((data[0] - 0.1).abs() < 1e-7);
let g2 = Tensor::from_storage(TensorStorage::cpu(vec![1.0, 1.0, 1.0]), vec![3], false)
.unwrap();
t.accumulate_grad(&g2).unwrap();
let grad = t.grad().unwrap().unwrap();
let data = grad.data().unwrap();
assert!((data[0] - 1.1).abs() < 1e-6);
assert!((data[1] - 1.2).abs() < 1e-6);
assert!((data[2] - 1.3).abs() < 1e-6);
}
}