oak_python/lexer/
mod.rs

1use crate::{kind::PythonSyntaxKind, language::PythonLanguage};
2use oak_core::{
3    Lexer, LexerCache, LexerState, OakError,
4    lexer::LexOutput,
5    source::{Source, TextEdit},
6};
7
8type State<'a, S> = LexerState<'a, S, PythonLanguage>;
9
10#[derive(Clone)]
11pub struct PythonLexer<'config> {
12    _config: &'config PythonLanguage,
13}
14
15impl<'config> PythonLexer<'config> {
16    pub fn new(config: &'config PythonLanguage) -> Self {
17        Self { _config: config }
18    }
19
20    /// 跳过空白字符
21    fn skip_whitespace<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
22        let start_pos = state.get_position();
23
24        while let Some(ch) = state.current() {
25            if ch == ' ' || ch == '\t' {
26                state.advance(ch.len_utf8());
27            }
28            else {
29                break;
30            }
31        }
32
33        if state.get_position() > start_pos {
34            state.add_token(PythonSyntaxKind::Whitespace, start_pos, state.get_position());
35            true
36        }
37        else {
38            false
39        }
40    }
41
42    /// 处理换行
43    fn lex_newline<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
44        let start_pos = state.get_position();
45
46        if let Some('\n') = state.current() {
47            state.advance(1);
48            state.add_token(PythonSyntaxKind::Newline, start_pos, state.get_position());
49            true
50        }
51        else if let Some('\r') = state.current() {
52            state.advance(1);
53            if let Some('\n') = state.current() {
54                state.advance(1);
55            }
56            state.add_token(PythonSyntaxKind::Newline, start_pos, state.get_position());
57            true
58        }
59        else {
60            false
61        }
62    }
63
64    /// 处理注释
65    fn lex_comment<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
66        if let Some('#') = state.current() {
67            let start_pos = state.get_position();
68            state.advance(1); // 跳过 '#'
69
70            // 读取到行尾
71            while let Some(ch) = state.current() {
72                if ch == '\n' || ch == '\r' {
73                    break;
74                }
75                state.advance(ch.len_utf8());
76            }
77
78            state.add_token(PythonSyntaxKind::Comment, start_pos, state.get_position());
79            true
80        }
81        else {
82            false
83        }
84    }
85
86    /// 处理字符串字面量
87    fn lex_string<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
88        let start_pos = state.get_position();
89
90        // 检查是否是字符串开始
91        let quote_char = match state.current() {
92            Some('"') => '"',
93            Some('\'') => '\'',
94            _ => return false,
95        };
96
97        state.advance(1); // 跳过开始引号
98
99        // 检查是否是三引号字符串 - 简化实现,不支持三引号
100        let mut escaped = false;
101        while let Some(ch) = state.current() {
102            if escaped {
103                escaped = false;
104                state.advance(ch.len_utf8());
105                continue;
106            }
107
108            if ch == '\\' {
109                escaped = true;
110                state.advance(1);
111                continue;
112            }
113
114            if ch == quote_char {
115                state.advance(1); // 跳过结束引号
116                break;
117            }
118            else if ch == '\n' || ch == '\r' {
119                // 单行字符串不能包含换行符
120                break;
121            }
122            else {
123                state.advance(ch.len_utf8());
124            }
125        }
126
127        state.add_token(PythonSyntaxKind::String, start_pos, state.get_position());
128        true
129    }
130
131    /// 处理数字字面量
132    fn lex_number<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
133        let start_pos = state.get_position();
134
135        if !state.current().map_or(false, |c| c.is_ascii_digit()) {
136            return false;
137        }
138
139        // 简化实现:只处理基本的十进制数字
140        while let Some(ch) = state.current() {
141            if ch.is_ascii_digit() || ch == '.' {
142                state.advance(1);
143            }
144            else {
145                break;
146            }
147        }
148
149        state.add_token(PythonSyntaxKind::Number, start_pos, state.get_position());
150        true
151    }
152
153    /// 处理标识符或关键字
154    fn lex_identifier_or_keyword<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
155        let start_pos = state.get_position();
156
157        // 检查第一个字符
158        if !state.current().map_or(false, |c| c.is_ascii_alphabetic() || c == '_') {
159            return false;
160        }
161
162        // 读取标识符
163        let mut text = String::new();
164        while let Some(ch) = state.current() {
165            if ch.is_ascii_alphanumeric() || ch == '_' {
166                text.push(ch);
167                state.advance(ch.len_utf8());
168            }
169            else {
170                break;
171            }
172        }
173
174        // 检查是否是关键字
175        let kind = match text.as_str() {
176            "and" => PythonSyntaxKind::AndKeyword,
177            "as" => PythonSyntaxKind::AsKeyword,
178            "assert" => PythonSyntaxKind::AssertKeyword,
179            "async" => PythonSyntaxKind::AsyncKeyword,
180            "await" => PythonSyntaxKind::AwaitKeyword,
181            "break" => PythonSyntaxKind::BreakKeyword,
182            "class" => PythonSyntaxKind::ClassKeyword,
183            "continue" => PythonSyntaxKind::ContinueKeyword,
184            "def" => PythonSyntaxKind::DefKeyword,
185            "del" => PythonSyntaxKind::DelKeyword,
186            "elif" => PythonSyntaxKind::ElifKeyword,
187            "else" => PythonSyntaxKind::ElseKeyword,
188            "except" => PythonSyntaxKind::ExceptKeyword,
189            "False" => PythonSyntaxKind::FalseKeyword,
190            "finally" => PythonSyntaxKind::FinallyKeyword,
191            "for" => PythonSyntaxKind::ForKeyword,
192            "from" => PythonSyntaxKind::FromKeyword,
193            "global" => PythonSyntaxKind::GlobalKeyword,
194            "if" => PythonSyntaxKind::IfKeyword,
195            "import" => PythonSyntaxKind::ImportKeyword,
196            "in" => PythonSyntaxKind::InKeyword,
197            "is" => PythonSyntaxKind::IsKeyword,
198            "lambda" => PythonSyntaxKind::LambdaKeyword,
199            "None" => PythonSyntaxKind::NoneKeyword,
200            "nonlocal" => PythonSyntaxKind::NonlocalKeyword,
201            "not" => PythonSyntaxKind::NotKeyword,
202            "or" => PythonSyntaxKind::OrKeyword,
203            "pass" => PythonSyntaxKind::PassKeyword,
204            "raise" => PythonSyntaxKind::RaiseKeyword,
205            "return" => PythonSyntaxKind::ReturnKeyword,
206            "True" => PythonSyntaxKind::TrueKeyword,
207            "try" => PythonSyntaxKind::TryKeyword,
208            "while" => PythonSyntaxKind::WhileKeyword,
209            "with" => PythonSyntaxKind::WithKeyword,
210            "yield" => PythonSyntaxKind::YieldKeyword,
211            _ => PythonSyntaxKind::Identifier,
212        };
213
214        state.add_token(kind, start_pos, state.get_position());
215        true
216    }
217
218    /// 处理操作符
219    fn lex_operator<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
220        let start_pos = state.get_position();
221
222        // 简化实现:只处理单字符操作符
223        if let Some(ch) = state.current() {
224            let kind = match ch {
225                '+' => {
226                    state.advance(1);
227                    PythonSyntaxKind::Plus
228                }
229                '-' => {
230                    state.advance(1);
231                    PythonSyntaxKind::Minus
232                }
233                '*' => {
234                    state.advance(1);
235                    PythonSyntaxKind::Star
236                }
237                '/' => {
238                    state.advance(1);
239                    PythonSyntaxKind::Slash
240                }
241                '%' => {
242                    state.advance(1);
243                    PythonSyntaxKind::Percent
244                }
245                '=' => {
246                    state.advance(1);
247                    PythonSyntaxKind::Assign
248                }
249                '<' => {
250                    state.advance(1);
251                    PythonSyntaxKind::Less
252                }
253                '>' => {
254                    state.advance(1);
255                    PythonSyntaxKind::Greater
256                }
257                '&' => {
258                    state.advance(1);
259                    PythonSyntaxKind::Ampersand
260                }
261                '|' => {
262                    state.advance(1);
263                    PythonSyntaxKind::Pipe
264                }
265                '^' => {
266                    state.advance(1);
267                    PythonSyntaxKind::Caret
268                }
269                '~' => {
270                    state.advance(1);
271                    PythonSyntaxKind::Tilde
272                }
273                '@' => {
274                    state.advance(1);
275                    PythonSyntaxKind::At
276                }
277                _ => return false,
278            };
279
280            state.add_token(kind, start_pos, state.get_position());
281            return true;
282        }
283
284        false
285    }
286
287    /// 处理分隔符
288    fn lex_delimiter<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
289        let start_pos = state.get_position();
290
291        if let Some(ch) = state.current() {
292            let kind = match ch {
293                '(' => PythonSyntaxKind::LeftParen,
294                ')' => PythonSyntaxKind::RightParen,
295                '[' => PythonSyntaxKind::LeftBracket,
296                ']' => PythonSyntaxKind::RightBracket,
297                '{' => PythonSyntaxKind::LeftBrace,
298                '}' => PythonSyntaxKind::RightBrace,
299                ',' => PythonSyntaxKind::Comma,
300                ':' => PythonSyntaxKind::Colon,
301                ';' => PythonSyntaxKind::Semicolon,
302                '.' => PythonSyntaxKind::Dot, // 简化处理,不支持省略号
303                _ => return false,
304            };
305
306            state.advance(1);
307            state.add_token(kind, start_pos, state.get_position());
308            return true;
309        }
310
311        false
312    }
313}
314
315impl<'config> Lexer<PythonLanguage> for PythonLexer<'config> {
316    fn lex<'a, S: Source + ?Sized>(&self, source: &S, _edits: &[TextEdit], cache: &'a mut impl LexerCache<PythonLanguage>) -> LexOutput<PythonLanguage> {
317        let mut state: State<'_, S> = LexerState::new(source);
318        let result = self.run(&mut state);
319        if result.is_ok() {
320            state.add_eof();
321        }
322        state.finish_with_cache(result, cache)
323    }
324}
325
326impl<'config> PythonLexer<'config> {
327    pub(crate) fn run<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
328        let mut indent_stack = vec![0];
329        let mut bracket_level: usize = 0;
330        let mut at_line_start = true;
331
332        while state.not_at_end() {
333            let safe_point = state.get_position();
334
335            if at_line_start && bracket_level == 0 {
336                self.handle_indentation(state, &mut indent_stack);
337                at_line_start = false;
338                continue;
339            }
340
341            if let Some(ch) = state.peek() {
342                match ch {
343                    ' ' | '\t' => {
344                        self.skip_whitespace(state);
345                    }
346                    '\n' | '\r' => {
347                        self.lex_newline(state);
348                        at_line_start = true;
349                    }
350                    '#' => {
351                        self.lex_comment(state);
352                    }
353                    '"' | '\'' => {
354                        self.lex_string(state);
355                    }
356                    '0'..='9' => {
357                        self.lex_number(state);
358                    }
359                    'a'..='z' | 'A'..='Z' | '_' => {
360                        self.lex_identifier_or_keyword(state);
361                    }
362                    '(' | '[' | '{' => {
363                        bracket_level += 1;
364                        self.lex_delimiter(state);
365                    }
366                    ')' | ']' | '}' => {
367                        bracket_level = bracket_level.saturating_sub(1);
368                        self.lex_delimiter(state);
369                    }
370                    '+' | '-' | '*' | '/' | '%' | '=' | '<' | '>' | '&' | '|' | '^' | '~' | '@' => {
371                        self.lex_operator(state);
372                    }
373                    ',' | ':' | ';' | '.' => {
374                        self.lex_delimiter(state);
375                    }
376                    _ => {
377                        // Fallback to error
378                        state.advance(ch.len_utf8());
379                        state.add_token(PythonSyntaxKind::Error, safe_point, state.get_position());
380                    }
381                }
382            }
383
384            state.advance_if_dead_lock(safe_point);
385        }
386
387        // Emit remaining dedents
388        while indent_stack.len() > 1 {
389            indent_stack.pop();
390            let pos = state.get_position();
391            state.add_token(PythonSyntaxKind::Dedent, pos, pos);
392        }
393
394        Ok(())
395    }
396
397    fn handle_indentation<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>, stack: &mut Vec<usize>) {
398        let start_pos = state.get_position();
399        let current_indent;
400
401        // Skip comments and empty lines at start of line
402        let mut temp_state = state.get_position();
403        loop {
404            let mut indent = 0;
405            while let Some(ch) = state.get_char_at(temp_state) {
406                if ch == ' ' {
407                    indent += 1;
408                }
409                else if ch == '\t' {
410                    indent += 8;
411                }
412                // Standard Python tab width
413                else {
414                    break;
415                }
416                temp_state += 1;
417            }
418
419            match state.get_char_at(temp_state) {
420                Some('\n') | Some('\r') | Some('#') => {
421                    // This is an empty line or comment-only line, ignore indentation change
422                    return;
423                }
424                None => return, // EOF
425                _ => {
426                    current_indent = indent;
427                    break;
428                }
429            }
430        }
431
432        // Advance state to skip the indentation we just measured
433        if current_indent > 0 {
434            let end_pos = state.get_position() + (temp_state - state.get_position());
435            state.add_token(PythonSyntaxKind::Whitespace, start_pos, end_pos);
436            state.set_position(end_pos);
437        }
438
439        let last_indent = *stack.last().unwrap();
440        if current_indent > last_indent {
441            stack.push(current_indent);
442            state.add_token(PythonSyntaxKind::Indent, state.get_position(), state.get_position());
443        }
444        else {
445            while current_indent < *stack.last().unwrap() {
446                stack.pop();
447                state.add_token(PythonSyntaxKind::Dedent, state.get_position(), state.get_position());
448            }
449            // If current_indent doesn't match any previous level, it's an indentation error,
450            // but for now we just stop at the closest level.
451        }
452    }
453}