midenc_hir/patterns/
pattern.rs

1use alloc::rc::Rc;
2use core::{any::TypeId, fmt};
3
4use smallvec::SmallVec;
5
6use super::Rewriter;
7use crate::{interner, Context, OperationName, OperationRef, Report};
8
9#[derive(Debug)]
10pub enum PatternKind {
11    /// The pattern root matches any operation
12    Any,
13    /// The pattern root matches a specific named operation
14    Operation(OperationName),
15    /// The pattern root matches a specific trait
16    Trait(TypeId),
17}
18impl fmt::Display for PatternKind {
19    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20        match self {
21            Self::Any => f.write_str("for any"),
22            Self::Operation(name) => write!(f, "for operation '{name}'"),
23            Self::Trait(_) => write!(f, "for trait"),
24        }
25    }
26}
27
28/// Represents the benefit a pattern has.
29///
30/// More beneficial patterns are preferred over those with lesser benefit, while patterns with no
31/// benefit whatsoever can be discarded.
32///
33/// This is used to evaluate which patterns to apply, and in what order.
34#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)]
35#[repr(transparent)]
36pub struct PatternBenefit(Option<core::num::NonZeroU16>);
37impl PatternBenefit {
38    /// Represents a pattern which is the most beneficial
39    pub const MAX: Self = Self(core::num::NonZeroU16::new(u16::MAX));
40    /// Represents a pattern which is the least beneficial
41    pub const MIN: Self = Self(core::num::NonZeroU16::new(1));
42    /// Represents a pattern which can never match, and thus should be discarded
43    pub const NONE: Self = Self(None);
44
45    /// Create a new [PatternBenefit] from a raw [u16] value.
46    ///
47    /// A value of `u16::MAX` is treated as impossible to match, while values from `0..=65534` range
48    /// from the least beneficial to the most beneficial.
49    pub fn new(benefit: u16) -> Self {
50        if benefit == u16::MAX {
51            Self(None)
52        } else {
53            Self(core::num::NonZeroU16::new(benefit + 1))
54        }
55    }
56
57    /// Returns true if the pattern benefit indicates it can never match
58    #[inline]
59    pub fn is_impossible_to_match(&self) -> bool {
60        self.0.is_none()
61    }
62}
63
64impl PartialOrd for PatternBenefit {
65    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
66        Some(self.cmp(other))
67    }
68}
69impl Ord for PatternBenefit {
70    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
71        use core::cmp::Ordering;
72        match (self.0, other.0) {
73            (None, None) => Ordering::Equal,
74            // Impossible to match is always last
75            (None, Some(_)) => Ordering::Greater,
76            (Some(_), None) => Ordering::Less,
77            // Benefits are ordered in reverse of integer order (higher benefit appears earlier)
78            (Some(a), Some(b)) => a.get().cmp(&b.get()).reverse(),
79        }
80    }
81}
82
83pub trait Pattern {
84    fn info(&self) -> &PatternInfo;
85    /// A name used when printing diagnostics related to this pattern
86    #[inline(always)]
87    fn name(&self) -> &'static str {
88        self.info().name
89    }
90    /// The kind of value used to select candidate root operations for this pattern.
91    #[inline(always)]
92    fn kind(&self) -> &PatternKind {
93        &self.info().kind
94    }
95    /// Returns the benefit - the inverse of "cost" - of matching this pattern.
96    ///
97    /// The benefit of a [Pattern] is always static - rewrites that may have dynamic benefit can be
98    /// instantiated multiple times (different instances), for each benefit that they may return,
99    /// and be guarded by different match condition predicates.
100    #[inline(always)]
101    fn benefit(&self) -> &PatternBenefit {
102        &self.info().benefit
103    }
104    /// Returns true if this pattern is known to result in recursive application, i.e. this pattern
105    /// may generate IR that also matches this pattern, but is known to bound the recursion. This
106    /// signals to the rewrite driver that it is safe to apply this pattern recursively to the
107    /// generated IR.
108    #[inline(always)]
109    fn has_bounded_rewrite_recursion(&self) -> bool {
110        self.info().has_bounded_recursion
111    }
112    /// Return a list of operations that may be generated when rewriting an operation instance
113    /// with this pattern.
114    #[inline(always)]
115    fn generated_ops(&self) -> &[OperationName] {
116        &self.info().generated_ops
117    }
118    /// Return the root operation that this pattern matches.
119    ///
120    /// Patterns that can match multiple root types return `None`
121    #[inline(always)]
122    fn get_root_operation(&self) -> Option<OperationName> {
123        self.info().root_operation()
124    }
125    /// Return the trait id used to match the root operation of this pattern.
126    ///
127    /// If the pattern does not use a trait id for deciding the root match, this returns `None`
128    #[inline(always)]
129    fn get_root_trait(&self) -> Option<TypeId> {
130        self.info().get_root_trait()
131    }
132}
133
134/// [PatternBase] describes all of the data related to a pattern, but does not express any actual
135/// pattern logic, i.e. it is solely used for metadata about a pattern.
136pub struct PatternInfo {
137    #[allow(unused)]
138    context: Rc<Context>,
139    name: &'static str,
140    kind: PatternKind,
141    #[allow(unused)]
142    labels: SmallVec<[interner::Symbol; 1]>,
143    benefit: PatternBenefit,
144    has_bounded_recursion: bool,
145    generated_ops: SmallVec<[OperationName; 0]>,
146}
147
148impl PatternInfo {
149    /// Create a new [Pattern] from its component parts.
150    pub fn new(
151        context: Rc<Context>,
152        name: &'static str,
153        kind: PatternKind,
154        benefit: PatternBenefit,
155    ) -> Self {
156        Self {
157            context,
158            name,
159            kind,
160            labels: SmallVec::default(),
161            benefit,
162            has_bounded_recursion: false,
163            generated_ops: SmallVec::default(),
164        }
165    }
166
167    /// Set whether or not this pattern has bounded rewrite recursion
168    #[inline(always)]
169    pub fn with_bounded_rewrite_recursion(&mut self, yes: bool) -> &mut Self {
170        self.has_bounded_recursion = yes;
171        self
172    }
173
174    /// Return the root operation that this pattern matches.
175    ///
176    /// Patterns that can match multiple root types return `None`
177    pub fn root_operation(&self) -> Option<OperationName> {
178        match self.kind {
179            PatternKind::Operation(ref name) => Some(name.clone()),
180            _ => None,
181        }
182    }
183
184    /// Return the trait id used to match the root operation of this pattern.
185    ///
186    /// If the pattern does not use a trait id for deciding the root match, this returns `None`
187    pub fn root_trait(&self) -> Option<TypeId> {
188        match self.kind {
189            PatternKind::Trait(type_id) => Some(type_id),
190            _ => None,
191        }
192    }
193}
194
195impl Pattern for PatternInfo {
196    #[inline(always)]
197    fn info(&self) -> &PatternInfo {
198        self
199    }
200}
201
202/// A [RewritePattern] represents two things:
203///
204/// * A pattern which matches some IR that we're interested in, typically to replace with something
205///   else.
206/// * A rewrite which replaces IR that maches the pattern, with new IR, i.e. a DAG-to-DAG
207///   replacement
208pub trait RewritePattern: Pattern {
209    /// Attempt to match this pattern against the IR rooted at the specified operation, and rewrite
210    /// it if the match is successful.
211    ///
212    /// If applied, this rewrites the IR rooted at the matched operation, using the provided
213    /// [Rewriter] to generate new blocks and/or operations, or apply any modifications.
214    ///
215    /// If an unexpected error is encountered, i.e. an internal compiler error, it is emitted
216    /// through the normal diagnostic system, and the IR is left in a valid state.
217    fn match_and_rewrite(
218        &self,
219        op: OperationRef,
220        rewriter: &mut dyn Rewriter,
221    ) -> Result<bool, Report>;
222}
223
224#[cfg(test)]
225mod tests {
226    use alloc::{rc::Rc, string::ToString};
227
228    use pretty_assertions::{assert_eq, assert_str_eq};
229
230    use super::*;
231    use crate::{
232        dialects::{builtin::*, test::*},
233        patterns::*,
234        *,
235    };
236
237    /// In Miden, `n << 1` is vastly inferior to `n * 2` in cost, so reverse it
238    ///
239    /// NOTE: These two ops have slightly different semantics, a real implementation would have
240    /// to handle the edge cases.
241    struct ConvertShiftLeftBy1ToMultiply {
242        info: PatternInfo,
243    }
244    impl ConvertShiftLeftBy1ToMultiply {
245        pub fn new(context: Rc<Context>) -> Self {
246            let dialect = context.get_or_register_dialect::<TestDialect>();
247            let op_name = dialect.expect_registered_name::<Shl>();
248            let mut info = PatternInfo::new(
249                context,
250                "convert-shl1-to-mul2",
251                PatternKind::Operation(op_name),
252                PatternBenefit::new(1),
253            );
254            info.with_bounded_rewrite_recursion(true);
255            Self { info }
256        }
257    }
258    impl Pattern for ConvertShiftLeftBy1ToMultiply {
259        fn info(&self) -> &PatternInfo {
260            &self.info
261        }
262    }
263    impl RewritePattern for ConvertShiftLeftBy1ToMultiply {
264        fn match_and_rewrite(
265            &self,
266            op: OperationRef,
267            rewriter: &mut dyn Rewriter,
268        ) -> Result<bool, Report> {
269            use crate::matchers::{self, match_chain, match_op, MatchWith, Matcher};
270
271            let binder = MatchWith(|op: &UnsafeIntrusiveEntityRef<Shl>| {
272                log::trace!(
273                    "found matching 'hir.shl' operation, checking if `shift` operand is foldable"
274                );
275                let op = op.borrow();
276                let shift = op.shift().as_operand_ref();
277                let matched = matchers::foldable_operand_of::<Immediate>().matches(&shift);
278                matched.and_then(|imm| {
279                    log::trace!("`shift` operand is an immediate: {imm}");
280                    let imm = imm.as_u64();
281                    if imm.is_none() {
282                        log::trace!("`shift` operand is not a valid u64 value");
283                    }
284                    if imm.is_some_and(|imm| imm == 1) {
285                        Some(())
286                    } else {
287                        None
288                    }
289                })
290            });
291            log::trace!("attempting to match '{}'", self.name());
292            let matched = match_chain(match_op::<Shl>(), binder).matches(&op.borrow()).is_some();
293            log::trace!("'{}' matched: {matched}", self.name());
294
295            if !matched {
296                return Ok(false);
297            }
298
299            log::trace!("found match, rewriting '{}'", op.borrow().name());
300            let (span, lhs) = {
301                let shl = op.borrow();
302                let shl = shl.downcast_ref::<Shl>().unwrap();
303                let span = shl.span();
304                let lhs = shl.lhs().as_value_ref();
305                (span, lhs)
306            };
307            let constant_builder = rewriter.create::<Constant, _>(span);
308            let constant: UnsafeIntrusiveEntityRef<Constant> =
309                constant_builder(Immediate::U32(2)).unwrap();
310            let shift = constant.borrow().result().as_value_ref();
311            let mul_builder = rewriter.create::<Mul, _>(span);
312            let mul = mul_builder(lhs, shift, Overflow::Wrapping).unwrap();
313            let mul = mul.as_operation_ref();
314            log::trace!("replacing shl with mul");
315            rewriter.replace_op(op, mul);
316
317            Ok(true)
318        }
319    }
320
321    #[test]
322    fn rewrite_pattern_api_test() {
323        let mut builder = env_logger::Builder::from_env("MIDENC_TRACE");
324        builder.init();
325
326        let context = Rc::new(Context::default());
327        let pattern = ConvertShiftLeftBy1ToMultiply::new(Rc::clone(&context));
328
329        let mut builder = OpBuilder::new(Rc::clone(&context));
330        let function = {
331            let builder = builder.create::<Function, (_, _)>(SourceSpan::default());
332            let name = Ident::new("test".into(), SourceSpan::default());
333            let signature = Signature::new([AbiParam::new(Type::U32)], [AbiParam::new(Type::U32)]);
334            builder(name, signature).unwrap()
335        };
336
337        // Define function body
338        {
339            let mut builder = FunctionBuilder::new(function, &mut builder);
340            let shift = builder.u32(1, SourceSpan::default()).unwrap();
341            let block = builder.current_block();
342            let lhs = block.borrow().arguments()[0].upcast();
343            let result = builder.shl(lhs, shift, SourceSpan::default()).unwrap();
344            builder.ret(Some(result), SourceSpan::default()).unwrap();
345        }
346
347        // Construct pattern set
348        let mut rewrites = RewritePatternSet::new(builder.context_rc());
349        rewrites.push(pattern);
350        let rewrites = Rc::new(FrozenRewritePatternSet::new(rewrites));
351
352        // Execute pattern driver
353        let mut config = GreedyRewriteConfig::default();
354        config.with_region_simplification_level(RegionSimplificationLevel::None);
355        let result =
356            apply_patterns_and_fold_greedily(function.as_operation_ref(), rewrites, config);
357
358        // The rewrite should converge and modify the IR
359        assert_eq!(result, Ok(true));
360
361        // Confirm that the expected rewrite occurred
362        let func = function.borrow();
363        let output = func.as_operation().to_string();
364        let expected = "\
365public builtin.function @test(v0: u32) -> u32 {
366^block0(v0: u32):
367    v3 = test.constant 2 : u32;
368    v4 = test.mul v0, v3 : u32 #[overflow = wrapping];
369    builtin.ret v4;
370};";
371        assert_str_eq!(output.as_str(), expected);
372    }
373}