Skip to main content

oak_cpp/parser/
mod.rs

1#![doc = include_str!("readme.md")]
2/// Element type definition.
3pub mod element_type;
4pub use element_type::CppElementType;
5
6use crate::{
7    language::CppLanguage,
8    lexer::{CppLexer, CppTokenType},
9};
10use oak_core::{
11    GreenNode, OakError,
12    parser::{Associativity, ParseCache, ParseOutput, Parser, ParserState, Pratt, PrattParser, parse_with_lexer},
13    source::{Source, TextEdit},
14};
15
16pub(crate) type State<'a, S> = ParserState<'a, CppLanguage, S>;
17
18/// Parser for the C++ language.
19///
20/// This parser transforms a stream of tokens into a green tree of C++ syntax nodes,
21/// using a combination of top-down recursive descent and Pratt parsing for expressions.
22pub struct CppParser<'config> {
23    pub(crate) config: &'config CppLanguage,
24}
25
26impl<'config> CppParser<'config> {
27    /// Creates a new `CppParser` with the given configuration.
28    pub fn new(config: &'config CppLanguage) -> Self {
29        Self { config }
30    }
31
32    /// Parses a single C++ statement.
33    ///
34    /// This includes keywords, compound statements, preprocessor directives,
35    /// and expressions followed by a semicolon.
36    fn parse_statement<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
37        use crate::lexer::CppTokenType::*;
38        self.skip_trivia(state);
39        match state.peek_kind() {
40            Some(LeftBrace) => self.parse_compound_statement(state)?,
41            Some(Preprocessor) => {
42                // Skip preprocessor directives
43                while state.not_at_end() && !state.at(Newline) {
44                    state.bump();
45                }
46            }
47            _ => {
48                // Skip any tokens until semicolon or brace
49                while state.not_at_end() && !state.at(Semicolon) && !state.at(LeftBrace) && !state.at(RightBrace) {
50                    state.bump();
51                }
52                if state.at(Semicolon) {
53                    state.bump();
54                }
55                else if state.at(LeftBrace) {
56                    self.parse_compound_statement(state)?;
57                }
58            }
59        }
60        Ok(())
61    }
62
63    /// Skips trivia tokens (whitespace and comments).
64    fn skip_trivia<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) {
65        use crate::lexer::CppTokenType::*;
66        while let Some(kind) = state.peek_kind() {
67            if matches!(kind, Whitespace | Newline | Comment) {
68                state.bump();
69            }
70            else {
71                break;
72            }
73        }
74    }
75
76    /// Parses a compound statement (a block of statements enclosed in braces).
77    fn parse_compound_statement<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
78        let cp = state.checkpoint();
79        if !state.eat(CppTokenType::LeftBrace) {
80            // Skip until right brace or end of file
81            while state.not_at_end() && !state.at(CppTokenType::RightBrace) {
82                state.bump();
83            }
84            if state.at(CppTokenType::RightBrace) {
85                state.bump();
86            }
87            state.finish_at(cp, CppElementType::CompoundStatement);
88            return Ok(());
89        }
90
91        while state.not_at_end() && !state.at(CppTokenType::RightBrace) {
92            self.parse_statement(state)?;
93        }
94
95        if !state.eat(CppTokenType::RightBrace) {
96            // Skip until end of file or next statement
97            while state.not_at_end() && !state.at(CppTokenType::Semicolon) && !state.at(CppTokenType::LeftBrace) {
98                state.bump();
99            }
100        }
101
102        state.finish_at(cp, CppElementType::CompoundStatement);
103        Ok(())
104    }
105
106    /// Parses an if statement.
107    fn parse_if_statement<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
108        let cp = state.checkpoint();
109        state.bump(); // if
110        self.skip_trivia(state);
111        state.expect(CppTokenType::LeftParen).ok();
112        while state.not_at_end() && !state.at(CppTokenType::RightParen) {
113            state.bump();
114        }
115        state.expect(CppTokenType::RightParen).ok();
116        self.parse_statement(state)?;
117        self.skip_trivia(state);
118        if state.at(CppTokenType::Keyword) {
119            state.bump();
120            self.parse_statement(state)?;
121        }
122        state.finish_at(cp, CppElementType::IfStatement);
123        Ok(())
124    }
125
126    /// Parses a while statement.
127    fn parse_while_statement<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
128        let cp = state.checkpoint();
129        state.bump(); // while
130        self.skip_trivia(state);
131        state.expect(CppTokenType::LeftParen).ok();
132        let expr = PrattParser::parse(state, 0, self);
133        state.push_child(expr);
134        state.expect(CppTokenType::RightParen).ok();
135        self.parse_statement(state)?;
136        state.finish_at(cp, CppElementType::WhileStatement);
137        Ok(())
138    }
139
140    /// Parses a for statement.
141    fn parse_for_statement<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
142        let cp = state.checkpoint();
143        state.bump(); // for
144        self.skip_trivia(state);
145        state.expect(CppTokenType::LeftParen).ok();
146
147        // Initialize
148        self.skip_trivia(state);
149        if !state.at(CppTokenType::Semicolon) {
150            let expr = PrattParser::parse(state, 0, self);
151            state.push_child(expr);
152        }
153        state.expect(CppTokenType::Semicolon).ok();
154
155        // Condition
156        self.skip_trivia(state);
157        if !state.at(CppTokenType::Semicolon) {
158            let expr = PrattParser::parse(state, 0, self);
159            state.push_child(expr);
160        }
161        state.expect(CppTokenType::Semicolon).ok();
162
163        // Increment
164        self.skip_trivia(state);
165        if !state.at(CppTokenType::RightParen) {
166            let expr = PrattParser::parse(state, 0, self);
167            state.push_child(expr);
168        }
169        state.expect(CppTokenType::RightParen).ok();
170
171        self.parse_statement(state)?;
172        state.finish_at(cp, CppElementType::ForStatement);
173        Ok(())
174    }
175
176    /// Parses a return statement.
177    fn parse_return_statement<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
178        let cp = state.checkpoint();
179        state.bump(); // return
180        self.skip_trivia(state);
181        if !state.at(CppTokenType::Semicolon) {
182            let expr = PrattParser::parse(state, 0, self);
183            state.push_child(expr);
184        }
185        state.eat(CppTokenType::Semicolon);
186        state.finish_at(cp, CppElementType::ReturnStatement);
187        Ok(())
188    }
189
190    /// Parses a declaration statement.
191    fn parse_declaration<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
192        let cp = state.checkpoint();
193
194        // Parse type specifiers and qualifiers
195        while state.not_at_end() && !state.at(CppTokenType::Semicolon) && !state.at(CppTokenType::LeftParen) {
196            state.bump();
197        }
198
199        // Check if it's a function declaration
200        if state.at(CppTokenType::LeftParen) {
201            // Function declaration
202            self.parse_function_definition(state)?;
203        }
204        else {
205            // Variable declaration
206            while state.not_at_end() && !state.at(CppTokenType::Semicolon) {
207                state.bump();
208            }
209            state.eat(CppTokenType::Semicolon);
210            state.finish_at(cp, CppElementType::DeclarationStatement);
211        }
212
213        Ok(())
214    }
215
216    /// Parses a function definition.
217    fn parse_function_definition<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
218        let cp = state.checkpoint();
219
220        // Parse function name
221        if state.at(CppTokenType::Identifier) {
222            state.bump();
223        }
224
225        // Parse parameters
226        self.skip_trivia(state);
227        state.expect(CppTokenType::LeftParen).ok();
228        while state.not_at_end() && !state.at(CppTokenType::RightParen) {
229            self.skip_trivia(state);
230            // Parse parameter
231            while state.not_at_end() && !state.at(CppTokenType::Comma) && !state.at(CppTokenType::RightParen) {
232                state.bump();
233            }
234            if state.at(CppTokenType::Comma) {
235                state.bump();
236            }
237        }
238        state.expect(CppTokenType::RightParen).ok();
239
240        // Parse function body
241        self.skip_trivia(state);
242        if state.at(CppTokenType::LeftBrace) {
243            self.parse_compound_statement(state)?;
244        }
245
246        state.finish_at(cp, CppElementType::FunctionDefinition);
247        Ok(())
248    }
249
250    /// Parses a class definition.
251    fn parse_class_definition<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
252        let cp = state.checkpoint();
253        state.bump(); // class/struct/enum
254
255        // Parse class name
256        self.skip_trivia(state);
257        if state.at(CppTokenType::Identifier) {
258            state.bump();
259        }
260
261        // Parse template parameters (if any)
262        self.skip_trivia(state);
263        if state.at(CppTokenType::Less) {
264            while state.not_at_end() && !state.at(CppTokenType::Greater) {
265                state.bump();
266            }
267            state.eat(CppTokenType::Greater);
268        }
269
270        // Parse base classes (if any)
271        self.skip_trivia(state);
272        if state.at(CppTokenType::Colon) {
273            state.bump();
274            while state.not_at_end() && !state.at(CppTokenType::LeftBrace) {
275                state.bump();
276            }
277        }
278
279        // Parse class body
280        self.skip_trivia(state);
281        if state.at(CppTokenType::LeftBrace) {
282            state.bump();
283            while state.not_at_end() && !state.at(CppTokenType::RightBrace) {
284                // Parse class members
285                self.skip_trivia(state);
286                if state.at(CppTokenType::Keyword) {
287                    state.bump();
288                    if state.at(CppTokenType::Colon) {
289                        state.bump();
290                    }
291                }
292                // Parse member declaration
293                while state.not_at_end() && !state.at(CppTokenType::Semicolon) && !state.at(CppTokenType::RightBrace) {
294                    state.bump();
295                }
296                state.eat(CppTokenType::Semicolon);
297            }
298            state.expect(CppTokenType::RightBrace).ok();
299        }
300
301        state.eat(CppTokenType::Semicolon);
302        state.finish_at(cp, CppElementType::ClassDefinition);
303        Ok(())
304    }
305
306    /// Parses a namespace definition.
307    fn parse_namespace_definition<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
308        let cp = state.checkpoint();
309        state.bump(); // namespace
310
311        // Parse namespace name
312        self.skip_trivia(state);
313        if state.at(CppTokenType::Identifier) {
314            state.bump();
315            // Handle nested namespaces (::)
316            while state.at(CppTokenType::Scope) {
317                state.bump();
318                if state.at(CppTokenType::Identifier) {
319                    state.bump();
320                }
321            }
322        }
323
324        // Parse namespace body
325        self.skip_trivia(state);
326        if state.at(CppTokenType::LeftBrace) {
327            state.bump();
328            while state.not_at_end() && !state.at(CppTokenType::RightBrace) {
329                self.parse_statement(state)?;
330            }
331            state.expect(CppTokenType::RightBrace).ok();
332        }
333
334        state.finish_at(cp, CppElementType::NamespaceDefinition);
335        Ok(())
336    }
337}
338
339impl<'config> Parser<CppLanguage> for CppParser<'config> {
340    /// Parses the entire C++ source file.
341    fn parse<'a, S: Source + ?Sized>(&self, text: &'a S, edits: &[TextEdit], cache: &'a mut impl ParseCache<CppLanguage>) -> ParseOutput<'a, CppLanguage> {
342        let lexer = CppLexer::new(self.config);
343        parse_with_lexer(&lexer, text, edits, cache, |state| {
344            let cp = state.checkpoint();
345            while state.not_at_end() {
346                self.parse_statement(state)?
347            }
348            Ok(state.finish_at(cp, CppElementType::SourceFile))
349        })
350    }
351}
352
353impl<'config> Pratt<CppLanguage> for CppParser<'config> {
354    /// Parses a primary expression (e.g., identifiers, literals, parenthesized expressions).
355    fn primary<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> &'a GreenNode<'a, CppLanguage> {
356        use crate::lexer::CppTokenType::*;
357        self.skip_trivia(state);
358        let cp = state.checkpoint();
359        match state.peek_kind() {
360            Some(Identifier) => {
361                state.bump();
362                state.finish_at(cp, CppElementType::Token(Identifier))
363            }
364            Some(IntegerLiteral) | Some(FloatLiteral) | Some(CharacterLiteral) | Some(StringLiteral) | Some(BooleanLiteral) => {
365                state.bump();
366                state.finish_at(cp, CppElementType::ExpressionStatement)
367            }
368            Some(LeftParen) => {
369                state.bump();
370                let expr = PrattParser::parse(state, 0, self);
371                state.push_child(expr);
372                self.skip_trivia(state);
373                state.expect(RightParen).ok();
374                state.finish_at(cp, CppElementType::ExpressionStatement)
375            }
376            _ => {
377                state.bump();
378                state.finish_at(cp, CppElementType::Error)
379            }
380        }
381    }
382
383    fn prefix<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> &'a GreenNode<'a, CppLanguage> {
384        self.primary(state)
385    }
386
387    fn infix<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>, left: &'a GreenNode<'a, CppLanguage>, min_precedence: u8) -> Option<&'a GreenNode<'a, CppLanguage>> {
388        use crate::lexer::CppTokenType::*;
389        self.skip_trivia(state);
390        let kind = state.peek_kind()?;
391
392        let (prec, assoc) = match kind {
393            Assign | PlusAssign | MinusAssign | StarAssign | SlashAssign | PercentAssign | AndAssign | OrAssign | XorAssign | LeftShiftAssign | RightShiftAssign => (1, Associativity::Right),
394            LogicalOr => (2, Associativity::Left),
395            LogicalAnd => (3, Associativity::Left),
396            Equal | NotEqual | Less | Greater | LessEqual | GreaterEqual => (4, Associativity::Left),
397            Plus | Minus => (10, Associativity::Left),
398            Star | Slash | Percent => (11, Associativity::Left),
399            LeftParen | LeftBracket | Dot | Arrow => (15, Associativity::Left),
400            Scope => (16, Associativity::Left),
401            _ => return None,
402        };
403
404        if prec < min_precedence {
405            return None;
406        }
407
408        match kind {
409            LeftParen => {
410                let cp = state.checkpoint();
411                state.push_child(left);
412                state.expect(LeftParen).ok();
413                while state.not_at_end() && !state.at(RightParen) {
414                    self.skip_trivia(state);
415                    let expr = PrattParser::parse(state, 0, self);
416                    state.push_child(expr);
417                    self.skip_trivia(state);
418                    if !state.eat(Comma) {
419                        break;
420                    }
421                }
422                state.expect(RightParen).ok();
423                Some(state.finish_at(cp, CppElementType::FunctionCall))
424            }
425            _ => {
426                let cp = state.checkpoint();
427                state.push_child(left);
428                state.bump();
429                self.skip_trivia(state);
430                let right = PrattParser::parse(state, prec + (assoc as u8), self);
431                state.push_child(right);
432                Some(state.finish_at(cp, CppElementType::ExpressionStatement))
433            }
434        }
435    }
436}