Skip to main content

oak_scala/lexer/
mod.rs

1use crate::{kind::ScalaSyntaxKind, language::ScalaLanguage};
2use oak_core::{
3    Lexer, LexerCache, LexerState, OakError, TextEdit,
4    lexer::{CommentConfig, LexOutput, StringConfig, WhitespaceConfig},
5    source::Source,
6};
7use std::sync::LazyLock;
8
9type State<'s, S> = LexerState<'s, S, ScalaLanguage>;
10
11static SCALA_WHITESPACE: LazyLock<WhitespaceConfig> = LazyLock::new(|| WhitespaceConfig { unicode_whitespace: true });
12static SCALA_COMMENT: LazyLock<CommentConfig> = LazyLock::new(|| CommentConfig { line_marker: "//", block_start: "/*", block_end: "*/", nested_blocks: true });
13static SCALA_STRING: LazyLock<StringConfig> = LazyLock::new(|| StringConfig { quotes: &['"'], escape: Some('\\') });
14static SCALA_CHAR: LazyLock<StringConfig> = LazyLock::new(|| StringConfig { quotes: &['\''], escape: None });
15
16#[derive(Clone, Debug)]
17pub struct ScalaLexer<'config> {
18    _config: &'config ScalaLanguage,
19}
20
21impl<'config> Lexer<ScalaLanguage> for ScalaLexer<'config> {
22    fn lex<'a, S: Source + ?Sized>(&self, source: &'a S, _edits: &[TextEdit], cache: &'a mut impl LexerCache<ScalaLanguage>) -> LexOutput<ScalaLanguage> {
23        let mut state: State<'_, S> = LexerState::new(source);
24        let result = self.run(&mut state);
25        if result.is_ok() {
26            state.add_eof();
27        }
28        state.finish_with_cache(result, cache)
29    }
30}
31
32impl<'config> ScalaLexer<'config> {
33    pub fn new(config: &'config ScalaLanguage) -> Self {
34        Self { _config: config }
35    }
36
37    fn run<'s, S: Source + ?Sized>(&self, state: &mut State<'s, S>) -> Result<(), OakError> {
38        while state.not_at_end() {
39            let safe_point = state.get_position();
40
41            if self.skip_whitespace(state) {
42                continue;
43            }
44
45            if self.lex_newline(state) {
46                continue;
47            }
48
49            if self.skip_comment(state) {
50                continue;
51            }
52
53            if self.lex_string_literal(state) {
54                continue;
55            }
56
57            if self.lex_char_literal(state) {
58                continue;
59            }
60
61            if self.lex_number_literal(state) {
62                continue;
63            }
64
65            if self.lex_identifier_or_keyword(state) {
66                continue;
67            }
68
69            if self.lex_operators(state) {
70                continue;
71            }
72
73            if self.lex_single_char_tokens(state) {
74                continue;
75            }
76
77            // 错误处理:如果没有匹配任何规则,跳过当前字符并标记为错误
78            let start_pos = state.get_position();
79            if let Some(ch) = state.peek() {
80                state.advance(ch.len_utf8());
81                state.add_token(ScalaSyntaxKind::Error, start_pos, state.get_position());
82            }
83
84            state.advance_if_dead_lock(safe_point);
85        }
86
87        Ok(())
88    }
89
90    fn skip_whitespace<'s, S: Source + ?Sized>(&self, state: &mut State<'s, S>) -> bool {
91        SCALA_WHITESPACE.scan(state, ScalaSyntaxKind::Whitespace)
92    }
93
94    /// 处理换行
95    fn lex_newline<'s, S: Source + ?Sized>(&self, state: &mut State<'s, S>) -> bool {
96        let start_pos = state.get_position();
97
98        if let Some('\n') = state.peek() {
99            state.advance(1);
100            state.add_token(ScalaSyntaxKind::Newline, start_pos, state.get_position());
101            true
102        }
103        else if let Some('\r') = state.peek() {
104            state.advance(1);
105            if let Some('\n') = state.peek() {
106                state.advance(1);
107            }
108            state.add_token(ScalaSyntaxKind::Newline, start_pos, state.get_position());
109            true
110        }
111        else {
112            false
113        }
114    }
115
116    fn skip_comment<'s, S: Source + ?Sized>(&self, state: &mut State<'s, S>) -> bool {
117        // 行注释 & 块注释
118        if SCALA_COMMENT.scan(state, ScalaSyntaxKind::LineComment, ScalaSyntaxKind::BlockComment) {
119            return true;
120        }
121
122        false
123    }
124
125    fn lex_string_literal<'s, S: Source + ?Sized>(&self, state: &mut State<'s, S>) -> bool {
126        SCALA_STRING.scan(state, ScalaSyntaxKind::StringLiteral)
127    }
128
129    fn lex_char_literal<'s, S: Source + ?Sized>(&self, state: &mut State<'s, S>) -> bool {
130        SCALA_CHAR.scan(state, ScalaSyntaxKind::CharLiteral)
131    }
132
133    fn lex_number_literal<'s, S: Source + ?Sized>(&self, state: &mut State<'s, S>) -> bool {
134        if !state.current().map_or(false, |c| c.is_ascii_digit()) {
135            return false;
136        }
137
138        let start = state.get_position();
139        let mut len = 0;
140
141        // 跳过数字
142        while let Some(ch) = state.source().get_char_at(start + len) {
143            if ch.is_ascii_digit() {
144                len += ch.len_utf8();
145            }
146            else if ch == '.' {
147                // 浮点数
148                len += ch.len_utf8();
149                while let Some(ch) = state.source().get_char_at(start + len) {
150                    if ch.is_ascii_digit() {
151                        len += ch.len_utf8();
152                    }
153                    else {
154                        break;
155                    }
156                }
157                break;
158            }
159            else {
160                break;
161            }
162        }
163
164        state.advance(len);
165        let end = state.get_position();
166        state.add_token(ScalaSyntaxKind::IntegerLiteral, start, end);
167        true
168    }
169
170    fn lex_identifier_or_keyword<'s, S: Source + ?Sized>(&self, state: &mut State<'s, S>) -> bool {
171        let first_char = match state.current() {
172            Some(c) if c.is_alphabetic() || c == '_' => c,
173            _ => return false,
174        };
175
176        let start = state.get_position();
177        let mut len = first_char.len_utf8();
178
179        while let Some(ch) = state.source().get_char_at(start + len) {
180            if ch.is_alphanumeric() || ch == '_' {
181                len += ch.len_utf8();
182            }
183            else {
184                break;
185            }
186        }
187
188        let text = state.source().get_text_in((start..start + len).into());
189        state.advance(len);
190        let end = state.get_position();
191
192        let kind = match text.as_ref() {
193            "abstract" => ScalaSyntaxKind::Abstract,
194            "case" => ScalaSyntaxKind::Case,
195            "catch" => ScalaSyntaxKind::Catch,
196            "class" => ScalaSyntaxKind::Class,
197            "def" => ScalaSyntaxKind::Def,
198            "do" => ScalaSyntaxKind::Do,
199            "else" => ScalaSyntaxKind::Else,
200            "extends" => ScalaSyntaxKind::Extends,
201            "false" => ScalaSyntaxKind::False,
202            "final" => ScalaSyntaxKind::Final,
203            "finally" => ScalaSyntaxKind::Finally,
204            "for" => ScalaSyntaxKind::For,
205            "if" => ScalaSyntaxKind::If,
206            "implicit" => ScalaSyntaxKind::Implicit,
207            "import" => ScalaSyntaxKind::Import,
208            "lazy" => ScalaSyntaxKind::Lazy,
209            "match" => ScalaSyntaxKind::Match,
210            "new" => ScalaSyntaxKind::New,
211            "null" => ScalaSyntaxKind::Null,
212            "object" => ScalaSyntaxKind::Object,
213            "override" => ScalaSyntaxKind::Override,
214            "package" => ScalaSyntaxKind::Package,
215            "private" => ScalaSyntaxKind::Private,
216            "protected" => ScalaSyntaxKind::Protected,
217            "return" => ScalaSyntaxKind::Return,
218            "sealed" => ScalaSyntaxKind::Sealed,
219            "super" => ScalaSyntaxKind::Super,
220            "this" => ScalaSyntaxKind::This,
221            "throw" => ScalaSyntaxKind::Throw,
222            "trait" => ScalaSyntaxKind::Trait,
223            "true" => ScalaSyntaxKind::True,
224            "try" => ScalaSyntaxKind::Try,
225            "type" => ScalaSyntaxKind::Type,
226            "val" => ScalaSyntaxKind::Val,
227            "var" => ScalaSyntaxKind::Var,
228            "while" => ScalaSyntaxKind::While,
229            "with" => ScalaSyntaxKind::With,
230            "yield" => ScalaSyntaxKind::Yield,
231            _ => ScalaSyntaxKind::Identifier,
232        };
233
234        state.add_token(kind, start, end);
235        true
236    }
237
238    fn lex_operators<'s, S: Source + ?Sized>(&self, state: &mut State<'s, S>) -> bool {
239        let start = state.get_position();
240
241        // 多字符操作符
242        if state.starts_with("=>") {
243            state.advance(2);
244            state.add_token(ScalaSyntaxKind::Arrow, start, state.get_position());
245            return true;
246        }
247        if state.starts_with("<=") {
248            state.advance(2);
249            state.add_token(ScalaSyntaxKind::LessEqual, start, state.get_position());
250            return true;
251        }
252        if state.starts_with(">=") {
253            state.advance(2);
254            state.add_token(ScalaSyntaxKind::GreaterEqual, start, state.get_position());
255            return true;
256        }
257        if state.starts_with("==") {
258            state.advance(2);
259            state.add_token(ScalaSyntaxKind::EqualEqual, start, state.get_position());
260            return true;
261        }
262        if state.starts_with("!=") {
263            state.advance(2);
264            state.add_token(ScalaSyntaxKind::NotEqual, start, state.get_position());
265            return true;
266        }
267
268        false
269    }
270
271    fn lex_single_char_tokens<'s, S: Source + ?Sized>(&self, state: &mut State<'s, S>) -> bool {
272        let ch = match state.current() {
273            Some(c) => c,
274            None => return false,
275        };
276        let start = state.get_position();
277        state.advance(ch.len_utf8());
278        let end = state.get_position();
279
280        let kind = match ch {
281            '(' => ScalaSyntaxKind::LeftParen,
282            ')' => ScalaSyntaxKind::RightParen,
283            '[' => ScalaSyntaxKind::LeftBracket,
284            ']' => ScalaSyntaxKind::RightBracket,
285            '{' => ScalaSyntaxKind::LeftBrace,
286            '}' => ScalaSyntaxKind::RightBrace,
287            ',' => ScalaSyntaxKind::Comma,
288            ';' => ScalaSyntaxKind::Semicolon,
289            ':' => ScalaSyntaxKind::Colon,
290            '.' => ScalaSyntaxKind::Dot,
291            '+' => ScalaSyntaxKind::Plus,
292            '-' => ScalaSyntaxKind::Minus,
293            '*' => ScalaSyntaxKind::Star,
294            '/' => ScalaSyntaxKind::Slash,
295            '%' => ScalaSyntaxKind::Percent,
296            '=' => ScalaSyntaxKind::Eq,
297            '<' => ScalaSyntaxKind::Lt,
298            '>' => ScalaSyntaxKind::Gt,
299            '!' => ScalaSyntaxKind::Not,
300            '&' => ScalaSyntaxKind::And,
301            '|' => ScalaSyntaxKind::Or,
302            '^' => ScalaSyntaxKind::Xor,
303            '~' => ScalaSyntaxKind::Tilde,
304            '?' => ScalaSyntaxKind::Question,
305            '@' => ScalaSyntaxKind::At,
306            '#' => ScalaSyntaxKind::Hash,
307            _ => {
308                return false;
309            }
310        };
311
312        state.add_token(kind, start, end);
313        true
314    }
315}