claw_ast/
expressions.rs

1use super::{merge, NameId, Span};
2use cranelift_entity::{entity_impl, PrimaryMap};
3use std::collections::HashMap;
4
5#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
6pub struct ExpressionId(u32);
7entity_impl!(ExpressionId, "expression");
8
9#[derive(Clone, Debug, Default)]
10pub struct ExpressionData {
11    expressions: PrimaryMap<ExpressionId, Expression>,
12    expression_spans: HashMap<ExpressionId, Span>,
13}
14
15impl ExpressionData {
16    pub fn alloc(&mut self, expression: Expression, span: Span) -> ExpressionId {
17        let id = self.expressions.push(expression);
18        self.expression_spans.insert(id, span);
19        id
20    }
21
22    pub fn get_exp(&self, id: ExpressionId) -> &Expression {
23        self.expressions.get(id).unwrap()
24    }
25
26    pub fn get_span(&self, id: ExpressionId) -> Span {
27        *self.expression_spans.get(&id).unwrap()
28    }
29
30    pub fn expressions(&self) -> &PrimaryMap<ExpressionId, Expression> {
31        &self.expressions
32    }
33
34    pub fn alloc_ident(&mut self, ident: NameId, span: Span) -> ExpressionId {
35        let expr = Expression::Identifier(Identifier { ident });
36        self.alloc(expr, span)
37    }
38
39    pub fn alloc_literal(&mut self, literal: Literal, span: Span) -> ExpressionId {
40        self.alloc(Expression::Literal(literal), span)
41    }
42
43    pub fn alloc_call(
44        &mut self,
45        ident: NameId,
46        args: Vec<ExpressionId>,
47        span: Span,
48    ) -> ExpressionId {
49        let expr = Expression::Call(Call { ident, args });
50        self.alloc(expr, span)
51    }
52
53    pub fn alloc_unary_op(&mut self, op: UnaryOp, inner: ExpressionId, span: Span) -> ExpressionId {
54        let expr = match op {
55            UnaryOp::Negate => Expression::Unary(UnaryExpression { op, inner }),
56        };
57        self.alloc(expr, span)
58    }
59
60    pub fn alloc_bin_op(
61        &mut self,
62        op: BinaryOp,
63        left: ExpressionId,
64        right: ExpressionId,
65    ) -> ExpressionId {
66        let span = merge(&self.get_span(left), &self.get_span(right));
67        self.alloc(
68            Expression::Binary(BinaryExpression { op, left, right }),
69            span,
70        )
71    }
72}
73
74pub trait ContextEq<Context> {
75    fn context_eq(&self, other: &Self, context: &Context) -> bool;
76}
77
78#[derive(Debug, PartialEq, Clone)]
79pub enum Expression {
80    Identifier(Identifier),
81    Enum(EnumLiteral),
82    Literal(Literal),
83    Call(Call),
84    Unary(UnaryExpression),
85    Binary(BinaryExpression),
86}
87
88impl ContextEq<super::Component> for ExpressionId {
89    fn context_eq(&self, other: &Self, context: &super::Component) -> bool {
90        let self_span = context.expr().get_span(*self);
91        let other_span = context.expr().get_span(*other);
92        if self_span != other_span {
93            dbg!(self_span, other_span);
94            return false;
95        }
96
97        let self_expr = context.expr().get_exp(*self);
98        let other_expr = context.expr().get_exp(*other);
99        if !self_expr.context_eq(other_expr, context) {
100            dbg!(self_expr, other_expr);
101            return false;
102        }
103        true
104    }
105}
106
107impl ContextEq<super::Component> for Expression {
108    fn context_eq(&self, other: &Self, context: &super::Component) -> bool {
109        match (self, other) {
110            (Expression::Identifier(left), Expression::Identifier(right)) => {
111                left.context_eq(right, context)
112            }
113            (Expression::Literal(left), Expression::Literal(right)) => {
114                left.context_eq(right, context)
115            }
116            (Expression::Call(left), Expression::Call(right)) => left.context_eq(right, context),
117            (Expression::Unary(left), Expression::Unary(right)) => left.context_eq(right, context),
118            (Expression::Binary(left), Expression::Binary(right)) => {
119                left.context_eq(right, context)
120            }
121            _ => false,
122        }
123    }
124}
125
126#[derive(Debug, PartialEq, Clone)]
127pub struct Identifier {
128    pub ident: NameId,
129}
130
131impl ContextEq<super::Component> for Identifier {
132    fn context_eq(&self, other: &Self, context: &super::Component) -> bool {
133        context.get_name(self.ident) == context.get_name(other.ident)
134    }
135}
136
137#[derive(Debug, PartialEq, Clone)]
138pub struct EnumLiteral {
139    pub enum_name: NameId,
140    pub case_name: NameId,
141}
142
143impl ContextEq<super::Component> for EnumLiteral {
144    fn context_eq(&self, other: &Self, context: &super::Component) -> bool {
145        context.get_name(self.enum_name) == context.get_name(other.enum_name)
146            && context.get_name(self.case_name) == context.get_name(other.case_name)
147    }
148}
149
150#[derive(Debug, PartialEq, Clone)]
151pub enum Literal {
152    Integer(u64),
153    Float(f64),
154    String(String),
155}
156
157impl ContextEq<super::Component> for Literal {
158    fn context_eq(&self, other: &Self, _context: &super::Component) -> bool {
159        self == other
160    }
161}
162
163#[derive(Debug, PartialEq, Clone)]
164pub struct Call {
165    pub ident: NameId,
166    pub args: Vec<ExpressionId>,
167}
168
169impl ContextEq<super::Component> for Call {
170    fn context_eq(&self, other: &Self, context: &super::Component) -> bool {
171        let ident_eq = self.ident.context_eq(&other.ident, context);
172        let args_eq = self
173            .args
174            .iter()
175            .zip(other.args.iter())
176            .map(|(l, r)| l.context_eq(r, context))
177            .all(|v| v);
178
179        ident_eq && args_eq
180    }
181}
182
183// Unary Operators
184
185#[derive(Debug, PartialEq, Clone, Copy)]
186pub enum UnaryOp {
187    Negate,
188}
189
190#[derive(Debug, PartialEq, Clone)]
191pub struct UnaryExpression {
192    pub op: UnaryOp,
193    pub inner: ExpressionId,
194}
195
196impl ContextEq<super::Component> for UnaryExpression {
197    fn context_eq(&self, other: &Self, context: &super::Component) -> bool {
198        let self_inner = context.expr().get_exp(self.inner);
199        let other_inner = context.expr().get_exp(other.inner);
200        self_inner.context_eq(other_inner, context)
201    }
202}
203
204// Binary Operators
205
206#[derive(Debug, PartialEq, Clone, Copy)]
207pub enum BinaryOp {
208    // Arithmetic Operations
209    Multiply,
210    Divide,
211    Modulo,
212    Add,
213    Subtract,
214
215    // Shifting Operations
216    BitShiftL,
217    BitShiftR,
218    ArithShiftR,
219
220    // Comparisons
221    LessThan,
222    LessThanEqual,
223    GreaterThan,
224    GreaterThanEqual,
225    Equals,
226    NotEquals,
227
228    // Bitwise Operations
229    BitOr,
230    BitXor,
231    BitAnd,
232
233    // Logical Operations
234    LogicalOr,
235    LogicalAnd,
236}
237
238#[derive(Debug, PartialEq, Clone)]
239pub struct BinaryExpression {
240    pub op: BinaryOp,
241    pub left: ExpressionId,
242    pub right: ExpressionId,
243}
244
245impl ContextEq<super::Component> for BinaryExpression {
246    fn context_eq(&self, other: &Self, context: &super::Component) -> bool {
247        let self_left = context.expr().get_exp(self.left);
248        let other_left = context.expr().get_exp(other.left);
249        let left_eq = self_left.context_eq(other_left, context);
250
251        let self_right = context.expr().get_exp(self.right);
252        let other_right = context.expr().get_exp(other.right);
253        let right_eq = self_right.context_eq(other_right, context);
254
255        left_eq && right_eq
256    }
257}
258
259impl BinaryExpression {
260    pub fn is_relation(&self) -> bool {
261        use BinaryOp as BE;
262        matches!(
263            self.op,
264            BE::LessThan
265                | BE::LessThanEqual
266                | BE::GreaterThan
267                | BE::GreaterThanEqual
268                | BE::Equals
269                | BE::NotEquals
270        )
271    }
272}