midenc_hir/patterns/
pattern_set.rs1use alloc::{boxed::Box, collections::BTreeMap, rc::Rc, vec, vec::Vec};
2
3use smallvec::SmallVec;
4
5use super::*;
6use crate::{Context, OperationName};
7
8pub struct RewritePatternSet {
9 context: Rc<Context>,
10 patterns: Vec<Box<dyn RewritePattern>>,
11}
12impl RewritePatternSet {
13 pub fn new(context: Rc<Context>) -> Self {
14 Self {
15 context,
16 patterns: vec![],
17 }
18 }
19
20 pub fn from_iter<P>(context: Rc<Context>, patterns: P) -> Self
21 where
22 P: IntoIterator<Item = Box<dyn RewritePattern>>,
23 {
24 Self {
25 context,
26 patterns: patterns.into_iter().collect(),
27 }
28 }
29
30 #[inline]
31 pub fn context(&self) -> Rc<Context> {
32 Rc::clone(&self.context)
33 }
34
35 #[inline]
36 pub fn patterns(&self) -> &[Box<dyn RewritePattern>] {
37 &self.patterns
38 }
39
40 pub fn push(&mut self, pattern: impl RewritePattern + 'static) {
41 self.patterns.push(Box::new(pattern));
42 }
43
44 pub fn extend<P>(&mut self, patterns: P)
45 where
46 P: IntoIterator<Item = Box<dyn RewritePattern>>,
47 {
48 self.patterns.extend(patterns);
49 }
50}
51
52pub struct FrozenRewritePatternSet {
53 context: Rc<Context>,
54 patterns: Vec<Rc<dyn RewritePattern>>,
55 op_specific_patterns: BTreeMap<OperationName, SmallVec<[Rc<dyn RewritePattern>; 2]>>,
56 any_op_patterns: SmallVec<[Rc<dyn RewritePattern>; 1]>,
57}
58impl FrozenRewritePatternSet {
59 pub fn new(patterns: RewritePatternSet) -> Self {
60 let RewritePatternSet { context, patterns } = patterns;
61 let mut this = Self {
62 context,
63 patterns: Default::default(),
64 op_specific_patterns: Default::default(),
65 any_op_patterns: Default::default(),
66 };
67
68 for pattern in patterns {
69 let pattern = Rc::<dyn RewritePattern>::from(pattern);
70 match pattern.kind() {
71 PatternKind::Operation(name) => {
72 this.op_specific_patterns
73 .entry(name.clone())
74 .or_default()
75 .push(Rc::clone(&pattern));
76 this.patterns.push(pattern);
77 }
78 PatternKind::Trait(ref trait_id) => {
79 for dialect in this.context.registered_dialects().values() {
80 for op in dialect.registered_ops().iter() {
81 if op.implements_trait_id(trait_id) {
82 this.op_specific_patterns
83 .entry(op.clone())
84 .or_default()
85 .push(Rc::clone(&pattern));
86 }
87 }
88 }
89 this.patterns.push(pattern);
90 }
91 PatternKind::Any => {
92 this.any_op_patterns.push(Rc::clone(&pattern));
93 this.patterns.push(pattern);
94 }
95 }
96 }
97
98 this
99 }
100
101 #[inline]
102 pub fn patterns(&self) -> &[Rc<dyn RewritePattern>] {
103 &self.patterns
104 }
105
106 #[inline]
107 pub fn op_specific_patterns(
108 &self,
109 ) -> &BTreeMap<OperationName, SmallVec<[Rc<dyn RewritePattern>; 2]>> {
110 &self.op_specific_patterns
111 }
112
113 #[inline]
114 pub fn any_op_patterns(&self) -> &[Rc<dyn RewritePattern>] {
115 &self.any_op_patterns
116 }
117}