metal_matrix/
matrix.rs

1/*!
2 * # Matrix
3 *
4 * This module provides the core `Matrix` data structure for linear algebra operations.
5 *
6 * The `Matrix` struct represents a 2D matrix with floating-point elements.
7 * It supports both regular matrices and vectors (as 1D matrices).
8 */
9
10use anyhow::Result;
11
12/// Represents a 2D matrix with dimensions and data.
13///
14/// This is the core data structure for all linear algebra operations in the library.
15/// It can represent both regular matrices and vectors (as 1D matrices with either
16/// one row or one column).
17///
18/// # Examples
19///
20/// Creating a new matrix:
21/// ```
22/// use metal_matrix::Matrix;
23///
24/// // Create a 3x3 zero matrix
25/// let mut matrix = Matrix::new(3, 3);
26///
27/// // Set some values
28/// matrix.set(0, 0, 1.0);
29/// matrix.set(1, 1, 2.0);
30/// matrix.set(2, 2, 3.0);
31/// ```
32///
33/// Creating a vector:
34/// ```
35/// use metal_matrix::Matrix;
36///
37/// // Create a column vector
38/// let vector = Matrix::vector(vec![1.0, 2.0, 3.0]);
39/// assert_eq!(vector.rows, 3);
40/// assert_eq!(vector.cols, 1);
41/// ```
42#[derive(Clone, Debug)]
43pub struct Matrix {
44    /// Number of rows in the matrix
45    pub rows: usize,
46
47    /// Number of columns in the matrix
48    pub cols: usize,
49
50    /// Matrix data in row-major order
51    pub data: Vec<f32>,
52}
53
54impl Matrix {
55    /// Create a new matrix with given dimensions, initialized with zeros.
56    ///
57    /// # Arguments
58    ///
59    /// * `rows` - Number of rows
60    /// * `cols` - Number of columns
61    ///
62    /// # Returns
63    ///
64    /// A new matrix of the specified dimensions, filled with zeros.
65    pub fn new(rows: usize, cols: usize) -> Self {
66        Self {
67            rows,
68            cols,
69            data: vec![0.0; rows * cols],
70        }
71    }
72
73    /// Create a new matrix with given dimensions and data.
74    ///
75    /// # Arguments
76    ///
77    /// * `rows` - Number of rows
78    /// * `cols` - Number of columns
79    /// * `data` - Vector of data in row-major order
80    ///
81    /// # Returns
82    ///
83    /// A `Result` containing the new matrix or an error if the data length doesn't match dimensions.
84    ///
85    /// # Errors
86    ///
87    /// Returns an error if `data.len() != rows * cols`.
88    pub fn with_data(rows: usize, cols: usize, data: Vec<f32>) -> Result<Self> {
89        if data.len() != rows * cols {
90            anyhow::bail!("Data length does not match matrix dimensions");
91        }
92
93        Ok(Self { rows, cols, data })
94    }
95
96    /// Create a column vector (1D matrix) with given data.
97    ///
98    /// # Arguments
99    ///
100    /// * `data` - Vector of data
101    ///
102    /// # Returns
103    ///
104    /// A new matrix with dimensions `(data.len(), 1)`.
105    pub fn vector(data: Vec<f32>) -> Self {
106        Self {
107            rows: data.len(),
108            cols: 1,
109            data,
110        }
111    }
112
113    /// Create an identity matrix of size n×n.
114    ///
115    /// # Arguments
116    ///
117    /// * `n` - Size of the square matrix
118    ///
119    /// # Returns
120    ///
121    /// A new n×n identity matrix (ones on the diagonal, zeros elsewhere).
122    pub fn identity(n: usize) -> Self {
123        let mut matrix = Self::new(n, n);
124        for i in 0..n {
125            matrix.set(i, i, 1.0);
126        }
127        matrix
128    }
129
130    /// Get element at position (row, col).
131    ///
132    /// # Arguments
133    ///
134    /// * `row` - Row index (0-based)
135    /// * `col` - Column index (0-based)
136    ///
137    /// # Returns
138    ///
139    /// The value at the specified position.
140    ///
141    /// # Panics
142    ///
143    /// Panics if the indices are out of bounds.
144    pub fn get(&self, row: usize, col: usize) -> f32 {
145        self.data[row * self.cols + col]
146    }
147
148    /// Set element at position (row, col).
149    ///
150    /// # Arguments
151    ///
152    /// * `row` - Row index (0-based)
153    /// * `col` - Column index (0-based)
154    /// * `value` - Value to set
155    ///
156    /// # Panics
157    ///
158    /// Panics if the indices are out of bounds.
159    pub fn set(&mut self, row: usize, col: usize, value: f32) {
160        self.data[row * self.cols + col] = value;
161    }
162
163    /// Check if this matrix is a vector (1D matrix).
164    ///
165    /// # Returns
166    ///
167    /// `true` if the matrix has either one row or one column, `false` otherwise.
168    pub fn is_vector(&self) -> bool {
169        self.cols == 1 || self.rows == 1
170    }
171
172    /// Get the size of the matrix if it's a vector.
173    ///
174    /// # Returns
175    ///
176    /// The number of elements if the matrix is a vector, or 0 if it's not a vector.
177    pub fn vector_size(&self) -> usize {
178        if self.cols == 1 {
179            self.rows
180        } else if self.rows == 1 {
181            self.cols
182        } else {
183            0 // Not a vector
184        }
185    }
186
187    /// Get vector element at index (for 1D matrices).
188    ///
189    /// # Arguments
190    ///
191    /// * `index` - Element index (0-based)
192    ///
193    /// # Returns
194    ///
195    /// A `Result` containing the value at the specified index or an error if the matrix is not a vector.
196    ///
197    /// # Errors
198    ///
199    /// Returns an error if the matrix is not a vector.
200    pub fn vector_get(&self, index: usize) -> Result<f32> {
201        if !self.is_vector() {
202            anyhow::bail!("Not a vector");
203        }
204
205        if self.cols == 1 {
206            Ok(self.get(index, 0))
207        } else {
208            Ok(self.get(0, index))
209        }
210    }
211
212    /// Extract a row as a new Matrix.
213    ///
214    /// # Arguments
215    ///
216    /// * `row` - Row index (0-based)
217    ///
218    /// # Returns
219    ///
220    /// A new 1×n matrix containing the specified row.
221    ///
222    /// # Panics
223    ///
224    /// Panics if the row index is out of bounds.
225    pub fn row(&self, row: usize) -> Self {
226        let mut data = Vec::with_capacity(self.cols);
227        for col in 0..self.cols {
228            data.push(self.get(row, col));
229        }
230        Self {
231            rows: 1,
232            cols: self.cols,
233            data,
234        }
235    }
236
237    /// Extract a column as a new Matrix.
238    ///
239    /// # Arguments
240    ///
241    /// * `col` - Column index (0-based)
242    ///
243    /// # Returns
244    ///
245    /// A new m×1 matrix containing the specified column.
246    ///
247    /// # Panics
248    ///
249    /// Panics if the column index is out of bounds.
250    pub fn column(&self, col: usize) -> Self {
251        let mut data = Vec::with_capacity(self.rows);
252        for row in 0..self.rows {
253            data.push(self.get(row, col));
254        }
255        Self {
256            rows: self.rows,
257            cols: 1,
258            data,
259        }
260    }
261}