use std::ops::{Index, IndexMut};
#[derive(Debug, Clone, PartialEq)]
pub struct Array<T> {
pub data: Vec<T>,
pub shape: Vec<usize>,
pub strides: Vec<usize>,
}
impl<T: Clone + Default> Array<T> {
pub fn from_vec(data: Vec<T>, shape: Vec<usize>) -> Self {
let total_size: usize = shape.iter().product();
assert_eq!(data.len(), total_size, "Data length must match shape");
let strides = Self::compute_strides(&shape);
Array {
data,
shape,
strides,
}
}
pub fn zeros(shape: Vec<usize>) -> Self {
let total_size: usize = shape.iter().product();
let strides = Self::compute_strides(&shape);
Array {
data: vec![T::default(); total_size],
shape,
strides,
}
}
pub fn ones(shape: Vec<usize>) -> Array<f64>
where
T: Into<f64>,
{
let total_size: usize = shape.iter().product();
let strides = Self::compute_strides(&shape);
Array {
data: vec![1.0; total_size],
shape,
strides,
}
}
pub fn full(shape: Vec<usize>, fill_value: T) -> Self {
let total_size: usize = shape.iter().product();
let strides = Self::compute_strides(&shape);
Array {
data: vec![fill_value; total_size],
shape,
strides,
}
}
pub fn arange(start: T, stop: T, step: T) -> Array<T>
where
T: num_traits::Num + PartialOrd + Copy,
{
let mut data = Vec::new();
let mut current = start;
while current < stop {
data.push(current);
current = current + step;
}
let len = data.len();
Array::from_vec(data, vec![len])
}
fn compute_strides(shape: &[usize]) -> Vec<usize> {
let mut strides = vec![1; shape.len()];
for i in (0..shape.len().saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
strides
}
pub fn ravel_index(&self, indices: &[usize]) -> usize {
assert_eq!(indices.len(), self.shape.len(), "Index dimension mismatch");
indices
.iter()
.zip(&self.strides)
.map(|(&idx, &stride)| idx * stride)
.sum()
}
pub fn get(&self, indices: &[usize]) -> Option<&T> {
for (idx, dim_size) in indices.iter().zip(&self.shape) {
if *idx >= *dim_size {
return None;
}
}
let flat_index = self.ravel_index(indices);
self.data.get(flat_index)
}
pub fn get_mut(&mut self, indices: &[usize]) -> Option<&mut T> {
for (idx, dim_size) in indices.iter().zip(&self.shape) {
if *idx >= *dim_size {
return None;
}
}
let flat_index = self.ravel_index(indices);
self.data.get_mut(flat_index)
}
pub fn reshape(&self, new_shape: Vec<usize>) -> Array<T> {
let old_size: usize = self.shape.iter().product();
let new_size: usize = new_shape.iter().product();
assert_eq!(old_size, new_size, "Total size must remain the same");
Array::from_vec(self.data.clone(), new_shape)
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn size(&self) -> usize {
self.shape.iter().product()
}
pub fn transpose(&self) -> Array<T> {
if self.ndim() != 2 {
panic!("Transpose currently only supports 2D arrays");
}
let (rows, cols) = (self.shape[0], self.shape[1]);
let mut new_data = Vec::with_capacity(self.data.len());
for j in 0..cols {
for i in 0..rows {
let flat_idx = i * self.strides[0] + j * self.strides[i];
new_data.push(self.data[flat_idx].clone());
}
}
Array::from_vec(new_data, vec![cols, rows])
}
}
impl<T: Clone + Default> Index<(usize, usize)> for Array<T> {
type Output = T;
fn index(&self, index: (usize, usize)) -> &Self::Output {
if self.ndim() != 2 {
panic!("2D indexing only works for 2D arrays");
}
let (i, j) = index;
&self.data[i * self.strides[0] + j * self.strides[1]]
}
}
impl<T: Clone + Default> IndexMut<(usize, usize)> for Array<T> {
fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
if self.ndim() != 2 {
panic!("2D indexing only works for 2D arrays");
}
let (i, j) = index;
&mut self.data[i * self.strides[0] + j * self.strides[1]]
}
}
impl<T: Clone + Default> Index<&[usize]> for Array<T> {
type Output = T;
fn index(&self, indices: &[usize]) -> &Self::Output {
let flat_index = self.ravel_index(indices);
&self.data[flat_index]
}
}
impl<T: Clone + Default> IndexMut<&[usize]> for Array<T> {
fn index_mut(&mut self, indices: &[usize]) -> &mut Self::Output {
let flat_index = self.ravel_index(indices);
&mut self.data[flat_index]
}
}