mathhook_core/core/expression/
matrix_methods.rs1use std::sync::Arc;
7
8use super::Expression;
9use crate::core::symbol::SymbolType;
10
11impl Expression {
12 pub fn transpose(&self) -> Expression {
40 match self {
41 Expression::Symbol(s) if s.symbol_type() == SymbolType::Matrix => {
42 Expression::function("transpose", vec![Expression::symbol(s.clone())])
43 }
44
45 Expression::Mul(factors) => {
46 let all_matrices = factors.iter().all(|f| {
47 matches!(f, Expression::Symbol(s) if s.symbol_type() == SymbolType::Matrix)
48 || matches!(f, Expression::Matrix(_))
49 });
50
51 if all_matrices && factors.len() > 1 {
52 let transposed_factors: Vec<Expression> =
53 factors.iter().rev().map(|f| f.transpose()).collect();
54
55 Expression::mul(transposed_factors)
56 } else {
57 Expression::function("transpose", vec![self.clone()])
58 }
59 }
60
61 Expression::Add(terms) => {
62 let transposed_terms: Vec<Expression> =
63 terms.iter().map(|term| term.transpose()).collect();
64
65 Expression::add(transposed_terms)
66 }
67
68 Expression::Matrix(matrix) => {
69 use crate::matrices::CoreMatrixOps;
70 Expression::Matrix(Arc::new(matrix.transpose()))
71 }
72
73 Expression::Number(_) | Expression::Constant(_) => self.clone(),
74
75 _ => Expression::function("transpose", vec![self.clone()]),
76 }
77 }
78
79 pub fn inverse(&self) -> Expression {
107 match self {
108 Expression::Symbol(s) if s.symbol_type() == SymbolType::Matrix => {
109 Expression::function("inverse", vec![Expression::symbol(s.clone())])
110 }
111
112 Expression::Mul(factors) => {
113 let all_matrices = factors.iter().all(|f| {
114 matches!(f, Expression::Symbol(s) if s.symbol_type() == SymbolType::Matrix)
115 || matches!(f, Expression::Matrix(_))
116 });
117
118 if all_matrices && factors.len() > 1 {
119 let inverse_factors: Vec<Expression> =
120 factors.iter().rev().map(|f| f.inverse()).collect();
121
122 Expression::mul(inverse_factors)
123 } else {
124 Expression::function("inverse", vec![self.clone()])
125 }
126 }
127
128 Expression::Matrix(matrix) => {
129 use crate::matrices::CoreMatrixOps;
130 Expression::Matrix(Arc::new(matrix.inverse()))
131 }
132
133 Expression::Number(_) => Expression::pow(self.clone(), Expression::integer(-1)),
134
135 _ => Expression::function("inverse", vec![self.clone()]),
136 }
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use crate::symbol;
144
145 #[test]
146 fn test_transpose_single_matrix_symbol() {
147 let a = symbol!(A; matrix);
148 let expr = Expression::symbol(a.clone());
149 let transposed = expr.transpose();
150
151 match transposed {
152 Expression::Function { name, args } => {
153 assert_eq!(name.as_ref(), "transpose");
154 assert_eq!(args.len(), 1);
155 assert_eq!(args[0], Expression::symbol(a));
156 }
157 _ => panic!("Expected Function expression for transpose"),
158 }
159 }
160
161 #[test]
162 fn test_function_expression_commutativity() {
163 use crate::core::commutativity::Commutativity;
164
165 let a = symbol!(A; matrix);
166 let a_t = Expression::function("transpose", vec![Expression::symbol(a.clone())]);
167
168 assert_eq!(
169 a_t.commutativity(),
170 Commutativity::Noncommutative,
171 "transpose(A) should be noncommutative since A is a matrix"
172 );
173 }
174
175 #[test]
176 fn test_mul_preserves_noncommutative_function_order() {
177 let a = symbol!(A; matrix);
178 let b = symbol!(B; matrix);
179
180 let a_t = Expression::function("transpose", vec![Expression::symbol(a.clone())]);
181 let b_t = Expression::function("transpose", vec![Expression::symbol(b.clone())]);
182
183 let product = Expression::mul(vec![b_t.clone(), a_t.clone()]);
185
186 match product {
188 Expression::Mul(ref factors) => {
189 assert_eq!(factors.len(), 2);
190 assert_eq!(factors[0], b_t, "Expected B^T to be first");
192 assert_eq!(factors[1], a_t, "Expected A^T to be second");
193 }
194 _ => panic!("Expected Mul expression, got {:?}", product),
195 }
196 }
197
198 #[test]
199 fn test_transpose_product_reverses_order_two_matrices() {
200 let a = symbol!(A; matrix);
201 let b = symbol!(B; matrix);
202
203 let product = Expression::mul(vec![
204 Expression::symbol(a.clone()),
205 Expression::symbol(b.clone()),
206 ]);
207
208 let transposed_product = product.transpose();
209
210 let a_t = Expression::function("transpose", vec![Expression::symbol(a.clone())]);
211 let b_t = Expression::function("transpose", vec![Expression::symbol(b.clone())]);
212 let expected = Expression::mul(vec![b_t.clone(), a_t.clone()]);
213
214 assert_eq!(transposed_product, expected);
215 }
216
217 #[test]
218 fn test_transpose_product_reverses_order_three_matrices() {
219 let a = symbol!(A; matrix);
220 let b = symbol!(B; matrix);
221 let c = symbol!(C; matrix);
222
223 let product = Expression::mul(vec![
224 Expression::symbol(a.clone()),
225 Expression::symbol(b.clone()),
226 Expression::symbol(c.clone()),
227 ]);
228
229 let transposed_product = product.transpose();
230
231 let a_t = Expression::function("transpose", vec![Expression::symbol(a.clone())]);
232 let b_t = Expression::function("transpose", vec![Expression::symbol(b.clone())]);
233 let c_t = Expression::function("transpose", vec![Expression::symbol(c.clone())]);
234 let expected = Expression::mul(vec![c_t, b_t, a_t]);
235
236 assert_eq!(transposed_product, expected);
237 }
238
239 #[test]
240 fn test_transpose_sum_distributes() {
241 let a = symbol!(A; matrix);
242 let b = symbol!(B; matrix);
243
244 let sum = Expression::add(vec![
245 Expression::symbol(a.clone()),
246 Expression::symbol(b.clone()),
247 ]);
248
249 let transposed_sum = sum.transpose();
250
251 let a_t = Expression::function("transpose", vec![Expression::symbol(a.clone())]);
252 let b_t = Expression::function("transpose", vec![Expression::symbol(b.clone())]);
253 let expected = Expression::add(vec![a_t, b_t]);
254
255 assert_eq!(transposed_sum, expected);
256 }
257
258 #[test]
259 fn test_transpose_scalar_unchanged() {
260 let x = Expression::integer(42);
261 let transposed = x.transpose();
262 assert_eq!(transposed, x);
263 }
264
265 #[test]
266 fn test_inverse_single_matrix_symbol() {
267 let a = symbol!(A; matrix);
268 let expr = Expression::symbol(a.clone());
269 let inverse = expr.inverse();
270
271 match inverse {
272 Expression::Function { name, args } => {
273 assert_eq!(name.as_ref(), "inverse");
274 assert_eq!(args.len(), 1);
275 assert_eq!(args[0], Expression::symbol(a));
276 }
277 _ => panic!("Expected Function expression for inverse"),
278 }
279 }
280
281 #[test]
282 fn test_inverse_product_reverses_order_two_matrices() {
283 let a = symbol!(A; matrix);
284 let b = symbol!(B; matrix);
285
286 let product = Expression::mul(vec![
287 Expression::symbol(a.clone()),
288 Expression::symbol(b.clone()),
289 ]);
290
291 let inverse_product = product.inverse();
292
293 let a_inv = Expression::function("inverse", vec![Expression::symbol(a.clone())]);
294 let b_inv = Expression::function("inverse", vec![Expression::symbol(b.clone())]);
295 let expected = Expression::mul(vec![b_inv, a_inv]);
296
297 assert_eq!(inverse_product, expected);
298 }
299
300 #[test]
304 fn test_inverse_product_reverses_order_three_matrices() {
305 let a = symbol!(A; matrix);
306 let b = symbol!(B; matrix);
307 let c = symbol!(C; matrix);
308
309 let product = Expression::mul(vec![
310 Expression::symbol(a.clone()),
311 Expression::symbol(b.clone()),
312 Expression::symbol(c.clone()),
313 ]);
314
315 let inverse_product = product.inverse();
316
317 let a_inv = Expression::function("inverse", vec![Expression::symbol(a.clone())]);
318 let b_inv = Expression::function("inverse", vec![Expression::symbol(b.clone())]);
319 let c_inv = Expression::function("inverse", vec![Expression::symbol(c.clone())]);
320 let expected = Expression::mul(vec![c_inv, b_inv, a_inv]);
321
322 assert_eq!(inverse_product, expected);
323 }
324
325 #[test]
326 fn test_inverse_scalar_becomes_reciprocal() {
327 let x = Expression::integer(5);
328 let inverse = x.inverse();
329 let expected = Expression::pow(Expression::integer(5), Expression::integer(-1));
330 assert_eq!(inverse, expected);
331 }
332
333 #[test]
334 fn test_transpose_nested_product() {
335 let a = symbol!(A; matrix);
336 let b = symbol!(B; matrix);
337 let c = symbol!(C; matrix);
338 let d = symbol!(D; matrix);
339
340 let ab = Expression::mul(vec![
341 Expression::symbol(a.clone()),
342 Expression::symbol(b.clone()),
343 ]);
344 let cd = Expression::mul(vec![
345 Expression::symbol(c.clone()),
346 Expression::symbol(d.clone()),
347 ]);
348
349 let product = Expression::mul(vec![ab.clone(), cd.clone()]);
350 let transposed = product.transpose();
351
352 let cd_t = cd.transpose();
353 let ab_t = ab.transpose();
354 let expected = Expression::mul(vec![cd_t, ab_t]);
355
356 assert_eq!(transposed, expected);
357 }
358
359 #[test]
360 fn test_inverse_nested_product() {
361 let a = symbol!(A; matrix);
362 let b = symbol!(B; matrix);
363 let c = symbol!(C; matrix);
364 let d = symbol!(D; matrix);
365
366 let ab = Expression::mul(vec![
367 Expression::symbol(a.clone()),
368 Expression::symbol(b.clone()),
369 ]);
370 let cd = Expression::mul(vec![
371 Expression::symbol(c.clone()),
372 Expression::symbol(d.clone()),
373 ]);
374
375 let product = Expression::mul(vec![ab.clone(), cd.clone()]);
376 let inverse = product.inverse();
377
378 let cd_inv = cd.inverse();
379 let ab_inv = ab.inverse();
380 let expected = Expression::mul(vec![cd_inv, ab_inv]);
381
382 assert_eq!(inverse, expected);
383 }
384
385 #[test]
386 fn test_transpose_concrete_matrix() {
387 let matrix = Expression::matrix(vec![
388 vec![Expression::integer(1), Expression::integer(2)],
389 vec![Expression::integer(3), Expression::integer(4)],
390 ]);
391
392 let transposed = matrix.transpose();
393
394 let expected = Expression::matrix(vec![
395 vec![Expression::integer(1), Expression::integer(3)],
396 vec![Expression::integer(2), Expression::integer(4)],
397 ]);
398
399 assert_eq!(transposed, expected);
400 }
401
402 #[test]
403 fn test_transpose_idempotent() {
404 let a = symbol!(A; matrix);
405 let expr = Expression::symbol(a.clone());
406 let transposed_once = expr.transpose();
407 let transposed_twice = transposed_once.clone().transpose();
408
409 match transposed_twice {
410 Expression::Function { name, args } => {
411 assert_eq!(name.as_ref(), "transpose");
412 assert_eq!(args.len(), 1);
413 assert_eq!(args[0], transposed_once);
414 }
415 _ => panic!("Expected nested transpose function"),
416 }
417 }
418
419 #[test]
420 fn test_symbolic_matrix_operations_combined() {
421 let a = symbol!(A; matrix);
422 let b = symbol!(B; matrix);
423
424 let ab = Expression::mul(vec![
425 Expression::symbol(a.clone()),
426 Expression::symbol(b.clone()),
427 ]);
428
429 let ab_t = ab.transpose();
430 let ab_inv = ab.inverse();
431
432 let a_t = Expression::function("transpose", vec![Expression::symbol(a.clone())]);
433 let b_t = Expression::function("transpose", vec![Expression::symbol(b.clone())]);
434 let expected_transpose = Expression::mul(vec![b_t, a_t]);
435
436 let a_inv = Expression::function("inverse", vec![Expression::symbol(a.clone())]);
437 let b_inv = Expression::function("inverse", vec![Expression::symbol(b.clone())]);
438 let expected_inverse = Expression::mul(vec![b_inv, a_inv]);
439
440 assert_eq!(ab_t, expected_transpose);
441 assert_eq!(ab_inv, expected_inverse);
442 }
443}