use alloc::{collections::BTreeMap, rc::Rc};
use smallvec::SmallVec;
use super::{FrozenRewritePatternSet, PatternBenefit, RewritePattern, Rewriter};
use crate::{OperationName, OperationRef, ProgramPoint, Report};
pub enum PatternApplicationError {
NoMatchesFound,
Report(Report),
}
pub struct PatternApplicator {
rewrite_patterns_set: Rc<FrozenRewritePatternSet>,
patterns: BTreeMap<OperationName, SmallVec<[Rc<dyn RewritePattern>; 2]>>,
match_any_patterns: SmallVec<[Rc<dyn RewritePattern>; 1]>,
}
impl PatternApplicator {
pub fn new(rewrite_patterns_set: Rc<FrozenRewritePatternSet>) -> Self {
Self {
rewrite_patterns_set,
patterns: Default::default(),
match_any_patterns: Default::default(),
}
}
pub fn apply_cost_model<CostModel>(&mut self, model: CostModel)
where
CostModel: Fn(&dyn RewritePattern) -> PatternBenefit,
{
self.match_any_patterns.clear();
self.patterns.clear();
let mut benefits = SmallVec::<[_; 4]>::default();
for (op, op_patterns) in self.rewrite_patterns_set.op_specific_patterns().iter() {
benefits
.extend(op_patterns.iter().filter_map(|p| filter_map_pattern_benefit(p, &model)));
benefits.sort_by_key(|(_, benefit)| *benefit);
self.patterns
.insert(op.clone(), benefits.drain(..).map(|(pat, _)| pat).collect());
}
benefits.extend(
self.rewrite_patterns_set
.any_op_patterns()
.iter()
.filter_map(|p| filter_map_pattern_benefit(p, &model)),
);
benefits.sort_by_key(|(_, benefit)| *benefit);
self.match_any_patterns.extend(benefits.into_iter().map(|(pat, _)| pat));
}
#[inline]
pub fn apply_default_cost_model(&mut self) {
log::debug!(target: "pattern-rewrite-driver", "applying default cost model");
self.apply_cost_model(|pattern| *pattern.benefit());
}
pub fn walk_all_patterns<F>(&self, mut callback: F)
where
F: FnMut(Rc<dyn RewritePattern>),
{
for patterns in self.rewrite_patterns_set.op_specific_patterns().values() {
for pattern in patterns {
callback(Rc::clone(pattern));
}
}
for pattern in self.rewrite_patterns_set.any_op_patterns() {
callback(Rc::clone(pattern));
}
}
pub fn match_and_rewrite<A, F, S, R>(
&mut self,
op: OperationRef,
rewriter: &mut R,
can_apply: A,
mut on_failure: F,
mut on_success: S,
) -> Result<(), PatternApplicationError>
where
A: for<'a> Fn(&'a dyn RewritePattern) -> bool,
F: for<'a> FnMut(&'a dyn RewritePattern),
S: for<'a> FnMut(&'a dyn RewritePattern) -> Result<(), Report>,
R: Rewriter,
{
let op_name = {
let op = op.borrow();
op.name()
};
let op_specific_patterns = self.patterns.get(&op_name).map(|p| p.as_slice()).unwrap_or(&[]);
if op_specific_patterns.is_empty() {
log::trace!(
target: "pattern-rewrite-driver",
dialect = op_name.dialect().as_str(),
op = op_name.name().as_str();
"no op-specific patterns found for '{op_name}'"
);
} else {
log::trace!(
target: "pattern-rewrite-driver",
dialect = op_name.dialect().as_str(),
op = op_name.name().as_str();
"found {} op-specific patterns for '{op_name}'",
op_specific_patterns.len()
);
}
log::trace!(
target: "pattern-rewrite-driver",
dialect = op_name.dialect().as_str(),
op = op_name.name().as_str();
"{} op-agnostic patterns available",
self.match_any_patterns.len()
);
let mut op_patterns = op_specific_patterns.iter().peekable();
let mut any_op_patterns = self.match_any_patterns.iter().peekable();
let mut result = Err(PatternApplicationError::NoMatchesFound);
loop {
let mut best_pattern = op_patterns.peek().copied();
if let Some(next_any_pattern) = any_op_patterns
.next_if(|p| best_pattern.is_none_or(|bp| bp.benefit() < p.benefit()))
{
if let Some(best_pattern) = best_pattern {
log::trace!(
target: "pattern-rewrite-driver",
dialect = op_name.dialect().as_str(),
op = op_name.name().as_str();
"selected op-agnostic pattern '{}' because its benefit is higher than the \
next best op-specific pattern '{}'",
next_any_pattern.name(),
best_pattern.name()
);
} else {
log::trace!(
target: "pattern-rewrite-driver",
dialect = op_name.dialect().as_str(),
op = op_name.name().as_str();
"selected op-agnostic pattern '{}' because no op-specific pattern is \
available",
next_any_pattern.name()
);
}
best_pattern.replace(next_any_pattern);
} else {
if let Some(best_pattern) = best_pattern {
log::trace!(
target: "pattern-rewrite-driver",
dialect = op_name.dialect().as_str(),
op = op_name.name().as_str();
"selected op-specific pattern '{}'",
best_pattern.name()
);
}
best_pattern = op_patterns.next();
}
let Some(best_pattern) = best_pattern else {
log::trace!(
target: "pattern-rewrite-driver",
dialect = op_name.dialect().as_str(),
op = op_name.name().as_str();
"all patterns have been exhausted"
);
break;
};
let applicable = can_apply(&**best_pattern);
if !applicable {
log::trace!(
target: "pattern-rewrite-driver",
dialect = op_name.dialect().as_str(),
op = op_name.name().as_str();
"skipping pattern: can_apply returned false"
);
continue;
}
rewriter.set_insertion_point(ProgramPoint::before(op));
log::debug!(
target: "pattern-rewrite-driver",
dialect = op_name.dialect().as_str(),
op = op_name.name().as_str();
"trying to match '{}'",
best_pattern.name()
);
match best_pattern.match_and_rewrite(op, rewriter) {
Ok(matched) => {
if matched {
log::trace!(
target: "pattern-rewrite-driver",
dialect = op_name.dialect().as_str(),
op = op_name.name().as_str();
"pattern matched successfully"
);
result =
on_success(&**best_pattern).map_err(PatternApplicationError::Report);
break;
} else {
log::trace!(
target: "pattern-rewrite-driver",
dialect = op_name.dialect().as_str(),
op = op_name.name().as_str();
"failed to match pattern"
);
on_failure(&**best_pattern);
}
}
Err(err) => {
log::error!(
target: "pattern-rewrite-driver",
dialect = op_name.dialect().as_str(),
op = op_name.name().as_str();
"error occurred during match_and_rewrite: {err}"
);
result = Err(PatternApplicationError::Report(err));
on_failure(&**best_pattern);
}
}
}
result
}
}
fn filter_map_pattern_benefit<CostModel>(
pattern: &Rc<dyn RewritePattern>,
cost_model: &CostModel,
) -> Option<(Rc<dyn RewritePattern>, PatternBenefit)>
where
CostModel: Fn(&dyn RewritePattern) -> PatternBenefit,
{
let benefit = if pattern.benefit().is_impossible_to_match() {
PatternBenefit::NONE
} else {
cost_model(&**pattern)
};
if benefit.is_impossible_to_match() {
log::debug!(
target: "pattern-rewrite-driver",
"ignoring pattern '{}' ({}) because it is impossible to match or cannot lead to legal \
IR (by cost model)",
pattern.name(),
pattern.kind(),
);
None
} else {
Some((Rc::clone(pattern), benefit))
}
}