Skip to main content

oak_python/lexer/
mod.rs

1use crate::{kind::PythonSyntaxKind, language::PythonLanguage};
2use oak_core::{
3    Lexer, LexerCache, LexerState, OakError,
4    lexer::LexOutput,
5    source::{Source, TextEdit},
6};
7
8type State<'a, S> = LexerState<'a, S, PythonLanguage>;
9
10#[derive(Clone)]
11pub struct PythonLexer<'config> {
12    _config: &'config PythonLanguage,
13}
14
15impl<'config> Lexer<PythonLanguage> for PythonLexer<'config> {
16    fn lex<'a, S: Source + ?Sized>(&self, source: &S, _edits: &[TextEdit], cache: &'a mut impl LexerCache<PythonLanguage>) -> LexOutput<PythonLanguage> {
17        let mut state = State::new_with_cache(source, 0, cache);
18        let result = self.run(&mut state);
19        if result.is_ok() {
20            state.add_eof();
21        }
22        state.finish_with_cache(result, cache)
23    }
24}
25
26impl<'config> PythonLexer<'config> {
27    pub fn new(config: &'config PythonLanguage) -> Self {
28        Self { _config: config }
29    }
30
31    /// 跳过空白字符
32    fn skip_whitespace<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
33        let start_pos = state.get_position();
34
35        while let Some(ch) = state.current() {
36            if ch == ' ' || ch == '\t' {
37                state.advance(ch.len_utf8());
38            }
39            else {
40                break;
41            }
42        }
43
44        if state.get_position() > start_pos {
45            state.add_token(PythonSyntaxKind::Whitespace, start_pos, state.get_position());
46            true
47        }
48        else {
49            false
50        }
51    }
52
53    /// 处理换行
54    fn lex_newline<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
55        let start_pos = state.get_position();
56
57        if let Some('\n') = state.current() {
58            state.advance(1);
59            state.add_token(PythonSyntaxKind::Newline, start_pos, state.get_position());
60            true
61        }
62        else if let Some('\r') = state.current() {
63            state.advance(1);
64            if let Some('\n') = state.current() {
65                state.advance(1);
66            }
67            state.add_token(PythonSyntaxKind::Newline, start_pos, state.get_position());
68            true
69        }
70        else {
71            false
72        }
73    }
74
75    /// 处理注释
76    fn lex_comment<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
77        if let Some('#') = state.current() {
78            let start_pos = state.get_position();
79            state.advance(1); // 跳过 '#'
80
81            // 读取到行尾
82            while let Some(ch) = state.current() {
83                if ch == '\n' || ch == '\r' {
84                    break;
85                }
86                state.advance(ch.len_utf8());
87            }
88
89            state.add_token(PythonSyntaxKind::Comment, start_pos, state.get_position());
90            true
91        }
92        else {
93            false
94        }
95    }
96
97    /// 处理字符串字面量
98    fn lex_string<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
99        let start_pos = state.get_position();
100
101        // 检查是否是字符串开始
102        let quote_char = match state.current() {
103            Some('"') => '"',
104            Some('\'') => '\'',
105            _ => return false,
106        };
107
108        state.advance(1); // 跳过开始引号
109
110        // 检查是否是三引号字符串 - 简化实现,不支持三引号
111        let mut escaped = false;
112        while let Some(ch) = state.current() {
113            if escaped {
114                escaped = false;
115                state.advance(ch.len_utf8());
116                continue;
117            }
118
119            if ch == '\\' {
120                escaped = true;
121                state.advance(1);
122                continue;
123            }
124
125            if ch == quote_char {
126                state.advance(1); // 跳过结束引号
127                break;
128            }
129            else if ch == '\n' || ch == '\r' {
130                // 单行字符串不能包含换行符
131                break;
132            }
133            else {
134                state.advance(ch.len_utf8());
135            }
136        }
137
138        state.add_token(PythonSyntaxKind::String, start_pos, state.get_position());
139        true
140    }
141
142    /// 处理数字字面量
143    fn lex_number<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
144        let start_pos = state.get_position();
145
146        if !state.current().map_or(false, |c| c.is_ascii_digit()) {
147            return false;
148        }
149
150        // 简化实现:只处理基本的十进制数字
151        while let Some(ch) = state.current() {
152            if ch.is_ascii_digit() || ch == '.' {
153                state.advance(1);
154            }
155            else {
156                break;
157            }
158        }
159
160        state.add_token(PythonSyntaxKind::Number, start_pos, state.get_position());
161        true
162    }
163
164    /// 处理标识符或关键字
165    fn lex_identifier_or_keyword<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
166        let start_pos = state.get_position();
167
168        // 检查第一个字符
169        if !state.current().map_or(false, |c| c.is_ascii_alphabetic() || c == '_') {
170            return false;
171        }
172
173        // 读取标识符
174        let mut text = String::new();
175        while let Some(ch) = state.current() {
176            if ch.is_ascii_alphanumeric() || ch == '_' {
177                text.push(ch);
178                state.advance(ch.len_utf8());
179            }
180            else {
181                break;
182            }
183        }
184
185        // 检查是否是关键字
186        let kind = match text.as_str() {
187            "and" => PythonSyntaxKind::AndKeyword,
188            "as" => PythonSyntaxKind::AsKeyword,
189            "assert" => PythonSyntaxKind::AssertKeyword,
190            "async" => PythonSyntaxKind::AsyncKeyword,
191            "await" => PythonSyntaxKind::AwaitKeyword,
192            "break" => PythonSyntaxKind::BreakKeyword,
193            "class" => PythonSyntaxKind::ClassKeyword,
194            "continue" => PythonSyntaxKind::ContinueKeyword,
195            "def" => PythonSyntaxKind::DefKeyword,
196            "del" => PythonSyntaxKind::DelKeyword,
197            "elif" => PythonSyntaxKind::ElifKeyword,
198            "else" => PythonSyntaxKind::ElseKeyword,
199            "except" => PythonSyntaxKind::ExceptKeyword,
200            "False" => PythonSyntaxKind::FalseKeyword,
201            "finally" => PythonSyntaxKind::FinallyKeyword,
202            "for" => PythonSyntaxKind::ForKeyword,
203            "from" => PythonSyntaxKind::FromKeyword,
204            "global" => PythonSyntaxKind::GlobalKeyword,
205            "if" => PythonSyntaxKind::IfKeyword,
206            "import" => PythonSyntaxKind::ImportKeyword,
207            "in" => PythonSyntaxKind::InKeyword,
208            "is" => PythonSyntaxKind::IsKeyword,
209            "lambda" => PythonSyntaxKind::LambdaKeyword,
210            "None" => PythonSyntaxKind::NoneKeyword,
211            "nonlocal" => PythonSyntaxKind::NonlocalKeyword,
212            "not" => PythonSyntaxKind::NotKeyword,
213            "or" => PythonSyntaxKind::OrKeyword,
214            "pass" => PythonSyntaxKind::PassKeyword,
215            "raise" => PythonSyntaxKind::RaiseKeyword,
216            "return" => PythonSyntaxKind::ReturnKeyword,
217            "True" => PythonSyntaxKind::TrueKeyword,
218            "try" => PythonSyntaxKind::TryKeyword,
219            "while" => PythonSyntaxKind::WhileKeyword,
220            "with" => PythonSyntaxKind::WithKeyword,
221            "yield" => PythonSyntaxKind::YieldKeyword,
222            _ => PythonSyntaxKind::Identifier,
223        };
224
225        state.add_token(kind, start_pos, state.get_position());
226        true
227    }
228
229    /// 处理操作符
230    fn lex_operator<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
231        let start_pos = state.get_position();
232
233        if let Some(ch) = state.current() {
234            let kind = match ch {
235                '+' => {
236                    state.advance(1);
237                    if let Some('=') = state.current() {
238                        state.advance(1);
239                        PythonSyntaxKind::PlusAssign
240                    }
241                    else {
242                        PythonSyntaxKind::Plus
243                    }
244                }
245                '-' => {
246                    state.advance(1);
247                    if let Some('=') = state.current() {
248                        state.advance(1);
249                        PythonSyntaxKind::MinusAssign
250                    }
251                    else if let Some('>') = state.current() {
252                        state.advance(1);
253                        PythonSyntaxKind::Arrow
254                    }
255                    else {
256                        PythonSyntaxKind::Minus
257                    }
258                }
259                '*' => {
260                    state.advance(1);
261                    if let Some('=') = state.current() {
262                        state.advance(1);
263                        PythonSyntaxKind::StarAssign
264                    }
265                    else if let Some('*') = state.current() {
266                        state.advance(1);
267                        if let Some('=') = state.current() {
268                            state.advance(1);
269                            PythonSyntaxKind::DoubleStarAssign
270                        }
271                        else {
272                            PythonSyntaxKind::DoubleStar
273                        }
274                    }
275                    else {
276                        PythonSyntaxKind::Star
277                    }
278                }
279                '/' => {
280                    state.advance(1);
281                    if let Some('=') = state.current() {
282                        state.advance(1);
283                        PythonSyntaxKind::SlashAssign
284                    }
285                    else if let Some('/') = state.current() {
286                        state.advance(1);
287                        if let Some('=') = state.current() {
288                            state.advance(1);
289                            PythonSyntaxKind::DoubleSlashAssign
290                        }
291                        else {
292                            PythonSyntaxKind::DoubleSlash
293                        }
294                    }
295                    else {
296                        PythonSyntaxKind::Slash
297                    }
298                }
299                '%' => {
300                    state.advance(1);
301                    if let Some('=') = state.current() {
302                        state.advance(1);
303                        PythonSyntaxKind::PercentAssign
304                    }
305                    else {
306                        PythonSyntaxKind::Percent
307                    }
308                }
309                '=' => {
310                    state.advance(1);
311                    if let Some('=') = state.current() {
312                        state.advance(1);
313                        PythonSyntaxKind::Eq
314                    }
315                    else {
316                        PythonSyntaxKind::Assign
317                    }
318                }
319                '<' => {
320                    state.advance(1);
321                    if let Some('=') = state.current() {
322                        state.advance(1);
323                        PythonSyntaxKind::LessEqual
324                    }
325                    else if let Some('<') = state.current() {
326                        state.advance(1);
327                        if let Some('=') = state.current() {
328                            state.advance(1);
329                            PythonSyntaxKind::LeftShiftAssign
330                        }
331                        else {
332                            PythonSyntaxKind::LeftShift
333                        }
334                    }
335                    else {
336                        PythonSyntaxKind::Less
337                    }
338                }
339                '>' => {
340                    state.advance(1);
341                    if let Some('=') = state.current() {
342                        state.advance(1);
343                        PythonSyntaxKind::GreaterEqual
344                    }
345                    else if let Some('>') = state.current() {
346                        state.advance(1);
347                        if let Some('=') = state.current() {
348                            state.advance(1);
349                            PythonSyntaxKind::RightShiftAssign
350                        }
351                        else {
352                            PythonSyntaxKind::RightShift
353                        }
354                    }
355                    else {
356                        PythonSyntaxKind::Greater
357                    }
358                }
359                '!' => {
360                    state.advance(1);
361                    if let Some('=') = state.current() {
362                        state.advance(1);
363                        PythonSyntaxKind::NotEqual
364                    }
365                    else {
366                        return false;
367                    }
368                }
369                '&' => {
370                    state.advance(1);
371                    if let Some('=') = state.current() {
372                        state.advance(1);
373                        PythonSyntaxKind::AmpersandAssign
374                    }
375                    else {
376                        PythonSyntaxKind::Ampersand
377                    }
378                }
379                '|' => {
380                    state.advance(1);
381                    if let Some('=') = state.current() {
382                        state.advance(1);
383                        PythonSyntaxKind::PipeAssign
384                    }
385                    else {
386                        PythonSyntaxKind::Pipe
387                    }
388                }
389                '^' => {
390                    state.advance(1);
391                    if let Some('=') = state.current() {
392                        state.advance(1);
393                        PythonSyntaxKind::CaretAssign
394                    }
395                    else {
396                        PythonSyntaxKind::Caret
397                    }
398                }
399                '~' => {
400                    state.advance(1);
401                    PythonSyntaxKind::Tilde
402                }
403                '@' => {
404                    state.advance(1);
405                    if let Some('=') = state.current() {
406                        state.advance(1);
407                        PythonSyntaxKind::AtAssign
408                    }
409                    else {
410                        PythonSyntaxKind::At
411                    }
412                }
413                _ => return false,
414            };
415
416            state.add_token(kind, start_pos, state.get_position());
417            return true;
418        }
419
420        false
421    }
422
423    /// 处理分隔符
424    fn lex_delimiter<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
425        let start_pos = state.get_position();
426
427        if let Some(ch) = state.current() {
428            let kind = match ch {
429                '(' => PythonSyntaxKind::LeftParen,
430                ')' => PythonSyntaxKind::RightParen,
431                '[' => PythonSyntaxKind::LeftBracket,
432                ']' => PythonSyntaxKind::RightBracket,
433                '{' => PythonSyntaxKind::LeftBrace,
434                '}' => PythonSyntaxKind::RightBrace,
435                ',' => PythonSyntaxKind::Comma,
436                ':' => PythonSyntaxKind::Colon,
437                ';' => PythonSyntaxKind::Semicolon,
438                '.' => PythonSyntaxKind::Dot, // 简化处理,不支持省略号
439                _ => return false,
440            };
441
442            state.advance(1);
443            state.add_token(kind, start_pos, state.get_position());
444            return true;
445        }
446
447        false
448    }
449}
450
451impl<'config> PythonLexer<'config> {
452    pub(crate) fn run<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
453        let mut indent_stack = vec![0];
454        let mut bracket_level: usize = 0;
455        let mut at_line_start = true;
456
457        while state.not_at_end() {
458            let safe_point = state.get_position();
459
460            if at_line_start && bracket_level == 0 {
461                self.handle_indentation(state, &mut indent_stack);
462                at_line_start = false;
463                continue;
464            }
465
466            if let Some(ch) = state.peek() {
467                match ch {
468                    ' ' | '\t' => {
469                        self.skip_whitespace(state);
470                    }
471                    '\n' | '\r' => {
472                        self.lex_newline(state);
473                        at_line_start = true;
474                    }
475                    '#' => {
476                        self.lex_comment(state);
477                    }
478                    '"' | '\'' => {
479                        self.lex_string(state);
480                    }
481                    '0'..='9' => {
482                        self.lex_number(state);
483                    }
484                    'a'..='z' | 'A'..='Z' | '_' => {
485                        self.lex_identifier_or_keyword(state);
486                    }
487                    '(' | '[' | '{' => {
488                        bracket_level += 1;
489                        self.lex_delimiter(state);
490                    }
491                    ')' | ']' | '}' => {
492                        bracket_level = bracket_level.saturating_sub(1);
493                        self.lex_delimiter(state);
494                    }
495                    '+' | '-' | '*' | '/' | '%' | '=' | '<' | '>' | '&' | '|' | '^' | '~' | '@' => {
496                        self.lex_operator(state);
497                    }
498                    ',' | ':' | ';' | '.' => {
499                        self.lex_delimiter(state);
500                    }
501                    _ => {
502                        // Fallback to error
503                        state.advance(ch.len_utf8());
504                        state.add_token(PythonSyntaxKind::Error, safe_point, state.get_position());
505                    }
506                }
507            }
508
509            state.advance_if_dead_lock(safe_point);
510        }
511
512        // Emit remaining dedents
513        while indent_stack.len() > 1 {
514            indent_stack.pop();
515            let pos = state.get_position();
516            state.add_token(PythonSyntaxKind::Dedent, pos, pos);
517        }
518
519        Ok(())
520    }
521
522    fn handle_indentation<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>, stack: &mut Vec<usize>) {
523        let start_pos = state.get_position();
524        let current_indent;
525
526        // Skip comments and empty lines at start of line
527        let mut temp_state = state.get_position();
528        loop {
529            let mut indent = 0;
530            while let Some(ch) = state.get_char_at(temp_state) {
531                if ch == ' ' {
532                    indent += 1;
533                }
534                else if ch == '\t' {
535                    indent += 8;
536                }
537                // Standard Python tab width
538                else {
539                    break;
540                }
541                temp_state += 1;
542            }
543
544            match state.get_char_at(temp_state) {
545                Some('\n') | Some('\r') | Some('#') => {
546                    // This is an empty line or comment-only line, ignore indentation change
547                    return;
548                }
549                None => return, // EOF
550                _ => {
551                    current_indent = indent;
552                    break;
553                }
554            }
555        }
556
557        // Advance state to skip the indentation we just measured
558        if current_indent > 0 {
559            let end_pos = state.get_position() + (temp_state - state.get_position());
560            state.add_token(PythonSyntaxKind::Whitespace, start_pos, end_pos);
561            state.set_position(end_pos);
562        }
563
564        let last_indent = *stack.last().unwrap();
565        if current_indent > last_indent {
566            stack.push(current_indent);
567            state.add_token(PythonSyntaxKind::Indent, state.get_position(), state.get_position());
568        }
569        else {
570            while current_indent < *stack.last().unwrap() {
571                stack.pop();
572                state.add_token(PythonSyntaxKind::Dedent, state.get_position(), state.get_position());
573            }
574            // If current_indent doesn't match any previous level, it's an indentation error,
575            // but for now we just stop at the closest level.
576        }
577    }
578}