mathhook_core/simplify/arithmetic/
matrix_ops.rs

1//! Fast-path matrix simplification helpers
2//!
3//! Provides optimized direct matrix computation paths to avoid unnecessary
4//! expression tree construction during simplification.
5
6use crate::core::Expression;
7use crate::error::MathError;
8use crate::matrices::CoreMatrixOps;
9
10/// Try to simplify A + B for matrices directly without building expression tree
11///
12/// Returns Some(Ok(expr)) if both operands are matrices with compatible dimensions.
13/// Returns Some(Err(e)) if both operands are matrices but dimensions are incompatible.
14/// Returns None if either operand is not a matrix.
15pub fn try_matrix_add(a: &Expression, b: &Expression) -> Option<Result<Expression, MathError>> {
16    match (a, b) {
17        (Expression::Matrix(ma), Expression::Matrix(mb)) => Some(
18            ma.add(mb)
19                .map(|result| Expression::Matrix(Box::new(result))),
20        ),
21        _ => None,
22    }
23}
24
25/// Try to simplify A * B for matrices directly without building expression tree
26///
27/// Returns Some(Ok(expr)) if both operands are matrices with compatible dimensions.
28/// Returns Some(Err(e)) if both operands are matrices but dimensions are incompatible.
29/// Returns None if either operand is not a matrix.
30pub fn try_matrix_multiply(
31    a: &Expression,
32    b: &Expression,
33) -> Option<Result<Expression, MathError>> {
34    match (a, b) {
35        (Expression::Matrix(ma), Expression::Matrix(mb)) => Some(
36            ma.multiply(mb)
37                .map(|result| Expression::Matrix(Box::new(result))),
38        ),
39        _ => None,
40    }
41}
42
43#[cfg(test)]
44mod tests {
45    use super::*;
46    use crate::expr;
47    use crate::simplify::Simplify;
48
49    #[test]
50    fn test_matrix_add_fast_path_compatible() {
51        let a = Expression::matrix(vec![vec![expr!(1), expr!(2)], vec![expr!(3), expr!(4)]]);
52        let b = Expression::matrix(vec![vec![expr!(5), expr!(6)], vec![expr!(7), expr!(8)]]);
53
54        let result = try_matrix_add(&a, &b);
55        assert!(result.is_some());
56
57        match result {
58            Some(Ok(Expression::Matrix(m))) => {
59                let (rows, cols) = m.dimensions();
60                assert_eq!(rows, 2);
61                assert_eq!(cols, 2);
62            }
63            Some(Err(_)) => panic!("Expected Ok result for compatible matrices"),
64            None => panic!("Expected Some for matrix operands"),
65            _ => panic!("Expected matrix result"),
66        }
67    }
68
69    #[test]
70    fn test_matrix_add_fast_path_incompatible() {
71        let a = Expression::matrix(vec![vec![expr!(1), expr!(2)], vec![expr!(3), expr!(4)]]);
72        let b = Expression::matrix(vec![vec![expr!(5), expr!(6), expr!(7)]]);
73
74        let result = try_matrix_add(&a, &b);
75        assert!(result.is_some());
76
77        match result {
78            Some(Err(MathError::DomainError {
79                operation, reason, ..
80            })) => {
81                assert_eq!(operation, "matrix_addition");
82                assert!(reason.contains("2x2"));
83                assert!(reason.contains("1x3"));
84            }
85            _ => panic!("Expected DomainError for incompatible dimensions"),
86        }
87    }
88
89    #[test]
90    fn test_matrix_add_fast_path_non_matrix() {
91        let a = Expression::matrix(vec![vec![expr!(1), expr!(2)]]);
92        let b = expr!(42);
93
94        let result = try_matrix_add(&a, &b);
95        assert!(result.is_none());
96    }
97
98    #[test]
99    fn test_matrix_multiply_fast_path_compatible() {
100        let a = Expression::matrix(vec![vec![expr!(1), expr!(2)], vec![expr!(3), expr!(4)]]);
101        let b = Expression::matrix(vec![vec![expr!(5), expr!(6)], vec![expr!(7), expr!(8)]]);
102
103        let result = try_matrix_multiply(&a, &b);
104        assert!(result.is_some());
105
106        match result {
107            Some(Ok(Expression::Matrix(m))) => {
108                let (rows, cols) = m.dimensions();
109                assert_eq!(rows, 2);
110                assert_eq!(cols, 2);
111            }
112            Some(Err(_)) => panic!("Expected Ok result for compatible matrices"),
113            None => panic!("Expected Some for matrix operands"),
114            _ => panic!("Expected matrix result"),
115        }
116    }
117
118    #[test]
119    fn test_matrix_multiply_fast_path_incompatible() {
120        let a = Expression::matrix(vec![vec![expr!(1), expr!(2)]]);
121        let b = Expression::matrix(vec![vec![expr!(5)], vec![expr!(6)], vec![expr!(7)]]);
122
123        let result = try_matrix_multiply(&a, &b);
124        assert!(result.is_some());
125
126        match result {
127            Some(Err(MathError::DomainError {
128                operation, reason, ..
129            })) => {
130                assert_eq!(operation, "matrix_multiplication");
131                assert!(reason.contains("1x2"));
132                assert!(reason.contains("3x1"));
133                assert!(reason.contains("2 != 3"));
134            }
135            _ => panic!("Expected DomainError for incompatible dimensions"),
136        }
137    }
138
139    #[test]
140    fn test_matrix_multiply_fast_path_non_matrix() {
141        let a = Expression::matrix(vec![vec![expr!(1), expr!(2)]]);
142        let b = expr!(x);
143
144        let result = try_matrix_multiply(&a, &b);
145        assert!(result.is_none());
146    }
147
148    #[test]
149    fn test_mixed_symbolic_matrix() {
150        let a = Expression::matrix(vec![vec![expr!(x), expr!(y)]]);
151        let b = Expression::matrix(vec![vec![expr!(2), expr!(3)]]);
152
153        let result = try_matrix_add(&a, &b);
154        assert!(result.is_some());
155        assert!(result.unwrap().is_ok());
156    }
157
158    /// Quick verification that the benchmark operations don't stack overflow
159    #[test]
160    fn test_benchmark_operations_no_stack_overflow() {
161        fn create_test_matrix(size: usize) -> Expression {
162            let mut rows = Vec::new();
163            for i in 0..size {
164                let mut row = Vec::new();
165                for j in 0..size {
166                    row.push(Expression::integer((i * size + j + 1) as i64));
167                }
168                rows.push(row);
169            }
170            Expression::matrix(rows)
171        }
172
173        // Test sizes from the benchmark: [2, 3, 4, 8, 16]
174        for size in [2, 3, 4, 8, 16] {
175            let matrix_a = create_test_matrix(size);
176            let matrix_b = create_test_matrix(size);
177
178            // These are exactly what the benchmark does
179            let _add_result = Expression::add(vec![matrix_a.clone(), matrix_b.clone()]).simplify();
180            let _mul_result = Expression::mul(vec![matrix_a.clone(), matrix_b.clone()]).simplify();
181        }
182    }
183}