Skip to main content

oak_mojo/parser/
mod.rs

1/// Element type definitions for the Mojo parser.
2pub mod element_type;
3pub use element_type::MojoElementType;
4
5use crate::{
6    MojoLanguage,
7    ast::*,
8    lexer::{MojoLexer, MojoTokenType},
9};
10use oak_core::{
11    OakError,
12    parser::{
13        ParseCache, ParseOutput, Parser, ParserState, parse_with_lexer,
14        pratt::{Associativity, Pratt, PrattParser},
15    },
16    source::{Source, TextEdit},
17    tree::GreenNode,
18};
19
20pub(crate) type State<'a, S> = ParserState<'a, MojoLanguage, S>;
21
22/// Mojo syntax parser
23pub struct MojoParser<'config> {
24    config: &'config MojoLanguage,
25}
26
27impl<'config> Parser<MojoLanguage> for MojoParser<'config> {
28    fn parse<'a, S: Source + ?Sized>(&self, source: &'a S, edits: &[TextEdit], cache: &'a mut impl ParseCache<MojoLanguage>) -> ParseOutput<'a, MojoLanguage> {
29        let lexer = MojoLexer::new(self.config);
30        parse_with_lexer(&lexer, source, edits, cache, |state| {
31            let cp = state.checkpoint();
32            while state.not_at_end() {
33                self.skip_trivia(state);
34                if !state.not_at_end() {
35                    break;
36                }
37                if state.at(MojoTokenType::Newline) {
38                    state.bump();
39                    continue;
40                }
41                self.parse_statement(state)?;
42            }
43            Ok(state.finish_at(cp, MojoElementType::Root.into()))
44        })
45    }
46}
47
48impl<'config> MojoParser<'config> {
49    /// Creates a new Mojo parser with the given language configuration.
50    pub fn new(config: &'config MojoLanguage) -> Self {
51        Self { config }
52    }
53
54    fn skip_trivia<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) {
55        while state.not_at_end() {
56            if let Some(kind) = state.peek_kind() {
57                if kind == MojoTokenType::Whitespace || kind == MojoTokenType::Comment {
58                    state.bump();
59                    continue;
60                }
61            }
62            break;
63        }
64    }
65
66    fn parse_statement<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
67        self.skip_trivia(state);
68        if state.at(MojoTokenType::Fn) {
69            self.parse_function_def(state)
70        }
71        else if state.at(MojoTokenType::Var) || state.at(MojoTokenType::Let) {
72            self.parse_variable_decl(state)
73        }
74        else if state.at(MojoTokenType::If) {
75            self.parse_if_stmt(state)
76        }
77        else if state.at(MojoTokenType::While) {
78            self.parse_while_stmt(state)
79        }
80        else if state.at(MojoTokenType::Return) {
81            self.parse_return_stmt(state)
82        }
83        else {
84            self.parse_expression_stmt(state)
85        }
86    }
87
88    fn parse_function_def<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
89        state.incremental_node(MojoElementType::FunctionDef.into(), |state| {
90            state.expect(MojoTokenType::Fn)?;
91            self.skip_trivia(state);
92            state.expect(MojoTokenType::Identifier)?;
93            self.skip_trivia(state);
94            state.expect(MojoTokenType::LeftParen)?;
95            self.parse_param_list(state)?;
96            state.expect(MojoTokenType::RightParen)?;
97            self.skip_trivia(state);
98            if state.eat(MojoTokenType::Arrow) {
99                self.skip_trivia(state);
100                state.expect(MojoTokenType::Identifier)?; // Return type
101                self.skip_trivia(state);
102            }
103            state.expect(MojoTokenType::Colon)?;
104            self.parse_block(state)
105        })
106    }
107
108    fn parse_param_list<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
109        state.incremental_node(MojoElementType::ParamList.into(), |state| {
110            while state.not_at_end() && !state.at(MojoTokenType::RightParen) {
111                self.skip_trivia(state);
112                state.expect(MojoTokenType::Identifier)?;
113                self.skip_trivia(state);
114                if state.eat(MojoTokenType::Colon) {
115                    self.skip_trivia(state);
116                    state.expect(MojoTokenType::Identifier)?; // Param type
117                    self.skip_trivia(state);
118                }
119                if !state.eat(MojoTokenType::Comma) {
120                    break;
121                }
122            }
123            Ok(())
124        })
125    }
126
127    fn parse_variable_decl<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
128        state.incremental_node(MojoElementType::VariableDecl.into(), |state| {
129            if state.at(MojoTokenType::Var) {
130                state.bump();
131            }
132            else {
133                state.expect(MojoTokenType::Let)?;
134            }
135            self.skip_trivia(state);
136            state.expect(MojoTokenType::Identifier)?;
137            self.skip_trivia(state);
138            if state.eat(MojoTokenType::Colon) {
139                self.skip_trivia(state);
140                state.expect(MojoTokenType::Identifier)?; // Type
141                self.skip_trivia(state);
142            }
143            if state.eat(MojoTokenType::Equal) {
144                self.skip_trivia(state);
145                self.parse_expression(state, 0);
146            }
147            state.eat(MojoTokenType::Newline);
148            Ok(())
149        })
150    }
151
152    fn parse_if_stmt<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
153        state.incremental_node(MojoElementType::IfStatement.into(), |state| {
154            state.expect(MojoTokenType::If)?;
155            self.skip_trivia(state);
156            self.parse_expression(state, 0);
157            self.skip_trivia(state);
158            state.expect(MojoTokenType::Colon)?;
159            self.parse_block(state)?;
160            self.skip_trivia(state);
161            if state.eat(MojoTokenType::Else) {
162                self.skip_trivia(state);
163                if state.at(MojoTokenType::If) {
164                    self.parse_if_stmt(state)?;
165                }
166                else {
167                    state.expect(MojoTokenType::Colon)?;
168                    self.parse_block(state)?;
169                }
170            }
171            Ok(())
172        })
173    }
174
175    fn parse_while_stmt<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
176        state.incremental_node(MojoElementType::WhileStatement.into(), |state| {
177            state.expect(MojoTokenType::While)?;
178            self.skip_trivia(state);
179            self.parse_expression(state, 0);
180            self.skip_trivia(state);
181            state.expect(MojoTokenType::Colon)?;
182            self.parse_block(state)
183        })
184    }
185
186    fn parse_return_stmt<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
187        state.incremental_node(MojoElementType::ReturnStatement.into(), |state| {
188            state.expect(MojoTokenType::Return)?;
189            self.skip_trivia(state);
190            if !state.at(MojoTokenType::Newline) && !state.at(MojoTokenType::EndOfStream) {
191                self.parse_expression(state, 0);
192            }
193            state.eat(MojoTokenType::Newline);
194            Ok(())
195        })
196    }
197
198    fn parse_expression_stmt<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
199        state.incremental_node(MojoElementType::ExpressionStatement.into(), |state| {
200            self.parse_expression(state, 0);
201            state.eat(MojoTokenType::Newline);
202            Ok(())
203        })
204    }
205
206    fn parse_block<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
207        state.incremental_node(MojoElementType::Block.into(), |state| {
208            // Skip spaces after colon
209            self.skip_trivia(state);
210            // Must have a newline
211            state.expect(MojoTokenType::Newline)?;
212            // Multiple empty lines may follow
213            while state.eat(MojoTokenType::Newline) {
214                self.skip_trivia(state);
215            }
216            // Indent starts
217            state.expect(MojoTokenType::Indent)?;
218            while state.not_at_end() && !state.at(MojoTokenType::Dedent) {
219                self.skip_trivia(state);
220                if state.eat(MojoTokenType::Newline) {
221                    continue;
222                }
223                self.parse_statement(state)?;
224            }
225            state.expect(MojoTokenType::Dedent)
226        })
227    }
228
229    fn parse_expression<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>, min_precedence: u8) -> &'a GreenNode<'a, MojoLanguage> {
230        PrattParser::parse(state, min_precedence, self)
231    }
232}
233
234impl<'config> Pratt<MojoLanguage> for MojoParser<'config> {
235    fn primary<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> &'a GreenNode<'a, MojoLanguage> {
236        self.skip_trivia(state);
237        let cp = state.checkpoint();
238        if state.at(MojoTokenType::Identifier) {
239            state.bump();
240            state.finish_at(cp, MojoElementType::IdentifierExpr.into())
241        }
242        else if state.at(MojoTokenType::Integer) || state.at(MojoTokenType::Float) || state.at(MojoTokenType::String) {
243            state.bump();
244            state.finish_at(cp, MojoElementType::LiteralExpr.into())
245        }
246        else if state.at(MojoTokenType::LeftParen) {
247            state.bump();
248            self.parse_expression(state, 0);
249            state.expect(MojoTokenType::RightParen).ok();
250            state.finish_at(cp, MojoElementType::Grouping.into())
251        }
252        else {
253            state.bump(); // Error recovery
254            state.finish_at(cp, MojoElementType::Error.into())
255        }
256    }
257
258    fn prefix<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> &'a GreenNode<'a, MojoLanguage> {
259        self.primary(state)
260    }
261
262    fn infix<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>, left: &'a GreenNode<'a, MojoLanguage>, min_precedence: u8) -> Option<&'a GreenNode<'a, MojoLanguage>> {
263        self.skip_trivia(state);
264        let kind = state.peek_kind()?;
265        let (precedence, associativity) = match kind {
266            MojoTokenType::Plus | MojoTokenType::Minus => (10, Associativity::Left),
267            MojoTokenType::Star | MojoTokenType::Slash | MojoTokenType::Percent => (20, Associativity::Left),
268            MojoTokenType::EqualEqual | MojoTokenType::NotEqual | MojoTokenType::Less | MojoTokenType::LessEqual | MojoTokenType::Greater | MojoTokenType::GreaterEqual => (5, Associativity::Left),
269            _ => return None,
270        };
271
272        if precedence < min_precedence {
273            return None;
274        }
275
276        let cp = state.checkpoint_before(left);
277        state.bump();
278        let next_prec = if associativity == Associativity::Left { precedence + 1 } else { precedence };
279        self.parse_expression(state, next_prec);
280        Some(state.finish_at(cp, MojoElementType::BinaryExpr.into()))
281    }
282}