use crate::{Device, Result, Shape};
use scirs2_core::ndarray::ArrayD;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct Tensor<T> {
pub storage: TensorStorage<T>,
pub(in crate::tensor) shape: Shape,
pub(in crate::tensor) device: Device,
pub(in crate::tensor) requires_grad: bool,
pub(in crate::tensor) grad: Option<Arc<Tensor<T>>>,
}
#[derive(Debug, Clone)]
pub enum TensorStorage<T> {
Cpu(ArrayD<T>),
#[cfg(feature = "gpu")]
Gpu(crate::gpu::buffer::GpuBuffer<T>),
}
impl<T> Tensor<T> {
pub fn shape(&self) -> &Shape {
&self.shape
}
pub fn device(&self) -> &Device {
&self.device
}
pub fn dtype(&self) -> crate::DType
where
T: 'static,
{
crate::dtype_from_type::<T>()
}
pub fn requires_grad(&self) -> bool {
self.requires_grad
}
pub fn set_requires_grad(&mut self, requires_grad: bool) {
self.requires_grad = requires_grad;
}
pub fn grad(&self) -> Option<&Tensor<T>> {
self.grad.as_ref().map(|g| g.as_ref())
}
pub fn set_grad(&mut self, grad: Option<Tensor<T>>) {
self.grad = grad.map(Arc::new);
}
pub fn data(&self) -> &[T] {
match &self.storage {
TensorStorage::Cpu(arr) => {
arr.as_slice().unwrap_or_else(|| {
panic!("Tensor data is not contiguous. Use to_owned() or iter() for non-contiguous access.")
})
}
#[cfg(feature = "gpu")]
TensorStorage::Gpu(_) => {
panic!("Cannot access GPU tensor data directly. Use to_cpu() first.")
}
}
}
pub fn get(&self, index: &[usize]) -> Option<T>
where
T: Clone,
{
match &self.storage {
TensorStorage::Cpu(arr) => {
if index.len() != arr.ndim() {
return None;
}
arr.get(index).cloned()
}
#[cfg(feature = "gpu")]
_ => None,
}
}
pub fn as_slice(&self) -> Option<&[T]> {
match &self.storage {
TensorStorage::Cpu(array) => array.as_slice(),
#[cfg(feature = "gpu")]
TensorStorage::Gpu(_) => None,
}
}
pub fn is_empty(&self) -> bool {
self.shape.elements() == 0
}
pub fn memory_usage(&self) -> usize {
let element_size = std::mem::size_of::<T>();
self.shape.elements() * element_size
}
pub fn same_shape(&self, other: &Self) -> bool {
self.shape == other.shape
}
pub fn is_broadcastable_with(&self, other: &Self) -> bool {
let dims1 = self.shape.dims();
let dims2 = other.shape.dims();
let max_dims = dims1.len().max(dims2.len());
for i in 0..max_dims {
let dim1 = dims1
.get(dims1.len().saturating_sub(i + 1))
.copied()
.unwrap_or(1);
let dim2 = dims2
.get(dims2.len().saturating_sub(i + 1))
.copied()
.unwrap_or(1);
if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
return false;
}
}
true
}
pub fn summary(&self) -> String
where
T: std::fmt::Display + Clone,
{
format!(
"Tensor<{}>: shape={:?}, device={:?}, numel={}, memory={}B, requires_grad={}",
std::any::type_name::<T>(),
self.shape.dims(),
self.device,
self.shape.elements(),
self.memory_usage(),
self.requires_grad
)
}
pub fn size(&self) -> usize {
self.shape.size()
}
pub fn numel(&self) -> usize {
self.shape.size()
}
pub fn rank(&self) -> usize {
self.shape.rank()
}
pub fn ndim(&self) -> usize {
self.shape.rank()
}
pub fn is_scalar(&self) -> bool {
self.shape.rank() == 0
}
pub fn is_vector(&self) -> bool {
self.shape.rank() == 1
}
pub fn is_matrix(&self) -> bool {
self.shape.rank() == 2
}
pub fn is_contiguous(&self) -> bool {
match &self.storage {
TensorStorage::Cpu(arr) => arr.is_standard_layout(),
#[cfg(feature = "gpu")]
TensorStorage::Gpu(_) => true, }
}
}
impl<T> Tensor<T>
where
T: Clone + bytemuck::Pod + bytemuck::Zeroable + Send + Sync + 'static,
{
pub fn map_inplace<F>(&mut self, f: F) -> Result<()>
where
F: Fn(&T) -> T,
{
match &mut self.storage {
TensorStorage::Cpu(arr) => {
arr.mapv_inplace(|x| f(&x));
Ok(())
}
#[cfg(feature = "gpu")]
TensorStorage::Gpu(buffer) => {
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
|| std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>()
{
let mut cpu_array = buffer.to_cpu_array()?;
cpu_array.mapv_inplace(|x| f(&x));
let device_id = match self.device {
crate::Device::Gpu(id) => id,
_ => {
return Err(crate::TensorError::device_error_simple(
"Expected GPU device".to_string(),
))
}
};
let new_gpu_buffer =
crate::gpu::buffer::GpuBuffer::from_cpu_array(&cpu_array, device_id)?;
*buffer = new_gpu_buffer;
Ok(())
} else {
Err(crate::TensorError::unsupported_operation_simple(format!(
"GPU map_inplace not supported for type {}",
std::any::type_name::<T>()
)))
}
}
}
}
}