1use crate::{kind::SqlSyntaxKind, language::SqlLanguage};
2use oak_core::{
3 Lexer, LexerCache, LexerState, OakError, TextEdit,
4 lexer::{LexOutput, WhitespaceConfig},
5 source::Source,
6};
7use std::sync::LazyLock;
8
9type State<'a, S> = LexerState<'a, S, SqlLanguage>;
10
11static SQL_WHITESPACE: LazyLock<WhitespaceConfig> = LazyLock::new(|| WhitespaceConfig { unicode_whitespace: true });
12
13#[derive(Clone, Debug)]
14pub struct SqlLexer<'config> {
15 _config: &'config SqlLanguage,
16}
17
18impl<'config> Lexer<SqlLanguage> for SqlLexer<'config> {
19 fn lex<'a, S: Source + ?Sized>(&self, text: &'a S, _edits: &[TextEdit], cache: &'a mut impl LexerCache<SqlLanguage>) -> LexOutput<SqlLanguage> {
20 let mut state = State::new(text);
21 let result = self.run(&mut state);
22 if result.is_ok() {
23 state.add_eof();
24 }
25 state.finish_with_cache(result, cache)
26 }
27}
28
29impl<'config> SqlLexer<'config> {
30 pub fn new(config: &'config SqlLanguage) -> Self {
31 Self { _config: config }
32 }
33
34 fn run<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
35 while state.not_at_end() {
36 let safe_point = state.get_position();
37
38 if let Some(ch) = state.peek() {
39 match ch {
40 ' ' | '\t' => {
41 self.skip_whitespace(state);
42 }
43 '\n' | '\r' => {
44 self.lex_newline(state);
45 }
46 '-' => {
47 if state.starts_with("--") {
48 self.skip_comment(state);
49 }
50 else {
51 self.lex_operators(state);
52 }
53 }
54 '/' => {
55 if state.starts_with("/*") {
56 self.skip_comment(state);
57 }
58 else {
59 self.lex_operators(state);
60 }
61 }
62 '\'' | '"' => {
63 self.lex_string_literal(state);
64 }
65 '0'..='9' => {
66 self.lex_number_literal(state);
67 }
68 'a'..='z' | 'A'..='Z' | '_' => {
69 self.lex_identifier_or_keyword(state);
70 }
71 '<' | '>' | '!' | '=' | '+' | '*' | '%' => {
72 self.lex_operators(state);
73 }
74 '(' | ')' | ',' | ';' | '.' => {
75 self.lex_single_char_tokens(state);
76 }
77 _ => {
78 state.advance(ch.len_utf8());
80 state.add_token(SqlSyntaxKind::Error, safe_point, state.get_position());
81 }
82 }
83 }
84
85 state.advance_if_dead_lock(safe_point);
86 }
87 Ok(())
88 }
89
90 fn lex_newline<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
92 let start_pos = state.get_position();
93
94 if let Some('\n') = state.peek() {
95 state.advance(1);
96 state.add_token(SqlSyntaxKind::Newline, start_pos, state.get_position());
97 true
98 }
99 else if let Some('\r') = state.peek() {
100 state.advance(1);
101 if let Some('\n') = state.peek() {
102 state.advance(1);
103 }
104 state.add_token(SqlSyntaxKind::Newline, start_pos, state.get_position());
105 true
106 }
107 else {
108 false
109 }
110 }
111
112 fn skip_whitespace<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
113 SQL_WHITESPACE.scan(state, SqlSyntaxKind::Whitespace)
114 }
115
116 fn skip_comment<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
117 let start = state.get_position();
118
119 if state.starts_with("--") {
121 state.advance(2);
122 state.take_while(|ch| ch != '\n' && ch != '\r');
123 state.add_token(SqlSyntaxKind::Comment, start, state.get_position());
124 return true;
125 }
126
127 if state.starts_with("/*") {
129 state.advance(2);
130 while state.not_at_end() {
131 if state.starts_with("*/") {
132 state.advance(2);
133 break;
134 }
135 if let Some(ch) = state.current() {
136 state.advance(ch.len_utf8());
137 }
138 }
139 state.add_token(SqlSyntaxKind::Comment, start, state.get_position());
140 return true;
141 }
142
143 false
144 }
145
146 fn lex_string_literal<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
147 let start = state.get_position();
148 if let Some(quote) = state.current() {
149 if quote != '\'' && quote != '"' {
150 return false;
151 }
152 state.advance(1);
153 let mut escaped = false;
154 while state.not_at_end() {
155 let ch = match state.peek() {
156 Some(c) => c,
157 None => break,
158 };
159
160 if ch == quote && !escaped {
161 state.advance(1); break;
163 }
164 state.advance(ch.len_utf8());
165 if escaped {
166 escaped = false;
167 continue;
168 }
169 if ch == '\\' {
170 escaped = true;
171 continue;
172 }
173 if ch == '\n' || ch == '\r' {
174 break;
175 }
176 }
177 state.add_token(SqlSyntaxKind::StringLiteral, start, state.get_position());
178 return true;
179 }
180 false
181 }
182
183 fn lex_number_literal<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
184 let start = state.get_position();
185 let first = match state.current() {
186 Some(c) => c,
187 None => return false,
188 };
189
190 if !first.is_ascii_digit() {
191 return false;
192 }
193
194 let mut is_float = false;
195 state.advance(1);
196
197 while let Some(c) = state.peek() {
199 if c.is_ascii_digit() || c == '_' {
200 state.advance(1);
201 }
202 else {
203 break;
204 }
205 }
206
207 if state.peek() == Some('.') {
209 let next = state.peek_next_n(1);
210 if next.map(|c| c.is_ascii_digit()).unwrap_or(false) {
211 is_float = true;
212 state.advance(1); while let Some(c) = state.peek() {
214 if c.is_ascii_digit() || c == '_' {
215 state.advance(1);
216 }
217 else {
218 break;
219 }
220 }
221 }
222 }
223
224 if let Some(c) = state.peek() {
226 if c == 'e' || c == 'E' {
227 let next = state.peek_next_n(1);
228 if next == Some('+') || next == Some('-') || next.map(|d| d.is_ascii_digit()).unwrap_or(false) {
229 is_float = true;
230 state.advance(1);
231 if let Some(sign) = state.peek() {
232 if sign == '+' || sign == '-' {
233 state.advance(1);
234 }
235 }
236 while let Some(d) = state.peek() {
237 if d.is_ascii_digit() || d == '_' {
238 state.advance(1);
239 }
240 else {
241 break;
242 }
243 }
244 }
245 }
246 }
247
248 let end = state.get_position();
249 state.add_token(if is_float { SqlSyntaxKind::FloatLiteral } else { SqlSyntaxKind::NumberLiteral }, start, end);
250 true
251 }
252
253 fn lex_identifier_or_keyword<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
254 let start = state.get_position();
255 let ch = match state.current() {
256 Some(c) => c,
257 None => return false,
258 };
259
260 if !ch.is_alphabetic() && ch != '_' {
261 return false;
262 }
263
264 state.advance(ch.len_utf8());
265 while let Some(c) = state.peek() {
266 if c.is_alphanumeric() || c == '_' {
267 state.advance(c.len_utf8());
268 }
269 else {
270 break;
271 }
272 }
273
274 let end = state.get_position();
275 let text = state.source().get_text_in(oak_core::Range { start, end }).to_uppercase();
276 let kind = match text.as_str() {
277 "SELECT" => SqlSyntaxKind::Select,
278 "FROM" => SqlSyntaxKind::From,
279 "WHERE" => SqlSyntaxKind::Where,
280 "INSERT" => SqlSyntaxKind::Insert,
281 "UPDATE" => SqlSyntaxKind::Update,
282 "DELETE" => SqlSyntaxKind::Delete,
283 "CREATE" => SqlSyntaxKind::Create,
284 "DROP" => SqlSyntaxKind::Drop,
285 "ALTER" => SqlSyntaxKind::Alter,
286 "TABLE" => SqlSyntaxKind::Table,
287 "INDEX" => SqlSyntaxKind::Index,
288 "INTO" => SqlSyntaxKind::Into,
289 "VALUES" => SqlSyntaxKind::Values,
290 "SET" => SqlSyntaxKind::Set,
291 "JOIN" => SqlSyntaxKind::Join,
292 "INNER" => SqlSyntaxKind::Inner,
293 "LEFT" => SqlSyntaxKind::Left,
294 "RIGHT" => SqlSyntaxKind::Right,
295 "FULL" => SqlSyntaxKind::Full,
296 "OUTER" => SqlSyntaxKind::Outer,
297 "ON" => SqlSyntaxKind::On,
298 "AND" => SqlSyntaxKind::And,
299 "OR" => SqlSyntaxKind::Or,
300 "NOT" => SqlSyntaxKind::Not,
301 "NULL" => SqlSyntaxKind::Null,
302 "TRUE" => SqlSyntaxKind::True,
303 "FALSE" => SqlSyntaxKind::False,
304 "AS" => SqlSyntaxKind::As,
305 "BY" => SqlSyntaxKind::By,
306 "ORDER" => SqlSyntaxKind::Order,
307 "ASC" => SqlSyntaxKind::Asc,
308 "DESC" => SqlSyntaxKind::Desc,
309 "GROUP" => SqlSyntaxKind::Group,
310 "HAVING" => SqlSyntaxKind::Having,
311 "LIMIT" => SqlSyntaxKind::Limit,
312 "OFFSET" => SqlSyntaxKind::Offset,
313 "UNION" => SqlSyntaxKind::Union,
314 "ALL" => SqlSyntaxKind::All,
315 "DISTINCT" => SqlSyntaxKind::Distinct,
316 "PRIMARY" => SqlSyntaxKind::Primary,
317 "KEY" => SqlSyntaxKind::Key,
318 "FOREIGN" => SqlSyntaxKind::Foreign,
319 "REFERENCES" => SqlSyntaxKind::References,
320 "DEFAULT" => SqlSyntaxKind::Default,
321 "UNIQUE" => SqlSyntaxKind::Unique,
322 "AUTO_INCREMENT" => SqlSyntaxKind::AutoIncrement,
323 "INT" => SqlSyntaxKind::Int,
324 "INTEGER" => SqlSyntaxKind::Integer,
325 "VARCHAR" => SqlSyntaxKind::Varchar,
326 "CHAR" => SqlSyntaxKind::Char,
327 "TEXT" => SqlSyntaxKind::Text,
328 "DATE" => SqlSyntaxKind::Date,
329 "TIME" => SqlSyntaxKind::Time,
330 "TIMESTAMP" => SqlSyntaxKind::Timestamp,
331 "DECIMAL" => SqlSyntaxKind::Decimal,
332 "FLOAT" => SqlSyntaxKind::Float,
333 "DOUBLE" => SqlSyntaxKind::Double,
334 "BOOLEAN" => SqlSyntaxKind::Boolean,
335 _ => SqlSyntaxKind::Identifier,
336 };
337
338 state.add_token(kind, start, end);
339 true
340 }
341
342 fn lex_operators<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
343 let start = state.get_position();
344
345 let ops = [
346 ("<=", SqlSyntaxKind::LessEqual),
347 (">=", SqlSyntaxKind::GreaterEqual),
348 ("<>", SqlSyntaxKind::NotEqual),
349 ("!=", SqlSyntaxKind::NotEqual),
350 ("=", SqlSyntaxKind::Equal),
351 ("<", SqlSyntaxKind::Less),
352 (">", SqlSyntaxKind::Greater),
353 ("+", SqlSyntaxKind::Plus),
354 ("-", SqlSyntaxKind::Minus),
355 ("*", SqlSyntaxKind::Star),
356 ("/", SqlSyntaxKind::Slash),
357 ("%", SqlSyntaxKind::Percent),
358 ];
359
360 for (op, kind) in ops {
361 if state.starts_with(op) {
362 state.advance(op.len());
363 state.add_token(kind, start, state.get_position());
364 return true;
365 }
366 }
367
368 false
369 }
370
371 fn lex_single_char_tokens<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
372 let start = state.get_position();
373 let ch = match state.current() {
374 Some(c) => c,
375 None => return false,
376 };
377
378 let kind = match ch {
379 '(' => SqlSyntaxKind::LeftParen,
380 ')' => SqlSyntaxKind::RightParen,
381 ',' => SqlSyntaxKind::Comma,
382 ';' => SqlSyntaxKind::Semicolon,
383 '.' => SqlSyntaxKind::Dot,
384 _ => return false,
385 };
386
387 state.advance(ch.len_utf8());
388 state.add_token(kind, start, state.get_position());
389 true
390 }
391}