1use rigsql_core::{Span, Token, TokenKind};
2use smol_str::SmolStr;
3use thiserror::Error;
4
5#[derive(Debug, Error)]
6pub enum LexerError {
7 #[error("Unexpected character '{ch}' at offset {offset}")]
8 UnexpectedChar { ch: char, offset: u32 },
9 #[error("Unterminated string literal starting at offset {offset}")]
10 UnterminatedString { offset: u32 },
11 #[error("Unterminated block comment starting at offset {offset}")]
12 UnterminatedBlockComment { offset: u32 },
13 #[error("Unterminated quoted identifier starting at offset {offset}")]
14 UnterminatedQuotedIdentifier { offset: u32 },
15}
16
17#[derive(Debug, Clone, Default)]
19pub struct LexerConfig {
20 pub double_colon: bool,
22 pub bracket_identifiers: bool,
24 pub backtick_identifiers: bool,
26 pub double_at: bool,
28 pub dollar_quoting: bool,
30}
31
32impl LexerConfig {
33 pub fn ansi() -> Self {
34 Self::default()
35 }
36
37 pub fn postgres() -> Self {
38 Self {
39 double_colon: true,
40 dollar_quoting: true,
41 ..Self::default()
42 }
43 }
44
45 pub fn tsql() -> Self {
46 Self {
47 bracket_identifiers: true,
48 double_at: true,
49 ..Self::default()
50 }
51 }
52}
53
54pub struct Lexer<'a> {
55 source: &'a str,
56 bytes: &'a [u8],
57 pos: usize,
58 config: LexerConfig,
59}
60
61impl<'a> Lexer<'a> {
62 pub fn new(source: &'a str, config: LexerConfig) -> Self {
63 Self {
64 source,
65 bytes: source.as_bytes(),
66 pos: 0,
67 config,
68 }
69 }
70
71 pub fn tokenize(&mut self) -> Result<Vec<Token>, LexerError> {
73 let mut tokens = Vec::new();
74 loop {
75 let token = self.next_token()?;
76 let is_eof = token.kind == TokenKind::Eof;
77 tokens.push(token);
78 if is_eof {
79 break;
80 }
81 }
82 Ok(tokens)
83 }
84
85 fn next_token(&mut self) -> Result<Token, LexerError> {
86 if self.pos >= self.bytes.len() {
87 return Ok(Token::new(
88 TokenKind::Eof,
89 Span::new(self.pos as u32, self.pos as u32),
90 "",
91 ));
92 }
93
94 let start = self.pos;
95 let ch = self.bytes[self.pos];
96
97 match ch {
98 b'\n' => {
100 self.pos += 1;
101 Ok(self.make_token(TokenKind::Newline, start))
102 }
103 b'\r' => {
104 self.pos += 1;
105 if self.peek() == Some(b'\n') {
106 self.pos += 1;
107 }
108 Ok(self.make_token(TokenKind::Newline, start))
109 }
110
111 b' ' | b'\t' => {
113 self.pos += 1;
114 while let Some(b) = self.peek() {
115 if b == b' ' || b == b'\t' {
116 self.pos += 1;
117 } else {
118 break;
119 }
120 }
121 Ok(self.make_token(TokenKind::Whitespace, start))
122 }
123
124 b'-' if self.peek_at(1) == Some(b'-') => {
126 self.pos += 2;
127 while let Some(b) = self.peek() {
128 if b == b'\n' || b == b'\r' {
129 break;
130 }
131 self.pos += 1;
132 }
133 Ok(self.make_token(TokenKind::LineComment, start))
134 }
135
136 b'/' if self.peek_at(1) == Some(b'*') => {
138 self.pos += 2;
139 let mut depth = 1u32;
140 while self.pos < self.bytes.len() && depth > 0 {
141 if self.bytes[self.pos] == b'/' && self.peek_at(1) == Some(b'*') {
142 depth += 1;
143 self.pos += 2;
144 } else if self.bytes[self.pos] == b'*' && self.peek_at(1) == Some(b'/') {
145 depth -= 1;
146 self.pos += 2;
147 } else {
148 self.pos += 1;
149 }
150 }
151 if depth > 0 {
152 return Err(LexerError::UnterminatedBlockComment {
153 offset: start as u32,
154 });
155 }
156 Ok(self.make_token(TokenKind::BlockComment, start))
157 }
158
159 b'\'' => self.lex_string_literal(start),
161
162 b'"' => self.lex_quoted_identifier(start, b'"'),
164
165 b'[' if self.config.bracket_identifiers => self.lex_bracket_identifier(start),
167
168 b'`' if self.config.backtick_identifiers => self.lex_quoted_identifier(start, b'`'),
170
171 b'0'..=b'9' => self.lex_number(start),
173
174 b'.' if self.peek_at(1).is_some_and(|b| b.is_ascii_digit()) => self.lex_number(start),
176
177 b'.' => {
179 self.pos += 1;
180 Ok(self.make_token(TokenKind::Dot, start))
181 }
182 b',' => {
183 self.pos += 1;
184 Ok(self.make_token(TokenKind::Comma, start))
185 }
186 b';' => {
187 self.pos += 1;
188 Ok(self.make_token(TokenKind::Semicolon, start))
189 }
190 b'(' => {
191 self.pos += 1;
192 Ok(self.make_token(TokenKind::LParen, start))
193 }
194 b')' => {
195 self.pos += 1;
196 Ok(self.make_token(TokenKind::RParen, start))
197 }
198 b'*' => {
199 self.pos += 1;
200 Ok(self.make_token(TokenKind::Star, start))
201 }
202 b'+' => {
203 self.pos += 1;
204 Ok(self.make_token(TokenKind::Plus, start))
205 }
206 b'-' => {
207 self.pos += 1;
209 Ok(self.make_token(TokenKind::Minus, start))
210 }
211 b'/' => {
212 self.pos += 1;
214 Ok(self.make_token(TokenKind::Slash, start))
215 }
216 b'%' => {
217 self.pos += 1;
218 Ok(self.make_token(TokenKind::Percent, start))
219 }
220 b'=' => {
221 self.pos += 1;
222 Ok(self.make_token(TokenKind::Eq, start))
223 }
224
225 b'<' => {
227 self.pos += 1;
228 match self.peek() {
229 Some(b'=') => {
230 self.pos += 1;
231 Ok(self.make_token(TokenKind::LtEq, start))
232 }
233 Some(b'>') => {
234 self.pos += 1;
235 Ok(self.make_token(TokenKind::Neq, start))
236 }
237 _ => Ok(self.make_token(TokenKind::Lt, start)),
238 }
239 }
240
241 b'>' => {
243 self.pos += 1;
244 if self.peek() == Some(b'=') {
245 self.pos += 1;
246 Ok(self.make_token(TokenKind::GtEq, start))
247 } else {
248 Ok(self.make_token(TokenKind::Gt, start))
249 }
250 }
251
252 b'!' if self.peek_at(1) == Some(b'=') => {
254 self.pos += 2;
255 Ok(self.make_token(TokenKind::Neq, start))
256 }
257
258 b'|' if self.peek_at(1) == Some(b'|') => {
260 self.pos += 2;
261 Ok(self.make_token(TokenKind::Concat, start))
262 }
263
264 b':' if self.config.double_colon && self.peek_at(1) == Some(b':') => {
266 self.pos += 2;
267 Ok(self.make_token(TokenKind::ColonColon, start))
268 }
269
270 b':' => {
272 self.pos += 1;
273 if self
274 .peek()
275 .is_some_and(|b| b.is_ascii_alphanumeric() || b == b'_')
276 {
277 while self
278 .peek()
279 .is_some_and(|b| b.is_ascii_alphanumeric() || b == b'_')
280 {
281 self.pos += 1;
282 }
283 Ok(self.make_token(TokenKind::Placeholder, start))
284 } else {
285 Ok(self.make_token(TokenKind::Colon, start))
286 }
287 }
288
289 b'@' => {
291 self.pos += 1;
292 if self.config.double_at && self.peek() == Some(b'@') {
293 self.pos += 1;
294 }
295 self.eat_word_chars();
297 Ok(self.make_token(TokenKind::AtSign, start))
298 }
299
300 b'?' => {
302 self.pos += 1;
303 Ok(self.make_token(TokenKind::Placeholder, start))
304 }
305
306 b'$' => {
308 if self.config.dollar_quoting {
309 self.lex_dollar_quote_or_param(start)
310 } else {
311 self.pos += 1;
312 while self.peek().is_some_and(|b| b.is_ascii_digit()) {
314 self.pos += 1;
315 }
316 Ok(self.make_token(TokenKind::Placeholder, start))
317 }
318 }
319
320 b if is_word_start(b) || b >= 0x80 => {
322 if b >= 0x80 {
323 let s = &self.source[self.pos..];
324 let first_char = s.chars().next().unwrap();
325 self.pos += first_char.len_utf8();
326 } else {
327 self.pos += 1;
328 }
329 self.eat_word_chars();
330 Ok(self.make_token(TokenKind::Word, start))
331 }
332
333 _ => {
334 let ch = self.source[self.pos..].chars().next().unwrap();
335 Err(LexerError::UnexpectedChar {
336 ch,
337 offset: start as u32,
338 })
339 }
340 }
341 }
342
343 fn lex_string_literal(&mut self, start: usize) -> Result<Token, LexerError> {
344 self.pos += 1; loop {
346 match self.peek() {
347 None => {
348 return Err(LexerError::UnterminatedString {
349 offset: start as u32,
350 })
351 }
352 Some(b'\'') => {
353 self.pos += 1;
354 if self.peek() == Some(b'\'') {
356 self.pos += 1;
357 continue;
358 }
359 return Ok(self.make_token(TokenKind::StringLiteral, start));
360 }
361 Some(_) => self.pos += 1,
362 }
363 }
364 }
365
366 fn lex_quoted_identifier(&mut self, start: usize, quote: u8) -> Result<Token, LexerError> {
367 self.pos += 1; loop {
369 match self.peek() {
370 None => {
371 return Err(LexerError::UnterminatedQuotedIdentifier {
372 offset: start as u32,
373 })
374 }
375 Some(b) if b == quote => {
376 self.pos += 1;
377 if self.peek() == Some(quote) {
379 self.pos += 1;
380 continue;
381 }
382 return Ok(self.make_token(TokenKind::QuotedIdentifier, start));
383 }
384 Some(_) => self.pos += 1,
385 }
386 }
387 }
388
389 fn lex_bracket_identifier(&mut self, start: usize) -> Result<Token, LexerError> {
390 self.pos += 1; loop {
392 match self.peek() {
393 None => {
394 return Err(LexerError::UnterminatedQuotedIdentifier {
395 offset: start as u32,
396 })
397 }
398 Some(b']') => {
399 self.pos += 1;
400 return Ok(self.make_token(TokenKind::QuotedIdentifier, start));
401 }
402 Some(_) => self.pos += 1,
403 }
404 }
405 }
406
407 fn lex_number(&mut self, start: usize) -> Result<Token, LexerError> {
408 while self.peek().is_some_and(|b| b.is_ascii_digit()) {
410 self.pos += 1;
411 }
412 if self.peek() == Some(b'.') && self.peek_at(1).is_some_and(|b| b.is_ascii_digit()) {
414 self.pos += 1; while self.peek().is_some_and(|b| b.is_ascii_digit()) {
416 self.pos += 1;
417 }
418 } else if self.bytes[start] == b'.' {
419 self.pos += 1; while self.peek().is_some_and(|b| b.is_ascii_digit()) {
422 self.pos += 1;
423 }
424 }
425 if self.peek().is_some_and(|b| b == b'e' || b == b'E') {
427 self.pos += 1;
428 if self.peek().is_some_and(|b| b == b'+' || b == b'-') {
429 self.pos += 1;
430 }
431 while self.peek().is_some_and(|b| b.is_ascii_digit()) {
432 self.pos += 1;
433 }
434 }
435 Ok(self.make_token(TokenKind::NumberLiteral, start))
436 }
437
438 fn lex_dollar_quote_or_param(&mut self, start: usize) -> Result<Token, LexerError> {
439 let after_dollar = self.pos + 1;
441 if after_dollar < self.bytes.len() {
442 if self.bytes[after_dollar] == b'$' {
444 self.pos += 2; let tag = "";
447 return self.lex_dollar_body(start, tag);
448 }
449 if self.bytes[after_dollar].is_ascii_alphabetic() || self.bytes[after_dollar] == b'_' {
450 let tag_start = after_dollar;
452 let mut p = after_dollar;
453 while p < self.bytes.len()
454 && (self.bytes[p].is_ascii_alphanumeric() || self.bytes[p] == b'_')
455 {
456 p += 1;
457 }
458 if p < self.bytes.len() && self.bytes[p] == b'$' {
459 let tag = &self.source[tag_start..p];
460 self.pos = p + 1; return self.lex_dollar_body(start, tag);
462 }
463 }
464 }
465
466 self.pos += 1;
468 while self.peek().is_some_and(|b| b.is_ascii_digit()) {
469 self.pos += 1;
470 }
471 Ok(self.make_token(TokenKind::Placeholder, start))
472 }
473
474 fn lex_dollar_body(&mut self, start: usize, tag: &str) -> Result<Token, LexerError> {
475 let end_tag = format!("${tag}$");
476 let end_bytes = end_tag.as_bytes();
477 while self.pos + end_bytes.len() <= self.bytes.len() {
478 if &self.bytes[self.pos..self.pos + end_bytes.len()] == end_bytes {
479 self.pos += end_bytes.len();
480 return Ok(self.make_token(TokenKind::StringLiteral, start));
481 }
482 self.pos += 1;
483 }
484 Err(LexerError::UnterminatedString {
486 offset: start as u32,
487 })
488 }
489
490 fn peek(&self) -> Option<u8> {
491 self.bytes.get(self.pos).copied()
492 }
493
494 fn peek_at(&self, offset: usize) -> Option<u8> {
495 self.bytes.get(self.pos + offset).copied()
496 }
497
498 fn eat_word_chars(&mut self) {
500 while self.pos < self.bytes.len() {
501 let b = self.bytes[self.pos];
502 if is_word_continue(b) {
503 self.pos += 1;
504 } else if b >= 0x80 {
505 let remaining = &self.source[self.pos..];
506 if let Some(c) = remaining.chars().next() {
507 if c.is_alphanumeric() || c == '_' {
508 self.pos += c.len_utf8();
509 } else {
510 break;
511 }
512 } else {
513 break;
514 }
515 } else {
516 break;
517 }
518 }
519 }
520
521 fn make_token(&self, kind: TokenKind, start: usize) -> Token {
522 let text = &self.source[start..self.pos];
523 Token::new(
524 kind,
525 Span::new(start as u32, self.pos as u32),
526 SmolStr::new(text),
527 )
528 }
529}
530
531fn is_word_start(b: u8) -> bool {
532 b.is_ascii_alphabetic() || b == b'_' || b == b'#'
533}
534
535fn is_word_continue(b: u8) -> bool {
536 b.is_ascii_alphanumeric() || b == b'_' || b == b'#'
537}
538
539#[cfg(test)]
540mod tests {
541 use super::*;
542
543 fn lex(input: &str) -> Vec<Token> {
544 let mut lexer = Lexer::new(input, LexerConfig::ansi());
545 lexer.tokenize().unwrap()
546 }
547
548 fn kinds(input: &str) -> Vec<TokenKind> {
549 lex(input).into_iter().map(|t| t.kind).collect()
550 }
551
552 #[test]
553 fn test_simple_select() {
554 let tokens = lex("SELECT 1");
555 assert_eq!(tokens.len(), 4); assert_eq!(tokens[0].kind, TokenKind::Word);
557 assert_eq!(tokens[0].text.as_str(), "SELECT");
558 assert_eq!(tokens[1].kind, TokenKind::Whitespace);
559 assert_eq!(tokens[2].kind, TokenKind::NumberLiteral);
560 assert_eq!(tokens[2].text.as_str(), "1");
561 assert_eq!(tokens[3].kind, TokenKind::Eof);
562 }
563
564 #[test]
565 fn test_select_star() {
566 let k = kinds("SELECT * FROM users;");
567 assert_eq!(
568 k,
569 vec![
570 TokenKind::Word, TokenKind::Whitespace, TokenKind::Star, TokenKind::Whitespace, TokenKind::Word, TokenKind::Whitespace, TokenKind::Word, TokenKind::Semicolon, TokenKind::Eof,
579 ]
580 );
581 }
582
583 #[test]
584 fn test_string_literal() {
585 let tokens = lex("'hello world'");
586 assert_eq!(tokens[0].kind, TokenKind::StringLiteral);
587 assert_eq!(tokens[0].text.as_str(), "'hello world'");
588 }
589
590 #[test]
591 fn test_escaped_string() {
592 let tokens = lex("'it''s'");
593 assert_eq!(tokens[0].kind, TokenKind::StringLiteral);
594 assert_eq!(tokens[0].text.as_str(), "'it''s'");
595 }
596
597 #[test]
598 fn test_line_comment() {
599 let tokens = lex("-- comment\nSELECT");
600 assert_eq!(tokens[0].kind, TokenKind::LineComment);
601 assert_eq!(tokens[0].text.as_str(), "-- comment");
602 assert_eq!(tokens[1].kind, TokenKind::Newline);
603 assert_eq!(tokens[2].kind, TokenKind::Word);
604 }
605
606 #[test]
607 fn test_block_comment() {
608 let tokens = lex("/* multi\nline */");
609 assert_eq!(tokens[0].kind, TokenKind::BlockComment);
610 assert_eq!(tokens[0].text.as_str(), "/* multi\nline */");
611 }
612
613 #[test]
614 fn test_nested_block_comment() {
615 let tokens = lex("/* outer /* inner */ end */");
616 assert_eq!(tokens[0].kind, TokenKind::BlockComment);
617 }
618
619 #[test]
620 fn test_operators() {
621 let k = kinds("<= >= <> !=");
622 assert_eq!(
623 k,
624 vec![
625 TokenKind::LtEq,
626 TokenKind::Whitespace,
627 TokenKind::GtEq,
628 TokenKind::Whitespace,
629 TokenKind::Neq,
630 TokenKind::Whitespace,
631 TokenKind::Neq,
632 TokenKind::Eof,
633 ]
634 );
635 }
636
637 #[test]
638 fn test_number_formats() {
639 let tokens = lex("42 3.14 .5 1e10 2.5E-3");
640 let nums: Vec<&str> = tokens
641 .iter()
642 .filter(|t| t.kind == TokenKind::NumberLiteral)
643 .map(|t| t.text.as_str())
644 .collect();
645 assert_eq!(nums, vec!["42", "3.14", ".5", "1e10", "2.5E-3"]);
646 }
647
648 #[test]
649 fn test_quoted_identifier() {
650 let tokens = lex("\"my column\"");
651 assert_eq!(tokens[0].kind, TokenKind::QuotedIdentifier);
652 assert_eq!(tokens[0].text.as_str(), "\"my column\"");
653 }
654
655 #[test]
656 fn test_postgres_double_colon() {
657 let mut lexer = Lexer::new("col::int", LexerConfig::postgres());
658 let tokens = lexer.tokenize().unwrap();
659 assert_eq!(tokens[1].kind, TokenKind::ColonColon);
660 }
661
662 #[test]
663 fn test_tsql_bracket_identifier() {
664 let mut lexer = Lexer::new("[my col]", LexerConfig::tsql());
665 let tokens = lexer.tokenize().unwrap();
666 assert_eq!(tokens[0].kind, TokenKind::QuotedIdentifier);
667 assert_eq!(tokens[0].text.as_str(), "[my col]");
668 }
669
670 #[test]
671 fn test_newline_types() {
672 let k = kinds("a\nb\r\nc");
673 assert_eq!(
674 k,
675 vec![
676 TokenKind::Word,
677 TokenKind::Newline,
678 TokenKind::Word,
679 TokenKind::Newline,
680 TokenKind::Word,
681 TokenKind::Eof,
682 ]
683 );
684 }
685
686 #[test]
687 fn test_placeholder() {
688 let tokens = lex(":name ?");
689 assert_eq!(tokens[0].kind, TokenKind::Placeholder);
690 assert_eq!(tokens[0].text.as_str(), ":name");
691 assert_eq!(tokens[2].kind, TokenKind::Placeholder);
692 assert_eq!(tokens[2].text.as_str(), "?");
693 }
694}