1use alloc::rc::Rc;
2use core::{any::TypeId, fmt};
3
4use smallvec::SmallVec;
5
6use super::Rewriter;
7use crate::{interner, Context, OperationName, OperationRef, Report};
8
9#[derive(Debug)]
10pub enum PatternKind {
11 Any,
13 Operation(OperationName),
15 Trait(TypeId),
17}
18impl fmt::Display for PatternKind {
19 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20 match self {
21 Self::Any => f.write_str("for any"),
22 Self::Operation(name) => write!(f, "for operation '{name}'"),
23 Self::Trait(_) => write!(f, "for trait"),
24 }
25 }
26}
27
28#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)]
35#[repr(transparent)]
36pub struct PatternBenefit(Option<core::num::NonZeroU16>);
37impl PatternBenefit {
38 pub const MAX: Self = Self(core::num::NonZeroU16::new(u16::MAX));
40 pub const MIN: Self = Self(core::num::NonZeroU16::new(1));
42 pub const NONE: Self = Self(None);
44
45 pub fn new(benefit: u16) -> Self {
50 if benefit == u16::MAX {
51 Self(None)
52 } else {
53 Self(core::num::NonZeroU16::new(benefit + 1))
54 }
55 }
56
57 #[inline]
59 pub fn is_impossible_to_match(&self) -> bool {
60 self.0.is_none()
61 }
62}
63
64impl PartialOrd for PatternBenefit {
65 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
66 Some(self.cmp(other))
67 }
68}
69impl Ord for PatternBenefit {
70 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
71 use core::cmp::Ordering;
72 match (self.0, other.0) {
73 (None, None) => Ordering::Equal,
74 (None, Some(_)) => Ordering::Greater,
76 (Some(_), None) => Ordering::Less,
77 (Some(a), Some(b)) => a.get().cmp(&b.get()).reverse(),
79 }
80 }
81}
82
83pub trait Pattern {
84 fn info(&self) -> &PatternInfo;
85 #[inline(always)]
87 fn name(&self) -> &'static str {
88 self.info().name
89 }
90 #[inline(always)]
92 fn kind(&self) -> &PatternKind {
93 &self.info().kind
94 }
95 #[inline(always)]
101 fn benefit(&self) -> &PatternBenefit {
102 &self.info().benefit
103 }
104 #[inline(always)]
109 fn has_bounded_rewrite_recursion(&self) -> bool {
110 self.info().has_bounded_recursion
111 }
112 #[inline(always)]
115 fn generated_ops(&self) -> &[OperationName] {
116 &self.info().generated_ops
117 }
118 #[inline(always)]
122 fn get_root_operation(&self) -> Option<OperationName> {
123 self.info().root_operation()
124 }
125 #[inline(always)]
129 fn get_root_trait(&self) -> Option<TypeId> {
130 self.info().get_root_trait()
131 }
132}
133
134pub struct PatternInfo {
137 #[allow(unused)]
138 context: Rc<Context>,
139 name: &'static str,
140 kind: PatternKind,
141 #[allow(unused)]
142 labels: SmallVec<[interner::Symbol; 1]>,
143 benefit: PatternBenefit,
144 has_bounded_recursion: bool,
145 generated_ops: SmallVec<[OperationName; 0]>,
146}
147
148impl PatternInfo {
149 pub fn new(
151 context: Rc<Context>,
152 name: &'static str,
153 kind: PatternKind,
154 benefit: PatternBenefit,
155 ) -> Self {
156 Self {
157 context,
158 name,
159 kind,
160 labels: SmallVec::default(),
161 benefit,
162 has_bounded_recursion: false,
163 generated_ops: SmallVec::default(),
164 }
165 }
166
167 #[inline(always)]
169 pub fn with_bounded_rewrite_recursion(&mut self, yes: bool) -> &mut Self {
170 self.has_bounded_recursion = yes;
171 self
172 }
173
174 pub fn root_operation(&self) -> Option<OperationName> {
178 match self.kind {
179 PatternKind::Operation(ref name) => Some(name.clone()),
180 _ => None,
181 }
182 }
183
184 pub fn root_trait(&self) -> Option<TypeId> {
188 match self.kind {
189 PatternKind::Trait(type_id) => Some(type_id),
190 _ => None,
191 }
192 }
193}
194
195impl Pattern for PatternInfo {
196 #[inline(always)]
197 fn info(&self) -> &PatternInfo {
198 self
199 }
200}
201
202pub trait RewritePattern: Pattern {
209 fn match_and_rewrite(
218 &self,
219 op: OperationRef,
220 rewriter: &mut dyn Rewriter,
221 ) -> Result<bool, Report>;
222}
223
224#[cfg(test)]
225mod tests {
226 use alloc::{rc::Rc, string::ToString};
227
228 use pretty_assertions::{assert_eq, assert_str_eq};
229
230 use super::*;
231 use crate::{
232 dialects::{builtin::*, test::*},
233 patterns::*,
234 *,
235 };
236
237 struct ConvertShiftLeftBy1ToMultiply {
242 info: PatternInfo,
243 }
244 impl ConvertShiftLeftBy1ToMultiply {
245 pub fn new(context: Rc<Context>) -> Self {
246 let dialect = context.get_or_register_dialect::<TestDialect>();
247 let op_name = dialect.expect_registered_name::<Shl>();
248 let mut info = PatternInfo::new(
249 context,
250 "convert-shl1-to-mul2",
251 PatternKind::Operation(op_name),
252 PatternBenefit::new(1),
253 );
254 info.with_bounded_rewrite_recursion(true);
255 Self { info }
256 }
257 }
258 impl Pattern for ConvertShiftLeftBy1ToMultiply {
259 fn info(&self) -> &PatternInfo {
260 &self.info
261 }
262 }
263 impl RewritePattern for ConvertShiftLeftBy1ToMultiply {
264 fn match_and_rewrite(
265 &self,
266 op: OperationRef,
267 rewriter: &mut dyn Rewriter,
268 ) -> Result<bool, Report> {
269 use crate::matchers::{self, match_chain, match_op, MatchWith, Matcher};
270
271 let binder = MatchWith(|op: &UnsafeIntrusiveEntityRef<Shl>| {
272 log::trace!(
273 "found matching 'hir.shl' operation, checking if `shift` operand is foldable"
274 );
275 let op = op.borrow();
276 let shift = op.shift().as_operand_ref();
277 let matched = matchers::foldable_operand_of::<Immediate>().matches(&shift);
278 matched.and_then(|imm| {
279 log::trace!("`shift` operand is an immediate: {imm}");
280 let imm = imm.as_u64();
281 if imm.is_none() {
282 log::trace!("`shift` operand is not a valid u64 value");
283 }
284 if imm.is_some_and(|imm| imm == 1) {
285 Some(())
286 } else {
287 None
288 }
289 })
290 });
291 log::trace!("attempting to match '{}'", self.name());
292 let matched = match_chain(match_op::<Shl>(), binder).matches(&op.borrow()).is_some();
293 log::trace!("'{}' matched: {matched}", self.name());
294
295 if !matched {
296 return Ok(false);
297 }
298
299 log::trace!("found match, rewriting '{}'", op.borrow().name());
300 let (span, lhs) = {
301 let shl = op.borrow();
302 let shl = shl.downcast_ref::<Shl>().unwrap();
303 let span = shl.span();
304 let lhs = shl.lhs().as_value_ref();
305 (span, lhs)
306 };
307 let constant_builder = rewriter.create::<Constant, _>(span);
308 let constant: UnsafeIntrusiveEntityRef<Constant> =
309 constant_builder(Immediate::U32(2)).unwrap();
310 let shift = constant.borrow().result().as_value_ref();
311 let mul_builder = rewriter.create::<Mul, _>(span);
312 let mul = mul_builder(lhs, shift, Overflow::Wrapping).unwrap();
313 let mul = mul.as_operation_ref();
314 log::trace!("replacing shl with mul");
315 rewriter.replace_op(op, mul);
316
317 Ok(true)
318 }
319 }
320
321 #[test]
322 fn rewrite_pattern_api_test() {
323 let mut builder = env_logger::Builder::from_env("MIDENC_TRACE");
324 builder.init();
325
326 let context = Rc::new(Context::default());
327 let pattern = ConvertShiftLeftBy1ToMultiply::new(Rc::clone(&context));
328
329 let mut builder = OpBuilder::new(Rc::clone(&context));
330 let function = {
331 let builder = builder.create::<Function, (_, _)>(SourceSpan::default());
332 let name = Ident::new("test".into(), SourceSpan::default());
333 let signature = Signature::new([AbiParam::new(Type::U32)], [AbiParam::new(Type::U32)]);
334 builder(name, signature).unwrap()
335 };
336
337 {
339 let mut builder = FunctionBuilder::new(function, &mut builder);
340 let shift = builder.u32(1, SourceSpan::default()).unwrap();
341 let block = builder.current_block();
342 let lhs = block.borrow().arguments()[0].upcast();
343 let result = builder.shl(lhs, shift, SourceSpan::default()).unwrap();
344 builder.ret(Some(result), SourceSpan::default()).unwrap();
345 }
346
347 let mut rewrites = RewritePatternSet::new(builder.context_rc());
349 rewrites.push(pattern);
350 let rewrites = Rc::new(FrozenRewritePatternSet::new(rewrites));
351
352 let mut config = GreedyRewriteConfig::default();
354 config.with_region_simplification_level(RegionSimplificationLevel::None);
355 let result =
356 apply_patterns_and_fold_greedily(function.as_operation_ref(), rewrites, config);
357
358 assert_eq!(result, Ok(true));
360
361 let func = function.borrow();
363 let output = func.as_operation().to_string();
364 let expected = "\
365public builtin.function @test(v0: u32) -> u32 {
366^block0(v0: u32):
367 v3 = test.constant 2 : u32;
368 v4 = test.mul v0, v3 : u32 #[overflow = wrapping];
369 builtin.ret v4;
370};";
371 assert_str_eq!(output.as_str(), expected);
372 }
373}