mathhook_core/core/expression/
matrix_methods.rs

1//! Matrix-specific methods for Expression type
2//!
3//! This module provides matrix operations including transpose and inverse
4//! that respect noncommutativity and implement proper order reversal rules.
5
6use std::sync::Arc;
7
8use super::Expression;
9use crate::core::symbol::SymbolType;
10
11impl Expression {
12    /// Compute transpose of matrix expression
13    ///
14    /// For symbolic matrix expressions, this implements proper order reversal
15    /// according to the mathematical rule: (AB)^T = B^T A^T
16    ///
17    /// # Mathematical Rules
18    ///
19    /// - For products: (AB)^T = B^T A^T (order reverses)
20    /// - For sums: (A+B)^T = A^T + B^T (distributes)
21    /// - For scalars: scalar^T = scalar (no change)
22    /// - For matrix symbols: A^T creates transpose function
23    ///
24    /// # Examples
25    ///
26    /// ```rust
27    /// use mathhook_core::{Expression, symbol};
28    ///
29    /// let a = symbol!(A; matrix);
30    /// let b = symbol!(B; matrix);
31    ///
32    /// let product = Expression::mul(vec![
33    ///     Expression::symbol(a.clone()),
34    ///     Expression::symbol(b.clone()),
35    /// ]);
36    ///
37    /// let transposed = product.transpose();
38    /// ```
39    pub fn transpose(&self) -> Expression {
40        match self {
41            Expression::Symbol(s) if s.symbol_type() == SymbolType::Matrix => {
42                Expression::function("transpose", vec![Expression::symbol(s.clone())])
43            }
44
45            Expression::Mul(factors) => {
46                let all_matrices = factors.iter().all(|f| {
47                    matches!(f, Expression::Symbol(s) if s.symbol_type() == SymbolType::Matrix)
48                        || matches!(f, Expression::Matrix(_))
49                });
50
51                if all_matrices && factors.len() > 1 {
52                    let transposed_factors: Vec<Expression> =
53                        factors.iter().rev().map(|f| f.transpose()).collect();
54
55                    Expression::mul(transposed_factors)
56                } else {
57                    Expression::function("transpose", vec![self.clone()])
58                }
59            }
60
61            Expression::Add(terms) => {
62                let transposed_terms: Vec<Expression> =
63                    terms.iter().map(|term| term.transpose()).collect();
64
65                Expression::add(transposed_terms)
66            }
67
68            Expression::Matrix(matrix) => {
69                use crate::matrices::CoreMatrixOps;
70                Expression::Matrix(Arc::new(matrix.transpose()))
71            }
72
73            Expression::Number(_) | Expression::Constant(_) => self.clone(),
74
75            _ => Expression::function("transpose", vec![self.clone()]),
76        }
77    }
78
79    /// Compute inverse of matrix expression
80    ///
81    /// For symbolic matrix expressions, this implements proper order reversal
82    /// according to the mathematical rule: (AB)^(-1) = B^(-1) A^(-1)
83    ///
84    /// # Mathematical Rules
85    ///
86    /// - For products: (AB)^(-1) = B^(-1) A^(-1) (order reverses)
87    /// - For matrix symbols: A^(-1) creates inverse function
88    /// - For identity: I^(-1) = I
89    /// - For scalars: a^(-1) = 1/a (reciprocal)
90    ///
91    /// # Examples
92    ///
93    /// ```rust
94    /// use mathhook_core::{Expression, symbol};
95    ///
96    /// let a = symbol!(A; matrix);
97    /// let b = symbol!(B; matrix);
98    ///
99    /// let product = Expression::mul(vec![
100    ///     Expression::symbol(a.clone()),
101    ///     Expression::symbol(b.clone()),
102    /// ]);
103    ///
104    /// let inverse = product.inverse();
105    /// ```
106    pub fn inverse(&self) -> Expression {
107        match self {
108            Expression::Symbol(s) if s.symbol_type() == SymbolType::Matrix => {
109                Expression::function("inverse", vec![Expression::symbol(s.clone())])
110            }
111
112            Expression::Mul(factors) => {
113                let all_matrices = factors.iter().all(|f| {
114                    matches!(f, Expression::Symbol(s) if s.symbol_type() == SymbolType::Matrix)
115                        || matches!(f, Expression::Matrix(_))
116                });
117
118                if all_matrices && factors.len() > 1 {
119                    let inverse_factors: Vec<Expression> =
120                        factors.iter().rev().map(|f| f.inverse()).collect();
121
122                    Expression::mul(inverse_factors)
123                } else {
124                    Expression::function("inverse", vec![self.clone()])
125                }
126            }
127
128            Expression::Matrix(matrix) => {
129                use crate::matrices::CoreMatrixOps;
130                Expression::Matrix(Arc::new(matrix.inverse()))
131            }
132
133            Expression::Number(_) => Expression::pow(self.clone(), Expression::integer(-1)),
134
135            _ => Expression::function("inverse", vec![self.clone()]),
136        }
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use crate::symbol;
144
145    #[test]
146    fn test_transpose_single_matrix_symbol() {
147        let a = symbol!(A; matrix);
148        let expr = Expression::symbol(a.clone());
149        let transposed = expr.transpose();
150
151        match transposed {
152            Expression::Function { name, args } => {
153                assert_eq!(name.as_ref(), "transpose");
154                assert_eq!(args.len(), 1);
155                assert_eq!(args[0], Expression::symbol(a));
156            }
157            _ => panic!("Expected Function expression for transpose"),
158        }
159    }
160
161    #[test]
162    fn test_function_expression_commutativity() {
163        use crate::core::commutativity::Commutativity;
164
165        let a = symbol!(A; matrix);
166        let a_t = Expression::function("transpose", vec![Expression::symbol(a.clone())]);
167
168        assert_eq!(
169            a_t.commutativity(),
170            Commutativity::Noncommutative,
171            "transpose(A) should be noncommutative since A is a matrix"
172        );
173    }
174
175    #[test]
176    fn test_mul_preserves_noncommutative_function_order() {
177        let a = symbol!(A; matrix);
178        let b = symbol!(B; matrix);
179
180        let a_t = Expression::function("transpose", vec![Expression::symbol(a.clone())]);
181        let b_t = Expression::function("transpose", vec![Expression::symbol(b.clone())]);
182
183        // Create product B^T * A^T
184        let product = Expression::mul(vec![b_t.clone(), a_t.clone()]);
185
186        // The order should be preserved since these are noncommutative
187        match product {
188            Expression::Mul(ref factors) => {
189                assert_eq!(factors.len(), 2);
190                // Check that B^T comes first
191                assert_eq!(factors[0], b_t, "Expected B^T to be first");
192                assert_eq!(factors[1], a_t, "Expected A^T to be second");
193            }
194            _ => panic!("Expected Mul expression, got {:?}", product),
195        }
196    }
197
198    #[test]
199    fn test_transpose_product_reverses_order_two_matrices() {
200        let a = symbol!(A; matrix);
201        let b = symbol!(B; matrix);
202
203        let product = Expression::mul(vec![
204            Expression::symbol(a.clone()),
205            Expression::symbol(b.clone()),
206        ]);
207
208        let transposed_product = product.transpose();
209
210        let a_t = Expression::function("transpose", vec![Expression::symbol(a.clone())]);
211        let b_t = Expression::function("transpose", vec![Expression::symbol(b.clone())]);
212        let expected = Expression::mul(vec![b_t.clone(), a_t.clone()]);
213
214        assert_eq!(transposed_product, expected);
215    }
216
217    #[test]
218    fn test_transpose_product_reverses_order_three_matrices() {
219        let a = symbol!(A; matrix);
220        let b = symbol!(B; matrix);
221        let c = symbol!(C; matrix);
222
223        let product = Expression::mul(vec![
224            Expression::symbol(a.clone()),
225            Expression::symbol(b.clone()),
226            Expression::symbol(c.clone()),
227        ]);
228
229        let transposed_product = product.transpose();
230
231        let a_t = Expression::function("transpose", vec![Expression::symbol(a.clone())]);
232        let b_t = Expression::function("transpose", vec![Expression::symbol(b.clone())]);
233        let c_t = Expression::function("transpose", vec![Expression::symbol(c.clone())]);
234        let expected = Expression::mul(vec![c_t, b_t, a_t]);
235
236        assert_eq!(transposed_product, expected);
237    }
238
239    #[test]
240    fn test_transpose_sum_distributes() {
241        let a = symbol!(A; matrix);
242        let b = symbol!(B; matrix);
243
244        let sum = Expression::add(vec![
245            Expression::symbol(a.clone()),
246            Expression::symbol(b.clone()),
247        ]);
248
249        let transposed_sum = sum.transpose();
250
251        let a_t = Expression::function("transpose", vec![Expression::symbol(a.clone())]);
252        let b_t = Expression::function("transpose", vec![Expression::symbol(b.clone())]);
253        let expected = Expression::add(vec![a_t, b_t]);
254
255        assert_eq!(transposed_sum, expected);
256    }
257
258    #[test]
259    fn test_transpose_scalar_unchanged() {
260        let x = Expression::integer(42);
261        let transposed = x.transpose();
262        assert_eq!(transposed, x);
263    }
264
265    #[test]
266    fn test_inverse_single_matrix_symbol() {
267        let a = symbol!(A; matrix);
268        let expr = Expression::symbol(a.clone());
269        let inverse = expr.inverse();
270
271        match inverse {
272            Expression::Function { name, args } => {
273                assert_eq!(name.as_ref(), "inverse");
274                assert_eq!(args.len(), 1);
275                assert_eq!(args[0], Expression::symbol(a));
276            }
277            _ => panic!("Expected Function expression for inverse"),
278        }
279    }
280
281    #[test]
282    fn test_inverse_product_reverses_order_two_matrices() {
283        let a = symbol!(A; matrix);
284        let b = symbol!(B; matrix);
285
286        let product = Expression::mul(vec![
287            Expression::symbol(a.clone()),
288            Expression::symbol(b.clone()),
289        ]);
290
291        let inverse_product = product.inverse();
292
293        let a_inv = Expression::function("inverse", vec![Expression::symbol(a.clone())]);
294        let b_inv = Expression::function("inverse", vec![Expression::symbol(b.clone())]);
295        let expected = Expression::mul(vec![b_inv, a_inv]);
296
297        assert_eq!(inverse_product, expected);
298    }
299
300    // SKIPPED: This test fails due to Expression::mul() alphabetically sorting Function expressions
301    // This is a known limitation in the canonical form system, not a bug in transpose/inverse
302    // The mathematical operations ARE correct (order IS reversed), but canonical form re-sorts
303    #[test]
304    fn test_inverse_product_reverses_order_three_matrices() {
305        let a = symbol!(A; matrix);
306        let b = symbol!(B; matrix);
307        let c = symbol!(C; matrix);
308
309        let product = Expression::mul(vec![
310            Expression::symbol(a.clone()),
311            Expression::symbol(b.clone()),
312            Expression::symbol(c.clone()),
313        ]);
314
315        let inverse_product = product.inverse();
316
317        let a_inv = Expression::function("inverse", vec![Expression::symbol(a.clone())]);
318        let b_inv = Expression::function("inverse", vec![Expression::symbol(b.clone())]);
319        let c_inv = Expression::function("inverse", vec![Expression::symbol(c.clone())]);
320        let expected = Expression::mul(vec![c_inv, b_inv, a_inv]);
321
322        assert_eq!(inverse_product, expected);
323    }
324
325    #[test]
326    fn test_inverse_scalar_becomes_reciprocal() {
327        let x = Expression::integer(5);
328        let inverse = x.inverse();
329        let expected = Expression::pow(Expression::integer(5), Expression::integer(-1));
330        assert_eq!(inverse, expected);
331    }
332
333    #[test]
334    fn test_transpose_nested_product() {
335        let a = symbol!(A; matrix);
336        let b = symbol!(B; matrix);
337        let c = symbol!(C; matrix);
338        let d = symbol!(D; matrix);
339
340        let ab = Expression::mul(vec![
341            Expression::symbol(a.clone()),
342            Expression::symbol(b.clone()),
343        ]);
344        let cd = Expression::mul(vec![
345            Expression::symbol(c.clone()),
346            Expression::symbol(d.clone()),
347        ]);
348
349        let product = Expression::mul(vec![ab.clone(), cd.clone()]);
350        let transposed = product.transpose();
351
352        let cd_t = cd.transpose();
353        let ab_t = ab.transpose();
354        let expected = Expression::mul(vec![cd_t, ab_t]);
355
356        assert_eq!(transposed, expected);
357    }
358
359    #[test]
360    fn test_inverse_nested_product() {
361        let a = symbol!(A; matrix);
362        let b = symbol!(B; matrix);
363        let c = symbol!(C; matrix);
364        let d = symbol!(D; matrix);
365
366        let ab = Expression::mul(vec![
367            Expression::symbol(a.clone()),
368            Expression::symbol(b.clone()),
369        ]);
370        let cd = Expression::mul(vec![
371            Expression::symbol(c.clone()),
372            Expression::symbol(d.clone()),
373        ]);
374
375        let product = Expression::mul(vec![ab.clone(), cd.clone()]);
376        let inverse = product.inverse();
377
378        let cd_inv = cd.inverse();
379        let ab_inv = ab.inverse();
380        let expected = Expression::mul(vec![cd_inv, ab_inv]);
381
382        assert_eq!(inverse, expected);
383    }
384
385    #[test]
386    fn test_transpose_concrete_matrix() {
387        let matrix = Expression::matrix(vec![
388            vec![Expression::integer(1), Expression::integer(2)],
389            vec![Expression::integer(3), Expression::integer(4)],
390        ]);
391
392        let transposed = matrix.transpose();
393
394        let expected = Expression::matrix(vec![
395            vec![Expression::integer(1), Expression::integer(3)],
396            vec![Expression::integer(2), Expression::integer(4)],
397        ]);
398
399        assert_eq!(transposed, expected);
400    }
401
402    #[test]
403    fn test_transpose_idempotent() {
404        let a = symbol!(A; matrix);
405        let expr = Expression::symbol(a.clone());
406        let transposed_once = expr.transpose();
407        let transposed_twice = transposed_once.clone().transpose();
408
409        match transposed_twice {
410            Expression::Function { name, args } => {
411                assert_eq!(name.as_ref(), "transpose");
412                assert_eq!(args.len(), 1);
413                assert_eq!(args[0], transposed_once);
414            }
415            _ => panic!("Expected nested transpose function"),
416        }
417    }
418
419    #[test]
420    fn test_symbolic_matrix_operations_combined() {
421        let a = symbol!(A; matrix);
422        let b = symbol!(B; matrix);
423
424        let ab = Expression::mul(vec![
425            Expression::symbol(a.clone()),
426            Expression::symbol(b.clone()),
427        ]);
428
429        let ab_t = ab.transpose();
430        let ab_inv = ab.inverse();
431
432        let a_t = Expression::function("transpose", vec![Expression::symbol(a.clone())]);
433        let b_t = Expression::function("transpose", vec![Expression::symbol(b.clone())]);
434        let expected_transpose = Expression::mul(vec![b_t, a_t]);
435
436        let a_inv = Expression::function("inverse", vec![Expression::symbol(a.clone())]);
437        let b_inv = Expression::function("inverse", vec![Expression::symbol(b.clone())]);
438        let expected_inverse = Expression::mul(vec![b_inv, a_inv]);
439
440        assert_eq!(ab_t, expected_transpose);
441        assert_eq!(ab_inv, expected_inverse);
442    }
443}