kodept_ast/node/
expression.rs

1#[cfg(feature = "serde")]
2use serde::{Deserialize, Serialize};
3
4use BinaryExpressionKind::*;
5use kodept_core::structure::{rlt};
6use kodept_core::structure::rlt::new_types::{BinaryOperationSymbol, UnaryOperationSymbol};
7use kodept_core::structure::span::CodeHolder;
8use UnaryExpressionKind::*;
9
10use crate::{BlockLevel, CodeFlow, Lit, node, node_sub_enum, Param, Term};
11use crate::graph::{Identity, SyntaxTreeBuilder};
12use crate::graph::NodeId;
13use crate::graph::tags::*;
14use crate::traits::{Linker, PopulateTree};
15
16node_sub_enum! {
17    #[derive(Debug, PartialEq)]
18    #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
19    pub enum Operation {
20        Appl(Appl),
21        Acc(Acc),
22        Unary(UnExpr),
23        Binary(BinExpr),
24        Block(Exprs),
25        Expr(forward Expression),
26    }
27}
28
29node_sub_enum! {
30    #[derive(Debug, PartialEq)]
31    #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
32    pub enum Expression {
33        Lambda(Lambda),
34        CodeFlow(forward CodeFlow),
35        Lit(forward Lit),
36        Term(forward Term)
37    }
38}
39
40node! {
41    #[derive(Debug, PartialEq)]
42    #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
43    pub struct Appl {;
44        pub expr: Identity<Operation> as PRIMARY,
45        pub params: Vec<Operation> as SECONDARY,
46    }
47}
48
49node! {
50    #[derive(Debug, PartialEq)]
51    #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
52    pub struct Acc {;
53        pub left: Identity<Operation> as LEFT,
54        pub right: Identity<Operation> as RIGHT,
55    }
56}
57
58node! {
59    #[derive(Debug, PartialEq)]
60    #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
61    pub struct UnExpr {
62        pub kind: UnaryExpressionKind,;
63        pub expr: Identity<Operation>,
64    }
65}
66
67node! {
68    #[derive(Debug, PartialEq)]
69    #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
70    pub struct BinExpr {
71        pub kind: BinaryExpressionKind,;
72        pub left: Identity<Operation> as LEFT,
73        pub right: Identity<Operation> as RIGHT,
74    }
75}
76
77node! {
78    #[derive(Debug, PartialEq)]
79    #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
80    pub struct Lambda {;
81        // binds somehow wrapped in operation causing expr to fail => tags required
82        pub binds: Vec<Param> as PRIMARY,
83        pub expr: Identity<Operation> as SECONDARY,
84    }
85}
86
87node! {
88    #[derive(Debug, PartialEq)]
89    #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
90    pub struct Exprs {;
91        pub items: Vec<BlockLevel>,
92    }
93}
94
95#[derive(Debug, PartialEq, Clone)]
96#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
97pub enum UnaryExpressionKind {
98    Neg,
99    Not,
100    Inv,
101    Plus,
102}
103
104#[derive(Debug, PartialEq, Clone)]
105#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
106pub enum ComparisonKind {
107    Less,
108    LessEq,
109    Greater,
110    GreaterEq,
111}
112
113#[derive(Debug, PartialEq, Clone)]
114#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
115pub enum EqKind {
116    Eq,
117    NEq,
118}
119
120#[derive(Debug, PartialEq, Clone)]
121#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
122pub enum LogicKind {
123    Disj,
124    Conj,
125}
126
127#[derive(Debug, PartialEq, Clone)]
128#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
129pub enum BitKind {
130    Or,
131    And,
132    Xor,
133}
134
135#[derive(Debug, PartialEq, Clone)]
136#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
137pub enum MathKind {
138    Add,
139    Sub,
140    Mul,
141    Pow,
142    Div,
143    Mod,
144}
145
146#[derive(Debug, PartialEq, Clone)]
147#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
148pub enum BinaryExpressionKind {
149    Math(MathKind),
150    Cmp(ComparisonKind),
151    Eq(EqKind),
152    Bit(BitKind),
153    Logic(LogicKind),
154    ComplexComparison,
155    Assign
156}
157
158impl PopulateTree for rlt::ExpressionBlock {
159    type Output = Exprs;
160
161    fn convert(
162        &self,
163        builder: &mut SyntaxTreeBuilder,
164        context: &mut (impl Linker + CodeHolder),
165    ) -> NodeId<Self::Output> {
166        builder
167            .add_node(Exprs::uninit())
168            .with_children_from(self.expression.as_ref(), context)
169            .with_rlt(context, self)
170            .id()
171    }
172}
173
174impl PopulateTree for rlt::Operation {
175    type Output = Operation;
176
177    fn convert(
178        &self,
179        builder: &mut SyntaxTreeBuilder,
180        context: &mut (impl Linker + CodeHolder),
181    ) -> NodeId<Self::Output> {
182        match self {
183            rlt::Operation::Block(x) => x.convert(builder, context).cast(),
184            rlt::Operation::Access { left, right, .. } => {
185                build_access(self, builder, context, left, right)
186            }
187            rlt::Operation::TopUnary { operator, expr } => {
188                build_unary(self, builder, context, operator, expr)
189            }
190            rlt::Operation::Binary {
191                left,
192                operation,
193                right,
194            } => build_binary(self, builder, context, left, operation, right),
195            rlt::Operation::Application(x) => x.convert(builder, context).cast(),
196            rlt::Operation::Expression(x) => x.convert(builder, context).cast(),
197        }
198    }
199}
200
201fn build_binary(
202    node: &rlt::Operation,
203    builder: &mut SyntaxTreeBuilder,
204    context: &mut (impl Linker + CodeHolder + Sized),
205    left: &rlt::Operation,
206    operation: &BinaryOperationSymbol,
207    right: &rlt::Operation,
208) -> NodeId<Operation> {
209    let binding = context.get_chunk_located(operation);
210    let op_text = binding.as_ref();
211    
212    builder
213        .add_node(BinExpr::uninit(match (operation, op_text) {
214            (BinaryOperationSymbol::Pow(_), _) => Math(MathKind::Pow),
215            (BinaryOperationSymbol::Mul(_), "*") => Math(MathKind::Mul),
216            (BinaryOperationSymbol::Mul(_), "/") => Math(MathKind::Div),
217            (BinaryOperationSymbol::Mul(_), "%") => Math(MathKind::Mod),
218            (BinaryOperationSymbol::Add(_), "+") => Math(MathKind::Add),
219            (BinaryOperationSymbol::Add(_), "-") => Math(MathKind::Sub),
220            (BinaryOperationSymbol::ComplexComparison(_), _) => ComplexComparison,
221            (BinaryOperationSymbol::CompoundComparison(_), "<=") => Cmp(ComparisonKind::LessEq),
222            (BinaryOperationSymbol::CompoundComparison(_), ">=") => Cmp(ComparisonKind::GreaterEq),
223            (BinaryOperationSymbol::CompoundComparison(_), "!=") => Eq(EqKind::NEq),
224            (BinaryOperationSymbol::CompoundComparison(_), "==") => Eq(EqKind::Eq),
225            (BinaryOperationSymbol::Comparison(_), "<") => Cmp(ComparisonKind::Less),
226            (BinaryOperationSymbol::Comparison(_), ">") => Cmp(ComparisonKind::Greater),
227            (BinaryOperationSymbol::Bit(_), "|") => Bit(BitKind::Or),
228            (BinaryOperationSymbol::Bit(_), "&") => Bit(BitKind::And),
229            (BinaryOperationSymbol::Bit(_), "^") => Bit(BitKind::Xor),
230            (BinaryOperationSymbol::Logic(_), "||") => Logic(LogicKind::Disj),
231            (BinaryOperationSymbol::Logic(_), "&&") => Logic(LogicKind::Conj),
232            (BinaryOperationSymbol::Assign(_), "=") => Assign,
233            
234            (BinaryOperationSymbol::Mul(_), x) => panic!("Unknown mul operator found: {x}"),
235            (BinaryOperationSymbol::Add(_), x) => panic!("Unknown add operator found: {x}"),
236            (BinaryOperationSymbol::CompoundComparison(_), x) => panic!("Unknown cmp operator found: {x}"),
237            (BinaryOperationSymbol::Comparison(_), x) => panic!("Unknown cmp operator found: {x}"),
238            (BinaryOperationSymbol::Bit(_), x) => panic!("Unknown bit operator found: {x}"),
239            (BinaryOperationSymbol::Logic(_), x) => panic!("Unknown logic operator found: {x}"),
240            (BinaryOperationSymbol::Assign(_), x) => panic!("Unknown assign operator found: {x}")
241        }))
242        .with_children_from::<LEFT, _>([left], context)
243        .with_children_from::<RIGHT, _>([right], context)
244        .with_rlt(context, node)
245        .id()
246        .cast()
247}
248
249fn build_unary(
250    node: &rlt::Operation,
251    builder: &mut SyntaxTreeBuilder,
252    context: &mut (impl Linker + CodeHolder + Sized),
253    operator: &UnaryOperationSymbol,
254    expr: &rlt::Operation,
255) -> NodeId<Operation> {
256    builder
257        .add_node(UnExpr::uninit(match operator {
258            UnaryOperationSymbol::Neg(_) => Neg,
259            UnaryOperationSymbol::Not(_) => Not,
260            UnaryOperationSymbol::Inv(_) => Inv,
261            UnaryOperationSymbol::Plus(_) => Plus,
262        }))
263        .with_children_from([expr], context)
264        .with_rlt(context, node)
265        .id()
266        .cast()
267}
268
269fn build_access(
270    node: &rlt::Operation,
271    builder: &mut SyntaxTreeBuilder,
272    context: &mut (impl Linker + CodeHolder + Sized),
273    left: &rlt::Operation,
274    right: &rlt::Operation,
275) -> NodeId<Operation> {
276    builder
277        .add_node(Acc::uninit())
278        .with_children_from::<LEFT, _>([left], context)
279        .with_children_from::<RIGHT, _>([right], context)
280        .with_rlt(context, node)
281        .id()
282        .cast()
283}
284
285impl PopulateTree for rlt::Application {
286    type Output = Appl;
287
288    fn convert(
289        &self,
290        builder: &mut SyntaxTreeBuilder,
291        context: &mut (impl Linker + CodeHolder),
292    ) -> NodeId<Self::Output> {
293        builder
294            .add_node(Appl::uninit())
295            .with_children_from::<PRIMARY, _>([&self.expr], context)
296            .with_children_from::<SECONDARY, _>(
297                self.params
298                    .as_ref()
299                    .map_or([].as_slice(), |x| x.inner.as_ref()),
300                context,
301            )
302            .with_rlt(context, self)
303            .id()
304    }
305}
306
307impl PopulateTree for rlt::Expression {
308    type Output = Expression;
309
310    fn convert(
311        &self,
312        builder: &mut SyntaxTreeBuilder,
313        context: &mut (impl Linker + CodeHolder),
314    ) -> NodeId<Self::Output> {
315        match self {
316            rlt::Expression::Lambda { binds, expr, .. } => builder
317                .add_node(Lambda::uninit())
318                .with_children_from(binds.as_ref(), context)
319                .with_children_from([expr.as_ref()], context)
320                .with_rlt(context, self)
321                .id()
322                .cast(),
323            rlt::Expression::Term(x) => x.convert(builder, context).cast(),
324            rlt::Expression::Literal(x) => x.convert(builder, context).cast(),
325            rlt::Expression::If(x) => x.convert(builder, context).cast::<CodeFlow>().cast(),
326        }
327    }
328}