Skip to main content

oak_sql/lexer/
mod.rs

1#![doc = include_str!("readme.md")]
2use oak_core::Source;
3/// Token types for SQL.
4pub mod token_type;
5pub use token_type::SqlTokenType;
6
7use crate::language::SqlLanguage;
8use oak_core::{
9    Lexer, LexerCache, LexerState, OakError, TextEdit,
10    lexer::{LexOutput, WhitespaceConfig},
11};
12use std::sync::LazyLock;
13
14pub(crate) type State<'a, S> = LexerState<'a, S, SqlLanguage>;
15
16static SQL_WHITESPACE: LazyLock<WhitespaceConfig> = LazyLock::new(|| WhitespaceConfig { unicode_whitespace: true });
17
18/// Lexer for SQL.
19///
20/// This lexer is responsible for breaking down SQL source text into a stream of
21/// tokens. it handles different SQL dialects and supports incremental lexing
22/// through the [`Lexer`] trait.
23///
24/// # Supported Features
25///
26/// - Case-insensitive keywords
27/// - Multiple identifier quoting styles (double quotes, backticks, brackets)
28/// - Various literal types (strings, numbers, booleans)
29/// - Comments (line and block)
30#[derive(Clone, Debug)]
31pub struct SqlLexer<'config> {
32    config: &'config SqlLanguage,
33}
34
35impl<'config> Lexer<SqlLanguage> for SqlLexer<'config> {
36    fn lex<'a, S: Source + ?Sized>(&self, text: &S, _edits: &[TextEdit], cache: &'a mut impl LexerCache<SqlLanguage>) -> LexOutput<SqlLanguage> {
37        let mut state = State::new(text);
38        let result = self.run(&mut state);
39        if result.is_ok() {
40            state.add_eof();
41        }
42        state.finish_with_cache(result, cache)
43    }
44}
45
46impl<'config> SqlLexer<'config> {
47    /// Creates a new `SqlLexer` with the given configuration.
48    pub fn new(config: &'config SqlLanguage) -> Self {
49        Self { config }
50    }
51
52    fn run<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
53        while state.not_at_end() {
54            let safe_point = state.get_position();
55
56            if let Some(ch) = state.peek() {
57                match ch {
58                    ' ' | '\t' => {
59                        self.skip_whitespace(state);
60                    }
61                    '\n' | '\r' => {
62                        self.lex_newline(state);
63                    }
64                    '-' => {
65                        if state.starts_with("--") {
66                            self.skip_comment(state);
67                        }
68                        else {
69                            self.lex_operators(state);
70                        }
71                    }
72                    '/' => {
73                        if state.starts_with("/*") {
74                            self.skip_comment(state);
75                        }
76                        else {
77                            self.lex_operators(state);
78                        }
79                    }
80                    '\'' | '"' => {
81                        self.lex_string_literal(state);
82                    }
83                    '`' if self.config.backtick_identifiers => {
84                        self.lex_quoted_identifier(state, '`');
85                    }
86                    '[' if self.config.bracket_identifiers => {
87                        self.lex_bracket_identifier(state);
88                    }
89                    '0'..='9' => {
90                        self.lex_number_literal(state);
91                    }
92                    'a'..='z' | 'A'..='Z' | '_' => {
93                        self.lex_identifier_or_keyword(state);
94                    }
95                    '<' | '>' | '!' | '=' | '+' | '*' | '%' => {
96                        self.lex_operators(state);
97                    }
98                    '(' | ')' | ',' | ';' | '.' | ':' | '[' | ']' => {
99                        self.lex_single_char_tokens(state);
100                    }
101                    _ => {
102                        // If no patterns match, skip current character and add error token
103                        state.advance(ch.len_utf8());
104                        state.add_token(SqlTokenType::Error, safe_point, state.get_position());
105                    }
106                }
107            }
108
109            state.advance_if_dead_lock(safe_point);
110        }
111        Ok(())
112    }
113
114    /// Handles newlines
115    fn lex_newline<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
116        let start_pos = state.get_position();
117
118        if let Some('\n') = state.peek() {
119            state.advance(1);
120            state.add_token(SqlTokenType::Newline, start_pos, state.get_position());
121            true
122        }
123        else if let Some('\r') = state.peek() {
124            state.advance(1);
125            if let Some('\n') = state.peek() {
126                state.advance(1);
127            }
128            state.add_token(SqlTokenType::Newline, start_pos, state.get_position());
129            true
130        }
131        else {
132            false
133        }
134    }
135
136    fn skip_whitespace<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
137        SQL_WHITESPACE.scan(state, SqlTokenType::Whitespace);
138        true
139    }
140
141    fn skip_comment<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
142        let start = state.get_position();
143
144        // Line comment: -- ... until newline
145        if state.starts_with("--") {
146            state.advance(2);
147            state.take_while(|ch| ch != '\n' && ch != '\r');
148            state.add_token(SqlTokenType::Comment, start, state.get_position());
149            return true;
150        }
151
152        // Block comment: /* ... */
153        if state.starts_with("/*") {
154            state.advance(2);
155            while state.not_at_end() {
156                if state.starts_with("*/") {
157                    state.advance(2);
158                    break;
159                }
160                if let Some(ch) = state.current() {
161                    state.advance(ch.len_utf8());
162                }
163            }
164            state.add_token(SqlTokenType::Comment, start, state.get_position());
165            return true;
166        }
167
168        false
169    }
170
171    fn lex_string_literal<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
172        let start = state.get_position();
173        let quote = match state.current() {
174            Some(c) if c == '\'' || c == '"' => {
175                state.advance(c.len_utf8());
176                c
177            }
178            _ => return false,
179        };
180
181        while let Some(ch) = state.current() {
182            if ch == quote {
183                state.advance(ch.len_utf8());
184                // Handle escaped quotes if necessary (e.g. '' in SQL)
185                if state.peek() == Some(quote) {
186                    state.advance(quote.len_utf8());
187                    continue;
188                }
189                break;
190            }
191            state.advance(ch.len_utf8());
192        }
193
194        state.add_token(SqlTokenType::StringLiteral, start, state.get_position());
195        true
196    }
197
198    fn lex_quoted_identifier<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>, quote: char) -> bool {
199        let start = state.get_position();
200        state.advance(quote.len_utf8());
201
202        while let Some(ch) = state.current() {
203            if ch == quote {
204                state.advance(ch.len_utf8());
205                break;
206            }
207            state.advance(ch.len_utf8());
208        }
209
210        state.add_token(SqlTokenType::Identifier_, start, state.get_position());
211        true
212    }
213
214    fn lex_bracket_identifier<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
215        let start = state.get_position();
216        state.advance(1); // '['
217
218        while let Some(ch) = state.current() {
219            if ch == ']' {
220                state.advance(1);
221                break;
222            }
223            state.advance(ch.len_utf8());
224        }
225
226        state.add_token(SqlTokenType::Identifier_, start, state.get_position());
227        true
228    }
229
230    fn lex_number_literal<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
231        let start = state.get_position();
232        let first = match state.current() {
233            Some(c) => c,
234            None => return false,
235        };
236
237        if !first.is_ascii_digit() {
238            return false;
239        }
240
241        let mut is_float = false;
242        state.advance(1);
243
244        // Integer part
245        while let Some(c) = state.peek() {
246            if c.is_ascii_digit() || c == '_' {
247                state.advance(1);
248            }
249            else {
250                break;
251            }
252        }
253
254        // Decimal part
255        if state.peek() == Some('.') {
256            let next = state.peek_next_n(1);
257            if next.map(|c| c.is_ascii_digit()).unwrap_or(false) {
258                is_float = true;
259                state.advance(1); // consume '.'
260                while let Some(c) = state.peek() {
261                    if c.is_ascii_digit() || c == '_' {
262                        state.advance(1);
263                    }
264                    else {
265                        break;
266                    }
267                }
268            }
269        }
270
271        // Exponent part
272        if let Some(c) = state.peek() {
273            if c == 'e' || c == 'E' {
274                let next = state.peek_next_n(1);
275                if next == Some('+') || next == Some('-') || next.map(|d| d.is_ascii_digit()).unwrap_or(false) {
276                    is_float = true;
277                    state.advance(1);
278                    if let Some(sign) = state.peek() {
279                        if sign == '+' || sign == '-' {
280                            state.advance(1);
281                        }
282                    }
283                    while let Some(d) = state.peek() {
284                        if d.is_ascii_digit() || d == '_' {
285                            state.advance(1);
286                        }
287                        else {
288                            break;
289                        }
290                    }
291                }
292            }
293        }
294
295        let end = state.get_position();
296        state.add_token(if is_float { SqlTokenType::FloatLiteral } else { SqlTokenType::NumberLiteral }, start, end);
297        true
298    }
299
300    fn lex_identifier_or_keyword<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
301        let start = state.get_position();
302        let ch = match state.current() {
303            Some(c) => c,
304            None => return false,
305        };
306
307        if !ch.is_alphabetic() && ch != '_' {
308            return false;
309        }
310
311        state.advance(ch.len_utf8());
312        while let Some(c) = state.peek() {
313            if c.is_alphanumeric() || c == '_' {
314                state.advance(c.len_utf8());
315            }
316            else {
317                break;
318            }
319        }
320
321        let end = state.get_position();
322        let text = state.source().get_text_in(oak_core::Range { start, end }).to_uppercase();
323        let kind = match text.as_str() {
324            "SELECT" => SqlTokenType::Select,
325            "FROM" => SqlTokenType::From,
326            "WHERE" => SqlTokenType::Where,
327            "INSERT" => SqlTokenType::Insert,
328            "UPDATE" => SqlTokenType::Update,
329            "DELETE" => SqlTokenType::Delete,
330            "CREATE" => SqlTokenType::Create,
331            "DROP" => SqlTokenType::Drop,
332            "ALTER" => SqlTokenType::Alter,
333            "ADD" => SqlTokenType::Add,
334            "COLUMN" => SqlTokenType::Column,
335            "TABLE" => SqlTokenType::Table,
336            "VIEW" => SqlTokenType::View,
337            "INDEX" => SqlTokenType::Index,
338            "INTO" => SqlTokenType::Into,
339            "VALUES" => SqlTokenType::Values,
340            "SET" => SqlTokenType::Set,
341            "JOIN" => SqlTokenType::Join,
342            "INNER" => SqlTokenType::Inner,
343            "LEFT" => SqlTokenType::Left,
344            "RIGHT" => SqlTokenType::Right,
345            "FULL" => SqlTokenType::Full,
346            "OUTER" => SqlTokenType::Outer,
347            "ON" => SqlTokenType::On,
348            "AND" => SqlTokenType::And,
349            "OR" => SqlTokenType::Or,
350            "NOT" => SqlTokenType::Not,
351            "NULL" => SqlTokenType::Null,
352            "TRUE" => SqlTokenType::True,
353            "FALSE" => SqlTokenType::False,
354            "TRIGGER" => SqlTokenType::Trigger,
355            "AFTER" => SqlTokenType::After,
356            "DELIMITER" => SqlTokenType::Delimiter,
357            "FOR" => SqlTokenType::For,
358            "EACH" => SqlTokenType::Each,
359            "ROW" => SqlTokenType::Row,
360            "CHECK" => SqlTokenType::Check,
361            "BEGIN" => SqlTokenType::Begin,
362            "END" => SqlTokenType::End,
363            "IF" => SqlTokenType::If,
364            "EXISTS" => SqlTokenType::Exists,
365            "RENAME" => SqlTokenType::Rename,
366            "TO" => SqlTokenType::To,
367            "AS" => SqlTokenType::As,
368            "BY" => SqlTokenType::By,
369            "ORDER" => SqlTokenType::Order,
370            "ASC" => SqlTokenType::Asc,
371            "DESC" => SqlTokenType::Desc,
372            "GROUP" => SqlTokenType::Group,
373            "HAVING" => SqlTokenType::Having,
374            "LIMIT" => SqlTokenType::Limit,
375            "OFFSET" => SqlTokenType::Offset,
376            "UNION" => SqlTokenType::Union,
377            "ALL" => SqlTokenType::All,
378            "DISTINCT" => SqlTokenType::Distinct,
379            "PRIMARY" => SqlTokenType::Primary,
380            "KEY" => SqlTokenType::Key,
381            "FOREIGN" => SqlTokenType::Foreign,
382            "REFERENCES" => SqlTokenType::References,
383            "DEFAULT" => SqlTokenType::Default,
384            "UNIQUE" => SqlTokenType::Unique,
385            "AUTO_INCREMENT" | "AUTOINCREMENT" => SqlTokenType::AutoIncrement,
386            "INT" => SqlTokenType::Int,
387            "INTEGER" => SqlTokenType::Integer,
388            "VARCHAR" => SqlTokenType::Varchar,
389            "CHAR" => SqlTokenType::Char,
390            "TEXT" => SqlTokenType::Text,
391            "DATE" => SqlTokenType::Date,
392            "TIME" => SqlTokenType::Time,
393            "TIMESTAMP" => SqlTokenType::Timestamp,
394            "DECIMAL" => SqlTokenType::Decimal,
395            "FLOAT" => SqlTokenType::Float,
396            "DOUBLE" => SqlTokenType::Double,
397            "BOOLEAN" => SqlTokenType::Boolean,
398            "SERIAL" => SqlTokenType::Serial,
399            "BIGSERIAL" => SqlTokenType::BigSerial,
400            "CONFLICT" => SqlTokenType::Conflict,
401            "DO" => SqlTokenType::Do,
402            "NOTHING" => SqlTokenType::Nothing,
403            "RETURNING" => SqlTokenType::Returning,
404            "ILIKE" => SqlTokenType::Ilike,
405            "STRICT" => SqlTokenType::Strict,
406            "WITHOUT" => SqlTokenType::Without,
407            "ROWID" => SqlTokenType::Rowid,
408            "MAX" => SqlTokenType::Max,
409            "EXPLAIN" => SqlTokenType::Explain,
410            "PRAGMA" => SqlTokenType::Pragma,
411            "SHOW" => SqlTokenType::Show,
412            "DATABASE" => SqlTokenType::Database,
413            "SCHEMA" => SqlTokenType::Schema,
414            "VECTOR" => SqlTokenType::Vector,
415            _ => SqlTokenType::Identifier_,
416        };
417
418        state.add_token(kind, start, end);
419        true
420    }
421
422    fn lex_operators<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
423        let start = state.get_position();
424
425        let ops = [
426            ("::", SqlTokenType::DoubleColon),
427            ("||", SqlTokenType::Concat),
428            ("<=", SqlTokenType::LessEqual),
429            (">=", SqlTokenType::GreaterEqual),
430            ("<>", SqlTokenType::NotEqual),
431            ("!=", SqlTokenType::NotEqual),
432            ("=", SqlTokenType::Equal),
433            ("<", SqlTokenType::Less),
434            (">", SqlTokenType::Greater),
435            ("+", SqlTokenType::Plus),
436            ("-", SqlTokenType::Minus),
437            ("*", SqlTokenType::Star),
438            ("/", SqlTokenType::Slash),
439            ("%", SqlTokenType::Percent),
440        ];
441
442        for (op, kind) in ops {
443            if state.starts_with(op) {
444                state.advance(op.len());
445                state.add_token(kind, start, state.get_position());
446                return true;
447            }
448        }
449
450        false
451    }
452
453    fn lex_single_char_tokens<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
454        let start = state.get_position();
455        let ch = match state.current() {
456            Some(c) => c,
457            None => return false,
458        };
459
460        let kind = match ch {
461            '(' => SqlTokenType::LeftParen,
462            ')' => SqlTokenType::RightParen,
463            ',' => SqlTokenType::Comma,
464            ';' => SqlTokenType::Semicolon,
465            '.' => SqlTokenType::Dot,
466            ':' => SqlTokenType::Colon,
467            _ => return false,
468        };
469
470        state.advance(ch.len_utf8());
471        state.add_token(kind, start, state.get_position());
472        true
473    }
474}