1#![doc = include_str!("readme.md")]
2use oak_core::Source;
3pub mod token_type;
4pub use token_type::SqlTokenType;
5
6use crate::language::SqlLanguage;
7use oak_core::{
8 Lexer, LexerCache, LexerState, OakError, TextEdit,
9 lexer::{LexOutput, WhitespaceConfig},
10};
11use std::sync::LazyLock;
12
13type State<'a, S> = LexerState<'a, S, SqlLanguage>;
14
15static SQL_WHITESPACE: LazyLock<WhitespaceConfig> = LazyLock::new(|| WhitespaceConfig { unicode_whitespace: true });
16
17#[derive(Clone, Debug)]
18pub struct SqlLexer<'config> {
19 _config: &'config SqlLanguage,
20}
21
22impl<'config> Lexer<SqlLanguage> for SqlLexer<'config> {
23 fn lex<'a, S: Source + ?Sized>(&self, text: &'a S, _edits: &[TextEdit], cache: &'a mut impl LexerCache<SqlLanguage>) -> LexOutput<SqlLanguage> {
24 let mut state = State::new(text);
25 let result = self.run(&mut state);
26 if result.is_ok() {
27 state.add_eof();
28 }
29 state.finish_with_cache(result, cache)
30 }
31}
32
33impl<'config> SqlLexer<'config> {
34 pub fn new(config: &'config SqlLanguage) -> Self {
35 Self { _config: config }
36 }
37
38 fn run<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> Result<(), OakError> {
39 while state.not_at_end() {
40 let safe_point = state.get_position();
41
42 if let Some(ch) = state.peek() {
43 match ch {
44 ' ' | '\t' => {
45 self.skip_whitespace(state);
46 }
47 '\n' | '\r' => {
48 self.lex_newline(state);
49 }
50 '-' => {
51 if state.starts_with("--") {
52 self.skip_comment(state);
53 }
54 else {
55 self.lex_operators(state);
56 }
57 }
58 '/' => {
59 if state.starts_with("/*") {
60 self.skip_comment(state);
61 }
62 else {
63 self.lex_operators(state);
64 }
65 }
66 '\'' | '"' => {
67 self.lex_string_literal(state);
68 }
69 '0'..='9' => {
70 self.lex_number_literal(state);
71 }
72 'a'..='z' | 'A'..='Z' | '_' => {
73 self.lex_identifier_or_keyword(state);
74 }
75 '<' | '>' | '!' | '=' | '+' | '*' | '%' => {
76 self.lex_operators(state);
77 }
78 '(' | ')' | ',' | ';' | '.' => {
79 self.lex_single_char_tokens(state);
80 }
81 _ => {
82 state.advance(ch.len_utf8());
84 state.add_token(SqlTokenType::Error, safe_point, state.get_position());
85 }
86 }
87 }
88
89 state.advance_if_dead_lock(safe_point);
90 }
91 Ok(())
92 }
93
94 fn lex_newline<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
96 let start_pos = state.get_position();
97
98 if let Some('\n') = state.peek() {
99 state.advance(1);
100 state.add_token(SqlTokenType::Newline, start_pos, state.get_position());
101 true
102 }
103 else if let Some('\r') = state.peek() {
104 state.advance(1);
105 if let Some('\n') = state.peek() {
106 state.advance(1);
107 }
108 state.add_token(SqlTokenType::Newline, start_pos, state.get_position());
109 true
110 }
111 else {
112 false
113 }
114 }
115
116 fn skip_whitespace<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
117 SQL_WHITESPACE.scan(state, SqlTokenType::Whitespace);
118 true
119 }
120
121 fn skip_comment<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
122 let start = state.get_position();
123
124 if state.starts_with("--") {
126 state.advance(2);
127 state.take_while(|ch| ch != '\n' && ch != '\r');
128 state.add_token(SqlTokenType::Comment, start, state.get_position());
129 return true;
130 }
131
132 if state.starts_with("/*") {
134 state.advance(2);
135 while state.not_at_end() {
136 if state.starts_with("*/") {
137 state.advance(2);
138 break;
139 }
140 if let Some(ch) = state.current() {
141 state.advance(ch.len_utf8());
142 }
143 }
144 state.add_token(SqlTokenType::Comment, start, state.get_position());
145 return true;
146 }
147
148 false
149 }
150
151 fn lex_string_literal<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
152 let start = state.get_position();
153 if let Some(quote) = state.current() {
154 if quote != '\'' && quote != '"' {
155 return false;
156 }
157 state.advance(1);
158 let mut escaped = false;
159 while state.not_at_end() {
160 let ch = match state.peek() {
161 Some(c) => c,
162 None => break,
163 };
164
165 if ch == quote && !escaped {
166 state.advance(1); break;
168 }
169 state.advance(ch.len_utf8());
170 if escaped {
171 escaped = false;
172 continue;
173 }
174 if ch == '\\' {
175 escaped = true;
176 continue;
177 }
178 if ch == '\n' || ch == '\r' {
179 break;
180 }
181 }
182 state.add_token(SqlTokenType::StringLiteral, start, state.get_position());
183 true
184 }
185 else {
186 false
187 }
188 }
189
190 fn lex_number_literal<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
191 let start = state.get_position();
192 let first = match state.current() {
193 Some(c) => c,
194 None => return false,
195 };
196
197 if !first.is_ascii_digit() {
198 return false;
199 }
200
201 let mut is_float = false;
202 state.advance(1);
203
204 while let Some(c) = state.peek() {
206 if c.is_ascii_digit() || c == '_' {
207 state.advance(1);
208 }
209 else {
210 break;
211 }
212 }
213
214 if state.peek() == Some('.') {
216 let next = state.peek_next_n(1);
217 if next.map(|c| c.is_ascii_digit()).unwrap_or(false) {
218 is_float = true;
219 state.advance(1); while let Some(c) = state.peek() {
221 if c.is_ascii_digit() || c == '_' {
222 state.advance(1);
223 }
224 else {
225 break;
226 }
227 }
228 }
229 }
230
231 if let Some(c) = state.peek() {
233 if c == 'e' || c == 'E' {
234 let next = state.peek_next_n(1);
235 if next == Some('+') || next == Some('-') || next.map(|d| d.is_ascii_digit()).unwrap_or(false) {
236 is_float = true;
237 state.advance(1);
238 if let Some(sign) = state.peek() {
239 if sign == '+' || sign == '-' {
240 state.advance(1);
241 }
242 }
243 while let Some(d) = state.peek() {
244 if d.is_ascii_digit() || d == '_' {
245 state.advance(1);
246 }
247 else {
248 break;
249 }
250 }
251 }
252 }
253 }
254
255 let end = state.get_position();
256 state.add_token(if is_float { SqlTokenType::FloatLiteral } else { SqlTokenType::NumberLiteral }, start, end);
257 true
258 }
259
260 fn lex_identifier_or_keyword<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
261 let start = state.get_position();
262 let ch = match state.current() {
263 Some(c) => c,
264 None => return false,
265 };
266
267 if !ch.is_alphabetic() && ch != '_' {
268 return false;
269 }
270
271 state.advance(ch.len_utf8());
272 while let Some(c) = state.peek() {
273 if c.is_alphanumeric() || c == '_' {
274 state.advance(c.len_utf8());
275 }
276 else {
277 break;
278 }
279 }
280
281 let end = state.get_position();
282 let text = state.source().get_text_in(oak_core::Range { start, end }).to_uppercase();
283 let kind = match text.as_str() {
284 "SELECT" => SqlTokenType::Select,
285 "FROM" => SqlTokenType::From,
286 "WHERE" => SqlTokenType::Where,
287 "INSERT" => SqlTokenType::Insert,
288 "UPDATE" => SqlTokenType::Update,
289 "DELETE" => SqlTokenType::Delete,
290 "CREATE" => SqlTokenType::Create,
291 "DROP" => SqlTokenType::Drop,
292 "ALTER" => SqlTokenType::Alter,
293 "TABLE" => SqlTokenType::Table,
294 "INDEX" => SqlTokenType::Index,
295 "INTO" => SqlTokenType::Into,
296 "VALUES" => SqlTokenType::Values,
297 "SET" => SqlTokenType::Set,
298 "JOIN" => SqlTokenType::Join,
299 "INNER" => SqlTokenType::Inner,
300 "LEFT" => SqlTokenType::Left,
301 "RIGHT" => SqlTokenType::Right,
302 "FULL" => SqlTokenType::Full,
303 "OUTER" => SqlTokenType::Outer,
304 "ON" => SqlTokenType::On,
305 "AND" => SqlTokenType::And,
306 "OR" => SqlTokenType::Or,
307 "NOT" => SqlTokenType::Not,
308 "NULL" => SqlTokenType::Null,
309 "TRUE" => SqlTokenType::True,
310 "FALSE" => SqlTokenType::False,
311 "AS" => SqlTokenType::As,
312 "BY" => SqlTokenType::By,
313 "ORDER" => SqlTokenType::Order,
314 "ASC" => SqlTokenType::Asc,
315 "DESC" => SqlTokenType::Desc,
316 "GROUP" => SqlTokenType::Group,
317 "HAVING" => SqlTokenType::Having,
318 "LIMIT" => SqlTokenType::Limit,
319 "OFFSET" => SqlTokenType::Offset,
320 "UNION" => SqlTokenType::Union,
321 "ALL" => SqlTokenType::All,
322 "DISTINCT" => SqlTokenType::Distinct,
323 "PRIMARY" => SqlTokenType::Primary,
324 "KEY" => SqlTokenType::Key,
325 "FOREIGN" => SqlTokenType::Foreign,
326 "REFERENCES" => SqlTokenType::References,
327 "DEFAULT" => SqlTokenType::Default,
328 "UNIQUE" => SqlTokenType::Unique,
329 "AUTO_INCREMENT" => SqlTokenType::AutoIncrement,
330 "INT" => SqlTokenType::Int,
331 "INTEGER" => SqlTokenType::Integer,
332 "VARCHAR" => SqlTokenType::Varchar,
333 "CHAR" => SqlTokenType::Char,
334 "TEXT" => SqlTokenType::Text,
335 "DATE" => SqlTokenType::Date,
336 "TIME" => SqlTokenType::Time,
337 "TIMESTAMP" => SqlTokenType::Timestamp,
338 "DECIMAL" => SqlTokenType::Decimal,
339 "FLOAT" => SqlTokenType::Float,
340 "DOUBLE" => SqlTokenType::Double,
341 "BOOLEAN" => SqlTokenType::Boolean,
342 _ => SqlTokenType::Identifier,
343 };
344
345 state.add_token(kind, start, end);
346 true
347 }
348
349 fn lex_operators<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
350 let start = state.get_position();
351
352 let ops = [
353 ("<=", SqlTokenType::LessEqual),
354 (">=", SqlTokenType::GreaterEqual),
355 ("<>", SqlTokenType::NotEqual),
356 ("!=", SqlTokenType::NotEqual),
357 ("=", SqlTokenType::Equal),
358 ("<", SqlTokenType::Less),
359 (">", SqlTokenType::Greater),
360 ("+", SqlTokenType::Plus),
361 ("-", SqlTokenType::Minus),
362 ("*", SqlTokenType::Star),
363 ("/", SqlTokenType::Slash),
364 ("%", SqlTokenType::Percent),
365 ];
366
367 for (op, kind) in ops {
368 if state.starts_with(op) {
369 state.advance(op.len());
370 state.add_token(kind, start, state.get_position());
371 return true;
372 }
373 }
374
375 false
376 }
377
378 fn lex_single_char_tokens<'a, S: Source + ?Sized>(&self, state: &mut State<'a, S>) -> bool {
379 let start = state.get_position();
380 let ch = match state.current() {
381 Some(c) => c,
382 None => return false,
383 };
384
385 let kind = match ch {
386 '(' => SqlTokenType::LeftParen,
387 ')' => SqlTokenType::RightParen,
388 ',' => SqlTokenType::Comma,
389 ';' => SqlTokenType::Semicolon,
390 '.' => SqlTokenType::Dot,
391 _ => return false,
392 };
393
394 state.advance(ch.len_utf8());
395 state.add_token(kind, start, state.get_position());
396 true
397 }
398}