mathhook_core/matrices/unified/
construction.rs

1//! Matrix construction and property methods
2
3use crate::core::matrix::NumericMatrix;
4use crate::core::Expression;
5use crate::matrices::types::*;
6use crate::matrices::unified::Matrix;
7
8impl Matrix {
9    /// Get matrix dimensions efficiently
10    ///
11    /// This method provides O(1) dimension lookup for all matrix types.
12    #[inline]
13    pub fn dimensions(&self) -> (usize, usize) {
14        match self {
15            Matrix::Dense(data) => {
16                let rows = data.rows.len();
17                let cols = data.rows.first().map(|row| row.len()).unwrap_or(0);
18                (rows, cols)
19            }
20            Matrix::Identity(data) => (data.size, data.size),
21            Matrix::Zero(data) => (data.rows, data.cols),
22            Matrix::Diagonal(data) => {
23                let size = data.diagonal_elements.len();
24                (size, size)
25            }
26            Matrix::Scalar(data) => (data.size, data.size),
27            Matrix::UpperTriangular(data) => (data.size, data.size),
28            Matrix::LowerTriangular(data) => (data.size, data.size),
29            Matrix::Symmetric(data) => (data.size, data.size),
30            Matrix::Permutation(data) => {
31                let size = data.permutation.len();
32                (size, size)
33            }
34        }
35    }
36
37    /// Get element at position (i, j) efficiently
38    ///
39    /// This method provides optimized element access for each matrix type.
40    #[inline]
41    pub fn get_element(&self, i: usize, j: usize) -> Expression {
42        match self {
43            Matrix::Dense(data) => data
44                .rows
45                .get(i)
46                .and_then(|row| row.get(j))
47                .cloned()
48                .unwrap_or_else(|| Expression::integer(0)),
49
50            Matrix::Identity(data) => {
51                if i < data.size && j < data.size && i == j {
52                    Expression::integer(1)
53                } else {
54                    Expression::integer(0)
55                }
56            }
57
58            Matrix::Zero(_) => Expression::integer(0),
59
60            Matrix::Diagonal(data) => {
61                if i == j && i < data.diagonal_elements.len() {
62                    data.diagonal_elements[i].clone()
63                } else {
64                    Expression::integer(0)
65                }
66            }
67
68            Matrix::Scalar(data) => {
69                if i < data.size && j < data.size && i == j {
70                    data.scalar_value.clone()
71                } else {
72                    Expression::integer(0)
73                }
74            }
75
76            Matrix::UpperTriangular(data) => {
77                if i <= j && i < data.size && j < data.size {
78                    data.get_element(i, j)
79                        .cloned()
80                        .unwrap_or_else(|| Expression::integer(0))
81                } else {
82                    Expression::integer(0)
83                }
84            }
85
86            Matrix::LowerTriangular(data) => {
87                if i >= j && i < data.size && j < data.size {
88                    data.get_element(i, j)
89                        .cloned()
90                        .unwrap_or_else(|| Expression::integer(0))
91                } else {
92                    Expression::integer(0)
93                }
94            }
95
96            Matrix::Symmetric(data) => {
97                if i < data.size && j < data.size {
98                    data.get_element(i, j)
99                        .cloned()
100                        .unwrap_or_else(|| Expression::integer(0))
101                } else {
102                    Expression::integer(0)
103                }
104            }
105
106            Matrix::Permutation(data) => Expression::integer(data.get_element(i, j)),
107        }
108    }
109
110    /// Try to convert this matrix to a NumericMatrix for fast numeric operations.
111    ///
112    /// Returns Some(NumericMatrix) if all elements can be converted to f64,
113    /// None otherwise (e.g., if matrix contains symbolic expressions).
114    pub fn as_numeric(&self) -> Option<NumericMatrix> {
115        NumericMatrix::try_from_matrix(self)
116    }
117
118    /// Check if this is a square matrix
119    #[inline]
120    pub fn is_square(&self) -> bool {
121        let (rows, cols) = self.dimensions();
122        rows == cols
123    }
124
125    /// Check if this is a zero matrix
126    #[inline]
127    pub fn is_zero(&self) -> bool {
128        matches!(self, Matrix::Zero(_))
129    }
130
131    /// Check if this is an identity matrix
132    #[inline]
133    pub fn is_identity(&self) -> bool {
134        match self {
135            Matrix::Identity(_) => true,
136            Matrix::Scalar(data) => data.scalar_value == Expression::integer(1),
137            _ => false,
138        }
139    }
140
141    /// Check if this is a diagonal matrix
142    #[inline]
143    pub fn is_diagonal(&self) -> bool {
144        matches!(
145            self,
146            Matrix::Identity(_) | Matrix::Zero(_) | Matrix::Diagonal(_) | Matrix::Scalar(_)
147        )
148    }
149
150    /// Check if this is symmetric
151    #[inline]
152    pub fn is_symmetric(&self) -> bool {
153        matches!(
154            self,
155            Matrix::Identity(_)
156                | Matrix::Zero(_)
157                | Matrix::Diagonal(_)
158                | Matrix::Scalar(_)
159                | Matrix::Symmetric(_)
160        )
161    }
162
163    /// Convert to the most efficient representation
164    ///
165    /// This method analyzes the matrix and converts it to the most
166    /// memory-efficient representation possible.
167    pub fn optimize(self) -> Matrix {
168        match self {
169            Matrix::Dense(data) => {
170                let (rows, cols) = (
171                    data.rows.len(),
172                    data.rows.first().map(|r| r.len()).unwrap_or(0),
173                );
174
175                if data
176                    .rows
177                    .iter()
178                    .all(|row| row.iter().all(|elem| elem.is_zero_fast()))
179                {
180                    return Matrix::Zero(ZeroMatrixData { rows, cols });
181                }
182
183                if rows == cols
184                    && data.rows.iter().enumerate().all(|(i, row)| {
185                        row.iter().enumerate().all(|(j, elem)| {
186                            if i == j {
187                                elem == &Expression::integer(1)
188                            } else {
189                                elem.is_zero_fast()
190                            }
191                        })
192                    })
193                {
194                    return Matrix::Identity(IdentityMatrixData { size: rows });
195                }
196
197                if rows == cols
198                    && data.rows.iter().enumerate().all(|(i, row)| {
199                        row.iter()
200                            .enumerate()
201                            .all(|(j, elem)| i == j || elem.is_zero_fast())
202                    })
203                {
204                    let diagonal_elements: Vec<Expression> =
205                        (0..rows).map(|i| data.rows[i][i].clone()).collect();
206
207                    if diagonal_elements
208                        .iter()
209                        .all(|elem| elem == &Expression::integer(1))
210                    {
211                        return Matrix::Identity(IdentityMatrixData { size: rows });
212                    }
213
214                    if diagonal_elements
215                        .iter()
216                        .all(|elem| elem == &diagonal_elements[0])
217                    {
218                        return Matrix::Scalar(ScalarMatrixData {
219                            size: rows,
220                            scalar_value: diagonal_elements[0].clone(),
221                        });
222                    }
223
224                    return Matrix::Diagonal(DiagonalMatrixData { diagonal_elements });
225                }
226
227                Matrix::Dense(data)
228            }
229
230            Matrix::Diagonal(data) => {
231                if data
232                    .diagonal_elements
233                    .iter()
234                    .all(|elem| elem == &Expression::integer(1))
235                {
236                    return Matrix::Identity(IdentityMatrixData {
237                        size: data.diagonal_elements.len(),
238                    });
239                }
240
241                if data
242                    .diagonal_elements
243                    .iter()
244                    .all(|elem| elem.is_zero_fast())
245                {
246                    let size = data.diagonal_elements.len();
247                    return Matrix::Zero(ZeroMatrixData {
248                        rows: size,
249                        cols: size,
250                    });
251                }
252
253                if !data.diagonal_elements.is_empty()
254                    && data
255                        .diagonal_elements
256                        .iter()
257                        .all(|elem| elem == &data.diagonal_elements[0])
258                {
259                    return Matrix::Scalar(ScalarMatrixData {
260                        size: data.diagonal_elements.len(),
261                        scalar_value: data.diagonal_elements[0].clone(),
262                    });
263                }
264
265                Matrix::Diagonal(data)
266            }
267
268            other => other,
269        }
270    }
271
272    /// Create a dense matrix from rows
273    ///
274    /// # Examples
275    ///
276    /// ```rust
277    /// use mathhook_core::matrices::Matrix;
278    /// use mathhook_core::Expression;
279    ///
280    /// let matrix = Matrix::dense(vec![
281    ///     vec![Expression::integer(1), Expression::integer(2)],
282    ///     vec![Expression::integer(3), Expression::integer(4)]
283    /// ]);
284    /// ```
285    pub fn dense(rows: Vec<Vec<Expression>>) -> Self {
286        Matrix::Dense(MatrixData { rows }).optimize()
287    }
288
289    /// Create an identity matrix of given size
290    /// Memory efficient: O(1) storage vs O(n²) for dense matrix
291    ///
292    /// # Examples
293    ///
294    /// ```rust
295    /// use mathhook_core::matrices::Matrix;
296    ///
297    /// let identity = Matrix::identity(3);
298    /// assert_eq!(identity.dimensions(), (3, 3));
299    /// assert!(identity.is_identity());
300    /// ```
301    pub fn identity(size: usize) -> Self {
302        Matrix::Identity(IdentityMatrixData { size })
303    }
304
305    /// Create a zero matrix of given dimensions
306    /// Memory efficient: O(1) storage vs O(n*m) for dense matrix
307    ///
308    /// # Examples
309    ///
310    /// ```rust
311    /// use mathhook_core::matrices::Matrix;
312    ///
313    /// let zero = Matrix::zero(2, 3);
314    /// assert_eq!(zero.dimensions(), (2, 3));
315    /// assert!(zero.is_zero());
316    /// ```
317    pub fn zero(rows: usize, cols: usize) -> Self {
318        Matrix::Zero(ZeroMatrixData { rows, cols })
319    }
320
321    /// Create a diagonal matrix from diagonal elements
322    /// Memory efficient: O(n) storage vs O(n²) for dense matrix
323    ///
324    /// # Examples
325    ///
326    /// ```rust
327    /// use mathhook_core::matrices::Matrix;
328    /// use mathhook_core::Expression;
329    ///
330    /// let diag = Matrix::diagonal(vec![
331    ///     Expression::integer(1),
332    ///     Expression::integer(2),
333    ///     Expression::integer(3)
334    /// ]);
335    /// assert_eq!(diag.dimensions(), (3, 3));
336    /// assert!(diag.is_diagonal());
337    /// ```
338    pub fn diagonal(diagonal_elements: Vec<Expression>) -> Self {
339        Matrix::Diagonal(DiagonalMatrixData { diagonal_elements }).optimize()
340    }
341
342    /// Create a scalar matrix (c*I)
343    /// Memory efficient: O(1) storage vs O(n²) for dense matrix
344    ///
345    /// # Examples
346    ///
347    /// ```rust
348    /// use mathhook_core::matrices::Matrix;
349    /// use mathhook_core::Expression;
350    ///
351    /// let scalar = Matrix::scalar(3, Expression::integer(5));
352    /// ```
353    pub fn scalar(size: usize, scalar_value: Expression) -> Self {
354        Matrix::Scalar(ScalarMatrixData { size, scalar_value })
355    }
356
357    /// Create an upper triangular matrix
358    /// Memory efficient: ~50% storage vs dense matrix
359    ///
360    /// # Examples
361    ///
362    /// ```rust
363    /// use mathhook_core::matrices::Matrix;
364    /// use mathhook_core::Expression;
365    ///
366    /// let upper = Matrix::upper_triangular(3, vec![
367    ///     Expression::integer(1), Expression::integer(2), Expression::integer(3),
368    ///     Expression::integer(4), Expression::integer(5),
369    ///     Expression::integer(6)
370    /// ]);
371    /// ```
372    pub fn upper_triangular(size: usize, elements: Vec<Expression>) -> Self {
373        Matrix::UpperTriangular(UpperTriangularMatrixData { size, elements })
374    }
375
376    /// Create a lower triangular matrix
377    /// Memory efficient: ~50% storage vs dense matrix
378    ///
379    /// # Examples
380    ///
381    /// ```rust
382    /// use mathhook_core::matrices::Matrix;
383    /// use mathhook_core::Expression;
384    ///
385    /// let lower = Matrix::lower_triangular(3, vec![
386    ///     Expression::integer(1),
387    ///     Expression::integer(2), Expression::integer(3),
388    ///     Expression::integer(4), Expression::integer(5), Expression::integer(6)
389    /// ]);
390    /// ```
391    pub fn lower_triangular(size: usize, elements: Vec<Expression>) -> Self {
392        Matrix::LowerTriangular(LowerTriangularMatrixData { size, elements })
393    }
394
395    /// Create a symmetric matrix
396    /// Memory efficient: ~50% storage vs dense matrix
397    ///
398    /// # Examples
399    ///
400    /// ```rust
401    /// use mathhook_core::matrices::Matrix;
402    /// use mathhook_core::Expression;
403    ///
404    /// let symmetric = Matrix::symmetric(3, vec![
405    ///     Expression::integer(1), Expression::integer(2), Expression::integer(3),
406    ///     Expression::integer(4), Expression::integer(5),
407    ///     Expression::integer(6)
408    /// ]);
409    /// ```
410    pub fn symmetric(size: usize, elements: Vec<Expression>) -> Self {
411        Matrix::Symmetric(SymmetricMatrixData { size, elements })
412    }
413
414    /// Create a permutation matrix
415    /// Memory efficient: O(n) storage vs O(n²) for dense matrix
416    ///
417    /// # Examples
418    ///
419    /// ```rust
420    /// use mathhook_core::matrices::Matrix;
421    ///
422    /// let perm = Matrix::permutation(vec![2, 0, 1]);
423    /// ```
424    pub fn permutation(permutation: Vec<usize>) -> Self {
425        Matrix::Permutation(PermutationMatrixData { permutation })
426    }
427
428    /// Create matrix from nested arrays (convenience method)
429    ///
430    /// # Examples
431    ///
432    /// ```rust
433    /// use mathhook_core::matrices::Matrix;
434    /// use mathhook_core::Expression;
435    ///
436    /// let matrix = Matrix::from_arrays([
437    ///     [1, 2, 3],
438    ///     [4, 5, 6]
439    /// ]);
440    /// ```
441    pub fn from_arrays<const R: usize, const C: usize>(arrays: [[i64; C]; R]) -> Self {
442        let rows: Vec<Vec<Expression>> = arrays
443            .iter()
444            .map(|row| row.iter().map(|&val| Expression::integer(val)).collect())
445            .collect();
446        Matrix::dense(rows)
447    }
448
449    /// Create matrix from flat vector (row-major order)
450    ///
451    /// # Examples
452    ///
453    /// ```rust
454    /// use mathhook_core::matrices::Matrix;
455    /// use mathhook_core::Expression;
456    ///
457    /// let matrix = Matrix::from_flat(2, 3, &[
458    ///     Expression::integer(1), Expression::integer(2), Expression::integer(3),
459    ///     Expression::integer(4), Expression::integer(5), Expression::integer(6)
460    /// ]);
461    /// ```
462    pub fn from_flat(rows: usize, cols: usize, elements: &[Expression]) -> Self {
463        if elements.len() != rows * cols {
464            return Matrix::zero(rows, cols);
465        }
466
467        let matrix_rows: Vec<Vec<Expression>> =
468            elements.chunks(cols).map(|chunk| chunk.to_vec()).collect();
469
470        Matrix::dense(matrix_rows)
471    }
472}