use std::fmt::{Display};
use std::ops::{Add, Sub, Mul, Div};
#[derive(Clone, Debug, PartialEq)]
pub struct Tensor<T: Copy, const RANK: usize> {
pub data: Vec<T>,
shape: [usize; RANK],
}
impl<T: Copy, const RANK: usize> Tensor<T, RANK> {
pub fn new(data: Vec<T>, shape: [usize; RANK]) -> Self {
assert_eq!(data.len(), shape.iter().product::<usize>(), "Data length must match shape product");
Tensor { data, shape }
}
pub fn new_vec(data: &Vec<T>) -> Tensor<T, 1> {
Tensor::new(data.clone(), [data.clone().len()])
}
pub fn null() -> Tensor<T, 0> {
Tensor {
data: vec![],
shape: [],
}
}
pub fn empty(shape: &[usize; RANK]) -> Tensor<T, RANK> {
Tensor {
data: vec![],
shape: *shape,
}
}
pub fn as_vec(&self) -> &Vec<T> {
&self.data
}
pub fn shape(&self) -> &[usize; RANK] {
&self.shape
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn reshape<const NEW_RANK: usize>(self, new_shape: [usize; NEW_RANK]) -> Tensor<T, NEW_RANK> {
assert_eq!(self.data.len(), new_shape.iter().product::<usize>(), "Data length must match new shape product");
Tensor { data: self.data, shape: new_shape }
}
pub fn size(&self) -> usize {
self.data.len()
}
pub fn to_vec(&self) -> Vec<T> {
self.data.clone()
}
}
impl<T: Copy, const RANK: usize> Tensor<T, RANK> {
pub fn filled(value: T, shape: [usize; RANK]) -> Self {
Tensor::new(vec![value; shape.iter().product()], shape)
}
}
impl<T: Add<Output = T> + Copy, const RANK: usize> Add for Tensor<T, RANK> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
assert_eq!(self.shape, rhs.shape, "Shapes must match for addition");
let data = self.data.iter().zip(rhs.data.iter()).map(|(&a, &b)| a + b).collect();
Tensor::new(data, *self.shape())
}
}
impl<T: Sub<Output = T> + Copy, const RANK: usize> Sub for Tensor<T, RANK> {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
assert_eq!(self.shape, rhs.shape, "Shapes must match for subtraction");
let data = self.data.iter().zip(rhs.data.iter()).map(|(&a, &b)| a - b).collect();
Tensor::new(data, *self.shape())
}
}
impl<T: Mul<Output = T> + Copy, const RANK: usize> Mul for Tensor<T, RANK> {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
assert_eq!(self.shape, rhs.shape, "Shapes must match for multiplication");
let data = self.data.iter().zip(rhs.data.iter()).map(|(&a, &b)| a * b).collect();
Tensor::new(data, *self.shape())
}
}
impl<T: Div<Output = T> + Copy, const RANK: usize> Div for Tensor<T, RANK> {
type Output = Self;
fn div(self, rhs: Self) -> Self::Output {
assert_eq!(self.shape, rhs.shape, "Shapes must match for division");
let data = self.data.iter().zip(rhs.data.iter()).map(|(&a, &b)| a / b).collect();
Tensor::new(data, *self.shape())
}
}
pub trait DotProduct<Rhs = Self> {
type Output;
fn dot(&self, rhs: &Rhs) -> Self::Output;
}
impl<T: Copy + Mul<Output = T> + Add<Output = T> + Default> DotProduct for Tensor<T, 1> {
type Output = T;
fn dot(&self, rhs: &Tensor<T, 1>) -> Self::Output {
assert_eq!(self.shape, rhs.shape, "Shapes must match for dot product");
self.data.iter().zip(rhs.data.iter()).map(|(&a, &b)| a * b).fold(T::default(), |acc, x| acc + x)
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct TensorBroadcastError;
pub trait Broadcast<T: Copy, const RANK: usize> {
fn broadcast(&self, other: &Tensor<T, RANK>) -> Result<Tensor<T, RANK>, TensorBroadcastError>;
fn broadcast_inplace(&mut self, new_shape: [usize; RANK]);
fn broadcast_shape(&self, shape: [usize; RANK]) -> Tensor<T, RANK>;
fn get_broadcast_shape(&self, other: &Tensor<T, RANK>) -> Result<[usize; RANK], TensorBroadcastError>;
}
impl<T: Copy, const RANK: usize> Broadcast<T, RANK> for Tensor<T, RANK> {
fn broadcast(&self, other: &Tensor<T, RANK>) -> Result<Tensor<T, RANK>, TensorBroadcastError> {
let shape = self.get_broadcast_shape(other);
let mut tensor = Tensor::new(self.data.clone(), self.shape.clone());
tensor.broadcast_inplace(shape?);
Ok(tensor)
}
fn broadcast_inplace(&mut self, new_shape: [usize; RANK]) {
let new_size = new_shape.iter().product::<usize>();
if new_size != self.data.len() {
if self.data.len() == 1 {
let value = self.data[0];
self.data = vec![value; new_size];
} else if new_size == self.data.len() {
self.data.truncate(new_size);
self.data.resize(new_size, self.data[0]);
} else {
panic!("Cannot resize tensor: new shape product must match data length or be a broadcastable shape");
}
}
self.shape = new_shape;
}
fn broadcast_shape(&self, shape: [usize; RANK]) -> Tensor<T, RANK> {
Tensor::new(self.data.clone(), shape)
}
fn get_broadcast_shape(&self, other: &Tensor<T, RANK>) -> Result<[usize; RANK], TensorBroadcastError> {
let mut result_shape = [0; RANK];
for i in 0..RANK {
if self.shape[i] == other.shape[i] {
result_shape[i] = self.shape[i];
} else if self.shape[i] == 1 {
result_shape[i] = other.shape[i];
} else if other.shape[i] == 1 {
result_shape[i] = self.shape[i];
} else {
return Err(TensorBroadcastError); }
}
Ok(result_shape)
}
}
impl<T: Copy, const RANK: usize> Display for Tensor<T, RANK> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.shape)
}
}