use alloc::vec::Vec;
use core::ops::{Add, Mul, Sub};
#[derive(Debug, Clone)]
pub struct Tensor<T> {
pub data: Vec<T>,
pub shape: Vec<usize>,
}
impl<T: Clone> Tensor<T> {
pub fn new(data: Vec<T>, shape: Vec<usize>) -> Self {
let expected_size: usize = shape.iter().product();
assert_eq!(
data.len(),
expected_size,
"Data length {} doesn't match shape {:?} (expected {})",
data.len(),
shape,
expected_size
);
Tensor { data, shape }
}
pub fn fill(value: T, shape: Vec<usize>) -> Self {
let size: usize = shape.iter().product();
Tensor {
data: vec![value; size],
shape,
}
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
pub fn size(&self) -> usize {
self.data.len()
}
pub fn get(&self, idx: usize) -> Option<&T> {
self.data.get(idx)
}
pub fn get_mut(&mut self, idx: usize) -> Option<&mut T> {
self.data.get_mut(idx)
}
pub fn reshape(&self, new_shape: Vec<usize>) -> Self {
let new_size: usize = new_shape.iter().product();
assert_eq!(
self.size(),
new_size,
"Cannot reshape tensor of size {} to shape {:?} (size {})",
self.size(),
new_shape,
new_size
);
Tensor {
data: self.data.clone(),
shape: new_shape,
}
}
pub fn is_matrix(&self) -> bool {
self.ndim() == 2
}
pub fn matrix_dims(&self) -> (usize, usize) {
assert!(self.is_matrix(), "Tensor is not a matrix");
(self.shape[0], self.shape[1])
}
}
impl<T> Add for Tensor<T>
where
T: Add<Output = T> + Clone,
{
type Output = Tensor<T>;
fn add(self, rhs: Self) -> Self::Output {
assert_eq!(
self.shape, rhs.shape,
"Shape mismatch: {:?} vs {:?}",
self.shape, rhs.shape
);
let data: Vec<T> = self
.data
.into_iter()
.zip(rhs.data.into_iter())
.map(|(a, b)| a + b)
.collect();
Tensor {
data,
shape: self.shape,
}
}
}
impl<T> Sub for Tensor<T>
where
T: Sub<Output = T> + Clone,
{
type Output = Tensor<T>;
fn sub(self, rhs: Self) -> Self::Output {
assert_eq!(
self.shape, rhs.shape,
"Shape mismatch: {:?} vs {:?}",
self.shape, rhs.shape
);
let data: Vec<T> = self
.data
.into_iter()
.zip(rhs.data.into_iter())
.map(|(a, b)| a - b)
.collect();
Tensor {
data,
shape: self.shape,
}
}
}
impl<T> Mul for Tensor<T>
where
T: Mul<Output = T> + Clone,
{
type Output = Tensor<T>;
fn mul(self, rhs: Self) -> Self::Output {
assert_eq!(
self.shape, rhs.shape,
"Shape mismatch: {:?} vs {:?}",
self.shape, rhs.shape
);
let data: Vec<T> = self
.data
.into_iter()
.zip(rhs.data.into_iter())
.map(|(a, b)| a * b)
.collect();
Tensor {
data,
shape: self.shape,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ScalarF4E4;
#[test]
fn test_tensor_creation() {
let data = vec![
ScalarF4E4::from(1.0),
ScalarF4E4::from(2.0),
ScalarF4E4::from(3.0),
ScalarF4E4::from(4.0),
];
let tensor = Tensor::new(data, vec![2, 2]);
assert_eq!(tensor.shape, vec![2, 2]);
assert_eq!(tensor.size(), 4);
assert!(tensor.is_matrix());
}
#[test]
fn test_element_wise_add() {
let a = Tensor::new(vec![ScalarF4E4::from(1.0), ScalarF4E4::from(2.0)], vec![2]);
let b = Tensor::new(vec![ScalarF4E4::from(3.0), ScalarF4E4::from(4.0)], vec![2]);
let c = a + b;
assert_eq!(c.data[0].to_f64(), 4.0);
assert_eq!(c.data[1].to_f64(), 6.0);
}
}