Skip to main content

oak_ruby/lexer/
mod.rs

1use crate::{kind::RubySyntaxKind, language::RubyLanguage};
2use oak_core::{LexOutput, Lexer, LexerCache, LexerState, OakError, Source, TextEdit};
3
4type State<'a, S> = LexerState<'a, S, RubyLanguage>;
5
6#[derive(Clone, Debug)]
7pub struct RubyLexer<'config> {
8    _config: &'config RubyLanguage,
9}
10
11impl<'config> Lexer<RubyLanguage> for RubyLexer<'config> {
12    fn lex<'a, S: Source + ?Sized>(&self, source: &S, _edits: &[TextEdit], cache: &'a mut impl LexerCache<RubyLanguage>) -> LexOutput<RubyLanguage> {
13        let mut state: State<'_, S> = LexerState::new(source);
14        let result = self.run(&mut state);
15        if result.is_ok() {
16            state.add_eof();
17        }
18        state.finish_with_cache(result, cache)
19    }
20}
21
22impl<'config> RubyLexer<'config> {
23    pub fn new(config: &'config RubyLanguage) -> Self {
24        Self { _config: config }
25    }
26
27    fn run<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
28        while state.not_at_end() {
29            let safe_point = state.get_position();
30
31            if self.skip_whitespace(state) {
32                continue;
33            }
34
35            if self.lex_newline(state) {
36                continue;
37            }
38
39            if self.skip_comment(state) {
40                continue;
41            }
42
43            if self.lex_string_literal(state) {
44                continue;
45            }
46
47            if self.lex_symbol(state) {
48                continue;
49            }
50
51            if self.lex_number_literal(state) {
52                continue;
53            }
54
55            if self.lex_identifier_or_keyword(state) {
56                continue;
57            }
58
59            if self.lex_operators(state) {
60                continue;
61            }
62
63            if self.lex_single_char_tokens(state) {
64                continue;
65            }
66
67            state.advance_if_dead_lock(safe_point);
68        }
69
70        Ok(())
71    }
72
73    /// 跳过空白字符
74    fn skip_whitespace<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
75        let start_pos = state.get_position();
76
77        while let Some(ch) = state.peek() {
78            if ch == ' ' || ch == '\t' {
79                state.advance(ch.len_utf8());
80            }
81            else {
82                break;
83            }
84        }
85
86        if state.get_position() > start_pos {
87            state.add_token(RubySyntaxKind::Whitespace, start_pos, state.get_position());
88            true
89        }
90        else {
91            false
92        }
93    }
94
95    /// 处理换行
96    fn lex_newline<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
97        let start_pos = state.get_position();
98
99        if let Some('\n') = state.peek() {
100            state.advance(1);
101            state.add_token(RubySyntaxKind::Newline, start_pos, state.get_position());
102            true
103        }
104        else if let Some('\r') = state.peek() {
105            state.advance(1);
106            if let Some('\n') = state.peek() {
107                state.advance(1);
108            }
109            state.add_token(RubySyntaxKind::Newline, start_pos, state.get_position());
110            true
111        }
112        else {
113            false
114        }
115    }
116
117    /// 处理注释
118    fn skip_comment<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
119        if let Some('#') = state.peek() {
120            let start_pos = state.get_position();
121            state.advance(1); // 跳过 '#'
122
123            // 读取到行
124            while let Some(ch) = state.peek() {
125                if ch == '\n' || ch == '\r' {
126                    break;
127                }
128                state.advance(ch.len_utf8());
129            }
130
131            state.add_token(RubySyntaxKind::Comment, start_pos, state.get_position());
132            true
133        }
134        else {
135            false
136        }
137    }
138
139    /// 处理字符串字面量
140    fn lex_string_literal<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
141        let start_pos = state.get_position();
142
143        // 检查是否是字符串开
144        let quote_char = match state.peek() {
145            Some('"') => '"',
146            Some('\'') => '\'',
147            Some('`') => '`',
148            _ => return false,
149        };
150
151        state.advance(1); // 跳过开始引
152        let mut escaped = false;
153        while let Some(ch) = state.peek() {
154            if escaped {
155                escaped = false;
156                state.advance(ch.len_utf8());
157                continue;
158            }
159
160            if ch == '\\' {
161                escaped = true;
162                state.advance(1);
163                continue;
164            }
165
166            if ch == quote_char {
167                state.advance(1); // 跳过结束引号
168                break;
169            }
170            else if ch == '\n' || ch == '\r' {
171                // Ruby 字符串可以跨多行
172                state.advance(ch.len_utf8());
173            }
174            else {
175                state.advance(ch.len_utf8());
176            }
177        }
178
179        state.add_token(RubySyntaxKind::StringLiteral, start_pos, state.get_position());
180        true
181    }
182
183    /// 处理符号
184    fn lex_symbol<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
185        if let Some(':') = state.peek() {
186            let start_pos = state.get_position();
187            state.advance(1); // 跳过 ':'
188
189            // 检查下一个字符是否是标识符开
190            if let Some(ch) = state.peek() {
191                if ch.is_ascii_alphabetic() || ch == '_' {
192                    // 读取标识
193                    while let Some(ch) = state.peek() {
194                        if ch.is_ascii_alphanumeric() || ch == '_' || ch == '?' || ch == '!' {
195                            state.advance(1);
196                        }
197                        else {
198                            break;
199                        }
200                    }
201                    state.add_token(RubySyntaxKind::Symbol, start_pos, state.get_position());
202                    return true;
203                }
204                else if ch == '"' || ch == '\'' {
205                    // 引号符号
206                    let quote = ch;
207                    state.advance(1);
208
209                    let mut escaped = false;
210                    while let Some(ch) = state.peek() {
211                        if escaped {
212                            escaped = false;
213                            state.advance(ch.len_utf8());
214                            continue;
215                        }
216
217                        if ch == '\\' {
218                            escaped = true;
219                            state.advance(1);
220                            continue;
221                        }
222
223                        if ch == quote {
224                            state.advance(1);
225                            break;
226                        }
227                        else {
228                            state.advance(ch.len_utf8());
229                        }
230                    }
231                    state.add_token(RubySyntaxKind::Symbol, start_pos, state.get_position());
232                    return true;
233                }
234            }
235        }
236        false
237    }
238
239    /// 处理数字字面
240    fn lex_number_literal<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
241        let start_pos = state.get_position();
242
243        if !state.peek().map_or(false, |c| c.is_ascii_digit()) {
244            return false;
245        }
246
247        let mut is_float = false;
248
249        // 检查进制前缀
250        if state.peek() == Some('0') {
251            let next_char = state.peek_next_n(1);
252            match next_char {
253                Some('b') | Some('B') => {
254                    state.advance(2); // 跳过 '0b' '0B'
255                    // 读取二进制数
256                    while let Some(ch) = state.peek() {
257                        if ch == '0' || ch == '1' {
258                            state.advance(1);
259                        }
260                        else if ch == '_' {
261                            state.advance(1); // 数字分隔
262                        }
263                        else {
264                            break;
265                        }
266                    }
267                }
268                Some('o') | Some('O') => {
269                    state.advance(2); // 跳过 '0o' '0O'
270                    // 读取八进制数
271                    while let Some(ch) = state.peek() {
272                        if ch.is_ascii_digit() && ch < '8' {
273                            state.advance(1);
274                        }
275                        else if ch == '_' {
276                            state.advance(1); // 数字分隔
277                        }
278                        else {
279                            break;
280                        }
281                    }
282                }
283                Some('x') | Some('X') => {
284                    state.advance(2); // 跳过 '0x' '0X'
285                    // 读取十六进制数字
286                    while let Some(ch) = state.peek() {
287                        if ch.is_ascii_hexdigit() {
288                            state.advance(1);
289                        }
290                        else if ch == '_' {
291                            state.advance(1); // 数字分隔
292                        }
293                        else {
294                            break;
295                        }
296                    }
297                }
298                _ => {
299                    // 十进制数
300                    self.lex_decimal_number(state, &mut is_float);
301                }
302            }
303        }
304        else {
305            // 十进制数
306            self.lex_decimal_number(state, &mut is_float);
307        }
308
309        let kind = if is_float { RubySyntaxKind::FloatLiteral } else { RubySyntaxKind::IntegerLiteral };
310
311        state.add_token(kind, start_pos, state.get_position());
312        true
313    }
314
315    /// 处理十进制数
316    fn lex_decimal_number<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>, is_float: &mut bool) {
317        // 读取整数部分
318        while let Some(ch) = state.peek() {
319            if ch.is_ascii_digit() {
320                state.advance(1);
321            }
322            else if ch == '_' {
323                state.advance(1); // 数字分隔            } else {
324                break;
325            }
326        }
327
328        // 检查小数点
329        if state.peek() == Some('.') && state.peek_next_n(1).map_or(false, |c| c.is_ascii_digit()) {
330            *is_float = true;
331            state.advance(1); // 跳过小数
332            // 读取小数部分
333            while let Some(ch) = state.peek() {
334                if ch.is_ascii_digit() {
335                    state.advance(1);
336                }
337                else if ch == '_' {
338                    state.advance(1); // 数字分隔
339                }
340                else {
341                    break;
342                }
343            }
344        }
345
346        // 检查科学计数法
347        if let Some('e') | Some('E') = state.peek() {
348            *is_float = true;
349            state.advance(1);
350
351            // 可选的符号
352            if let Some('+') | Some('-') = state.peek() {
353                state.advance(1);
354            }
355
356            // 指数部分
357            while let Some(ch) = state.peek() {
358                if ch.is_ascii_digit() {
359                    state.advance(1);
360                }
361                else if ch == '_' {
362                    state.advance(1); // 数字分隔                } else {
363                    break;
364                }
365            }
366        }
367    }
368
369    /// 处理标识符或关键
370    fn lex_identifier_or_keyword<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
371        let start_pos = state.get_position();
372
373        // 检查第一个字
374        if !state.peek().map_or(false, |c| c.is_ascii_alphabetic() || c == '_') {
375            return false;
376        }
377
378        // 构建标识符字符串
379        let mut buf = String::new();
380
381        // 读取标识
382        while let Some(ch) = state.peek() {
383            if ch.is_ascii_alphanumeric() || ch == '_' || ch == '?' || ch == '!' {
384                buf.push(ch);
385                state.advance(1);
386            }
387            else {
388                break;
389            }
390        }
391
392        // 检查是否是关键字
393        let kind = match buf.as_str() {
394            "if" => RubySyntaxKind::If,
395            "unless" => RubySyntaxKind::Unless,
396            "elsif" => RubySyntaxKind::Elsif,
397            "else" => RubySyntaxKind::Else,
398            "case" => RubySyntaxKind::Case,
399            "when" => RubySyntaxKind::When,
400            "then" => RubySyntaxKind::Then,
401            "for" => RubySyntaxKind::For,
402            "while" => RubySyntaxKind::While,
403            "until" => RubySyntaxKind::Until,
404            "break" => RubySyntaxKind::Break,
405            "next" => RubySyntaxKind::Next,
406            "redo" => RubySyntaxKind::Redo,
407            "retry" => RubySyntaxKind::Retry,
408            "return" => RubySyntaxKind::Return,
409            "yield" => RubySyntaxKind::Yield,
410            "def" => RubySyntaxKind::Def,
411            "class" => RubySyntaxKind::Class,
412            "module" => RubySyntaxKind::Module,
413            "end" => RubySyntaxKind::End,
414            "lambda" => RubySyntaxKind::Lambda,
415            "proc" => RubySyntaxKind::Proc,
416            "begin" => RubySyntaxKind::Begin,
417            "rescue" => RubySyntaxKind::Rescue,
418            "ensure" => RubySyntaxKind::Ensure,
419            "raise" => RubySyntaxKind::Raise,
420            "require" => RubySyntaxKind::Require,
421            "load" => RubySyntaxKind::Load,
422            "include" => RubySyntaxKind::Include,
423            "extend" => RubySyntaxKind::Extend,
424            "prepend" => RubySyntaxKind::Prepend,
425            "and" => RubySyntaxKind::And,
426            "or" => RubySyntaxKind::Or,
427            "not" => RubySyntaxKind::Not,
428            "in" => RubySyntaxKind::In,
429            "true" => RubySyntaxKind::True,
430            "false" => RubySyntaxKind::False,
431            "nil" => RubySyntaxKind::Nil,
432            "super" => RubySyntaxKind::Super,
433            "self" => RubySyntaxKind::Self_,
434            "alias" => RubySyntaxKind::Alias,
435            "undef" => RubySyntaxKind::Undef,
436            "defined?" => RubySyntaxKind::Defined,
437            "do" => RubySyntaxKind::Do,
438            _ => RubySyntaxKind::Identifier,
439        };
440
441        state.add_token(kind, start_pos, state.get_position());
442        true
443    }
444
445    /// 处理操作
446    fn lex_operators<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
447        let start_pos = state.get_position();
448
449        // 尝试匹配多字符操作符
450        let three_char_ops = ["<=>", "===", "**=", "<<=", ">>=", "||=", "&&=", "..."];
451        for op in &three_char_ops {
452            if state.peek() == op.chars().nth(0) && state.peek_next_n(1) == op.chars().nth(1) && state.peek_next_n(2) == op.chars().nth(2) {
453                state.advance(3);
454                let kind = match *op {
455                    "<=>" => RubySyntaxKind::Spaceship,
456                    "===" => RubySyntaxKind::EqualEqualEqual,
457                    "**=" => RubySyntaxKind::PowerAssign,
458                    "<<=" => RubySyntaxKind::LeftShiftAssign,
459                    ">>=" => RubySyntaxKind::RightShiftAssign,
460                    "||=" => RubySyntaxKind::OrOrAssign,
461                    "&&=" => RubySyntaxKind::AndAndAssign,
462                    "..." => RubySyntaxKind::DotDotDot,
463                    _ => RubySyntaxKind::Invalid,
464                };
465                state.add_token(kind, start_pos, state.get_position());
466                return true;
467            }
468        }
469
470        let two_char_ops = ["**", "<<", ">>", "<=", ">=", "==", "!=", "=~", "!~", "&&", "||", "+=", "-=", "*=", "/=", "%=", "&=", "|=", "^=", ".."];
471        for op in &two_char_ops {
472            if state.peek() == op.chars().nth(0) && state.peek_next_n(1) == op.chars().nth(1) {
473                state.advance(2);
474                let kind = match *op {
475                    "**" => RubySyntaxKind::Power,
476                    "<<" => RubySyntaxKind::LeftShift,
477                    ">>" => RubySyntaxKind::RightShift,
478                    "<=" => RubySyntaxKind::LessEqual,
479                    ">=" => RubySyntaxKind::GreaterEqual,
480                    "==" => RubySyntaxKind::EqualEqual,
481                    "!=" => RubySyntaxKind::NotEqual,
482                    "=~" => RubySyntaxKind::Match,
483                    "!~" => RubySyntaxKind::NotMatch,
484                    "&&" => RubySyntaxKind::AndAnd,
485                    "||" => RubySyntaxKind::OrOr,
486                    "+=" => RubySyntaxKind::PlusAssign,
487                    "-=" => RubySyntaxKind::MinusAssign,
488                    "*=" => RubySyntaxKind::MultiplyAssign,
489                    "/=" => RubySyntaxKind::DivideAssign,
490                    "%=" => RubySyntaxKind::ModuloAssign,
491                    "&=" => RubySyntaxKind::AndAssign,
492                    "|=" => RubySyntaxKind::OrAssign,
493                    "^=" => RubySyntaxKind::XorAssign,
494                    ".." => RubySyntaxKind::DotDot,
495                    _ => RubySyntaxKind::Invalid,
496                };
497                state.add_token(kind, start_pos, state.get_position());
498                return true;
499            }
500        }
501
502        // 尝试匹配单字符操作符
503        let single_char_ops = ['+', '-', '*', '/', '%', '=', '<', '>', '&', '|', '^', '!', '~', '?'];
504
505        if let Some(ch) = state.peek() {
506            if single_char_ops.contains(&ch) {
507                state.advance(1);
508                let kind = match ch {
509                    '+' => RubySyntaxKind::Plus,
510                    '-' => RubySyntaxKind::Minus,
511                    '*' => RubySyntaxKind::Multiply,
512                    '/' => RubySyntaxKind::Divide,
513                    '%' => RubySyntaxKind::Modulo,
514                    '=' => RubySyntaxKind::Assign,
515                    '<' => RubySyntaxKind::Less,
516                    '>' => RubySyntaxKind::Greater,
517                    '&' => RubySyntaxKind::BitAnd,
518                    '|' => RubySyntaxKind::BitOr,
519                    '^' => RubySyntaxKind::Xor,
520                    '!' => RubySyntaxKind::LogicalNot,
521                    '~' => RubySyntaxKind::Tilde,
522                    '?' => RubySyntaxKind::Question,
523                    _ => RubySyntaxKind::Invalid,
524                };
525                state.add_token(kind, start_pos, state.get_position());
526                return true;
527            }
528        }
529
530        false
531    }
532
533    /// 处理分隔
534    fn lex_single_char_tokens<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
535        let start_pos = state.get_position();
536
537        // 检查双冒号
538        if state.peek() == Some(':') && state.peek_next_n(1) == Some(':') {
539            state.advance(2);
540            state.add_token(RubySyntaxKind::DoubleColon, start_pos, state.get_position());
541            return true;
542        }
543
544        // 单字符分隔符
545        let delimiters = ['(', ')', '[', ']', '{', '}', ',', ';', '.', ':', '@', '$'];
546
547        if let Some(ch) = state.peek() {
548            if delimiters.contains(&ch) {
549                state.advance(1);
550                let kind = match ch {
551                    '(' => RubySyntaxKind::LeftParen,
552                    ')' => RubySyntaxKind::RightParen,
553                    '[' => RubySyntaxKind::LeftBracket,
554                    ']' => RubySyntaxKind::RightBracket,
555                    '{' => RubySyntaxKind::LeftBrace,
556                    '}' => RubySyntaxKind::RightBrace,
557                    ',' => RubySyntaxKind::Comma,
558                    ';' => RubySyntaxKind::Semicolon,
559                    '.' => RubySyntaxKind::Dot,
560                    ':' => RubySyntaxKind::Colon,
561                    '@' => RubySyntaxKind::At,
562                    '$' => RubySyntaxKind::Dollar,
563                    _ => RubySyntaxKind::Invalid,
564                };
565                state.add_token(kind, start_pos, state.get_position());
566                return true;
567            }
568        }
569
570        // 如果没有匹配任何已知字符,将其标记为 Invalid 并推进位置
571        if let Some(_ch) = state.peek() {
572            state.advance(1);
573            state.add_token(RubySyntaxKind::Invalid, start_pos, state.get_position());
574            return true;
575        }
576
577        false
578    }
579}