Skip to main content

oak_rust/parser/
mod.rs

1use crate::{RustLanguage, lexer::RustLexer};
2use oak_core::{
3    GreenNode, OakError, TokenType,
4    parser::{Associativity, ParseCache, ParseOutput, Parser, ParserState, Pratt, PrattParser, binary, parse_with_lexer, unary},
5    source::{Source, TextEdit},
6};
7
8/// Rust element type definitions.
9pub mod element_type;
10pub use self::element_type::RustElementType;
11
12/// A parser for the Rust programming language.
13#[derive(Clone)]
14pub struct RustParser<'config> {
15    /// Reference to the Rust language configuration
16    #[allow(dead_code)]
17    config: &'config RustLanguage,
18}
19
20impl<'config> RustParser<'config> {
21    /// Creates a new Rust parser with the given language configuration.
22    pub fn new(config: &'config RustLanguage) -> Self {
23        Self { config }
24    }
25}
26
27impl<'config> Pratt<RustLanguage> for RustParser<'config> {
28    fn primary<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> &'a GreenNode<'a, RustLanguage> {
29        let cp = state.checkpoint();
30        match state.peek_kind() {
31            Some(crate::lexer::RustTokenType::Identifier) => {
32                state.bump();
33                state.finish_at(cp, crate::parser::element_type::RustElementType::IdentifierExpression)
34            }
35            Some(k) if k.is_literal() => {
36                state.bump();
37                state.finish_at(cp, crate::parser::element_type::RustElementType::LiteralExpression)
38            }
39            Some(crate::lexer::RustTokenType::LeftParen) => {
40                state.bump();
41                PrattParser::parse(state, 0, self);
42                state.expect(crate::lexer::RustTokenType::RightParen).ok();
43                state.finish_at(cp, crate::parser::element_type::RustElementType::ParenthesizedExpression)
44            }
45            _ => {
46                state.bump();
47                state.finish_at(cp, crate::parser::element_type::RustElementType::Error)
48            }
49        }
50    }
51
52    fn prefix<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> &'a GreenNode<'a, RustLanguage> {
53        use crate::{lexer::RustTokenType::*, parser::RustElementType as RE};
54        let kind = match state.peek_kind() {
55            Some(k) => k,
56            None => return self.primary(state),
57        };
58
59        match kind {
60            Minus | Bang | Ampersand | Star => unary(state, kind, 13, RE::UnaryExpression.into(), |s, p| PrattParser::parse(s, p, self)),
61            _ => self.primary(state),
62        }
63    }
64
65    fn infix<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>, left: &'a GreenNode<'a, RustLanguage>, min_precedence: u8) -> Option<&'a GreenNode<'a, RustLanguage>> {
66        use crate::{lexer::RustTokenType::*, parser::RustElementType as RE};
67        let kind = state.peek_kind()?;
68
69        let (prec, assoc) = match kind {
70            Eq | PlusEq | MinusEq | StarEq | SlashEq | PercentEq | AndEq | OrEq | CaretEq | ShlEq | ShrEq => (1, Associativity::Right),
71            DotDot | DotDotEq => (2, Associativity::Left),
72            OrOr => (3, Associativity::Left),
73            AndAnd => (4, Associativity::Left),
74            EqEq | Ne => (5, Associativity::Left),
75            Lt | Le | Gt | Ge => (6, Associativity::Left),
76            Pipe => (7, Associativity::Left),
77            Caret => (8, Associativity::Left),
78            Ampersand => (9, Associativity::Left),
79            Shl | Shr => (10, Associativity::Left),
80            Plus | Minus => (11, Associativity::Left),
81            Star | Slash | Percent => (12, Associativity::Left),
82            LeftParen | LeftBracket | Dot => (14, Associativity::Left),
83            _ => return None,
84        };
85
86        if prec < min_precedence {
87            return None;
88        }
89
90        match kind {
91            LeftParen => {
92                let cp = state.checkpoint();
93                state.push_child(left);
94                state.expect(LeftParen).ok();
95                if !state.at(RightParen) {
96                    loop {
97                        PrattParser::parse(state, 0, self);
98                        if !state.eat(Comma) {
99                            break;
100                        }
101                    }
102                }
103                state.expect(RightParen).ok();
104                Some(state.finish_at(cp, RE::CallExpression))
105            }
106            LeftBracket => {
107                let cp = state.checkpoint();
108                state.push_child(left);
109                state.expect(LeftBracket).ok();
110                PrattParser::parse(state, 0, self);
111                state.expect(RightBracket).ok();
112                Some(state.finish_at(cp, RE::IndexExpression))
113            }
114            Dot => {
115                let cp = state.checkpoint();
116                state.push_child(left);
117                state.expect(Dot).ok();
118                state.expect(Identifier).ok();
119                Some(state.finish_at(cp, RE::MemberExpression))
120            }
121            _ => Some(binary(state, left, kind, prec, assoc, RE::BinaryExpression.into(), |s, p| PrattParser::parse(s, p, self))),
122        }
123    }
124}
125
126impl<'config> Parser<RustLanguage> for RustParser<'config> {
127    fn parse<'a, S: Source + ?Sized>(&self, text: &'a S, edits: &[TextEdit], cache: &'a mut impl ParseCache<RustLanguage>) -> ParseOutput<'a, RustLanguage> {
128        let lexer = RustLexer::new(self.config);
129        parse_with_lexer(&lexer, text, edits, cache, |state| {
130            let cp = state.checkpoint();
131            while state.not_at_end() {
132                if state.current().map(|t| t.kind.is_ignored()).unwrap_or(false) {
133                    state.advance();
134                    continue;
135                }
136                self.parse_statement(state)?
137            }
138            let root = state.finish_at(cp, crate::parser::element_type::RustElementType::SourceFile);
139            Ok(root)
140        })
141    }
142}
143
144impl<'config> RustParser<'config> {
145    /// Parses a single Rust statement or item.
146    ///
147    /// This method identifies the type of statement or item at the current position
148    /// and dispatches to the appropriate parsing method. If no specific statement
149    /// type is recognized, it parses an expression followed by a semicolon.
150    ///
151    /// # Arguments
152    /// * `state` - The current parser state
153    ///
154    /// # Returns
155    /// * `Result<(), OakError>` - Ok if parsing succeeds, Err otherwise
156    fn parse_statement<'a, S: Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
157        use crate::{lexer::RustTokenType, parser::RustElementType::*};
158
159        let kind = match state.peek_kind() {
160            Some(RustTokenType::Fn) => Some(Function),
161            Some(RustTokenType::Use) => Some(UseItem),
162            Some(RustTokenType::Mod) => Some(ModuleItem),
163            Some(RustTokenType::Struct) => Some(StructItem),
164            Some(RustTokenType::Enum) => Some(EnumItem),
165            Some(RustTokenType::Let) => Some(LetStatement),
166            Some(RustTokenType::If) => Some(IfExpression),
167            Some(RustTokenType::While) => Some(WhileExpression),
168            Some(RustTokenType::Loop) => Some(LoopExpression),
169            Some(RustTokenType::For) => Some(ForExpression),
170            Some(RustTokenType::Return) => Some(ReturnStatement),
171            Some(RustTokenType::LeftBrace) => Some(Block),
172            _ => None,
173        };
174
175        if let Some(k) = kind {
176            state.incremental_node(k.into(), |state| match k {
177                Function => self.parse_function_body(state),
178                UseItem => self.parse_use_item_body(state),
179                ModuleItem => self.parse_mod_item_body(state),
180                StructItem => self.parse_struct_item_body(state),
181                EnumItem => self.parse_enum_item_body(state),
182                LetStatement => self.parse_let_statement_body(state),
183                IfExpression => self.parse_if_expression_body(state),
184                WhileExpression => self.parse_while_expression_body(state),
185                LoopExpression => self.parse_loop_expression_body(state),
186                ForExpression => self.parse_for_expression_body(state),
187                ReturnStatement => self.parse_return_statement_body(state),
188                Block => self.parse_block_body(state),
189                _ => unreachable!(),
190            })
191        }
192        else {
193            PrattParser::parse(state, 0, self);
194            state.eat(RustTokenType::Semicolon);
195            Ok(())
196        }
197    }
198
199    fn parse_function_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
200        self.parse_function(state)
201    }
202
203    fn parse_use_item_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
204        self.parse_use_item(state)
205    }
206
207    fn parse_mod_item_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
208        self.parse_mod_item(state)
209    }
210
211    fn parse_struct_item_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
212        self.parse_struct_item(state)
213    }
214
215    fn parse_enum_item_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
216        self.parse_enum_item(state)
217    }
218
219    fn parse_let_statement_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
220        self.parse_let_statement(state)
221    }
222
223    fn parse_if_expression_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
224        self.parse_if_expression(state)
225    }
226
227    fn parse_while_expression_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
228        self.parse_while_expression(state)
229    }
230
231    fn parse_loop_expression_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
232        self.parse_loop_expression(state)
233    }
234
235    fn parse_for_expression_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
236        self.parse_for_expression(state)
237    }
238
239    fn parse_return_statement_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
240        self.parse_return_statement(state)
241    }
242
243    fn parse_block_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
244        self.parse_block(state)
245    }
246
247    /// Parses a function definition.
248    ///
249    /// This method parses a complete Rust function definition, including the function
250    /// keyword, name, parameters, return type (if specified), and body.
251    ///
252    /// # Arguments
253    /// * `state` - The current parser state
254    ///
255    /// # Returns
256    /// * `Result<(), OakError>` - Ok if parsing succeeds, Err otherwise
257    fn parse_function<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
258        use crate::lexer::RustTokenType;
259        let cp = state.checkpoint();
260        state.expect(RustTokenType::Fn).ok();
261        state.expect(RustTokenType::Identifier).ok();
262        self.parse_param_list(state)?;
263        if state.eat(RustTokenType::Arrow) {
264            while state.not_at_end() && !state.at(RustTokenType::LeftBrace) {
265                state.advance()
266            }
267        }
268        self.parse_block(state)?;
269        state.finish_at(cp, crate::parser::element_type::RustElementType::Function);
270        Ok(())
271    }
272
273    /// Parses a function parameter list.
274    ///
275    /// This method parses the parameters of a function definition, enclosed in parentheses.
276    ///
277    /// # Arguments
278    /// * `state` - The current parser state
279    ///
280    /// # Returns
281    /// * `Result<(), OakError>` - Ok if parsing succeeds, Err otherwise
282    fn parse_param_list<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
283        use crate::lexer::RustTokenType::*;
284        let cp = state.checkpoint();
285        state.expect(LeftParen).ok();
286        while state.not_at_end() && !state.at(RightParen) {
287            state.advance()
288        }
289        state.expect(RightParen).ok();
290        state.finish_at(cp, crate::parser::element_type::RustElementType::ParameterList);
291        Ok(())
292    }
293
294    /// Parses a block of statements enclosed in braces.
295    ///
296    /// This method parses a block of code enclosed in curly braces, which can contain
297    /// multiple statements and nested blocks.
298    ///
299    /// # Arguments
300    /// * `state` - The current parser state
301    ///
302    /// # Returns
303    /// * `Result<(), OakError>` - Ok if parsing succeeds, Err otherwise
304    fn parse_block<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
305        use crate::lexer::RustTokenType::*;
306        let cp = state.checkpoint();
307        state.expect(LeftBrace).ok();
308        while state.not_at_end() && !state.at(RightBrace) {
309            self.parse_statement(state)?
310        }
311        state.expect(RightBrace).ok();
312        state.finish_at(cp, crate::parser::element_type::RustElementType::BlockExpression);
313        Ok(())
314    }
315
316    /// Parses a `use` declaration.
317    fn parse_use_item<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
318        let cp = state.checkpoint();
319        state.expect(crate::lexer::RustTokenType::Use).ok();
320        // Simplified path handling
321        while !state.at(crate::lexer::RustTokenType::Semicolon) && state.not_at_end() {
322            state.bump()
323        }
324        state.eat(crate::lexer::RustTokenType::Semicolon);
325        state.finish_at(cp, crate::parser::element_type::RustElementType::UseItem);
326        Ok(())
327    }
328
329    /// Parses a module declaration.
330    fn parse_mod_item<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
331        let cp = state.checkpoint();
332        state.bump(); // mod
333        state.expect(crate::lexer::RustTokenType::Identifier).ok();
334        if state.at(crate::lexer::RustTokenType::LeftBrace) {
335            self.parse_block(state)?
336        }
337        else {
338            state.eat(crate::lexer::RustTokenType::Semicolon);
339        }
340        state.finish_at(cp, crate::parser::element_type::RustElementType::ModuleItem);
341        Ok(())
342    }
343
344    /// Parses a struct definition.
345    fn parse_struct_item<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
346        let cp = state.checkpoint();
347        state.bump(); // struct
348        state.expect(crate::lexer::RustTokenType::Identifier).ok();
349        while state.not_at_end() && !state.at(crate::lexer::RustTokenType::LeftBrace) && !state.at(crate::lexer::RustTokenType::Semicolon) {
350            state.advance()
351        }
352        if state.at(crate::lexer::RustTokenType::LeftBrace) {
353            self.parse_block(state)?
354        }
355        else {
356            state.eat(crate::lexer::RustTokenType::Semicolon);
357        }
358        state.finish_at(cp, crate::parser::element_type::RustElementType::StructItem);
359        Ok(())
360    }
361
362    /// Parses an enum definition.
363    fn parse_enum_item<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
364        let cp = state.checkpoint();
365        state.bump(); // enum
366        state.expect(crate::lexer::RustTokenType::Identifier).ok();
367        self.parse_block(state)?;
368        state.finish_at(cp, crate::parser::element_type::RustElementType::EnumItem);
369        Ok(())
370    }
371
372    /// Parses a `let` statement.
373    fn parse_let_statement<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
374        let cp = state.checkpoint();
375        state.bump(); // let
376        state.expect(crate::lexer::RustTokenType::Identifier).ok();
377        if state.eat(crate::lexer::RustTokenType::Eq) {
378            PrattParser::parse(state, 0, self);
379        }
380        state.eat(crate::lexer::RustTokenType::Semicolon);
381        state.finish_at(cp, crate::parser::element_type::RustElementType::LetStatement);
382        Ok(())
383    }
384
385    /// Parses an `if` expression.
386    fn parse_if_expression<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
387        let cp = state.checkpoint();
388        state.bump(); // if
389        PrattParser::parse(state, 0, self);
390        self.parse_block(state)?;
391        if state.eat(crate::lexer::RustTokenType::Else) {
392            if state.at(crate::lexer::RustTokenType::If) { self.parse_if_expression(state)? } else { self.parse_block(state)? }
393        }
394        state.finish_at(cp, crate::parser::element_type::RustElementType::IfExpression);
395        Ok(())
396    }
397
398    /// Parses a `while` loop.
399    fn parse_while_expression<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
400        let cp = state.checkpoint();
401        state.bump(); // while
402        PrattParser::parse(state, 0, self);
403        self.parse_block(state)?;
404        state.finish_at(cp, crate::parser::element_type::RustElementType::WhileExpression);
405        Ok(())
406    }
407
408    /// Parses a `loop` expression.
409    fn parse_loop_expression<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
410        let cp = state.checkpoint();
411        state.bump(); // loop
412        self.parse_block(state)?;
413        state.finish_at(cp, crate::parser::element_type::RustElementType::LoopExpression);
414        Ok(())
415    }
416
417    /// Parses a `for` loop.
418    fn parse_for_expression<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
419        let cp = state.checkpoint();
420        state.bump(); // for
421        state.expect(crate::lexer::RustTokenType::Identifier).ok();
422        state.expect(crate::lexer::RustTokenType::In).ok();
423        PrattParser::parse(state, 0, self);
424        self.parse_block(state)?;
425        state.finish_at(cp, crate::parser::element_type::RustElementType::ForExpression);
426        Ok(())
427    }
428
429    /// Parses a `return` statement.
430    fn parse_return_statement<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
431        let cp = state.checkpoint();
432        state.bump(); // return
433        if !state.at(crate::lexer::RustTokenType::Semicolon) {
434            PrattParser::parse(state, 0, self);
435        }
436        state.eat(crate::lexer::RustTokenType::Semicolon);
437        state.finish_at(cp, crate::parser::element_type::RustElementType::ReturnStatement);
438        Ok(())
439    }
440}