oak_ruby/lexer/
mod.rs

1use crate::{kind::RubySyntaxKind, language::RubyLanguage};
2use oak_core::{
3    IncrementalCache, Lexer, LexerState, OakError,
4    lexer::{CommentLine, LexOutput, StringConfig, WhitespaceConfig},
5    source::Source,
6};
7use std::sync::LazyLock;
8
9type State<S> = LexerState<S, RubyLanguage>;
10
11static RUBY_WHITESPACE: LazyLock<WhitespaceConfig> = LazyLock::new(|| WhitespaceConfig { unicode_whitespace: true });
12static RUBY_COMMENT: LazyLock<CommentLine> = LazyLock::new(|| CommentLine { line_markers: &["#"] });
13static RUBY_STRING: LazyLock<StringConfig> = LazyLock::new(|| StringConfig { quotes: &['"', '\'', '`'], escape: Some('\\') });
14
15#[derive(Clone)]
16pub struct RubyLexer<'config> {
17    config: &'config RubyLanguage,
18}
19
20impl<'config> Lexer<RubyLanguage> for RubyLexer<'config> {
21    fn lex_incremental(
22        &self,
23        source: impl Source,
24        changed: usize,
25        cache: IncrementalCache<RubyLanguage>,
26    ) -> LexOutput<RubyLanguage> {
27        let mut state = LexerState::new_with_cache(source, changed, cache);
28        let result = self.run(&mut state);
29        state.finish(result)
30    }
31}
32
33impl<'config> RubyLexer<'config> {
34    pub fn new(config: &'config RubyLanguage) -> Self {
35        Self { config }
36    }
37
38    fn run<S: Source>(&self, state: &mut State<S>) -> Result<(), OakError> {
39        while state.not_at_end() {
40            let safe_point = state.get_position();
41
42            if self.skip_whitespace(state) {
43                continue;
44            }
45
46            if self.lex_newline(state) {
47                continue;
48            }
49
50            if self.skip_comment(state) {
51                continue;
52            }
53
54            if self.lex_string_literal(state) {
55                continue;
56            }
57
58            if self.lex_symbol(state) {
59                continue;
60            }
61
62            if self.lex_number_literal(state) {
63                continue;
64            }
65
66            if self.lex_identifier_or_keyword(state) {
67                continue;
68            }
69
70            if self.lex_operators(state) {
71                continue;
72            }
73
74            if self.lex_single_char_tokens(state) {
75                continue;
76            }
77
78            state.safe_check(safe_point);
79        }
80
81        // 添加 EOF token
82        let eof_pos = state.get_position();
83        state.add_token(RubySyntaxKind::Eof, eof_pos, eof_pos);
84        Ok(())
85    }
86
87    /// 跳过空白字符
88    fn skip_whitespace<S: Source>(&self, state: &mut State<S>) -> bool {
89        let start_pos = state.get_position();
90
91        while let Some(ch) = state.peek() {
92            if ch == ' ' || ch == '\t' {
93                state.advance(ch.len_utf8());
94            }
95            else {
96                break;
97            }
98        }
99
100        if state.get_position() > start_pos {
101            state.add_token(RubySyntaxKind::Whitespace, start_pos, state.get_position());
102            true
103        }
104        else {
105            false
106        }
107    }
108
109    /// 处理换行
110    fn lex_newline<S: Source>(&self, state: &mut State<S>) -> bool {
111        let start_pos = state.get_position();
112
113        if let Some('\n') = state.peek() {
114            state.advance(1);
115            state.add_token(RubySyntaxKind::Newline, start_pos, state.get_position());
116            true
117        }
118        else if let Some('\r') = state.peek() {
119            state.advance(1);
120            if let Some('\n') = state.peek() {
121                state.advance(1);
122            }
123            state.add_token(RubySyntaxKind::Newline, start_pos, state.get_position());
124            true
125        }
126        else {
127            false
128        }
129    }
130
131    /// 处理注释
132    fn skip_comment<S: Source>(&self, state: &mut State<S>) -> bool {
133        if let Some('#') = state.peek() {
134            let start_pos = state.get_position();
135            state.advance(1); // 跳过 '#'
136
137            // 读取到行
138            while let Some(ch) = state.peek() {
139                if ch == '\n' || ch == '\r' {
140                    break;
141                }
142                state.advance(ch.len_utf8());
143            }
144
145            state.add_token(RubySyntaxKind::Comment, start_pos, state.get_position());
146            true
147        }
148        else {
149            false
150        }
151    }
152
153    /// 处理字符串字面量
154    fn lex_string_literal<S: Source>(&self, state: &mut State<S>) -> bool {
155        let start_pos = state.get_position();
156
157        // 检查是否是字符串开
158        let quote_char = match state.peek() {
159            Some('"') => '"',
160            Some('\'') => '\'',
161            Some('`') => '`',
162            _ => return false,
163        };
164
165        state.advance(1); // 跳过开始引
166        let mut escaped = false;
167        while let Some(ch) = state.peek() {
168            if escaped {
169                escaped = false;
170                state.advance(ch.len_utf8());
171                continue;
172            }
173
174            if ch == '\\' {
175                escaped = true;
176                state.advance(1);
177                continue;
178            }
179
180            if ch == quote_char {
181                state.advance(1); // 跳过结束引号
182                break;
183            }
184            else if ch == '\n' || ch == '\r' {
185                // Ruby 字符串可以跨多行
186                state.advance(ch.len_utf8());
187            }
188            else {
189                state.advance(ch.len_utf8());
190            }
191        }
192
193        state.add_token(RubySyntaxKind::StringLiteral, start_pos, state.get_position());
194        true
195    }
196
197    /// 处理符号
198    fn lex_symbol<S: Source>(&self, state: &mut State<S>) -> bool {
199        if let Some(':') = state.peek() {
200            let start_pos = state.get_position();
201            state.advance(1); // 跳过 ':'
202
203            // 检查下一个字符是否是标识符开
204            if let Some(ch) = state.peek() {
205                if ch.is_ascii_alphabetic() || ch == '_' {
206                    // 读取标识
207                    while let Some(ch) = state.peek() {
208                        if ch.is_ascii_alphanumeric() || ch == '_' || ch == '?' || ch == '!' {
209                            state.advance(1);
210                        }
211                        else {
212                            break;
213                        }
214                    }
215                    state.add_token(RubySyntaxKind::Symbol, start_pos, state.get_position());
216                    return true;
217                }
218                else if ch == '"' || ch == '\'' {
219                    // 引号符号
220                    let quote = ch;
221                    state.advance(1);
222
223                    let mut escaped = false;
224                    while let Some(ch) = state.peek() {
225                        if escaped {
226                            escaped = false;
227                            state.advance(ch.len_utf8());
228                            continue;
229                        }
230
231                        if ch == '\\' {
232                            escaped = true;
233                            state.advance(1);
234                            continue;
235                        }
236
237                        if ch == quote {
238                            state.advance(1);
239                            break;
240                        }
241                        else {
242                            state.advance(ch.len_utf8());
243                        }
244                    }
245                    state.add_token(RubySyntaxKind::Symbol, start_pos, state.get_position());
246                    return true;
247                }
248            }
249        }
250        false
251    }
252
253    /// 处理正则表达
254    fn lex_regex<S: Source>(&self, state: &mut State<S>) -> bool {
255        if let Some('/') = state.peek() {
256            let start_pos = state.get_position();
257            state.advance(1); // 跳过开始的 '/'
258
259            let mut escaped = false;
260            while let Some(ch) = state.peek() {
261                if escaped {
262                    escaped = false;
263                    state.advance(ch.len_utf8());
264                    continue;
265                }
266
267                if ch == '\\' {
268                    escaped = true;
269                    state.advance(1);
270                    continue;
271                }
272
273                if ch == '/' {
274                    state.advance(1); // 跳过结束'/'
275
276                    // 读取修饰
277                    while let Some(ch) = state.peek() {
278                        if ch.is_ascii_alphabetic() {
279                            state.advance(1);
280                        }
281                        else {
282                            break;
283                        }
284                    }
285                    break;
286                }
287                else if ch == '\n' || ch == '\r' {
288                    // 正则表达式不能跨
289                    break;
290                }
291                else {
292                    state.advance(ch.len_utf8());
293                }
294            }
295
296            state.add_token(RubySyntaxKind::RegexLiteral, start_pos, state.get_position());
297            true
298        }
299        else {
300            false
301        }
302    }
303
304    /// 处理数字字面
305    fn lex_number_literal<S: Source>(&self, state: &mut State<S>) -> bool {
306        let start_pos = state.get_position();
307
308        if !state.peek().map_or(false, |c| c.is_ascii_digit()) {
309            return false;
310        }
311
312        let mut is_float = false;
313
314        // 检查进制前缀
315        if state.peek() == Some('0') {
316            let next_char = state.peek_next_n(1);
317            match next_char {
318                Some('b') | Some('B') => {
319                    state.advance(2); // 跳过 '0b' '0B'
320                    // 读取二进制数
321                    while let Some(ch) = state.peek() {
322                        if ch == '0' || ch == '1' {
323                            state.advance(1);
324                        }
325                        else if ch == '_' {
326                            state.advance(1); // 数字分隔
327                        }
328                        else {
329                            break;
330                        }
331                    }
332                }
333                Some('o') | Some('O') => {
334                    state.advance(2); // 跳过 '0o' '0O'
335                    // 读取八进制数
336                    while let Some(ch) = state.peek() {
337                        if ch.is_ascii_digit() && ch < '8' {
338                            state.advance(1);
339                        }
340                        else if ch == '_' {
341                            state.advance(1); // 数字分隔
342                        }
343                        else {
344                            break;
345                        }
346                    }
347                }
348                Some('x') | Some('X') => {
349                    state.advance(2); // 跳过 '0x' '0X'
350                    // 读取十六进制数字
351                    while let Some(ch) = state.peek() {
352                        if ch.is_ascii_hexdigit() {
353                            state.advance(1);
354                        }
355                        else if ch == '_' {
356                            state.advance(1); // 数字分隔
357                        }
358                        else {
359                            break;
360                        }
361                    }
362                }
363                _ => {
364                    // 十进制数
365                    self.lex_decimal_number(state, &mut is_float);
366                }
367            }
368        }
369        else {
370            // 十进制数
371            self.lex_decimal_number(state, &mut is_float);
372        }
373
374        let kind = if is_float { RubySyntaxKind::FloatLiteral } else { RubySyntaxKind::IntegerLiteral };
375
376        state.add_token(kind, start_pos, state.get_position());
377        true
378    }
379
380    /// 处理十进制数
381    fn lex_decimal_number<S: Source>(&self, state: &mut State<S>, is_float: &mut bool) {
382        // 读取整数部分
383        while let Some(ch) = state.peek() {
384            if ch.is_ascii_digit() {
385                state.advance(1);
386            }
387            else if ch == '_' {
388                state.advance(1); // 数字分隔            } else {
389                break;
390            }
391        }
392
393        // 检查小数点
394        if state.peek() == Some('.') && state.peek_next_n(1).map_or(false, |c| c.is_ascii_digit()) {
395            *is_float = true;
396            state.advance(1); // 跳过小数
397            // 读取小数部分
398            while let Some(ch) = state.peek() {
399                if ch.is_ascii_digit() {
400                    state.advance(1);
401                }
402                else if ch == '_' {
403                    state.advance(1); // 数字分隔
404                }
405                else {
406                    break;
407                }
408            }
409        }
410
411        // 检查科学计数法
412        if let Some('e') | Some('E') = state.peek() {
413            *is_float = true;
414            state.advance(1);
415
416            // 可选的符号
417            if let Some('+') | Some('-') = state.peek() {
418                state.advance(1);
419            }
420
421            // 指数部分
422            while let Some(ch) = state.peek() {
423                if ch.is_ascii_digit() {
424                    state.advance(1);
425                }
426                else if ch == '_' {
427                    state.advance(1); // 数字分隔                } else {
428                    break;
429                }
430            }
431        }
432    }
433
434    /// 处理标识符或关键
435    fn lex_identifier_or_keyword<S: Source>(&self, state: &mut State<S>) -> bool {
436        let start_pos = state.get_position();
437
438        // 检查第一个字
439        if !state.peek().map_or(false, |c| c.is_ascii_alphabetic() || c == '_') {
440            return false;
441        }
442
443        // 构建标识符字符串
444        let mut buf = String::new();
445
446        // 读取标识
447        while let Some(ch) = state.peek() {
448            if ch.is_ascii_alphanumeric() || ch == '_' || ch == '?' || ch == '!' {
449                buf.push(ch);
450                state.advance(1);
451            }
452            else {
453                break;
454            }
455        }
456
457        // 检查是否是关键字
458        let kind = match buf.as_str() {
459            "if" => RubySyntaxKind::If,
460            "unless" => RubySyntaxKind::Unless,
461            "elsif" => RubySyntaxKind::Elsif,
462            "else" => RubySyntaxKind::Else,
463            "case" => RubySyntaxKind::Case,
464            "when" => RubySyntaxKind::When,
465            "then" => RubySyntaxKind::Then,
466            "for" => RubySyntaxKind::For,
467            "while" => RubySyntaxKind::While,
468            "until" => RubySyntaxKind::Until,
469            "break" => RubySyntaxKind::Break,
470            "next" => RubySyntaxKind::Next,
471            "redo" => RubySyntaxKind::Redo,
472            "retry" => RubySyntaxKind::Retry,
473            "return" => RubySyntaxKind::Return,
474            "yield" => RubySyntaxKind::Yield,
475            "def" => RubySyntaxKind::Def,
476            "class" => RubySyntaxKind::Class,
477            "module" => RubySyntaxKind::Module,
478            "end" => RubySyntaxKind::End,
479            "lambda" => RubySyntaxKind::Lambda,
480            "proc" => RubySyntaxKind::Proc,
481            "begin" => RubySyntaxKind::Begin,
482            "rescue" => RubySyntaxKind::Rescue,
483            "ensure" => RubySyntaxKind::Ensure,
484            "raise" => RubySyntaxKind::Raise,
485            "require" => RubySyntaxKind::Require,
486            "load" => RubySyntaxKind::Load,
487            "include" => RubySyntaxKind::Include,
488            "extend" => RubySyntaxKind::Extend,
489            "prepend" => RubySyntaxKind::Prepend,
490            "and" => RubySyntaxKind::And,
491            "or" => RubySyntaxKind::Or,
492            "not" => RubySyntaxKind::Not,
493            "in" => RubySyntaxKind::In,
494            "true" => RubySyntaxKind::True,
495            "false" => RubySyntaxKind::False,
496            "nil" => RubySyntaxKind::Nil,
497            "super" => RubySyntaxKind::Super,
498            "self" => RubySyntaxKind::Self_,
499            "alias" => RubySyntaxKind::Alias,
500            "undef" => RubySyntaxKind::Undef,
501            "defined?" => RubySyntaxKind::Defined,
502            "do" => RubySyntaxKind::Do,
503            _ => RubySyntaxKind::Identifier,
504        };
505
506        state.add_token(kind, start_pos, state.get_position());
507        true
508    }
509
510    /// 处理操作
511    fn lex_operators<S: Source>(&self, state: &mut State<S>) -> bool {
512        let start_pos = state.get_position();
513
514        // 尝试匹配多字符操作符
515        let three_char_ops = ["<=>", "===", "**=", "<<=", ">>=", "||=", "&&=", "..."];
516        for op in &three_char_ops {
517            if state.peek() == op.chars().nth(0)
518                && state.peek_next_n(1) == op.chars().nth(1)
519                && state.peek_next_n(2) == op.chars().nth(2)
520            {
521                state.advance(3);
522                let kind = match *op {
523                    "<=>" => RubySyntaxKind::Spaceship,
524                    "===" => RubySyntaxKind::EqualEqualEqual,
525                    "**=" => RubySyntaxKind::PowerAssign,
526                    "<<=" => RubySyntaxKind::LeftShiftAssign,
527                    ">>=" => RubySyntaxKind::RightShiftAssign,
528                    "||=" => RubySyntaxKind::OrOrAssign,
529                    "&&=" => RubySyntaxKind::AndAndAssign,
530                    "..." => RubySyntaxKind::DotDotDot,
531                    _ => RubySyntaxKind::Invalid,
532                };
533                state.add_token(kind, start_pos, state.get_position());
534                return true;
535            }
536        }
537
538        let two_char_ops = [
539            "**", "<<", ">>", "<=", ">=", "==", "!=", "=~", "!~", "&&", "||", "+=", "-=", "*=", "/=", "%=", "&=", "|=", "^=",
540            "..",
541        ];
542        for op in &two_char_ops {
543            if state.peek() == op.chars().nth(0) && state.peek_next_n(1) == op.chars().nth(1) {
544                state.advance(2);
545                let kind = match *op {
546                    "**" => RubySyntaxKind::Power,
547                    "<<" => RubySyntaxKind::LeftShift,
548                    ">>" => RubySyntaxKind::RightShift,
549                    "<=" => RubySyntaxKind::LessEqual,
550                    ">=" => RubySyntaxKind::GreaterEqual,
551                    "==" => RubySyntaxKind::EqualEqual,
552                    "!=" => RubySyntaxKind::NotEqual,
553                    "=~" => RubySyntaxKind::Match,
554                    "!~" => RubySyntaxKind::NotMatch,
555                    "&&" => RubySyntaxKind::AndAnd,
556                    "||" => RubySyntaxKind::OrOr,
557                    "+=" => RubySyntaxKind::PlusAssign,
558                    "-=" => RubySyntaxKind::MinusAssign,
559                    "*=" => RubySyntaxKind::MultiplyAssign,
560                    "/=" => RubySyntaxKind::DivideAssign,
561                    "%=" => RubySyntaxKind::ModuloAssign,
562                    "&=" => RubySyntaxKind::AndAssign,
563                    "|=" => RubySyntaxKind::OrAssign,
564                    "^=" => RubySyntaxKind::XorAssign,
565                    ".." => RubySyntaxKind::DotDot,
566                    _ => RubySyntaxKind::Invalid,
567                };
568                state.add_token(kind, start_pos, state.get_position());
569                return true;
570            }
571        }
572
573        // 尝试匹配单字符操作符
574        let single_char_ops = ['+', '-', '*', '/', '%', '=', '<', '>', '&', '|', '^', '!', '~', '?'];
575
576        if let Some(ch) = state.peek() {
577            if single_char_ops.contains(&ch) {
578                state.advance(1);
579                let kind = match ch {
580                    '+' => RubySyntaxKind::Plus,
581                    '-' => RubySyntaxKind::Minus,
582                    '*' => RubySyntaxKind::Multiply,
583                    '/' => RubySyntaxKind::Divide,
584                    '%' => RubySyntaxKind::Modulo,
585                    '=' => RubySyntaxKind::Assign,
586                    '<' => RubySyntaxKind::Less,
587                    '>' => RubySyntaxKind::Greater,
588                    '&' => RubySyntaxKind::BitAnd,
589                    '|' => RubySyntaxKind::BitOr,
590                    '^' => RubySyntaxKind::Xor,
591                    '!' => RubySyntaxKind::LogicalNot,
592                    '~' => RubySyntaxKind::Tilde,
593                    '?' => RubySyntaxKind::Question,
594                    _ => RubySyntaxKind::Invalid,
595                };
596                state.add_token(kind, start_pos, state.get_position());
597                return true;
598            }
599        }
600
601        false
602    }
603
604    /// 处理分隔
605    fn lex_single_char_tokens<S: Source>(&self, state: &mut State<S>) -> bool {
606        let start_pos = state.get_position();
607
608        // 检查双冒号
609        if state.peek() == Some(':') && state.peek_next_n(1) == Some(':') {
610            state.advance(2);
611            state.add_token(RubySyntaxKind::DoubleColon, start_pos, state.get_position());
612            return true;
613        }
614
615        // 单字符分隔符
616        let delimiters = ['(', ')', '[', ']', '{', '}', ',', ';', '.', ':', '@', '$'];
617
618        if let Some(ch) = state.peek() {
619            if delimiters.contains(&ch) {
620                state.advance(1);
621                let kind = match ch {
622                    '(' => RubySyntaxKind::LeftParen,
623                    ')' => RubySyntaxKind::RightParen,
624                    '[' => RubySyntaxKind::LeftBracket,
625                    ']' => RubySyntaxKind::RightBracket,
626                    '{' => RubySyntaxKind::LeftBrace,
627                    '}' => RubySyntaxKind::RightBrace,
628                    ',' => RubySyntaxKind::Comma,
629                    ';' => RubySyntaxKind::Semicolon,
630                    '.' => RubySyntaxKind::Dot,
631                    ':' => RubySyntaxKind::Colon,
632                    '@' => RubySyntaxKind::At,
633                    '$' => RubySyntaxKind::Dollar,
634                    _ => RubySyntaxKind::Invalid,
635                };
636                state.add_token(kind, start_pos, state.get_position());
637                return true;
638            }
639        }
640
641        // 如果没有匹配任何已知字符,将其标记为 Invalid 并推进位置
642        if let Some(_ch) = state.peek() {
643            state.advance(1);
644            state.add_token(RubySyntaxKind::Invalid, start_pos, state.get_position());
645            return true;
646        }
647
648        false
649    }
650}