Skip to main content

pivot/
ast.rs

1use serde::{Deserialize, Serialize};
2
3#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
4pub enum AstNodeKind {
5    // Primitives
6    Integer,
7    Identifier,
8    // Unary operators
9    Not,
10    // Infix operators
11    NotEqual,
12    Equal,
13    Add,
14    Subtract,
15    Multiply,
16    Divide,
17    // Control flow
18    Block,
19    IfStatement,
20    WhileLoop,
21    Program,
22    // Functions and variables
23    FunctionCall,
24    FunctionReturn,
25    FunctionDefinition,
26    VariableDefinition,
27    VariableDeclaration,
28    Assign,
29    // Import
30    Import,
31    // Blank node
32    Null,
33}
34
35#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
36pub struct AstNode {
37    pub kind: AstNodeKind,
38    pub value: String,
39    pub subnodes: Vec<AstNode>,
40}
41impl AstNode {
42    pub fn new(kind: AstNodeKind, value: String, subnodes: Vec<AstNode>) -> AstNode {
43        AstNode {
44            kind,
45            value,
46            subnodes,
47        }
48    }
49
50    // Primitives
51    pub fn integer(num: i64) -> AstNode {
52        AstNode {
53            kind: AstNodeKind::Integer,
54            value: num.to_string(),
55            subnodes: vec![],
56        }
57    }
58    pub fn identifier(id: String) -> AstNode {
59        AstNode {
60            kind: AstNodeKind::Identifier,
61            value: id,
62            subnodes: vec![],
63        }
64    }
65    // Unary operators
66    pub fn not(operand: AstNode) -> AstNode {
67        AstNode {
68            kind: AstNodeKind::Not,
69            value: "not".into(),
70            subnodes: vec![operand],
71        }
72    }
73    // Infix operators
74    pub fn not_equal(left: AstNode, right: AstNode) -> AstNode {
75        AstNode {
76            kind: AstNodeKind::NotEqual,
77            value: "not_equal".into(),
78            subnodes: vec![left, right],
79        }
80    }
81    pub fn equal(left: AstNode, right: AstNode) -> AstNode {
82        AstNode {
83            kind: AstNodeKind::Equal,
84            value: "equal".into(),
85            subnodes: vec![left, right],
86        }
87    }
88    pub fn add(left: AstNode, right: AstNode) -> AstNode {
89        AstNode {
90            kind: AstNodeKind::Add,
91            value: "add".into(),
92            subnodes: vec![left, right],
93        }
94    }
95    pub fn subtract(left: AstNode, right: AstNode) -> AstNode {
96        AstNode {
97            kind: AstNodeKind::Subtract,
98            value: "subtract".into(),
99            subnodes: vec![left, right],
100        }
101    }
102    pub fn multiply(left: AstNode, right: AstNode) -> AstNode {
103        AstNode {
104            kind: AstNodeKind::Multiply,
105            value: "multiply".into(),
106            subnodes: vec![left, right],
107        }
108    }
109    pub fn divide(left: AstNode, right: AstNode) -> AstNode {
110        AstNode {
111            kind: AstNodeKind::Divide,
112            value: "divide".into(),
113            subnodes: vec![left, right],
114        }
115    }
116    // Control flow
117    pub fn block(statements: Vec<AstNode>) -> AstNode {
118        AstNode {
119            kind: AstNodeKind::Block,
120            value: "block".into(),
121            subnodes: statements,
122        }
123    }
124    pub fn if_statement(
125        conditional: AstNode,
126        consequence: AstNode,
127        alternative: AstNode,
128    ) -> AstNode {
129        AstNode {
130            kind: AstNodeKind::IfStatement,
131            value: "if_statement".into(),
132            subnodes: vec![conditional, consequence, alternative],
133        }
134    }
135    pub fn while_loop(conditional: AstNode, body: AstNode) -> AstNode {
136        AstNode {
137            kind: AstNodeKind::WhileLoop,
138            value: "while_loop".into(),
139            subnodes: vec![conditional, body],
140        }
141    }
142    pub fn program(statements: Vec<AstNode>) -> AstNode {
143        AstNode {
144            kind: AstNodeKind::Program,
145            value: "program".into(),
146            subnodes: statements,
147        }
148    }
149    // Functions and variables
150    pub fn function_call(name: String, parameters: Vec<AstNode>) -> AstNode {
151        AstNode {
152            kind: AstNodeKind::FunctionCall,
153            value: name,
154            subnodes: parameters,
155        }
156    }
157    pub fn function_return(operand: AstNode) -> AstNode {
158        AstNode {
159            kind: AstNodeKind::FunctionReturn,
160            value: "return".into(),
161            subnodes: vec![operand],
162        }
163    }
164    pub fn function_definition(name: String, parameters: Vec<AstNode>, body: AstNode) -> AstNode {
165        let mut params = vec![body];
166        for p in parameters {
167            params.push(p);
168        }
169        AstNode {
170            kind: AstNodeKind::FunctionDefinition,
171            value: name,
172            subnodes: params,
173        }
174    }
175    pub fn variable_definition(name: String, value: AstNode) -> AstNode {
176        AstNode {
177            kind: AstNodeKind::VariableDefinition,
178            value: name,
179            subnodes: vec![value],
180        }
181    }
182    pub fn variable_declaration(name: String) -> AstNode {
183        AstNode {
184            kind: AstNodeKind::VariableDeclaration,
185            value: name,
186            subnodes: vec![],
187        }
188    }
189    pub fn assign(name: String, value: AstNode) -> AstNode {
190        AstNode {
191            kind: AstNodeKind::Assign,
192            value: name,
193            subnodes: vec![value],
194        }
195    }
196    // Import
197    pub fn import(num_args: AstNode, returns_value: AstNode, mut fn_path: Vec<AstNode>) -> AstNode {
198        let mut data = vec![num_args, returns_value];
199        data.append(&mut fn_path);
200        AstNode {
201            kind: AstNodeKind::Import,
202            value: "import".into(),
203            subnodes: data,
204        }
205    }
206    // Blank node
207    pub fn null() -> AstNode {
208        AstNode {
209            kind: AstNodeKind::Null,
210            value: "".into(),
211            subnodes: vec![],
212        }
213    }
214
215    // Other
216    pub fn pretty_print(&self, f: &mut std::fmt::Formatter<'_>, indent: usize) -> std::fmt::Result {
217        for _ in 0..indent {
218            write!(f, " ")?;
219        }
220        write!(f, "{{\n")?;
221        for _ in 0..indent + 2 {
222            write!(f, " ")?;
223        }
224        write!(f, "kind: {:?}\n", self.kind)?;
225        for _ in 0..indent + 2 {
226            write!(f, " ")?;
227        }
228        write!(f, "value: {:?}\n", self.value)?;
229        if self.subnodes.len() > 0 {
230            for _ in 0..indent + 2 {
231                write!(f, " ")?;
232            }
233            write!(f, "subnodes: [\n")?;
234            for subnode in &self.subnodes {
235                subnode.pretty_print(f, indent + 4)?;
236                write!(f, ",\n")?;
237            }
238            for _ in 0..indent + 2 {
239                write!(f, " ")?;
240            }
241            write!(f, "]\n")?;
242        }
243        for _ in 0..indent {
244            write!(f, " ")?;
245        }
246        write!(f, "}}")
247    }
248}
249impl std::fmt::Display for AstNode {
250    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
251        self.pretty_print(f, 0)
252    }
253}