rustframes/array/
core.rs

1use std::ops::{Index, IndexMut};
2
3#[derive(Debug, Clone, PartialEq)]
4pub struct Array<T> {
5    pub data: Vec<T>,
6    pub shape: Vec<usize>,
7    pub strides: Vec<usize>,
8}
9
10impl<T: Clone + Default> Array<T> {
11    /// Create an array from a vector with given shape
12    pub fn from_vec(data: Vec<T>, shape: Vec<usize>) -> Self {
13        let total_size: usize = shape.iter().product();
14        assert_eq!(data.len(), total_size, "Data length must match shape");
15
16        let strides = Self::compute_strides(&shape);
17        Array {
18            data,
19            shape,
20            strides,
21        }
22    }
23
24    /// Create an array filled with zeros
25    pub fn zeros(shape: Vec<usize>) -> Self {
26        let total_size: usize = shape.iter().product();
27        let strides = Self::compute_strides(&shape);
28        Array {
29            data: vec![T::default(); total_size],
30            shape,
31            strides,
32        }
33    }
34
35    /// Create an array filled with ones (for numeric types)
36    pub fn ones(shape: Vec<usize>) -> Array<f64>
37    where
38        T: Into<f64>,
39    {
40        let total_size: usize = shape.iter().product();
41        let strides = Self::compute_strides(&shape);
42        Array {
43            data: vec![1.0; total_size],
44            shape,
45            strides,
46        }
47    }
48
49    /// Create array with given shape and fill value
50    pub fn full(shape: Vec<usize>, fill_value: T) -> Self {
51        let total_size: usize = shape.iter().product();
52        let strides = Self::compute_strides(&shape);
53        Array {
54            data: vec![fill_value; total_size],
55            shape,
56            strides,
57        }
58    }
59
60    /// Create array from range
61    pub fn arange(start: T, stop: T, step: T) -> Array<T>
62    where
63        T: num_traits::Num + PartialOrd + Copy,
64    {
65        let mut data = Vec::new();
66        let mut current = start;
67        while current < stop {
68            data.push(current);
69            current = current + step;
70        }
71        let len = data.len();
72        Array::from_vec(data, vec![len])
73    }
74
75    /// Compute strides for row-major (C-style) ordering
76    fn compute_strides(shape: &[usize]) -> Vec<usize> {
77        let mut strides = vec![1; shape.len()];
78        for i in (0..shape.len().saturating_sub(1)).rev() {
79            strides[i] = strides[i + 1] * shape[i + 1];
80        }
81        strides
82    }
83
84    /// Convert multi-dimensional index to flat index
85    pub fn ravel_index(&self, indices: &[usize]) -> usize {
86        assert_eq!(indices.len(), self.shape.len(), "Index dimension mismatch");
87        indices
88            .iter()
89            .zip(&self.strides)
90            .map(|(&idx, &stride)| idx * stride)
91            .sum()
92    }
93
94    /// Get element at multi-dimensional index
95    pub fn get(&self, indices: &[usize]) -> Option<&T> {
96        for (idx, dim_size) in indices.iter().zip(&self.shape) {
97            if *idx >= *dim_size {
98                return None;
99            }
100        }
101        let flat_index = self.ravel_index(indices);
102        self.data.get(flat_index)
103    }
104
105    /// Get mutable element at mutli-dimensional index
106    pub fn get_mut(&mut self, indices: &[usize]) -> Option<&mut T> {
107        for (idx, dim_size) in indices.iter().zip(&self.shape) {
108            if *idx >= *dim_size {
109                return None;
110            }
111        }
112        let flat_index = self.ravel_index(indices);
113        self.data.get_mut(flat_index)
114    }
115
116    /// Reshape array to new shape (must have same total size)
117    pub fn reshape(&self, new_shape: Vec<usize>) -> Array<T> {
118        let old_size: usize = self.shape.iter().product();
119        let new_size: usize = new_shape.iter().product();
120        assert_eq!(old_size, new_size, "Total size must remain the same");
121
122        Array::from_vec(self.data.clone(), new_shape)
123    }
124
125    /// Get the number of dimensions
126    pub fn ndim(&self) -> usize {
127        self.shape.len()
128    }
129
130    /// Get total number of elements
131    pub fn len(&self) -> usize {
132        self.data.len()
133    }
134
135    /// Check if array is empty
136    pub fn is_empty(&self) -> bool {
137        self.data.is_empty()
138    }
139
140    /// Get array size (total elements)
141    pub fn size(&self) -> usize {
142        self.shape.iter().product()
143    }
144
145    /// Transpose 2D array
146    pub fn transpose(&self) -> Array<T> {
147        if self.ndim() != 2 {
148            panic!("Transpose currently only supports 2D arrays");
149        }
150
151        let (rows, cols) = (self.shape[0], self.shape[1]);
152        let mut new_data = Vec::with_capacity(self.data.len());
153
154        for j in 0..cols {
155            for i in 0..rows {
156                let flat_idx = i * self.strides[0] + j * self.strides[i];
157                new_data.push(self.data[flat_idx].clone());
158            }
159        }
160
161        Array::from_vec(new_data, vec![cols, rows])
162    }
163}
164
165// Implement indexing for backwards compability with 2D arrays
166impl<T: Clone + Default> Index<(usize, usize)> for Array<T> {
167    type Output = T;
168    fn index(&self, index: (usize, usize)) -> &Self::Output {
169        if self.ndim() != 2 {
170            panic!("2D indexing only works for 2D arrays");
171        }
172        let (i, j) = index;
173        &self.data[i * self.strides[0] + j * self.strides[1]]
174    }
175}
176
177impl<T: Clone + Default> IndexMut<(usize, usize)> for Array<T> {
178    fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
179        if self.ndim() != 2 {
180            panic!("2D indexing only works for 2D arrays");
181        }
182        let (i, j) = index;
183        &mut self.data[i * self.strides[0] + j * self.strides[1]]
184    }
185}
186
187// Implement indexing for N-dimensional arrays
188impl<T: Clone + Default> Index<&[usize]> for Array<T> {
189    type Output = T;
190    fn index(&self, indices: &[usize]) -> &Self::Output {
191        let flat_index = self.ravel_index(indices);
192        &self.data[flat_index]
193    }
194}
195
196impl<T: Clone + Default> IndexMut<&[usize]> for Array<T> {
197    fn index_mut(&mut self, indices: &[usize]) -> &mut Self::Output {
198        let flat_index = self.ravel_index(indices);
199        &mut self.data[flat_index]
200    }
201}