oak_cpp/lexer/
mod.rs

1use crate::{kind::CppSyntaxKind, language::CppLanguage};
2use oak_core::{IncrementalCache, Lexer, LexerState, lexer::LexOutput, source::Source};
3
4type State<S> = LexerState<S, CppLanguage>;
5
6pub struct CppLexer<'config> {
7    config: &'config CppLanguage,
8}
9
10/// C 词法分析器类型别名
11pub type CLexer<'config> = CppLexer<'config>;
12
13impl<'config> CppLexer<'config> {
14    pub fn new(config: &'config CppLanguage) -> Self {
15        Self { config }
16    }
17
18    /// 跳过空白字符
19    fn skip_whitespace<S: Source>(&self, state: &mut State<S>) -> bool {
20        let start_pos = state.get_position();
21
22        while let Some(ch) = state.peek() {
23            if ch == ' ' || ch == '\t' {
24                state.advance(ch.len_utf8());
25            }
26            else {
27                break;
28            }
29        }
30
31        if state.get_position() > start_pos {
32            state.add_token(CppSyntaxKind::Whitespace, start_pos, state.get_position());
33            true
34        }
35        else {
36            false
37        }
38    }
39
40    /// 处理换行
41    fn lex_newline<S: Source>(&self, state: &mut State<S>) -> bool {
42        let start_pos = state.get_position();
43
44        if let Some('\n') = state.peek() {
45            state.advance(1);
46            state.add_token(CppSyntaxKind::Newline, start_pos, state.get_position());
47            true
48        }
49        else if let Some('\r') = state.peek() {
50            state.advance(1);
51            if let Some('\n') = state.peek() {
52                state.advance(1);
53            }
54            state.add_token(CppSyntaxKind::Newline, start_pos, state.get_position());
55            true
56        }
57        else {
58            false
59        }
60    }
61
62    /// 处理注释
63    fn lex_comment<S: Source>(&self, state: &mut State<S>) -> bool {
64        let start_pos = state.get_position();
65
66        if let Some('/') = state.peek() {
67            if let Some('/') = state.peek_next_n(1) {
68                // 单行注释
69                state.advance(2);
70                while let Some(ch) = state.peek() {
71                    if ch == '\n' || ch == '\r' {
72                        break;
73                    }
74                    state.advance(ch.len_utf8());
75                }
76                state.add_token(CppSyntaxKind::Comment, start_pos, state.get_position());
77                true
78            }
79            else if let Some('*') = state.peek_next_n(1) {
80                // 多行注释
81                state.advance(2);
82                while let Some(ch) = state.peek() {
83                    if ch == '*' && state.peek_next_n(1) == Some('/') {
84                        state.advance(2);
85                        break;
86                    }
87                    state.advance(ch.len_utf8());
88                }
89                state.add_token(CppSyntaxKind::Comment, start_pos, state.get_position());
90                true
91            }
92            else {
93                false
94            }
95        }
96        else {
97            false
98        }
99    }
100
101    /// 处理字符串字面量
102    fn lex_string<S: Source>(&self, state: &mut State<S>) -> bool {
103        let start_pos = state.get_position();
104
105        if let Some('"') = state.peek() {
106            state.advance(1);
107
108            let mut escaped = false;
109            while let Some(ch) = state.peek() {
110                if escaped {
111                    escaped = false;
112                    state.advance(ch.len_utf8());
113                    continue;
114                }
115
116                if ch == '\\' {
117                    escaped = true;
118                    state.advance(1);
119                    continue;
120                }
121
122                if ch == '"' {
123                    state.advance(1);
124                    break;
125                }
126
127                if ch == '\n' || ch == '\r' {
128                    break; // 未闭合的字符
129                }
130
131                state.advance(ch.len_utf8());
132            }
133
134            state.add_token(CppSyntaxKind::StringLiteral, start_pos, state.get_position());
135            true
136        }
137        else {
138            false
139        }
140    }
141
142    /// 处理字符字面量
143    fn lex_character<S: Source>(&self, state: &mut State<S>) -> bool {
144        let start_pos = state.get_position();
145
146        if let Some('\'') = state.peek() {
147            state.advance(1);
148
149            let mut escaped = false;
150            while let Some(ch) = state.peek() {
151                if escaped {
152                    escaped = false;
153                    state.advance(ch.len_utf8());
154                    continue;
155                }
156
157                if ch == '\\' {
158                    escaped = true;
159                    state.advance(1);
160                    continue;
161                }
162
163                if ch == '\'' {
164                    state.advance(1);
165                    break;
166                }
167
168                if ch == '\n' || ch == '\r' {
169                    break; // 未闭合的字符
170                }
171
172                state.advance(ch.len_utf8());
173            }
174
175            state.add_token(CppSyntaxKind::CharacterLiteral, start_pos, state.get_position());
176            true
177        }
178        else {
179            false
180        }
181    }
182
183    /// 处理数字字面量
184    fn lex_number<S: Source>(&self, state: &mut State<S>) -> bool {
185        let start_pos = state.get_position();
186
187        if let Some(ch) = state.peek() {
188            if ch.is_ascii_digit() || (ch == '.' && state.peek_next_n(1).map_or(false, |c| c.is_ascii_digit())) {
189                let mut is_float = false;
190
191                // 处理十六进制、八进制、二进制
192                if ch == '0' {
193                    if let Some(next_ch) = state.peek_next_n(1) {
194                        if next_ch == 'x' || next_ch == 'X' {
195                            // 十六进制
196                            state.advance(2);
197                            while let Some(ch) = state.peek() {
198                                if ch.is_ascii_hexdigit() {
199                                    state.advance(1);
200                                }
201                                else {
202                                    break;
203                                }
204                            }
205                        }
206                        else if next_ch == 'b' || next_ch == 'B' {
207                            // 二进
208                            state.advance(2);
209                            while let Some(ch) = state.peek() {
210                                if ch == '0' || ch == '1' {
211                                    state.advance(1);
212                                }
213                                else {
214                                    break;
215                                }
216                            }
217                        }
218                        else if next_ch.is_ascii_digit() {
219                            // 八进
220                            while let Some(ch) = state.peek() {
221                                if ch.is_ascii_digit() {
222                                    state.advance(1);
223                                }
224                                else {
225                                    break;
226                                }
227                            }
228                        }
229                        else {
230                            state.advance(1); // 只是 '0'
231                        }
232                    }
233                    else {
234                        state.advance(1); // 只是 '0'
235                    }
236                }
237                else {
238                    // 十进制整数部
239                    while let Some(ch) = state.peek() {
240                        if ch.is_ascii_digit() {
241                            state.advance(1);
242                        }
243                        else {
244                            break;
245                        }
246                    }
247                }
248
249                // 检查小数点
250                if let Some('.') = state.peek() {
251                    if let Some(next_ch) = state.peek_next_n(1) {
252                        if next_ch.is_ascii_digit() {
253                            is_float = true;
254                            state.advance(1); // 消费小数
255                            while let Some(ch) = state.peek() {
256                                if ch.is_ascii_digit() {
257                                    state.advance(1);
258                                }
259                                else {
260                                    break;
261                                }
262                            }
263                        }
264                    }
265                }
266
267                // 检查科学记数法
268                if let Some(ch) = state.peek() {
269                    if ch == 'e' || ch == 'E' {
270                        is_float = true;
271                        state.advance(1);
272                        if let Some(sign) = state.peek() {
273                            if sign == '+' || sign == '-' {
274                                state.advance(1);
275                            }
276                        }
277                        while let Some(ch) = state.peek() {
278                            if ch.is_ascii_digit() {
279                                state.advance(1);
280                            }
281                            else {
282                                break;
283                            }
284                        }
285                    }
286                }
287
288                // 检查后缀
289                while let Some(ch) = state.peek() {
290                    if ch.is_ascii_alphabetic() {
291                        state.advance(1);
292                    }
293                    else {
294                        break;
295                    }
296                }
297
298                let token_kind = if is_float { CppSyntaxKind::FloatLiteral } else { CppSyntaxKind::IntegerLiteral };
299
300                state.add_token(token_kind, start_pos, state.get_position());
301                true
302            }
303            else {
304                false
305            }
306        }
307        else {
308            false
309        }
310    }
311
312    /// 处理关键字或标识符
313    fn lex_keyword_or_identifier<S: Source>(&self, state: &mut State<S>) -> bool {
314        let start_pos = state.get_position();
315
316        if let Some(ch) = state.peek() {
317            if ch.is_ascii_alphabetic() || ch == '_' {
318                while let Some(ch) = state.peek() {
319                    if ch.is_ascii_alphanumeric() || ch == '_' {
320                        state.advance(ch.len_utf8());
321                    }
322                    else {
323                        break;
324                    }
325                }
326
327                let text = state.get_text_in((start_pos..state.get_position()).into());
328                let token_kind = match text {
329                    // C++ 关键
330                    "alignas" | "alignof" | "and" | "and_eq" | "asm" | "atomic_cancel" | "atomic_commit"
331                    | "atomic_noexcept" | "auto" | "bitand" | "bitor" | "bool" | "break" | "case" | "catch" | "char"
332                    | "char8_t" | "char16_t" | "char32_t" | "class" | "compl" | "concept" | "const" | "consteval"
333                    | "constexpr" | "constinit" | "const_cast" | "continue" | "co_await" | "co_return" | "co_yield"
334                    | "decltype" | "default" | "delete" | "do" | "double" | "dynamic_cast" | "else" | "enum" | "explicit"
335                    | "export" | "extern" | "false" | "float" | "for" | "friend" | "goto" | "if" | "inline" | "int"
336                    | "long" | "mutable" | "namespace" | "new" | "noexcept" | "not" | "not_eq" | "nullptr" | "operator"
337                    | "or" | "or_eq" | "private" | "protected" | "public" | "reflexpr" | "register" | "reinterpret_cast"
338                    | "requires" | "return" | "short" | "signed" | "sizeof" | "static" | "static_assert" | "static_cast"
339                    | "struct" | "switch" | "synchronized" | "template" | "this" | "thread_local" | "throw" | "true"
340                    | "try" | "typedef" | "typeid" | "typename" | "union" | "unsigned" | "using" | "virtual" | "void"
341                    | "volatile" | "wchar_t" | "while" | "xor" | "xor_eq" => CppSyntaxKind::Keyword,
342                    "true" | "false" => CppSyntaxKind::BooleanLiteral,
343                    _ => CppSyntaxKind::Identifier,
344                };
345
346                state.add_token(token_kind, start_pos, state.get_position());
347                true
348            }
349            else {
350                false
351            }
352        }
353        else {
354            false
355        }
356    }
357
358    /// 处理操作符
359    fn lex_operator<S: Source>(&self, state: &mut State<S>) -> bool {
360        let start_pos = state.get_position();
361
362        if let Some(ch) = state.peek() {
363            let (token_kind, advance_count) = match ch {
364                '+' => {
365                    if let Some('+') = state.peek_next_n(1) {
366                        (CppSyntaxKind::Increment, 2)
367                    }
368                    else if let Some('=') = state.peek_next_n(1) {
369                        (CppSyntaxKind::PlusAssign, 2)
370                    }
371                    else {
372                        (CppSyntaxKind::Plus, 1)
373                    }
374                }
375                '-' => {
376                    if let Some('-') = state.peek_next_n(1) {
377                        (CppSyntaxKind::Decrement, 2)
378                    }
379                    else if let Some('=') = state.peek_next_n(1) {
380                        (CppSyntaxKind::MinusAssign, 2)
381                    }
382                    else if let Some('>') = state.peek_next_n(1) {
383                        (CppSyntaxKind::Arrow, 2)
384                    }
385                    else {
386                        (CppSyntaxKind::Minus, 1)
387                    }
388                }
389                '*' => {
390                    if let Some('=') = state.peek_next_n(1) {
391                        (CppSyntaxKind::StarAssign, 2)
392                    }
393                    else {
394                        (CppSyntaxKind::Star, 1)
395                    }
396                }
397                '/' => {
398                    if let Some('=') = state.peek_next_n(1) {
399                        (CppSyntaxKind::SlashAssign, 2)
400                    }
401                    else {
402                        (CppSyntaxKind::Slash, 1)
403                    }
404                }
405                '%' => {
406                    if let Some('=') = state.peek_next_n(1) {
407                        (CppSyntaxKind::PercentAssign, 2)
408                    }
409                    else {
410                        (CppSyntaxKind::Percent, 1)
411                    }
412                }
413                '=' => {
414                    if let Some('=') = state.peek_next_n(1) {
415                        (CppSyntaxKind::Equal, 2)
416                    }
417                    else {
418                        (CppSyntaxKind::Assign, 1)
419                    }
420                }
421                '!' => {
422                    if let Some('=') = state.peek_next_n(1) {
423                        (CppSyntaxKind::NotEqual, 2)
424                    }
425                    else {
426                        (CppSyntaxKind::LogicalNot, 1)
427                    }
428                }
429                '<' => {
430                    if let Some('<') = state.peek_next_n(1) {
431                        if let Some('=') = state.peek_next_n(2) {
432                            (CppSyntaxKind::LeftShiftAssign, 3)
433                        }
434                        else {
435                            (CppSyntaxKind::LeftShift, 2)
436                        }
437                    }
438                    else if let Some('=') = state.peek_next_n(1) {
439                        (CppSyntaxKind::LessEqual, 2)
440                    }
441                    else {
442                        (CppSyntaxKind::Less, 1)
443                    }
444                }
445                '>' => {
446                    if let Some('>') = state.peek_next_n(1) {
447                        if let Some('=') = state.peek_next_n(2) {
448                            (CppSyntaxKind::RightShiftAssign, 3)
449                        }
450                        else {
451                            (CppSyntaxKind::RightShift, 2)
452                        }
453                    }
454                    else if let Some('=') = state.peek_next_n(1) {
455                        (CppSyntaxKind::GreaterEqual, 2)
456                    }
457                    else {
458                        (CppSyntaxKind::Greater, 1)
459                    }
460                }
461                '&' => {
462                    if let Some('&') = state.peek_next_n(1) {
463                        (CppSyntaxKind::LogicalAnd, 2)
464                    }
465                    else if let Some('=') = state.peek_next_n(1) {
466                        (CppSyntaxKind::AndAssign, 2)
467                    }
468                    else {
469                        (CppSyntaxKind::BitAnd, 1)
470                    }
471                }
472                '|' => {
473                    if let Some('|') = state.peek_next_n(1) {
474                        (CppSyntaxKind::LogicalOr, 2)
475                    }
476                    else if let Some('=') = state.peek_next_n(1) {
477                        (CppSyntaxKind::OrAssign, 2)
478                    }
479                    else {
480                        (CppSyntaxKind::BitOr, 1)
481                    }
482                }
483                '^' => {
484                    if let Some('=') = state.peek_next_n(1) {
485                        (CppSyntaxKind::XorAssign, 2)
486                    }
487                    else {
488                        (CppSyntaxKind::BitXor, 1)
489                    }
490                }
491                '~' => (CppSyntaxKind::BitNot, 1),
492                '?' => (CppSyntaxKind::Question, 1),
493                ':' => {
494                    if let Some(':') = state.peek_next_n(1) {
495                        (CppSyntaxKind::Scope, 2)
496                    }
497                    else {
498                        (CppSyntaxKind::Colon, 1)
499                    }
500                }
501                '.' => (CppSyntaxKind::Dot, 1),
502                _ => return false,
503            };
504
505            state.advance(advance_count);
506            state.add_token(token_kind, start_pos, state.get_position());
507            true
508        }
509        else {
510            false
511        }
512    }
513
514    /// 处理分隔符
515    fn lex_delimiter<S: Source>(&self, state: &mut State<S>) -> bool {
516        let start_pos = state.get_position();
517
518        if let Some(ch) = state.peek() {
519            let token_kind = match ch {
520                '(' => CppSyntaxKind::LeftParen,
521                ')' => CppSyntaxKind::RightParen,
522                '[' => CppSyntaxKind::LeftBracket,
523                ']' => CppSyntaxKind::RightBracket,
524                '{' => CppSyntaxKind::LeftBrace,
525                '}' => CppSyntaxKind::RightBrace,
526                ',' => CppSyntaxKind::Comma,
527                ';' => CppSyntaxKind::Semicolon,
528                _ => return false,
529            };
530
531            state.advance(1);
532            state.add_token(token_kind, start_pos, state.get_position());
533            true
534        }
535        else {
536            false
537        }
538    }
539
540    /// 处理预处理指令
541    fn lex_preprocessor<S: Source>(&self, state: &mut State<S>) -> bool {
542        let start_pos = state.get_position();
543
544        if let Some('#') = state.peek() {
545            // 读取到行
546            while let Some(ch) = state.peek() {
547                if ch == '\n' || ch == '\r' {
548                    break;
549                }
550                state.advance(ch.len_utf8());
551            }
552
553            state.add_token(CppSyntaxKind::Preprocessor, start_pos, state.get_position());
554            true
555        }
556        else {
557            false
558        }
559    }
560}
561
562impl<'config> Lexer<CppLanguage> for CppLexer<'config> {
563    fn lex_incremental(
564        &self,
565        source: impl Source,
566        _changed: usize,
567        _cache: IncrementalCache<CppLanguage>,
568    ) -> LexOutput<CppLanguage> {
569        let mut state = LexerState::new_with_cache(source, _changed, _cache);
570
571        loop {
572            // 检查是否到达文件末尾
573            if state.not_at_end() == false {
574                break;
575            }
576
577            // 尝试各种词法规则
578            if self.skip_whitespace(&mut state) {
579                continue;
580            }
581
582            if self.lex_newline(&mut state) {
583                continue;
584            }
585
586            if self.lex_comment(&mut state) {
587                continue;
588            }
589
590            if self.lex_string(&mut state) {
591                continue;
592            }
593
594            if self.lex_character(&mut state) {
595                continue;
596            }
597
598            if self.lex_number(&mut state) {
599                continue;
600            }
601
602            if self.lex_keyword_or_identifier(&mut state) {
603                continue;
604            }
605
606            if self.lex_preprocessor(&mut state) {
607                continue;
608            }
609
610            if self.lex_operator(&mut state) {
611                continue;
612            }
613
614            if self.lex_delimiter(&mut state) {
615                continue;
616            }
617
618            // 如果所有规则都不匹配,跳过当前字符并标记为错误
619            let start_pos = state.get_position();
620            if let Some(ch) = state.peek() {
621                state.advance(ch.len_utf8());
622                state.add_token(CppSyntaxKind::Error, start_pos, state.get_position());
623            }
624            else {
625                // 如果没有字符可读,退出循环
626                break;
627            }
628        }
629
630        // 添加 EOF kind
631        let eof_pos = state.get_position();
632        state.add_token(CppSyntaxKind::Eof, eof_pos, eof_pos);
633
634        state.finish(Ok(()))
635    }
636}