oak_sql/lexer/
mod.rs

1use crate::{kind::SqlSyntaxKind, language::SqlLanguage};
2use oak_core::{
3    IncrementalCache, Lexer, LexerState, OakError,
4    lexer::{CommentLine, LexOutput, StringConfig, WhitespaceConfig},
5    source::Source,
6};
7use std::sync::LazyLock;
8
9type State<S> = LexerState<S, SqlLanguage>;
10
11static SQL_WHITESPACE: LazyLock<WhitespaceConfig> = LazyLock::new(|| WhitespaceConfig { unicode_whitespace: true });
12static SQL_COMMENT: LazyLock<CommentLine> = LazyLock::new(|| CommentLine { line_markers: &["--"] });
13static SQL_STRING: LazyLock<StringConfig> = LazyLock::new(|| StringConfig { quotes: &['"', '\''], escape: Some('\\') });
14
15#[derive(Clone)]
16pub struct SqlLexer<'config> {
17    config: &'config SqlLanguage,
18}
19
20impl<'config> Lexer<SqlLanguage> for SqlLexer<'config> {
21    fn lex_incremental(
22        &self,
23        source: impl Source,
24        changed: usize,
25        cache: IncrementalCache<SqlLanguage>,
26    ) -> LexOutput<SqlLanguage> {
27        let mut state = LexerState::new_with_cache(source, changed, cache);
28        let result = self.run(&mut state);
29        state.finish(result)
30    }
31}
32
33impl<'config> SqlLexer<'config> {
34    pub fn new(config: &'config SqlLanguage) -> Self {
35        Self { config }
36    }
37
38    fn run<S: Source>(&self, state: &mut State<S>) -> Result<(), OakError> {
39        while state.not_at_end() {
40            let safe_point = state.get_position();
41
42            if self.skip_whitespace(state) {
43                continue;
44            }
45
46            if self.lex_newline(state) {
47                continue;
48            }
49
50            if self.skip_comment(state) {
51                continue;
52            }
53
54            if self.lex_string_literal(state) {
55                continue;
56            }
57
58            if self.lex_number_literal(state) {
59                continue;
60            }
61
62            if self.lex_identifier_or_keyword(state) {
63                continue;
64            }
65
66            if self.lex_operators(state) {
67                continue;
68            }
69
70            if self.lex_single_char_tokens(state) {
71                continue;
72            }
73
74            // 如果没有匹配任何模式,跳过当前字符并添加错误 token
75            if let Some(ch) = state.peek() {
76                state.advance(ch.len_utf8());
77                state.add_token(SqlSyntaxKind::Error, safe_point, state.get_position());
78            }
79        }
80
81        // 添加 EOF token
82        let eof_pos = state.get_position();
83        state.add_token(SqlSyntaxKind::Eof, eof_pos, eof_pos);
84        Ok(())
85    }
86
87    /// 处理换行
88    fn lex_newline<S: Source>(&self, state: &mut State<S>) -> bool {
89        let start_pos = state.get_position();
90
91        if let Some('\n') = state.peek() {
92            state.advance(1);
93            state.add_token(SqlSyntaxKind::Newline, start_pos, state.get_position());
94            true
95        }
96        else if let Some('\r') = state.peek() {
97            state.advance(1);
98            if let Some('\n') = state.peek() {
99                state.advance(1);
100            }
101            state.add_token(SqlSyntaxKind::Newline, start_pos, state.get_position());
102            true
103        }
104        else {
105            false
106        }
107    }
108
109    fn skip_whitespace<S: Source>(&self, state: &mut State<S>) -> bool {
110        match SQL_WHITESPACE.scan(state.rest(), state.get_position(), SqlSyntaxKind::Whitespace) {
111            Some(token) => {
112                state.advance_with(token);
113                true
114            }
115            None => false,
116        }
117    }
118
119    fn skip_comment<S: Source>(&self, state: &mut State<S>) -> bool {
120        let start = state.get_position();
121        let rest = state.rest();
122
123        // 行注释: -- ... 直到换行
124        if rest.starts_with("--") {
125            state.advance(2);
126            while let Some(ch) = state.peek() {
127                if ch == '\n' || ch == '\r' {
128                    break;
129                }
130                state.advance(ch.len_utf8());
131            }
132            state.add_token(SqlSyntaxKind::Comment, start, state.get_position());
133            return true;
134        }
135
136        // 块注释: /* ... */
137        if rest.starts_with("/*") {
138            state.advance(2);
139            while let Some(ch) = state.peek() {
140                if ch == '*' && state.peek_next_n(1) == Some('/') {
141                    state.advance(2);
142                    break;
143                }
144                state.advance(ch.len_utf8());
145            }
146            state.add_token(SqlSyntaxKind::Comment, start, state.get_position());
147            return true;
148        }
149
150        false
151    }
152
153    fn lex_string_literal<S: Source>(&self, state: &mut State<S>) -> bool {
154        let start = state.get_position();
155        let ch = match state.current() {
156            Some(c) => c,
157            None => return false,
158        };
159
160        if ch == '\'' || ch == '"' {
161            let quote = ch;
162            state.advance(1);
163            let mut escaped = false;
164
165            while let Some(ch) = state.peek() {
166                if ch == quote && !escaped {
167                    state.advance(1); // 消费结束引号
168                    break;
169                }
170                state.advance(ch.len_utf8());
171                if escaped {
172                    escaped = false;
173                    continue;
174                }
175                if ch == '\\' {
176                    escaped = true;
177                    continue;
178                }
179                if ch == '\n' || ch == '\r' {
180                    break;
181                }
182            }
183            state.add_token(SqlSyntaxKind::StringLiteral, start, state.get_position());
184            return true;
185        }
186        false
187    }
188
189    fn lex_number_literal<S: Source>(&self, state: &mut State<S>) -> bool {
190        let start = state.get_position();
191        let first = match state.current() {
192            Some(c) => c,
193            None => return false,
194        };
195
196        if !first.is_ascii_digit() {
197            return false;
198        }
199
200        let mut is_float = false;
201        state.advance(1);
202
203        // 整数部分
204        while let Some(c) = state.peek() {
205            if c.is_ascii_digit() || c == '_' {
206                state.advance(1);
207            }
208            else {
209                break;
210            }
211        }
212
213        // 小数部分
214        if state.peek() == Some('.') {
215            let next = state.peek_next_n(1);
216            if next.map(|c| c.is_ascii_digit()).unwrap_or(false) {
217                is_float = true;
218                state.advance(1); // 消费 '.'
219                while let Some(c) = state.peek() {
220                    if c.is_ascii_digit() || c == '_' {
221                        state.advance(1);
222                    }
223                    else {
224                        break;
225                    }
226                }
227            }
228        }
229
230        // 指数部分
231        if let Some(c) = state.peek() {
232            if c == 'e' || c == 'E' {
233                let next = state.peek_next_n(1);
234                if next == Some('+') || next == Some('-') || next.map(|d| d.is_ascii_digit()).unwrap_or(false) {
235                    is_float = true;
236                    state.advance(1);
237                    if let Some(sign) = state.peek() {
238                        if sign == '+' || sign == '-' {
239                            state.advance(1);
240                        }
241                    }
242                    while let Some(d) = state.peek() {
243                        if d.is_ascii_digit() || d == '_' {
244                            state.advance(1);
245                        }
246                        else {
247                            break;
248                        }
249                    }
250                }
251            }
252        }
253
254        let end = state.get_position();
255        state.add_token(if is_float { SqlSyntaxKind::FloatLiteral } else { SqlSyntaxKind::NumberLiteral }, start, end);
256        true
257    }
258
259    fn lex_identifier_or_keyword<S: Source>(&self, state: &mut State<S>) -> bool {
260        let start = state.get_position();
261        let ch = match state.current() {
262            Some(c) => c,
263            None => return false,
264        };
265
266        if !(ch.is_ascii_alphabetic() || ch == '_') {
267            return false;
268        }
269
270        state.advance(1);
271        while let Some(c) = state.current() {
272            if c.is_ascii_alphanumeric() || c == '_' {
273                state.advance(1);
274            }
275            else {
276                break;
277            }
278        }
279
280        let end = state.get_position();
281        let text = state.get_text_in((start..end).into());
282        let kind = self.keyword_kind(&text).unwrap_or(SqlSyntaxKind::Identifier);
283        state.add_token(kind, start, end);
284        true
285    }
286
287    fn keyword_kind(&self, text: &str) -> Option<SqlSyntaxKind> {
288        match text.to_uppercase().as_str() {
289            "SELECT" => Some(SqlSyntaxKind::Select),
290            "FROM" => Some(SqlSyntaxKind::From),
291            "WHERE" => Some(SqlSyntaxKind::Where),
292            "INSERT" => Some(SqlSyntaxKind::Insert),
293            "INTO" => Some(SqlSyntaxKind::Into),
294            "VALUES" => Some(SqlSyntaxKind::Values),
295            "UPDATE" => Some(SqlSyntaxKind::Update),
296            "SET" => Some(SqlSyntaxKind::Set),
297            "DELETE" => Some(SqlSyntaxKind::Delete),
298            "CREATE" => Some(SqlSyntaxKind::Create),
299            "DROP" => Some(SqlSyntaxKind::Drop),
300            "ALTER" => Some(SqlSyntaxKind::Alter),
301            "ADD" => Some(SqlSyntaxKind::Add),
302            "COLUMN" => Some(SqlSyntaxKind::Column),
303            "TABLE" => Some(SqlSyntaxKind::Table),
304            "PRIMARY" => Some(SqlSyntaxKind::Primary),
305            "KEY" => Some(SqlSyntaxKind::Key),
306            "FOREIGN" => Some(SqlSyntaxKind::Foreign),
307            "REFERENCES" => Some(SqlSyntaxKind::References),
308            "INDEX" => Some(SqlSyntaxKind::Index),
309            "UNIQUE" => Some(SqlSyntaxKind::Unique),
310            "NOT" => Some(SqlSyntaxKind::Not),
311            "NULL" => Some(SqlSyntaxKind::Null),
312            "DEFAULT" => Some(SqlSyntaxKind::Default),
313            "AUTO_INCREMENT" => Some(SqlSyntaxKind::AutoIncrement),
314            "AND" => Some(SqlSyntaxKind::And),
315            "OR" => Some(SqlSyntaxKind::Or),
316            "IN" => Some(SqlSyntaxKind::In),
317            "LIKE" => Some(SqlSyntaxKind::Like),
318            "BETWEEN" => Some(SqlSyntaxKind::Between),
319            "IS" => Some(SqlSyntaxKind::Is),
320            "AS" => Some(SqlSyntaxKind::As),
321            "JOIN" => Some(SqlSyntaxKind::Join),
322            "INNER" => Some(SqlSyntaxKind::Inner),
323            "LEFT" => Some(SqlSyntaxKind::Left),
324            "RIGHT" => Some(SqlSyntaxKind::Right),
325            "FULL" => Some(SqlSyntaxKind::Full),
326            "OUTER" => Some(SqlSyntaxKind::Outer),
327            "ON" => Some(SqlSyntaxKind::On),
328            "GROUP" => Some(SqlSyntaxKind::Group),
329            "BY" => Some(SqlSyntaxKind::By),
330            "HAVING" => Some(SqlSyntaxKind::Having),
331            "ORDER" => Some(SqlSyntaxKind::Order),
332            "ASC" => Some(SqlSyntaxKind::Asc),
333            "DESC" => Some(SqlSyntaxKind::Desc),
334            "LIMIT" => Some(SqlSyntaxKind::Limit),
335            "OFFSET" => Some(SqlSyntaxKind::Offset),
336            "UNION" => Some(SqlSyntaxKind::Union),
337            "ALL" => Some(SqlSyntaxKind::All),
338            "DISTINCT" => Some(SqlSyntaxKind::Distinct),
339            "COUNT" => Some(SqlSyntaxKind::Count),
340            "SUM" => Some(SqlSyntaxKind::Sum),
341            "AVG" => Some(SqlSyntaxKind::Avg),
342            "MIN" => Some(SqlSyntaxKind::Min),
343            "MAX" => Some(SqlSyntaxKind::Max),
344            "VIEW" => Some(SqlSyntaxKind::View),
345            "DATABASE" => Some(SqlSyntaxKind::Database),
346            "SCHEMA" => Some(SqlSyntaxKind::Schema),
347            "TRUE" => Some(SqlSyntaxKind::True),
348            "FALSE" => Some(SqlSyntaxKind::False),
349            "EXISTS" => Some(SqlSyntaxKind::Exists),
350            "CASE" => Some(SqlSyntaxKind::Case),
351            "WHEN" => Some(SqlSyntaxKind::When),
352            "THEN" => Some(SqlSyntaxKind::Then),
353            "ELSE" => Some(SqlSyntaxKind::Else),
354            "END" => Some(SqlSyntaxKind::End),
355            "IF" => Some(SqlSyntaxKind::If),
356            "BEGIN" => Some(SqlSyntaxKind::Begin),
357            "COMMIT" => Some(SqlSyntaxKind::Commit),
358            "ROLLBACK" => Some(SqlSyntaxKind::Rollback),
359            "TRANSACTION" => Some(SqlSyntaxKind::Transaction),
360            // 数据类型
361            "INT" => Some(SqlSyntaxKind::Int),
362            "INTEGER" => Some(SqlSyntaxKind::Integer),
363            "VARCHAR" => Some(SqlSyntaxKind::Varchar),
364            "CHAR" => Some(SqlSyntaxKind::Char),
365            "TEXT" => Some(SqlSyntaxKind::Text),
366            "DATE" => Some(SqlSyntaxKind::Date),
367            "TIME" => Some(SqlSyntaxKind::Time),
368            "TIMESTAMP" => Some(SqlSyntaxKind::Timestamp),
369            "DECIMAL" => Some(SqlSyntaxKind::Decimal),
370            "FLOAT" => Some(SqlSyntaxKind::Float),
371            "DOUBLE" => Some(SqlSyntaxKind::Double),
372            "BOOLEAN" => Some(SqlSyntaxKind::Boolean),
373            _ => None,
374        }
375    }
376
377    fn lex_operators<S: Source>(&self, state: &mut State<S>) -> bool {
378        let start = state.get_position();
379        let rest = state.rest();
380
381        // 优先匹配较长的操作符
382        let patterns: &[(&str, SqlSyntaxKind)] = &[
383            ("<=", SqlSyntaxKind::Le),
384            (">=", SqlSyntaxKind::Ge),
385            ("!=", SqlSyntaxKind::Ne),
386            ("<>", SqlSyntaxKind::Ne),
387            ("||", SqlSyntaxKind::Concat),
388        ];
389
390        for (pat, kind) in patterns {
391            if rest.starts_with(pat) {
392                state.advance(pat.len());
393                state.add_token(*kind, start, state.get_position());
394                return true;
395            }
396        }
397
398        if let Some(ch) = state.current() {
399            let kind = match ch {
400                '=' => Some(SqlSyntaxKind::Equal),
401                '<' => Some(SqlSyntaxKind::Lt),
402                '>' => Some(SqlSyntaxKind::Gt),
403                '+' => Some(SqlSyntaxKind::Plus),
404                '-' => Some(SqlSyntaxKind::Minus),
405                '*' => Some(SqlSyntaxKind::Star),
406                '/' => Some(SqlSyntaxKind::Slash),
407                '%' => Some(SqlSyntaxKind::Percent),
408                '.' => Some(SqlSyntaxKind::Dot),
409                _ => None,
410            };
411            if let Some(k) = kind {
412                state.advance(ch.len_utf8());
413                state.add_token(k, start, state.get_position());
414                return true;
415            }
416        }
417        false
418    }
419
420    fn lex_single_char_tokens<S: Source>(&self, state: &mut State<S>) -> bool {
421        let start = state.get_position();
422        if let Some(ch) = state.current() {
423            let kind = match ch {
424                '(' => SqlSyntaxKind::LeftParen,
425                ')' => SqlSyntaxKind::RightParen,
426                '{' => SqlSyntaxKind::LeftBrace,
427                '}' => SqlSyntaxKind::RightBrace,
428                '[' => SqlSyntaxKind::LeftBracket,
429                ']' => SqlSyntaxKind::RightBracket,
430                ',' => SqlSyntaxKind::Comma,
431                ';' => SqlSyntaxKind::Semicolon,
432                ':' => SqlSyntaxKind::Colon,
433                '?' => SqlSyntaxKind::Question,
434                _ => return false,
435            };
436            state.advance(ch.len_utf8());
437            state.add_token(kind, start, state.get_position());
438            return true;
439        }
440        false
441    }
442}