Skip to main content

oak_sql/lexer/
mod.rs

1#![doc = include_str!("readme.md")]
2use oak_core::Source;
3pub mod token_type;
4pub use token_type::SqlTokenType;
5
6use crate::language::SqlLanguage;
7use oak_core::{
8    Lexer, LexerCache, LexerState, OakError, TextEdit,
9    lexer::{LexOutput, WhitespaceConfig},
10};
11use std::sync::LazyLock;
12
13type State<'a, S> = LexerState<'a, S, SqlLanguage>;
14
15static SQL_WHITESPACE: LazyLock<WhitespaceConfig> = LazyLock::new(|| WhitespaceConfig { unicode_whitespace: true });
16
17#[derive(Clone, Debug)]
18pub struct SqlLexer<'config> {
19    _config: &'config SqlLanguage,
20}
21
22impl<'config> Lexer<SqlLanguage> for SqlLexer<'config> {
23    fn lex<'a, S: Source + ?Sized>(&self, text: &'a S, _edits: &[TextEdit], cache: &'a mut impl LexerCache<SqlLanguage>) -> LexOutput<SqlLanguage> {
24        let mut state = State::new(text);
25        let result = self.run(&mut state);
26        if result.is_ok() {
27            state.add_eof();
28        }
29        state.finish_with_cache(result, cache)
30    }
31}
32
33impl<'config> SqlLexer<'config> {
34    pub fn new(config: &'config SqlLanguage) -> Self {
35        Self { _config: config }
36    }
37
38    fn run<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
39        while state.not_at_end() {
40            let safe_point = state.get_position();
41
42            if let Some(ch) = state.peek() {
43                match ch {
44                    ' ' | '\t' => {
45                        self.skip_whitespace(state);
46                    }
47                    '\n' | '\r' => {
48                        self.lex_newline(state);
49                    }
50                    '-' => {
51                        if state.starts_with("--") {
52                            self.skip_comment(state);
53                        }
54                        else {
55                            self.lex_operators(state);
56                        }
57                    }
58                    '/' => {
59                        if state.starts_with("/*") {
60                            self.skip_comment(state);
61                        }
62                        else {
63                            self.lex_operators(state);
64                        }
65                    }
66                    '\'' | '"' => {
67                        self.lex_string_literal(state);
68                    }
69                    '0'..='9' => {
70                        self.lex_number_literal(state);
71                    }
72                    'a'..='z' | 'A'..='Z' | '_' => {
73                        self.lex_identifier_or_keyword(state);
74                    }
75                    '<' | '>' | '!' | '=' | '+' | '*' | '%' => {
76                        self.lex_operators(state);
77                    }
78                    '(' | ')' | ',' | ';' | '.' => {
79                        self.lex_single_char_tokens(state);
80                    }
81                    _ => {
82                        // 如果没有匹配任何模式,跳过当前字符并添加错误 token
83                        state.advance(ch.len_utf8());
84                        state.add_token(SqlTokenType::Error, safe_point, state.get_position());
85                    }
86                }
87            }
88
89            state.advance_if_dead_lock(safe_point);
90        }
91        Ok(())
92    }
93
94    /// 处理换行
95    fn lex_newline<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
96        let start_pos = state.get_position();
97
98        if let Some('\n') = state.peek() {
99            state.advance(1);
100            state.add_token(SqlTokenType::Newline, start_pos, state.get_position());
101            true
102        }
103        else if let Some('\r') = state.peek() {
104            state.advance(1);
105            if let Some('\n') = state.peek() {
106                state.advance(1);
107            }
108            state.add_token(SqlTokenType::Newline, start_pos, state.get_position());
109            true
110        }
111        else {
112            false
113        }
114    }
115
116    fn skip_whitespace<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
117        SQL_WHITESPACE.scan(state, SqlTokenType::Whitespace);
118        true
119    }
120
121    fn skip_comment<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
122        let start = state.get_position();
123
124        // 行注释: -- ... 直到换行
125        if state.starts_with("--") {
126            state.advance(2);
127            state.take_while(|ch| ch != '\n' && ch != '\r');
128            state.add_token(SqlTokenType::Comment, start, state.get_position());
129            return true;
130        }
131
132        // 块注释: /* ... */
133        if state.starts_with("/*") {
134            state.advance(2);
135            while state.not_at_end() {
136                if state.starts_with("*/") {
137                    state.advance(2);
138                    break;
139                }
140                if let Some(ch) = state.current() {
141                    state.advance(ch.len_utf8());
142                }
143            }
144            state.add_token(SqlTokenType::Comment, start, state.get_position());
145            return true;
146        }
147
148        false
149    }
150
151    fn lex_string_literal<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
152        let start = state.get_position();
153        if let Some(quote) = state.current() {
154            if quote != '\'' && quote != '"' {
155                return false;
156            }
157            state.advance(1);
158            let mut escaped = false;
159            while state.not_at_end() {
160                let ch = match state.peek() {
161                    Some(c) => c,
162                    None => break,
163                };
164
165                if ch == quote && !escaped {
166                    state.advance(1); // 消费结束引号
167                    break;
168                }
169                state.advance(ch.len_utf8());
170                if escaped {
171                    escaped = false;
172                    continue;
173                }
174                if ch == '\\' {
175                    escaped = true;
176                    continue;
177                }
178                if ch == '\n' || ch == '\r' {
179                    break;
180                }
181            }
182            state.add_token(SqlTokenType::StringLiteral, start, state.get_position());
183            true
184        }
185        else {
186            false
187        }
188    }
189
190    fn lex_number_literal<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
191        let start = state.get_position();
192        let first = match state.current() {
193            Some(c) => c,
194            None => return false,
195        };
196
197        if !first.is_ascii_digit() {
198            return false;
199        }
200
201        let mut is_float = false;
202        state.advance(1);
203
204        // 整数部分
205        while let Some(c) = state.peek() {
206            if c.is_ascii_digit() || c == '_' {
207                state.advance(1);
208            }
209            else {
210                break;
211            }
212        }
213
214        // 小数部分
215        if state.peek() == Some('.') {
216            let next = state.peek_next_n(1);
217            if next.map(|c| c.is_ascii_digit()).unwrap_or(false) {
218                is_float = true;
219                state.advance(1); // 消费 '.'
220                while let Some(c) = state.peek() {
221                    if c.is_ascii_digit() || c == '_' {
222                        state.advance(1);
223                    }
224                    else {
225                        break;
226                    }
227                }
228            }
229        }
230
231        // 指数部分
232        if let Some(c) = state.peek() {
233            if c == 'e' || c == 'E' {
234                let next = state.peek_next_n(1);
235                if next == Some('+') || next == Some('-') || next.map(|d| d.is_ascii_digit()).unwrap_or(false) {
236                    is_float = true;
237                    state.advance(1);
238                    if let Some(sign) = state.peek() {
239                        if sign == '+' || sign == '-' {
240                            state.advance(1);
241                        }
242                    }
243                    while let Some(d) = state.peek() {
244                        if d.is_ascii_digit() || d == '_' {
245                            state.advance(1);
246                        }
247                        else {
248                            break;
249                        }
250                    }
251                }
252            }
253        }
254
255        let end = state.get_position();
256        state.add_token(if is_float { SqlTokenType::FloatLiteral } else { SqlTokenType::NumberLiteral }, start, end);
257        true
258    }
259
260    fn lex_identifier_or_keyword<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
261        let start = state.get_position();
262        let ch = match state.current() {
263            Some(c) => c,
264            None => return false,
265        };
266
267        if !ch.is_alphabetic() && ch != '_' {
268            return false;
269        }
270
271        state.advance(ch.len_utf8());
272        while let Some(c) = state.peek() {
273            if c.is_alphanumeric() || c == '_' {
274                state.advance(c.len_utf8());
275            }
276            else {
277                break;
278            }
279        }
280
281        let end = state.get_position();
282        let text = state.source().get_text_in(oak_core::Range { start, end }).to_uppercase();
283        let kind = match text.as_str() {
284            "SELECT" => SqlTokenType::Select,
285            "FROM" => SqlTokenType::From,
286            "WHERE" => SqlTokenType::Where,
287            "INSERT" => SqlTokenType::Insert,
288            "UPDATE" => SqlTokenType::Update,
289            "DELETE" => SqlTokenType::Delete,
290            "CREATE" => SqlTokenType::Create,
291            "DROP" => SqlTokenType::Drop,
292            "ALTER" => SqlTokenType::Alter,
293            "TABLE" => SqlTokenType::Table,
294            "INDEX" => SqlTokenType::Index,
295            "INTO" => SqlTokenType::Into,
296            "VALUES" => SqlTokenType::Values,
297            "SET" => SqlTokenType::Set,
298            "JOIN" => SqlTokenType::Join,
299            "INNER" => SqlTokenType::Inner,
300            "LEFT" => SqlTokenType::Left,
301            "RIGHT" => SqlTokenType::Right,
302            "FULL" => SqlTokenType::Full,
303            "OUTER" => SqlTokenType::Outer,
304            "ON" => SqlTokenType::On,
305            "AND" => SqlTokenType::And,
306            "OR" => SqlTokenType::Or,
307            "NOT" => SqlTokenType::Not,
308            "NULL" => SqlTokenType::Null,
309            "TRUE" => SqlTokenType::True,
310            "FALSE" => SqlTokenType::False,
311            "AS" => SqlTokenType::As,
312            "BY" => SqlTokenType::By,
313            "ORDER" => SqlTokenType::Order,
314            "ASC" => SqlTokenType::Asc,
315            "DESC" => SqlTokenType::Desc,
316            "GROUP" => SqlTokenType::Group,
317            "HAVING" => SqlTokenType::Having,
318            "LIMIT" => SqlTokenType::Limit,
319            "OFFSET" => SqlTokenType::Offset,
320            "UNION" => SqlTokenType::Union,
321            "ALL" => SqlTokenType::All,
322            "DISTINCT" => SqlTokenType::Distinct,
323            "PRIMARY" => SqlTokenType::Primary,
324            "KEY" => SqlTokenType::Key,
325            "FOREIGN" => SqlTokenType::Foreign,
326            "REFERENCES" => SqlTokenType::References,
327            "DEFAULT" => SqlTokenType::Default,
328            "UNIQUE" => SqlTokenType::Unique,
329            "AUTO_INCREMENT" => SqlTokenType::AutoIncrement,
330            "INT" => SqlTokenType::Int,
331            "INTEGER" => SqlTokenType::Integer,
332            "VARCHAR" => SqlTokenType::Varchar,
333            "CHAR" => SqlTokenType::Char,
334            "TEXT" => SqlTokenType::Text,
335            "DATE" => SqlTokenType::Date,
336            "TIME" => SqlTokenType::Time,
337            "TIMESTAMP" => SqlTokenType::Timestamp,
338            "DECIMAL" => SqlTokenType::Decimal,
339            "FLOAT" => SqlTokenType::Float,
340            "DOUBLE" => SqlTokenType::Double,
341            "BOOLEAN" => SqlTokenType::Boolean,
342            _ => SqlTokenType::Identifier,
343        };
344
345        state.add_token(kind, start, end);
346        true
347    }
348
349    fn lex_operators<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
350        let start = state.get_position();
351
352        let ops = [
353            ("<=", SqlTokenType::LessEqual),
354            (">=", SqlTokenType::GreaterEqual),
355            ("<>", SqlTokenType::NotEqual),
356            ("!=", SqlTokenType::NotEqual),
357            ("=", SqlTokenType::Equal),
358            ("<", SqlTokenType::Less),
359            (">", SqlTokenType::Greater),
360            ("+", SqlTokenType::Plus),
361            ("-", SqlTokenType::Minus),
362            ("*", SqlTokenType::Star),
363            ("/", SqlTokenType::Slash),
364            ("%", SqlTokenType::Percent),
365        ];
366
367        for (op, kind) in ops {
368            if state.starts_with(op) {
369                state.advance(op.len());
370                state.add_token(kind, start, state.get_position());
371                return true;
372            }
373        }
374
375        false
376    }
377
378    fn lex_single_char_tokens<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
379        let start = state.get_position();
380        let ch = match state.current() {
381            Some(c) => c,
382            None => return false,
383        };
384
385        let kind = match ch {
386            '(' => SqlTokenType::LeftParen,
387            ')' => SqlTokenType::RightParen,
388            ',' => SqlTokenType::Comma,
389            ';' => SqlTokenType::Semicolon,
390            '.' => SqlTokenType::Dot,
391            _ => return false,
392        };
393
394        state.advance(ch.len_utf8());
395        state.add_token(kind, start, state.get_position());
396        true
397    }
398}