mathhook_core/matrices/
operations.rs1use std::sync::Arc;
7
8use super::{CoreMatrixOps, Matrix};
9use crate::core::Expression;
10use crate::core::Number;
11use crate::simplify::Simplify;
12
13pub trait MatrixOperations {
18 fn matrix_add(&self, other: &Expression) -> Expression;
38
39 fn matrix_subtract(&self, other: &Expression) -> Expression;
41
42 fn matrix_multiply(&self, other: &Expression) -> Expression;
44
45 fn matrix_scalar_multiply(&self, scalar: &Expression) -> Expression;
47
48 fn matrix_determinant(&self) -> Expression;
50
51 fn matrix_transpose(&self) -> Expression;
53
54 fn matrix_inverse(&self) -> Expression;
56
57 fn matrix_trace(&self) -> Expression;
59
60 fn matrix_power(&self, exponent: &Expression) -> Expression;
62
63 fn is_identity_matrix(&self) -> bool;
65
66 fn is_zero_matrix(&self) -> bool;
68
69 fn is_diagonal(&self) -> bool;
71
72 fn simplify_matrix(&self) -> Expression;
74}
75
76impl Expression {
77 pub fn matrix_dimensions(&self) -> Option<(usize, usize)> {
96 match self {
97 Expression::Matrix(matrix) => Some(matrix.dimensions()),
98 _ => None,
99 }
100 }
101
102 pub fn is_matrix(&self) -> bool {
104 matches!(self, Expression::Matrix(_))
105 }
106}
107
108impl MatrixOperations for Expression {
109 fn matrix_add(&self, other: &Expression) -> Expression {
110 match (self, other) {
111 (Expression::Matrix(a), Expression::Matrix(b)) => match a.add(b) {
112 Ok(result_matrix) => Expression::Matrix(Arc::new(result_matrix)),
113 Err(_) => Expression::function("undefined", vec![]),
114 },
115 _ => Expression::function("undefined", vec![]),
116 }
117 }
118
119 fn matrix_subtract(&self, other: &Expression) -> Expression {
120 match (self, other) {
121 (Expression::Matrix(a), Expression::Matrix(b)) => {
122 let neg_b = b.scalar_multiply(&Expression::integer(-1));
123 match a.add(&neg_b) {
124 Ok(result_matrix) => Expression::Matrix(Arc::new(result_matrix)),
125 Err(_) => Expression::function("undefined", vec![]),
126 }
127 }
128 _ => Expression::function("undefined", vec![]),
129 }
130 }
131
132 fn matrix_multiply(&self, other: &Expression) -> Expression {
133 match (self, other) {
134 (Expression::Matrix(a), Expression::Matrix(b)) => match a.multiply(b) {
135 Ok(result_matrix) => Expression::Matrix(Arc::new(result_matrix)),
136 Err(_) => Expression::function("undefined", vec![]),
137 },
138 _ => Expression::function("undefined", vec![]),
139 }
140 }
141
142 fn matrix_scalar_multiply(&self, scalar: &Expression) -> Expression {
143 match self {
144 Expression::Matrix(matrix) => {
145 let result_matrix = matrix.scalar_multiply(scalar);
146 let result = Expression::Matrix(Arc::new(result_matrix));
147 result.simplify()
148 }
149 _ => Expression::function("undefined", vec![]),
150 }
151 }
152
153 fn matrix_determinant(&self) -> Expression {
154 match self {
155 Expression::Matrix(matrix) => matrix
156 .determinant()
157 .unwrap_or_else(|_| Expression::function("undefined", vec![])),
158 _ => Expression::function("undefined", vec![]),
159 }
160 }
161
162 fn matrix_transpose(&self) -> Expression {
163 match self {
164 Expression::Matrix(matrix) => {
165 let transposed = matrix.transpose();
166 Expression::Matrix(Arc::new(transposed))
167 }
168 _ => Expression::function("undefined", vec![]),
169 }
170 }
171
172 fn matrix_inverse(&self) -> Expression {
173 match self {
174 Expression::Matrix(matrix) => {
175 let inverse = matrix.inverse();
176 Expression::Matrix(Arc::new(inverse))
177 }
178 _ => Expression::function("undefined", vec![]),
179 }
180 }
181
182 fn matrix_trace(&self) -> Expression {
183 match self {
184 Expression::Matrix(matrix) => matrix.trace(),
185 _ => Expression::function("undefined", vec![]),
186 }
187 }
188
189 fn matrix_power(&self, exponent: &Expression) -> Expression {
190 if !self.is_matrix() {
191 return Expression::function("undefined", vec![]);
192 }
193
194 if let Expression::Number(Number::Integer(n)) = exponent {
195 if *n < 0 {
196 let inv = self.matrix_inverse();
197 return inv.matrix_power(&Expression::integer(-n));
198 }
199
200 if *n == 0 {
201 if let Some((rows, cols)) = self.matrix_dimensions() {
202 if rows == cols {
203 return Expression::identity_matrix(rows);
204 }
205 }
206 return Expression::function("undefined", vec![]);
207 }
208
209 if *n == 1 {
210 return self.clone();
211 }
212
213 let mut result = self.clone();
214 for _ in 1..*n {
215 result = result.matrix_multiply(self);
216 }
217 result
218 } else {
219 Expression::function("undefined", vec![])
220 }
221 }
222
223 fn is_identity_matrix(&self) -> bool {
224 match self {
225 Expression::Matrix(matrix) => matches!(matrix.as_ref(), Matrix::Identity(_)),
226 _ => false,
227 }
228 }
229
230 fn is_zero_matrix(&self) -> bool {
231 match self {
232 Expression::Matrix(matrix) => matches!(matrix.as_ref(), Matrix::Zero(_)),
233 _ => false,
234 }
235 }
236
237 fn is_diagonal(&self) -> bool {
238 match self {
239 Expression::Matrix(matrix) => {
240 matches!(
241 matrix.as_ref(),
242 Matrix::Diagonal(_) | Matrix::Identity(_) | Matrix::Scalar(_)
243 )
244 }
245 _ => false,
246 }
247 }
248
249 fn simplify_matrix(&self) -> Expression {
250 match self {
251 Expression::Matrix(matrix) => {
252 let optimized = matrix.as_ref().clone().optimize();
253 Expression::Matrix(Arc::new(optimized))
254 }
255 _ => self.clone(),
256 }
257 }
258}