use std::ops::{Add, Mul};
struct Tensor {
data: Vec<f32>,
shape: Vec<usize>,
}
impl Tensor {
fn new(data: Vec<f32>, shape: Vec<usize>) -> Tensor {
Tensor { data, shape }
}
fn from_shape(shape: Vec<usize>) -> Tensor {
let size = shape.iter().product();
Tensor {
data: vec![0.0; size],
shape,
}
}
fn reshape(&mut self, new_shape: Vec<usize>) {
let size: usize = self.shape.iter().product();
let new_size = new_shape.iter().product();
if size != new_size {
panic!("Cannot reshape tensor of size {} to size {}", size, new_size);
}
self.shape = new_shape;
}
fn size(&self) -> usize {
self.data.len()
}
fn shape(&self) -> &Vec<usize> {
&self.shape
}
fn rank(&self) -> usize {
self.shape.len()
}
}
impl Add for Tensor {
type Output = Tensor;
fn add(self, other: Tensor) -> Tensor {
let size = self.size();
if size != other.size() {
panic!("Cannot add tensors of different sizes ({} and {})", size, other.size());
}
let mut data = Vec::with_capacity(size);
for i in 0..size {
data.push(self.data[i] + other.data[i]);
}
Tensor {
data,
shape: self.shape.clone(),
}
}
}
impl Mul<Tensor> for Tensor {
type Output = Tensor;
fn mul(self, other: Tensor) -> Tensor {
let size = self.size();
if size != other.size() {
panic!("Cannot multiply tensors of different sizes ({} and {})", size, other.size());
}
let mut data = Vec::with_capacity(size);
for i in 0..size {
data.push(self.data[i] * other.data[i]);
}
Tensor {
data,
shape: self.shape.clone(),
}
}
}
#[cfg(test)]
mod tests {
#[test]
fn test_tensor() {
use super::Tensor;
let mut t = Tensor::from_shape(vec![2, 3]);
assert_eq!(t.size(), 6);
assert_eq!(t.rank(), 2);
assert_eq!(t.shape(), &vec![2, 3]);
t.reshape(vec![3, 2]);
assert_eq!(t.size(), 6);
assert_eq!(t.rank(), 2);
assert_eq!(t.shape(), &vec![3, 2]);
}
#[test]
fn test_add() {
use super::Tensor;
let t1 = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
let t2 = Tensor::new(vec![5.0, 6.0, 7.0, 8.0], vec![2, 2]);
let t3 = t1 + t2;
assert_eq!(t3.data, vec![6.0, 8.0, 10.0, 12.0]);
assert_eq!(t3.shape, vec![2, 2]);
}
#[test]
fn test_mul() {
use super::Tensor;
let t1 = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
let t2 = Tensor::new(vec![5.0, 6.0, 7.0, 8.0], vec![2, 2]);
let t3 = t1 * t2;
assert_eq!(t3.data, vec![5.0, 12.0, 21.0, 32.0]);
assert_eq!(t3.shape, vec![2, 2]);
}
}