Skip to main content

oak_bash/lexer/
mod.rs

1pub mod token_type;
2
3pub use token_type::BashTokenType;
4
5use crate::language::BashLanguage;
6use oak_core::{Lexer, LexerCache, LexerState, OakError, lexer::LexOutput, source::Source};
7use std::sync::LazyLock;
8
9type State<'a, S> = LexerState<'a, S, BashLanguage>;
10
11#[derive(Clone)]
12pub struct BashLexer<'config> {
13    _config: &'config BashLanguage,
14}
15
16impl<'config> Lexer<BashLanguage> for BashLexer<'config> {
17    fn lex<'a, S: Source + ?Sized>(&self, source: &S, _edits: &[oak_core::source::TextEdit], cache: &'a mut impl LexerCache<BashLanguage>) -> LexOutput<BashLanguage> {
18        let mut state = LexerState::new_with_cache(source, 0, cache);
19        let result = self.run(&mut state);
20        if result.is_ok() {
21            state.add_eof();
22        }
23        state.finish_with_cache(result, cache)
24    }
25}
26
27impl<'config> BashLexer<'config> {
28    pub fn new(config: &'config BashLanguage) -> Self {
29        Self { _config: config }
30    }
31
32    fn run<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
33        while state.not_at_end() {
34            let safe_point = state.get_position();
35            if self.skip_whitespace(state) {
36                continue;
37            }
38
39            if self.skip_comment(state) {
40                continue;
41            }
42
43            if self.lex_newline(state) {
44                continue;
45            }
46
47            if self.lex_string(state) {
48                continue;
49            }
50
51            if self.lex_variable(state) {
52                continue;
53            }
54
55            if self.lex_number(state) {
56                continue;
57            }
58
59            if self.lex_keyword_or_identifier(state) {
60                continue;
61            }
62
63            if self.lex_operator_or_delimiter(state) {
64                continue;
65            }
66
67            if self.lex_heredoc(state) {
68                continue;
69            }
70
71            if self.lex_glob_pattern(state) {
72                continue;
73            }
74
75            if self.lex_special_char(state) {
76                continue;
77            }
78
79            if self.lex_text(state) {
80                continue;
81            }
82
83            // 如果没有匹配任何模式,跳过一个字符并生成 Error token
84            let start_pos = state.get_position();
85            if let Some(ch) = state.peek() {
86                state.advance(ch.len_utf8());
87                state.add_token(BashTokenType::Error, start_pos, state.get_position());
88            }
89
90            state.advance_if_dead_lock(safe_point);
91        }
92        Ok(())
93    }
94
95    fn skip_whitespace<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
96        let start_pos = state.get_position();
97
98        while let Some(ch) = state.peek() {
99            if ch == ' ' || ch == '\t' {
100                state.advance(ch.len_utf8());
101            }
102            else {
103                break;
104            }
105        }
106
107        if state.get_position() > start_pos {
108            state.add_token(BashTokenType::Whitespace, start_pos, state.get_position());
109            true
110        }
111        else {
112            false
113        }
114    }
115
116    fn skip_comment<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
117        let start_pos = state.get_position();
118
119        if let Some('#') = state.peek() {
120            state.advance(1);
121            while let Some(ch) = state.peek() {
122                if ch == '\n' || ch == '\r' {
123                    break;
124                }
125                state.advance(ch.len_utf8());
126            }
127            state.add_token(BashTokenType::Comment, start_pos, state.get_position());
128            true
129        }
130        else {
131            false
132        }
133    }
134
135    fn lex_newline<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
136        let start_pos = state.get_position();
137
138        if let Some('\n') = state.peek() {
139            state.advance(1);
140            state.add_token(BashTokenType::Newline, start_pos, state.get_position());
141            true
142        }
143        else if let Some('\r') = state.peek() {
144            state.advance(1);
145            if let Some('\n') = state.peek() {
146                state.advance(1);
147            }
148            state.add_token(BashTokenType::Newline, start_pos, state.get_position());
149            true
150        }
151        else {
152            false
153        }
154    }
155
156    fn lex_string<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
157        let start_pos = state.get_position();
158
159        if let Some(quote) = state.peek() {
160            if quote == '"' || quote == '\'' {
161                state.advance(1);
162                let mut escaped = false;
163
164                while let Some(ch) = state.peek() {
165                    if escaped {
166                        escaped = false;
167                        state.advance(ch.len_utf8());
168                        continue;
169                    }
170
171                    if ch == '\\' {
172                        escaped = true;
173                        state.advance(1);
174                        continue;
175                    }
176
177                    if ch == quote {
178                        state.advance(1);
179                        break;
180                    }
181
182                    state.advance(ch.len_utf8());
183                }
184
185                state.add_token(BashTokenType::StringLiteral, start_pos, state.get_position());
186                return true;
187            }
188        }
189
190        false
191    }
192
193    fn lex_variable<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
194        let start_pos = state.get_position();
195
196        if let Some('$') = state.peek() {
197            state.advance(1);
198
199            // 处理特殊变量 $0, $1, $?, $$ 等
200            if let Some(ch) = state.peek() {
201                if ch.is_ascii_digit() || ch == '?' || ch == '$' || ch == '#' || ch == '@' || ch == '*' {
202                    state.advance(1);
203                    state.add_token(BashTokenType::Variable, start_pos, state.get_position());
204                    return true;
205                }
206            }
207
208            // 处理 ${var} 形式
209            if let Some('{') = state.peek() {
210                state.advance(1);
211                while let Some(ch) = state.peek() {
212                    if ch == '}' {
213                        state.advance(1);
214                        break;
215                    }
216                    state.advance(ch.len_utf8());
217                }
218                state.add_token(BashTokenType::Variable, start_pos, state.get_position());
219                return true;
220            }
221
222            // 处理普通变量名
223            if let Some(ch) = state.peek() {
224                if ch.is_alphabetic() || ch == '_' {
225                    state.advance(ch.len_utf8());
226                    while let Some(ch) = state.peek() {
227                        if ch.is_alphanumeric() || ch == '_' {
228                            state.advance(ch.len_utf8());
229                        }
230                        else {
231                            break;
232                        }
233                    }
234                    state.add_token(BashTokenType::Variable, start_pos, state.get_position());
235                    return true;
236                }
237            }
238
239            // 如果只有 $ 没有有效变量名,回退
240            state.set_position(start_pos);
241        }
242
243        false
244    }
245
246    fn lex_number<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
247        let start_pos = state.get_position();
248
249        if let Some(ch) = state.peek() {
250            if ch.is_ascii_digit() {
251                state.advance(1);
252                while let Some(ch) = state.peek() {
253                    if ch.is_ascii_digit() {
254                        state.advance(1);
255                    }
256                    else {
257                        break;
258                    }
259                }
260                state.add_token(BashTokenType::NumberLiteral, start_pos, state.get_position());
261                return true;
262            }
263        }
264
265        false
266    }
267
268    fn lex_keyword_or_identifier<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
269        let start_pos = state.get_position();
270
271        if let Some(ch) = state.peek() {
272            if ch.is_ascii_alphabetic() || ch == '_' {
273                state.advance(ch.len_utf8());
274                while let Some(ch) = state.peek() {
275                    if ch.is_ascii_alphanumeric() || ch == '_' {
276                        state.advance(ch.len_utf8());
277                    }
278                    else {
279                        break;
280                    }
281                }
282
283                let text = state.get_text_in((start_pos..state.get_position()).into());
284                let kind = if BASH_KEYWORDS.contains(&text.as_ref()) { BashTokenType::Keyword } else { BashTokenType::Identifier };
285
286                state.add_token(kind, start_pos, state.get_position());
287                return true;
288            }
289        }
290
291        false
292    }
293
294    fn lex_operator_or_delimiter<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
295        let start_pos = state.get_position();
296
297        if let Some(ch) = state.peek() {
298            let two_char = if let Some(next_ch) = state.peek_next_n(1) { format!("{}{}", ch, next_ch) } else { String::new() };
299
300            // 检查双字符操作符
301            if BASH_TWO_CHAR_OPERATORS.contains(&two_char.as_str()) {
302                state.advance(2);
303                state.add_token(BashTokenType::Operator, start_pos, state.get_position());
304                return true;
305            }
306
307            // 检查单字符操作符和分隔符
308            let ch_str = ch.to_string();
309            if BASH_OPERATORS.contains(&ch_str.as_str()) {
310                state.advance(1);
311                state.add_token(BashTokenType::Operator, start_pos, state.get_position());
312                return true;
313            }
314
315            if BASH_DELIMITERS.contains(&ch_str.as_str()) {
316                state.advance(1);
317                state.add_token(BashTokenType::Delimiter, start_pos, state.get_position());
318                return true;
319            }
320        }
321
322        false
323    }
324
325    fn lex_heredoc<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
326        let start_pos = state.get_position();
327
328        // 检查 << 开始的 heredoc
329        if let Some('<') = state.peek() {
330            if let Some('<') = state.peek_next_n(1) {
331                state.advance(2);
332
333                // 跳过可选的 -
334                if let Some('-') = state.peek() {
335                    state.advance(1);
336                }
337
338                // 读取标识符
339                while let Some(ch) = state.peek() {
340                    if ch.is_alphanumeric() || ch == '_' {
341                        state.advance(ch.len_utf8());
342                    }
343                    else {
344                        break;
345                    }
346                }
347
348                state.add_token(BashTokenType::Heredoc, start_pos, state.get_position());
349                return true;
350            }
351        }
352
353        false
354    }
355
356    fn lex_glob_pattern<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
357        let start_pos = state.get_position();
358
359        if let Some(ch) = state.peek() {
360            if ch == '*' || ch == '?' || ch == '[' {
361                state.advance(1);
362
363                if ch == '[' {
364                    // 处理字符类 [abc] 或 [!abc]
365                    if let Some('!') = state.peek() {
366                        state.advance(1);
367                    }
368                    while let Some(ch) = state.peek() {
369                        if ch == ']' {
370                            state.advance(1);
371                            break;
372                        }
373                        state.advance(ch.len_utf8());
374                    }
375                }
376
377                state.add_token(BashTokenType::GlobPattern, start_pos, state.get_position());
378                return true;
379            }
380        }
381
382        false
383    }
384
385    fn lex_special_char<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
386        let start_pos = state.get_position();
387
388        if let Some(ch) = state.peek() {
389            if BASH_SPECIAL_CHARS.contains(&ch) {
390                state.advance(1);
391                state.add_token(BashTokenType::SpecialChar, start_pos, state.get_position());
392                return true;
393            }
394        }
395
396        false
397    }
398
399    fn lex_text<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
400        let start_pos = state.get_position();
401
402        if let Some(ch) = state.peek() {
403            if !ch.is_whitespace() && !BASH_SPECIAL_CHARS.contains(&ch) {
404                state.advance(ch.len_utf8());
405                state.add_token(BashTokenType::Text, start_pos, state.get_position());
406                return true;
407            }
408        }
409
410        false
411    }
412}
413
414static BASH_KEYWORDS: LazyLock<&[&str]> = LazyLock::new(|| {
415    &[
416        "if", "then", "else", "elif", "fi", "case", "esac", "for", "while", "until", "do", "done", "function", "return", "break", "continue", "local", "export", "readonly", "declare", "typeset", "unset", "shift", "exit", "source", ".", "eval", "exec",
417        "trap", "wait", "jobs", "bg", "fg", "disown", "suspend", "alias", "unalias", "history", "fc", "let", "test", "[", "[[", "]]", "time", "coproc", "select", "in",
418    ]
419});
420
421static BASH_OPERATORS: LazyLock<&[&str]> = LazyLock::new(|| &["+", "-", "*", "/", "%", "=", "!", "<", ">", "&", "|", "^", "~"]);
422
423static BASH_TWO_CHAR_OPERATORS: LazyLock<&[&str]> = LazyLock::new(|| &["==", "!=", "<=", ">=", "&&", "||", "<<", ">>", "++", "--", "+=", "-=", "*=", "/=", "%=", "&=", "|=", "^=", "<<=", ">>=", "**"]);
424
425static BASH_DELIMITERS: LazyLock<&[&str]> = LazyLock::new(|| &["(", ")", "{", "}", "[", "]", ";", ",", ":", "."]);
426
427static BASH_SPECIAL_CHARS: LazyLock<&[char]> = LazyLock::new(|| &['\\', '`', '~', '@', '#', '$', '%', '^', '&', '*', '(', ')', '-', '+', '=', '{', '}', '[', ']', '|', '\\', ':', ';', '"', '\'', '<', '>', ',', '.', '?', '/', '!', '`']);