Skip to main content

trueno/matrix/ops/
storage.rs

1//! Matrix storage and construction operations
2//!
3//! This module provides storage-related operations for Matrix:
4//! - Constructors: `new()`, `from_vec()`, `from_slice()`, `zeros()`, `identity()`
5//! - Accessors: `rows()`, `cols()`, `shape()`, `get()`, `get_mut()`, `as_slice()`
6//!
7//! ## Domain Separation (PMAT-018)
8//!
9//! Storage is separate from Algebra (arithmetic operations).
10//! A matrix's memory layout is independent of its mathematical operations.
11
12use crate::{Backend, TruenoError};
13
14use super::super::Matrix;
15
16impl std::ops::Index<(usize, usize)> for Matrix<f32> {
17    type Output = f32;
18
19    fn index(&self, (row, col): (usize, usize)) -> &Self::Output {
20        &self.data[row * self.cols + col]
21    }
22}
23
24impl Matrix<f32> {
25    // =========================================================================
26    // Constructors
27    // =========================================================================
28
29    /// Creates a new matrix with uninitialized values
30    ///
31    /// # Arguments
32    ///
33    /// * `rows` - Number of rows
34    /// * `cols` - Number of columns
35    ///
36    /// # Returns
37    ///
38    /// A new matrix with dimensions `rows x cols` containing uninitialized values
39    ///
40    /// # Example
41    ///
42    /// ```
43    /// use trueno::Matrix;
44    ///
45    /// let m = Matrix::new(3, 4);
46    /// assert_eq!(m.rows(), 3);
47    /// assert_eq!(m.cols(), 4);
48    /// ```
49    pub fn new(rows: usize, cols: usize) -> Self {
50        let backend = Backend::select_best();
51        Matrix { rows, cols, data: vec![0.0; rows * cols], backend }
52    }
53
54    /// Creates a matrix from a vector of data
55    ///
56    /// # Arguments
57    ///
58    /// * `rows` - Number of rows
59    /// * `cols` - Number of columns
60    /// * `data` - Vector containing matrix elements in row-major order
61    ///
62    /// # Errors
63    ///
64    /// Returns `InvalidInput` if `data.len() != rows * cols`
65    ///
66    /// # Example
67    ///
68    /// ```
69    /// use trueno::Matrix;
70    ///
71    /// let m = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
72    /// assert_eq!(m.rows(), 2);
73    /// assert_eq!(m.cols(), 2);
74    /// ```
75    pub fn from_vec(rows: usize, cols: usize, data: Vec<f32>) -> Result<Self, TruenoError> {
76        if data.len() != rows * cols {
77            return Err(TruenoError::InvalidInput(format!(
78                "Data length {} does not match matrix dimensions {}x{} (expected {})",
79                data.len(),
80                rows,
81                cols,
82                rows * cols
83            )));
84        }
85
86        let backend = Backend::select_best();
87        Ok(Matrix { rows, cols, data, backend })
88    }
89
90    /// Creates a matrix from a vector with a specific backend
91    ///
92    /// This is useful for testing specific SIMD code paths.
93    pub fn from_vec_with_backend(
94        rows: usize,
95        cols: usize,
96        data: Vec<f32>,
97        backend: Backend,
98    ) -> Self {
99        assert_eq!(
100            data.len(),
101            rows * cols,
102            "Data length {} does not match matrix dimensions {}x{}",
103            data.len(),
104            rows,
105            cols
106        );
107        Matrix { rows, cols, data, backend }
108    }
109
110    /// Creates a matrix from a slice by copying the data
111    ///
112    /// This is a convenience method that copies the slice into an owned vector.
113    /// For zero-copy scenarios, consider using the data directly with `from_vec`
114    /// if you already have an owned `Vec`.
115    ///
116    /// # Arguments
117    ///
118    /// * `rows` - Number of rows
119    /// * `cols` - Number of columns
120    /// * `data` - Slice containing matrix elements in row-major order
121    ///
122    /// # Errors
123    ///
124    /// Returns `InvalidInput` if `data.len() != rows * cols`
125    ///
126    /// # Example
127    ///
128    /// ```
129    /// use trueno::Matrix;
130    ///
131    /// let data = [1.0, 2.0, 3.0, 4.0];
132    /// let m = Matrix::from_slice(2, 2, &data).unwrap();
133    /// assert_eq!(m.get(0, 0), Some(&1.0));
134    /// ```
135    pub fn from_slice(rows: usize, cols: usize, data: &[f32]) -> Result<Self, TruenoError> {
136        Self::from_vec(rows, cols, data.to_vec())
137    }
138
139    /// Creates a matrix filled with zeros
140    ///
141    /// # Example
142    ///
143    /// ```
144    /// use trueno::Matrix;
145    ///
146    /// let m = Matrix::zeros(3, 3);
147    /// assert_eq!(m.get(1, 1), Some(&0.0));
148    /// ```
149    pub fn zeros(rows: usize, cols: usize) -> Self {
150        Matrix::new(rows, cols)
151    }
152
153    /// Creates a matrix filled with zeros using a specific backend
154    /// (Internal use only - reuses backend from parent matrix)
155    pub(crate) fn zeros_with_backend(rows: usize, cols: usize, backend: Backend) -> Self {
156        Matrix { rows, cols, data: vec![0.0; rows * cols], backend }
157    }
158
159    /// Creates an identity matrix (square matrix with 1s on diagonal)
160    ///
161    /// # Example
162    ///
163    /// ```
164    /// use trueno::Matrix;
165    ///
166    /// let m = Matrix::identity(3);
167    /// assert_eq!(m.get(0, 0), Some(&1.0));
168    /// assert_eq!(m.get(0, 1), Some(&0.0));
169    /// assert_eq!(m.get(1, 1), Some(&1.0));
170    /// ```
171    pub fn identity(n: usize) -> Self {
172        let mut data = vec![0.0; n * n];
173        for i in 0..n {
174            data[i * n + i] = 1.0;
175        }
176        let backend = Backend::select_best();
177        Matrix { rows: n, cols: n, data, backend }
178    }
179
180    // =========================================================================
181    // Accessors
182    // =========================================================================
183
184    /// Returns the number of rows
185    pub fn rows(&self) -> usize {
186        self.rows
187    }
188
189    /// Returns the number of columns
190    pub fn cols(&self) -> usize {
191        self.cols
192    }
193
194    /// Returns the shape as (rows, cols)
195    pub fn shape(&self) -> (usize, usize) {
196        (self.rows, self.cols)
197    }
198
199    /// Gets a reference to an element at (row, col)
200    ///
201    /// Returns `None` if indices are out of bounds
202    pub fn get(&self, row: usize, col: usize) -> Option<&f32> {
203        if row >= self.rows || col >= self.cols {
204            None
205        } else {
206            self.data.get(row * self.cols + col)
207        }
208    }
209
210    /// Gets a mutable reference to an element at (row, col)
211    ///
212    /// Returns `None` if indices are out of bounds
213    pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut f32> {
214        if row >= self.rows || col >= self.cols {
215            None
216        } else {
217            let idx = row * self.cols + col;
218            self.data.get_mut(idx)
219        }
220    }
221
222    /// Returns a reference to the underlying data
223    pub fn as_slice(&self) -> &[f32] {
224        &self.data
225    }
226
227    /// Returns the backend used by this matrix
228    pub fn backend(&self) -> Backend {
229        self.backend
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    #[test]
238    fn test_new_creates_zero_matrix() {
239        let m = Matrix::new(3, 4);
240        assert_eq!(m.rows(), 3);
241        assert_eq!(m.cols(), 4);
242        assert!(m.as_slice().iter().all(|&x| x == 0.0));
243    }
244
245    #[test]
246    fn test_from_vec_success() {
247        let m = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
248        assert_eq!(m.get(0, 0), Some(&1.0));
249        assert_eq!(m.get(1, 1), Some(&4.0));
250    }
251
252    #[test]
253    fn test_from_vec_dimension_mismatch() {
254        let result = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0]);
255        assert!(result.is_err());
256    }
257
258    #[test]
259    fn test_identity() {
260        let m = Matrix::identity(3);
261        assert_eq!(m.get(0, 0), Some(&1.0));
262        assert_eq!(m.get(1, 1), Some(&1.0));
263        assert_eq!(m.get(2, 2), Some(&1.0));
264        assert_eq!(m.get(0, 1), Some(&0.0));
265    }
266
267    #[test]
268    fn test_index_operator() {
269        let m = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
270        assert_eq!(m[(0, 0)], 1.0);
271        assert_eq!(m[(1, 1)], 4.0);
272    }
273
274    #[test]
275    fn test_get_out_of_bounds() {
276        let m = Matrix::new(2, 2);
277        assert_eq!(m.get(2, 0), None);
278        assert_eq!(m.get(0, 2), None);
279    }
280
281    #[test]
282    fn test_get_mut() {
283        let mut m = Matrix::new(2, 2);
284        if let Some(val) = m.get_mut(1, 1) {
285            *val = 42.0;
286        }
287        assert_eq!(m.get(1, 1), Some(&42.0));
288    }
289}