Skip to main content

oak_rust/parser/
mod.rs

1use crate::{RustLanguage, lexer::RustLexer};
2use oak_core::{
3    GreenNode, OakError, TokenType,
4    parser::{Associativity, ParseCache, ParseOutput, Parser, ParserState, Pratt, PrattParser, binary, parse_with_lexer, unary},
5    source::{Source, TextEdit},
6};
7
8/// Rust element type definitions.
9pub mod element_type;
10pub use self::element_type::RustElementType;
11
12/// A parser for the Rust programming language.
13#[derive(Clone)]
14pub struct RustParser<'config> {
15    /// Reference to the Rust language configuration
16    #[allow(dead_code)]
17    config: &'config RustLanguage,
18}
19
20impl<'config> RustParser<'config> {
21    /// Creates a new Rust parser with the given language configuration.
22    pub fn new(config: &'config RustLanguage) -> Self {
23        Self { config }
24    }
25}
26
27impl<'config> Pratt<RustLanguage> for RustParser<'config> {
28    fn primary<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> &'a GreenNode<'a, RustLanguage> {
29        let cp = state.checkpoint();
30        match state.peek_kind() {
31            Some(crate::lexer::RustTokenType::Identifier) => {
32                state.bump();
33                state.finish_at(cp, crate::parser::element_type::RustElementType::IdentifierExpression)
34            }
35            Some(k) if k.is_literal() => {
36                state.bump();
37                state.finish_at(cp, crate::parser::element_type::RustElementType::LiteralExpression)
38            }
39            Some(crate::lexer::RustTokenType::LeftParen) => {
40                state.bump();
41                PrattParser::parse(state, 0, self);
42                state.expect(crate::lexer::RustTokenType::RightParen).ok();
43                state.finish_at(cp, crate::parser::element_type::RustElementType::ParenthesizedExpression)
44            }
45            _ => {
46                state.bump();
47                state.finish_at(cp, crate::parser::element_type::RustElementType::Error)
48            }
49        }
50    }
51
52    fn prefix<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> &'a GreenNode<'a, RustLanguage> {
53        use crate::{lexer::RustTokenType::*, parser::RustElementType as RE};
54        let kind = match state.peek_kind() {
55            Some(k) => k,
56            None => return self.primary(state),
57        };
58
59        match kind {
60            Minus | Bang | Ampersand | Star => unary(state, kind, 13, RE::UnaryExpression.into(), |s, p| PrattParser::parse(s, p, self)),
61            _ => self.primary(state),
62        }
63    }
64
65    fn infix<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>, left: &'a GreenNode<'a, RustLanguage>, min_precedence: u8) -> Option<&'a GreenNode<'a, RustLanguage>> {
66        use crate::{lexer::RustTokenType::*, parser::RustElementType as RE};
67        let kind = state.peek_kind()?;
68
69        let (prec, assoc) = match kind {
70            Eq | PlusEq | MinusEq | StarEq | SlashEq | PercentEq | AndEq | OrEq | CaretEq | ShlEq | ShrEq => (1, Associativity::Right),
71            DotDot | DotDotEq => (2, Associativity::Left),
72            OrOr => (3, Associativity::Left),
73            AndAnd => (4, Associativity::Left),
74            EqEq | Ne => (5, Associativity::Left),
75            Lt | Le | Gt | Ge => (6, Associativity::Left),
76            Pipe => (7, Associativity::Left),
77            Caret => (8, Associativity::Left),
78            Ampersand => (9, Associativity::Left),
79            Shl | Shr => (10, Associativity::Left),
80            Plus | Minus => (11, Associativity::Left),
81            Star | Slash | Percent => (12, Associativity::Left),
82            LeftParen | LeftBracket | Dot => (14, Associativity::Left),
83            _ => return None,
84        };
85
86        if prec < min_precedence {
87            return None;
88        }
89
90        match kind {
91            LeftParen => {
92                let cp = state.checkpoint();
93                state.push_child(left);
94                state.expect(LeftParen).ok();
95                if !state.at(RightParen) {
96                    loop {
97                        PrattParser::parse(state, 0, self);
98                        if !state.eat(Comma) {
99                            break;
100                        }
101                    }
102                }
103                state.expect(RightParen).ok();
104                Some(state.finish_at(cp, RE::CallExpression))
105            }
106            LeftBracket => {
107                let cp = state.checkpoint();
108                state.push_child(left);
109                state.expect(LeftBracket).ok();
110                PrattParser::parse(state, 0, self);
111                state.expect(RightBracket).ok();
112                Some(state.finish_at(cp, RE::IndexExpression))
113            }
114            Dot => {
115                let cp = state.checkpoint();
116                state.push_child(left);
117                state.expect(Dot).ok();
118                state.expect(Identifier).ok();
119                Some(state.finish_at(cp, RE::MemberExpression))
120            }
121            _ => Some(binary(state, left, kind, prec, assoc, RE::BinaryExpression.into(), |s, p| PrattParser::parse(s, p, self))),
122        }
123    }
124}
125
126impl<'config> Parser<RustLanguage> for RustParser<'config> {
127    fn parse<'a, S: Source + ?Sized>(&self, text: &'a S, edits: &[TextEdit], cache: &'a mut impl ParseCache<RustLanguage>) -> ParseOutput<'a, RustLanguage> {
128        let lexer = RustLexer::new(self.config);
129        parse_with_lexer(&lexer, text, edits, cache, |state| {
130            let cp = state.checkpoint();
131            while state.not_at_end() {
132                if state.current().map(|t| t.kind.is_ignored()).unwrap_or(false) {
133                    state.advance();
134                    continue;
135                }
136                self.parse_statement(state)?
137            }
138            let root = state.finish_at(cp, crate::parser::element_type::RustElementType::SourceFile);
139            Ok(root)
140        })
141    }
142}
143
144impl<'config> RustParser<'config> {
145    /// Parses a single Rust statement or item.
146    fn parse_statement<'a, S: Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
147        use crate::{lexer::RustTokenType, parser::RustElementType::*};
148
149        let kind = match state.peek_kind() {
150            Some(RustTokenType::Fn) => Some(Function),
151            Some(RustTokenType::Use) => Some(UseItem),
152            Some(RustTokenType::Mod) => Some(ModuleItem),
153            Some(RustTokenType::Struct) => Some(StructItem),
154            Some(RustTokenType::Enum) => Some(EnumItem),
155            Some(RustTokenType::Let) => Some(LetStatement),
156            Some(RustTokenType::If) => Some(IfExpression),
157            Some(RustTokenType::While) => Some(WhileExpression),
158            Some(RustTokenType::Loop) => Some(LoopExpression),
159            Some(RustTokenType::For) => Some(ForExpression),
160            Some(RustTokenType::Return) => Some(ReturnStatement),
161            Some(RustTokenType::LeftBrace) => Some(Block),
162            _ => None,
163        };
164
165        if let Some(k) = kind {
166            state.incremental_node(k.into(), |state| match k {
167                Function => self.parse_function_body(state),
168                UseItem => self.parse_use_item_body(state),
169                ModuleItem => self.parse_mod_item_body(state),
170                StructItem => self.parse_struct_item_body(state),
171                EnumItem => self.parse_enum_item_body(state),
172                LetStatement => self.parse_let_statement_body(state),
173                IfExpression => self.parse_if_expression_body(state),
174                WhileExpression => self.parse_while_expression_body(state),
175                LoopExpression => self.parse_loop_expression_body(state),
176                ForExpression => self.parse_for_expression_body(state),
177                ReturnStatement => self.parse_return_statement_body(state),
178                Block => self.parse_block_body(state),
179                _ => unreachable!(),
180            })
181        }
182        else {
183            PrattParser::parse(state, 0, self);
184            state.eat(RustTokenType::Semicolon);
185            Ok(())
186        }
187    }
188
189    fn parse_function_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
190        self.parse_function(state)
191    }
192
193    fn parse_use_item_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
194        self.parse_use_item(state)
195    }
196
197    fn parse_mod_item_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
198        self.parse_mod_item(state)
199    }
200
201    fn parse_struct_item_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
202        self.parse_struct_item(state)
203    }
204
205    fn parse_enum_item_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
206        self.parse_enum_item(state)
207    }
208
209    fn parse_let_statement_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
210        self.parse_let_statement(state)
211    }
212
213    fn parse_if_expression_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
214        self.parse_if_expression(state)
215    }
216
217    fn parse_while_expression_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
218        self.parse_while_expression(state)
219    }
220
221    fn parse_loop_expression_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
222        self.parse_loop_expression(state)
223    }
224
225    fn parse_for_expression_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
226        self.parse_for_expression(state)
227    }
228
229    fn parse_return_statement_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
230        self.parse_return_statement(state)
231    }
232
233    fn parse_block_body<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
234        self.parse_block(state)
235    }
236
237    /// Parses a function definition.
238    fn parse_function<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
239        use crate::lexer::token_type::RustTokenType;
240        let cp = state.checkpoint();
241        state.expect(RustTokenType::Fn).ok();
242        state.expect(RustTokenType::Identifier).ok();
243        self.parse_param_list(state)?;
244        if state.eat(RustTokenType::Arrow) {
245            while state.not_at_end() && !state.at(RustTokenType::LeftBrace) {
246                state.advance()
247            }
248        }
249        self.parse_block(state)?;
250        state.finish_at(cp, crate::parser::element_type::RustElementType::Function);
251        Ok(())
252    }
253
254    fn parse_param_list<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
255        use crate::lexer::RustTokenType::*;
256        let cp = state.checkpoint();
257        state.expect(LeftParen).ok();
258        while state.not_at_end() && !state.at(RightParen) {
259            state.advance()
260        }
261        state.expect(RightParen).ok();
262        state.finish_at(cp, crate::parser::element_type::RustElementType::ParameterList);
263        Ok(())
264    }
265
266    /// Parses a block of statements enclosed in braces.
267    fn parse_block<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
268        use crate::lexer::RustTokenType::*;
269        let cp = state.checkpoint();
270        state.expect(LeftBrace).ok();
271        while state.not_at_end() && !state.at(RightBrace) {
272            self.parse_statement(state)?
273        }
274        state.expect(RightBrace).ok();
275        state.finish_at(cp, crate::parser::element_type::RustElementType::BlockExpression);
276        Ok(())
277    }
278
279    /// Parses a `use` declaration.
280    fn parse_use_item<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
281        let cp = state.checkpoint();
282        state.expect(crate::lexer::RustTokenType::Use).ok();
283        // Simplified path handling
284        while !state.at(crate::lexer::RustTokenType::Semicolon) && state.not_at_end() {
285            state.bump()
286        }
287        state.eat(crate::lexer::RustTokenType::Semicolon);
288        state.finish_at(cp, crate::parser::element_type::RustElementType::UseItem);
289        Ok(())
290    }
291
292    /// Parses a module declaration.
293    fn parse_mod_item<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
294        let cp = state.checkpoint();
295        state.bump(); // mod
296        state.expect(crate::lexer::RustTokenType::Identifier).ok();
297        if state.at(crate::lexer::RustTokenType::LeftBrace) {
298            self.parse_block(state)?
299        }
300        else {
301            state.eat(crate::lexer::RustTokenType::Semicolon);
302        }
303        state.finish_at(cp, crate::parser::element_type::RustElementType::ModuleItem);
304        Ok(())
305    }
306
307    /// Parses a struct definition.
308    fn parse_struct_item<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
309        let cp = state.checkpoint();
310        state.bump(); // struct
311        state.expect(crate::lexer::RustTokenType::Identifier).ok();
312        while state.not_at_end() && !state.at(crate::lexer::RustTokenType::LeftBrace) && !state.at(crate::lexer::RustTokenType::Semicolon) {
313            state.advance()
314        }
315        if state.at(crate::lexer::RustTokenType::LeftBrace) {
316            self.parse_block(state)?
317        }
318        else {
319            state.eat(crate::lexer::RustTokenType::Semicolon);
320        }
321        state.finish_at(cp, crate::parser::element_type::RustElementType::StructItem);
322        Ok(())
323    }
324
325    /// Parses an enum definition.
326    fn parse_enum_item<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
327        let cp = state.checkpoint();
328        state.bump(); // enum
329        state.expect(crate::lexer::RustTokenType::Identifier).ok();
330        self.parse_block(state)?;
331        state.finish_at(cp, crate::parser::element_type::RustElementType::EnumItem);
332        Ok(())
333    }
334
335    /// Parses a `let` statement.
336    fn parse_let_statement<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
337        let cp = state.checkpoint();
338        state.bump(); // let
339        state.expect(crate::lexer::RustTokenType::Identifier).ok();
340        if state.eat(crate::lexer::RustTokenType::Eq) {
341            PrattParser::parse(state, 0, self);
342        }
343        state.eat(crate::lexer::RustTokenType::Semicolon);
344        state.finish_at(cp, crate::parser::element_type::RustElementType::LetStatement);
345        Ok(())
346    }
347
348    /// Parses an `if` expression.
349    fn parse_if_expression<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
350        let cp = state.checkpoint();
351        state.bump(); // if
352        PrattParser::parse(state, 0, self);
353        self.parse_block(state)?;
354        if state.eat(crate::lexer::RustTokenType::Else) {
355            if state.at(crate::lexer::RustTokenType::If) { self.parse_if_expression(state)? } else { self.parse_block(state)? }
356        }
357        state.finish_at(cp, crate::parser::element_type::RustElementType::IfExpression);
358        Ok(())
359    }
360
361    /// Parses a `while` loop.
362    fn parse_while_expression<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
363        let cp = state.checkpoint();
364        state.bump(); // while
365        PrattParser::parse(state, 0, self);
366        self.parse_block(state)?;
367        state.finish_at(cp, crate::parser::element_type::RustElementType::WhileExpression);
368        Ok(())
369    }
370
371    /// Parses a `loop` expression.
372    fn parse_loop_expression<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
373        let cp = state.checkpoint();
374        state.bump(); // loop
375        self.parse_block(state)?;
376        state.finish_at(cp, crate::parser::element_type::RustElementType::LoopExpression);
377        Ok(())
378    }
379
380    /// Parses a `for` loop.
381    fn parse_for_expression<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
382        let cp = state.checkpoint();
383        state.bump(); // for
384        state.expect(crate::lexer::RustTokenType::Identifier).ok();
385        state.expect(crate::lexer::RustTokenType::In).ok();
386        PrattParser::parse(state, 0, self);
387        self.parse_block(state)?;
388        state.finish_at(cp, crate::parser::element_type::RustElementType::ForExpression);
389        Ok(())
390    }
391
392    /// Parses a `return` statement.
393    fn parse_return_statement<'a, S: oak_core::source::Source + ?Sized>(&self, state: &mut ParserState<'a, RustLanguage, S>) -> Result<(), OakError> {
394        let cp = state.checkpoint();
395        state.bump(); // return
396        if !state.at(crate::lexer::RustTokenType::Semicolon) {
397            PrattParser::parse(state, 0, self);
398        }
399        state.eat(crate::lexer::RustTokenType::Semicolon);
400        state.finish_at(cp, crate::parser::element_type::RustElementType::ReturnStatement);
401        Ok(())
402    }
403}