mathhook_core/core/expression/
matrix_methods.rs1use super::Expression;
7use crate::core::symbol::SymbolType;
8
9impl Expression {
10 pub fn transpose(&self) -> Expression {
38 match self {
39 Expression::Symbol(s) if s.symbol_type() == SymbolType::Matrix => {
40 Expression::function("transpose", vec![Expression::symbol(s.clone())])
41 }
42
43 Expression::Mul(factors) => {
44 let all_matrices = factors.iter().all(|f| {
45 matches!(f, Expression::Symbol(s) if s.symbol_type() == SymbolType::Matrix)
46 || matches!(f, Expression::Matrix(_))
47 });
48
49 if all_matrices && factors.len() > 1 {
50 let transposed_factors: Vec<Expression> =
51 factors.iter().rev().map(|f| f.transpose()).collect();
52
53 Expression::mul(transposed_factors)
54 } else {
55 Expression::function("transpose", vec![self.clone()])
56 }
57 }
58
59 Expression::Add(terms) => {
60 let transposed_terms: Vec<Expression> =
61 terms.iter().map(|term| term.transpose()).collect();
62
63 Expression::add(transposed_terms)
64 }
65
66 Expression::Matrix(matrix) => {
67 use crate::matrices::CoreMatrixOps;
68 Expression::Matrix(Box::new(matrix.transpose()))
69 }
70
71 Expression::Number(_) | Expression::Constant(_) => self.clone(),
72
73 _ => Expression::function("transpose", vec![self.clone()]),
74 }
75 }
76
77 pub fn inverse(&self) -> Expression {
105 match self {
106 Expression::Symbol(s) if s.symbol_type() == SymbolType::Matrix => {
107 Expression::function("inverse", vec![Expression::symbol(s.clone())])
108 }
109
110 Expression::Mul(factors) => {
111 let all_matrices = factors.iter().all(|f| {
112 matches!(f, Expression::Symbol(s) if s.symbol_type() == SymbolType::Matrix)
113 || matches!(f, Expression::Matrix(_))
114 });
115
116 if all_matrices && factors.len() > 1 {
117 let inverse_factors: Vec<Expression> =
118 factors.iter().rev().map(|f| f.inverse()).collect();
119
120 Expression::mul(inverse_factors)
121 } else {
122 Expression::function("inverse", vec![self.clone()])
123 }
124 }
125
126 Expression::Matrix(matrix) => {
127 use crate::matrices::CoreMatrixOps;
128 Expression::Matrix(Box::new(matrix.inverse()))
129 }
130
131 Expression::Number(_) => Expression::pow(self.clone(), Expression::integer(-1)),
132
133 _ => Expression::function("inverse", vec![self.clone()]),
134 }
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use crate::symbol;
142
143 #[test]
144 fn test_transpose_single_matrix_symbol() {
145 let a = symbol!(A; matrix);
146 let expr = Expression::symbol(a.clone());
147 let transposed = expr.transpose();
148
149 match transposed {
150 Expression::Function { name, args } => {
151 assert_eq!(name, "transpose");
152 assert_eq!(args.len(), 1);
153 assert_eq!(args[0], Expression::symbol(a));
154 }
155 _ => panic!("Expected Function expression for transpose"),
156 }
157 }
158
159 #[test]
160 fn test_function_expression_commutativity() {
161 use crate::core::commutativity::Commutativity;
162
163 let a = symbol!(A; matrix);
164 let a_t = Expression::function("transpose", vec![Expression::symbol(a.clone())]);
165
166 assert_eq!(
167 a_t.commutativity(),
168 Commutativity::Noncommutative,
169 "transpose(A) should be noncommutative since A is a matrix"
170 );
171 }
172
173 #[test]
174 fn test_mul_preserves_noncommutative_function_order() {
175 let a = symbol!(A; matrix);
176 let b = symbol!(B; matrix);
177
178 let a_t = Expression::function("transpose", vec![Expression::symbol(a.clone())]);
179 let b_t = Expression::function("transpose", vec![Expression::symbol(b.clone())]);
180
181 let product = Expression::mul(vec![b_t.clone(), a_t.clone()]);
183
184 match product {
186 Expression::Mul(ref factors) => {
187 assert_eq!(factors.len(), 2);
188 assert_eq!(factors[0], b_t, "Expected B^T to be first");
190 assert_eq!(factors[1], a_t, "Expected A^T to be second");
191 }
192 _ => panic!("Expected Mul expression, got {:?}", product),
193 }
194 }
195
196 #[test]
197 fn test_transpose_product_reverses_order_two_matrices() {
198 let a = symbol!(A; matrix);
199 let b = symbol!(B; matrix);
200
201 let product = Expression::mul(vec![
202 Expression::symbol(a.clone()),
203 Expression::symbol(b.clone()),
204 ]);
205
206 let transposed_product = product.transpose();
207
208 let a_t = Expression::function("transpose", vec![Expression::symbol(a.clone())]);
209 let b_t = Expression::function("transpose", vec![Expression::symbol(b.clone())]);
210 let expected = Expression::mul(vec![b_t.clone(), a_t.clone()]);
211
212 assert_eq!(transposed_product, expected);
213 }
214
215 #[test]
216 fn test_transpose_product_reverses_order_three_matrices() {
217 let a = symbol!(A; matrix);
218 let b = symbol!(B; matrix);
219 let c = symbol!(C; matrix);
220
221 let product = Expression::mul(vec![
222 Expression::symbol(a.clone()),
223 Expression::symbol(b.clone()),
224 Expression::symbol(c.clone()),
225 ]);
226
227 let transposed_product = product.transpose();
228
229 let a_t = Expression::function("transpose", vec![Expression::symbol(a.clone())]);
230 let b_t = Expression::function("transpose", vec![Expression::symbol(b.clone())]);
231 let c_t = Expression::function("transpose", vec![Expression::symbol(c.clone())]);
232 let expected = Expression::mul(vec![c_t, b_t, a_t]);
233
234 assert_eq!(transposed_product, expected);
235 }
236
237 #[test]
238 fn test_transpose_sum_distributes() {
239 let a = symbol!(A; matrix);
240 let b = symbol!(B; matrix);
241
242 let sum = Expression::add(vec![
243 Expression::symbol(a.clone()),
244 Expression::symbol(b.clone()),
245 ]);
246
247 let transposed_sum = sum.transpose();
248
249 let a_t = Expression::function("transpose", vec![Expression::symbol(a.clone())]);
250 let b_t = Expression::function("transpose", vec![Expression::symbol(b.clone())]);
251 let expected = Expression::add(vec![a_t, b_t]);
252
253 assert_eq!(transposed_sum, expected);
254 }
255
256 #[test]
257 fn test_transpose_scalar_unchanged() {
258 let x = Expression::integer(42);
259 let transposed = x.transpose();
260 assert_eq!(transposed, x);
261 }
262
263 #[test]
264 fn test_inverse_single_matrix_symbol() {
265 let a = symbol!(A; matrix);
266 let expr = Expression::symbol(a.clone());
267 let inverse = expr.inverse();
268
269 match inverse {
270 Expression::Function { name, args } => {
271 assert_eq!(name, "inverse");
272 assert_eq!(args.len(), 1);
273 assert_eq!(args[0], Expression::symbol(a));
274 }
275 _ => panic!("Expected Function expression for inverse"),
276 }
277 }
278
279 #[test]
280 fn test_inverse_product_reverses_order_two_matrices() {
281 let a = symbol!(A; matrix);
282 let b = symbol!(B; matrix);
283
284 let product = Expression::mul(vec![
285 Expression::symbol(a.clone()),
286 Expression::symbol(b.clone()),
287 ]);
288
289 let inverse_product = product.inverse();
290
291 let a_inv = Expression::function("inverse", vec![Expression::symbol(a.clone())]);
292 let b_inv = Expression::function("inverse", vec![Expression::symbol(b.clone())]);
293 let expected = Expression::mul(vec![b_inv, a_inv]);
294
295 assert_eq!(inverse_product, expected);
296 }
297
298 #[test]
302 fn test_inverse_product_reverses_order_three_matrices() {
303 let a = symbol!(A; matrix);
304 let b = symbol!(B; matrix);
305 let c = symbol!(C; matrix);
306
307 let product = Expression::mul(vec![
308 Expression::symbol(a.clone()),
309 Expression::symbol(b.clone()),
310 Expression::symbol(c.clone()),
311 ]);
312
313 let inverse_product = product.inverse();
314
315 let a_inv = Expression::function("inverse", vec![Expression::symbol(a.clone())]);
316 let b_inv = Expression::function("inverse", vec![Expression::symbol(b.clone())]);
317 let c_inv = Expression::function("inverse", vec![Expression::symbol(c.clone())]);
318 let expected = Expression::mul(vec![c_inv, b_inv, a_inv]);
319
320 assert_eq!(inverse_product, expected);
321 }
322
323 #[test]
324 fn test_inverse_scalar_becomes_reciprocal() {
325 let x = Expression::integer(5);
326 let inverse = x.inverse();
327 let expected = Expression::pow(Expression::integer(5), Expression::integer(-1));
328 assert_eq!(inverse, expected);
329 }
330
331 #[test]
332 fn test_transpose_nested_product() {
333 let a = symbol!(A; matrix);
334 let b = symbol!(B; matrix);
335 let c = symbol!(C; matrix);
336 let d = symbol!(D; matrix);
337
338 let ab = Expression::mul(vec![
339 Expression::symbol(a.clone()),
340 Expression::symbol(b.clone()),
341 ]);
342 let cd = Expression::mul(vec![
343 Expression::symbol(c.clone()),
344 Expression::symbol(d.clone()),
345 ]);
346
347 let product = Expression::mul(vec![ab.clone(), cd.clone()]);
348 let transposed = product.transpose();
349
350 let cd_t = cd.transpose();
351 let ab_t = ab.transpose();
352 let expected = Expression::mul(vec![cd_t, ab_t]);
353
354 assert_eq!(transposed, expected);
355 }
356
357 #[test]
358 fn test_inverse_nested_product() {
359 let a = symbol!(A; matrix);
360 let b = symbol!(B; matrix);
361 let c = symbol!(C; matrix);
362 let d = symbol!(D; matrix);
363
364 let ab = Expression::mul(vec![
365 Expression::symbol(a.clone()),
366 Expression::symbol(b.clone()),
367 ]);
368 let cd = Expression::mul(vec![
369 Expression::symbol(c.clone()),
370 Expression::symbol(d.clone()),
371 ]);
372
373 let product = Expression::mul(vec![ab.clone(), cd.clone()]);
374 let inverse = product.inverse();
375
376 let cd_inv = cd.inverse();
377 let ab_inv = ab.inverse();
378 let expected = Expression::mul(vec![cd_inv, ab_inv]);
379
380 assert_eq!(inverse, expected);
381 }
382
383 #[test]
384 fn test_transpose_concrete_matrix() {
385 let matrix = Expression::matrix(vec![
386 vec![Expression::integer(1), Expression::integer(2)],
387 vec![Expression::integer(3), Expression::integer(4)],
388 ]);
389
390 let transposed = matrix.transpose();
391
392 let expected = Expression::matrix(vec![
393 vec![Expression::integer(1), Expression::integer(3)],
394 vec![Expression::integer(2), Expression::integer(4)],
395 ]);
396
397 assert_eq!(transposed, expected);
398 }
399
400 #[test]
401 fn test_transpose_idempotent() {
402 let a = symbol!(A; matrix);
403 let expr = Expression::symbol(a.clone());
404 let transposed_once = expr.transpose();
405 let transposed_twice = transposed_once.clone().transpose();
406
407 match transposed_twice {
408 Expression::Function { name, args } => {
409 assert_eq!(name, "transpose");
410 assert_eq!(args.len(), 1);
411 assert_eq!(args[0], transposed_once);
412 }
413 _ => panic!("Expected nested transpose function"),
414 }
415 }
416
417 #[test]
418 fn test_symbolic_matrix_operations_combined() {
419 let a = symbol!(A; matrix);
420 let b = symbol!(B; matrix);
421
422 let ab = Expression::mul(vec![
423 Expression::symbol(a.clone()),
424 Expression::symbol(b.clone()),
425 ]);
426
427 let ab_t = ab.transpose();
428 let ab_inv = ab.inverse();
429
430 let a_t = Expression::function("transpose", vec![Expression::symbol(a.clone())]);
431 let b_t = Expression::function("transpose", vec![Expression::symbol(b.clone())]);
432 let expected_transpose = Expression::mul(vec![b_t, a_t]);
433
434 let a_inv = Expression::function("inverse", vec![Expression::symbol(a.clone())]);
435 let b_inv = Expression::function("inverse", vec![Expression::symbol(b.clone())]);
436 let expected_inverse = Expression::mul(vec![b_inv, a_inv]);
437
438 assert_eq!(ab_t, expected_transpose);
439 assert_eq!(ab_inv, expected_inverse);
440 }
441}