mathhook_core/matrices/
operations.rs1use super::{CoreMatrixOps, Matrix};
7use crate::core::Expression;
8use crate::core::Number;
9use crate::simplify::Simplify;
10
11pub trait MatrixOperations {
16 fn matrix_add(&self, other: &Expression) -> Expression;
36
37 fn matrix_subtract(&self, other: &Expression) -> Expression;
39
40 fn matrix_multiply(&self, other: &Expression) -> Expression;
42
43 fn matrix_scalar_multiply(&self, scalar: &Expression) -> Expression;
45
46 fn matrix_determinant(&self) -> Expression;
48
49 fn matrix_transpose(&self) -> Expression;
51
52 fn matrix_inverse(&self) -> Expression;
54
55 fn matrix_trace(&self) -> Expression;
57
58 fn matrix_power(&self, exponent: &Expression) -> Expression;
60
61 fn is_identity_matrix(&self) -> bool;
63
64 fn is_zero_matrix(&self) -> bool;
66
67 fn is_diagonal(&self) -> bool;
69
70 fn simplify_matrix(&self) -> Expression;
72}
73
74impl Expression {
75 pub fn matrix_dimensions(&self) -> Option<(usize, usize)> {
94 match self {
95 Expression::Matrix(matrix) => Some(matrix.dimensions()),
96 _ => None,
97 }
98 }
99
100 pub fn is_matrix(&self) -> bool {
102 matches!(self, Expression::Matrix(_))
103 }
104}
105
106impl MatrixOperations for Expression {
107 fn matrix_add(&self, other: &Expression) -> Expression {
108 match (self, other) {
109 (Expression::Matrix(a), Expression::Matrix(b)) => match a.add(b) {
110 Ok(result_matrix) => Expression::Matrix(Box::new(result_matrix)),
111 Err(_) => Expression::function("undefined", vec![]),
112 },
113 _ => Expression::function("undefined", vec![]),
114 }
115 }
116
117 fn matrix_subtract(&self, other: &Expression) -> Expression {
118 match (self, other) {
119 (Expression::Matrix(a), Expression::Matrix(b)) => {
120 let neg_b = b.scalar_multiply(&Expression::integer(-1));
121 match a.add(&neg_b) {
122 Ok(result_matrix) => Expression::Matrix(Box::new(result_matrix)),
123 Err(_) => Expression::function("undefined", vec![]),
124 }
125 }
126 _ => Expression::function("undefined", vec![]),
127 }
128 }
129
130 fn matrix_multiply(&self, other: &Expression) -> Expression {
131 match (self, other) {
132 (Expression::Matrix(a), Expression::Matrix(b)) => match a.multiply(b) {
133 Ok(result_matrix) => Expression::Matrix(Box::new(result_matrix)),
134 Err(_) => Expression::function("undefined", vec![]),
135 },
136 _ => Expression::function("undefined", vec![]),
137 }
138 }
139
140 fn matrix_scalar_multiply(&self, scalar: &Expression) -> Expression {
141 match self {
142 Expression::Matrix(matrix) => {
143 let result_matrix = matrix.scalar_multiply(scalar);
144 let result = Expression::Matrix(Box::new(result_matrix));
145 result.simplify()
146 }
147 _ => Expression::function("undefined", vec![]),
148 }
149 }
150
151 fn matrix_determinant(&self) -> Expression {
152 match self {
153 Expression::Matrix(matrix) => matrix
154 .determinant()
155 .unwrap_or_else(|_| Expression::function("undefined", vec![])),
156 _ => Expression::function("undefined", vec![]),
157 }
158 }
159
160 fn matrix_transpose(&self) -> Expression {
161 match self {
162 Expression::Matrix(matrix) => {
163 let transposed = matrix.transpose();
164 Expression::Matrix(Box::new(transposed))
165 }
166 _ => Expression::function("undefined", vec![]),
167 }
168 }
169
170 fn matrix_inverse(&self) -> Expression {
171 match self {
172 Expression::Matrix(matrix) => {
173 let inverse = matrix.inverse();
174 Expression::Matrix(Box::new(inverse))
175 }
176 _ => Expression::function("undefined", vec![]),
177 }
178 }
179
180 fn matrix_trace(&self) -> Expression {
181 match self {
182 Expression::Matrix(matrix) => matrix.trace(),
183 _ => Expression::function("undefined", vec![]),
184 }
185 }
186
187 fn matrix_power(&self, exponent: &Expression) -> Expression {
188 if !self.is_matrix() {
189 return Expression::function("undefined", vec![]);
190 }
191
192 if let Expression::Number(Number::Integer(n)) = exponent {
193 if *n < 0 {
194 let inv = self.matrix_inverse();
195 return inv.matrix_power(&Expression::integer(-n));
196 }
197
198 if *n == 0 {
199 if let Some((rows, cols)) = self.matrix_dimensions() {
200 if rows == cols {
201 return Expression::identity_matrix(rows);
202 }
203 }
204 return Expression::function("undefined", vec![]);
205 }
206
207 if *n == 1 {
208 return self.clone();
209 }
210
211 let mut result = self.clone();
212 for _ in 1..*n {
213 result = result.matrix_multiply(self);
214 }
215 result
216 } else {
217 Expression::function("undefined", vec![])
218 }
219 }
220
221 fn is_identity_matrix(&self) -> bool {
222 match self {
223 Expression::Matrix(matrix) => matches!(matrix.as_ref(), Matrix::Identity(_)),
224 _ => false,
225 }
226 }
227
228 fn is_zero_matrix(&self) -> bool {
229 match self {
230 Expression::Matrix(matrix) => matches!(matrix.as_ref(), Matrix::Zero(_)),
231 _ => false,
232 }
233 }
234
235 fn is_diagonal(&self) -> bool {
236 match self {
237 Expression::Matrix(matrix) => {
238 matches!(
239 matrix.as_ref(),
240 Matrix::Diagonal(_) | Matrix::Identity(_) | Matrix::Scalar(_)
241 )
242 }
243 _ => false,
244 }
245 }
246
247 fn simplify_matrix(&self) -> Expression {
248 match self {
249 Expression::Matrix(matrix) => {
250 let optimized = matrix.as_ref().clone().optimize();
251 Expression::Matrix(Box::new(optimized))
252 }
253 _ => self.clone(),
254 }
255 }
256}