ferric_parser/parser/
parser.rs

1//! The ferric_parser module provides tools for parsing a sequence of tokens
2//! into an abstract syntax tree (AST) and performing optimizations on it.
3
4use ferric_lexer::{Keyword, Operator, Token};
5use crate::parser::code_gen::ast_node::ASTNode;
6use std::collections::HashSet;
7use crate::parser::symbol_table::SymbolTable;
8
9/// Represents the parser which takes a sequence of tokens and produces an AST.
10///
11/// # Example
12///
13/// ```rust
14/// use ferric_parser::parser::Parser;
15/// use ferric_lexer::{Token, Lexer};
16///
17/// let code = "let x = 5; let y = x + 3;";
18/// let tokens = Lexer::new(code).lex();
19/// let mut parser = Parser::new(tokens);
20/// let ast = parser.parse().unwrap();
21/// ```
22pub struct Parser {
23    tokens: Vec<Token>,
24    position: usize,
25    symtable: SymbolTable,
26}
27
28/// Propagates constants through the AST, simplifying expressions
29/// where possible using the symbol table.
30fn propagate_constants(node: ASTNode, symtable: &mut SymbolTable) -> ASTNode {
31    match node {
32        ASTNode::Program(statements) => {
33            let optimized_statements = statements
34                .into_iter()
35                .map(|stmt| propagate_constants(stmt, symtable))
36                .collect();
37            ASTNode::Program(optimized_statements)
38        }
39        ASTNode::Statement(inner) => {
40            ASTNode::Statement(Box::new(propagate_constants(*inner, symtable)))
41        }
42        ASTNode::Assign { name, value } => {
43            let new_value = propagate_constants(*value, symtable);
44
45            // If the new_value is a number, update the symbol table
46            if let ASTNode::Number(val) = &new_value {
47                symtable.set(&name, *val);
48            }
49
50            ASTNode::Assign {
51                name,
52                value: Box::new(new_value),
53            }
54        }
55        ASTNode::BuiltInFunctionCall { name, args } => {
56            let optimized_args = args
57                .into_iter()
58                .map(|arg| propagate_constants(arg, symtable))
59                .collect();
60            ASTNode::BuiltInFunctionCall {
61                name,
62                args: optimized_args,
63            }
64        }
65        ASTNode::Variable(name) => {
66            if let Some(value) = symtable.get(&name) {
67                ASTNode::Number(value)
68            } else {
69                ASTNode::Variable(name)
70            }
71        }
72        ASTNode::BinaryOp { op, left, right } => ASTNode::BinaryOp {
73            op,
74            left: Box::new(propagate_constants(*left, symtable)),
75            right: Box::new(propagate_constants(*right, symtable)),
76        },
77        _ => node,
78    }
79}
80
81/// Evaluates and simplifies binary operations on constants.
82fn fold_constants(node: ASTNode) -> ASTNode {
83    match node {
84        ASTNode::Program(statements) => {
85            let folded_statements = statements.into_iter().map(fold_constants).collect();
86            ASTNode::Program(folded_statements)
87        }
88        ASTNode::Statement(inner) => ASTNode::Statement(Box::new(fold_constants(*inner))),
89        ASTNode::Assign { name, value } => ASTNode::Assign {
90            name,
91            value: Box::new(fold_constants(*value)),
92        },
93        ASTNode::BuiltInFunctionCall { name, args } => {
94            let folded_args = args.into_iter().map(fold_constants).collect();
95            ASTNode::BuiltInFunctionCall {
96                name,
97                args: folded_args,
98            }
99        }
100        ASTNode::BinaryOp { op, left, right } => {
101            let left = fold_constants(*left);
102            let right = fold_constants(*right);
103            if let (ASTNode::Number(l_val), ASTNode::Number(r_val)) = (&left, &right) {
104                let result = match op {
105                    Operator::Plus => l_val + r_val,
106                    Operator::Minus => l_val - r_val,
107                    Operator::Multiply => l_val * r_val,
108                    Operator::Divide => l_val / r_val,
109                    Operator::Modulo => l_val % r_val,
110                    _ => {
111                        return ASTNode::BinaryOp {
112                            op,
113                            left: Box::new(left),
114                            right: Box::new(right),
115                        }
116                    }
117                };
118                ASTNode::Number(result)
119            } else {
120                ASTNode::BinaryOp {
121                    op,
122                    left: Box::new(left),
123                    right: Box::new(right),
124                }
125            }
126        }
127        _ => node,
128    }
129}
130
131/// Collects the variables used within the AST.
132fn collect_used_variables(node: &ASTNode, used_vars: &mut HashSet<String>) {
133    match node {
134        ASTNode::Variable(name) => {
135            used_vars.insert(name.clone());
136        }
137        ASTNode::Program(statements) => {
138            for stmt in statements {
139                collect_used_variables(stmt, used_vars);
140            }
141        }
142        ASTNode::BinaryOp { left, right, .. } => {
143            collect_used_variables(left, used_vars);
144            collect_used_variables(right, used_vars);
145        }
146        ASTNode::Assign { value, .. } => {
147            collect_used_variables(value, used_vars);
148        }
149        ASTNode::BuiltInFunctionCall { args, .. } | ASTNode::UserFunctionCall { args, .. } => {
150            for arg in args {
151                collect_used_variables(arg, used_vars);
152            }
153        }
154        ASTNode::Statement(inner) => match &**inner {
155            ASTNode::Assign { value, .. } => {
156                collect_used_variables(value, used_vars);
157            }
158            _ => collect_used_variables(&*inner, used_vars),
159        },
160        // Handle other node types as necessary
161        _ => {}
162    }
163}
164
165/// Filters out assignments to variables that are not used in the AST.
166fn filter_unused_assignments(node: ASTNode, used_vars: &HashSet<String>) -> Option<ASTNode> {
167    let result = match node {
168        ASTNode::Program(statements) => {
169            let filtered_statements: Vec<ASTNode> = statements
170                .into_iter()
171                .filter_map(|stmt| {
172                    if let ASTNode::Assign { name, .. } = &stmt {
173                        if !used_vars.contains(name) {
174                            return None; // Filter out unused assignments
175                        }
176                        return filter_unused_assignments(stmt, used_vars);
177                    }
178                    Some(stmt)
179                })
180                .collect();
181            Some(ASTNode::Program(filtered_statements))
182        }
183        ASTNode::Statement(inner) => {
184            if let Some(filtered_inner) = filter_unused_assignments(*inner, used_vars) {
185                Some(ASTNode::Statement(Box::new(filtered_inner)))
186            } else {
187                None
188            }
189        }
190        // ... handle other cases
191        _ => Some(node),
192    };
193
194    result
195}
196
197/// Removes any variables from the AST that aren't being used.
198fn remove_unused_variables(ast: ASTNode) -> ASTNode {
199    // Step 1: Collect used variables
200    let mut used_vars = HashSet::new();
201    collect_used_variables(&ast, &mut used_vars);
202
203    // Step 2: Filter out unused assignments
204    filter_unused_assignments(ast, &used_vars).unwrap()
205}
206
207impl Parser {
208    /// Creates a new parser for a given sequence of tokens.
209    pub fn new(tokens: Vec<Token>) -> Self {
210        Parser {
211            tokens,
212            position: 0,
213            symtable: SymbolTable::new(),
214        }
215    }
216
217    /// Parses the sequence of tokens into an AST.
218    pub fn parse(&mut self) -> Result<ASTNode, String> {
219        let mut statements = Vec::new();
220
221        while self.position < self.tokens.len() {
222            let stmt = self.statement()?;
223            statements.push(stmt);
224
225            if let Some(Token::EoL) = self.peek_next_token() {
226                self.position += 1; // Consume EoL
227            } else {
228                return Err(format!("Expected EoL, found {:?}", self.peek_next_token()));
229            }
230        }
231        let mut ast = ASTNode::Program(statements);
232        let mut previous_ast = ast.clone();
233
234        loop {
235            let optimized_ast = remove_unused_variables(self.optimize(ast.clone()));
236            if optimized_ast == previous_ast {
237                break;
238            }
239            previous_ast = ast;
240            ast = optimized_ast;
241        }
242
243        Ok(ast)
244    }
245
246    /// Consumes optional EoL tokens from the token stream.
247    fn consume_optional_eol(&mut self) {
248        if let Some(Token::EoL) = self.peek_next_token() {
249            self.position += 1; // Consume EoL
250        }
251    }
252
253    /// Parses a statement from the token stream.
254    fn statement(&mut self) -> Result<ASTNode, String> {
255        let stmt = match self.peek_next_token() {
256            Some(Token::Keyword(Keyword::Let)) => parse_variable_declaration!(self),
257            Some(Token::BuiltInFunction(_)) => parse_func_call!(self),
258            Some(Token::Keyword(Keyword::If)) => parse_if_statement!(self),
259            Some(Token::Keyword(Keyword::While)) => parse_while_loop!(self),
260            Some(Token::Number(_)) => self.expression(),
261            Some(Token::Identifier(_)) => {
262                if let Some(Token::Assign) = self.peek_nth_token(1) {
263                    self.expect_assignment()
264                } else {
265                    self.expression()
266                }
267            }
268            _ => Err(format!(
269                "Expected a statement, found {:?}",
270                self.peek_next_token()
271            )),
272        }?;
273
274        Ok(ASTNode::Statement(Box::new(stmt)))
275    }
276
277    /// Parses function arguments from the token stream.
278    fn parse_function_arguments(&mut self) -> Result<Vec<ASTNode>, String> {
279        let mut args = Vec::new();
280
281        // Expecting an open parenthesis after function name
282        match &self.tokens[self.position] {
283            Token::OpenParen => self.position += 1,
284            _ => {
285                return Err(format!(
286                    "Expected '(' after function name, found {:?}",
287                    &self.tokens[self.position]
288                ))
289            }
290        }
291
292        // Parse arguments until a close parenthesis is found
293        while self.tokens[self.position] != Token::CloseParen {
294            let expr = self.expression()?;
295            args.push(expr);
296
297            // If the next token is a comma, consume it and continue parsing the next argument
298            if self.tokens[self.position] == Token::Comma {
299                self.position += 1;
300            }
301        }
302
303        // Consuming the close parenthesis
304        self.position += 1;
305
306        Ok(args)
307    }
308
309    /// Parses a block (sequence of statements) from the token stream.
310    fn parse_block(&mut self) -> Result<ASTNode, String> {
311        let mut statements = Vec::new();
312
313        loop {
314            match self.peek_next_token() {
315                Some(Token::CloseBrace) => {
316                    self.position += 1;
317                    break;
318                }
319                Some(Token::EoL) => {
320                    self.position += 1; // consume the EoL
321                }
322                _ => {
323                    let stmt = self.statement()?;
324                    statements.push(stmt);
325                    self.consume_optional_eol();
326                }
327            }
328        }
329
330        Ok(ASTNode::Program(statements))
331    }
332
333    /// Parses an expression from the token stream.
334    fn expression(&mut self) -> Result<ASTNode, String> {
335        let mut left = self.term()?;
336
337        while let Some(op) = self.peek_next_operator(&[
338            Operator::Plus,
339            Operator::Minus,
340            //Operator::Modulo,
341            Operator::Equal,
342            Operator::NotEqual,
343            Operator::GreaterThan,
344            Operator::LessThan,
345            Operator::GreaterThanOrEqual,
346            Operator::LessThanOrEqual,
347        ]) {
348            self.position += 1; // Consume Operator
349            let right = self.term()?;
350            left = ASTNode::BinaryOp {
351                op,
352                left: Box::new(left),
353                right: Box::new(right),
354            };
355        }
356
357        Ok(left)
358    }
359
360    /// Parses a term from the token stream.
361    fn term(&mut self) -> Result<ASTNode, String> {
362        let mut left = self.factor()?;
363
364        while let Some(op) =
365            self.peek_next_operator(&[Operator::Multiply, Operator::Divide, Operator::Modulo])
366        {
367            self.position += 1; // Consume Operator
368            let right = self.factor()?;
369            left = ASTNode::BinaryOp {
370                op,
371                left: Box::new(left),
372                right: Box::new(right),
373            };
374        }
375
376        Ok(left)
377    }
378
379    /// Parses a factor from the token stream.
380    fn factor(&mut self) -> Result<ASTNode, String> {
381        if let Some(Token::Operator(Operator::Plus)) = self.peek_next_token() {
382            self.position += 1; // Consume '+'
383            return self.factor();
384        }
385
386        if let Some(Token::Operator(Operator::Minus)) = self.peek_next_token() {
387            self.position += 1; // Consume '-'
388            let right = self.primary()?;
389            return Ok(ASTNode::BinaryOp {
390                op: Operator::Minus,
391                left: Box::new(ASTNode::Number(0)), // Represent as 0 - right
392                right: Box::new(right),
393            });
394        }
395
396        self.primary()
397    }
398
399    /// Parses a primary expression from the token stream.
400    fn primary(&mut self) -> Result<ASTNode, String> {
401        let next_token = self.peek_next_token().cloned();
402        // First, handle expressions enclosed in parentheses
403        if let Some(Token::OpenParen) = next_token {
404            self.position += 1; // Consume '('
405            let expr = self.expression()?; // Evaluate the enclosed expression
406            if let Some(Token::CloseParen) = self.peek_next_token() {
407                self.position += 1; // Consume ')'
408                return Ok(expr);
409            } else {
410                return Err(format!("Expected ')', found {:?}", self.peek_next_token()));
411            }
412        }
413
414        if let Some(Token::BuiltInFunction(_)) = next_token {
415            return parse_func_call!(self);
416        }
417
418        if let Some(Token::StringLiteral(s)) = next_token {
419            self.position += 1;
420            return Ok(ASTNode::StringLiteral(s));
421        }
422
423        // Next, handle numbers and variables
424        match next_token {
425            Some(Token::Number(_)) => self.expect_number(),
426            Some(Token::Identifier(_)) => {
427                if let Some(Token::Assign) = self.peek_nth_token(1) {
428                    self.expect_assignment()
429                } else {
430                    self.expect_variable()
431                }
432            }
433            // Here you can extend to handle other primary expressions
434            _ => Err(format!(
435                "Expected a primary expression, found {:?}",
436                self.peek_next_token()
437            )),
438        }
439    }
440
441    /// Peeks at the next token in the token stream without consuming it.
442    fn peek_next_token(&self) -> Option<&Token> {
443        if self.position < self.tokens.len() {
444            Some(&self.tokens[self.position])
445        } else {
446            None
447        }
448    }
449
450    /// Peeks at the next operator token in the token stream without consuming it.
451    fn peek_next_operator(&self, ops: &[Operator]) -> Option<Operator> {
452        if self.position >= self.tokens.len() {
453            return None;
454        }
455
456        match self.tokens[self.position] {
457            Token::Operator(ref op) if ops.contains(op) => Some(op.clone()),
458            _ => None,
459        }
460    }
461
462    /// Peeks at the nth token in the token stream without consuming it.
463    fn peek_nth_token(&self, n: usize) -> Option<&Token> {
464        if self.position + n < self.tokens.len() {
465            Some(&self.tokens[self.position + n])
466        } else {
467            None
468        }
469    }
470
471    /// Expects and parses an assignment from the token stream.
472    fn expect_assignment(&mut self) -> Result<ASTNode, String> {
473        let name = match &self.tokens[self.position] {
474            Token::Identifier(ref ident) => ident.clone(),
475            _ => {
476                return Err(format!(
477                    "Expected an identifier, found {:?}",
478                    self.tokens[self.position]
479                ))
480            }
481        };
482        self.position += 2; // Consume Identifier and Assign
483        let value = self.expression()?;
484        Ok(ASTNode::Assign {
485            name,
486            value: Box::new(value),
487        })
488    }
489
490    /// Expects and parses a variable from the token stream.
491    fn expect_variable(&mut self) -> Result<ASTNode, String> {
492        expect_token!(self, Identifier(name) => Variable)
493    }
494
495    /// Expects and parses a number from the token stream.
496    fn expect_number(&mut self) -> Result<ASTNode, String> {
497        expect_token!(self, Number(val) => Number)
498    }
499
500    /// Parses and optimizes an AST.
501    ///
502    /// This method first parses the tokens into an AST and then performs
503    /// constant propagation and constant folding to optimize the AST.
504    ///
505    /// # Example
506    ///
507    /// Assuming Ferric code looks like:
508    /// ```
509    /// let x = 5;
510    /// let y = x + 3;
511    /// ```
512    ///
513    /// The resulting AST after parsing would represent the above code structure.
514    ///
515    /// ```rust
516    /// use ferric_parser::parser::Parser;
517    /// use ferric_lexer::Lexer;
518    ///
519    /// let code = "let x = 5; let y = x + 3;";
520    /// let tokens = Lexer::new(code).lex();
521    /// let mut parser = Parser::new(tokens);
522    /// let ast = parser.parse().unwrap();
523    /// let optimized_ast = parser.optimize(ast);
524    /// ```
525    pub fn optimize(&mut self, node: ASTNode) -> ASTNode {
526        let node = propagate_constants(node, &mut self.symtable);
527        fold_constants(node)
528    }
529}