use std::ops::{AddAssign, Index, IndexMut};
use crate::custom_error;
custom_error!(pub NDArrayError);
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct NDArray<T, const N: usize> {
shape: [usize; N],
strides: [usize; N],
data: Vec<T>,
}
impl<T, const N: usize> NDArray<T, N> {
pub fn new(shape: [usize; N], data: Vec<T>) -> Result<Self, NDArrayError> {
if shape.iter().product::<usize>() != data.len() {
return Err(NDArrayError::new(format!(
"Incompatible shapes: {:?} and {:?}",
shape,
data.len()
)));
}
let mut strides = [0; N];
let mut stride = 1;
for (i, &dim) in shape.iter().rev().enumerate() {
strides[N - 1 - i] = stride;
stride *= dim;
}
let result = Self {
shape,
strides,
data,
};
Ok(result)
}
pub fn shape(&self) -> [usize; N] {
self.shape
}
pub fn index(&self, indices: [usize; N]) -> usize {
indices
.iter()
.zip(self.strides.iter())
.map(|(&idx, &stride)| idx * stride)
.sum()
}
pub fn inverted_index(&self, mut idx: usize) -> [usize; N] {
let mut indices = [0; N];
for (d, index) in indices.iter_mut().enumerate().take(N) {
*index = idx / self.strides[d];
idx %= self.strides[d];
}
indices
}
}
impl<T: Default + Copy + AddAssign, const N: usize> NDArray<T, N> {
pub fn project_axis(&self, axis: usize) -> Vec<T> {
assert!(axis < N, "Axis out of bounds");
let mut result = vec![T::default(); self.shape[axis]];
for (i, value) in self.data.iter().enumerate() {
let indices = self.inverted_index(i);
result[indices[axis]] += *value;
}
result
}
}
impl<T: AddAssign + Copy, const N: usize> AddAssign for NDArray<T, N> {
fn add_assign(&mut self, other: Self) {
assert_eq!(self.shape(), other.shape());
other
.data
.into_iter()
.enumerate()
.for_each(|(i, value)| self.data[i] += value);
}
}
impl<T: Default, const N: usize> NDArray<T, N> {
pub fn empty(shape: [usize; N]) -> Self {
let size = shape.iter().product();
let data = (0..size).map(|_| T::default()).collect();
Self::new(shape, data).expect("Failed to create empty NDArray")
}
}
impl<T, const N: usize> Index<[usize; N]> for NDArray<T, N> {
type Output = T;
fn index(&self, indices: [usize; N]) -> &Self::Output {
let idx = self.index(indices);
&self.data[idx]
}
}
impl<T, const N: usize> IndexMut<[usize; N]> for NDArray<T, N> {
fn index_mut(&mut self, indices: [usize; N]) -> &mut Self::Output {
let idx = self.index(indices);
&mut self.data[idx]
}
}