use std::collections::HashMap;
use std::sync::Arc;
use crate::UOp;
use crate::op::pattern_derived::OpKey;
use super::RewriteResult;
pub type PatternClosure<C> = Arc<dyn Fn(&Arc<UOp>, &mut C) -> RewriteResult + Send + Sync>;
pub struct SimplifiedPatternMatcher<C = ()> {
indexed: HashMap<OpKey, Vec<PatternClosure<C>>>,
wildcards: Vec<PatternClosure<C>>,
}
impl<C> SimplifiedPatternMatcher<C> {
pub fn new() -> Self {
Self { indexed: HashMap::new(), wildcards: Vec::new() }
}
pub fn add<F>(&mut self, keys: &[OpKey], closure: F)
where
F: Fn(&Arc<UOp>, &mut C) -> RewriteResult + Send + Sync + 'static,
{
if keys.is_empty() {
self.wildcards.push(Arc::new(closure));
} else if keys.len() == 1 {
self.indexed.entry(keys[0].clone()).or_default().push(Arc::new(closure));
} else {
let shared: PatternClosure<C> = Arc::new(closure);
for key in keys {
self.indexed.entry(key.clone()).or_default().push(Arc::clone(&shared));
}
}
}
pub fn add_wildcard<F>(&mut self, closure: F)
where
F: Fn(&Arc<UOp>, &mut C) -> RewriteResult + Send + Sync + 'static,
{
self.wildcards.push(Arc::new(closure));
}
pub fn len(&self) -> usize {
self.indexed.values().map(|v| v.len()).sum::<usize>() + self.wildcards.len()
}
pub fn is_empty(&self) -> bool {
self.indexed.is_empty() && self.wildcards.is_empty()
}
pub fn wildcard_count(&self) -> usize {
self.wildcards.len()
}
pub fn indexed_count(&self) -> usize {
self.indexed.len()
}
pub fn rewrite(&self, uop: &Arc<UOp>, ctx: &mut C) -> RewriteResult {
let key = OpKey::from_op(uop.op());
if let Some(patterns) = self.indexed.get(&key) {
let pattern_count = patterns.len();
tracing::trace!(op_key = ?key, pattern_count, "trying indexed patterns");
for (idx, closure) in patterns.iter().enumerate() {
let result = closure(uop, ctx);
if !matches!(result, RewriteResult::NoMatch) {
tracing::debug!(op_key = ?key, pattern_idx = idx, "pattern matched");
return result;
}
}
}
if !self.wildcards.is_empty() {
tracing::trace!(wildcard_count = self.wildcards.len(), "trying wildcard patterns");
for (idx, closure) in self.wildcards.iter().enumerate() {
let result = closure(uop, ctx);
if !matches!(result, RewriteResult::NoMatch) {
tracing::debug!(wildcard_idx = idx, "wildcard pattern matched");
return result;
}
}
}
RewriteResult::NoMatch
}
}
impl<C> Clone for SimplifiedPatternMatcher<C> {
fn clone(&self) -> Self {
Self { indexed: self.indexed.clone(), wildcards: self.wildcards.clone() }
}
}
impl<C> Default for SimplifiedPatternMatcher<C> {
fn default() -> Self {
Self::new()
}
}
impl SimplifiedPatternMatcher<()> {
pub fn with_context<D: 'static + Send + Sync>(&self) -> SimplifiedPatternMatcher<D> {
let mut result = SimplifiedPatternMatcher::<D>::new();
for (key, closures) in &self.indexed {
for closure in closures {
let closure = Arc::clone(closure);
result
.indexed
.entry(key.clone())
.or_default()
.push(Arc::new(move |uop: &Arc<UOp>, _ctx: &mut D| closure(uop, &mut ())));
}
}
for closure in &self.wildcards {
let closure = Arc::clone(closure);
result.wildcards.push(Arc::new(move |uop: &Arc<UOp>, _ctx: &mut D| closure(uop, &mut ())));
}
result
}
}
impl<C> super::Matcher<C> for SimplifiedPatternMatcher<C> {
fn rewrite(&self, uop: &Arc<UOp>, ctx: &mut C) -> RewriteResult {
SimplifiedPatternMatcher::rewrite(self, uop, ctx)
}
}
impl<C> std::ops::Add for SimplifiedPatternMatcher<C> {
type Output = Self;
fn add(mut self, rhs: Self) -> Self::Output {
for (key, patterns) in rhs.indexed {
self.indexed.entry(key).or_default().extend(patterns);
}
self.wildcards.extend(rhs.wildcards);
self
}
}
impl<C> std::ops::Add for &SimplifiedPatternMatcher<C> {
type Output = SimplifiedPatternMatcher<C>;
fn add(self, rhs: Self) -> Self::Output {
self.clone() + rhs.clone()
}
}
impl<C> std::ops::Add<&SimplifiedPatternMatcher<C>> for SimplifiedPatternMatcher<C> {
type Output = SimplifiedPatternMatcher<C>;
fn add(self, rhs: &SimplifiedPatternMatcher<C>) -> Self::Output {
self + rhs.clone()
}
}
impl<C> std::ops::Add<SimplifiedPatternMatcher<C>> for &SimplifiedPatternMatcher<C> {
type Output = SimplifiedPatternMatcher<C>;
fn add(self, rhs: SimplifiedPatternMatcher<C>) -> Self::Output {
self.clone() + rhs
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::BinaryOp;
use crate::{ConstValue, Op, UOp};
use morok_dtype::DType;
fn const_int(v: i64) -> Arc<UOp> {
UOp::const_(DType::Int32, ConstValue::Int(v))
}
fn binary(op: BinaryOp, lhs: Arc<UOp>, rhs: Arc<UOp>) -> Arc<UOp> {
UOp::new(Op::Binary(op, lhs, rhs), DType::Int32)
}
#[test]
fn test_empty_matcher() {
let matcher = SimplifiedPatternMatcher::<()>::new();
assert!(matcher.is_empty());
assert_eq!(matcher.len(), 0);
}
#[test]
fn test_add_indexed_pattern() {
let mut matcher = SimplifiedPatternMatcher::<()>::new();
matcher.add(&[OpKey::Binary(BinaryOp::Add)], |_uop, _ctx| RewriteResult::NoMatch);
assert_eq!(matcher.len(), 1);
assert!(!matcher.is_empty());
}
#[test]
fn test_add_wildcard_pattern() {
let mut matcher = SimplifiedPatternMatcher::<()>::new();
matcher.add_wildcard(|_uop, _ctx| RewriteResult::NoMatch);
assert_eq!(matcher.len(), 1);
assert_eq!(matcher.wildcards.len(), 1);
}
#[test]
fn test_combine_matchers() {
let mut m1 = SimplifiedPatternMatcher::<()>::new();
m1.add(&[OpKey::Binary(BinaryOp::Add)], |_, _| RewriteResult::NoMatch);
let mut m2 = SimplifiedPatternMatcher::<()>::new();
m2.add(&[OpKey::Binary(BinaryOp::Mul)], |_, _| RewriteResult::NoMatch);
let combined = m1 + m2;
assert_eq!(combined.len(), 2);
}
#[test]
fn test_rewrite_basic() {
let mut matcher = SimplifiedPatternMatcher::<()>::new();
matcher.add(&[OpKey::Binary(BinaryOp::Add)], |uop, _ctx| {
let Op::Binary(BinaryOp::Add, left, right) = uop.op() else {
return RewriteResult::NoMatch;
};
if let Op::Const(cv) = right.op()
&& cv.0.is_zero()
{
return RewriteResult::Rewritten(left.clone());
}
if let Op::Const(cv) = left.op()
&& cv.0.is_zero()
{
return RewriteResult::Rewritten(right.clone());
}
RewriteResult::NoMatch
});
let five = const_int(5);
let zero = const_int(0);
let expr = binary(BinaryOp::Add, five.clone(), zero);
let result = matcher.rewrite(&expr, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(ref r) if Arc::ptr_eq(r, &five)));
let expr2 = binary(BinaryOp::Add, const_int(0), five.clone());
let result2 = matcher.rewrite(&expr2, &mut ());
assert!(matches!(result2, RewriteResult::Rewritten(ref r) if Arc::ptr_eq(r, &five)));
let expr3 = binary(BinaryOp::Add, const_int(3), const_int(4));
let result3 = matcher.rewrite(&expr3, &mut ());
assert!(matches!(result3, RewriteResult::NoMatch));
}
#[test]
fn test_wildcard_after_indexed() {
let mut matcher = SimplifiedPatternMatcher::<()>::new();
matcher.add(&[OpKey::Binary(BinaryOp::Add)], |_uop, _ctx| RewriteResult::NoMatch);
matcher.add_wildcard(|uop, _ctx| RewriteResult::Rewritten(uop.clone()));
let expr = binary(BinaryOp::Add, const_int(1), const_int(2));
let result = matcher.rewrite(&expr, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
}
}