mathhook_core/pattern/matching/engine/
core.rs1use super::{apply_replacement, match_commutative, PatternMatches};
6use crate::core::Expression;
7use crate::pattern::matching::patterns::Pattern;
8use std::collections::HashMap;
9
10pub trait Matchable {
12 fn matches(&self, pattern: &Pattern) -> Option<PatternMatches>;
51
52 fn replace(&self, pattern: &Pattern, replacement: &Pattern) -> Expression;
106}
107
108impl Matchable for Expression {
109 fn matches(&self, pattern: &Pattern) -> Option<PatternMatches> {
110 let mut bindings = HashMap::new();
111 if match_recursive(self, pattern, &mut bindings) {
112 Some(bindings)
113 } else {
114 None
115 }
116 }
117
118 fn replace(&self, pattern: &Pattern, replacement: &Pattern) -> Expression {
119 if let Some(bindings) = self.matches(pattern) {
120 apply_replacement(replacement, &bindings)
121 } else {
122 match self {
123 Expression::Add(terms) => {
124 let new_terms: Vec<Expression> = terms
125 .iter()
126 .map(|t| t.replace(pattern, replacement))
127 .collect();
128 Expression::Add(Box::new(new_terms))
129 }
130
131 Expression::Mul(factors) => {
132 let new_factors: Vec<Expression> = factors
133 .iter()
134 .map(|f| f.replace(pattern, replacement))
135 .collect();
136 Expression::Mul(Box::new(new_factors))
137 }
138
139 Expression::Pow(base, exp) => {
140 let new_base = base.replace(pattern, replacement);
141 let new_exp = exp.replace(pattern, replacement);
142 Expression::Pow(Box::new(new_base), Box::new(new_exp))
143 }
144
145 Expression::Function { name, args } => {
146 let new_args: Vec<Expression> = args
147 .iter()
148 .map(|a| a.replace(pattern, replacement))
149 .collect();
150 Expression::Function {
151 name: name.clone(),
152 args: Box::new(new_args),
153 }
154 }
155
156 _ => self.clone(),
157 }
158 }
159 }
160}
161
162pub(super) fn match_recursive(
167 expr: &Expression,
168 pattern: &Pattern,
169 bindings: &mut PatternMatches,
170) -> bool {
171 match pattern {
172 Pattern::Wildcard { name, constraints } => {
173 if let Some(constraints) = constraints {
174 if !constraints.is_satisfied_by(expr) {
175 return false;
176 }
177 }
178
179 if let Some(existing) = bindings.get(name) {
180 expr == existing
181 } else {
182 bindings.insert(name.clone(), expr.clone());
183 true
184 }
185 }
186
187 Pattern::Exact(pattern_expr) => expr == pattern_expr,
188
189 Pattern::Add(pattern_terms) => {
190 if let Expression::Add(expr_terms) = expr {
191 match_commutative(expr_terms, pattern_terms, bindings)
192 } else {
193 false
194 }
195 }
196
197 Pattern::Mul(pattern_factors) => {
198 if let Expression::Mul(expr_factors) = expr {
199 match_commutative(expr_factors, pattern_factors, bindings)
200 } else {
201 false
202 }
203 }
204
205 Pattern::Pow(pattern_base, pattern_exp) => {
206 if let Expression::Pow(expr_base, expr_exp) = expr {
207 match_recursive(expr_base, pattern_base, bindings)
208 && match_recursive(expr_exp, pattern_exp, bindings)
209 } else {
210 false
211 }
212 }
213
214 Pattern::Function { name, args } => {
215 if let Expression::Function {
216 name: expr_name,
217 args: expr_args,
218 } = expr
219 {
220 if expr_name != name {
221 return false;
222 }
223
224 if expr_args.len() != args.len() {
225 return false;
226 }
227
228 for (expr_arg, pattern_arg) in expr_args.iter().zip(args.iter()) {
229 if !match_recursive(expr_arg, pattern_arg, bindings) {
230 return false;
231 }
232 }
233
234 true
235 } else {
236 false
237 }
238 }
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245 use crate::pattern::matching::patterns::Pattern;
246 use crate::prelude::*;
247
248 #[test]
249 fn test_wildcard_pattern_matches() {
250 let expr = Expression::integer(42);
251 let pattern = Pattern::wildcard("x");
252
253 let matches = expr.matches(&pattern);
254 assert!(matches.is_some());
255
256 if let Some(bindings) = matches {
257 assert_eq!(bindings.get("x"), Some(&Expression::integer(42)));
258 }
259 }
260
261 #[test]
262 fn test_exact_pattern_matches() {
263 let expr = Expression::integer(42);
264 let pattern = Pattern::Exact(Expression::integer(42));
265
266 assert!(expr.matches(&pattern).is_some());
267 }
268
269 #[test]
270 fn test_exact_pattern_no_match() {
271 let expr = Expression::integer(42);
272 let pattern = Pattern::Exact(Expression::integer(43));
273
274 assert!(expr.matches(&pattern).is_none());
275 }
276
277 #[test]
278 fn test_addition_pattern() {
279 let x = symbol!(x);
280 let expr = Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]);
281
282 let pattern = Pattern::Add(vec![Pattern::wildcard("a"), Pattern::wildcard("b")]);
283
284 let matches = expr.matches(&pattern);
285 assert!(matches.is_some());
286
287 if let Some(bindings) = matches {
288 let a_val = bindings.get("a").unwrap();
289 let b_val = bindings.get("b").unwrap();
290
291 assert!(
292 (a_val == &Expression::symbol(x.clone()) && b_val == &Expression::integer(1))
293 || (a_val == &Expression::integer(1)
294 && b_val == &Expression::symbol(x.clone()))
295 );
296 }
297 }
298
299 #[test]
300 fn test_multiplication_pattern() {
301 let x = symbol!(x);
302 let expr = Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]);
303
304 let pattern = Pattern::Mul(vec![
305 Pattern::Exact(Expression::integer(2)),
306 Pattern::wildcard("x"),
307 ]);
308
309 let matches = expr.matches(&pattern);
310 assert!(matches.is_some());
311
312 if let Some(bindings) = matches {
313 assert_eq!(bindings.get("x"), Some(&Expression::symbol(x.clone())));
314 }
315 }
316
317 #[test]
318 fn test_power_pattern() {
319 let x = symbol!(x);
320 let expr = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
321
322 let pattern = Pattern::Pow(
323 Box::new(Pattern::wildcard("base")),
324 Box::new(Pattern::Exact(Expression::integer(2))),
325 );
326
327 let matches = expr.matches(&pattern);
328 assert!(matches.is_some());
329
330 if let Some(bindings) = matches {
331 assert_eq!(bindings.get("base"), Some(&Expression::symbol(x.clone())));
332 }
333 }
334
335 #[test]
336 fn test_function_pattern() {
337 let x = symbol!(x);
338 let expr = Expression::function("sin".to_string(), vec![Expression::symbol(x.clone())]);
339
340 let pattern = Pattern::Function {
341 name: "sin".to_string(),
342 args: vec![Pattern::wildcard("arg")],
343 };
344
345 let matches = expr.matches(&pattern);
346 assert!(matches.is_some());
347
348 if let Some(bindings) = matches {
349 assert_eq!(bindings.get("arg"), Some(&Expression::symbol(x.clone())));
350 }
351 }
352
353 #[test]
354 fn test_wildcard_consistency() {
355 let x = symbol!(x);
356 let expr = Expression::Add(Box::new(vec![
357 Expression::symbol(x.clone()),
358 Expression::symbol(x.clone()),
359 ]));
360
361 let pattern = Pattern::Add(vec![Pattern::wildcard("a"), Pattern::wildcard("a")]);
362
363 let matches = expr.matches(&pattern);
364 assert!(matches.is_some());
365
366 if let Some(bindings) = matches {
367 assert_eq!(bindings.get("a"), Some(&Expression::symbol(x.clone())));
368 }
369 }
370
371 #[test]
372 fn test_wildcard_inconsistency() {
373 let x = symbol!(x);
374 let y = symbol!(y);
375 let expr = Expression::add(vec![
376 Expression::symbol(x.clone()),
377 Expression::symbol(y.clone()),
378 ]);
379
380 let pattern = Pattern::Add(vec![Pattern::wildcard("a"), Pattern::wildcard("a")]);
381
382 assert!(expr.matches(&pattern).is_none());
383 }
384
385 #[test]
386 fn test_wildcard_with_exclude() {
387 let x = symbol!(x);
388 let y = symbol!(y);
389
390 let pattern = Pattern::wildcard_excluding("a", vec![Expression::symbol(x.clone())]);
391
392 assert!(Expression::symbol(x.clone()).matches(&pattern).is_none());
393
394 assert!(Expression::symbol(y.clone()).matches(&pattern).is_some());
395
396 let expr_with_x =
397 Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]);
398 assert!(expr_with_x.matches(&pattern).is_none());
399 }
400
401 #[test]
402 fn test_wildcard_with_property() {
403 fn is_integer(expr: &Expression) -> bool {
404 matches!(expr, Expression::Number(_))
405 }
406
407 let pattern = Pattern::wildcard_with_properties("n", vec![is_integer]);
408
409 assert!(Expression::integer(42).matches(&pattern).is_some());
410
411 let x = symbol!(x);
412 assert!(Expression::symbol(x.clone()).matches(&pattern).is_none());
413 }
414}