1use std::ops::{Index, IndexMut};
2
3#[derive(Debug, Clone, PartialEq)]
4pub struct Array<T> {
5 pub data: Vec<T>,
6 pub shape: Vec<usize>,
7 pub strides: Vec<usize>,
8}
9
10impl<T: Clone + Default> Array<T> {
11 pub fn from_vec(data: Vec<T>, shape: Vec<usize>) -> Self {
13 let total_size: usize = shape.iter().product();
14 assert_eq!(data.len(), total_size, "Data length must match shape");
15
16 let strides = Self::compute_strides(&shape);
17 Array {
18 data,
19 shape,
20 strides,
21 }
22 }
23
24 pub fn zeros(shape: Vec<usize>) -> Self {
26 let total_size: usize = shape.iter().product();
27 let strides = Self::compute_strides(&shape);
28 Array {
29 data: vec![T::default(); total_size],
30 shape,
31 strides,
32 }
33 }
34
35 pub fn ones(shape: Vec<usize>) -> Array<f64>
37 where
38 T: Into<f64>,
39 {
40 let total_size: usize = shape.iter().product();
41 let strides = Self::compute_strides(&shape);
42 Array {
43 data: vec![1.0; total_size],
44 shape,
45 strides,
46 }
47 }
48
49 pub fn full(shape: Vec<usize>, fill_value: T) -> Self {
51 let total_size: usize = shape.iter().product();
52 let strides = Self::compute_strides(&shape);
53 Array {
54 data: vec![fill_value; total_size],
55 shape,
56 strides,
57 }
58 }
59
60 pub fn arange(start: T, stop: T, step: T) -> Array<T>
62 where
63 T: num_traits::Num + PartialOrd + Copy,
64 {
65 let mut data = Vec::new();
66 let mut current = start;
67 while current < stop {
68 data.push(current);
69 current = current + step;
70 }
71 let len = data.len();
72 Array::from_vec(data, vec![len])
73 }
74
75 fn compute_strides(shape: &[usize]) -> Vec<usize> {
77 let mut strides = vec![1; shape.len()];
78 for i in (0..shape.len().saturating_sub(1)).rev() {
79 strides[i] = strides[i + 1] * shape[i + 1];
80 }
81 strides
82 }
83
84 pub fn ravel_index(&self, indices: &[usize]) -> usize {
86 assert_eq!(indices.len(), self.shape.len(), "Index dimension mismatch");
87 indices
88 .iter()
89 .zip(&self.strides)
90 .map(|(&idx, &stride)| idx * stride)
91 .sum()
92 }
93
94 pub fn get(&self, indices: &[usize]) -> Option<&T> {
96 for (idx, dim_size) in indices.iter().zip(&self.shape) {
97 if *idx >= *dim_size {
98 return None;
99 }
100 }
101 let flat_index = self.ravel_index(indices);
102 self.data.get(flat_index)
103 }
104
105 pub fn get_mut(&mut self, indices: &[usize]) -> Option<&mut T> {
107 for (idx, dim_size) in indices.iter().zip(&self.shape) {
108 if *idx >= *dim_size {
109 return None;
110 }
111 }
112 let flat_index = self.ravel_index(indices);
113 self.data.get_mut(flat_index)
114 }
115
116 pub fn reshape(&self, new_shape: Vec<usize>) -> Array<T> {
118 let old_size: usize = self.shape.iter().product();
119 let new_size: usize = new_shape.iter().product();
120 assert_eq!(old_size, new_size, "Total size must remain the same");
121
122 Array::from_vec(self.data.clone(), new_shape)
123 }
124
125 pub fn ndim(&self) -> usize {
127 self.shape.len()
128 }
129
130 pub fn len(&self) -> usize {
132 self.data.len()
133 }
134
135 pub fn is_empty(&self) -> bool {
137 self.data.is_empty()
138 }
139
140 pub fn size(&self) -> usize {
142 self.shape.iter().product()
143 }
144
145 pub fn transpose(&self) -> Array<T> {
147 if self.ndim() != 2 {
148 panic!("Transpose currently only supports 2D arrays");
149 }
150
151 let (rows, cols) = (self.shape[0], self.shape[1]);
152 let mut new_data = Vec::with_capacity(self.data.len());
153
154 for j in 0..cols {
155 for i in 0..rows {
156 let flat_idx = i * self.strides[0] + j * self.strides[i];
157 new_data.push(self.data[flat_idx].clone());
158 }
159 }
160
161 Array::from_vec(new_data, vec![cols, rows])
162 }
163}
164
165impl<T: Clone + Default> Index<(usize, usize)> for Array<T> {
167 type Output = T;
168 fn index(&self, index: (usize, usize)) -> &Self::Output {
169 if self.ndim() != 2 {
170 panic!("2D indexing only works for 2D arrays");
171 }
172 let (i, j) = index;
173 &self.data[i * self.strides[0] + j * self.strides[1]]
174 }
175}
176
177impl<T: Clone + Default> IndexMut<(usize, usize)> for Array<T> {
178 fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
179 if self.ndim() != 2 {
180 panic!("2D indexing only works for 2D arrays");
181 }
182 let (i, j) = index;
183 &mut self.data[i * self.strides[0] + j * self.strides[1]]
184 }
185}
186
187impl<T: Clone + Default> Index<&[usize]> for Array<T> {
189 type Output = T;
190 fn index(&self, indices: &[usize]) -> &Self::Output {
191 let flat_index = self.ravel_index(indices);
192 &self.data[flat_index]
193 }
194}
195
196impl<T: Clone + Default> IndexMut<&[usize]> for Array<T> {
197 fn index_mut(&mut self, indices: &[usize]) -> &mut Self::Output {
198 let flat_index = self.ravel_index(indices);
199 &mut self.data[flat_index]
200 }
201}