mathhook_core/matrices/
operations.rs

1//! High-level matrix operations for Expression integration
2//!
3//! This module provides the bridge between the Expression system and the
4//! unified matrix system, offering user-friendly matrix operations.
5
6use std::sync::Arc;
7
8use super::{CoreMatrixOps, Matrix};
9use crate::core::Expression;
10use crate::core::Number;
11use crate::simplify::Simplify;
12
13/// High-level matrix operations trait for Expression
14///
15/// Provides mathematical operations for matrices including addition, multiplication,
16/// transpose, inverse, and other linear algebra operations.
17pub trait MatrixOperations {
18    /// Add two matrices
19    ///
20    /// # Examples
21    ///
22    /// ```rust
23    /// use mathhook_core::Expression;
24    /// use mathhook_core::matrices::operations::MatrixOperations;
25    ///
26    /// let a = Expression::matrix(vec![
27    ///     vec![Expression::integer(1), Expression::integer(2)],
28    ///     vec![Expression::integer(3), Expression::integer(4)]
29    /// ]);
30    /// let b = Expression::matrix(vec![
31    ///     vec![Expression::integer(5), Expression::integer(6)],
32    ///     vec![Expression::integer(7), Expression::integer(8)]
33    /// ]);
34    /// let result = a.matrix_add(&b);
35    /// // Result: [[6, 8], [10, 12]]
36    /// ```
37    fn matrix_add(&self, other: &Expression) -> Expression;
38
39    /// Subtract two matrices
40    fn matrix_subtract(&self, other: &Expression) -> Expression;
41
42    /// Multiply two matrices
43    fn matrix_multiply(&self, other: &Expression) -> Expression;
44
45    /// Multiply matrix by scalar
46    fn matrix_scalar_multiply(&self, scalar: &Expression) -> Expression;
47
48    /// Get matrix determinant
49    fn matrix_determinant(&self) -> Expression;
50
51    /// Get matrix transpose
52    fn matrix_transpose(&self) -> Expression;
53
54    /// Get matrix inverse
55    fn matrix_inverse(&self) -> Expression;
56
57    /// Get matrix trace (sum of diagonal elements)
58    fn matrix_trace(&self) -> Expression;
59
60    /// Raise matrix to a power
61    fn matrix_power(&self, exponent: &Expression) -> Expression;
62
63    /// Check if matrix is identity matrix
64    fn is_identity_matrix(&self) -> bool;
65
66    /// Check if matrix is zero matrix
67    fn is_zero_matrix(&self) -> bool;
68
69    /// Check if matrix is diagonal
70    fn is_diagonal(&self) -> bool;
71
72    /// Simplify matrix expression
73    fn simplify_matrix(&self) -> Expression;
74}
75
76impl Expression {
77    /// Get matrix dimensions for any matrix type
78    ///
79    /// Returns (rows, columns) for all matrix types.
80    ///
81    /// # Examples
82    ///
83    /// ```rust
84    /// use mathhook_core::Expression;
85    ///
86    /// let matrix = Expression::matrix(vec![
87    ///     vec![Expression::integer(1), Expression::integer(2)],
88    ///     vec![Expression::integer(3), Expression::integer(4)]
89    /// ]);
90    /// assert_eq!(matrix.matrix_dimensions(), Some((2, 2)));
91    ///
92    /// let identity = Expression::identity_matrix(3);
93    /// assert_eq!(identity.matrix_dimensions(), Some((3, 3)));
94    /// ```
95    pub fn matrix_dimensions(&self) -> Option<(usize, usize)> {
96        match self {
97            Expression::Matrix(matrix) => Some(matrix.dimensions()),
98            _ => None,
99        }
100    }
101
102    /// Check if expression is any kind of matrix
103    pub fn is_matrix(&self) -> bool {
104        matches!(self, Expression::Matrix(_))
105    }
106}
107
108impl MatrixOperations for Expression {
109    fn matrix_add(&self, other: &Expression) -> Expression {
110        match (self, other) {
111            (Expression::Matrix(a), Expression::Matrix(b)) => match a.add(b) {
112                Ok(result_matrix) => Expression::Matrix(Arc::new(result_matrix)),
113                Err(_) => Expression::function("undefined", vec![]),
114            },
115            _ => Expression::function("undefined", vec![]),
116        }
117    }
118
119    fn matrix_subtract(&self, other: &Expression) -> Expression {
120        match (self, other) {
121            (Expression::Matrix(a), Expression::Matrix(b)) => {
122                let neg_b = b.scalar_multiply(&Expression::integer(-1));
123                match a.add(&neg_b) {
124                    Ok(result_matrix) => Expression::Matrix(Arc::new(result_matrix)),
125                    Err(_) => Expression::function("undefined", vec![]),
126                }
127            }
128            _ => Expression::function("undefined", vec![]),
129        }
130    }
131
132    fn matrix_multiply(&self, other: &Expression) -> Expression {
133        match (self, other) {
134            (Expression::Matrix(a), Expression::Matrix(b)) => match a.multiply(b) {
135                Ok(result_matrix) => Expression::Matrix(Arc::new(result_matrix)),
136                Err(_) => Expression::function("undefined", vec![]),
137            },
138            _ => Expression::function("undefined", vec![]),
139        }
140    }
141
142    fn matrix_scalar_multiply(&self, scalar: &Expression) -> Expression {
143        match self {
144            Expression::Matrix(matrix) => {
145                let result_matrix = matrix.scalar_multiply(scalar);
146                let result = Expression::Matrix(Arc::new(result_matrix));
147                result.simplify()
148            }
149            _ => Expression::function("undefined", vec![]),
150        }
151    }
152
153    fn matrix_determinant(&self) -> Expression {
154        match self {
155            Expression::Matrix(matrix) => matrix
156                .determinant()
157                .unwrap_or_else(|_| Expression::function("undefined", vec![])),
158            _ => Expression::function("undefined", vec![]),
159        }
160    }
161
162    fn matrix_transpose(&self) -> Expression {
163        match self {
164            Expression::Matrix(matrix) => {
165                let transposed = matrix.transpose();
166                Expression::Matrix(Arc::new(transposed))
167            }
168            _ => Expression::function("undefined", vec![]),
169        }
170    }
171
172    fn matrix_inverse(&self) -> Expression {
173        match self {
174            Expression::Matrix(matrix) => {
175                let inverse = matrix.inverse();
176                Expression::Matrix(Arc::new(inverse))
177            }
178            _ => Expression::function("undefined", vec![]),
179        }
180    }
181
182    fn matrix_trace(&self) -> Expression {
183        match self {
184            Expression::Matrix(matrix) => matrix.trace(),
185            _ => Expression::function("undefined", vec![]),
186        }
187    }
188
189    fn matrix_power(&self, exponent: &Expression) -> Expression {
190        if !self.is_matrix() {
191            return Expression::function("undefined", vec![]);
192        }
193
194        if let Expression::Number(Number::Integer(n)) = exponent {
195            if *n < 0 {
196                let inv = self.matrix_inverse();
197                return inv.matrix_power(&Expression::integer(-n));
198            }
199
200            if *n == 0 {
201                if let Some((rows, cols)) = self.matrix_dimensions() {
202                    if rows == cols {
203                        return Expression::identity_matrix(rows);
204                    }
205                }
206                return Expression::function("undefined", vec![]);
207            }
208
209            if *n == 1 {
210                return self.clone();
211            }
212
213            let mut result = self.clone();
214            for _ in 1..*n {
215                result = result.matrix_multiply(self);
216            }
217            result
218        } else {
219            Expression::function("undefined", vec![])
220        }
221    }
222
223    fn is_identity_matrix(&self) -> bool {
224        match self {
225            Expression::Matrix(matrix) => matches!(matrix.as_ref(), Matrix::Identity(_)),
226            _ => false,
227        }
228    }
229
230    fn is_zero_matrix(&self) -> bool {
231        match self {
232            Expression::Matrix(matrix) => matches!(matrix.as_ref(), Matrix::Zero(_)),
233            _ => false,
234        }
235    }
236
237    fn is_diagonal(&self) -> bool {
238        match self {
239            Expression::Matrix(matrix) => {
240                matches!(
241                    matrix.as_ref(),
242                    Matrix::Diagonal(_) | Matrix::Identity(_) | Matrix::Scalar(_)
243                )
244            }
245            _ => false,
246        }
247    }
248
249    fn simplify_matrix(&self) -> Expression {
250        match self {
251            Expression::Matrix(matrix) => {
252                let optimized = matrix.as_ref().clone().optimize();
253                Expression::Matrix(Arc::new(optimized))
254            }
255            _ => self.clone(),
256        }
257    }
258}