metal_matrix/matrix.rs
1/*!
2 * # Matrix
3 *
4 * This module provides the core `Matrix` data structure for linear algebra operations.
5 *
6 * The `Matrix` struct represents a 2D matrix with floating-point elements.
7 * It supports both regular matrices and vectors (as 1D matrices).
8 */
9
10use anyhow::Result;
11
12/// Represents a 2D matrix with dimensions and data.
13///
14/// This is the core data structure for all linear algebra operations in the library.
15/// It can represent both regular matrices and vectors (as 1D matrices with either
16/// one row or one column).
17///
18/// # Examples
19///
20/// Creating a new matrix:
21/// ```
22/// use metal_matrix::Matrix;
23///
24/// // Create a 3x3 zero matrix
25/// let mut matrix = Matrix::new(3, 3);
26///
27/// // Set some values
28/// matrix.set(0, 0, 1.0);
29/// matrix.set(1, 1, 2.0);
30/// matrix.set(2, 2, 3.0);
31/// ```
32///
33/// Creating a vector:
34/// ```
35/// use metal_matrix::Matrix;
36///
37/// // Create a column vector
38/// let vector = Matrix::vector(vec![1.0, 2.0, 3.0]);
39/// assert_eq!(vector.rows, 3);
40/// assert_eq!(vector.cols, 1);
41/// ```
42#[derive(Clone, Debug)]
43pub struct Matrix {
44 /// Number of rows in the matrix
45 pub rows: usize,
46
47 /// Number of columns in the matrix
48 pub cols: usize,
49
50 /// Matrix data in row-major order
51 pub data: Vec<f32>,
52}
53
54impl Matrix {
55 /// Create a new matrix with given dimensions, initialized with zeros.
56 ///
57 /// # Arguments
58 ///
59 /// * `rows` - Number of rows
60 /// * `cols` - Number of columns
61 ///
62 /// # Returns
63 ///
64 /// A new matrix of the specified dimensions, filled with zeros.
65 pub fn new(rows: usize, cols: usize) -> Self {
66 Self {
67 rows,
68 cols,
69 data: vec![0.0; rows * cols],
70 }
71 }
72
73 /// Create a new matrix with given dimensions and data.
74 ///
75 /// # Arguments
76 ///
77 /// * `rows` - Number of rows
78 /// * `cols` - Number of columns
79 /// * `data` - Vector of data in row-major order
80 ///
81 /// # Returns
82 ///
83 /// A `Result` containing the new matrix or an error if the data length doesn't match dimensions.
84 ///
85 /// # Errors
86 ///
87 /// Returns an error if `data.len() != rows * cols`.
88 pub fn with_data(rows: usize, cols: usize, data: Vec<f32>) -> Result<Self> {
89 if data.len() != rows * cols {
90 anyhow::bail!("Data length does not match matrix dimensions");
91 }
92
93 Ok(Self { rows, cols, data })
94 }
95
96 /// Create a column vector (1D matrix) with given data.
97 ///
98 /// # Arguments
99 ///
100 /// * `data` - Vector of data
101 ///
102 /// # Returns
103 ///
104 /// A new matrix with dimensions `(data.len(), 1)`.
105 pub fn vector(data: Vec<f32>) -> Self {
106 Self {
107 rows: data.len(),
108 cols: 1,
109 data,
110 }
111 }
112
113 /// Create an identity matrix of size n×n.
114 ///
115 /// # Arguments
116 ///
117 /// * `n` - Size of the square matrix
118 ///
119 /// # Returns
120 ///
121 /// A new n×n identity matrix (ones on the diagonal, zeros elsewhere).
122 pub fn identity(n: usize) -> Self {
123 let mut matrix = Self::new(n, n);
124 for i in 0..n {
125 matrix.set(i, i, 1.0);
126 }
127 matrix
128 }
129
130 /// Get element at position (row, col).
131 ///
132 /// # Arguments
133 ///
134 /// * `row` - Row index (0-based)
135 /// * `col` - Column index (0-based)
136 ///
137 /// # Returns
138 ///
139 /// The value at the specified position.
140 ///
141 /// # Panics
142 ///
143 /// Panics if the indices are out of bounds.
144 pub fn get(&self, row: usize, col: usize) -> f32 {
145 self.data[row * self.cols + col]
146 }
147
148 /// Set element at position (row, col).
149 ///
150 /// # Arguments
151 ///
152 /// * `row` - Row index (0-based)
153 /// * `col` - Column index (0-based)
154 /// * `value` - Value to set
155 ///
156 /// # Panics
157 ///
158 /// Panics if the indices are out of bounds.
159 pub fn set(&mut self, row: usize, col: usize, value: f32) {
160 self.data[row * self.cols + col] = value;
161 }
162
163 /// Check if this matrix is a vector (1D matrix).
164 ///
165 /// # Returns
166 ///
167 /// `true` if the matrix has either one row or one column, `false` otherwise.
168 pub fn is_vector(&self) -> bool {
169 self.cols == 1 || self.rows == 1
170 }
171
172 /// Get the size of the matrix if it's a vector.
173 ///
174 /// # Returns
175 ///
176 /// The number of elements if the matrix is a vector, or 0 if it's not a vector.
177 pub fn vector_size(&self) -> usize {
178 if self.cols == 1 {
179 self.rows
180 } else if self.rows == 1 {
181 self.cols
182 } else {
183 0 // Not a vector
184 }
185 }
186
187 /// Get vector element at index (for 1D matrices).
188 ///
189 /// # Arguments
190 ///
191 /// * `index` - Element index (0-based)
192 ///
193 /// # Returns
194 ///
195 /// A `Result` containing the value at the specified index or an error if the matrix is not a vector.
196 ///
197 /// # Errors
198 ///
199 /// Returns an error if the matrix is not a vector.
200 pub fn vector_get(&self, index: usize) -> Result<f32> {
201 if !self.is_vector() {
202 anyhow::bail!("Not a vector");
203 }
204
205 if self.cols == 1 {
206 Ok(self.get(index, 0))
207 } else {
208 Ok(self.get(0, index))
209 }
210 }
211
212 /// Extract a row as a new Matrix.
213 ///
214 /// # Arguments
215 ///
216 /// * `row` - Row index (0-based)
217 ///
218 /// # Returns
219 ///
220 /// A new 1×n matrix containing the specified row.
221 ///
222 /// # Panics
223 ///
224 /// Panics if the row index is out of bounds.
225 pub fn row(&self, row: usize) -> Self {
226 let mut data = Vec::with_capacity(self.cols);
227 for col in 0..self.cols {
228 data.push(self.get(row, col));
229 }
230 Self {
231 rows: 1,
232 cols: self.cols,
233 data,
234 }
235 }
236
237 /// Extract a column as a new Matrix.
238 ///
239 /// # Arguments
240 ///
241 /// * `col` - Column index (0-based)
242 ///
243 /// # Returns
244 ///
245 /// A new m×1 matrix containing the specified column.
246 ///
247 /// # Panics
248 ///
249 /// Panics if the column index is out of bounds.
250 pub fn column(&self, col: usize) -> Self {
251 let mut data = Vec::with_capacity(self.rows);
252 for row in 0..self.rows {
253 data.push(self.get(row, col));
254 }
255 Self {
256 rows: self.rows,
257 cols: 1,
258 data,
259 }
260 }
261}