mathhook_core/simplify/arithmetic/
matrix_ops.rs1use std::sync::Arc;
7
8use crate::core::Expression;
9use crate::error::MathError;
10use crate::matrices::CoreMatrixOps;
11
12pub fn try_matrix_add(a: &Expression, b: &Expression) -> Option<Result<Expression, MathError>> {
18 match (a, b) {
19 (Expression::Matrix(ma), Expression::Matrix(mb)) => Some(
20 ma.add(mb)
21 .map(|result| Expression::Matrix(Arc::new(result))),
22 ),
23 _ => None,
24 }
25}
26
27pub fn try_matrix_multiply(
33 a: &Expression,
34 b: &Expression,
35) -> Option<Result<Expression, MathError>> {
36 match (a, b) {
37 (Expression::Matrix(ma), Expression::Matrix(mb)) => Some(
38 ma.multiply(mb)
39 .map(|result| Expression::Matrix(Arc::new(result))),
40 ),
41 _ => None,
42 }
43}
44
45#[cfg(test)]
46mod tests {
47 use super::*;
48 use crate::expr;
49 use crate::simplify::Simplify;
50
51 #[test]
52 fn test_matrix_add_fast_path_compatible() {
53 let a = Expression::matrix(vec![vec![expr!(1), expr!(2)], vec![expr!(3), expr!(4)]]);
54 let b = Expression::matrix(vec![vec![expr!(5), expr!(6)], vec![expr!(7), expr!(8)]]);
55
56 let result = try_matrix_add(&a, &b);
57 assert!(result.is_some());
58
59 match result {
60 Some(Ok(Expression::Matrix(m))) => {
61 let (rows, cols) = m.dimensions();
62 assert_eq!(rows, 2);
63 assert_eq!(cols, 2);
64 }
65 Some(Err(_)) => panic!("Expected Ok result for compatible matrices"),
66 None => panic!("Expected Some for matrix operands"),
67 _ => panic!("Expected matrix result"),
68 }
69 }
70
71 #[test]
72 fn test_matrix_add_fast_path_incompatible() {
73 let a = Expression::matrix(vec![vec![expr!(1), expr!(2)], vec![expr!(3), expr!(4)]]);
74 let b = Expression::matrix(vec![vec![expr!(5), expr!(6), expr!(7)]]);
75
76 let result = try_matrix_add(&a, &b);
77 assert!(result.is_some());
78
79 match result {
80 Some(Err(MathError::DomainError {
81 operation, reason, ..
82 })) => {
83 assert_eq!(operation, "matrix_addition");
84 assert!(reason.contains("2x2"));
85 assert!(reason.contains("1x3"));
86 }
87 _ => panic!("Expected DomainError for incompatible dimensions"),
88 }
89 }
90
91 #[test]
92 fn test_matrix_add_fast_path_non_matrix() {
93 let a = Expression::matrix(vec![vec![expr!(1), expr!(2)]]);
94 let b = expr!(42);
95
96 let result = try_matrix_add(&a, &b);
97 assert!(result.is_none());
98 }
99
100 #[test]
101 fn test_matrix_multiply_fast_path_compatible() {
102 let a = Expression::matrix(vec![vec![expr!(1), expr!(2)], vec![expr!(3), expr!(4)]]);
103 let b = Expression::matrix(vec![vec![expr!(5), expr!(6)], vec![expr!(7), expr!(8)]]);
104
105 let result = try_matrix_multiply(&a, &b);
106 assert!(result.is_some());
107
108 match result {
109 Some(Ok(Expression::Matrix(m))) => {
110 let (rows, cols) = m.dimensions();
111 assert_eq!(rows, 2);
112 assert_eq!(cols, 2);
113 }
114 Some(Err(_)) => panic!("Expected Ok result for compatible matrices"),
115 None => panic!("Expected Some for matrix operands"),
116 _ => panic!("Expected matrix result"),
117 }
118 }
119
120 #[test]
121 fn test_matrix_multiply_fast_path_incompatible() {
122 let a = Expression::matrix(vec![vec![expr!(1), expr!(2)]]);
123 let b = Expression::matrix(vec![vec![expr!(5)], vec![expr!(6)], vec![expr!(7)]]);
124
125 let result = try_matrix_multiply(&a, &b);
126 assert!(result.is_some());
127
128 match result {
129 Some(Err(MathError::DomainError {
130 operation, reason, ..
131 })) => {
132 assert_eq!(operation, "matrix_multiplication");
133 assert!(reason.contains("1x2"));
134 assert!(reason.contains("3x1"));
135 assert!(reason.contains("2 != 3"));
136 }
137 _ => panic!("Expected DomainError for incompatible dimensions"),
138 }
139 }
140
141 #[test]
142 fn test_matrix_multiply_fast_path_non_matrix() {
143 let a = Expression::matrix(vec![vec![expr!(1), expr!(2)]]);
144 let b = expr!(x);
145
146 let result = try_matrix_multiply(&a, &b);
147 assert!(result.is_none());
148 }
149
150 #[test]
151 fn test_mixed_symbolic_matrix() {
152 let a = Expression::matrix(vec![vec![expr!(x), expr!(y)]]);
153 let b = Expression::matrix(vec![vec![expr!(2), expr!(3)]]);
154
155 let result = try_matrix_add(&a, &b);
156 assert!(result.is_some());
157 assert!(result.unwrap().is_ok());
158 }
159
160 #[test]
162 fn test_benchmark_operations_no_stack_overflow() {
163 fn create_test_matrix(size: usize) -> Expression {
164 let mut rows = Vec::new();
165 for i in 0..size {
166 let mut row = Vec::new();
167 for j in 0..size {
168 row.push(Expression::integer((i * size + j + 1) as i64));
169 }
170 rows.push(row);
171 }
172 Expression::matrix(rows)
173 }
174
175 for size in [2, 3, 4, 8, 16] {
177 let matrix_a = create_test_matrix(size);
178 let matrix_b = create_test_matrix(size);
179
180 let _add_result = Expression::add(vec![matrix_a.clone(), matrix_b.clone()]).simplify();
182 let _mul_result = Expression::mul(vec![matrix_a.clone(), matrix_b.clone()]).simplify();
183 }
184 }
185}