Skip to main content

oak_rust/parser/
mod.rs

1use crate::{RustLanguage, lexer::RustLexer};
2use oak_core::{
3    GreenNode, OakError, TokenType,
4    parser::{Associativity, ParseCache, ParseOutput, Parser, ParserState, Pratt, PrattParser, binary, parse_with_lexer, unary},
5    source::{Source, TextEdit},
6};
7
8/// Rust element type definitions.
9pub mod element_type;
10pub use self::element_type::RustElementType;
11
12/// A parser for the Rust programming language.
13#[derive(Clone)]
14pub struct RustParser<'config> {
15    /// Reference to the Rust language configuration
16    #[allow(dead_code)]
17    config: &'config RustLanguage,
18}
19
20impl<'config> RustParser<'config> {
21    /// Creates a new Rust parser with the given language configuration.
22    pub fn new(config: &'config RustLanguage) -> Self {
23        Self { config }
24    }
25}
26
27impl<'config> Pratt<RustLanguage> for RustParser<'config> {
28    fn primary<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> &'a GreenNode<'a, RustLanguage> {
29        let cp = state.checkpoint();
30        match state.peek_kind() {
31            Some(crate::lexer::RustTokenType::Identifier) => {
32                state.bump();
33                state.finish_at(cp, crate::parser::element_type::RustElementType::IdentifierExpression)
34            }
35            Some(k) if k.is_literal() => {
36                state.bump();
37                state.finish_at(cp, crate::parser::element_type::RustElementType::LiteralExpression)
38            }
39            Some(crate::lexer::RustTokenType::LeftParen) => {
40                state.bump();
41                PrattParser::parse(state, 0, self);
42                state.expect(crate::lexer::RustTokenType::RightParen).ok();
43                state.finish_at(cp, crate::parser::element_type::RustElementType::ParenthesizedExpression)
44            }
45            _ => {
46                state.bump();
47                state.finish_at(cp, crate::parser::element_type::RustElementType::Error)
48            }
49        }
50    }
51
52    fn prefix<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> &'a GreenNode<'a, RustLanguage> {
53        use crate::{lexer::RustTokenType::*, parser::RustElementType as RE};
54        let kind = match state.peek_kind() {
55            Some(k) => k,
56            None => return self.primary(state),
57        };
58
59        match kind {
60            Minus | Bang | Ampersand | Star => unary(state, kind, 13, RE::UnaryExpression.into(), |s, p| PrattParser::parse(s, p, self)),
61            _ => self.primary(state),
62        }
63    }
64
65    fn infix<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>, left: &'a GreenNode<'a, RustLanguage>, min_precedence: u8) -> Option<&'a GreenNode<'a, RustLanguage>> {
66        use crate::{lexer::RustTokenType::*, parser::RustElementType as RE};
67        let kind = state.peek_kind()?;
68
69        let (prec, assoc) = match kind {
70            Eq | PlusEq | MinusEq | StarEq | SlashEq | PercentEq | AndEq | OrEq | CaretEq | ShlEq | ShrEq => (1, Associativity::Right),
71            DotDot | DotDotEq => (2, Associativity::Left),
72            OrOr => (3, Associativity::Left),
73            AndAnd => (4, Associativity::Left),
74            EqEq | Ne => (5, Associativity::Left),
75            Lt | Le | Gt | Ge => (6, Associativity::Left),
76            Pipe => (7, Associativity::Left),
77            Caret => (8, Associativity::Left),
78            Ampersand => (9, Associativity::Left),
79            Shl | Shr => (10, Associativity::Left),
80            Plus | Minus => (11, Associativity::Left),
81            Star | Slash | Percent => (12, Associativity::Left),
82            LeftParen | LeftBracket | Dot => (14, Associativity::Left),
83            _ => return None,
84        };
85
86        if prec < min_precedence {
87            return None;
88        }
89
90        match kind {
91            LeftParen => {
92                let cp = state.checkpoint();
93                state.push_child(left);
94                state.expect(LeftParen).ok();
95                if !state.at(RightParen) {
96                    loop {
97                        PrattParser::parse(state, 0, self);
98                        if !state.eat(Comma) {
99                            break;
100                        }
101                    }
102                }
103                state.expect(RightParen).ok();
104                Some(state.finish_at(cp, RE::CallExpression))
105            }
106            LeftBracket => {
107                let cp = state.checkpoint();
108                state.push_child(left);
109                state.expect(LeftBracket).ok();
110                PrattParser::parse(state, 0, self);
111                state.expect(RightBracket).ok();
112                Some(state.finish_at(cp, RE::IndexExpression))
113            }
114            Dot => {
115                let cp = state.checkpoint();
116                state.push_child(left);
117                state.expect(Dot).ok();
118                state.expect(Identifier).ok();
119                Some(state.finish_at(cp, RE::MemberExpression))
120            }
121            _ => Some(binary(state, left, kind, prec, assoc, RE::BinaryExpression.into(), |s, p| PrattParser::parse(s, p, self))),
122        }
123    }
124}
125
126impl<'config> Parser<RustLanguage> for RustParser<'config> {
127    fn parse<'a, S: Source + ?Sized>(&self, text: &'a S, edits: &[TextEdit], cache: &'a mut impl ParseCache<RustLanguage>) -> ParseOutput<'a, RustLanguage> {
128        let lexer = RustLexer::new(self.config);
129        parse_with_lexer(&lexer, text, edits, cache, |state| self.parse_source_file(state))
130    }
131}
132
133impl<'config> RustParser<'config> {
134    /// Parses a complete Rust source file.
135    pub(crate) fn parse_source_file<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<&'a GreenNode<'a, RustLanguage>, OakError> {
136        let cp = state.checkpoint();
137        while state.not_at_end() {
138            if state.current().map(|t| t.kind.is_ignored()).unwrap_or(false) {
139                state.advance();
140                continue;
141            }
142            self.parse_statement(state)?
143        }
144        let root = state.finish_at(cp, crate::parser::element_type::RustElementType::SourceFile);
145        Ok(root)
146    }
147
148    /// Parses a single Rust statement or item.
149    fn parse_statement<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
150        use crate::{lexer::RustTokenType, parser::RustElementType::*};
151
152        let kind = match state.peek_kind() {
153            Some(RustTokenType::Fn) => Some(Function),
154            Some(RustTokenType::Use) => Some(UseItem),
155            Some(RustTokenType::Mod) => Some(ModuleItem),
156            Some(RustTokenType::Struct) => Some(StructItem),
157            Some(RustTokenType::Enum) => Some(EnumItem),
158            Some(RustTokenType::Let) => Some(LetStatement),
159            Some(RustTokenType::If) => Some(IfExpression),
160            Some(RustTokenType::While) => Some(WhileExpression),
161            Some(RustTokenType::Loop) => Some(LoopExpression),
162            Some(RustTokenType::For) => Some(ForExpression),
163            Some(RustTokenType::Return) => Some(ReturnStatement),
164            Some(RustTokenType::LeftBrace) => Some(Block),
165            _ => None,
166        };
167
168        if let Some(k) = kind {
169            state.incremental_node(k.into(), |state| match k {
170                Function => self.parse_function_body(state),
171                UseItem => self.parse_use_item_body(state),
172                ModuleItem => self.parse_mod_item_body(state),
173                StructItem => self.parse_struct_item_body(state),
174                EnumItem => self.parse_enum_item_body(state),
175                LetStatement => self.parse_let_statement_body(state),
176                IfExpression => self.parse_if_expression_body(state),
177                WhileExpression => self.parse_while_expression_body(state),
178                LoopExpression => self.parse_loop_expression_body(state),
179                ForExpression => self.parse_for_expression_body(state),
180                ReturnStatement => self.parse_return_statement_body(state),
181                Block => self.parse_block_body(state),
182                _ => unreachable!(),
183            })
184        }
185        else {
186            PrattParser::parse(state, 0, self);
187            state.eat(RustTokenType::Semicolon);
188            Ok(())
189        }
190    }
191
192    fn parse_function_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
193        self.parse_function(state)
194    }
195
196    fn parse_use_item_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
197        self.parse_use_item(state)
198    }
199
200    fn parse_mod_item_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
201        self.parse_mod_item(state)
202    }
203
204    fn parse_struct_item_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
205        self.parse_struct_item(state)
206    }
207
208    fn parse_enum_item_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
209        self.parse_enum_item(state)
210    }
211
212    fn parse_let_statement_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
213        self.parse_let_statement(state)
214    }
215
216    fn parse_if_expression_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
217        self.parse_if_expression(state)
218    }
219
220    fn parse_while_expression_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
221        self.parse_while_expression(state)
222    }
223
224    fn parse_loop_expression_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
225        self.parse_loop_expression(state)
226    }
227
228    fn parse_for_expression_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
229        self.parse_for_expression(state)
230    }
231
232    fn parse_return_statement_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
233        self.parse_return_statement(state)
234    }
235
236    fn parse_block_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
237        self.parse_block(state)
238    }
239
240    /// Parses a function definition.
241    fn parse_function<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
242        use crate::lexer::token_type::RustTokenType;
243        let cp = state.checkpoint();
244        state.expect(RustTokenType::Fn).ok();
245        state.expect(RustTokenType::Identifier).ok();
246        self.parse_param_list(state)?;
247        if state.eat(RustTokenType::Arrow) {
248            while state.not_at_end() && !state.at(RustTokenType::LeftBrace) {
249                state.advance()
250            }
251        }
252        self.parse_block(state)?;
253        state.finish_at(cp, crate::parser::element_type::RustElementType::Function);
254        Ok(())
255    }
256
257    fn parse_param_list<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
258        use crate::lexer::RustTokenType::*;
259        let cp = state.checkpoint();
260        state.expect(LeftParen).ok();
261        while state.not_at_end() && !state.at(RightParen) {
262            state.advance()
263        }
264        state.expect(RightParen).ok();
265        state.finish_at(cp, crate::parser::element_type::RustElementType::ParameterList);
266        Ok(())
267    }
268
269    /// Parses a block of statements enclosed in braces.
270    fn parse_block<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
271        use crate::lexer::RustTokenType::*;
272        let cp = state.checkpoint();
273        state.expect(LeftBrace).ok();
274        while state.not_at_end() && !state.at(RightBrace) {
275            self.parse_statement(state)?
276        }
277        state.expect(RightBrace).ok();
278        state.finish_at(cp, crate::parser::element_type::RustElementType::BlockExpression);
279        Ok(())
280    }
281
282    /// Parses a `use` declaration.
283    fn parse_use_item<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
284        let cp = state.checkpoint();
285        state.expect(crate::lexer::RustTokenType::Use).ok();
286        // Simplified path handling
287        while !state.at(crate::lexer::RustTokenType::Semicolon) && state.not_at_end() {
288            state.bump()
289        }
290        state.eat(crate::lexer::RustTokenType::Semicolon);
291        state.finish_at(cp, crate::parser::element_type::RustElementType::UseItem);
292        Ok(())
293    }
294
295    /// Parses a module declaration.
296    fn parse_mod_item<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
297        let cp = state.checkpoint();
298        state.bump(); // mod
299        state.expect(crate::lexer::RustTokenType::Identifier).ok();
300        if state.at(crate::lexer::RustTokenType::LeftBrace) {
301            self.parse_block(state)?
302        }
303        else {
304            state.eat(crate::lexer::RustTokenType::Semicolon);
305        }
306        state.finish_at(cp, crate::parser::element_type::RustElementType::ModuleItem);
307        Ok(())
308    }
309
310    /// Parses a struct definition.
311    fn parse_struct_item<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
312        let cp = state.checkpoint();
313        state.bump(); // struct
314        state.expect(crate::lexer::RustTokenType::Identifier).ok();
315        while state.not_at_end() && !state.at(crate::lexer::RustTokenType::LeftBrace) && !state.at(crate::lexer::RustTokenType::Semicolon) {
316            state.advance()
317        }
318        if state.at(crate::lexer::RustTokenType::LeftBrace) {
319            self.parse_block(state)?
320        }
321        else {
322            state.eat(crate::lexer::RustTokenType::Semicolon);
323        }
324        state.finish_at(cp, crate::parser::element_type::RustElementType::StructItem);
325        Ok(())
326    }
327
328    /// Parses an enum definition.
329    fn parse_enum_item<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
330        let cp = state.checkpoint();
331        state.bump(); // enum
332        state.expect(crate::lexer::RustTokenType::Identifier).ok();
333        self.parse_block(state)?;
334        state.finish_at(cp, crate::parser::element_type::RustElementType::EnumItem);
335        Ok(())
336    }
337
338    /// Parses a `let` statement.
339    fn parse_let_statement<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
340        let cp = state.checkpoint();
341        state.bump(); // let
342        state.expect(crate::lexer::RustTokenType::Identifier).ok();
343        if state.eat(crate::lexer::RustTokenType::Eq) {
344            PrattParser::parse(state, 0, self);
345        }
346        state.eat(crate::lexer::RustTokenType::Semicolon);
347        state.finish_at(cp, crate::parser::element_type::RustElementType::LetStatement);
348        Ok(())
349    }
350
351    /// Parses an `if` expression.
352    fn parse_if_expression<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
353        let cp = state.checkpoint();
354        state.bump(); // if
355        PrattParser::parse(state, 0, self);
356        self.parse_block(state)?;
357        if state.eat(crate::lexer::RustTokenType::Else) {
358            if state.at(crate::lexer::RustTokenType::If) { self.parse_if_expression(state)? } else { self.parse_block(state)? }
359        }
360        state.finish_at(cp, crate::parser::element_type::RustElementType::IfExpression);
361        Ok(())
362    }
363
364    /// Parses a `while` loop.
365    fn parse_while_expression<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
366        let cp = state.checkpoint();
367        state.bump(); // while
368        PrattParser::parse(state, 0, self);
369        self.parse_block(state)?;
370        state.finish_at(cp, crate::parser::element_type::RustElementType::WhileExpression);
371        Ok(())
372    }
373
374    /// Parses a `loop` expression.
375    fn parse_loop_expression<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
376        let cp = state.checkpoint();
377        state.bump(); // loop
378        self.parse_block(state)?;
379        state.finish_at(cp, crate::parser::element_type::RustElementType::LoopExpression);
380        Ok(())
381    }
382
383    /// Parses a `for` loop.
384    fn parse_for_expression<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
385        let cp = state.checkpoint();
386        state.bump(); // for
387        state.expect(crate::lexer::RustTokenType::Identifier).ok();
388        state.expect(crate::lexer::RustTokenType::In).ok();
389        PrattParser::parse(state, 0, self);
390        self.parse_block(state)?;
391        state.finish_at(cp, crate::parser::element_type::RustElementType::ForExpression);
392        Ok(())
393    }
394
395    /// Parses a `return` statement.
396    fn parse_return_statement<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
397        let cp = state.checkpoint();
398        state.bump(); // return
399        if !state.at(crate::lexer::RustTokenType::Semicolon) {
400            PrattParser::parse(state, 0, self);
401        }
402        state.eat(crate::lexer::RustTokenType::Semicolon);
403        state.finish_at(cp, crate::parser::element_type::RustElementType::ReturnStatement);
404        Ok(())
405    }
406}