use alloc::{rc::Rc, vec::Vec};
use core::cell::RefCell;
use smallvec::SmallVec;
use super::{
ForwardingListener, FrozenRewritePatternSet, PatternApplicator, PatternRewriter, Rewriter,
RewriterListener,
};
use crate::{
BlockRef, Builder, Context, Forward, InsertionGuard, Listener, OpFoldResult, OperationFolder,
OperationRef, ProgramPoint, RawWalk, Region, RegionRef, Report, SourceSpan, Spanned, Value,
ValueRef, WalkResult,
adt::SmallSet,
formatter::DisplayValues,
patterns::{PatternApplicationError, RewritePattern, TracingRewriterListener},
traits::{ConstantLike, Foldable, IsolatedFromAbove},
};
pub fn apply_patterns_and_fold_region_greedily(
region: RegionRef,
patterns: Rc<FrozenRewritePatternSet>,
mut config: GreedyRewriteConfig,
) -> Result<bool, bool> {
let context = {
let parent_op = region.parent().unwrap().borrow();
assert!(
parent_op.implements::<dyn IsolatedFromAbove>(),
"patterns can only be applied to operations which are isolated from above"
);
parent_op.context_rc()
};
if config.scope.is_none() {
config.scope = Some(region);
}
let mut driver = RegionPatternRewriteDriver::new(context, patterns, config, region);
let converged = driver.simplify();
if converged.is_err() {
if let Some(max_iterations) = driver.driver.config.max_iterations {
log::trace!(target: "pattern-rewrite-driver", "pattern rewrite did not converge after scanning {max_iterations} times");
} else {
log::trace!(target: "pattern-rewrite-driver", "pattern rewrite did not converge");
}
}
converged
}
pub fn apply_patterns_and_fold_greedily(
op: OperationRef,
patterns: Rc<FrozenRewritePatternSet>,
config: GreedyRewriteConfig,
) -> Result<bool, bool> {
let mut any_region_changed = false;
let mut failed = false;
let op = op.borrow();
let mut cursor = op.regions().front();
while let Some(region) = cursor.as_pointer() {
cursor.move_next();
match apply_patterns_and_fold_region_greedily(region, patterns.clone(), config.clone()) {
Ok(region_changed) => {
any_region_changed |= region_changed;
}
Err(region_changed) => {
any_region_changed |= region_changed;
failed = true;
}
}
}
if failed {
Err(any_region_changed)
} else {
Ok(any_region_changed)
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[repr(u8)]
pub enum ApplyPatternsAndFoldEffect {
None,
Changed,
Erased,
}
pub type ApplyPatternsAndFoldResult =
Result<ApplyPatternsAndFoldEffect, ApplyPatternsAndFoldEffect>;
pub fn apply_patterns_and_fold(
ops: &[OperationRef],
patterns: Rc<FrozenRewritePatternSet>,
mut config: GreedyRewriteConfig,
) -> ApplyPatternsAndFoldResult {
if ops.is_empty() {
return Ok(ApplyPatternsAndFoldEffect::None);
}
if let Some(scope) = config.scope.as_ref() {
let all_ops_in_scope = ops.iter().all(|op| scope.borrow().find_ancestor_op(*op).is_some());
assert!(all_ops_in_scope, "ops must be within the specified scope");
} else {
config.scope = Region::find_common_ancestor(ops);
}
let max_rewrites = config.max_rewrites.map(|max| max.get()).unwrap_or(u32::MAX);
let context = ops[0].borrow().context_rc();
let mut driver = MultiOpPatternRewriteDriver::new(context, patterns, config, ops);
let converged = driver.simplify(ops);
let changed = match converged.as_ref() {
Ok(changed) | Err(changed) => *changed,
};
let erased = driver.inner.surviving_ops.borrow().is_empty();
let effect = if erased {
ApplyPatternsAndFoldEffect::Erased
} else if changed {
ApplyPatternsAndFoldEffect::Changed
} else {
ApplyPatternsAndFoldEffect::None
};
if converged.is_ok() {
Ok(effect)
} else {
log::trace!(target: "pattern-rewrite-driver", "pattern rewrite did not converge after {max_rewrites} rewrites");
Err(effect)
}
}
#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)]
pub enum GreedyRewriteStrictness {
#[default]
Any,
ExistingAndNew,
Existing,
}
#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)]
pub enum RegionSimplificationLevel {
None,
#[default]
Normal,
Aggressive,
}
#[derive(Clone)]
pub struct GreedyRewriteConfig {
listener: Option<Rc<dyn RewriterListener>>,
scope: Option<RegionRef>,
max_iterations: Option<core::num::NonZeroU32>,
max_rewrites: Option<core::num::NonZeroU32>,
region_simplification: RegionSimplificationLevel,
restrict: GreedyRewriteStrictness,
use_top_down_traversal: bool,
}
impl Default for GreedyRewriteConfig {
fn default() -> Self {
Self {
listener: if log::log_enabled!(target: "rewriter", log::Level::Trace) {
Some(Rc::new(TracingRewriterListener))
} else {
None
},
scope: None,
max_iterations: core::num::NonZeroU32::new(10),
max_rewrites: None,
region_simplification: Default::default(),
restrict: Default::default(),
use_top_down_traversal: false,
}
}
}
impl GreedyRewriteConfig {
pub fn new_with_listener(listener: impl RewriterListener + 'static) -> Self {
Self {
listener: Some(Rc::new(listener)),
..Default::default()
}
}
pub fn with_scope(&mut self, region: RegionRef) -> &mut Self {
self.scope = Some(region);
self
}
pub fn with_max_iterations(&mut self, max: u32) -> &mut Self {
self.max_iterations = core::num::NonZeroU32::new(max);
self
}
pub fn with_max_rewrites(&mut self, max: u32) -> &mut Self {
self.max_rewrites = core::num::NonZeroU32::new(max);
self
}
pub fn with_region_simplification_level(
&mut self,
level: RegionSimplificationLevel,
) -> &mut Self {
self.region_simplification = level;
self
}
pub fn with_restrictions(&mut self, level: GreedyRewriteStrictness) -> &mut Self {
self.restrict = level;
self
}
pub fn with_top_down_traversal(&mut self, yes: bool) -> &mut Self {
self.use_top_down_traversal = yes;
self
}
#[inline]
pub fn scope(&self) -> Option<RegionRef> {
self.scope
}
#[inline]
pub fn max_iterations(&self) -> Option<core::num::NonZeroU32> {
self.max_iterations
}
#[inline]
pub fn max_rewrites(&self) -> Option<core::num::NonZeroU32> {
self.max_rewrites
}
#[inline]
pub fn region_simplification_level(&self) -> RegionSimplificationLevel {
self.region_simplification
}
#[inline]
pub fn strictness(&self) -> GreedyRewriteStrictness {
self.restrict
}
#[inline]
pub fn use_top_down_traversal(&self) -> bool {
self.use_top_down_traversal
}
}
pub struct GreedyPatternRewriteDriver {
context: Rc<Context>,
worklist: RefCell<Worklist>,
config: GreedyRewriteConfig,
filtered_ops: RefCell<SmallSet<OperationRef, 8>>,
matcher: RefCell<PatternApplicator>,
}
impl GreedyPatternRewriteDriver {
pub fn new(
context: Rc<Context>,
patterns: Rc<FrozenRewritePatternSet>,
config: GreedyRewriteConfig,
) -> Self {
let mut matcher = PatternApplicator::new(patterns);
matcher.apply_default_cost_model();
Self {
context,
worklist: Default::default(),
config,
filtered_ops: Default::default(),
matcher: RefCell::new(matcher),
}
}
}
impl GreedyPatternRewriteDriver {
pub fn add_single_op_to_worklist(&self, op: OperationRef) {
if matches!(self.config.restrict, GreedyRewriteStrictness::Any)
|| self.filtered_ops.borrow().contains(&op)
{
log::trace!(target: "pattern-rewrite-driver", "adding single op '{}' to worklist", op.name());
self.worklist.borrow_mut().push(op);
} else {
log::trace!(
target: "pattern-rewrite-driver", "skipped adding single op '{}' to worklist due to strictness level",
op.name()
);
}
}
pub fn add_to_worklist(&self, op: OperationRef) {
let mut ancestors = SmallVec::<[OperationRef; 8]>::default();
let mut op = Some(op);
while let Some(ancestor_op) = op.take() {
let region = ancestor_op.grandparent();
if self.config.scope.as_ref() == region.as_ref() {
ancestors.push(ancestor_op);
for op in ancestors {
self.add_single_op_to_worklist(op);
}
return;
} else {
log::trace!(target: "pattern-rewrite-driver", "gathering ancestors of '{}' for worklist", ancestor_op.name());
ancestors.push(ancestor_op);
}
if let Some(region) = region {
op = region.parent();
} else {
log::trace!(target: "pattern-rewrite-driver", "reached top level op while searching for ancestors");
}
}
}
pub fn process_worklist(self: Rc<Self>) -> bool {
log::debug!(target: "pattern-rewrite-driver", "starting processing of greedy pattern rewrite driver worklist");
let mut rewriter =
PatternRewriter::new_with_listener(self.context.clone(), Rc::clone(&self));
let mut changed = false;
let mut num_rewrites = 0u32;
while self.config.max_rewrites.is_none_or(|max| num_rewrites < max.get()) {
let Some(op) = self.worklist.borrow_mut().pop() else {
log::debug!(target: "pattern-rewrite-driver", "processing worklist complete, rewrites have converged");
return changed;
};
if self.process_worklist_item(&mut rewriter, op) {
changed = true;
num_rewrites += 1;
}
}
log::debug!(
target: "pattern-rewrite-driver", "processing worklist was canceled after {} rewrites without converging (reached max \
rewrite limit)",
self.config.max_rewrites.map(|max| max.get()).unwrap_or(u32::MAX)
);
changed
}
fn process_worklist_item(
&self,
rewriter: &mut PatternRewriter<Rc<Self>>,
op_ref: OperationRef,
) -> bool {
log::trace!(target: "pattern-rewrite-driver", "processing operation '{op_ref}'");
let op = op_ref.borrow();
if op.is_trivially_dead() {
drop(op);
rewriter.erase_op(op_ref);
log::trace!(target: "pattern-rewrite-driver", "processing complete: operation is trivially dead");
return true;
}
if !op.implements::<dyn ConstantLike>() {
let mut results = SmallVec::<[OpFoldResult; 1]>::default();
log::trace!(target: "pattern-rewrite-driver", "attempting to fold operation..");
if op.fold(&mut results).is_ok() {
if results.is_empty() {
self.notify_operation_modified(op_ref);
log::trace!(
target: "pattern-rewrite-driver",
"operation was succesfully folded/modified in-place"
);
return true;
} else {
log::trace!(
target: "pattern-rewrite-driver",
"operation was succesfully folded away, to be replaced with: {}",
DisplayValues::new(results.iter())
);
}
assert_eq!(
results.len(),
op.num_results(),
"folder produced incorrect number of results"
);
let mut rewriter = InsertionGuard::new(&mut **rewriter);
rewriter.set_insertion_point(ProgramPoint::before(op_ref));
log::trace!(target: "pattern-rewrite-driver", "replacing op with fold results..");
let mut replacements = SmallVec::<[Option<ValueRef>; 2]>::default();
let mut materialization_succeeded = true;
for (fold_result, result_ty) in results
.into_iter()
.zip(op.results().all().iter().map(|r| r.borrow().ty().clone()))
{
match fold_result {
OpFoldResult::Value(value) => {
assert_eq!(
value.borrow().ty(),
&result_ty,
"folder produced value of incorrect type"
);
replacements.push(Some(value));
}
OpFoldResult::Attribute(attr) => {
let span = op.span();
log::trace!(
target: "pattern-rewrite-driver",
"materializing constant for value '{attr:?}' and type '{result_ty}'",
);
let constant_op = op.dialect().materialize_constant(
&mut *rewriter,
attr,
&result_ty,
span,
);
match constant_op {
None => {
log::trace!(
target: "pattern-rewrite-driver",
"materialization failed: cleaning up any materialized ops \
for {} previous results",
replacements.len()
);
let mut replacement_ops =
SmallVec::<[OperationRef; 2]>::default();
for replacement in replacements.iter().filter_map(|repl| *repl)
{
let replacement = replacement.borrow();
assert!(
!replacement.is_used(),
"folder reused existing op for one result, but \
constant materialization failed for another result"
);
let replacement_op = replacement.get_defining_op().unwrap();
if replacement_ops.contains(&replacement_op) {
continue;
}
replacement_ops.push(replacement_op);
}
for replacement_op in replacement_ops {
rewriter.erase_op(replacement_op);
}
materialization_succeeded = false;
break;
}
Some(constant_op) => {
let const_op = constant_op.borrow();
assert!(
const_op.implements::<dyn ConstantLike>(),
"materialize_constant produced op that does not implement \
ConstantLike"
);
let result: ValueRef = const_op.results().all()[0].upcast();
assert_eq!(
result.borrow().ty(),
&result_ty,
"materialize_constant produced incorrect result type"
);
log::trace!(
target: "pattern-rewrite-driver",
"successfully materialized constant as {}",
result.borrow().id()
);
replacements.push(Some(result));
}
}
}
}
}
if materialization_succeeded {
log::trace!(
target: "pattern-rewrite-driver",
"materialization of fold results was successful, performing replacement.."
);
drop(op);
rewriter.replace_op_with_values(op_ref, &replacements);
log::trace!(
target: "pattern-rewrite-driver",
"fold succeeded: operation was replaced with materialized constants"
);
return true;
} else {
log::trace!(
target: "pattern-rewrite-driver",
"materialization of fold results failed, proceeding without folding"
);
}
}
} else {
log::trace!(target: "pattern-rewrite-driver", "operation could not be folded");
}
drop(op);
log::trace!(target: "pattern-rewrite-driver", "attempting to match and rewrite one of the input patterns..");
let result = if let Some(listener) = self.config.listener.as_deref() {
let op_name = op_ref.name();
let can_apply = |pattern: &dyn RewritePattern| {
log::trace!(target: "pattern-rewrite-driver", "applying pattern {} to op {}", pattern.name(), &op_name);
listener.notify_pattern_begin(pattern, op_ref);
true
};
let on_failure = |pattern: &dyn RewritePattern| {
log::trace!(target: "pattern-rewrite-driver", "pattern failed to match");
listener.notify_pattern_end(pattern, false);
};
let on_success = |pattern: &dyn RewritePattern| {
log::trace!(target: "pattern-rewrite-driver", "pattern applied successfully");
listener.notify_pattern_end(pattern, true);
Ok(())
};
self.matcher.borrow_mut().match_and_rewrite(
op_ref,
&mut **rewriter,
can_apply,
on_failure,
on_success,
)
} else {
self.matcher.borrow_mut().match_and_rewrite(
op_ref,
&mut **rewriter,
|_| true,
|_| {},
|_| Ok(()),
)
};
match result {
Ok(_) => {
log::trace!(target: "pattern-rewrite-driver", "processing complete: pattern matched and operation was rewritten");
true
}
Err(PatternApplicationError::NoMatchesFound) => {
log::debug!(target: "pattern-rewrite-driver", "processing complete: exhausted all patterns without finding a match");
false
}
Err(PatternApplicationError::Report(report)) => {
log::debug!(
target: "pattern-rewrite-driver", "processing complete: error occurred during match and rewrite: {report}"
);
false
}
}
}
fn add_operands_to_worklist(&self, op: OperationRef) {
let current_op = op.borrow();
for operand in current_op.operands().all() {
let operand = operand.borrow();
let Some(def_op) = operand.value().get_defining_op() else {
continue;
};
let mut other_user = None;
let mut has_more_than_two_uses = false;
for user in operand.value().iter_uses() {
if user.owner == op || other_user.as_ref().is_some_and(|ou| ou == &user.owner) {
continue;
}
if other_user.is_none() {
other_user = Some(user.owner);
continue;
}
has_more_than_two_uses = true;
break;
}
if !has_more_than_two_uses {
self.add_to_worklist(def_op);
}
}
}
}
impl Listener for GreedyPatternRewriteDriver {
fn kind(&self) -> crate::ListenerType {
crate::ListenerType::Rewriter
}
fn notify_block_inserted(
&self,
block: crate::BlockRef,
prev: Option<RegionRef>,
ip: Option<crate::BlockRef>,
) {
if let Some(listener) = self.config.listener.as_deref() {
listener.notify_block_inserted(block, prev, ip);
}
}
fn notify_operation_inserted(&self, op: OperationRef, prev: ProgramPoint) {
if let Some(listener) = self.config.listener.as_deref() {
listener.notify_operation_inserted(op, prev);
}
if matches!(self.config.restrict, GreedyRewriteStrictness::ExistingAndNew) {
self.filtered_ops.borrow_mut().insert(op);
}
self.add_to_worklist(op);
}
}
impl RewriterListener for GreedyPatternRewriteDriver {
fn notify_block_erased(&self, block: BlockRef) {
if let Some(listener) = self.config.listener.as_deref() {
listener.notify_block_erased(block);
}
}
fn notify_operation_modified(&self, op: OperationRef) {
if let Some(listener) = self.config.listener.as_deref() {
listener.notify_operation_modified(op);
}
self.add_to_worklist(op);
}
fn notify_operation_erased(&self, op: OperationRef) {
if let Some(scope) = self.config.scope.as_ref() {
assert!(
scope.parent().is_some_and(|parent_op| parent_op != op),
"scope region must not be erased during greedy pattern rewrite"
);
}
if let Some(listener) = self.config.listener.as_deref() {
listener.notify_operation_erased(op);
}
self.add_operands_to_worklist(op);
self.worklist.borrow_mut().remove(&op);
if self.config.restrict != GreedyRewriteStrictness::Any {
self.filtered_ops.borrow_mut().remove(&op);
}
}
fn notify_operation_replaced_with_values(
&self,
op: OperationRef,
replacement: &[Option<ValueRef>],
) {
if let Some(listener) = self.config.listener.as_deref() {
listener.notify_operation_replaced_with_values(op, replacement);
}
}
fn notify_match_failure(&self, span: SourceSpan, reason: Report) {
if let Some(listener) = self.config.listener.as_deref() {
listener.notify_match_failure(span, reason);
}
}
}
pub struct RegionPatternRewriteDriver {
driver: Rc<GreedyPatternRewriteDriver>,
region: RegionRef,
}
impl RegionPatternRewriteDriver {
pub fn new(
context: Rc<Context>,
patterns: Rc<FrozenRewritePatternSet>,
config: GreedyRewriteConfig,
region: RegionRef,
) -> Self {
let mut driver = GreedyPatternRewriteDriver::new(context, patterns, config);
if driver.config.restrict != GreedyRewriteStrictness::Any {
let filtered_ops = driver.filtered_ops.get_mut();
region.raw_postwalk_all::<Forward, _>(|op| {
filtered_ops.insert(op);
});
}
Self {
driver: Rc::new(driver),
region,
}
}
pub fn simplify(&mut self) -> Result<bool, bool> {
use crate::matchers::Matcher;
let mut continue_rewrites = false;
let mut iteration = 0;
while self.driver.config.max_iterations.is_none_or(|max| iteration < max.get()) {
log::trace!(target: "pattern-rewrite-driver", "starting iteration {iteration} of region pattern rewrite driver");
iteration += 1;
self.driver.worklist.borrow_mut().clear();
let context = self.driver.context.clone();
let mut folder = OperationFolder::new(context, Rc::clone(&self.driver));
let mut insert_known_constant = |op: OperationRef| {
let operation = op.borrow();
if let Some(const_value) = crate::matchers::constant().matches(&operation) {
drop(operation);
if !folder.insert_known_constant(op, Some(const_value)) {
return true;
}
}
false
};
if !self.driver.config.use_top_down_traversal {
log::trace!(target: "pattern-rewrite-driver", "adding operations in postorder");
self.region.raw_postwalk_all::<Forward, _>(|op| {
if !insert_known_constant(op) {
self.driver.add_to_worklist(op);
}
});
} else {
log::trace!(target: "pattern-rewrite-driver", "adding operations in preorder");
self.region
.raw_prewalk::<Forward, _, _>(|op| {
if !insert_known_constant(op) {
self.driver.add_to_worklist(op);
WalkResult::<Report>::Continue(())
} else {
WalkResult::Skip
}
})
.into_result()
.expect("unexpected error occurred while walking region");
self.driver.worklist.borrow_mut().reverse();
}
continue_rewrites = self.driver.clone().process_worklist();
log::trace!(
target: "pattern-rewrite-driver", "processing of worklist for this iteration has completed, \
changed={continue_rewrites}"
);
if self.driver.config.region_simplification != RegionSimplificationLevel::None {
let mut rewriter = PatternRewriter::new_with_listener(
self.driver.context.clone(),
Rc::clone(&self.driver),
);
continue_rewrites |= Region::simplify_all(
&[self.region],
&mut *rewriter,
self.driver.config.region_simplification,
)
.is_ok();
} else {
log::debug!(target: "pattern-rewrite-driver", "region simplification was disabled, skipping simplification rewrites");
}
if !continue_rewrites {
log::trace!(target: "pattern-rewrite-driver", "region pattern rewrites have converged");
break;
}
}
if !continue_rewrites {
Ok(iteration > 1)
} else {
Err(iteration > 1)
}
}
}
pub struct MultiOpPatternRewriteDriver {
driver: Rc<GreedyPatternRewriteDriver>,
inner: Rc<MultiOpPatternRewriteDriverImpl>,
}
struct MultiOpPatternRewriteDriverImpl {
surviving_ops: RefCell<SmallSet<OperationRef, 8>>,
}
impl MultiOpPatternRewriteDriver {
pub fn new(
context: Rc<Context>,
patterns: Rc<FrozenRewritePatternSet>,
mut config: GreedyRewriteConfig,
ops: &[OperationRef],
) -> Self {
let surviving_ops = SmallSet::from_iter(ops.iter().copied());
let inner = Rc::new(MultiOpPatternRewriteDriverImpl {
surviving_ops: RefCell::new(surviving_ops),
});
let listener = Rc::new(ForwardingListener::new(config.listener.take(), Rc::clone(&inner)));
config.listener = Some(listener);
let mut driver = GreedyPatternRewriteDriver::new(context.clone(), patterns, config);
if driver.config.restrict != GreedyRewriteStrictness::Any {
driver.filtered_ops.get_mut().extend(ops.iter().cloned());
}
Self {
driver: Rc::new(driver),
inner,
}
}
pub fn simplify(&mut self, ops: &[OperationRef]) -> Result<bool, bool> {
for op in ops.iter().copied() {
self.driver.add_single_op_to_worklist(op);
}
let changed = self.driver.clone().process_worklist();
if self.driver.worklist.borrow().is_empty() {
Ok(changed)
} else {
Err(changed)
}
}
}
impl Listener for MultiOpPatternRewriteDriverImpl {
fn kind(&self) -> crate::ListenerType {
crate::ListenerType::Rewriter
}
}
impl RewriterListener for MultiOpPatternRewriteDriverImpl {
fn notify_operation_erased(&self, op: OperationRef) {
self.surviving_ops.borrow_mut().remove(&op);
}
}
#[derive(Default)]
struct Worklist(Vec<OperationRef>);
impl Worklist {
#[inline]
pub fn clear(&mut self) {
self.0.clear()
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn push(&mut self, op: OperationRef) {
if self.0.contains(&op) {
return;
}
self.0.push(op);
}
#[inline]
pub fn pop(&mut self) -> Option<OperationRef> {
self.0.pop()
}
pub fn remove(&mut self, op: &OperationRef) {
if let Some(index) = self.0.iter().position(|o| o == op) {
self.0.remove(index);
}
}
pub fn reverse(&mut self) {
self.0.reverse();
}
}