1use super::{merge, NameId, Span};
2use cranelift_entity::{entity_impl, PrimaryMap};
3use std::collections::HashMap;
4
5#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
6pub struct ExpressionId(u32);
7entity_impl!(ExpressionId, "expression");
8
9#[derive(Clone, Debug, Default)]
10pub struct ExpressionData {
11 expressions: PrimaryMap<ExpressionId, Expression>,
12 expression_spans: HashMap<ExpressionId, Span>,
13}
14
15impl ExpressionData {
16 pub fn alloc(&mut self, expression: Expression, span: Span) -> ExpressionId {
17 let id = self.expressions.push(expression);
18 self.expression_spans.insert(id, span);
19 id
20 }
21
22 pub fn get_exp(&self, id: ExpressionId) -> &Expression {
23 self.expressions.get(id).unwrap()
24 }
25
26 pub fn get_span(&self, id: ExpressionId) -> Span {
27 *self.expression_spans.get(&id).unwrap()
28 }
29
30 pub fn expressions(&self) -> &PrimaryMap<ExpressionId, Expression> {
31 &self.expressions
32 }
33
34 pub fn alloc_ident(&mut self, ident: NameId, span: Span) -> ExpressionId {
35 let expr = Expression::Identifier(Identifier { ident });
36 self.alloc(expr, span)
37 }
38
39 pub fn alloc_literal(&mut self, literal: Literal, span: Span) -> ExpressionId {
40 self.alloc(Expression::Literal(literal), span)
41 }
42
43 pub fn alloc_call(
44 &mut self,
45 ident: NameId,
46 args: Vec<ExpressionId>,
47 span: Span,
48 ) -> ExpressionId {
49 let expr = Expression::Call(Call { ident, args });
50 self.alloc(expr, span)
51 }
52
53 pub fn alloc_unary_op(&mut self, op: UnaryOp, inner: ExpressionId, span: Span) -> ExpressionId {
54 let expr = match op {
55 UnaryOp::Negate => Expression::Unary(UnaryExpression { op, inner }),
56 };
57 self.alloc(expr, span)
58 }
59
60 pub fn alloc_bin_op(
61 &mut self,
62 op: BinaryOp,
63 left: ExpressionId,
64 right: ExpressionId,
65 ) -> ExpressionId {
66 let span = merge(&self.get_span(left), &self.get_span(right));
67 self.alloc(
68 Expression::Binary(BinaryExpression { op, left, right }),
69 span,
70 )
71 }
72}
73
74pub trait ContextEq<Context> {
75 fn context_eq(&self, other: &Self, context: &Context) -> bool;
76}
77
78#[derive(Debug, PartialEq, Clone)]
79pub enum Expression {
80 Identifier(Identifier),
81 Enum(EnumLiteral),
82 Literal(Literal),
83 Call(Call),
84 Unary(UnaryExpression),
85 Binary(BinaryExpression),
86}
87
88impl ContextEq<super::Component> for ExpressionId {
89 fn context_eq(&self, other: &Self, context: &super::Component) -> bool {
90 let self_span = context.expr().get_span(*self);
91 let other_span = context.expr().get_span(*other);
92 if self_span != other_span {
93 dbg!(self_span, other_span);
94 return false;
95 }
96
97 let self_expr = context.expr().get_exp(*self);
98 let other_expr = context.expr().get_exp(*other);
99 if !self_expr.context_eq(other_expr, context) {
100 dbg!(self_expr, other_expr);
101 return false;
102 }
103 true
104 }
105}
106
107impl ContextEq<super::Component> for Expression {
108 fn context_eq(&self, other: &Self, context: &super::Component) -> bool {
109 match (self, other) {
110 (Expression::Identifier(left), Expression::Identifier(right)) => {
111 left.context_eq(right, context)
112 }
113 (Expression::Literal(left), Expression::Literal(right)) => {
114 left.context_eq(right, context)
115 }
116 (Expression::Call(left), Expression::Call(right)) => left.context_eq(right, context),
117 (Expression::Unary(left), Expression::Unary(right)) => left.context_eq(right, context),
118 (Expression::Binary(left), Expression::Binary(right)) => {
119 left.context_eq(right, context)
120 }
121 _ => false,
122 }
123 }
124}
125
126#[derive(Debug, PartialEq, Clone)]
127pub struct Identifier {
128 pub ident: NameId,
129}
130
131impl ContextEq<super::Component> for Identifier {
132 fn context_eq(&self, other: &Self, context: &super::Component) -> bool {
133 context.get_name(self.ident) == context.get_name(other.ident)
134 }
135}
136
137#[derive(Debug, PartialEq, Clone)]
138pub struct EnumLiteral {
139 pub enum_name: NameId,
140 pub case_name: NameId,
141}
142
143impl ContextEq<super::Component> for EnumLiteral {
144 fn context_eq(&self, other: &Self, context: &super::Component) -> bool {
145 context.get_name(self.enum_name) == context.get_name(other.enum_name)
146 && context.get_name(self.case_name) == context.get_name(other.case_name)
147 }
148}
149
150#[derive(Debug, PartialEq, Clone)]
151pub enum Literal {
152 Integer(u64),
153 Float(f64),
154 String(String),
155}
156
157impl ContextEq<super::Component> for Literal {
158 fn context_eq(&self, other: &Self, _context: &super::Component) -> bool {
159 self == other
160 }
161}
162
163#[derive(Debug, PartialEq, Clone)]
164pub struct Call {
165 pub ident: NameId,
166 pub args: Vec<ExpressionId>,
167}
168
169impl ContextEq<super::Component> for Call {
170 fn context_eq(&self, other: &Self, context: &super::Component) -> bool {
171 let ident_eq = self.ident.context_eq(&other.ident, context);
172 let args_eq = self
173 .args
174 .iter()
175 .zip(other.args.iter())
176 .map(|(l, r)| l.context_eq(r, context))
177 .all(|v| v);
178
179 ident_eq && args_eq
180 }
181}
182
183#[derive(Debug, PartialEq, Clone, Copy)]
186pub enum UnaryOp {
187 Negate,
188}
189
190#[derive(Debug, PartialEq, Clone)]
191pub struct UnaryExpression {
192 pub op: UnaryOp,
193 pub inner: ExpressionId,
194}
195
196impl ContextEq<super::Component> for UnaryExpression {
197 fn context_eq(&self, other: &Self, context: &super::Component) -> bool {
198 let self_inner = context.expr().get_exp(self.inner);
199 let other_inner = context.expr().get_exp(other.inner);
200 self_inner.context_eq(other_inner, context)
201 }
202}
203
204#[derive(Debug, PartialEq, Clone, Copy)]
207pub enum BinaryOp {
208 Multiply,
210 Divide,
211 Modulo,
212 Add,
213 Subtract,
214
215 BitShiftL,
217 BitShiftR,
218 ArithShiftR,
219
220 LessThan,
222 LessThanEqual,
223 GreaterThan,
224 GreaterThanEqual,
225 Equals,
226 NotEquals,
227
228 BitOr,
230 BitXor,
231 BitAnd,
232
233 LogicalOr,
235 LogicalAnd,
236}
237
238#[derive(Debug, PartialEq, Clone)]
239pub struct BinaryExpression {
240 pub op: BinaryOp,
241 pub left: ExpressionId,
242 pub right: ExpressionId,
243}
244
245impl ContextEq<super::Component> for BinaryExpression {
246 fn context_eq(&self, other: &Self, context: &super::Component) -> bool {
247 let self_left = context.expr().get_exp(self.left);
248 let other_left = context.expr().get_exp(other.left);
249 let left_eq = self_left.context_eq(other_left, context);
250
251 let self_right = context.expr().get_exp(self.right);
252 let other_right = context.expr().get_exp(other.right);
253 let right_eq = self_right.context_eq(other_right, context);
254
255 left_eq && right_eq
256 }
257}
258
259impl BinaryExpression {
260 pub fn is_relation(&self) -> bool {
261 use BinaryOp as BE;
262 matches!(
263 self.op,
264 BE::LessThan
265 | BE::LessThanEqual
266 | BE::GreaterThan
267 | BE::GreaterThanEqual
268 | BE::Equals
269 | BE::NotEquals
270 )
271 }
272}