mathhook_core/matrices/
operations.rs

1//! High-level matrix operations for Expression integration
2//!
3//! This module provides the bridge between the Expression system and the
4//! unified matrix system, offering user-friendly matrix operations.
5
6use super::{CoreMatrixOps, Matrix};
7use crate::core::Expression;
8use crate::core::Number;
9use crate::simplify::Simplify;
10
11/// High-level matrix operations trait for Expression
12///
13/// Provides mathematical operations for matrices including addition, multiplication,
14/// transpose, inverse, and other linear algebra operations.
15pub trait MatrixOperations {
16    /// Add two matrices
17    ///
18    /// # Examples
19    ///
20    /// ```rust
21    /// use mathhook_core::Expression;
22    /// use mathhook_core::matrices::operations::MatrixOperations;
23    ///
24    /// let a = Expression::matrix(vec![
25    ///     vec![Expression::integer(1), Expression::integer(2)],
26    ///     vec![Expression::integer(3), Expression::integer(4)]
27    /// ]);
28    /// let b = Expression::matrix(vec![
29    ///     vec![Expression::integer(5), Expression::integer(6)],
30    ///     vec![Expression::integer(7), Expression::integer(8)]
31    /// ]);
32    /// let result = a.matrix_add(&b);
33    /// // Result: [[6, 8], [10, 12]]
34    /// ```
35    fn matrix_add(&self, other: &Expression) -> Expression;
36
37    /// Subtract two matrices
38    fn matrix_subtract(&self, other: &Expression) -> Expression;
39
40    /// Multiply two matrices
41    fn matrix_multiply(&self, other: &Expression) -> Expression;
42
43    /// Multiply matrix by scalar
44    fn matrix_scalar_multiply(&self, scalar: &Expression) -> Expression;
45
46    /// Get matrix determinant
47    fn matrix_determinant(&self) -> Expression;
48
49    /// Get matrix transpose
50    fn matrix_transpose(&self) -> Expression;
51
52    /// Get matrix inverse
53    fn matrix_inverse(&self) -> Expression;
54
55    /// Get matrix trace (sum of diagonal elements)
56    fn matrix_trace(&self) -> Expression;
57
58    /// Raise matrix to a power
59    fn matrix_power(&self, exponent: &Expression) -> Expression;
60
61    /// Check if matrix is identity matrix
62    fn is_identity_matrix(&self) -> bool;
63
64    /// Check if matrix is zero matrix
65    fn is_zero_matrix(&self) -> bool;
66
67    /// Check if matrix is diagonal
68    fn is_diagonal(&self) -> bool;
69
70    /// Simplify matrix expression
71    fn simplify_matrix(&self) -> Expression;
72}
73
74impl Expression {
75    /// Get matrix dimensions for any matrix type
76    ///
77    /// Returns (rows, columns) for all matrix types.
78    ///
79    /// # Examples
80    ///
81    /// ```rust
82    /// use mathhook_core::Expression;
83    ///
84    /// let matrix = Expression::matrix(vec![
85    ///     vec![Expression::integer(1), Expression::integer(2)],
86    ///     vec![Expression::integer(3), Expression::integer(4)]
87    /// ]);
88    /// assert_eq!(matrix.matrix_dimensions(), Some((2, 2)));
89    ///
90    /// let identity = Expression::identity_matrix(3);
91    /// assert_eq!(identity.matrix_dimensions(), Some((3, 3)));
92    /// ```
93    pub fn matrix_dimensions(&self) -> Option<(usize, usize)> {
94        match self {
95            Expression::Matrix(matrix) => Some(matrix.dimensions()),
96            _ => None,
97        }
98    }
99
100    /// Check if expression is any kind of matrix
101    pub fn is_matrix(&self) -> bool {
102        matches!(self, Expression::Matrix(_))
103    }
104}
105
106impl MatrixOperations for Expression {
107    fn matrix_add(&self, other: &Expression) -> Expression {
108        match (self, other) {
109            (Expression::Matrix(a), Expression::Matrix(b)) => match a.add(b) {
110                Ok(result_matrix) => Expression::Matrix(Box::new(result_matrix)),
111                Err(_) => Expression::function("undefined", vec![]),
112            },
113            _ => Expression::function("undefined", vec![]),
114        }
115    }
116
117    fn matrix_subtract(&self, other: &Expression) -> Expression {
118        match (self, other) {
119            (Expression::Matrix(a), Expression::Matrix(b)) => {
120                let neg_b = b.scalar_multiply(&Expression::integer(-1));
121                match a.add(&neg_b) {
122                    Ok(result_matrix) => Expression::Matrix(Box::new(result_matrix)),
123                    Err(_) => Expression::function("undefined", vec![]),
124                }
125            }
126            _ => Expression::function("undefined", vec![]),
127        }
128    }
129
130    fn matrix_multiply(&self, other: &Expression) -> Expression {
131        match (self, other) {
132            (Expression::Matrix(a), Expression::Matrix(b)) => match a.multiply(b) {
133                Ok(result_matrix) => Expression::Matrix(Box::new(result_matrix)),
134                Err(_) => Expression::function("undefined", vec![]),
135            },
136            _ => Expression::function("undefined", vec![]),
137        }
138    }
139
140    fn matrix_scalar_multiply(&self, scalar: &Expression) -> Expression {
141        match self {
142            Expression::Matrix(matrix) => {
143                let result_matrix = matrix.scalar_multiply(scalar);
144                let result = Expression::Matrix(Box::new(result_matrix));
145                result.simplify()
146            }
147            _ => Expression::function("undefined", vec![]),
148        }
149    }
150
151    fn matrix_determinant(&self) -> Expression {
152        match self {
153            Expression::Matrix(matrix) => matrix
154                .determinant()
155                .unwrap_or_else(|_| Expression::function("undefined", vec![])),
156            _ => Expression::function("undefined", vec![]),
157        }
158    }
159
160    fn matrix_transpose(&self) -> Expression {
161        match self {
162            Expression::Matrix(matrix) => {
163                let transposed = matrix.transpose();
164                Expression::Matrix(Box::new(transposed))
165            }
166            _ => Expression::function("undefined", vec![]),
167        }
168    }
169
170    fn matrix_inverse(&self) -> Expression {
171        match self {
172            Expression::Matrix(matrix) => {
173                let inverse = matrix.inverse();
174                Expression::Matrix(Box::new(inverse))
175            }
176            _ => Expression::function("undefined", vec![]),
177        }
178    }
179
180    fn matrix_trace(&self) -> Expression {
181        match self {
182            Expression::Matrix(matrix) => matrix.trace(),
183            _ => Expression::function("undefined", vec![]),
184        }
185    }
186
187    fn matrix_power(&self, exponent: &Expression) -> Expression {
188        if !self.is_matrix() {
189            return Expression::function("undefined", vec![]);
190        }
191
192        if let Expression::Number(Number::Integer(n)) = exponent {
193            if *n < 0 {
194                let inv = self.matrix_inverse();
195                return inv.matrix_power(&Expression::integer(-n));
196            }
197
198            if *n == 0 {
199                if let Some((rows, cols)) = self.matrix_dimensions() {
200                    if rows == cols {
201                        return Expression::identity_matrix(rows);
202                    }
203                }
204                return Expression::function("undefined", vec![]);
205            }
206
207            if *n == 1 {
208                return self.clone();
209            }
210
211            let mut result = self.clone();
212            for _ in 1..*n {
213                result = result.matrix_multiply(self);
214            }
215            result
216        } else {
217            Expression::function("undefined", vec![])
218        }
219    }
220
221    fn is_identity_matrix(&self) -> bool {
222        match self {
223            Expression::Matrix(matrix) => matches!(matrix.as_ref(), Matrix::Identity(_)),
224            _ => false,
225        }
226    }
227
228    fn is_zero_matrix(&self) -> bool {
229        match self {
230            Expression::Matrix(matrix) => matches!(matrix.as_ref(), Matrix::Zero(_)),
231            _ => false,
232        }
233    }
234
235    fn is_diagonal(&self) -> bool {
236        match self {
237            Expression::Matrix(matrix) => {
238                matches!(
239                    matrix.as_ref(),
240                    Matrix::Diagonal(_) | Matrix::Identity(_) | Matrix::Scalar(_)
241                )
242            }
243            _ => false,
244        }
245    }
246
247    fn simplify_matrix(&self) -> Expression {
248        match self {
249            Expression::Matrix(matrix) => {
250                let optimized = matrix.as_ref().clone().optimize();
251                Expression::Matrix(Box::new(optimized))
252            }
253            _ => self.clone(),
254        }
255    }
256}