oak_sql/lexer/
mod.rs

1use crate::{kind::SqlSyntaxKind, language::SqlLanguage};
2use oak_core::{
3    Lexer, LexerCache, LexerState, OakError, TextEdit,
4    lexer::{LexOutput, WhitespaceConfig},
5    source::Source,
6};
7use std::sync::LazyLock;
8
9type State<'a, S> = LexerState<'a, S, SqlLanguage>;
10
11static SQL_WHITESPACE: LazyLock<WhitespaceConfig> = LazyLock::new(|| WhitespaceConfig { unicode_whitespace: true });
12
13#[derive(Clone)]
14pub struct SqlLexer<'config> {
15    _config: &'config SqlLanguage,
16}
17
18impl<'config> Lexer<SqlLanguage> for SqlLexer<'config> {
19    fn lex<'a, S: Source + ?Sized>(&self, text: &'a S, _edits: &[TextEdit], cache: &'a mut impl LexerCache<SqlLanguage>) -> LexOutput<SqlLanguage> {
20        let mut state = State::new(text);
21        let result = self.run(&mut state);
22        if result.is_ok() {
23            state.add_eof();
24        }
25        state.finish_with_cache(result, cache)
26    }
27}
28
29impl<'config> SqlLexer<'config> {
30    pub fn new(config: &'config SqlLanguage) -> Self {
31        Self { _config: config }
32    }
33
34    fn run<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
35        while state.not_at_end() {
36            let safe_point = state.get_position();
37
38            if let Some(ch) = state.peek() {
39                match ch {
40                    ' ' | '\t' => {
41                        self.skip_whitespace(state);
42                    }
43                    '\n' | '\r' => {
44                        self.lex_newline(state);
45                    }
46                    '-' => {
47                        if state.starts_with("--") {
48                            self.skip_comment(state);
49                        }
50                        else {
51                            self.lex_operators(state);
52                        }
53                    }
54                    '/' => {
55                        if state.starts_with("/*") {
56                            self.skip_comment(state);
57                        }
58                        else {
59                            self.lex_operators(state);
60                        }
61                    }
62                    '\'' | '"' => {
63                        self.lex_string_literal(state);
64                    }
65                    '0'..='9' => {
66                        self.lex_number_literal(state);
67                    }
68                    'a'..='z' | 'A'..='Z' | '_' => {
69                        self.lex_identifier_or_keyword(state);
70                    }
71                    '<' | '>' | '!' | '=' | '+' | '*' | '%' => {
72                        self.lex_operators(state);
73                    }
74                    '(' | ')' | ',' | ';' | '.' => {
75                        self.lex_single_char_tokens(state);
76                    }
77                    _ => {
78                        // 如果没有匹配任何模式,跳过当前字符并添加错误 token
79                        state.advance(ch.len_utf8());
80                        state.add_token(SqlSyntaxKind::Error, safe_point, state.get_position());
81                    }
82                }
83            }
84
85            state.advance_if_dead_lock(safe_point);
86        }
87        Ok(())
88    }
89
90    /// 处理换行
91    fn lex_newline<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
92        let start_pos = state.get_position();
93
94        if let Some('\n') = state.peek() {
95            state.advance(1);
96            state.add_token(SqlSyntaxKind::Newline, start_pos, state.get_position());
97            true
98        }
99        else if let Some('\r') = state.peek() {
100            state.advance(1);
101            if let Some('\n') = state.peek() {
102                state.advance(1);
103            }
104            state.add_token(SqlSyntaxKind::Newline, start_pos, state.get_position());
105            true
106        }
107        else {
108            false
109        }
110    }
111
112    fn skip_whitespace<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
113        SQL_WHITESPACE.scan(state, SqlSyntaxKind::Whitespace)
114    }
115
116    fn skip_comment<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
117        let start = state.get_position();
118
119        // 行注释: -- ... 直到换行
120        if state.starts_with("--") {
121            state.advance(2);
122            state.take_while(|ch| ch != '\n' && ch != '\r');
123            state.add_token(SqlSyntaxKind::Comment, start, state.get_position());
124            return true;
125        }
126
127        // 块注释: /* ... */
128        if state.starts_with("/*") {
129            state.advance(2);
130            while state.not_at_end() {
131                if state.starts_with("*/") {
132                    state.advance(2);
133                    break;
134                }
135                if let Some(ch) = state.current() {
136                    state.advance(ch.len_utf8());
137                }
138            }
139            state.add_token(SqlSyntaxKind::Comment, start, state.get_position());
140            return true;
141        }
142
143        false
144    }
145
146    fn lex_string_literal<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
147        let start = state.get_position();
148        if let Some(quote) = state.current() {
149            if quote != '\'' && quote != '"' {
150                return false;
151            }
152            state.advance(1);
153            let mut escaped = false;
154            while state.not_at_end() {
155                let ch = match state.peek() {
156                    Some(c) => c,
157                    None => break,
158                };
159
160                if ch == quote && !escaped {
161                    state.advance(1); // 消费结束引号
162                    break;
163                }
164                state.advance(ch.len_utf8());
165                if escaped {
166                    escaped = false;
167                    continue;
168                }
169                if ch == '\\' {
170                    escaped = true;
171                    continue;
172                }
173                if ch == '\n' || ch == '\r' {
174                    break;
175                }
176            }
177            state.add_token(SqlSyntaxKind::StringLiteral, start, state.get_position());
178            return true;
179        }
180        false
181    }
182
183    fn lex_number_literal<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
184        let start = state.get_position();
185        let first = match state.current() {
186            Some(c) => c,
187            None => return false,
188        };
189
190        if !first.is_ascii_digit() {
191            return false;
192        }
193
194        let mut is_float = false;
195        state.advance(1);
196
197        // 整数部分
198        while let Some(c) = state.peek() {
199            if c.is_ascii_digit() || c == '_' {
200                state.advance(1);
201            }
202            else {
203                break;
204            }
205        }
206
207        // 小数部分
208        if state.peek() == Some('.') {
209            let next = state.peek_next_n(1);
210            if next.map(|c| c.is_ascii_digit()).unwrap_or(false) {
211                is_float = true;
212                state.advance(1); // 消费 '.'
213                while let Some(c) = state.peek() {
214                    if c.is_ascii_digit() || c == '_' {
215                        state.advance(1);
216                    }
217                    else {
218                        break;
219                    }
220                }
221            }
222        }
223
224        // 指数部分
225        if let Some(c) = state.peek() {
226            if c == 'e' || c == 'E' {
227                let next = state.peek_next_n(1);
228                if next == Some('+') || next == Some('-') || next.map(|d| d.is_ascii_digit()).unwrap_or(false) {
229                    is_float = true;
230                    state.advance(1);
231                    if let Some(sign) = state.peek() {
232                        if sign == '+' || sign == '-' {
233                            state.advance(1);
234                        }
235                    }
236                    while let Some(d) = state.peek() {
237                        if d.is_ascii_digit() || d == '_' {
238                            state.advance(1);
239                        }
240                        else {
241                            break;
242                        }
243                    }
244                }
245            }
246        }
247
248        let end = state.get_position();
249        state.add_token(if is_float { SqlSyntaxKind::FloatLiteral } else { SqlSyntaxKind::NumberLiteral }, start, end);
250        true
251    }
252
253    fn lex_identifier_or_keyword<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
254        let start = state.get_position();
255        let ch = match state.current() {
256            Some(c) => c,
257            None => return false,
258        };
259
260        if !ch.is_alphabetic() && ch != '_' {
261            return false;
262        }
263
264        state.advance(ch.len_utf8());
265        while let Some(c) = state.peek() {
266            if c.is_alphanumeric() || c == '_' {
267                state.advance(c.len_utf8());
268            }
269            else {
270                break;
271            }
272        }
273
274        let end = state.get_position();
275        let text = state.source().get_text_in(oak_core::Range { start, end }).to_uppercase();
276        let kind = match text.as_str() {
277            "SELECT" => SqlSyntaxKind::Select,
278            "FROM" => SqlSyntaxKind::From,
279            "WHERE" => SqlSyntaxKind::Where,
280            "INSERT" => SqlSyntaxKind::Insert,
281            "UPDATE" => SqlSyntaxKind::Update,
282            "DELETE" => SqlSyntaxKind::Delete,
283            "CREATE" => SqlSyntaxKind::Create,
284            "DROP" => SqlSyntaxKind::Drop,
285            "ALTER" => SqlSyntaxKind::Alter,
286            "TABLE" => SqlSyntaxKind::Table,
287            "INDEX" => SqlSyntaxKind::Index,
288            "INTO" => SqlSyntaxKind::Into,
289            "VALUES" => SqlSyntaxKind::Values,
290            "SET" => SqlSyntaxKind::Set,
291            "JOIN" => SqlSyntaxKind::Join,
292            "ON" => SqlSyntaxKind::On,
293            "AND" => SqlSyntaxKind::And,
294            "OR" => SqlSyntaxKind::Or,
295            "NOT" => SqlSyntaxKind::Not,
296            "NULL" => SqlSyntaxKind::Null,
297            "TRUE" => SqlSyntaxKind::True,
298            "FALSE" => SqlSyntaxKind::False,
299            "AS" => SqlSyntaxKind::As,
300            "BY" => SqlSyntaxKind::By,
301            "ORDER" => SqlSyntaxKind::Order,
302            "GROUP" => SqlSyntaxKind::Group,
303            "HAVING" => SqlSyntaxKind::Having,
304            "LIMIT" => SqlSyntaxKind::Limit,
305            "OFFSET" => SqlSyntaxKind::Offset,
306            "UNION" => SqlSyntaxKind::Union,
307            "ALL" => SqlSyntaxKind::All,
308            "DISTINCT" => SqlSyntaxKind::Distinct,
309            "PRIMARY" => SqlSyntaxKind::Primary,
310            "KEY" => SqlSyntaxKind::Key,
311            "FOREIGN" => SqlSyntaxKind::Foreign,
312            "REFERENCES" => SqlSyntaxKind::References,
313            "DEFAULT" => SqlSyntaxKind::Default,
314            "UNIQUE" => SqlSyntaxKind::Unique,
315            "AUTO_INCREMENT" => SqlSyntaxKind::AutoIncrement,
316            "INT" => SqlSyntaxKind::Int,
317            "INTEGER" => SqlSyntaxKind::Integer,
318            "VARCHAR" => SqlSyntaxKind::Varchar,
319            "CHAR" => SqlSyntaxKind::Char,
320            "TEXT" => SqlSyntaxKind::Text,
321            "DATE" => SqlSyntaxKind::Date,
322            "TIME" => SqlSyntaxKind::Time,
323            "TIMESTAMP" => SqlSyntaxKind::Timestamp,
324            "DECIMAL" => SqlSyntaxKind::Decimal,
325            "FLOAT" => SqlSyntaxKind::Float,
326            "DOUBLE" => SqlSyntaxKind::Double,
327            "BOOLEAN" => SqlSyntaxKind::Boolean,
328            _ => SqlSyntaxKind::Identifier,
329        };
330
331        state.add_token(kind, start, end);
332        true
333    }
334
335    fn lex_operators<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
336        let start = state.get_position();
337
338        let ops = [
339            ("<=", SqlSyntaxKind::LessEqual),
340            (">=", SqlSyntaxKind::GreaterEqual),
341            ("<>", SqlSyntaxKind::NotEqual),
342            ("!=", SqlSyntaxKind::NotEqual),
343            ("=", SqlSyntaxKind::Equal),
344            ("<", SqlSyntaxKind::Less),
345            (">", SqlSyntaxKind::Greater),
346            ("+", SqlSyntaxKind::Plus),
347            ("-", SqlSyntaxKind::Minus),
348            ("*", SqlSyntaxKind::Star),
349            ("/", SqlSyntaxKind::Slash),
350            ("%", SqlSyntaxKind::Percent),
351        ];
352
353        for (op, kind) in ops {
354            if state.starts_with(op) {
355                state.advance(op.len());
356                state.add_token(kind, start, state.get_position());
357                return true;
358            }
359        }
360
361        false
362    }
363
364    fn lex_single_char_tokens<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
365        let start = state.get_position();
366        let ch = match state.current() {
367            Some(c) => c,
368            None => return false,
369        };
370
371        let kind = match ch {
372            '(' => SqlSyntaxKind::LeftParen,
373            ')' => SqlSyntaxKind::RightParen,
374            ',' => SqlSyntaxKind::Comma,
375            ';' => SqlSyntaxKind::Semicolon,
376            '.' => SqlSyntaxKind::Dot,
377            _ => return false,
378        };
379
380        state.advance(ch.len_utf8());
381        state.add_token(kind, start, state.get_position());
382        true
383    }
384}