Skip to main content

oak_sql/lexer/
mod.rs

1use crate::{kind::SqlSyntaxKind, language::SqlLanguage};
2use oak_core::{
3    Lexer, LexerCache, LexerState, OakError, TextEdit,
4    lexer::{LexOutput, WhitespaceConfig},
5    source::Source,
6};
7use std::sync::LazyLock;
8
9type State<'a, S> = LexerState<'a, S, SqlLanguage>;
10
11static SQL_WHITESPACE: LazyLock<WhitespaceConfig> = LazyLock::new(|| WhitespaceConfig { unicode_whitespace: true });
12
13#[derive(Clone, Debug)]
14pub struct SqlLexer<'config> {
15    _config: &'config SqlLanguage,
16}
17
18impl<'config> Lexer<SqlLanguage> for SqlLexer<'config> {
19    fn lex<'a, S: Source + ?Sized>(&self, text: &'a S, _edits: &[TextEdit], cache: &'a mut impl LexerCache<SqlLanguage>) -> LexOutput<SqlLanguage> {
20        let mut state = State::new(text);
21        let result = self.run(&mut state);
22        if result.is_ok() {
23            state.add_eof();
24        }
25        state.finish_with_cache(result, cache)
26    }
27}
28
29impl<'config> SqlLexer<'config> {
30    pub fn new(config: &'config SqlLanguage) -> Self {
31        Self { _config: config }
32    }
33
34    fn run<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
35        while state.not_at_end() {
36            let safe_point = state.get_position();
37
38            if let Some(ch) = state.peek() {
39                match ch {
40                    ' ' | '\t' => {
41                        self.skip_whitespace(state);
42                    }
43                    '\n' | '\r' => {
44                        self.lex_newline(state);
45                    }
46                    '-' => {
47                        if state.starts_with("--") {
48                            self.skip_comment(state);
49                        }
50                        else {
51                            self.lex_operators(state);
52                        }
53                    }
54                    '/' => {
55                        if state.starts_with("/*") {
56                            self.skip_comment(state);
57                        }
58                        else {
59                            self.lex_operators(state);
60                        }
61                    }
62                    '\'' | '"' => {
63                        self.lex_string_literal(state);
64                    }
65                    '0'..='9' => {
66                        self.lex_number_literal(state);
67                    }
68                    'a'..='z' | 'A'..='Z' | '_' => {
69                        self.lex_identifier_or_keyword(state);
70                    }
71                    '<' | '>' | '!' | '=' | '+' | '*' | '%' => {
72                        self.lex_operators(state);
73                    }
74                    '(' | ')' | ',' | ';' | '.' => {
75                        self.lex_single_char_tokens(state);
76                    }
77                    _ => {
78                        // 如果没有匹配任何模式,跳过当前字符并添加错误 token
79                        state.advance(ch.len_utf8());
80                        state.add_token(SqlSyntaxKind::Error, safe_point, state.get_position());
81                    }
82                }
83            }
84
85            state.advance_if_dead_lock(safe_point);
86        }
87        Ok(())
88    }
89
90    /// 处理换行
91    fn lex_newline<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
92        let start_pos = state.get_position();
93
94        if let Some('\n') = state.peek() {
95            state.advance(1);
96            state.add_token(SqlSyntaxKind::Newline, start_pos, state.get_position());
97            true
98        }
99        else if let Some('\r') = state.peek() {
100            state.advance(1);
101            if let Some('\n') = state.peek() {
102                state.advance(1);
103            }
104            state.add_token(SqlSyntaxKind::Newline, start_pos, state.get_position());
105            true
106        }
107        else {
108            false
109        }
110    }
111
112    fn skip_whitespace<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
113        SQL_WHITESPACE.scan(state, SqlSyntaxKind::Whitespace)
114    }
115
116    fn skip_comment<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
117        let start = state.get_position();
118
119        // 行注释: -- ... 直到换行
120        if state.starts_with("--") {
121            state.advance(2);
122            state.take_while(|ch| ch != '\n' && ch != '\r');
123            state.add_token(SqlSyntaxKind::Comment, start, state.get_position());
124            return true;
125        }
126
127        // 块注释: /* ... */
128        if state.starts_with("/*") {
129            state.advance(2);
130            while state.not_at_end() {
131                if state.starts_with("*/") {
132                    state.advance(2);
133                    break;
134                }
135                if let Some(ch) = state.current() {
136                    state.advance(ch.len_utf8());
137                }
138            }
139            state.add_token(SqlSyntaxKind::Comment, start, state.get_position());
140            return true;
141        }
142
143        false
144    }
145
146    fn lex_string_literal<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
147        let start = state.get_position();
148        if let Some(quote) = state.current() {
149            if quote != '\'' && quote != '"' {
150                return false;
151            }
152            state.advance(1);
153            let mut escaped = false;
154            while state.not_at_end() {
155                let ch = match state.peek() {
156                    Some(c) => c,
157                    None => break,
158                };
159
160                if ch == quote && !escaped {
161                    state.advance(1); // 消费结束引号
162                    break;
163                }
164                state.advance(ch.len_utf8());
165                if escaped {
166                    escaped = false;
167                    continue;
168                }
169                if ch == '\\' {
170                    escaped = true;
171                    continue;
172                }
173                if ch == '\n' || ch == '\r' {
174                    break;
175                }
176            }
177            state.add_token(SqlSyntaxKind::StringLiteral, start, state.get_position());
178            return true;
179        }
180        false
181    }
182
183    fn lex_number_literal<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
184        let start = state.get_position();
185        let first = match state.current() {
186            Some(c) => c,
187            None => return false,
188        };
189
190        if !first.is_ascii_digit() {
191            return false;
192        }
193
194        let mut is_float = false;
195        state.advance(1);
196
197        // 整数部分
198        while let Some(c) = state.peek() {
199            if c.is_ascii_digit() || c == '_' {
200                state.advance(1);
201            }
202            else {
203                break;
204            }
205        }
206
207        // 小数部分
208        if state.peek() == Some('.') {
209            let next = state.peek_next_n(1);
210            if next.map(|c| c.is_ascii_digit()).unwrap_or(false) {
211                is_float = true;
212                state.advance(1); // 消费 '.'
213                while let Some(c) = state.peek() {
214                    if c.is_ascii_digit() || c == '_' {
215                        state.advance(1);
216                    }
217                    else {
218                        break;
219                    }
220                }
221            }
222        }
223
224        // 指数部分
225        if let Some(c) = state.peek() {
226            if c == 'e' || c == 'E' {
227                let next = state.peek_next_n(1);
228                if next == Some('+') || next == Some('-') || next.map(|d| d.is_ascii_digit()).unwrap_or(false) {
229                    is_float = true;
230                    state.advance(1);
231                    if let Some(sign) = state.peek() {
232                        if sign == '+' || sign == '-' {
233                            state.advance(1);
234                        }
235                    }
236                    while let Some(d) = state.peek() {
237                        if d.is_ascii_digit() || d == '_' {
238                            state.advance(1);
239                        }
240                        else {
241                            break;
242                        }
243                    }
244                }
245            }
246        }
247
248        let end = state.get_position();
249        state.add_token(if is_float { SqlSyntaxKind::FloatLiteral } else { SqlSyntaxKind::NumberLiteral }, start, end);
250        true
251    }
252
253    fn lex_identifier_or_keyword<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
254        let start = state.get_position();
255        let ch = match state.current() {
256            Some(c) => c,
257            None => return false,
258        };
259
260        if !ch.is_alphabetic() && ch != '_' {
261            return false;
262        }
263
264        state.advance(ch.len_utf8());
265        while let Some(c) = state.peek() {
266            if c.is_alphanumeric() || c == '_' {
267                state.advance(c.len_utf8());
268            }
269            else {
270                break;
271            }
272        }
273
274        let end = state.get_position();
275        let text = state.source().get_text_in(oak_core::Range { start, end }).to_uppercase();
276        let kind = match text.as_str() {
277            "SELECT" => SqlSyntaxKind::Select,
278            "FROM" => SqlSyntaxKind::From,
279            "WHERE" => SqlSyntaxKind::Where,
280            "INSERT" => SqlSyntaxKind::Insert,
281            "UPDATE" => SqlSyntaxKind::Update,
282            "DELETE" => SqlSyntaxKind::Delete,
283            "CREATE" => SqlSyntaxKind::Create,
284            "DROP" => SqlSyntaxKind::Drop,
285            "ALTER" => SqlSyntaxKind::Alter,
286            "TABLE" => SqlSyntaxKind::Table,
287            "INDEX" => SqlSyntaxKind::Index,
288            "INTO" => SqlSyntaxKind::Into,
289            "VALUES" => SqlSyntaxKind::Values,
290            "SET" => SqlSyntaxKind::Set,
291            "JOIN" => SqlSyntaxKind::Join,
292            "INNER" => SqlSyntaxKind::Inner,
293            "LEFT" => SqlSyntaxKind::Left,
294            "RIGHT" => SqlSyntaxKind::Right,
295            "FULL" => SqlSyntaxKind::Full,
296            "OUTER" => SqlSyntaxKind::Outer,
297            "ON" => SqlSyntaxKind::On,
298            "AND" => SqlSyntaxKind::And,
299            "OR" => SqlSyntaxKind::Or,
300            "NOT" => SqlSyntaxKind::Not,
301            "NULL" => SqlSyntaxKind::Null,
302            "TRUE" => SqlSyntaxKind::True,
303            "FALSE" => SqlSyntaxKind::False,
304            "AS" => SqlSyntaxKind::As,
305            "BY" => SqlSyntaxKind::By,
306            "ORDER" => SqlSyntaxKind::Order,
307            "ASC" => SqlSyntaxKind::Asc,
308            "DESC" => SqlSyntaxKind::Desc,
309            "GROUP" => SqlSyntaxKind::Group,
310            "HAVING" => SqlSyntaxKind::Having,
311            "LIMIT" => SqlSyntaxKind::Limit,
312            "OFFSET" => SqlSyntaxKind::Offset,
313            "UNION" => SqlSyntaxKind::Union,
314            "ALL" => SqlSyntaxKind::All,
315            "DISTINCT" => SqlSyntaxKind::Distinct,
316            "PRIMARY" => SqlSyntaxKind::Primary,
317            "KEY" => SqlSyntaxKind::Key,
318            "FOREIGN" => SqlSyntaxKind::Foreign,
319            "REFERENCES" => SqlSyntaxKind::References,
320            "DEFAULT" => SqlSyntaxKind::Default,
321            "UNIQUE" => SqlSyntaxKind::Unique,
322            "AUTO_INCREMENT" => SqlSyntaxKind::AutoIncrement,
323            "INT" => SqlSyntaxKind::Int,
324            "INTEGER" => SqlSyntaxKind::Integer,
325            "VARCHAR" => SqlSyntaxKind::Varchar,
326            "CHAR" => SqlSyntaxKind::Char,
327            "TEXT" => SqlSyntaxKind::Text,
328            "DATE" => SqlSyntaxKind::Date,
329            "TIME" => SqlSyntaxKind::Time,
330            "TIMESTAMP" => SqlSyntaxKind::Timestamp,
331            "DECIMAL" => SqlSyntaxKind::Decimal,
332            "FLOAT" => SqlSyntaxKind::Float,
333            "DOUBLE" => SqlSyntaxKind::Double,
334            "BOOLEAN" => SqlSyntaxKind::Boolean,
335            _ => SqlSyntaxKind::Identifier,
336        };
337
338        state.add_token(kind, start, end);
339        true
340    }
341
342    fn lex_operators<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
343        let start = state.get_position();
344
345        let ops = [
346            ("<=", SqlSyntaxKind::LessEqual),
347            (">=", SqlSyntaxKind::GreaterEqual),
348            ("<>", SqlSyntaxKind::NotEqual),
349            ("!=", SqlSyntaxKind::NotEqual),
350            ("=", SqlSyntaxKind::Equal),
351            ("<", SqlSyntaxKind::Less),
352            (">", SqlSyntaxKind::Greater),
353            ("+", SqlSyntaxKind::Plus),
354            ("-", SqlSyntaxKind::Minus),
355            ("*", SqlSyntaxKind::Star),
356            ("/", SqlSyntaxKind::Slash),
357            ("%", SqlSyntaxKind::Percent),
358        ];
359
360        for (op, kind) in ops {
361            if state.starts_with(op) {
362                state.advance(op.len());
363                state.add_token(kind, start, state.get_position());
364                return true;
365            }
366        }
367
368        false
369    }
370
371    fn lex_single_char_tokens<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
372        let start = state.get_position();
373        let ch = match state.current() {
374            Some(c) => c,
375            None => return false,
376        };
377
378        let kind = match ch {
379            '(' => SqlSyntaxKind::LeftParen,
380            ')' => SqlSyntaxKind::RightParen,
381            ',' => SqlSyntaxKind::Comma,
382            ';' => SqlSyntaxKind::Semicolon,
383            '.' => SqlSyntaxKind::Dot,
384            _ => return false,
385        };
386
387        state.advance(ch.len_utf8());
388        state.add_token(kind, start, state.get_position());
389        true
390    }
391}