Skip to main content

neco_array2/
lib.rs

1use std::ops::{Index, IndexMut};
2
3/// Lightweight row-major 2D array used by grid-oriented helper crates.
4#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
5#[derive(Debug, Clone, PartialEq)]
6pub struct Array2<T> {
7    nrows: usize,
8    ncols: usize,
9    data: Vec<T>,
10}
11
12impl<T> Array2<T> {
13    pub fn from_shape_vec(shape: (usize, usize), data: Vec<T>) -> Result<Self, String> {
14        let (nrows, ncols) = shape;
15        if data.len() != nrows * ncols {
16            return Err(format!(
17                "Array2 data length mismatch: got {}, expected {} for shape ({nrows}, {ncols})",
18                data.len(),
19                nrows * ncols
20            ));
21        }
22        Ok(Self { nrows, ncols, data })
23    }
24
25    #[inline]
26    pub fn dim(&self) -> (usize, usize) {
27        (self.nrows, self.ncols)
28    }
29
30    #[inline]
31    pub fn nrows(&self) -> usize {
32        self.nrows
33    }
34
35    #[inline]
36    pub fn ncols(&self) -> usize {
37        self.ncols
38    }
39
40    #[inline]
41    pub fn shape(&self) -> [usize; 2] {
42        [self.nrows, self.ncols]
43    }
44
45    #[inline]
46    pub fn as_slice(&self) -> &[T] {
47        &self.data
48    }
49
50    #[inline]
51    pub fn iter(&self) -> std::slice::Iter<'_, T> {
52        self.data.iter()
53    }
54
55    #[inline]
56    pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, T> {
57        self.data.iter_mut()
58    }
59
60    #[inline]
61    fn offset(&self, row: usize, col: usize) -> usize {
62        row * self.ncols + col
63    }
64}
65
66impl<T: Clone> Array2<T> {
67    pub fn from_elem(shape: (usize, usize), value: T) -> Self {
68        let (nrows, ncols) = shape;
69        Self {
70            nrows,
71            ncols,
72            data: vec![value; nrows * ncols],
73        }
74    }
75
76    pub fn fill(&mut self, value: T) {
77        self.data.fill(value);
78    }
79}
80
81impl<T: Clone + Default> Array2<T> {
82    pub fn zeros(shape: (usize, usize)) -> Self {
83        Self::from_elem(shape, T::default())
84    }
85}
86
87impl<T> Index<[usize; 2]> for Array2<T> {
88    type Output = T;
89
90    fn index(&self, index: [usize; 2]) -> &Self::Output {
91        &self.data[self.offset(index[0], index[1])]
92    }
93}
94
95impl<T> IndexMut<[usize; 2]> for Array2<T> {
96    fn index_mut(&mut self, index: [usize; 2]) -> &mut Self::Output {
97        let offset = self.offset(index[0], index[1]);
98        &mut self.data[offset]
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::Array2;
105
106    #[test]
107    fn shape_vec_roundtrip_preserves_row_major_order() {
108        let array =
109            Array2::from_shape_vec((2, 3), vec![1, 2, 3, 4, 5, 6]).expect("test shape is valid");
110        assert_eq!(array.shape(), [2, 3]);
111        assert_eq!(array[[0, 0]], 1);
112        assert_eq!(array[[0, 2]], 3);
113        assert_eq!(array[[1, 0]], 4);
114        assert_eq!(array[[1, 2]], 6);
115        assert_eq!(array.as_slice(), &[1, 2, 3, 4, 5, 6]);
116    }
117
118    #[test]
119    fn from_elem_and_fill_cover_whole_buffer() {
120        let mut array = Array2::from_elem((2, 2), false);
121        array[[1, 1]] = true;
122        array.fill(true);
123        assert_eq!(array.as_slice(), &[true, true, true, true]);
124    }
125
126    #[test]
127    fn zeros_uses_default_values() {
128        let array = Array2::<f64>::zeros((2, 2));
129        assert_eq!(array.as_slice(), &[0.0, 0.0, 0.0, 0.0]);
130    }
131
132    #[test]
133    fn shape_mismatch_is_rejected() {
134        let err = Array2::from_shape_vec((2, 2), vec![1, 2, 3]).expect_err("shape mismatch");
135        assert!(err.contains("mismatch"), "{err}");
136    }
137}