oak_haskell/lexer/
mod.rs

1use crate::{kind::HaskellSyntaxKind, language::HaskellLanguage};
2use oak_core::{IncrementalCache, Lexer, LexerState, lexer::LexOutput, source::Source};
3
4#[derive(Clone)]
5pub struct HaskellLexer<'config> {
6    config: &'config HaskellLanguage,
7}
8
9impl<'config> HaskellLexer<'config> {
10    pub fn new(config: &'config HaskellLanguage) -> Self {
11        Self { config }
12    }
13
14    fn skip_whitespace<S: Source>(&self, state: &mut LexerState<S, HaskellLanguage>) -> bool {
15        let start_pos = state.get_position();
16        while let Some(ch) = state.peek() {
17            if ch == ' ' || ch == '\t' {
18                state.advance(1);
19            }
20            else {
21                break;
22            }
23        }
24
25        if state.get_position() > start_pos {
26            state.add_token(HaskellSyntaxKind::Whitespace, start_pos, state.get_position());
27            true
28        }
29        else {
30            false
31        }
32    }
33
34    fn lex_newline<S: Source>(&self, state: &mut LexerState<S, HaskellLanguage>) -> bool {
35        let start_pos = state.get_position();
36
37        if let Some('\n') = state.peek() {
38            state.advance(1);
39            state.add_token(HaskellSyntaxKind::Newline, start_pos, state.get_position());
40            true
41        }
42        else if let Some('\r') = state.peek() {
43            state.advance(1);
44            if let Some('\n') = state.peek() {
45                state.advance(1);
46            }
47            state.add_token(HaskellSyntaxKind::Newline, start_pos, state.get_position());
48            true
49        }
50        else {
51            false
52        }
53    }
54
55    fn lex_single_line_comment<S: Source>(&self, state: &mut LexerState<S, HaskellLanguage>) -> bool {
56        let start_pos = state.get_position();
57
58        if let Some('-') = state.peek() {
59            if let Some('-') = state.peek_next_n(1) {
60                state.advance(2);
61                while let Some(ch) = state.peek() {
62                    if ch == '\n' || ch == '\r' {
63                        break;
64                    }
65                    state.advance(1);
66                }
67                state.add_token(HaskellSyntaxKind::Comment, start_pos, state.get_position());
68                true
69            }
70            else {
71                false
72            }
73        }
74        else {
75            false
76        }
77    }
78
79    fn lex_multi_line_comment<S: Source>(&self, state: &mut LexerState<S, HaskellLanguage>) -> bool {
80        let start_pos = state.get_position();
81
82        if let Some('{') = state.peek() {
83            if let Some('-') = state.peek_next_n(1) {
84                state.advance(2);
85                let mut depth = 1;
86                while let Some(ch) = state.peek() {
87                    if ch == '{' && state.peek_next_n(1) == Some('-') {
88                        depth += 1;
89                        state.advance(2);
90                    }
91                    else if ch == '-' && state.peek_next_n(1) == Some('}') {
92                        depth -= 1;
93                        state.advance(2);
94                        if depth == 0 {
95                            break;
96                        }
97                    }
98                    else {
99                        state.advance(1);
100                    }
101                }
102                state.add_token(HaskellSyntaxKind::Comment, start_pos, state.get_position());
103                true
104            }
105            else {
106                false
107            }
108        }
109        else {
110            false
111        }
112    }
113
114    fn lex_identifier_or_keyword<S: Source>(&self, state: &mut LexerState<S, HaskellLanguage>) -> bool {
115        let start_pos = state.get_position();
116
117        if let Some(ch) = state.peek() {
118            if ch.is_ascii_alphabetic() || ch == '_' {
119                state.advance(1);
120
121                while let Some(ch) = state.peek() {
122                    if ch.is_ascii_alphanumeric() || ch == '_' || ch == '\'' {
123                        state.advance(1);
124                    }
125                    else {
126                        break;
127                    }
128                }
129
130                let end_pos = state.get_position();
131                let text = state.get_text_in((start_pos..end_pos).into());
132                let kind = self.keyword_or_identifier(text.as_ref());
133
134                state.add_token(kind, start_pos, end_pos);
135                true
136            }
137            else {
138                false
139            }
140        }
141        else {
142            false
143        }
144    }
145
146    fn keyword_or_identifier(&self, text: &str) -> HaskellSyntaxKind {
147        match text {
148            "case" => HaskellSyntaxKind::Case,
149            "class" => HaskellSyntaxKind::Class,
150            "data" => HaskellSyntaxKind::Data,
151            "default" => HaskellSyntaxKind::Default,
152            "deriving" => HaskellSyntaxKind::Deriving,
153            "do" => HaskellSyntaxKind::Do,
154            "else" => HaskellSyntaxKind::Else,
155            "if" => HaskellSyntaxKind::If,
156            "import" => HaskellSyntaxKind::Import,
157            "in" => HaskellSyntaxKind::In,
158            "infix" => HaskellSyntaxKind::Infix,
159            "infixl" => HaskellSyntaxKind::Infixl,
160            "infixr" => HaskellSyntaxKind::Infixr,
161            "instance" => HaskellSyntaxKind::Instance,
162            "let" => HaskellSyntaxKind::Let,
163            "module" => HaskellSyntaxKind::Module,
164            "newtype" => HaskellSyntaxKind::Newtype,
165            "of" => HaskellSyntaxKind::Of,
166            "then" => HaskellSyntaxKind::Then,
167            "type" => HaskellSyntaxKind::Type,
168            "where" => HaskellSyntaxKind::Where,
169            _ => HaskellSyntaxKind::Identifier,
170        }
171    }
172
173    fn lex_number<S: Source>(&self, state: &mut LexerState<S, HaskellLanguage>) -> bool {
174        let start_pos = state.get_position();
175
176        if let Some(ch) = state.peek() {
177            if ch.is_ascii_digit() {
178                state.advance(1);
179
180                while let Some(ch) = state.peek() {
181                    if ch.is_ascii_digit() {
182                        state.advance(1);
183                    }
184                    else if ch == '.' {
185                        state.advance(1);
186                        while let Some(ch) = state.peek() {
187                            if ch.is_ascii_digit() {
188                                state.advance(ch.len_utf8());
189                            }
190                            else {
191                                break;
192                            }
193                        }
194                        break;
195                    }
196                    else {
197                        break;
198                    }
199                }
200
201                state.add_token(HaskellSyntaxKind::Number, start_pos, state.get_position());
202                true
203            }
204            else {
205                false
206            }
207        }
208        else {
209            false
210        }
211    }
212
213    fn lex_string<S: Source>(&self, state: &mut LexerState<S, HaskellLanguage>) -> bool {
214        let start_pos = state.get_position();
215
216        if let Some('"') = state.peek() {
217            state.advance(1);
218
219            while let Some(ch) = state.peek() {
220                if ch == '"' {
221                    state.advance(1);
222                    state.add_token(HaskellSyntaxKind::StringLiteral, start_pos, state.get_position());
223                    return true;
224                }
225                else if ch == '\\' {
226                    state.advance(1);
227                    if let Some(_) = state.peek() {
228                        state.advance(1);
229                    }
230                }
231                else {
232                    state.advance(1);
233                }
234            }
235
236            state.add_token(HaskellSyntaxKind::StringLiteral, start_pos, state.get_position());
237            true
238        }
239        else {
240            false
241        }
242    }
243
244    fn lex_char<S: Source>(&self, state: &mut LexerState<S, HaskellLanguage>) -> bool {
245        let start_pos = state.get_position();
246
247        if let Some('\'') = state.peek() {
248            state.advance(1);
249
250            if let Some(ch) = state.peek() {
251                if ch == '\\' {
252                    state.advance(1);
253                    if let Some(_) = state.peek() {
254                        state.advance(1);
255                    }
256                }
257                else if ch != '\'' {
258                    state.advance(1);
259                }
260            }
261
262            if let Some('\'') = state.peek() {
263                state.advance(1);
264                state.add_token(HaskellSyntaxKind::CharLiteral, start_pos, state.get_position());
265                true
266            }
267            else {
268                state.add_token(HaskellSyntaxKind::CharLiteral, start_pos, state.get_position());
269                true
270            }
271        }
272        else {
273            false
274        }
275    }
276
277    fn lex_operators<S: Source>(&self, state: &mut LexerState<S, HaskellLanguage>) -> bool {
278        let start_pos = state.get_position();
279
280        if let Some(ch) = state.peek() {
281            let token_kind = match ch {
282                '+' => {
283                    state.advance(1);
284                    if let Some('+') = state.peek() {
285                        state.advance(1);
286                        HaskellSyntaxKind::Append
287                    }
288                    else {
289                        HaskellSyntaxKind::Plus
290                    }
291                }
292                '-' => {
293                    state.advance(1);
294                    if let Some('>') = state.peek() {
295                        state.advance(1);
296                        HaskellSyntaxKind::Arrow
297                    }
298                    else {
299                        HaskellSyntaxKind::Minus
300                    }
301                }
302                '*' => {
303                    state.advance(1);
304                    HaskellSyntaxKind::Star
305                }
306                '/' => {
307                    state.advance(1);
308                    HaskellSyntaxKind::Slash
309                }
310                '=' => {
311                    state.advance(1);
312                    if let Some('=') = state.peek() {
313                        state.advance(1);
314                        HaskellSyntaxKind::Equal
315                    }
316                    else {
317                        HaskellSyntaxKind::Assign
318                    }
319                }
320                '<' => {
321                    state.advance(1);
322                    if let Some('=') = state.peek() {
323                        state.advance(1);
324                        HaskellSyntaxKind::LessEqual
325                    }
326                    else if let Some('-') = state.peek() {
327                        state.advance(1);
328                        HaskellSyntaxKind::LeftArrow
329                    }
330                    else {
331                        HaskellSyntaxKind::Less
332                    }
333                }
334                '>' => {
335                    state.advance(1);
336                    if let Some('=') = state.peek() {
337                        state.advance(1);
338                        HaskellSyntaxKind::GreaterEqual
339                    }
340                    else {
341                        HaskellSyntaxKind::Greater
342                    }
343                }
344                ':' => {
345                    state.advance(1);
346                    if let Some(':') = state.peek() {
347                        state.advance(1);
348                        HaskellSyntaxKind::DoubleColon
349                    }
350                    else {
351                        HaskellSyntaxKind::Colon
352                    }
353                }
354                '|' => {
355                    state.advance(1);
356                    HaskellSyntaxKind::Pipe
357                }
358                '&' => {
359                    state.advance(1);
360                    HaskellSyntaxKind::Ampersand
361                }
362                '!' => {
363                    state.advance(1);
364                    HaskellSyntaxKind::Bang
365                }
366                '?' => {
367                    state.advance(1);
368                    HaskellSyntaxKind::Question
369                }
370                ';' => {
371                    state.advance(1);
372                    HaskellSyntaxKind::Semicolon
373                }
374                ',' => {
375                    state.advance(1);
376                    HaskellSyntaxKind::Comma
377                }
378                '.' => {
379                    state.advance(1);
380                    if let Some('.') = state.peek() {
381                        state.advance(1);
382                        HaskellSyntaxKind::DoubleDot
383                    }
384                    else {
385                        HaskellSyntaxKind::Dot
386                    }
387                }
388                '$' => {
389                    state.advance(1);
390                    HaskellSyntaxKind::Dollar
391                }
392                '@' => {
393                    state.advance(1);
394                    HaskellSyntaxKind::At
395                }
396                '~' => {
397                    state.advance(1);
398                    HaskellSyntaxKind::Tilde
399                }
400                '\\' => {
401                    state.advance(1);
402                    HaskellSyntaxKind::Backslash
403                }
404                '`' => {
405                    state.advance(1);
406                    HaskellSyntaxKind::Backtick
407                }
408                _ => return false,
409            };
410
411            state.add_token(token_kind, start_pos, state.get_position());
412            true
413        }
414        else {
415            false
416        }
417    }
418
419    fn lex_delimiters<S: Source>(&self, state: &mut LexerState<S, HaskellLanguage>) -> bool {
420        let start_pos = state.get_position();
421
422        if let Some(ch) = state.peek() {
423            let token_kind = match ch {
424                '(' => {
425                    state.advance(1);
426                    HaskellSyntaxKind::LeftParen
427                }
428                ')' => {
429                    state.advance(1);
430                    HaskellSyntaxKind::RightParen
431                }
432                '[' => {
433                    state.advance(1);
434                    HaskellSyntaxKind::LeftBracket
435                }
436                ']' => {
437                    state.advance(1);
438                    HaskellSyntaxKind::RightBracket
439                }
440                '{' => {
441                    state.advance(1);
442                    HaskellSyntaxKind::LeftBrace
443                }
444                '}' => {
445                    state.advance(1);
446                    HaskellSyntaxKind::RightBrace
447                }
448                _ => return false,
449            };
450
451            state.add_token(token_kind, start_pos, state.get_position());
452            true
453        }
454        else {
455            false
456        }
457    }
458}
459
460impl<'config> Lexer<HaskellLanguage> for HaskellLexer<'config> {
461    fn lex_incremental(
462        &self,
463        source: impl Source,
464        changed: usize,
465        cache: IncrementalCache<HaskellLanguage>,
466    ) -> LexOutput<HaskellLanguage> {
467        let mut state = LexerState::new_with_cache(source, changed, cache);
468
469        while state.not_at_end() {
470            if self.skip_whitespace(&mut state) {
471                continue;
472            }
473
474            if self.lex_newline(&mut state) {
475                continue;
476            }
477
478            if self.lex_single_line_comment(&mut state) {
479                continue;
480            }
481
482            if self.lex_multi_line_comment(&mut state) {
483                continue;
484            }
485
486            if self.lex_identifier_or_keyword(&mut state) {
487                continue;
488            }
489
490            if self.lex_number(&mut state) {
491                continue;
492            }
493
494            if self.lex_string(&mut state) {
495                continue;
496            }
497
498            if self.lex_char(&mut state) {
499                continue;
500            }
501
502            if self.lex_operators(&mut state) {
503                continue;
504            }
505
506            if self.lex_delimiters(&mut state) {
507                continue;
508            }
509
510            // 如果没有匹配到任何模式,跳过当前字符并标记为错误
511            let start_pos = state.get_position();
512            if let Some(ch) = state.peek() {
513                state.advance(1);
514                state.add_token(HaskellSyntaxKind::Error, start_pos, state.get_position());
515            }
516        }
517
518        // 添加 EOF token
519        let pos = state.get_position();
520        state.add_token(HaskellSyntaxKind::Eof, pos, pos);
521
522        state.finish(Ok(()))
523    }
524}