use alloc::rc::Rc;
use core::{any::TypeId, fmt};
use smallvec::SmallVec;
use super::Rewriter;
use crate::{Context, OperationName, OperationRef, Report, interner};
#[derive(Debug)]
pub enum PatternKind {
Any,
Operation(OperationName),
Trait(TypeId),
}
impl fmt::Display for PatternKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Any => f.write_str("for any"),
Self::Operation(name) => write!(f, "for operation '{name}'"),
Self::Trait(_) => write!(f, "for trait"),
}
}
}
#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)]
#[repr(transparent)]
pub struct PatternBenefit(Option<core::num::NonZeroU16>);
impl PatternBenefit {
pub const MAX: Self = Self(core::num::NonZeroU16::new(u16::MAX));
pub const MIN: Self = Self(core::num::NonZeroU16::new(1));
pub const NONE: Self = Self(None);
pub fn new(benefit: u16) -> Self {
if benefit == u16::MAX {
Self(None)
} else {
Self(core::num::NonZeroU16::new(benefit + 1))
}
}
#[inline]
pub fn is_impossible_to_match(&self) -> bool {
self.0.is_none()
}
}
impl PartialOrd for PatternBenefit {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for PatternBenefit {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
use core::cmp::Ordering;
match (self.0, other.0) {
(None, None) => Ordering::Equal,
(None, Some(_)) => Ordering::Greater,
(Some(_), None) => Ordering::Less,
(Some(a), Some(b)) => a.get().cmp(&b.get()).reverse(),
}
}
}
pub trait Pattern {
fn info(&self) -> &PatternInfo;
#[inline(always)]
fn name(&self) -> &'static str {
self.info().name
}
#[inline(always)]
fn kind(&self) -> &PatternKind {
&self.info().kind
}
#[inline(always)]
fn benefit(&self) -> &PatternBenefit {
&self.info().benefit
}
#[inline(always)]
fn has_bounded_rewrite_recursion(&self) -> bool {
self.info().has_bounded_recursion
}
#[inline(always)]
fn generated_ops(&self) -> &[OperationName] {
&self.info().generated_ops
}
#[inline(always)]
fn get_root_operation(&self) -> Option<OperationName> {
self.info().root_operation()
}
#[inline(always)]
fn get_root_trait(&self) -> Option<TypeId> {
self.info().get_root_trait()
}
}
pub struct PatternInfo {
#[allow(unused)]
context: Rc<Context>,
name: &'static str,
kind: PatternKind,
#[allow(unused)]
labels: SmallVec<[interner::Symbol; 1]>,
benefit: PatternBenefit,
has_bounded_recursion: bool,
generated_ops: SmallVec<[OperationName; 0]>,
}
impl PatternInfo {
pub fn new(
context: Rc<Context>,
name: &'static str,
kind: PatternKind,
benefit: PatternBenefit,
) -> Self {
Self {
context,
name,
kind,
labels: SmallVec::default(),
benefit,
has_bounded_recursion: false,
generated_ops: SmallVec::default(),
}
}
#[inline(always)]
pub fn with_bounded_rewrite_recursion(&mut self, yes: bool) -> &mut Self {
self.has_bounded_recursion = yes;
self
}
pub fn root_operation(&self) -> Option<OperationName> {
match self.kind {
PatternKind::Operation(ref name) => Some(name.clone()),
_ => None,
}
}
pub fn root_trait(&self) -> Option<TypeId> {
match self.kind {
PatternKind::Trait(type_id) => Some(type_id),
_ => None,
}
}
}
impl Pattern for PatternInfo {
#[inline(always)]
fn info(&self) -> &PatternInfo {
self
}
}
pub trait RewritePattern: Pattern {
fn match_and_rewrite(
&self,
op: OperationRef,
rewriter: &mut dyn Rewriter,
) -> Result<bool, Report>;
}
#[cfg(test)]
mod tests {
use alloc::{rc::Rc, string::ToString};
use pretty_assertions::{assert_eq, assert_str_eq};
use super::*;
use crate::{
attributes::IntegerLikeAttr,
dialects::{builtin::*, test::*},
patterns::*,
testing::Test,
*,
};
struct ConvertShiftLeftBy1ToMultiply {
info: PatternInfo,
}
impl ConvertShiftLeftBy1ToMultiply {
pub fn new(context: Rc<Context>) -> Self {
let dialect = context.get_or_register_dialect::<TestDialect>();
let op_name = dialect.expect_registered_name::<Shl>();
let mut info = PatternInfo::new(
context,
"convert-shl1-to-mul2",
PatternKind::Operation(op_name),
PatternBenefit::new(1),
);
info.with_bounded_rewrite_recursion(true);
Self { info }
}
}
impl Pattern for ConvertShiftLeftBy1ToMultiply {
fn info(&self) -> &PatternInfo {
&self.info
}
}
impl RewritePattern for ConvertShiftLeftBy1ToMultiply {
fn match_and_rewrite(
&self,
op: OperationRef,
rewriter: &mut dyn Rewriter,
) -> Result<bool, Report> {
use crate::matchers::{self, MatchWith, Matcher, match_chain, match_op};
let binder = MatchWith(|op: &UnsafeIntrusiveEntityRef<Shl>| {
log::trace!(
"found matching 'hir.shl' operation, checking if `shift` operand is foldable"
);
let op = op.borrow();
let shift = op.shift().as_operand_ref();
let matched =
matchers::foldable_operand_of_trait::<dyn IntegerLikeAttr>().matches(&shift);
matched.and_then(|imm| {
let imm = imm.borrow().as_immediate();
log::trace!("`shift` operand is an immediate: {imm}");
let imm = imm.as_u64();
if imm.is_none() {
log::trace!("`shift` operand is not a valid u64 value");
}
if imm.is_some_and(|imm| imm == 1) {
Some(())
} else {
None
}
})
});
log::trace!("attempting to match '{}'", self.name());
let matched = match_chain(match_op::<Shl>(), binder).matches(&op.borrow()).is_some();
log::trace!("'{}' matched: {matched}", self.name());
if !matched {
return Ok(false);
}
log::trace!("found match, rewriting '{}'", op.borrow().name());
let (span, lhs) = {
let shl = op.borrow();
let shl = shl.downcast_ref::<Shl>().unwrap();
let span = shl.span();
let lhs = shl.lhs().as_value_ref();
(span, lhs)
};
let constant_builder = rewriter.create::<Constant, _>(span);
let constant: UnsafeIntrusiveEntityRef<Constant> =
constant_builder(Immediate::U32(2)).unwrap();
let shift = constant.borrow().result().as_value_ref();
let mul_builder = rewriter.create::<Mul, _>(span);
let mul = mul_builder(lhs, shift, Overflow::Wrapping).unwrap();
let mul = mul.as_operation_ref();
log::trace!("replacing shl with mul");
rewriter.replace_op(op, mul);
Ok(true)
}
}
#[test]
fn rewrite_pattern_api_test() {
let mut test = Test::new("rewrite_pattern_api_test", &[Type::U32], &[Type::U32]);
let pattern = ConvertShiftLeftBy1ToMultiply::new(test.context_rc());
{
let mut builder = test.function_builder();
let shift = builder.u32(1, SourceSpan::default()).unwrap();
let block = builder.current_block();
let lhs = block.borrow().arguments()[0] as ValueRef;
let result = builder.shl(lhs, shift, SourceSpan::default()).unwrap();
builder.ret(Some(result), SourceSpan::default()).unwrap();
}
let mut rewrites = RewritePatternSet::new(test.context_rc());
rewrites.push(pattern);
let rewrites = Rc::new(FrozenRewritePatternSet::new(rewrites));
let mut config = GreedyRewriteConfig::default();
config.with_region_simplification_level(RegionSimplificationLevel::None);
let result =
apply_patterns_and_fold_greedily(test.function().as_operation_ref(), rewrites, config);
assert_eq!(result, Ok(true));
let func = test.function().borrow();
let output = func.as_operation().to_string();
let expected = "\
builtin.function public extern(\"C\") @rewrite_pattern_api_test(%0: u32) -> u32 {
%3 = test.constant 2 : u32;
%4 = test.mul %0, %3 <{ overflow = #builtin.overflow<wrapping> }>;
builtin.ret %4 : (u32);
};";
assert_str_eq!(output.as_str(), expected);
}
}