use std::sync::Arc;
use parking_lot::RwLock;
use rand_distr::{Distribution, Normal, Uniform};
use crate::dtype::{DType, TensorElement};
use crate::shape::{Shape, Strides};
use crate::storage::Storage;
use crate::error::{GhostError, Result};
#[derive(Debug)]
pub struct Tensor {
storage: Storage,
shape: Shape,
strides: Strides,
offset: usize,
requires_grad: bool,
grad: Option<Arc<RwLock<Tensor>>>,
}
impl Tensor {
pub fn from_slice<T: TensorElement>(data: &[T], shape: &[usize]) -> Result<Self> {
let shape = Shape::new(shape);
if data.len() != shape.numel() {
return Err(GhostError::InvalidShape(format!(
"Data length {} doesn't match shape {:?} (numel={})",
data.len(),
shape.dims(),
shape.numel()
)));
}
let strides = shape.default_strides();
let storage = Storage::from_slice(data);
Ok(Tensor {
storage,
shape,
strides,
offset: 0,
requires_grad: false,
grad: None,
})
}
pub fn zeros(shape: &[usize]) -> Self {
Self::full(shape, 0.0f32)
}
pub fn ones(shape: &[usize]) -> Self {
Self::full(shape, 1.0f32)
}
pub fn full<T: TensorElement>(shape: &[usize], value: T) -> Self {
let shape = Shape::new(shape);
let numel = shape.numel();
let data: Vec<T> = vec![value; numel];
let strides = shape.default_strides();
let storage = Storage::from_slice(&data);
Tensor {
storage,
shape,
strides,
offset: 0,
requires_grad: false,
grad: None,
}
}
pub fn rand(shape: &[usize]) -> Self {
let shape_obj = Shape::new(shape);
let numel = shape_obj.numel();
let mut rng = rand::thread_rng();
let dist = Uniform::new(0.0f32, 1.0);
let data: Vec<f32> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
Tensor::from_slice(&data, shape).unwrap()
}
pub fn randn(shape: &[usize]) -> Self {
let shape_obj = Shape::new(shape);
let numel = shape_obj.numel();
let mut rng = rand::thread_rng();
let dist = Normal::new(0.0f32, 1.0).unwrap();
let data: Vec<f32> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
Tensor::from_slice(&data, shape).unwrap()
}
pub fn eye(n: usize) -> Self {
let mut data = vec![0.0f32; n * n];
for i in 0..n {
data[i * n + i] = 1.0;
}
Tensor::from_slice(&data, &[n, n]).unwrap()
}
pub fn arange(start: f32, end: f32, step: f32) -> Self {
let mut data = Vec::new();
let mut val = start;
while val < end {
data.push(val);
val += step;
}
let len = data.len();
Tensor::from_slice(&data, &[len]).unwrap()
}
pub fn linspace(start: f32, end: f32, n: usize) -> Self {
if n == 0 {
return Tensor::from_slice::<f32>(&[], &[0]).unwrap();
}
if n == 1 {
return Tensor::from_slice(&[start], &[1]).unwrap();
}
let step = (end - start) / (n - 1) as f32;
let data: Vec<f32> = (0..n).map(|i| start + i as f32 * step).collect();
Tensor::from_slice(&data, &[n]).unwrap()
}
pub fn shape(&self) -> &Shape {
&self.shape
}
pub fn dims(&self) -> &[usize] {
self.shape.dims()
}
pub fn ndim(&self) -> usize {
self.shape.ndim()
}
pub fn numel(&self) -> usize {
self.shape.numel()
}
pub fn dtype(&self) -> DType {
self.storage.dtype()
}
pub fn strides(&self) -> &Strides {
&self.strides
}
pub fn is_contiguous(&self) -> bool {
self.strides.is_contiguous(&self.shape)
}
pub fn requires_grad(&self) -> bool {
self.requires_grad
}
pub fn set_requires_grad(&mut self, requires_grad: bool) {
self.requires_grad = requires_grad;
}
pub fn grad(&self) -> Option<Tensor> {
self.grad.as_ref().map(|g| g.read().clone())
}
pub fn storage(&self) -> &Storage {
&self.storage
}
pub fn set_grad(&mut self, grad: Tensor) {
self.grad = Some(Arc::new(RwLock::new(grad)));
}
pub fn zero_grad(&mut self) {
if let Some(ref grad) = self.grad {
let mut g = grad.write();
let zeros = Tensor::zeros(g.dims());
*g = zeros;
}
}
pub fn data_f32(&self) -> Vec<f32> {
let guard = self.storage.as_slice::<f32>();
if self.is_contiguous() && self.offset == 0 {
guard.to_vec()
} else {
self.to_contiguous_data::<f32>()
}
}
fn to_contiguous_data<T: TensorElement>(&self) -> Vec<T> {
let numel = self.numel();
let mut result = Vec::with_capacity(numel);
let guard = self.storage.as_slice::<T>();
self.for_each_index(|indices| {
let offset = self.compute_offset(indices);
result.push(guard[offset]);
});
result
}
fn compute_offset(&self, indices: &[usize]) -> usize {
self.offset + self.strides.offset(indices)
}
fn for_each_index<F: FnMut(&[usize])>(&self, mut f: F) {
let dims = self.dims();
if dims.is_empty() {
f(&[]);
return;
}
let mut indices = vec![0usize; dims.len()];
loop {
f(&indices);
let mut i = dims.len() - 1;
loop {
indices[i] += 1;
if indices[i] < dims[i] {
break;
}
indices[i] = 0;
if i == 0 {
return;
}
i -= 1;
}
}
}
pub fn reshape(&self, new_shape: &[usize]) -> Result<Tensor> {
let new_shape = Shape::new(new_shape);
if new_shape.numel() != self.numel() {
return Err(GhostError::InvalidShape(format!(
"Cannot reshape tensor of {} elements to shape {:?}",
self.numel(),
new_shape.dims()
)));
}
if self.is_contiguous() {
let new_strides = new_shape.default_strides();
return Ok(Tensor {
storage: self.storage.clone(),
shape: new_shape,
strides: new_strides,
offset: self.offset,
requires_grad: self.requires_grad,
grad: None,
});
}
let data = self.to_contiguous_data::<f32>();
Tensor::from_slice(&data, new_shape.dims())
}
pub fn flatten(&self) -> Result<Tensor> {
self.reshape(&[self.numel()])
}
pub fn transpose(&self, dim0: usize, dim1: usize) -> Result<Tensor> {
if dim0 >= self.ndim() || dim1 >= self.ndim() {
return Err(GhostError::DimOutOfBounds {
dim: dim0.max(dim1),
ndim: self.ndim(),
});
}
let mut new_shape = self.shape.dims().to_vec();
let mut new_strides = self.strides.as_slice().to_vec();
new_shape.swap(dim0, dim1);
new_strides.swap(dim0, dim1);
Ok(Tensor {
storage: self.storage.clone(),
shape: Shape::from(new_shape),
strides: Strides::from(new_strides.as_slice()),
offset: self.offset,
requires_grad: self.requires_grad,
grad: None,
})
}
pub fn t(&self) -> Result<Tensor> {
if self.ndim() != 2 {
return Err(GhostError::InvalidOperation(
"t() only works on 2D tensors".to_string()
));
}
self.transpose(0, 1)
}
pub fn squeeze(&self) -> Tensor {
let new_dims: Vec<usize> = self.dims().iter()
.filter(|&&d| d != 1)
.copied()
.collect();
if new_dims.is_empty() {
let data = self.data_f32();
Tensor::from_slice(&data, &[]).unwrap()
} else {
self.reshape(&new_dims).unwrap()
}
}
pub fn unsqueeze(&self, dim: usize) -> Result<Tensor> {
if dim > self.ndim() {
return Err(GhostError::DimOutOfBounds {
dim,
ndim: self.ndim() + 1,
});
}
let mut new_dims = self.dims().to_vec();
new_dims.insert(dim, 1);
self.reshape(&new_dims)
}
pub fn deep_clone(&self) -> Self {
let data = self.data_f32();
Tensor::from_slice(&data, self.dims()).unwrap()
}
}
impl Clone for Tensor {
fn clone(&self) -> Self {
Tensor {
storage: self.storage.clone(),
shape: self.shape.clone(),
strides: self.strides.clone(),
offset: self.offset,
requires_grad: self.requires_grad,
grad: self.grad.clone(),
}
}
}
impl std::fmt::Display for Tensor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Tensor(shape={}, dtype={})", self.shape, self.dtype())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tensor_creation() {
let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
assert_eq!(t.dims(), &[2, 2]);
assert_eq!(t.numel(), 4);
}
#[test]
fn test_zeros_ones() {
let zeros = Tensor::zeros(&[3, 3]);
let ones = Tensor::ones(&[3, 3]);
assert!(zeros.data_f32().iter().all(|&x| x == 0.0));
assert!(ones.data_f32().iter().all(|&x| x == 1.0));
}
#[test]
fn test_reshape() {
let t = Tensor::arange(0.0, 12.0, 1.0);
let reshaped = t.reshape(&[3, 4]).unwrap();
assert_eq!(reshaped.dims(), &[3, 4]);
}
#[test]
fn test_transpose() {
let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
let transposed = t.t().unwrap();
assert_eq!(transposed.dims(), &[3, 2]);
}
}