slac/
compiler.rs

1use crate::{error::Result, operator::Operator};
2use std::vec;
3
4use crate::{
5    ast::Expression,
6    error::Error,
7    token::{Precedence, Token},
8};
9
10/// A compiler to transform a list of [`Tokens`](Token) into a single nested [`Expression`] tree.
11///
12/// # Remarks
13///
14/// Uses a Pratt-Parser to build the AST based on the tokens `Precedence`.
15pub struct Compiler {
16    tokens: Vec<Token>,
17    current: usize,
18}
19
20impl Compiler {
21    /// Compiles a structured [`Expression`] tree from a list of [`Tokens`](Token).
22    ///
23    /// # Errors
24    ///
25    /// Returns an [`Error`] when encountering an invalid combination of [`Tokens`](Token).
26    pub fn compile_ast(tokens: Vec<Token>) -> Result<Expression> {
27        let mut compiler = Compiler { tokens, current: 0 };
28        compiler.compile()
29    }
30
31    fn compile(&mut self) -> Result<Expression> {
32        let expression = self.expression()?;
33
34        match self.current() {
35            Some(token) => Err(Error::MultipleExpressions(token.clone())),
36            None => Ok(expression),
37        }
38    }
39
40    fn expression(&mut self) -> Result<Expression> {
41        if self.current < self.tokens.len() {
42            self.parse_precedence(Precedence::Or)
43        } else {
44            Err(Error::Eof)
45        }
46    }
47
48    fn parse_precedence(&mut self, precedence: Precedence) -> Result<Expression> {
49        self.advance();
50        let mut expression = self.do_prefix()?;
51
52        while self
53            .current()
54            .is_some_and(|t| precedence <= Precedence::from(t))
55        {
56            self.advance();
57            expression = self.do_infix(expression)?;
58        }
59
60        Ok(expression)
61    }
62
63    fn do_prefix(&mut self) -> Result<Expression> {
64        let previous = self.previous()?;
65        match previous {
66            Token::Literal(value) => Ok(Expression::Literal {
67                value: value.clone(),
68            }),
69            Token::Identifier(name) => Ok(Expression::Variable { name: name.clone() }),
70            Token::LeftParen => self.grouping(),
71            Token::LeftBracket => self.array(),
72            Token::Not | Token::Minus => self.unary(),
73            _ => Err(Error::NoValidPrefixToken(previous.clone())),
74        }
75    }
76
77    fn do_infix(&mut self, left: Expression) -> Result<Expression> {
78        let previous = self.previous()?;
79        match previous {
80            Token::Minus
81            | Token::Plus
82            | Token::Star
83            | Token::Slash
84            | Token::Div
85            | Token::Mod
86            | Token::Equal
87            | Token::NotEqual
88            | Token::Greater
89            | Token::GreaterEqual
90            | Token::Less
91            | Token::LessEqual
92            | Token::And
93            | Token::Or
94            | Token::Xor => self.binary(left),
95            Token::LeftParen => self.call(left),
96            _ => Err(Error::NoValidInfixToken(previous.clone())),
97        }
98    }
99
100    fn expression_list(&mut self, end_token: &Token) -> Result<Vec<Expression>> {
101        let mut expressions: Vec<Expression> = vec![];
102
103        while self.current().is_some_and(|t| t != end_token) {
104            expressions.push(self.expression()?);
105
106            if self.current() == Some(&Token::Comma) {
107                self.advance();
108            }
109        }
110
111        self.chomp(end_token)?;
112
113        Ok(expressions)
114    }
115
116    fn call(&mut self, left: Expression) -> Result<Expression> {
117        if let Expression::Variable { name } = left {
118            Ok(Expression::Call {
119                name,
120                params: self.expression_list(&Token::RightParen)?,
121            })
122        } else {
123            Err(Error::CallNotOnVariable(self.previous()?.clone()))
124        }
125    }
126
127    fn array(&mut self) -> Result<Expression> {
128        Ok(Expression::Array {
129            expressions: self.expression_list(&Token::RightBracket)?,
130        })
131    }
132
133    fn binary(&mut self, left: Expression) -> Result<Expression> {
134        let operator = Operator::try_from(self.previous()?)?;
135        let right = self.parse_precedence(Precedence::from(self.previous()?).next())?;
136
137        Ok(Expression::Binary {
138            left: Box::new(left),
139            right: Box::new(right),
140            operator,
141        })
142    }
143
144    fn unary(&mut self) -> Result<Expression> {
145        let operator = Operator::try_from(self.previous()?)?;
146        let right = self.parse_precedence(Precedence::Unary)?;
147
148        Ok(Expression::Unary {
149            right: Box::new(right),
150            operator,
151        })
152    }
153
154    fn grouping(&mut self) -> Result<Expression> {
155        let expression = self.expression()?;
156        self.chomp(&Token::RightParen)?;
157
158        Ok(expression)
159    }
160
161    fn advance(&mut self) {
162        if self.current < self.tokens.len() {
163            self.current += 1;
164        }
165    }
166
167    fn current(&self) -> Option<&Token> {
168        self.tokens.get(self.current)
169    }
170
171    fn previous(&self) -> Result<&Token> {
172        self.tokens
173            .get(self.current - 1)
174            .ok_or(Error::PreviousTokenNotFound)
175    }
176
177    fn chomp(&mut self, token: &Token) -> Result<()> {
178        if self.current() == Some(token) {
179            self.advance();
180            Ok(())
181        } else {
182            Err(self
183                .current()
184                .map_or(Error::Eof, |t| Error::InvalidToken(t.clone())))
185        }
186    }
187}
188
189#[cfg(test)]
190mod test {
191    use crate::{ast::Expression, error::Error, operator::Operator, token::Token, value::Value};
192
193    use super::Compiler;
194
195    #[test]
196    fn single_literal() {
197        let ast = Compiler::compile_ast(vec![Token::Literal(Value::Boolean(true))]);
198        let expected = Expression::Literal {
199            value: Value::Boolean(true),
200        };
201
202        assert_eq!(ast, Ok(expected));
203    }
204
205    #[test]
206    fn single_variable() {
207        let ast = Compiler::compile_ast(vec![Token::Identifier(String::from("test"))]);
208        let expected = Expression::Variable {
209            name: String::from("test"),
210        };
211
212        assert_eq!(ast, Ok(expected));
213    }
214
215    #[test]
216    fn expression_group() {
217        let ast = Compiler::compile_ast(vec![
218            Token::LeftParen,
219            Token::Literal(Value::Boolean(true)),
220            Token::RightParen,
221        ]);
222        let expected = Expression::Literal {
223            value: Value::Boolean(true),
224        };
225
226        assert_eq!(ast, Ok(expected));
227    }
228
229    #[test]
230    fn unary_literal() {
231        let ast = Compiler::compile_ast(vec![Token::Minus, Token::Literal(Value::Number(42.0))]);
232        let expected = Expression::Unary {
233            right: Box::new(Expression::Literal {
234                value: Value::Number(42.0),
235            }),
236            operator: Operator::Minus,
237        };
238
239        assert_eq!(ast, Ok(expected));
240    }
241
242    #[test]
243    fn multiply_number() {
244        let ast = Compiler::compile_ast(vec![
245            Token::Literal(Value::Number(3.0)),
246            Token::Star,
247            Token::Literal(Value::Number(2.0)),
248        ]);
249        let expected = Expression::Binary {
250            left: Box::new(Expression::Literal {
251                value: Value::Number(3.0),
252            }),
253            right: Box::new(Expression::Literal {
254                value: Value::Number(2.0),
255            }),
256            operator: Operator::Multiply,
257        };
258
259        assert_eq!(ast, Ok(expected));
260    }
261
262    #[test]
263    fn add_number() {
264        let ast = Compiler::compile_ast(vec![
265            Token::Literal(Value::Number(3.0)),
266            Token::Plus,
267            Token::Literal(Value::Number(2.0)),
268        ]);
269        let expected = Expression::Binary {
270            left: Box::new(Expression::Literal {
271                value: Value::Number(3.0),
272            }),
273            right: Box::new(Expression::Literal {
274                value: Value::Number(2.0),
275            }),
276            operator: Operator::Plus,
277        };
278
279        assert_eq!(ast, Ok(expected));
280    }
281
282    #[test]
283    fn precedence_multiply_addition() {
284        let ast = Compiler::compile_ast(vec![
285            Token::Literal(Value::Number(1.0)),
286            Token::Plus,
287            Token::Literal(Value::Number(2.0)),
288            Token::Star,
289            Token::Literal(Value::Number(3.0)),
290        ]);
291        let expected = Expression::Binary {
292            left: Box::new(Expression::Literal {
293                value: Value::Number(1.0),
294            }),
295            right: Box::new(Expression::Binary {
296                left: Box::new(Expression::Literal {
297                    value: Value::Number(2.0),
298                }),
299                right: Box::new(Expression::Literal {
300                    value: Value::Number(3.0),
301                }),
302                operator: Operator::Multiply,
303            }),
304            operator: Operator::Plus,
305        };
306
307        assert_eq!(ast, Ok(expected));
308    }
309
310    #[test]
311    fn comparison_equal() {
312        let ast = Compiler::compile_ast(vec![
313            Token::Literal(Value::Number(5.0)),
314            Token::Equal,
315            Token::Literal(Value::Number(7.0)),
316        ]);
317        let expected = Expression::Binary {
318            left: Box::new(Expression::Literal {
319                value: Value::Number(5.0),
320            }),
321            right: Box::new(Expression::Literal {
322                value: Value::Number(7.0),
323            }),
324            operator: Operator::Equal,
325        };
326
327        assert_eq!(ast, Ok(expected));
328    }
329
330    #[test]
331    fn boolean_and() {
332        let ast = Compiler::compile_ast(vec![
333            Token::Literal(Value::Boolean(true)),
334            Token::And,
335            Token::Literal(Value::Boolean(false)),
336        ]);
337        let expected = Expression::Binary {
338            left: Box::new(Expression::Literal {
339                value: Value::Boolean(true),
340            }),
341            right: Box::new(Expression::Literal {
342                value: Value::Boolean(false),
343            }),
344            operator: Operator::And,
345        };
346
347        assert_eq!(ast, Ok(expected));
348    }
349
350    #[test]
351    fn variable_add() {
352        let ast = Compiler::compile_ast(vec![
353            Token::LeftParen,
354            Token::Literal(Value::Number(5.0)),
355            Token::Plus,
356            Token::Identifier(String::from("SOME_VAR")),
357            Token::RightParen,
358            Token::Star,
359            Token::Literal(Value::Number(4.0)),
360        ]);
361        let expected = Expression::Binary {
362            left: Box::new(Expression::Binary {
363                left: Box::new(Expression::Literal {
364                    value: Value::Number(5.0),
365                }),
366                right: Box::new(Expression::Variable {
367                    name: String::from("SOME_VAR"),
368                }),
369                operator: Operator::Plus,
370            }),
371            right: Box::new(Expression::Literal {
372                value: Value::Number(4.0),
373            }),
374            operator: Operator::Multiply,
375        };
376
377        assert_eq!(ast, Ok(expected));
378    }
379
380    #[test]
381    fn variable_mul() {
382        let ast = Compiler::compile_ast(vec![
383            Token::Identifier(String::from("SOME_VAR")),
384            Token::Star,
385            Token::Literal(Value::Number(4.0)),
386        ]);
387        let expected = Expression::Binary {
388            left: Box::new(Expression::Variable {
389                name: String::from("SOME_VAR"),
390            }),
391            right: Box::new(Expression::Literal {
392                value: Value::Number(4.0),
393            }),
394            operator: Operator::Multiply,
395        };
396
397        assert_eq!(ast, Ok(expected));
398    }
399
400    #[test]
401    fn function_call() {
402        let ast = Compiler::compile_ast(vec![
403            Token::Identifier(String::from("max")),
404            Token::LeftParen,
405            Token::Literal(Value::Number(1.0)),
406            Token::Comma,
407            Token::Literal(Value::Number(2.0)),
408            Token::RightParen,
409        ]);
410        let expected = Expression::Call {
411            name: String::from("max"),
412            params: vec![
413                Expression::Literal {
414                    value: Value::Number(1.0),
415                },
416                Expression::Literal {
417                    value: Value::Number(2.0),
418                },
419            ],
420        };
421
422        assert_eq!(ast, Ok(expected));
423    }
424
425    #[test]
426    fn err_open_function_call() {
427        let ast = Compiler::compile_ast(vec![
428            Token::Identifier(String::from("max")),
429            Token::LeftParen,
430        ]);
431
432        let expected = Error::Eof;
433
434        assert_eq!(ast, Err(expected));
435    }
436
437    #[test]
438    fn err_open_array() {
439        let ast = Compiler::compile_ast(vec![
440            Token::LeftBracket,
441            Token::Literal(Value::Boolean(false)),
442        ]);
443
444        let expected = Error::Eof;
445        assert_eq!(ast, Err(expected));
446    }
447
448    #[test]
449    fn err_open_group() {
450        let ast = Compiler::compile_ast(vec![Token::LeftParen]);
451
452        assert_eq!(ast, Err(Error::Eof));
453
454        let ast = Compiler::compile_ast(vec![
455            Token::Identifier(String::from("test")),
456            Token::And,
457            Token::LeftParen,
458        ]);
459        assert_eq!(ast, Err(Error::Eof));
460    }
461
462    #[test]
463    fn err_array_empty_expressions() {
464        let ast =
465            Compiler::compile_ast(vec![Token::LeftBracket, Token::Comma, Token::RightBracket]);
466
467        let expected = Error::NoValidPrefixToken(Token::Comma);
468        assert_eq!(ast, Err(expected));
469    }
470}