1#![doc = include_str!("readme.md")]
2use oak_core::Source;
3pub mod token_type;
5pub use token_type::SqlTokenType;
6
7use crate::language::SqlLanguage;
8use oak_core::{
9 Lexer, LexerCache, LexerState, OakError, TextEdit,
10 lexer::{LexOutput, WhitespaceConfig},
11};
12use std::sync::LazyLock;
13
14pub(crate) type State<'a, S> = LexerState<'a, S, SqlLanguage>;
15
16static SQL_WHITESPACE: LazyLock<WhitespaceConfig> = LazyLock::new(|| WhitespaceConfig { unicode_whitespace: true });
17
18#[derive(Clone, Debug)]
31pub struct SqlLexer<'config> {
32 config: &'config SqlLanguage,
33}
34
35impl<'config> Lexer<SqlLanguage> for SqlLexer<'config> {
36 fn lex<'a, S: Source + ?Sized>(&self, text: &S, _edits: &[TextEdit], cache: &'a mut impl LexerCache<SqlLanguage>) -> LexOutput<SqlLanguage> {
37 let mut state = State::new(text);
38 let result = self.run(&mut state);
39 if result.is_ok() {
40 state.add_eof();
41 }
42 state.finish_with_cache(result, cache)
43 }
44}
45
46impl<'config> SqlLexer<'config> {
47 pub fn new(config: &'config SqlLanguage) -> Self {
49 Self { config }
50 }
51
52 fn run<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
53 while state.not_at_end() {
54 let safe_point = state.get_position();
55
56 if let Some(ch) = state.peek() {
57 match ch {
58 ' ' | '\t' => {
59 self.skip_whitespace(state);
60 }
61 '\n' | '\r' => {
62 self.lex_newline(state);
63 }
64 '-' => {
65 if state.starts_with("--") {
66 self.skip_comment(state);
67 }
68 else {
69 self.lex_operators(state);
70 }
71 }
72 '/' => {
73 if state.starts_with("/*") {
74 self.skip_comment(state);
75 }
76 else {
77 self.lex_operators(state);
78 }
79 }
80 '\'' | '"' => {
81 self.lex_string_literal(state);
82 }
83 '`' if self.config.backtick_identifiers => {
84 self.lex_quoted_identifier(state, '`');
85 }
86 '[' if self.config.bracket_identifiers => {
87 self.lex_bracket_identifier(state);
88 }
89 '0'..='9' => {
90 self.lex_number_literal(state);
91 }
92 'a'..='z' | 'A'..='Z' | '_' => {
93 self.lex_identifier_or_keyword(state);
94 }
95 '<' | '>' | '!' | '=' | '+' | '*' | '%' => {
96 self.lex_operators(state);
97 }
98 '(' | ')' | ',' | ';' | '.' | ':' | '[' | ']' => {
99 self.lex_single_char_tokens(state);
100 }
101 _ => {
102 state.advance(ch.len_utf8());
104 state.add_token(SqlTokenType::Error, safe_point, state.get_position());
105 }
106 }
107 }
108
109 state.advance_if_dead_lock(safe_point);
110 }
111 Ok(())
112 }
113
114 fn lex_newline<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
116 let start_pos = state.get_position();
117
118 if let Some('\n') = state.peek() {
119 state.advance(1);
120 state.add_token(SqlTokenType::Newline, start_pos, state.get_position());
121 true
122 }
123 else if let Some('\r') = state.peek() {
124 state.advance(1);
125 if let Some('\n') = state.peek() {
126 state.advance(1);
127 }
128 state.add_token(SqlTokenType::Newline, start_pos, state.get_position());
129 true
130 }
131 else {
132 false
133 }
134 }
135
136 fn skip_whitespace<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
137 SQL_WHITESPACE.scan(state, SqlTokenType::Whitespace);
138 true
139 }
140
141 fn skip_comment<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
142 let start = state.get_position();
143
144 if state.starts_with("--") {
146 state.advance(2);
147 state.take_while(|ch| ch != '\n' && ch != '\r');
148 state.add_token(SqlTokenType::Comment, start, state.get_position());
149 return true;
150 }
151
152 if state.starts_with("/*") {
154 state.advance(2);
155 while state.not_at_end() {
156 if state.starts_with("*/") {
157 state.advance(2);
158 break;
159 }
160 if let Some(ch) = state.current() {
161 state.advance(ch.len_utf8());
162 }
163 }
164 state.add_token(SqlTokenType::Comment, start, state.get_position());
165 return true;
166 }
167
168 false
169 }
170
171 fn lex_string_literal<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
172 let start = state.get_position();
173 let quote = match state.current() {
174 Some(c) if c == '\'' || c == '"' => {
175 state.advance(c.len_utf8());
176 c
177 }
178 _ => return false,
179 };
180
181 while let Some(ch) = state.current() {
182 if ch == quote {
183 state.advance(ch.len_utf8());
184 if state.peek() == Some(quote) {
186 state.advance(quote.len_utf8());
187 continue;
188 }
189 break;
190 }
191 state.advance(ch.len_utf8());
192 }
193
194 state.add_token(SqlTokenType::StringLiteral, start, state.get_position());
195 true
196 }
197
198 fn lex_quoted_identifier<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>, quote: char) -> bool {
199 let start = state.get_position();
200 state.advance(quote.len_utf8());
201
202 while let Some(ch) = state.current() {
203 if ch == quote {
204 state.advance(ch.len_utf8());
205 break;
206 }
207 state.advance(ch.len_utf8());
208 }
209
210 state.add_token(SqlTokenType::Identifier_, start, state.get_position());
211 true
212 }
213
214 fn lex_bracket_identifier<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
215 let start = state.get_position();
216 state.advance(1); while let Some(ch) = state.current() {
219 if ch == ']' {
220 state.advance(1);
221 break;
222 }
223 state.advance(ch.len_utf8());
224 }
225
226 state.add_token(SqlTokenType::Identifier_, start, state.get_position());
227 true
228 }
229
230 fn lex_number_literal<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
231 let start = state.get_position();
232 let first = match state.current() {
233 Some(c) => c,
234 None => return false,
235 };
236
237 if !first.is_ascii_digit() {
238 return false;
239 }
240
241 let mut is_float = false;
242 state.advance(1);
243
244 while let Some(c) = state.peek() {
246 if c.is_ascii_digit() || c == '_' {
247 state.advance(1);
248 }
249 else {
250 break;
251 }
252 }
253
254 if state.peek() == Some('.') {
256 let next = state.peek_next_n(1);
257 if next.map(|c| c.is_ascii_digit()).unwrap_or(false) {
258 is_float = true;
259 state.advance(1); while let Some(c) = state.peek() {
261 if c.is_ascii_digit() || c == '_' {
262 state.advance(1);
263 }
264 else {
265 break;
266 }
267 }
268 }
269 }
270
271 if let Some(c) = state.peek() {
273 if c == 'e' || c == 'E' {
274 let next = state.peek_next_n(1);
275 if next == Some('+') || next == Some('-') || next.map(|d| d.is_ascii_digit()).unwrap_or(false) {
276 is_float = true;
277 state.advance(1);
278 if let Some(sign) = state.peek() {
279 if sign == '+' || sign == '-' {
280 state.advance(1);
281 }
282 }
283 while let Some(d) = state.peek() {
284 if d.is_ascii_digit() || d == '_' {
285 state.advance(1);
286 }
287 else {
288 break;
289 }
290 }
291 }
292 }
293 }
294
295 let end = state.get_position();
296 state.add_token(if is_float { SqlTokenType::FloatLiteral } else { SqlTokenType::NumberLiteral }, start, end);
297 true
298 }
299
300 fn lex_identifier_or_keyword<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
301 let start = state.get_position();
302 let ch = match state.current() {
303 Some(c) => c,
304 None => return false,
305 };
306
307 if !ch.is_alphabetic() && ch != '_' {
308 return false;
309 }
310
311 state.advance(ch.len_utf8());
312 while let Some(c) = state.peek() {
313 if c.is_alphanumeric() || c == '_' {
314 state.advance(c.len_utf8());
315 }
316 else {
317 break;
318 }
319 }
320
321 let end = state.get_position();
322 let text = state.source().get_text_in(oak_core::Range { start, end }).to_uppercase();
323 let kind = match text.as_str() {
324 "SELECT" => SqlTokenType::Select,
325 "FROM" => SqlTokenType::From,
326 "WHERE" => SqlTokenType::Where,
327 "INSERT" => SqlTokenType::Insert,
328 "UPDATE" => SqlTokenType::Update,
329 "DELETE" => SqlTokenType::Delete,
330 "CREATE" => SqlTokenType::Create,
331 "DROP" => SqlTokenType::Drop,
332 "ALTER" => SqlTokenType::Alter,
333 "ADD" => SqlTokenType::Add,
334 "COLUMN" => SqlTokenType::Column,
335 "TABLE" => SqlTokenType::Table,
336 "VIEW" => SqlTokenType::View,
337 "INDEX" => SqlTokenType::Index,
338 "INTO" => SqlTokenType::Into,
339 "VALUES" => SqlTokenType::Values,
340 "SET" => SqlTokenType::Set,
341 "JOIN" => SqlTokenType::Join,
342 "INNER" => SqlTokenType::Inner,
343 "LEFT" => SqlTokenType::Left,
344 "RIGHT" => SqlTokenType::Right,
345 "FULL" => SqlTokenType::Full,
346 "OUTER" => SqlTokenType::Outer,
347 "ON" => SqlTokenType::On,
348 "AND" => SqlTokenType::And,
349 "OR" => SqlTokenType::Or,
350 "NOT" => SqlTokenType::Not,
351 "NULL" => SqlTokenType::Null,
352 "TRUE" => SqlTokenType::True,
353 "FALSE" => SqlTokenType::False,
354 "TRIGGER" => SqlTokenType::Trigger,
355 "AFTER" => SqlTokenType::After,
356 "DELIMITER" => SqlTokenType::Delimiter,
357 "FOR" => SqlTokenType::For,
358 "EACH" => SqlTokenType::Each,
359 "ROW" => SqlTokenType::Row,
360 "CHECK" => SqlTokenType::Check,
361 "BEGIN" => SqlTokenType::Begin,
362 "END" => SqlTokenType::End,
363 "IF" => SqlTokenType::If,
364 "EXISTS" => SqlTokenType::Exists,
365 "RENAME" => SqlTokenType::Rename,
366 "TO" => SqlTokenType::To,
367 "AS" => SqlTokenType::As,
368 "BY" => SqlTokenType::By,
369 "ORDER" => SqlTokenType::Order,
370 "ASC" => SqlTokenType::Asc,
371 "DESC" => SqlTokenType::Desc,
372 "GROUP" => SqlTokenType::Group,
373 "HAVING" => SqlTokenType::Having,
374 "LIMIT" => SqlTokenType::Limit,
375 "OFFSET" => SqlTokenType::Offset,
376 "UNION" => SqlTokenType::Union,
377 "ALL" => SqlTokenType::All,
378 "DISTINCT" => SqlTokenType::Distinct,
379 "PRIMARY" => SqlTokenType::Primary,
380 "KEY" => SqlTokenType::Key,
381 "FOREIGN" => SqlTokenType::Foreign,
382 "REFERENCES" => SqlTokenType::References,
383 "DEFAULT" => SqlTokenType::Default,
384 "UNIQUE" => SqlTokenType::Unique,
385 "AUTO_INCREMENT" | "AUTOINCREMENT" => SqlTokenType::AutoIncrement,
386 "INT" => SqlTokenType::Int,
387 "INTEGER" => SqlTokenType::Integer,
388 "VARCHAR" => SqlTokenType::Varchar,
389 "CHAR" => SqlTokenType::Char,
390 "TEXT" => SqlTokenType::Text,
391 "DATE" => SqlTokenType::Date,
392 "TIME" => SqlTokenType::Time,
393 "TIMESTAMP" => SqlTokenType::Timestamp,
394 "DECIMAL" => SqlTokenType::Decimal,
395 "FLOAT" => SqlTokenType::Float,
396 "DOUBLE" => SqlTokenType::Double,
397 "BOOLEAN" => SqlTokenType::Boolean,
398 "SERIAL" => SqlTokenType::Serial,
399 "BIGSERIAL" => SqlTokenType::BigSerial,
400 "CONFLICT" => SqlTokenType::Conflict,
401 "DO" => SqlTokenType::Do,
402 "NOTHING" => SqlTokenType::Nothing,
403 "RETURNING" => SqlTokenType::Returning,
404 "ILIKE" => SqlTokenType::Ilike,
405 "STRICT" => SqlTokenType::Strict,
406 "WITHOUT" => SqlTokenType::Without,
407 "ROWID" => SqlTokenType::Rowid,
408 "MAX" => SqlTokenType::Max,
409 "EXPLAIN" => SqlTokenType::Explain,
410 "PRAGMA" => SqlTokenType::Pragma,
411 "SHOW" => SqlTokenType::Show,
412 "DATABASE" => SqlTokenType::Database,
413 "SCHEMA" => SqlTokenType::Schema,
414 "VECTOR" => SqlTokenType::Vector,
415 _ => SqlTokenType::Identifier_,
416 };
417
418 state.add_token(kind, start, end);
419 true
420 }
421
422 fn lex_operators<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
423 let start = state.get_position();
424
425 let ops = [
426 ("::", SqlTokenType::DoubleColon),
427 ("||", SqlTokenType::Concat),
428 ("<=", SqlTokenType::LessEqual),
429 (">=", SqlTokenType::GreaterEqual),
430 ("<>", SqlTokenType::NotEqual),
431 ("!=", SqlTokenType::NotEqual),
432 ("=", SqlTokenType::Equal),
433 ("<", SqlTokenType::Less),
434 (">", SqlTokenType::Greater),
435 ("+", SqlTokenType::Plus),
436 ("-", SqlTokenType::Minus),
437 ("*", SqlTokenType::Star),
438 ("/", SqlTokenType::Slash),
439 ("%", SqlTokenType::Percent),
440 ];
441
442 for (op, kind) in ops {
443 if state.starts_with(op) {
444 state.advance(op.len());
445 state.add_token(kind, start, state.get_position());
446 return true;
447 }
448 }
449
450 false
451 }
452
453 fn lex_single_char_tokens<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
454 let start = state.get_position();
455 let ch = match state.current() {
456 Some(c) => c,
457 None => return false,
458 };
459
460 let kind = match ch {
461 '(' => SqlTokenType::LeftParen,
462 ')' => SqlTokenType::RightParen,
463 ',' => SqlTokenType::Comma,
464 ';' => SqlTokenType::Semicolon,
465 '.' => SqlTokenType::Dot,
466 ':' => SqlTokenType::Colon,
467 _ => return false,
468 };
469
470 state.advance(ch.len_utf8());
471 state.add_token(kind, start, state.get_position());
472 true
473 }
474}