mathhook_core/simplify/arithmetic/
matrix_ops.rs1use crate::core::Expression;
7use crate::error::MathError;
8use crate::matrices::CoreMatrixOps;
9
10pub fn try_matrix_add(a: &Expression, b: &Expression) -> Option<Result<Expression, MathError>> {
16 match (a, b) {
17 (Expression::Matrix(ma), Expression::Matrix(mb)) => Some(
18 ma.add(mb)
19 .map(|result| Expression::Matrix(Box::new(result))),
20 ),
21 _ => None,
22 }
23}
24
25pub fn try_matrix_multiply(
31 a: &Expression,
32 b: &Expression,
33) -> Option<Result<Expression, MathError>> {
34 match (a, b) {
35 (Expression::Matrix(ma), Expression::Matrix(mb)) => Some(
36 ma.multiply(mb)
37 .map(|result| Expression::Matrix(Box::new(result))),
38 ),
39 _ => None,
40 }
41}
42
43#[cfg(test)]
44mod tests {
45 use super::*;
46 use crate::expr;
47 use crate::simplify::Simplify;
48
49 #[test]
50 fn test_matrix_add_fast_path_compatible() {
51 let a = Expression::matrix(vec![vec![expr!(1), expr!(2)], vec![expr!(3), expr!(4)]]);
52 let b = Expression::matrix(vec![vec![expr!(5), expr!(6)], vec![expr!(7), expr!(8)]]);
53
54 let result = try_matrix_add(&a, &b);
55 assert!(result.is_some());
56
57 match result {
58 Some(Ok(Expression::Matrix(m))) => {
59 let (rows, cols) = m.dimensions();
60 assert_eq!(rows, 2);
61 assert_eq!(cols, 2);
62 }
63 Some(Err(_)) => panic!("Expected Ok result for compatible matrices"),
64 None => panic!("Expected Some for matrix operands"),
65 _ => panic!("Expected matrix result"),
66 }
67 }
68
69 #[test]
70 fn test_matrix_add_fast_path_incompatible() {
71 let a = Expression::matrix(vec![vec![expr!(1), expr!(2)], vec![expr!(3), expr!(4)]]);
72 let b = Expression::matrix(vec![vec![expr!(5), expr!(6), expr!(7)]]);
73
74 let result = try_matrix_add(&a, &b);
75 assert!(result.is_some());
76
77 match result {
78 Some(Err(MathError::DomainError {
79 operation, reason, ..
80 })) => {
81 assert_eq!(operation, "matrix_addition");
82 assert!(reason.contains("2x2"));
83 assert!(reason.contains("1x3"));
84 }
85 _ => panic!("Expected DomainError for incompatible dimensions"),
86 }
87 }
88
89 #[test]
90 fn test_matrix_add_fast_path_non_matrix() {
91 let a = Expression::matrix(vec![vec![expr!(1), expr!(2)]]);
92 let b = expr!(42);
93
94 let result = try_matrix_add(&a, &b);
95 assert!(result.is_none());
96 }
97
98 #[test]
99 fn test_matrix_multiply_fast_path_compatible() {
100 let a = Expression::matrix(vec![vec![expr!(1), expr!(2)], vec![expr!(3), expr!(4)]]);
101 let b = Expression::matrix(vec![vec![expr!(5), expr!(6)], vec![expr!(7), expr!(8)]]);
102
103 let result = try_matrix_multiply(&a, &b);
104 assert!(result.is_some());
105
106 match result {
107 Some(Ok(Expression::Matrix(m))) => {
108 let (rows, cols) = m.dimensions();
109 assert_eq!(rows, 2);
110 assert_eq!(cols, 2);
111 }
112 Some(Err(_)) => panic!("Expected Ok result for compatible matrices"),
113 None => panic!("Expected Some for matrix operands"),
114 _ => panic!("Expected matrix result"),
115 }
116 }
117
118 #[test]
119 fn test_matrix_multiply_fast_path_incompatible() {
120 let a = Expression::matrix(vec![vec![expr!(1), expr!(2)]]);
121 let b = Expression::matrix(vec![vec![expr!(5)], vec![expr!(6)], vec![expr!(7)]]);
122
123 let result = try_matrix_multiply(&a, &b);
124 assert!(result.is_some());
125
126 match result {
127 Some(Err(MathError::DomainError {
128 operation, reason, ..
129 })) => {
130 assert_eq!(operation, "matrix_multiplication");
131 assert!(reason.contains("1x2"));
132 assert!(reason.contains("3x1"));
133 assert!(reason.contains("2 != 3"));
134 }
135 _ => panic!("Expected DomainError for incompatible dimensions"),
136 }
137 }
138
139 #[test]
140 fn test_matrix_multiply_fast_path_non_matrix() {
141 let a = Expression::matrix(vec![vec![expr!(1), expr!(2)]]);
142 let b = expr!(x);
143
144 let result = try_matrix_multiply(&a, &b);
145 assert!(result.is_none());
146 }
147
148 #[test]
149 fn test_mixed_symbolic_matrix() {
150 let a = Expression::matrix(vec![vec![expr!(x), expr!(y)]]);
151 let b = Expression::matrix(vec![vec![expr!(2), expr!(3)]]);
152
153 let result = try_matrix_add(&a, &b);
154 assert!(result.is_some());
155 assert!(result.unwrap().is_ok());
156 }
157
158 #[test]
160 fn test_benchmark_operations_no_stack_overflow() {
161 fn create_test_matrix(size: usize) -> Expression {
162 let mut rows = Vec::new();
163 for i in 0..size {
164 let mut row = Vec::new();
165 for j in 0..size {
166 row.push(Expression::integer((i * size + j + 1) as i64));
167 }
168 rows.push(row);
169 }
170 Expression::matrix(rows)
171 }
172
173 for size in [2, 3, 4, 8, 16] {
175 let matrix_a = create_test_matrix(size);
176 let matrix_b = create_test_matrix(size);
177
178 let _add_result = Expression::add(vec![matrix_a.clone(), matrix_b.clone()]).simplify();
180 let _mul_result = Expression::mul(vec![matrix_a.clone(), matrix_b.clone()]).simplify();
181 }
182 }
183}