use std::fmt;
use std::sync::{Arc, Mutex};
use crate::device::Device;
use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::shape::{c_contiguous_strides, channels_last_3d_strides, channels_last_strides};
use crate::storage::TensorStorage;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MemoryFormat {
Contiguous,
ChannelsLast,
ChannelsLast3d,
}
static NEXT_TENSOR_ID: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TensorId(u64);
impl TensorId {
fn next() -> Self {
Self(NEXT_TENSOR_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed))
}
}
pub trait GradFn<T: Float>: Send + Sync + fmt::Debug {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>>;
fn inputs(&self) -> Vec<&Tensor<T>>;
fn name(&self) -> &'static str;
fn scalar_args(&self) -> Vec<f64> {
vec![]
}
}
struct TensorInner<T: Float> {
id: TensorId,
storage: Arc<TensorStorage<T>>,
shape: Vec<usize>,
strides: Vec<isize>,
offset: usize,
grad: Mutex<Option<Box<Tensor<T>>>>,
grad_fn: Option<Arc<dyn GradFn<T>>>,
requires_grad: bool,
is_leaf: bool,
hooks: Mutex<crate::autograd::hooks::HookStorage<T>>,
}
pub struct Tensor<T: Float = f32> {
inner: Arc<TensorInner<T>>,
}
impl<T: Float> Tensor<T> {
pub fn from_storage(
storage: TensorStorage<T>,
shape: Vec<usize>,
requires_grad: bool,
) -> FerrotorchResult<Self> {
let numel: usize = shape.iter().product();
if numel > storage.len() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"shape {:?} requires {} elements but storage has {}",
shape,
numel,
storage.len()
),
});
}
let strides = c_contiguous_strides(&shape);
Ok(Self {
inner: Arc::new(TensorInner {
id: TensorId::next(),
storage: Arc::new(storage),
shape,
strides,
offset: 0,
grad: Mutex::new(None),
grad_fn: None,
requires_grad,
is_leaf: true,
hooks: Mutex::new(crate::autograd::hooks::HookStorage::new()),
}),
})
}
pub fn view_reshape(&self, new_shape: Vec<usize>) -> FerrotorchResult<Self> {
if !self.is_contiguous() {
return self.contiguous()?.view_reshape(new_shape);
}
let new_numel: usize = new_shape.iter().product();
if new_numel != self.numel() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"view_reshape: new shape {:?} ({} elements) vs old {:?} ({} elements)",
new_shape,
new_numel,
self.shape(),
self.numel()
),
});
}
let strides = c_contiguous_strides(&new_shape);
Ok(Self {
inner: Arc::new(TensorInner {
id: TensorId::next(),
storage: Arc::clone(&self.inner.storage),
shape: new_shape,
strides,
offset: self.inner.offset,
grad: Mutex::new(None),
grad_fn: None,
requires_grad: false,
is_leaf: true,
hooks: Mutex::new(crate::autograd::hooks::HookStorage::new()),
}),
})
}
pub fn view_operation(
&self,
new_shape: Vec<usize>,
grad_fn: Arc<dyn GradFn<T>>,
) -> FerrotorchResult<Self> {
if !self.is_contiguous() {
let data = self.data_vec()?;
let storage = TensorStorage::cpu(data);
let contiguous =
Tensor::from_storage(storage, self.shape().to_vec(), self.requires_grad())?;
return contiguous.view_operation(new_shape, grad_fn);
}
let new_numel: usize = new_shape.iter().product();
if new_numel != self.numel() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"view_operation: new shape {:?} ({} elements) vs {:?} ({} elements)",
new_shape,
new_numel,
self.shape(),
self.numel()
),
});
}
let strides = c_contiguous_strides(&new_shape);
Ok(Self {
inner: Arc::new(TensorInner {
id: TensorId::next(),
storage: Arc::clone(&self.inner.storage),
shape: new_shape,
strides,
offset: self.inner.offset,
grad: Mutex::new(None),
grad_fn: Some(grad_fn),
requires_grad: true,
is_leaf: false,
hooks: Mutex::new(crate::autograd::hooks::HookStorage::new()),
}),
})
}
pub fn stride_view(
&self,
new_shape: Vec<usize>,
new_strides: Vec<isize>,
new_offset: usize,
) -> Self {
Self {
inner: Arc::new(TensorInner {
id: TensorId::next(),
storage: Arc::clone(&self.inner.storage),
shape: new_shape,
strides: new_strides,
offset: new_offset,
grad: Mutex::new(None),
grad_fn: None,
requires_grad: false,
is_leaf: true,
hooks: Mutex::new(crate::autograd::hooks::HookStorage::new()),
}),
}
}
pub fn stride_view_operation(
&self,
new_shape: Vec<usize>,
new_strides: Vec<isize>,
new_offset: usize,
grad_fn: Arc<dyn GradFn<T>>,
) -> Self {
Self {
inner: Arc::new(TensorInner {
id: TensorId::next(),
storage: Arc::clone(&self.inner.storage),
shape: new_shape,
strides: new_strides,
offset: new_offset,
grad: Mutex::new(None),
grad_fn: Some(grad_fn),
requires_grad: true,
is_leaf: false,
hooks: Mutex::new(crate::autograd::hooks::HookStorage::new()),
}),
}
}
pub fn from_operation(
storage: TensorStorage<T>,
shape: Vec<usize>,
grad_fn: Arc<dyn GradFn<T>>,
) -> FerrotorchResult<Self> {
if crate::autograd::no_grad::is_inference_mode() {
return Self::from_storage(storage, shape, false);
}
let numel: usize = shape.iter().product();
if numel > storage.len() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"shape {:?} requires {} elements but storage has {}",
shape,
numel,
storage.len()
),
});
}
let strides = c_contiguous_strides(&shape);
Ok(Self {
inner: Arc::new(TensorInner {
id: TensorId::next(),
storage: Arc::new(storage),
shape,
strides,
offset: 0,
grad: Mutex::new(None),
grad_fn: Some(grad_fn),
requires_grad: true,
is_leaf: false,
hooks: Mutex::new(crate::autograd::hooks::HookStorage::new()),
}),
})
}
}
#[derive(Debug)]
struct ToDeviceBackward<T: Float> {
source: Tensor<T>,
}
impl<T: Float> GradFn<T> for ToDeviceBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let target_device = self.source.device();
if grad_output.device() == target_device {
Ok(vec![Some(grad_output.clone())])
} else {
Ok(vec![Some(grad_output.to(target_device)?)])
}
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.source]
}
fn name(&self) -> &'static str {
"ToDeviceBackward"
}
}
impl<T: Float> Tensor<T> {
#[inline]
pub fn id(&self) -> TensorId {
self.inner.id
}
#[inline]
pub fn shape(&self) -> &[usize] {
&self.inner.shape
}
#[inline]
pub fn ndim(&self) -> usize {
self.inner.shape.len()
}
#[inline]
pub fn numel(&self) -> usize {
self.inner.shape.iter().product()
}
#[inline]
pub fn strides(&self) -> &[isize] {
&self.inner.strides
}
#[inline]
pub fn storage_offset(&self) -> usize {
self.inner.offset
}
#[inline]
pub fn storage_len(&self) -> usize {
self.inner.storage.len()
}
#[inline]
pub fn storage(&self) -> &TensorStorage<T> {
&self.inner.storage
}
#[inline]
pub fn device(&self) -> Device {
self.inner.storage.device()
}
#[inline]
pub fn requires_grad(&self) -> bool {
self.inner.requires_grad
}
#[inline]
pub fn is_leaf(&self) -> bool {
self.inner.is_leaf
}
#[inline]
pub fn grad_fn(&self) -> Option<&Arc<dyn GradFn<T>>> {
self.inner.grad_fn.as_ref()
}
pub(crate) fn hooks(&self) -> &Mutex<crate::autograd::hooks::HookStorage<T>> {
&self.inner.hooks
}
pub fn register_hook<F>(&self, func: F) -> FerrotorchResult<crate::autograd::hooks::HookHandle>
where
F: Fn(&Tensor<T>) -> Option<Tensor<T>> + Send + Sync + 'static,
{
let mut guard = self
.inner
.hooks
.lock()
.map_err(|e| FerrotorchError::LockPoisoned {
message: format!("hook storage mutex: {e}"),
})?;
Ok(guard.add_grad_hook(func))
}
pub fn register_post_accumulate_grad_hook<F>(
&self,
func: F,
) -> FerrotorchResult<crate::autograd::hooks::HookHandle>
where
F: Fn(&Tensor<T>) + Send + Sync + 'static,
{
let mut guard = self
.inner
.hooks
.lock()
.map_err(|e| FerrotorchError::LockPoisoned {
message: format!("hook storage mutex: {e}"),
})?;
Ok(guard.add_post_accumulate_hook(func))
}
pub fn remove_hook(
&self,
handle: crate::autograd::hooks::HookHandle,
) -> FerrotorchResult<bool> {
let mut guard = self
.inner
.hooks
.lock()
.map_err(|e| FerrotorchError::LockPoisoned {
message: format!("hook storage mutex: {e}"),
})?;
Ok(guard.remove(handle))
}
pub fn grad(&self) -> FerrotorchResult<Option<Tensor<T>>> {
let guard = self
.inner
.grad
.lock()
.map_err(|e| FerrotorchError::LockPoisoned {
message: format!("grad mutex: {e}"),
})?;
Ok(guard.as_ref().map(|b| (**b).clone()))
}
pub fn set_grad(&self, grad: Option<Tensor<T>>) -> FerrotorchResult<()> {
let mut guard = self
.inner
.grad
.lock()
.map_err(|e| FerrotorchError::LockPoisoned {
message: format!("grad mutex: {e}"),
})?;
*guard = grad.map(Box::new);
Ok(())
}
pub fn zero_grad(&self) -> FerrotorchResult<()> {
self.set_grad(None)
}
pub(crate) fn accumulate_grad(&self, incoming: &Tensor<T>) -> FerrotorchResult<()> {
let mut guard = self
.inner
.grad
.lock()
.map_err(|e| FerrotorchError::LockPoisoned {
message: format!("grad mutex: {e}"),
})?;
match guard.as_mut() {
None => {
let (storage, shape) = incoming.clone().into_storage_and_shape()?;
let tensor = Tensor::from_storage(storage, shape, false)?;
*guard = Some(Box::new(tensor));
}
Some(existing) => {
if existing.is_cuda() && incoming.is_cuda() {
let backend = crate::gpu_dispatch::gpu_backend()
.ok_or(FerrotorchError::DeviceUnavailable)?;
if existing.numel() != incoming.numel() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"gradient accumulation shape mismatch: {:?} vs {:?}",
existing.shape(),
incoming.shape()
),
});
}
let a_handle = existing.gpu_handle()?;
let b_handle = incoming.gpu_handle()?;
let sum_handle = if std::mem::size_of::<T>() == 4 {
backend.add_f32(a_handle, b_handle)?
} else {
backend.add_f64(a_handle, b_handle)?
};
let storage = TensorStorage::gpu(sum_handle);
let combined = Tensor::from_storage(storage, existing.shape().to_vec(), false)?;
*guard = Some(Box::new(combined));
} else {
let incoming_data = incoming.data_vec()?;
let mut buf = existing.data_vec()?;
if buf.len() != incoming_data.len() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"gradient accumulation shape mismatch: {:?} vs {:?}",
existing.shape(),
incoming.shape()
),
});
}
for (e, &n) in buf.iter_mut().zip(incoming_data.iter()) {
*e += n;
}
let device = existing.device();
let combined = Tensor::from_storage(
TensorStorage::on_device(buf, device)?,
existing.shape().to_vec(),
false,
)?;
*guard = Some(Box::new(combined));
}
}
}
Ok(())
}
pub fn data(&self) -> FerrotorchResult<&[T]> {
if self.inner.storage.is_gpu() {
return Err(FerrotorchError::GpuTensorNotAccessible);
}
if self.inner.storage.is_cubecl() {
return Err(FerrotorchError::GpuTensorNotAccessible);
}
if self.inner.storage.is_meta() {
return Err(FerrotorchError::InvalidArgument {
message: "cannot read data from a meta tensor; meta tensors carry shape only. \
Call .to(Device::Cpu) to materialize, or use .shape() / .numel() / .device() \
for metadata access."
.into(),
});
}
if !self.is_contiguous() {
return Err(FerrotorchError::InvalidArgument {
message: "tensor is not contiguous; call .contiguous() or use .data_vec()".into(),
});
}
let slice = self.inner.storage.try_as_slice()?;
let end = self.inner.offset + self.numel();
if end > slice.len() {
return Err(FerrotorchError::InvalidArgument {
message: "tensor view extends beyond storage".into(),
});
}
Ok(&slice[self.inner.offset..end])
}
#[inline]
pub fn data_ref(&self) -> FerrotorchResult<&[T]> {
self.data()
}
pub fn data_vec(&self) -> FerrotorchResult<Vec<T>> {
if self.inner.storage.is_meta() {
return Err(FerrotorchError::InvalidArgument {
message: "cannot read data from a meta tensor; meta tensors carry shape only. \
Call .to(Device::Cpu) to materialize, or use .shape() / .numel() / .device() \
for metadata access."
.into(),
});
}
if self.is_cuda() || self.inner.storage.is_cubecl() {
let cpu_tensor = self.cpu()?;
Ok(cpu_tensor.data()?.to_vec())
} else if self.is_contiguous() {
Ok(self.data()?.to_vec())
} else {
let slice = self.inner.storage.try_as_slice()?;
let shape = &self.inner.shape;
let strides = &self.inner.strides;
let offset = self.inner.offset;
let numel = self.numel();
let ndim = shape.len();
let mut result = Vec::with_capacity(numel);
let mut indices = vec![0usize; ndim];
for _ in 0..numel {
let mut flat = offset as isize;
for d in 0..ndim {
flat += indices[d] as isize * strides[d];
}
result.push(slice[flat as usize]);
for d in (0..ndim).rev() {
indices[d] += 1;
if indices[d] < shape[d] {
break;
}
indices[d] = 0;
}
}
Ok(result)
}
}
pub fn into_storage_and_shape(self) -> FerrotorchResult<(TensorStorage<T>, Vec<usize>)> {
if !self.is_contiguous() {
let data = self.data_vec()?;
let shape = self.shape().to_vec();
let device = self.device();
return Ok((TensorStorage::on_device(data, device)?, shape));
}
let shape = self.inner.shape.clone();
let offset = self.inner.offset;
let numel: usize = shape.iter().product();
match Arc::try_unwrap(self.inner) {
Ok(inner) => {
match Arc::try_unwrap(inner.storage) {
Ok(storage) if offset == 0 && storage.len() == numel => {
Ok((storage, shape))
}
Ok(storage) => {
let sub = storage.try_clone_subregion(offset, numel)?;
Ok((sub, shape))
}
Err(arc_storage) => {
let sub = arc_storage.try_clone_subregion(offset, numel)?;
Ok((sub, shape))
}
}
}
Err(arc_inner) => {
let sub = arc_inner
.storage
.try_clone_subregion(arc_inner.offset, numel)?;
Ok((sub, shape))
}
}
}
pub fn to(&self, device: Device) -> FerrotorchResult<Tensor<T>> {
if self.device() == device {
return Ok(self.clone());
}
let needs_grad_fn =
self.requires_grad() && !self.is_leaf() && crate::autograd::no_grad::is_grad_enabled();
match (self.device(), device) {
(Device::Cpu, Device::Cuda(ordinal)) => {
let contiguous_self = if self.is_contiguous() {
self.clone()
} else {
crate::methods::contiguous_t(self)?
};
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let cpu_data = contiguous_self.data()?;
let bytes = unsafe {
std::slice::from_raw_parts(
cpu_data.as_ptr().cast::<u8>(),
std::mem::size_of_val(cpu_data),
)
};
let handle = backend.cpu_to_gpu(bytes, T::dtype(), ordinal)?;
let storage = TensorStorage::gpu(handle);
if needs_grad_fn {
let grad_fn = Arc::new(ToDeviceBackward {
source: self.clone(),
});
Tensor::from_operation(storage, self.shape().to_vec(), grad_fn)
} else {
Tensor::from_storage(storage, self.shape().to_vec(), self.requires_grad())
}
}
(Device::Cuda(_), Device::Cpu) => {
use std::any::TypeId;
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let handle = self.gpu_handle()?;
let numel = self.numel();
let needs_materialize =
!self.is_contiguous() || self.storage_offset() != 0 || handle.len() != numel;
let bytes = if needs_materialize
&& self.shape().len() <= 8
&& (TypeId::of::<T>() == TypeId::of::<f32>()
|| TypeId::of::<T>() == TypeId::of::<f64>())
{
let view_shape = self.shape().to_vec();
let src_strides = self.strides().to_vec();
let src_offset = self.storage_offset();
let materialized = if TypeId::of::<T>() == TypeId::of::<f32>() {
backend.strided_copy_f32(handle, &view_shape, &src_strides, src_offset)?
} else {
backend.strided_copy_f64(handle, &view_shape, &src_strides, src_offset)?
};
backend.gpu_to_cpu(&materialized)?
} else {
backend.gpu_to_cpu(handle)?
};
let data: Vec<T> = unsafe {
let mut bytes = std::mem::ManuallyDrop::new(bytes);
let len = bytes.len() / std::mem::size_of::<T>();
let cap = bytes.capacity() / std::mem::size_of::<T>();
Vec::from_raw_parts(bytes.as_mut_ptr().cast::<T>(), len, cap)
};
let storage = TensorStorage::cpu(data);
if needs_grad_fn {
let grad_fn = Arc::new(ToDeviceBackward {
source: self.clone(),
});
Tensor::from_operation(storage, self.shape().to_vec(), grad_fn)
} else {
Tensor::from_storage(storage, self.shape().to_vec(), self.requires_grad())
}
}
(Device::Cuda(a), Device::Cuda(b)) if a != b => {
let cpu = self.to(Device::Cpu)?;
cpu.to(Device::Cuda(b))
}
(Device::Cpu, Device::Xpu(_)) => Err(FerrotorchError::InvalidArgument {
message: "CPU→XPU transfer requires a CubeRuntime. \
Use ferrotorch_xpu::make_xpu_tensor or \
ferrotorch_xpu::XpuDevice::upload instead. Issue #673."
.into(),
}),
(Device::Xpu(_), Device::Cpu) => {
let handle = self.inner.storage.cubecl_handle().ok_or_else(|| {
FerrotorchError::InvalidArgument {
message: "XPU→CPU transfer: storage does not contain a CubeCL handle. \
This tensor may have been created before issue #673 was applied."
.into(),
}
})?;
let host_f32 = handle.read_to_host()?;
let data: Vec<T> = {
if std::mem::size_of::<T>() != std::mem::size_of::<f32>() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"XPU→CPU: expected f32 storage (size 4), got size {}; \
only f32 XPU tensors are supported. Issue #673.",
std::mem::size_of::<T>()
),
});
}
unsafe {
let mut md = std::mem::ManuallyDrop::new(host_f32);
Vec::from_raw_parts(md.as_mut_ptr().cast::<T>(), md.len(), md.capacity())
}
};
let storage = TensorStorage::cpu(data);
if needs_grad_fn {
let grad_fn = Arc::new(ToDeviceBackward {
source: self.clone(),
});
Tensor::from_operation(storage, self.shape().to_vec(), grad_fn)
} else {
Tensor::from_storage(storage, self.shape().to_vec(), self.requires_grad())
}
}
(Device::Xpu(a), Device::Xpu(b)) if a != b => {
let cpu = self.to(Device::Cpu)?;
cpu.to(Device::Xpu(b))
}
(Device::Cuda(_), Device::Xpu(_)) | (Device::Xpu(_), Device::Cuda(_)) => {
let cpu = self.to(Device::Cpu)?;
cpu.to(device)
}
(_, Device::Meta) => {
let storage = TensorStorage::meta(self.numel());
Tensor::from_storage(storage, self.shape().to_vec(), self.requires_grad())
}
(Device::Meta, _) => Err(FerrotorchError::InvalidArgument {
message: format!(
"cannot move a meta tensor to {device} -- meta tensors carry no data. \
Construct a real tensor on {device} via creation::zeros/randn/etc."
),
}),
_ => Ok(self.clone()),
}
}
pub fn to_pinned(&self, device: Device) -> FerrotorchResult<Tensor<T>> {
if self.device() == device {
return Ok(self.clone());
}
match (self.device(), device) {
(Device::Cpu, Device::Cuda(_)) => {
let needs_grad_fn = self.requires_grad()
&& !self.is_leaf()
&& crate::autograd::no_grad::is_grad_enabled();
let contiguous_self = if self.is_contiguous() {
self.clone()
} else {
crate::methods::contiguous_t(self)?
};
let cpu_data = contiguous_self.data()?;
let owned: Vec<T> = cpu_data.to_vec();
let storage = TensorStorage::on_device_pinned(owned, device)?;
if needs_grad_fn {
let grad_fn = Arc::new(ToDeviceBackward {
source: self.clone(),
});
Tensor::from_operation(storage, self.shape().to_vec(), grad_fn)
} else {
Tensor::from_storage(storage, self.shape().to_vec(), self.requires_grad())
}
}
_ => self.to(device),
}
}
pub fn cuda(&self) -> FerrotorchResult<Tensor<T>> {
self.to(Device::Cuda(0))
}
pub fn cpu(&self) -> FerrotorchResult<Tensor<T>> {
self.to(Device::Cpu)
}
pub fn to_dtype<U: Float>(&self) -> FerrotorchResult<Tensor<U>> {
use std::any::TypeId;
if TypeId::of::<T>() == TypeId::of::<U>() {
let cloned = self.clone();
return Ok(unsafe {
let md = std::mem::ManuallyDrop::new(cloned);
std::mem::transmute_copy::<Tensor<T>, Tensor<U>>(&md)
});
}
match self.device() {
Device::Cpu => {
let materialised = if self.is_contiguous() {
self.clone()
} else {
crate::methods::contiguous_t(self)?
};
let src = materialised.data()?;
let mut out: Vec<U> = Vec::with_capacity(src.len());
for (i, &v) in src.iter().enumerate() {
out.push(crate::numeric_cast::cast::<T, U>(v).map_err(|_| {
FerrotorchError::InvalidArgument {
message: format!(
"Tensor::to_dtype: element {i} = {v:?} not representable in {}",
U::dtype()
),
}
})?);
}
let storage = TensorStorage::cpu(out);
Tensor::<U>::from_storage(storage, self.shape().to_vec(), false)
}
Device::Cuda(_) => {
let materialised = if self.is_contiguous() {
self.clone()
} else {
crate::methods::contiguous_t(self)?
};
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let src_handle = materialised.gpu_handle()?;
let new_handle = backend.cast_f_to_f(src_handle, U::dtype())?;
let storage = TensorStorage::gpu(new_handle);
Tensor::<U>::from_storage(storage, self.shape().to_vec(), false)
}
_ => Err(FerrotorchError::InvalidArgument {
message: format!(
"Tensor::to_dtype: unsupported source device {:?}",
self.device()
),
}),
}
}
#[inline]
pub fn is_cpu(&self) -> bool {
self.device().is_cpu()
}
#[inline]
pub fn is_meta(&self) -> bool {
self.device().is_meta()
}
#[inline]
pub fn meta_fill_value(&self) -> Option<&T> {
self.inner.storage.meta_fill_value()
}
#[inline]
pub fn is_cuda(&self) -> bool {
self.device().is_cuda()
}
#[inline]
pub fn is_xpu(&self) -> bool {
matches!(self.device(), crate::device::Device::Xpu(_))
}
pub fn gpu_handle(&self) -> FerrotorchResult<&crate::gpu_dispatch::GpuBufferHandle> {
self.inner
.storage
.gpu_handle()
.ok_or(FerrotorchError::InvalidArgument {
message: "tensor is on CPU, not GPU".into(),
})
}
#[inline]
pub fn masked_fill(
&self,
mask: &crate::bool_tensor::BoolTensor,
value: T,
) -> FerrotorchResult<Tensor<T>> {
crate::grad_fns::indexing::masked_fill_bt(self, mask, value)
}
#[inline]
pub fn masked_select(
&self,
mask: &crate::bool_tensor::BoolTensor,
) -> FerrotorchResult<Tensor<T>> {
crate::ops::indexing::masked_select(self, mask)
}
#[allow(clippy::mut_from_ref)]
pub unsafe fn data_mut(&self) -> FerrotorchResult<&mut [T]> {
if !self.is_contiguous() {
return Err(FerrotorchError::InvalidArgument {
message: "data_mut requires a contiguous tensor".into(),
});
}
let storage_ptr = Arc::as_ptr(&self.inner.storage).cast_mut();
let storage = unsafe { &mut *storage_ptr };
let slice = storage.try_as_mut_slice()?;
let end = self.inner.offset + self.numel();
if end > slice.len() {
return Err(FerrotorchError::InvalidArgument {
message: "tensor view extends beyond storage".into(),
});
}
Ok(&mut slice[self.inner.offset..end])
}
pub unsafe fn update_data(&self, new_data: &[T]) -> FerrotorchResult<()> {
let numel = self.numel();
if new_data.len() != numel {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"update_data: new data has {} elements but tensor has {}",
new_data.len(),
numel,
),
});
}
let storage_ptr = Arc::as_ptr(&self.inner.storage).cast_mut();
let storage = unsafe { &mut *storage_ptr };
if storage.is_gpu() {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let ordinal = match storage.device() {
Device::Cuda(o) => o,
_ => unreachable!(),
};
let bytes: &[u8] = unsafe {
std::slice::from_raw_parts(
new_data.as_ptr().cast::<u8>(),
std::mem::size_of_val(new_data),
)
};
let new_handle = backend.cpu_to_gpu(bytes, T::dtype(), ordinal)?;
storage.data = crate::storage::StorageBuffer::Gpu(new_handle);
} else {
let slice = storage.try_as_mut_slice()?;
let offset = self.inner.offset;
slice[offset..offset + numel].copy_from_slice(new_data);
}
Ok(())
}
pub unsafe fn update_storage_and_shape(
&self,
new_storage: TensorStorage<T>,
new_shape: Vec<usize>,
) -> FerrotorchResult<()> {
let new_numel: usize = new_shape.iter().product();
if new_storage.len() != new_numel {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"update_storage_and_shape: new storage has {} elements but \
new shape {:?} requires {}",
new_storage.len(),
new_shape,
new_numel,
),
});
}
let new_strides = c_contiguous_strides(&new_shape);
let storage_ptr = Arc::as_ptr(&self.inner.storage).cast_mut();
let old_storage = unsafe { std::ptr::replace(storage_ptr, new_storage) };
drop(old_storage);
let inner_ptr = Arc::as_ptr(&self.inner).cast_mut();
unsafe {
(*inner_ptr).shape = new_shape;
(*inner_ptr).strides = new_strides;
(*inner_ptr).offset = 0;
}
Ok(())
}
pub unsafe fn update_storage(&self, new_storage: TensorStorage<T>) -> FerrotorchResult<()> {
let numel = self.numel();
if new_storage.len() != numel {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"update_storage: new storage has {} elements but tensor has {}",
new_storage.len(),
numel,
),
});
}
let storage_ptr = Arc::as_ptr(&self.inner.storage).cast_mut();
let old = unsafe { std::ptr::replace(storage_ptr, new_storage) };
drop(old);
Ok(())
}
pub fn with_gpu_handle_mut<R>(
&self,
f: impl FnOnce(&mut crate::gpu_dispatch::GpuBufferHandle) -> FerrotorchResult<R>,
) -> FerrotorchResult<R> {
let storage_ptr = Arc::as_ptr(&self.inner.storage).cast_mut();
let storage: &mut TensorStorage<T> = unsafe { &mut *storage_ptr };
let handle = storage
.gpu_handle_mut()
.ok_or(FerrotorchError::DeviceUnavailable)?;
f(handle)
}
pub fn detach(&self) -> Self {
Self {
inner: Arc::new(TensorInner {
id: TensorId::next(),
storage: Arc::clone(&self.inner.storage),
shape: self.inner.shape.clone(),
strides: self.inner.strides.clone(),
offset: self.inner.offset,
grad: Mutex::new(None),
grad_fn: None,
requires_grad: false,
is_leaf: true,
hooks: Mutex::new(crate::autograd::hooks::HookStorage::new()),
}),
}
}
pub fn requires_grad_(self, requires_grad: bool) -> Self {
Self {
inner: Arc::new(TensorInner {
id: self.inner.id,
storage: Arc::clone(&self.inner.storage),
shape: self.inner.shape.clone(),
strides: self.inner.strides.clone(),
offset: self.inner.offset,
grad: Mutex::new(None),
grad_fn: self.inner.grad_fn.clone(),
requires_grad,
is_leaf: self.inner.is_leaf,
hooks: Mutex::new(crate::autograd::hooks::HookStorage::new()),
}),
}
}
pub fn is_contiguous(&self) -> bool {
if self.inner.shape.is_empty() {
return true;
}
let mut expected_stride: isize = 1;
for d in (0..self.ndim()).rev() {
if self.inner.shape[d] == 0 {
return true;
}
if self.inner.shape[d] != 1 && self.inner.strides[d] != expected_stride {
return false;
}
if self.inner.shape[d] != 1 {
expected_stride *= self.inner.shape[d] as isize;
}
}
true
}
pub fn is_contiguous_for(&self, format: MemoryFormat) -> bool {
match format {
MemoryFormat::Contiguous => self.is_contiguous(),
MemoryFormat::ChannelsLast => {
if self.ndim() != 4 {
return false;
}
let expected = channels_last_strides(&self.inner.shape);
strides_match_with_size1(&self.inner.shape, &self.inner.strides, &expected)
}
MemoryFormat::ChannelsLast3d => {
if self.ndim() != 5 {
return false;
}
let expected = channels_last_3d_strides(&self.inner.shape);
strides_match_with_size1(&self.inner.shape, &self.inner.strides, &expected)
}
}
}
pub fn to_memory_format(&self, format: MemoryFormat) -> FerrotorchResult<Self> {
if self.is_contiguous_for(format) {
return Ok(self.clone());
}
self.materialize_format(format)
}
pub fn contiguous_in(&self, format: MemoryFormat) -> FerrotorchResult<Self> {
self.to_memory_format(format)
}
fn materialize_format(&self, format: MemoryFormat) -> FerrotorchResult<Self> {
let shape = &self.inner.shape;
let ndim = shape.len();
match format {
MemoryFormat::ChannelsLast if ndim != 4 => {
return Err(FerrotorchError::InvalidArgument {
message: format!("ChannelsLast requires a 4D tensor, got {ndim}D"),
});
}
MemoryFormat::ChannelsLast3d if ndim != 5 => {
return Err(FerrotorchError::InvalidArgument {
message: format!("ChannelsLast3d requires a 5D tensor, got {ndim}D"),
});
}
_ => {}
}
let target_strides = match format {
MemoryFormat::Contiguous => c_contiguous_strides(shape),
MemoryFormat::ChannelsLast => channels_last_strides(shape),
MemoryFormat::ChannelsLast3d => channels_last_3d_strides(shape),
};
if self.is_cuda()
&& ndim <= 8
&& let Some(backend) = crate::gpu_dispatch::gpu_backend()
{
use std::any::TypeId;
let perm = format_permutation(format, ndim);
let permuted_shape: Vec<usize> = perm.iter().map(|&d| shape[d]).collect();
let permuted_src_strides: Vec<isize> =
perm.iter().map(|&d| self.inner.strides[d]).collect();
let in_handle = self.gpu_handle()?;
let src_offset = self.inner.offset;
let out_handle = if TypeId::of::<T>() == TypeId::of::<f32>() {
backend.strided_copy_f32(
in_handle,
&permuted_shape,
&permuted_src_strides,
src_offset,
)
} else if TypeId::of::<T>() == TypeId::of::<f64>() {
backend.strided_copy_f64(
in_handle,
&permuted_shape,
&permuted_src_strides,
src_offset,
)
} else {
return self.materialize_format_cpu(format, target_strides);
};
if let Ok(handle) = out_handle {
let storage = TensorStorage::gpu(handle);
return Ok(Self {
inner: Arc::new(TensorInner {
id: TensorId::next(),
storage: Arc::new(storage),
shape: shape.clone(),
strides: target_strides,
offset: 0,
grad: Mutex::new(None),
grad_fn: None,
requires_grad: self.inner.requires_grad,
is_leaf: true,
hooks: Mutex::new(crate::autograd::hooks::HookStorage::new()),
}),
});
}
}
self.materialize_format_cpu(format, target_strides)
}
fn materialize_format_cpu(
&self,
_format: MemoryFormat,
target_strides: Vec<isize>,
) -> FerrotorchResult<Self> {
let shape = &self.inner.shape;
let ndim = shape.len();
let numel = self.numel();
let src_strides = &self.inner.strides;
let offset = self.inner.offset;
let device = self.device();
let src_owned: Vec<T>;
let src_ref: &[T] = if self.is_cuda() {
src_owned = self.data_vec()?;
&src_owned
} else {
self.inner.storage.try_as_slice()?
};
let mut dst = vec![<T as num_traits::Zero>::zero(); numel];
let mut indices = vec![0usize; ndim];
for _ in 0..numel {
let mut src_flat = offset as isize;
let mut dst_flat: isize = 0;
for d in 0..ndim {
src_flat += indices[d] as isize * src_strides[d];
dst_flat += indices[d] as isize * target_strides[d];
}
dst[dst_flat as usize] = src_ref[src_flat as usize];
for d in (0..ndim).rev() {
indices[d] += 1;
if indices[d] < shape[d] {
break;
}
indices[d] = 0;
}
}
let storage = TensorStorage::on_device(dst, device)?;
Ok(Self {
inner: Arc::new(TensorInner {
id: TensorId::next(),
storage: Arc::new(storage),
shape: shape.clone(),
strides: target_strides,
offset: 0,
grad: Mutex::new(None),
grad_fn: None,
requires_grad: self.inner.requires_grad,
is_leaf: true,
hooks: Mutex::new(crate::autograd::hooks::HookStorage::new()),
}),
})
}
#[inline]
pub fn is_scalar(&self) -> bool {
self.inner.shape.is_empty()
}
pub fn item(&self) -> FerrotorchResult<T> {
if !self.is_scalar() && self.numel() != 1 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"item() requires a scalar or single-element tensor, got shape {:?}",
self.shape()
),
});
}
let data = self.data()?;
Ok(data[0])
}
pub fn is_same(&self, other: &Self) -> bool {
self.inner.id == other.inner.id
}
#[inline]
pub(crate) fn inner_refcount(&self) -> usize {
Arc::strong_count(&self.inner)
}
#[inline]
pub(crate) fn storage_refcount(&self) -> usize {
Arc::strong_count(&self.inner.storage)
}
#[inline]
pub fn inner_storage_arc(&self) -> &Arc<TensorStorage<T>> {
&self.inner.storage
}
#[cfg(test)]
pub(crate) fn shares_storage(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.inner.storage, &other.inner.storage)
}
}
fn format_permutation(format: MemoryFormat, ndim: usize) -> Vec<usize> {
match format {
MemoryFormat::ChannelsLast if ndim == 4 => vec![0, 2, 3, 1],
MemoryFormat::ChannelsLast3d if ndim == 5 => vec![0, 2, 3, 4, 1],
_ => (0..ndim).collect(),
}
}
fn strides_match_with_size1(shape: &[usize], actual: &[isize], expected: &[isize]) -> bool {
if actual.len() != expected.len() {
return false;
}
for i in 0..shape.len() {
if shape[i] != 1 && actual[i] != expected[i] {
return false;
}
}
true
}
impl<T: Float> Clone for Tensor<T> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl<T: Float> fmt::Debug for Tensor<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Tensor")
.field("id", &self.inner.id)
.field("shape", &self.inner.shape)
.field("device", &self.device())
.field("requires_grad", &self.inner.requires_grad)
.field("is_leaf", &self.inner.is_leaf)
.field("grad_fn", &self.inner.grad_fn.as_ref().map(|gf| gf.name()))
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::TensorStorage;
#[test]
fn test_tensor_from_storage() {
let storage = TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]);
let t = Tensor::from_storage(storage, vec![2, 3], false).unwrap();
assert_eq!(t.shape(), &[2, 3]);
assert_eq!(t.strides(), &[3, 1]);
assert_eq!(t.ndim(), 2);
assert_eq!(t.numel(), 6);
assert!(t.is_contiguous());
assert!(t.is_leaf());
assert!(!t.requires_grad());
assert_eq!(t.device(), Device::Cpu);
}
#[test]
fn test_tensor_shape_mismatch() {
let storage = TensorStorage::cpu(vec![1.0f32, 2.0, 3.0]);
let result = Tensor::from_storage(storage, vec![2, 3], false);
assert!(result.is_err());
}
#[test]
fn test_tensor_data_access() {
let storage = TensorStorage::cpu(vec![1.0f64, 2.0, 3.0]);
let t = Tensor::from_storage(storage, vec![3], false).unwrap();
assert_eq!(t.data().unwrap(), &[1.0, 2.0, 3.0]);
}
#[test]
#[allow(clippy::float_cmp)]
fn test_tensor_scalar() {
let storage = TensorStorage::cpu(vec![42.0f32]);
let t = Tensor::from_storage(storage, vec![], false).unwrap();
assert!(t.is_scalar());
assert_eq!(t.item().unwrap(), 42.0);
}
#[test]
fn test_tensor_detach() {
let storage = TensorStorage::cpu(vec![1.0f32, 2.0]);
let t = Tensor::from_storage(storage, vec![2], true).unwrap();
assert!(t.requires_grad());
let d = t.detach();
assert!(!d.requires_grad());
assert!(d.is_leaf());
assert!(d.grad_fn().is_none());
}
#[test]
fn test_tensor_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<Tensor<f32>>();
assert_send_sync::<Tensor<f64>>();
}
#[test]
fn test_clone_shares_identity() {
let storage = TensorStorage::cpu(vec![1.0f32, 2.0]);
let t = Tensor::from_storage(storage, vec![2], true).unwrap();
let t2 = t.clone();
assert!(t.is_same(&t2));
assert_eq!(t.id(), t2.id());
}
#[test]
fn test_view_operation_shares_storage() {
use crate::grad_fns::shape::FlattenBackward;
let storage = TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]);
let t = Tensor::from_storage(storage, vec![2, 3], true).unwrap();
let grad_fn = Arc::new(FlattenBackward::new(t.clone(), t.shape().to_vec()));
let view = t.view_operation(vec![6], grad_fn).unwrap();
assert!(t.shares_storage(&view), "view_operation must share storage");
assert!(
!t.is_same(&view),
"view_operation creates new tensor identity"
);
}
#[test]
fn test_clone_shares_grad() {
let storage = TensorStorage::cpu(vec![1.0f32, 2.0, 3.0]);
let t = Tensor::from_storage(storage, vec![3], true).unwrap();
let t2 = t.clone();
let g =
Tensor::from_storage(TensorStorage::cpu(vec![0.1, 0.2, 0.3]), vec![3], false).unwrap();
t.accumulate_grad(&g).unwrap();
let grad = t2.grad().unwrap().unwrap();
let data = grad.data().unwrap();
assert!((data[0] - 0.1).abs() < 1e-7);
}
#[test]
fn test_tensor_grad_accumulation() {
let storage = TensorStorage::cpu(vec![1.0f32, 2.0, 3.0]);
let t = Tensor::from_storage(storage, vec![3], true).unwrap();
assert!(t.grad().unwrap().is_none());
let g1 =
Tensor::from_storage(TensorStorage::cpu(vec![0.1, 0.2, 0.3]), vec![3], false).unwrap();
t.accumulate_grad(&g1).unwrap();
let grad = t.grad().unwrap().unwrap();
let data = grad.data().unwrap();
assert!((data[0] - 0.1).abs() < 1e-7);
let g2 =
Tensor::from_storage(TensorStorage::cpu(vec![1.0, 1.0, 1.0]), vec![3], false).unwrap();
t.accumulate_grad(&g2).unwrap();
let grad = t.grad().unwrap().unwrap();
let data = grad.data().unwrap();
assert!((data[0] - 1.1).abs() < 1e-6);
assert!((data[1] - 1.2).abs() < 1e-6);
assert!((data[2] - 1.3).abs() < 1e-6);
}
#[test]
fn to_dtype_same_dtype_is_zero_copy_clone() {
let storage = TensorStorage::cpu(vec![1.0f32, 2.0, 3.0]);
let t = Tensor::from_storage(storage, vec![3], false).unwrap();
let same: Tensor<f32> = t.to_dtype::<f32>().unwrap();
assert_eq!(same.shape(), &[3usize]);
assert_eq!(same.data().unwrap(), &[1.0_f32, 2.0, 3.0]);
assert_eq!(same.id(), t.id());
}
#[test]
fn to_dtype_cpu_f32_to_bf16_round_trips_bf16_representable_values() {
let storage = TensorStorage::cpu(vec![1.0f32, -2.0, 0.5, 100.0]);
let t = Tensor::from_storage(storage, vec![4], false).unwrap();
let bf16 = t.to_dtype::<half::bf16>().unwrap();
assert_eq!(bf16.shape(), &[4usize]);
let bits: Vec<u16> = bf16.data().unwrap().iter().map(|b| b.to_bits()).collect();
let expect: Vec<u16> = [1.0f32, -2.0, 0.5, 100.0]
.iter()
.map(|&v| half::bf16::from_f32(v).to_bits())
.collect();
assert_eq!(bits, expect);
}
#[test]
fn to_dtype_cpu_bf16_to_f32_widens_exactly() {
let bf16_data: Vec<half::bf16> = [1.0f32, 1.5, -2.25, 100.0]
.iter()
.map(|&v| half::bf16::from_f32(v))
.collect();
let storage = TensorStorage::cpu(bf16_data.clone());
let t = Tensor::from_storage(storage, vec![4], false).unwrap();
let f32_t = t.to_dtype::<f32>().unwrap();
let got = f32_t.data().unwrap();
let want: Vec<f32> = bf16_data.iter().map(|b| b.to_f32()).collect();
for (g, w) in got.iter().zip(want.iter()) {
assert_eq!(g.to_bits(), w.to_bits());
}
}
#[test]
fn to_dtype_cpu_saturating_cast_errors() {
let big = TensorStorage::cpu(vec![1e300_f64, -1e300_f64]);
let t = Tensor::from_storage(big, vec![2], false).unwrap();
let result = t.to_dtype::<f32>();
assert!(
result.is_err(),
"expected saturation error casting 1e300_f64 to f32, got Ok"
);
}
#[test]
fn to_dtype_cpu_preserves_shape() {
let storage = TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]);
let t = Tensor::from_storage(storage, vec![2, 3], false).unwrap();
let bf16 = t.to_dtype::<half::bf16>().unwrap();
assert_eq!(bf16.shape(), &[2usize, 3]);
assert_eq!(bf16.numel(), 6);
}
}