use std::ops::{Add, Index};
use lut::Lut;
mod lut;
pub struct Array<T> {
shape: Vec<usize>,
lut: Lut,
data: Vec<T>,
}
impl<T> Array<T> {
pub fn constants(val: T, shape: &[usize]) -> Array<T>
where
T: Clone,
{
let mut size = 1;
for dim in shape {
size *= *dim;
}
Array{
shape: shape.to_vec(),
lut: Lut::new(shape),
data: vec![val; size],
}
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn reshape(&mut self, shape: &[usize]) {
self.shape = shape.to_vec();
self.lut.update(shape);
}
}
impl<T: Add<Output = T>+Copy> Add for Array<T> {
type Output = Self;
fn add(self, other: Self) -> Self {
if self.shape.len() != other.shape.len() {
panic!("invalid number of dimensions");
}
for i in 0..self.shape.len() {
if self.shape[i] != other.shape[i] {
panic!("dimension mismatch");
}
}
let mut out = Self {
shape: self.shape.clone(),
lut: self.lut.clone(),
data: self.data.clone(),
};
for i in 0..self.data.len() {
out.data[i] = out.data[i] + other.data[i];
}
out
}
}
impl<T> Index<&[usize]> for Array<T> {
type Output = T;
fn index(&self, rhs: &[usize]) -> &Self::Output {
&self.data[self.lut.at(rhs)]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_constants() {
let m = Array::constants(0, &[10, 10, 10]);
assert_eq!(m.len(), 1000);
}
#[test]
fn test_addition() {
let m1 = Array::constants(2, &[10, 10]);
let m2 = Array::constants(2, &[10, 10]);
let m3 = m1 + m2;
assert_eq!(m3.len(), 100);
assert_eq!(m3[&[1, 0]], 4);
}
}