1use std::collections::HashMap;
33use std::sync::Arc;
34
35use crate::UOp;
36use crate::op::pattern_derived::OpKey;
37
38use super::RewriteResult;
39
40pub type PatternClosure<C> = Arc<dyn Fn(&Arc<UOp>, &mut C) -> RewriteResult + Send + Sync>;
46
47pub struct SimplifiedPatternMatcher<C = ()> {
95 indexed: HashMap<OpKey, Vec<PatternClosure<C>>>,
97 wildcards: Vec<PatternClosure<C>>,
99}
100
101impl<C> SimplifiedPatternMatcher<C> {
102 pub fn new() -> Self {
104 Self { indexed: HashMap::new(), wildcards: Vec::new() }
105 }
106
107 pub fn add<F>(&mut self, keys: &[OpKey], closure: F)
112 where
113 F: Fn(&Arc<UOp>, &mut C) -> RewriteResult + Send + Sync + 'static,
114 {
115 if keys.is_empty() {
116 self.wildcards.push(Arc::new(closure));
118 } else if keys.len() == 1 {
119 self.indexed.entry(keys[0].clone()).or_default().push(Arc::new(closure));
121 } else {
122 let shared: PatternClosure<C> = Arc::new(closure);
124 for key in keys {
125 self.indexed.entry(key.clone()).or_default().push(Arc::clone(&shared));
126 }
127 }
128 }
129
130 pub fn add_wildcard<F>(&mut self, closure: F)
134 where
135 F: Fn(&Arc<UOp>, &mut C) -> RewriteResult + Send + Sync + 'static,
136 {
137 self.wildcards.push(Arc::new(closure));
138 }
139
140 pub fn len(&self) -> usize {
142 self.indexed.values().map(|v| v.len()).sum::<usize>() + self.wildcards.len()
143 }
144
145 pub fn is_empty(&self) -> bool {
147 self.indexed.is_empty() && self.wildcards.is_empty()
148 }
149
150 pub fn wildcard_count(&self) -> usize {
152 self.wildcards.len()
153 }
154
155 pub fn indexed_count(&self) -> usize {
157 self.indexed.len()
158 }
159
160 pub fn rewrite(&self, uop: &Arc<UOp>, ctx: &mut C) -> RewriteResult {
172 let key = OpKey::from_op(uop.op());
173
174 if let Some(patterns) = self.indexed.get(&key) {
176 let pattern_count = patterns.len();
177 tracing::trace!(op_key = ?key, pattern_count, "trying indexed patterns");
178
179 for (idx, closure) in patterns.iter().enumerate() {
180 let result = closure(uop, ctx);
181 if !matches!(result, RewriteResult::NoMatch) {
182 tracing::debug!(op_key = ?key, pattern_idx = idx, "pattern matched");
183 return result;
184 }
185 }
186 }
187
188 if !self.wildcards.is_empty() {
190 tracing::trace!(wildcard_count = self.wildcards.len(), "trying wildcard patterns");
191
192 for (idx, closure) in self.wildcards.iter().enumerate() {
193 let result = closure(uop, ctx);
194 if !matches!(result, RewriteResult::NoMatch) {
195 tracing::debug!(wildcard_idx = idx, "wildcard pattern matched");
196 return result;
197 }
198 }
199 }
200
201 RewriteResult::NoMatch
202 }
203}
204
205impl<C> Clone for SimplifiedPatternMatcher<C> {
206 fn clone(&self) -> Self {
207 Self { indexed: self.indexed.clone(), wildcards: self.wildcards.clone() }
208 }
209}
210
211impl<C> Default for SimplifiedPatternMatcher<C> {
212 fn default() -> Self {
213 Self::new()
214 }
215}
216
217impl SimplifiedPatternMatcher<()> {
218 pub fn with_context<D: 'static + Send + Sync>(&self) -> SimplifiedPatternMatcher<D> {
230 let mut result = SimplifiedPatternMatcher::<D>::new();
231 for (key, closures) in &self.indexed {
232 for closure in closures {
233 let closure = Arc::clone(closure);
234 result
235 .indexed
236 .entry(key.clone())
237 .or_default()
238 .push(Arc::new(move |uop: &Arc<UOp>, _ctx: &mut D| closure(uop, &mut ())));
239 }
240 }
241 for closure in &self.wildcards {
242 let closure = Arc::clone(closure);
243 result.wildcards.push(Arc::new(move |uop: &Arc<UOp>, _ctx: &mut D| closure(uop, &mut ())));
244 }
245 result
246 }
247}
248
249impl<C> super::Matcher<C> for SimplifiedPatternMatcher<C> {
251 fn rewrite(&self, uop: &Arc<UOp>, ctx: &mut C) -> RewriteResult {
252 SimplifiedPatternMatcher::rewrite(self, uop, ctx)
254 }
255}
256
257impl<C> std::ops::Add for SimplifiedPatternMatcher<C> {
259 type Output = Self;
260
261 fn add(mut self, rhs: Self) -> Self::Output {
263 for (key, patterns) in rhs.indexed {
265 self.indexed.entry(key).or_default().extend(patterns);
266 }
267 self.wildcards.extend(rhs.wildcards);
269 self
270 }
271}
272
273impl<C> std::ops::Add for &SimplifiedPatternMatcher<C> {
276 type Output = SimplifiedPatternMatcher<C>;
277
278 fn add(self, rhs: Self) -> Self::Output {
279 self.clone() + rhs.clone()
280 }
281}
282
283impl<C> std::ops::Add<&SimplifiedPatternMatcher<C>> for SimplifiedPatternMatcher<C> {
284 type Output = SimplifiedPatternMatcher<C>;
285
286 fn add(self, rhs: &SimplifiedPatternMatcher<C>) -> Self::Output {
287 self + rhs.clone()
288 }
289}
290
291impl<C> std::ops::Add<SimplifiedPatternMatcher<C>> for &SimplifiedPatternMatcher<C> {
292 type Output = SimplifiedPatternMatcher<C>;
293
294 fn add(self, rhs: SimplifiedPatternMatcher<C>) -> Self::Output {
295 self.clone() + rhs
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302 use crate::types::BinaryOp;
303 use crate::{ConstValue, Op, UOp};
304 use morok_dtype::DType;
305
306 fn const_int(v: i64) -> Arc<UOp> {
307 UOp::const_(DType::Int32, ConstValue::Int(v))
308 }
309
310 fn binary(op: BinaryOp, lhs: Arc<UOp>, rhs: Arc<UOp>) -> Arc<UOp> {
311 UOp::new(Op::Binary(op, lhs, rhs), DType::Int32)
313 }
314
315 #[test]
316 fn test_empty_matcher() {
317 let matcher = SimplifiedPatternMatcher::<()>::new();
318 assert!(matcher.is_empty());
319 assert_eq!(matcher.len(), 0);
320 }
321
322 #[test]
323 fn test_add_indexed_pattern() {
324 let mut matcher = SimplifiedPatternMatcher::<()>::new();
325
326 matcher.add(&[OpKey::Binary(BinaryOp::Add)], |_uop, _ctx| RewriteResult::NoMatch);
327
328 assert_eq!(matcher.len(), 1);
329 assert!(!matcher.is_empty());
330 }
331
332 #[test]
333 fn test_add_wildcard_pattern() {
334 let mut matcher = SimplifiedPatternMatcher::<()>::new();
335
336 matcher.add_wildcard(|_uop, _ctx| RewriteResult::NoMatch);
337
338 assert_eq!(matcher.len(), 1);
339 assert_eq!(matcher.wildcards.len(), 1);
340 }
341
342 #[test]
343 fn test_combine_matchers() {
344 let mut m1 = SimplifiedPatternMatcher::<()>::new();
345 m1.add(&[OpKey::Binary(BinaryOp::Add)], |_, _| RewriteResult::NoMatch);
346
347 let mut m2 = SimplifiedPatternMatcher::<()>::new();
348 m2.add(&[OpKey::Binary(BinaryOp::Mul)], |_, _| RewriteResult::NoMatch);
349
350 let combined = m1 + m2;
351 assert_eq!(combined.len(), 2);
352 }
353
354 #[test]
355 fn test_rewrite_basic() {
356 let mut matcher = SimplifiedPatternMatcher::<()>::new();
357
358 matcher.add(&[OpKey::Binary(BinaryOp::Add)], |uop, _ctx| {
360 let Op::Binary(BinaryOp::Add, left, right) = uop.op() else {
361 return RewriteResult::NoMatch;
362 };
363 if let Op::Const(cv) = right.op()
365 && cv.0.is_zero()
366 {
367 return RewriteResult::Rewritten(left.clone());
368 }
369 if let Op::Const(cv) = left.op()
371 && cv.0.is_zero()
372 {
373 return RewriteResult::Rewritten(right.clone());
374 }
375 RewriteResult::NoMatch
376 });
377
378 let five = const_int(5);
380 let zero = const_int(0);
381 let expr = binary(BinaryOp::Add, five.clone(), zero);
382
383 let result = matcher.rewrite(&expr, &mut ());
384 assert!(matches!(result, RewriteResult::Rewritten(ref r) if Arc::ptr_eq(r, &five)));
385
386 let expr2 = binary(BinaryOp::Add, const_int(0), five.clone());
388 let result2 = matcher.rewrite(&expr2, &mut ());
389 assert!(matches!(result2, RewriteResult::Rewritten(ref r) if Arc::ptr_eq(r, &five)));
390
391 let expr3 = binary(BinaryOp::Add, const_int(3), const_int(4));
393 let result3 = matcher.rewrite(&expr3, &mut ());
394 assert!(matches!(result3, RewriteResult::NoMatch));
395 }
396
397 #[test]
398 fn test_wildcard_after_indexed() {
399 let mut matcher = SimplifiedPatternMatcher::<()>::new();
400
401 matcher.add(&[OpKey::Binary(BinaryOp::Add)], |_uop, _ctx| RewriteResult::NoMatch);
403
404 matcher.add_wildcard(|uop, _ctx| RewriteResult::Rewritten(uop.clone()));
406
407 let expr = binary(BinaryOp::Add, const_int(1), const_int(2));
408
409 let result = matcher.rewrite(&expr, &mut ());
411 assert!(matches!(result, RewriteResult::Rewritten(_)));
412 }
413}