mathhook_core/matrices/unified/
operations.rs

1//! Matrix arithmetic and basic operations
2
3use crate::core::Expression;
4use crate::error::MathError;
5use crate::matrices::types::*;
6use crate::matrices::unified::Matrix;
7use crate::simplify::Simplify;
8
9impl Matrix {
10    /// Get the trace (sum of diagonal elements) efficiently
11    #[inline]
12    pub fn trace(&self) -> Expression {
13        match self {
14            Matrix::Identity(data) => Expression::integer(data.size as i64),
15            Matrix::Zero(_) => Expression::integer(0),
16            Matrix::Scalar(data) => Expression::mul(vec![
17                Expression::integer(data.size as i64),
18                data.scalar_value.clone(),
19            ])
20            .simplify(),
21            Matrix::Diagonal(data) => Expression::add(data.diagonal_elements.clone()).simplify(),
22            _ => {
23                let (rows, _) = self.dimensions();
24                let diagonal_elements: Vec<Expression> =
25                    (0..rows).map(|i| self.get_element(i, i)).collect();
26                Expression::add(diagonal_elements).simplify()
27            }
28        }
29    }
30
31    /// Get the determinant efficiently (for square matrices)
32    ///
33    /// # Returns
34    /// Result containing the determinant expression, or MathError for non-square matrices
35    ///
36    /// # Errors
37    /// Returns DomainError if matrix is not square
38    ///
39    /// # Algorithm
40    /// - Special matrices (Identity, Zero, Scalar, Diagonal): O(1) or O(n)
41    /// - Small matrices (1x1, 2x2): Direct formulas
42    /// - Numeric matrices: NumericMatrix fast-path with O(n³) LU decomposition
43    /// - Larger symbolic matrices (n≥3): LU decomposition O(n³)
44    pub fn determinant(&self) -> Result<Expression, MathError> {
45        let (rows, cols) = self.dimensions();
46
47        if rows != cols {
48            return Err(MathError::DomainError {
49                operation: "determinant".to_string(),
50                value: Expression::function("matrix", vec![]),
51                reason: format!("Determinant requires square matrix, got {}x{}", rows, cols),
52            });
53        }
54
55        match self {
56            Matrix::Identity(_) => return Ok(Expression::integer(1)),
57            Matrix::Zero(_) => return Ok(Expression::integer(0)),
58            Matrix::Scalar(data) => {
59                return Ok(Expression::pow(
60                    data.scalar_value.clone(),
61                    Expression::integer(data.size as i64),
62                )
63                .simplify())
64            }
65            Matrix::Diagonal(data) => {
66                return Ok(Expression::mul(data.diagonal_elements.clone()).simplify())
67            }
68            Matrix::Permutation(_data) => return Ok(Expression::integer(1)),
69            _ => {}
70        }
71
72        if let Some(numeric) = self.as_numeric() {
73            let det = numeric.determinant()?;
74            return Ok(Expression::float(det));
75        }
76
77        if rows == 0 {
78            return Ok(Expression::integer(1));
79        }
80
81        if rows == 1 {
82            return Ok(self.get_element(0, 0));
83        }
84
85        if rows == 2 {
86            let a = self.get_element(0, 0);
87            let b = self.get_element(0, 1);
88            let c = self.get_element(1, 0);
89            let d = self.get_element(1, 1);
90
91            let ad = Expression::mul(vec![a, d]);
92            let bc = Expression::mul(vec![b, c]);
93            return Ok(Expression::add(vec![
94                ad,
95                Expression::mul(vec![Expression::integer(-1), bc]),
96            ])
97            .simplify());
98        }
99
100        Ok(self.determinant_lu())
101    }
102
103    /// Compute determinant using LU decomposition: det(A) = det(L) * det(U) * sign(P)
104    ///
105    /// O(n³) algorithm significantly faster than cofactor expansion O(n!)
106    ///
107    /// For symbolic matrices, falls back to cofactor expansion on small minors
108    fn determinant_lu(&self) -> Expression {
109        let (n, _) = self.dimensions();
110
111        let mut a: Vec<Vec<Expression>> = (0..n)
112            .map(|i| (0..n).map(|j| self.get_element(i, j)).collect())
113            .collect();
114
115        let mut sign = 1i64;
116
117        for k in 0..n {
118            let pivot = ((k + 1)..n).find(|&i| !a[i][k].is_zero_fast()).unwrap_or(k);
119
120            if pivot != k {
121                a.swap(k, pivot);
122                sign = -sign;
123            }
124
125            if a[k][k].is_zero_fast() {
126                return Expression::integer(0);
127            }
128
129            for i in (k + 1)..n {
130                let factor = Expression::mul(vec![
131                    a[i][k].clone(),
132                    Expression::pow(a[k][k].clone(), Expression::integer(-1)),
133                ]);
134
135                let pivot_row: Vec<Expression> = a[k][k..n].to_vec();
136                for (j_offset, pivot_val) in pivot_row.into_iter().enumerate() {
137                    let j = k + j_offset;
138                    let subtraction = Expression::mul(vec![factor.clone(), pivot_val]);
139                    a[i][j] = a[i][j].clone() - subtraction;
140                }
141            }
142        }
143
144        let det_u: Vec<Expression> = (0..n).map(|i| a[i][i].clone()).collect();
145
146        Expression::mul(vec![Expression::integer(sign), Expression::mul(det_u)])
147    }
148
149    /// Scalar multiplication
150    pub fn scalar_multiply(&self, scalar: &Expression) -> Matrix {
151        match self {
152            Matrix::Zero(data) => Matrix::Zero(data.clone()),
153            Matrix::Identity(data) => Matrix::Scalar(ScalarMatrixData {
154                size: data.size,
155                scalar_value: scalar.clone(),
156            }),
157            Matrix::Scalar(data) => Matrix::Scalar(ScalarMatrixData {
158                size: data.size,
159                scalar_value: Expression::mul(vec![scalar.clone(), data.scalar_value.clone()])
160                    .simplify(),
161            }),
162            Matrix::Diagonal(data) => {
163                let scaled_elements: Vec<Expression> = data
164                    .diagonal_elements
165                    .iter()
166                    .map(|elem| Expression::mul(vec![scalar.clone(), elem.clone()]).simplify())
167                    .collect();
168                Matrix::Diagonal(DiagonalMatrixData {
169                    diagonal_elements: scaled_elements,
170                })
171            }
172            _ => {
173                let (rows, cols) = self.dimensions();
174                let mut result_rows = Vec::with_capacity(rows);
175                for i in 0..rows {
176                    let mut row = Vec::with_capacity(cols);
177                    for j in 0..cols {
178                        let elem = self.get_element(i, j);
179                        let product = Expression::mul(vec![scalar.clone(), elem]).simplify();
180                        row.push(product);
181                    }
182                    result_rows.push(row);
183                }
184                Matrix::Dense(MatrixData { rows: result_rows }).optimize()
185            }
186        }
187    }
188}
189
190/// Core matrix operations that work directly on Matrix types
191pub trait CoreMatrixOps {
192    fn add(&self, other: &Matrix) -> Result<Matrix, MathError>;
193    fn multiply(&self, other: &Matrix) -> Result<Matrix, MathError>;
194    fn transpose(&self) -> Matrix;
195    fn inverse(&self) -> Matrix;
196}
197
198impl CoreMatrixOps for Matrix {
199    fn add(&self, other: &Matrix) -> Result<Matrix, MathError> {
200        let (rows1, cols1) = self.dimensions();
201        let (rows2, cols2) = other.dimensions();
202
203        if rows1 != rows2 || cols1 != cols2 {
204            return Err(MathError::DomainError {
205                operation: "matrix_addition".to_string(),
206                value: Expression::function("incompatible_matrices", vec![]),
207                reason: format!(
208                    "Cannot add {}x{} matrix to {}x{} matrix",
209                    rows1, cols1, rows2, cols2
210                ),
211            });
212        }
213
214        let result = match (self, other) {
215            (Matrix::Zero(_), other) => other.clone(),
216            (this, Matrix::Zero(_)) => this.clone(),
217
218            (Matrix::Identity(id), Matrix::Dense(dense))
219            | (Matrix::Dense(dense), Matrix::Identity(id)) => {
220                let mut result_rows = dense.rows.clone();
221                for i in 0..id.size.min(result_rows.len()) {
222                    if let Some(row) = result_rows.get_mut(i) {
223                        if let Some(elem) = row.get_mut(i) {
224                            *elem = Expression::add(vec![elem.clone(), Expression::integer(1)]);
225                        }
226                    }
227                }
228                Matrix::Dense(MatrixData { rows: result_rows })
229            }
230
231            (Matrix::Diagonal(d1), Matrix::Diagonal(d2))
232                if d1.diagonal_elements.len() == d2.diagonal_elements.len() =>
233            {
234                let result_elements: Vec<Expression> = d1
235                    .diagonal_elements
236                    .iter()
237                    .zip(d2.diagonal_elements.iter())
238                    .map(|(a, b)| Expression::add(vec![a.clone(), b.clone()]).simplify())
239                    .collect();
240                Matrix::Diagonal(DiagonalMatrixData {
241                    diagonal_elements: result_elements,
242                })
243            }
244
245            (Matrix::Identity(id), Matrix::Diagonal(diag))
246            | (Matrix::Diagonal(diag), Matrix::Identity(id))
247                if diag.diagonal_elements.len() == id.size =>
248            {
249                let result_elements: Vec<Expression> = diag
250                    .diagonal_elements
251                    .iter()
252                    .map(|elem| {
253                        Expression::add(vec![elem.clone(), Expression::integer(1)]).simplify()
254                    })
255                    .collect();
256                Matrix::Diagonal(DiagonalMatrixData {
257                    diagonal_elements: result_elements,
258                })
259            }
260
261            (Matrix::Scalar(s1), Matrix::Scalar(s2)) if s1.size == s2.size => {
262                Matrix::Scalar(ScalarMatrixData {
263                    size: s1.size,
264                    scalar_value: Expression::add(vec![
265                        s1.scalar_value.clone(),
266                        s2.scalar_value.clone(),
267                    ])
268                    .simplify(),
269                })
270            }
271
272            _ => {
273                let mut result_rows = Vec::with_capacity(rows1);
274                for i in 0..rows1 {
275                    let mut row = Vec::with_capacity(cols1);
276                    for j in 0..cols1 {
277                        let elem1 = self.get_element(i, j);
278                        let elem2 = other.get_element(i, j);
279                        let sum = Expression::add(vec![elem1, elem2]).simplify();
280                        row.push(sum);
281                    }
282                    result_rows.push(row);
283                }
284
285                Matrix::Dense(MatrixData { rows: result_rows }).optimize()
286            }
287        };
288
289        Ok(result)
290    }
291
292    fn multiply(&self, other: &Matrix) -> Result<Matrix, MathError> {
293        let (rows1, cols1) = self.dimensions();
294        let (rows2, cols2) = other.dimensions();
295
296        if cols1 != rows2 {
297            return Err(MathError::DomainError {
298                operation: "matrix_multiplication".to_string(),
299                value: Expression::function("incompatible_matrices", vec![]),
300                reason: format!(
301                    "Cannot multiply {}x{} matrix by {}x{} matrix (inner dimensions {} != {})",
302                    rows1, cols1, rows2, cols2, cols1, rows2
303                ),
304            });
305        }
306
307        let result = match (self, other) {
308            (Matrix::Zero(_), _) => Matrix::Zero(ZeroMatrixData {
309                rows: rows1,
310                cols: cols2,
311            }),
312            (_, Matrix::Zero(_)) => Matrix::Zero(ZeroMatrixData {
313                rows: rows1,
314                cols: cols2,
315            }),
316
317            (Matrix::Identity(_), other) => other.clone(),
318            (this, Matrix::Identity(_)) => this.clone(),
319
320            (Matrix::Diagonal(d1), Matrix::Diagonal(d2))
321                if d1.diagonal_elements.len() == d2.diagonal_elements.len() =>
322            {
323                let result_elements: Vec<Expression> = d1
324                    .diagonal_elements
325                    .iter()
326                    .zip(d2.diagonal_elements.iter())
327                    .map(|(a, b)| Expression::mul(vec![a.clone(), b.clone()]))
328                    .collect();
329                Matrix::Diagonal(DiagonalMatrixData {
330                    diagonal_elements: result_elements,
331                })
332            }
333
334            (Matrix::Scalar(s1), Matrix::Scalar(s2)) if s1.size == s2.size => {
335                let product_scalar =
336                    Expression::mul(vec![s1.scalar_value.clone(), s2.scalar_value.clone()]);
337                Matrix::Scalar(ScalarMatrixData {
338                    size: s1.size,
339                    scalar_value: product_scalar,
340                })
341            }
342            (Matrix::Scalar(s), other) => other.scalar_multiply(&s.scalar_value),
343            (this, Matrix::Scalar(s)) => this.scalar_multiply(&s.scalar_value),
344
345            _ => {
346                if let (Some(num1), Some(num2)) = (self.as_numeric(), other.as_numeric()) {
347                    let result = num1.multiply(&num2)?;
348                    return Ok(result.to_matrix().optimize());
349                }
350
351                let mut result_rows = Vec::with_capacity(rows1);
352                for i in 0..rows1 {
353                    let mut row = Vec::with_capacity(cols2);
354                    for j in 0..cols2 {
355                        let mut sum_terms = Vec::with_capacity(cols1);
356                        for k in 0..cols1 {
357                            let elem1 = self.get_element(i, k);
358                            let elem2 = other.get_element(k, j);
359                            sum_terms.push(Expression::mul(vec![elem1, elem2]));
360                        }
361                        let sum = Expression::add(sum_terms);
362                        row.push(sum);
363                    }
364                    result_rows.push(row);
365                }
366
367                Matrix::Dense(MatrixData { rows: result_rows })
368            }
369        };
370
371        Ok(result)
372    }
373
374    fn transpose(&self) -> Matrix {
375        match self {
376            Matrix::Identity(data) => Matrix::Identity(data.clone()),
377            Matrix::Zero(data) => Matrix::Zero(ZeroMatrixData {
378                rows: data.cols,
379                cols: data.rows,
380            }),
381            Matrix::Scalar(data) => Matrix::Scalar(data.clone()),
382            Matrix::Diagonal(data) => Matrix::Diagonal(data.clone()),
383            Matrix::Symmetric(data) => Matrix::Symmetric(data.clone()),
384            Matrix::UpperTriangular(data) => Matrix::LowerTriangular(LowerTriangularMatrixData {
385                size: data.size,
386                elements: data.elements.clone(),
387            }),
388            Matrix::LowerTriangular(data) => Matrix::UpperTriangular(UpperTriangularMatrixData {
389                size: data.size,
390                elements: data.elements.clone(),
391            }),
392            _ => {
393                let (rows, cols) = self.dimensions();
394                let mut result_rows = Vec::with_capacity(cols);
395                for j in 0..cols {
396                    let mut row = Vec::with_capacity(rows);
397                    for i in 0..rows {
398                        row.push(self.get_element(i, j));
399                    }
400                    result_rows.push(row);
401                }
402                Matrix::Dense(MatrixData { rows: result_rows }).optimize()
403            }
404        }
405    }
406
407    fn inverse(&self) -> Matrix {
408        match self {
409            Matrix::Identity(data) => Matrix::Identity(data.clone()),
410            Matrix::Scalar(data) => {
411                let inverse_scalar =
412                    Expression::pow(data.scalar_value.clone(), Expression::integer(-1)).simplify();
413                Matrix::Scalar(ScalarMatrixData {
414                    size: data.size,
415                    scalar_value: inverse_scalar,
416                })
417            }
418            Matrix::Diagonal(data) => {
419                let inverse_elements: Vec<Expression> = data
420                    .diagonal_elements
421                    .iter()
422                    .map(|elem| Expression::pow(elem.clone(), Expression::integer(-1)).simplify())
423                    .collect();
424                Matrix::Diagonal(DiagonalMatrixData {
425                    diagonal_elements: inverse_elements,
426                })
427            }
428            _ => {
429                if let Some(numeric) = self.as_numeric() {
430                    if let Ok(inv) = numeric.inverse() {
431                        return inv.to_matrix();
432                    }
433                }
434
435                if let Some(inv) = self.inverse_via_lu() {
436                    inv
437                } else {
438                    self.gauss_jordan_inverse()
439                }
440            }
441        }
442    }
443}