midenc_hir/patterns/
pattern_set.rs

1use alloc::{boxed::Box, collections::BTreeMap, rc::Rc, vec, vec::Vec};
2
3use smallvec::SmallVec;
4
5use super::*;
6use crate::{Context, OperationName};
7
8pub struct RewritePatternSet {
9    context: Rc<Context>,
10    patterns: Vec<Box<dyn RewritePattern>>,
11}
12impl RewritePatternSet {
13    pub fn new(context: Rc<Context>) -> Self {
14        Self {
15            context,
16            patterns: vec![],
17        }
18    }
19
20    pub fn from_iter<P>(context: Rc<Context>, patterns: P) -> Self
21    where
22        P: IntoIterator<Item = Box<dyn RewritePattern>>,
23    {
24        Self {
25            context,
26            patterns: patterns.into_iter().collect(),
27        }
28    }
29
30    #[inline]
31    pub fn context(&self) -> Rc<Context> {
32        Rc::clone(&self.context)
33    }
34
35    #[inline]
36    pub fn patterns(&self) -> &[Box<dyn RewritePattern>] {
37        &self.patterns
38    }
39
40    pub fn push(&mut self, pattern: impl RewritePattern + 'static) {
41        self.patterns.push(Box::new(pattern));
42    }
43
44    pub fn extend<P>(&mut self, patterns: P)
45    where
46        P: IntoIterator<Item = Box<dyn RewritePattern>>,
47    {
48        self.patterns.extend(patterns);
49    }
50}
51
52pub struct FrozenRewritePatternSet {
53    context: Rc<Context>,
54    patterns: Vec<Rc<dyn RewritePattern>>,
55    op_specific_patterns: BTreeMap<OperationName, SmallVec<[Rc<dyn RewritePattern>; 2]>>,
56    any_op_patterns: SmallVec<[Rc<dyn RewritePattern>; 1]>,
57}
58impl FrozenRewritePatternSet {
59    pub fn new(patterns: RewritePatternSet) -> Self {
60        let RewritePatternSet { context, patterns } = patterns;
61        let mut this = Self {
62            context,
63            patterns: Default::default(),
64            op_specific_patterns: Default::default(),
65            any_op_patterns: Default::default(),
66        };
67
68        for pattern in patterns {
69            let pattern = Rc::<dyn RewritePattern>::from(pattern);
70            match pattern.kind() {
71                PatternKind::Operation(name) => {
72                    this.op_specific_patterns
73                        .entry(name.clone())
74                        .or_default()
75                        .push(Rc::clone(&pattern));
76                    this.patterns.push(pattern);
77                }
78                PatternKind::Trait(ref trait_id) => {
79                    for dialect in this.context.registered_dialects().values() {
80                        for op in dialect.registered_ops().iter() {
81                            if op.implements_trait_id(trait_id) {
82                                this.op_specific_patterns
83                                    .entry(op.clone())
84                                    .or_default()
85                                    .push(Rc::clone(&pattern));
86                            }
87                        }
88                    }
89                    this.patterns.push(pattern);
90                }
91                PatternKind::Any => {
92                    this.any_op_patterns.push(Rc::clone(&pattern));
93                    this.patterns.push(pattern);
94                }
95            }
96        }
97
98        this
99    }
100
101    #[inline]
102    pub fn patterns(&self) -> &[Rc<dyn RewritePattern>] {
103        &self.patterns
104    }
105
106    #[inline]
107    pub fn op_specific_patterns(
108        &self,
109    ) -> &BTreeMap<OperationName, SmallVec<[Rc<dyn RewritePattern>; 2]>> {
110        &self.op_specific_patterns
111    }
112
113    #[inline]
114    pub fn any_op_patterns(&self) -> &[Rc<dyn RewritePattern>] {
115        &self.any_op_patterns
116    }
117}