trueno/matrix/ops/storage.rs
1//! Matrix storage and construction operations
2//!
3//! This module provides storage-related operations for Matrix:
4//! - Constructors: `new()`, `from_vec()`, `from_slice()`, `zeros()`, `identity()`
5//! - Accessors: `rows()`, `cols()`, `shape()`, `get()`, `get_mut()`, `as_slice()`
6//!
7//! ## Domain Separation (PMAT-018)
8//!
9//! Storage is separate from Algebra (arithmetic operations).
10//! A matrix's memory layout is independent of its mathematical operations.
11
12use crate::{Backend, TruenoError};
13
14use super::super::Matrix;
15
16impl std::ops::Index<(usize, usize)> for Matrix<f32> {
17 type Output = f32;
18
19 fn index(&self, (row, col): (usize, usize)) -> &Self::Output {
20 &self.data[row * self.cols + col]
21 }
22}
23
24impl Matrix<f32> {
25 // =========================================================================
26 // Constructors
27 // =========================================================================
28
29 /// Creates a new matrix with uninitialized values
30 ///
31 /// # Arguments
32 ///
33 /// * `rows` - Number of rows
34 /// * `cols` - Number of columns
35 ///
36 /// # Returns
37 ///
38 /// A new matrix with dimensions `rows x cols` containing uninitialized values
39 ///
40 /// # Example
41 ///
42 /// ```
43 /// use trueno::Matrix;
44 ///
45 /// let m = Matrix::new(3, 4);
46 /// assert_eq!(m.rows(), 3);
47 /// assert_eq!(m.cols(), 4);
48 /// ```
49 pub fn new(rows: usize, cols: usize) -> Self {
50 let backend = Backend::select_best();
51 Matrix { rows, cols, data: vec![0.0; rows * cols], backend }
52 }
53
54 /// Creates a matrix from a vector of data
55 ///
56 /// # Arguments
57 ///
58 /// * `rows` - Number of rows
59 /// * `cols` - Number of columns
60 /// * `data` - Vector containing matrix elements in row-major order
61 ///
62 /// # Errors
63 ///
64 /// Returns `InvalidInput` if `data.len() != rows * cols`
65 ///
66 /// # Example
67 ///
68 /// ```
69 /// use trueno::Matrix;
70 ///
71 /// let m = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
72 /// assert_eq!(m.rows(), 2);
73 /// assert_eq!(m.cols(), 2);
74 /// ```
75 pub fn from_vec(rows: usize, cols: usize, data: Vec<f32>) -> Result<Self, TruenoError> {
76 if data.len() != rows * cols {
77 return Err(TruenoError::InvalidInput(format!(
78 "Data length {} does not match matrix dimensions {}x{} (expected {})",
79 data.len(),
80 rows,
81 cols,
82 rows * cols
83 )));
84 }
85
86 let backend = Backend::select_best();
87 Ok(Matrix { rows, cols, data, backend })
88 }
89
90 /// Creates a matrix from a vector with a specific backend
91 ///
92 /// This is useful for testing specific SIMD code paths.
93 pub fn from_vec_with_backend(
94 rows: usize,
95 cols: usize,
96 data: Vec<f32>,
97 backend: Backend,
98 ) -> Self {
99 assert_eq!(
100 data.len(),
101 rows * cols,
102 "Data length {} does not match matrix dimensions {}x{}",
103 data.len(),
104 rows,
105 cols
106 );
107 Matrix { rows, cols, data, backend }
108 }
109
110 /// Creates a matrix from a slice by copying the data
111 ///
112 /// This is a convenience method that copies the slice into an owned vector.
113 /// For zero-copy scenarios, consider using the data directly with `from_vec`
114 /// if you already have an owned `Vec`.
115 ///
116 /// # Arguments
117 ///
118 /// * `rows` - Number of rows
119 /// * `cols` - Number of columns
120 /// * `data` - Slice containing matrix elements in row-major order
121 ///
122 /// # Errors
123 ///
124 /// Returns `InvalidInput` if `data.len() != rows * cols`
125 ///
126 /// # Example
127 ///
128 /// ```
129 /// use trueno::Matrix;
130 ///
131 /// let data = [1.0, 2.0, 3.0, 4.0];
132 /// let m = Matrix::from_slice(2, 2, &data).unwrap();
133 /// assert_eq!(m.get(0, 0), Some(&1.0));
134 /// ```
135 pub fn from_slice(rows: usize, cols: usize, data: &[f32]) -> Result<Self, TruenoError> {
136 Self::from_vec(rows, cols, data.to_vec())
137 }
138
139 /// Creates a matrix filled with zeros
140 ///
141 /// # Example
142 ///
143 /// ```
144 /// use trueno::Matrix;
145 ///
146 /// let m = Matrix::zeros(3, 3);
147 /// assert_eq!(m.get(1, 1), Some(&0.0));
148 /// ```
149 pub fn zeros(rows: usize, cols: usize) -> Self {
150 Matrix::new(rows, cols)
151 }
152
153 /// Creates a matrix filled with zeros using a specific backend
154 /// (Internal use only - reuses backend from parent matrix)
155 pub(crate) fn zeros_with_backend(rows: usize, cols: usize, backend: Backend) -> Self {
156 Matrix { rows, cols, data: vec![0.0; rows * cols], backend }
157 }
158
159 /// Creates an identity matrix (square matrix with 1s on diagonal)
160 ///
161 /// # Example
162 ///
163 /// ```
164 /// use trueno::Matrix;
165 ///
166 /// let m = Matrix::identity(3);
167 /// assert_eq!(m.get(0, 0), Some(&1.0));
168 /// assert_eq!(m.get(0, 1), Some(&0.0));
169 /// assert_eq!(m.get(1, 1), Some(&1.0));
170 /// ```
171 pub fn identity(n: usize) -> Self {
172 let mut data = vec![0.0; n * n];
173 for i in 0..n {
174 data[i * n + i] = 1.0;
175 }
176 let backend = Backend::select_best();
177 Matrix { rows: n, cols: n, data, backend }
178 }
179
180 // =========================================================================
181 // Accessors
182 // =========================================================================
183
184 /// Returns the number of rows
185 pub fn rows(&self) -> usize {
186 self.rows
187 }
188
189 /// Returns the number of columns
190 pub fn cols(&self) -> usize {
191 self.cols
192 }
193
194 /// Returns the shape as (rows, cols)
195 pub fn shape(&self) -> (usize, usize) {
196 (self.rows, self.cols)
197 }
198
199 /// Gets a reference to an element at (row, col)
200 ///
201 /// Returns `None` if indices are out of bounds
202 pub fn get(&self, row: usize, col: usize) -> Option<&f32> {
203 if row >= self.rows || col >= self.cols {
204 None
205 } else {
206 self.data.get(row * self.cols + col)
207 }
208 }
209
210 /// Gets a mutable reference to an element at (row, col)
211 ///
212 /// Returns `None` if indices are out of bounds
213 pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut f32> {
214 if row >= self.rows || col >= self.cols {
215 None
216 } else {
217 let idx = row * self.cols + col;
218 self.data.get_mut(idx)
219 }
220 }
221
222 /// Returns a reference to the underlying data
223 pub fn as_slice(&self) -> &[f32] {
224 &self.data
225 }
226
227 /// Returns the backend used by this matrix
228 pub fn backend(&self) -> Backend {
229 self.backend
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn test_new_creates_zero_matrix() {
239 let m = Matrix::new(3, 4);
240 assert_eq!(m.rows(), 3);
241 assert_eq!(m.cols(), 4);
242 assert!(m.as_slice().iter().all(|&x| x == 0.0));
243 }
244
245 #[test]
246 fn test_from_vec_success() {
247 let m = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
248 assert_eq!(m.get(0, 0), Some(&1.0));
249 assert_eq!(m.get(1, 1), Some(&4.0));
250 }
251
252 #[test]
253 fn test_from_vec_dimension_mismatch() {
254 let result = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0]);
255 assert!(result.is_err());
256 }
257
258 #[test]
259 fn test_identity() {
260 let m = Matrix::identity(3);
261 assert_eq!(m.get(0, 0), Some(&1.0));
262 assert_eq!(m.get(1, 1), Some(&1.0));
263 assert_eq!(m.get(2, 2), Some(&1.0));
264 assert_eq!(m.get(0, 1), Some(&0.0));
265 }
266
267 #[test]
268 fn test_index_operator() {
269 let m = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
270 assert_eq!(m[(0, 0)], 1.0);
271 assert_eq!(m[(1, 1)], 4.0);
272 }
273
274 #[test]
275 fn test_get_out_of_bounds() {
276 let m = Matrix::new(2, 2);
277 assert_eq!(m.get(2, 0), None);
278 assert_eq!(m.get(0, 2), None);
279 }
280
281 #[test]
282 fn test_get_mut() {
283 let mut m = Matrix::new(2, 2);
284 if let Some(val) = m.get_mut(1, 1) {
285 *val = 42.0;
286 }
287 assert_eq!(m.get(1, 1), Some(&42.0));
288 }
289}