1use crate::{kind::SqlSyntaxKind, language::SqlLanguage};
2use oak_core::{
3 IncrementalCache, Lexer, LexerState, OakError,
4 lexer::{CommentLine, LexOutput, StringConfig, WhitespaceConfig},
5 source::Source,
6};
7use std::sync::LazyLock;
8
9type State<S> = LexerState<S, SqlLanguage>;
10
11static SQL_WHITESPACE: LazyLock<WhitespaceConfig> = LazyLock::new(|| WhitespaceConfig { unicode_whitespace: true });
12static SQL_COMMENT: LazyLock<CommentLine> = LazyLock::new(|| CommentLine { line_markers: &["--"] });
13static SQL_STRING: LazyLock<StringConfig> = LazyLock::new(|| StringConfig { quotes: &['"', '\''], escape: Some('\\') });
14
15#[derive(Clone)]
16pub struct SqlLexer<'config> {
17 config: &'config SqlLanguage,
18}
19
20impl<'config> Lexer<SqlLanguage> for SqlLexer<'config> {
21 fn lex_incremental(
22 &self,
23 source: impl Source,
24 changed: usize,
25 cache: IncrementalCache<SqlLanguage>,
26 ) -> LexOutput<SqlLanguage> {
27 let mut state = LexerState::new_with_cache(source, changed, cache);
28 let result = self.run(&mut state);
29 state.finish(result)
30 }
31}
32
33impl<'config> SqlLexer<'config> {
34 pub fn new(config: &'config SqlLanguage) -> Self {
35 Self { config }
36 }
37
38 fn run<S: Source>(&self, state: &mut State<S>) -> Result<(), OakError> {
39 while state.not_at_end() {
40 let safe_point = state.get_position();
41
42 if self.skip_whitespace(state) {
43 continue;
44 }
45
46 if self.lex_newline(state) {
47 continue;
48 }
49
50 if self.skip_comment(state) {
51 continue;
52 }
53
54 if self.lex_string_literal(state) {
55 continue;
56 }
57
58 if self.lex_number_literal(state) {
59 continue;
60 }
61
62 if self.lex_identifier_or_keyword(state) {
63 continue;
64 }
65
66 if self.lex_operators(state) {
67 continue;
68 }
69
70 if self.lex_single_char_tokens(state) {
71 continue;
72 }
73
74 if let Some(ch) = state.peek() {
76 state.advance(ch.len_utf8());
77 state.add_token(SqlSyntaxKind::Error, safe_point, state.get_position());
78 }
79 }
80
81 let eof_pos = state.get_position();
83 state.add_token(SqlSyntaxKind::Eof, eof_pos, eof_pos);
84 Ok(())
85 }
86
87 fn lex_newline<S: Source>(&self, state: &mut State<S>) -> bool {
89 let start_pos = state.get_position();
90
91 if let Some('\n') = state.peek() {
92 state.advance(1);
93 state.add_token(SqlSyntaxKind::Newline, start_pos, state.get_position());
94 true
95 }
96 else if let Some('\r') = state.peek() {
97 state.advance(1);
98 if let Some('\n') = state.peek() {
99 state.advance(1);
100 }
101 state.add_token(SqlSyntaxKind::Newline, start_pos, state.get_position());
102 true
103 }
104 else {
105 false
106 }
107 }
108
109 fn skip_whitespace<S: Source>(&self, state: &mut State<S>) -> bool {
110 match SQL_WHITESPACE.scan(state.rest(), state.get_position(), SqlSyntaxKind::Whitespace) {
111 Some(token) => {
112 state.advance_with(token);
113 true
114 }
115 None => false,
116 }
117 }
118
119 fn skip_comment<S: Source>(&self, state: &mut State<S>) -> bool {
120 let start = state.get_position();
121 let rest = state.rest();
122
123 if rest.starts_with("--") {
125 state.advance(2);
126 while let Some(ch) = state.peek() {
127 if ch == '\n' || ch == '\r' {
128 break;
129 }
130 state.advance(ch.len_utf8());
131 }
132 state.add_token(SqlSyntaxKind::Comment, start, state.get_position());
133 return true;
134 }
135
136 if rest.starts_with("/*") {
138 state.advance(2);
139 while let Some(ch) = state.peek() {
140 if ch == '*' && state.peek_next_n(1) == Some('/') {
141 state.advance(2);
142 break;
143 }
144 state.advance(ch.len_utf8());
145 }
146 state.add_token(SqlSyntaxKind::Comment, start, state.get_position());
147 return true;
148 }
149
150 false
151 }
152
153 fn lex_string_literal<S: Source>(&self, state: &mut State<S>) -> bool {
154 let start = state.get_position();
155 let ch = match state.current() {
156 Some(c) => c,
157 None => return false,
158 };
159
160 if ch == '\'' || ch == '"' {
161 let quote = ch;
162 state.advance(1);
163 let mut escaped = false;
164
165 while let Some(ch) = state.peek() {
166 if ch == quote && !escaped {
167 state.advance(1); break;
169 }
170 state.advance(ch.len_utf8());
171 if escaped {
172 escaped = false;
173 continue;
174 }
175 if ch == '\\' {
176 escaped = true;
177 continue;
178 }
179 if ch == '\n' || ch == '\r' {
180 break;
181 }
182 }
183 state.add_token(SqlSyntaxKind::StringLiteral, start, state.get_position());
184 return true;
185 }
186 false
187 }
188
189 fn lex_number_literal<S: Source>(&self, state: &mut State<S>) -> bool {
190 let start = state.get_position();
191 let first = match state.current() {
192 Some(c) => c,
193 None => return false,
194 };
195
196 if !first.is_ascii_digit() {
197 return false;
198 }
199
200 let mut is_float = false;
201 state.advance(1);
202
203 while let Some(c) = state.peek() {
205 if c.is_ascii_digit() || c == '_' {
206 state.advance(1);
207 }
208 else {
209 break;
210 }
211 }
212
213 if state.peek() == Some('.') {
215 let next = state.peek_next_n(1);
216 if next.map(|c| c.is_ascii_digit()).unwrap_or(false) {
217 is_float = true;
218 state.advance(1); while let Some(c) = state.peek() {
220 if c.is_ascii_digit() || c == '_' {
221 state.advance(1);
222 }
223 else {
224 break;
225 }
226 }
227 }
228 }
229
230 if let Some(c) = state.peek() {
232 if c == 'e' || c == 'E' {
233 let next = state.peek_next_n(1);
234 if next == Some('+') || next == Some('-') || next.map(|d| d.is_ascii_digit()).unwrap_or(false) {
235 is_float = true;
236 state.advance(1);
237 if let Some(sign) = state.peek() {
238 if sign == '+' || sign == '-' {
239 state.advance(1);
240 }
241 }
242 while let Some(d) = state.peek() {
243 if d.is_ascii_digit() || d == '_' {
244 state.advance(1);
245 }
246 else {
247 break;
248 }
249 }
250 }
251 }
252 }
253
254 let end = state.get_position();
255 state.add_token(if is_float { SqlSyntaxKind::FloatLiteral } else { SqlSyntaxKind::NumberLiteral }, start, end);
256 true
257 }
258
259 fn lex_identifier_or_keyword<S: Source>(&self, state: &mut State<S>) -> bool {
260 let start = state.get_position();
261 let ch = match state.current() {
262 Some(c) => c,
263 None => return false,
264 };
265
266 if !(ch.is_ascii_alphabetic() || ch == '_') {
267 return false;
268 }
269
270 state.advance(1);
271 while let Some(c) = state.current() {
272 if c.is_ascii_alphanumeric() || c == '_' {
273 state.advance(1);
274 }
275 else {
276 break;
277 }
278 }
279
280 let end = state.get_position();
281 let text = state.get_text_in((start..end).into());
282 let kind = self.keyword_kind(&text).unwrap_or(SqlSyntaxKind::Identifier);
283 state.add_token(kind, start, end);
284 true
285 }
286
287 fn keyword_kind(&self, text: &str) -> Option<SqlSyntaxKind> {
288 match text.to_uppercase().as_str() {
289 "SELECT" => Some(SqlSyntaxKind::Select),
290 "FROM" => Some(SqlSyntaxKind::From),
291 "WHERE" => Some(SqlSyntaxKind::Where),
292 "INSERT" => Some(SqlSyntaxKind::Insert),
293 "INTO" => Some(SqlSyntaxKind::Into),
294 "VALUES" => Some(SqlSyntaxKind::Values),
295 "UPDATE" => Some(SqlSyntaxKind::Update),
296 "SET" => Some(SqlSyntaxKind::Set),
297 "DELETE" => Some(SqlSyntaxKind::Delete),
298 "CREATE" => Some(SqlSyntaxKind::Create),
299 "DROP" => Some(SqlSyntaxKind::Drop),
300 "ALTER" => Some(SqlSyntaxKind::Alter),
301 "ADD" => Some(SqlSyntaxKind::Add),
302 "COLUMN" => Some(SqlSyntaxKind::Column),
303 "TABLE" => Some(SqlSyntaxKind::Table),
304 "PRIMARY" => Some(SqlSyntaxKind::Primary),
305 "KEY" => Some(SqlSyntaxKind::Key),
306 "FOREIGN" => Some(SqlSyntaxKind::Foreign),
307 "REFERENCES" => Some(SqlSyntaxKind::References),
308 "INDEX" => Some(SqlSyntaxKind::Index),
309 "UNIQUE" => Some(SqlSyntaxKind::Unique),
310 "NOT" => Some(SqlSyntaxKind::Not),
311 "NULL" => Some(SqlSyntaxKind::Null),
312 "DEFAULT" => Some(SqlSyntaxKind::Default),
313 "AUTO_INCREMENT" => Some(SqlSyntaxKind::AutoIncrement),
314 "AND" => Some(SqlSyntaxKind::And),
315 "OR" => Some(SqlSyntaxKind::Or),
316 "IN" => Some(SqlSyntaxKind::In),
317 "LIKE" => Some(SqlSyntaxKind::Like),
318 "BETWEEN" => Some(SqlSyntaxKind::Between),
319 "IS" => Some(SqlSyntaxKind::Is),
320 "AS" => Some(SqlSyntaxKind::As),
321 "JOIN" => Some(SqlSyntaxKind::Join),
322 "INNER" => Some(SqlSyntaxKind::Inner),
323 "LEFT" => Some(SqlSyntaxKind::Left),
324 "RIGHT" => Some(SqlSyntaxKind::Right),
325 "FULL" => Some(SqlSyntaxKind::Full),
326 "OUTER" => Some(SqlSyntaxKind::Outer),
327 "ON" => Some(SqlSyntaxKind::On),
328 "GROUP" => Some(SqlSyntaxKind::Group),
329 "BY" => Some(SqlSyntaxKind::By),
330 "HAVING" => Some(SqlSyntaxKind::Having),
331 "ORDER" => Some(SqlSyntaxKind::Order),
332 "ASC" => Some(SqlSyntaxKind::Asc),
333 "DESC" => Some(SqlSyntaxKind::Desc),
334 "LIMIT" => Some(SqlSyntaxKind::Limit),
335 "OFFSET" => Some(SqlSyntaxKind::Offset),
336 "UNION" => Some(SqlSyntaxKind::Union),
337 "ALL" => Some(SqlSyntaxKind::All),
338 "DISTINCT" => Some(SqlSyntaxKind::Distinct),
339 "COUNT" => Some(SqlSyntaxKind::Count),
340 "SUM" => Some(SqlSyntaxKind::Sum),
341 "AVG" => Some(SqlSyntaxKind::Avg),
342 "MIN" => Some(SqlSyntaxKind::Min),
343 "MAX" => Some(SqlSyntaxKind::Max),
344 "VIEW" => Some(SqlSyntaxKind::View),
345 "DATABASE" => Some(SqlSyntaxKind::Database),
346 "SCHEMA" => Some(SqlSyntaxKind::Schema),
347 "TRUE" => Some(SqlSyntaxKind::True),
348 "FALSE" => Some(SqlSyntaxKind::False),
349 "EXISTS" => Some(SqlSyntaxKind::Exists),
350 "CASE" => Some(SqlSyntaxKind::Case),
351 "WHEN" => Some(SqlSyntaxKind::When),
352 "THEN" => Some(SqlSyntaxKind::Then),
353 "ELSE" => Some(SqlSyntaxKind::Else),
354 "END" => Some(SqlSyntaxKind::End),
355 "IF" => Some(SqlSyntaxKind::If),
356 "BEGIN" => Some(SqlSyntaxKind::Begin),
357 "COMMIT" => Some(SqlSyntaxKind::Commit),
358 "ROLLBACK" => Some(SqlSyntaxKind::Rollback),
359 "TRANSACTION" => Some(SqlSyntaxKind::Transaction),
360 "INT" => Some(SqlSyntaxKind::Int),
362 "INTEGER" => Some(SqlSyntaxKind::Integer),
363 "VARCHAR" => Some(SqlSyntaxKind::Varchar),
364 "CHAR" => Some(SqlSyntaxKind::Char),
365 "TEXT" => Some(SqlSyntaxKind::Text),
366 "DATE" => Some(SqlSyntaxKind::Date),
367 "TIME" => Some(SqlSyntaxKind::Time),
368 "TIMESTAMP" => Some(SqlSyntaxKind::Timestamp),
369 "DECIMAL" => Some(SqlSyntaxKind::Decimal),
370 "FLOAT" => Some(SqlSyntaxKind::Float),
371 "DOUBLE" => Some(SqlSyntaxKind::Double),
372 "BOOLEAN" => Some(SqlSyntaxKind::Boolean),
373 _ => None,
374 }
375 }
376
377 fn lex_operators<S: Source>(&self, state: &mut State<S>) -> bool {
378 let start = state.get_position();
379 let rest = state.rest();
380
381 let patterns: &[(&str, SqlSyntaxKind)] = &[
383 ("<=", SqlSyntaxKind::Le),
384 (">=", SqlSyntaxKind::Ge),
385 ("!=", SqlSyntaxKind::Ne),
386 ("<>", SqlSyntaxKind::Ne),
387 ("||", SqlSyntaxKind::Concat),
388 ];
389
390 for (pat, kind) in patterns {
391 if rest.starts_with(pat) {
392 state.advance(pat.len());
393 state.add_token(*kind, start, state.get_position());
394 return true;
395 }
396 }
397
398 if let Some(ch) = state.current() {
399 let kind = match ch {
400 '=' => Some(SqlSyntaxKind::Equal),
401 '<' => Some(SqlSyntaxKind::Lt),
402 '>' => Some(SqlSyntaxKind::Gt),
403 '+' => Some(SqlSyntaxKind::Plus),
404 '-' => Some(SqlSyntaxKind::Minus),
405 '*' => Some(SqlSyntaxKind::Star),
406 '/' => Some(SqlSyntaxKind::Slash),
407 '%' => Some(SqlSyntaxKind::Percent),
408 '.' => Some(SqlSyntaxKind::Dot),
409 _ => None,
410 };
411 if let Some(k) = kind {
412 state.advance(ch.len_utf8());
413 state.add_token(k, start, state.get_position());
414 return true;
415 }
416 }
417 false
418 }
419
420 fn lex_single_char_tokens<S: Source>(&self, state: &mut State<S>) -> bool {
421 let start = state.get_position();
422 if let Some(ch) = state.current() {
423 let kind = match ch {
424 '(' => SqlSyntaxKind::LeftParen,
425 ')' => SqlSyntaxKind::RightParen,
426 '{' => SqlSyntaxKind::LeftBrace,
427 '}' => SqlSyntaxKind::RightBrace,
428 '[' => SqlSyntaxKind::LeftBracket,
429 ']' => SqlSyntaxKind::RightBracket,
430 ',' => SqlSyntaxKind::Comma,
431 ';' => SqlSyntaxKind::Semicolon,
432 ':' => SqlSyntaxKind::Colon,
433 '?' => SqlSyntaxKind::Question,
434 _ => return false,
435 };
436 state.advance(ch.len_utf8());
437 state.add_token(kind, start, state.get_position());
438 return true;
439 }
440 false
441 }
442}