1use std::ops::{Index, IndexMut};
2
3#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
5#[derive(Debug, Clone, PartialEq)]
6pub struct Array2<T> {
7 nrows: usize,
8 ncols: usize,
9 data: Vec<T>,
10}
11
12impl<T> Array2<T> {
13 pub fn from_shape_vec(shape: (usize, usize), data: Vec<T>) -> Result<Self, String> {
14 let (nrows, ncols) = shape;
15 if data.len() != nrows * ncols {
16 return Err(format!(
17 "Array2 data length mismatch: got {}, expected {} for shape ({nrows}, {ncols})",
18 data.len(),
19 nrows * ncols
20 ));
21 }
22 Ok(Self { nrows, ncols, data })
23 }
24
25 #[inline]
26 pub fn dim(&self) -> (usize, usize) {
27 (self.nrows, self.ncols)
28 }
29
30 #[inline]
31 pub fn nrows(&self) -> usize {
32 self.nrows
33 }
34
35 #[inline]
36 pub fn ncols(&self) -> usize {
37 self.ncols
38 }
39
40 #[inline]
41 pub fn shape(&self) -> [usize; 2] {
42 [self.nrows, self.ncols]
43 }
44
45 #[inline]
46 pub fn as_slice(&self) -> &[T] {
47 &self.data
48 }
49
50 #[inline]
51 pub fn iter(&self) -> std::slice::Iter<'_, T> {
52 self.data.iter()
53 }
54
55 #[inline]
56 pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, T> {
57 self.data.iter_mut()
58 }
59
60 #[inline]
61 fn offset(&self, row: usize, col: usize) -> usize {
62 row * self.ncols + col
63 }
64}
65
66impl<T: Clone> Array2<T> {
67 pub fn from_elem(shape: (usize, usize), value: T) -> Self {
68 let (nrows, ncols) = shape;
69 Self {
70 nrows,
71 ncols,
72 data: vec![value; nrows * ncols],
73 }
74 }
75
76 pub fn fill(&mut self, value: T) {
77 self.data.fill(value);
78 }
79}
80
81impl<T: Clone + Default> Array2<T> {
82 pub fn zeros(shape: (usize, usize)) -> Self {
83 Self::from_elem(shape, T::default())
84 }
85}
86
87impl<T> Index<[usize; 2]> for Array2<T> {
88 type Output = T;
89
90 fn index(&self, index: [usize; 2]) -> &Self::Output {
91 &self.data[self.offset(index[0], index[1])]
92 }
93}
94
95impl<T> IndexMut<[usize; 2]> for Array2<T> {
96 fn index_mut(&mut self, index: [usize; 2]) -> &mut Self::Output {
97 let offset = self.offset(index[0], index[1]);
98 &mut self.data[offset]
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::Array2;
105
106 #[test]
107 fn shape_vec_roundtrip_preserves_row_major_order() {
108 let array =
109 Array2::from_shape_vec((2, 3), vec![1, 2, 3, 4, 5, 6]).expect("test shape is valid");
110 assert_eq!(array.shape(), [2, 3]);
111 assert_eq!(array[[0, 0]], 1);
112 assert_eq!(array[[0, 2]], 3);
113 assert_eq!(array[[1, 0]], 4);
114 assert_eq!(array[[1, 2]], 6);
115 assert_eq!(array.as_slice(), &[1, 2, 3, 4, 5, 6]);
116 }
117
118 #[test]
119 fn from_elem_and_fill_cover_whole_buffer() {
120 let mut array = Array2::from_elem((2, 2), false);
121 array[[1, 1]] = true;
122 array.fill(true);
123 assert_eq!(array.as_slice(), &[true, true, true, true]);
124 }
125
126 #[test]
127 fn zeros_uses_default_values() {
128 let array = Array2::<f64>::zeros((2, 2));
129 assert_eq!(array.as_slice(), &[0.0, 0.0, 0.0, 0.0]);
130 }
131
132 #[test]
133 fn shape_mismatch_is_rejected() {
134 let err = Array2::from_shape_vec((2, 2), vec![1, 2, 3]).expect_err("shape mismatch");
135 assert!(err.contains("mismatch"), "{err}");
136 }
137}