mathhook_core/simplify/arithmetic/
matrix_ops.rs

1//! Fast-path matrix simplification helpers
2//!
3//! Provides optimized direct matrix computation paths to avoid unnecessary
4//! expression tree construction during simplification.
5
6use std::sync::Arc;
7
8use crate::core::Expression;
9use crate::error::MathError;
10use crate::matrices::CoreMatrixOps;
11
12/// Try to simplify A + B for matrices directly without building expression tree
13///
14/// Returns Some(Ok(expr)) if both operands are matrices with compatible dimensions.
15/// Returns Some(Err(e)) if both operands are matrices but dimensions are incompatible.
16/// Returns None if either operand is not a matrix.
17pub fn try_matrix_add(a: &Expression, b: &Expression) -> Option<Result<Expression, MathError>> {
18    match (a, b) {
19        (Expression::Matrix(ma), Expression::Matrix(mb)) => Some(
20            ma.add(mb)
21                .map(|result| Expression::Matrix(Arc::new(result))),
22        ),
23        _ => None,
24    }
25}
26
27/// Try to simplify A * B for matrices directly without building expression tree
28///
29/// Returns Some(Ok(expr)) if both operands are matrices with compatible dimensions.
30/// Returns Some(Err(e)) if both operands are matrices but dimensions are incompatible.
31/// Returns None if either operand is not a matrix.
32pub fn try_matrix_multiply(
33    a: &Expression,
34    b: &Expression,
35) -> Option<Result<Expression, MathError>> {
36    match (a, b) {
37        (Expression::Matrix(ma), Expression::Matrix(mb)) => Some(
38            ma.multiply(mb)
39                .map(|result| Expression::Matrix(Arc::new(result))),
40        ),
41        _ => None,
42    }
43}
44
45#[cfg(test)]
46mod tests {
47    use super::*;
48    use crate::expr;
49    use crate::simplify::Simplify;
50
51    #[test]
52    fn test_matrix_add_fast_path_compatible() {
53        let a = Expression::matrix(vec![vec![expr!(1), expr!(2)], vec![expr!(3), expr!(4)]]);
54        let b = Expression::matrix(vec![vec![expr!(5), expr!(6)], vec![expr!(7), expr!(8)]]);
55
56        let result = try_matrix_add(&a, &b);
57        assert!(result.is_some());
58
59        match result {
60            Some(Ok(Expression::Matrix(m))) => {
61                let (rows, cols) = m.dimensions();
62                assert_eq!(rows, 2);
63                assert_eq!(cols, 2);
64            }
65            Some(Err(_)) => panic!("Expected Ok result for compatible matrices"),
66            None => panic!("Expected Some for matrix operands"),
67            _ => panic!("Expected matrix result"),
68        }
69    }
70
71    #[test]
72    fn test_matrix_add_fast_path_incompatible() {
73        let a = Expression::matrix(vec![vec![expr!(1), expr!(2)], vec![expr!(3), expr!(4)]]);
74        let b = Expression::matrix(vec![vec![expr!(5), expr!(6), expr!(7)]]);
75
76        let result = try_matrix_add(&a, &b);
77        assert!(result.is_some());
78
79        match result {
80            Some(Err(MathError::DomainError {
81                operation, reason, ..
82            })) => {
83                assert_eq!(operation, "matrix_addition");
84                assert!(reason.contains("2x2"));
85                assert!(reason.contains("1x3"));
86            }
87            _ => panic!("Expected DomainError for incompatible dimensions"),
88        }
89    }
90
91    #[test]
92    fn test_matrix_add_fast_path_non_matrix() {
93        let a = Expression::matrix(vec![vec![expr!(1), expr!(2)]]);
94        let b = expr!(42);
95
96        let result = try_matrix_add(&a, &b);
97        assert!(result.is_none());
98    }
99
100    #[test]
101    fn test_matrix_multiply_fast_path_compatible() {
102        let a = Expression::matrix(vec![vec![expr!(1), expr!(2)], vec![expr!(3), expr!(4)]]);
103        let b = Expression::matrix(vec![vec![expr!(5), expr!(6)], vec![expr!(7), expr!(8)]]);
104
105        let result = try_matrix_multiply(&a, &b);
106        assert!(result.is_some());
107
108        match result {
109            Some(Ok(Expression::Matrix(m))) => {
110                let (rows, cols) = m.dimensions();
111                assert_eq!(rows, 2);
112                assert_eq!(cols, 2);
113            }
114            Some(Err(_)) => panic!("Expected Ok result for compatible matrices"),
115            None => panic!("Expected Some for matrix operands"),
116            _ => panic!("Expected matrix result"),
117        }
118    }
119
120    #[test]
121    fn test_matrix_multiply_fast_path_incompatible() {
122        let a = Expression::matrix(vec![vec![expr!(1), expr!(2)]]);
123        let b = Expression::matrix(vec![vec![expr!(5)], vec![expr!(6)], vec![expr!(7)]]);
124
125        let result = try_matrix_multiply(&a, &b);
126        assert!(result.is_some());
127
128        match result {
129            Some(Err(MathError::DomainError {
130                operation, reason, ..
131            })) => {
132                assert_eq!(operation, "matrix_multiplication");
133                assert!(reason.contains("1x2"));
134                assert!(reason.contains("3x1"));
135                assert!(reason.contains("2 != 3"));
136            }
137            _ => panic!("Expected DomainError for incompatible dimensions"),
138        }
139    }
140
141    #[test]
142    fn test_matrix_multiply_fast_path_non_matrix() {
143        let a = Expression::matrix(vec![vec![expr!(1), expr!(2)]]);
144        let b = expr!(x);
145
146        let result = try_matrix_multiply(&a, &b);
147        assert!(result.is_none());
148    }
149
150    #[test]
151    fn test_mixed_symbolic_matrix() {
152        let a = Expression::matrix(vec![vec![expr!(x), expr!(y)]]);
153        let b = Expression::matrix(vec![vec![expr!(2), expr!(3)]]);
154
155        let result = try_matrix_add(&a, &b);
156        assert!(result.is_some());
157        assert!(result.unwrap().is_ok());
158    }
159
160    /// Quick verification that the benchmark operations don't stack overflow
161    #[test]
162    fn test_benchmark_operations_no_stack_overflow() {
163        fn create_test_matrix(size: usize) -> Expression {
164            let mut rows = Vec::new();
165            for i in 0..size {
166                let mut row = Vec::new();
167                for j in 0..size {
168                    row.push(Expression::integer((i * size + j + 1) as i64));
169                }
170                rows.push(row);
171            }
172            Expression::matrix(rows)
173        }
174
175        // Test sizes from the benchmark: [2, 3, 4, 8, 16]
176        for size in [2, 3, 4, 8, 16] {
177            let matrix_a = create_test_matrix(size);
178            let matrix_b = create_test_matrix(size);
179
180            // These are exactly what the benchmark does
181            let _add_result = Expression::add(vec![matrix_a.clone(), matrix_b.clone()]).simplify();
182            let _mul_result = Expression::mul(vec![matrix_a.clone(), matrix_b.clone()]).simplify();
183        }
184    }
185}