1use rigsql_core::{Span, Token, TokenKind};
2use smol_str::SmolStr;
3use thiserror::Error;
4
5#[derive(Debug, Error)]
6pub enum LexerError {
7 #[error("Unexpected character '{ch}' at offset {offset}")]
8 UnexpectedChar { ch: char, offset: u32 },
9 #[error("Unterminated string literal starting at offset {offset}")]
10 UnterminatedString { offset: u32 },
11 #[error("Unterminated block comment starting at offset {offset}")]
12 UnterminatedBlockComment { offset: u32 },
13 #[error("Unterminated quoted identifier starting at offset {offset}")]
14 UnterminatedQuotedIdentifier { offset: u32 },
15}
16
17#[derive(Debug, Clone, Default)]
19pub struct LexerConfig {
20 pub double_colon: bool,
22 pub bracket_identifiers: bool,
24 pub backtick_identifiers: bool,
26 pub double_at: bool,
28 pub dollar_quoting: bool,
30}
31
32impl LexerConfig {
33 pub fn ansi() -> Self {
34 Self::default()
35 }
36
37 pub fn postgres() -> Self {
38 Self {
39 double_colon: true,
40 dollar_quoting: true,
41 ..Self::default()
42 }
43 }
44
45 pub fn tsql() -> Self {
46 Self {
47 bracket_identifiers: true,
48 double_at: true,
49 ..Self::default()
50 }
51 }
52}
53
54pub struct Lexer<'a> {
55 source: &'a str,
56 bytes: &'a [u8],
57 pos: usize,
58 config: LexerConfig,
59}
60
61impl<'a> Lexer<'a> {
62 pub fn new(source: &'a str, config: LexerConfig) -> Self {
63 Self {
64 source,
65 bytes: source.as_bytes(),
66 pos: 0,
67 config,
68 }
69 }
70
71 pub fn tokenize(&mut self) -> Result<Vec<Token>, LexerError> {
73 let mut tokens = Vec::new();
74 loop {
75 let token = self.next_token()?;
76 let is_eof = token.kind == TokenKind::Eof;
77 tokens.push(token);
78 if is_eof {
79 break;
80 }
81 }
82 Ok(tokens)
83 }
84
85 fn next_token(&mut self) -> Result<Token, LexerError> {
86 if self.pos >= self.bytes.len() {
87 return Ok(Token::new(
88 TokenKind::Eof,
89 Span::new(self.pos as u32, self.pos as u32),
90 "",
91 ));
92 }
93
94 let start = self.pos;
95 let ch = self.bytes[self.pos];
96
97 match ch {
98 b'\n' => {
100 self.pos += 1;
101 Ok(self.make_token(TokenKind::Newline, start))
102 }
103 b'\r' => {
104 self.pos += 1;
105 if self.peek() == Some(b'\n') {
106 self.pos += 1;
107 }
108 Ok(self.make_token(TokenKind::Newline, start))
109 }
110
111 b' ' | b'\t' => {
113 self.pos += 1;
114 while let Some(b) = self.peek() {
115 if b == b' ' || b == b'\t' {
116 self.pos += 1;
117 } else {
118 break;
119 }
120 }
121 Ok(self.make_token(TokenKind::Whitespace, start))
122 }
123
124 b'-' if self.peek_at(1) == Some(b'-') => {
126 self.pos += 2;
127 while let Some(b) = self.peek() {
128 if b == b'\n' || b == b'\r' {
129 break;
130 }
131 self.pos += 1;
132 }
133 Ok(self.make_token(TokenKind::LineComment, start))
134 }
135
136 b'/' if self.peek_at(1) == Some(b'*') => {
138 self.pos += 2;
139 let mut depth = 1u32;
140 while self.pos < self.bytes.len() && depth > 0 {
141 if self.bytes[self.pos] == b'/' && self.peek_at(1) == Some(b'*') {
142 depth += 1;
143 self.pos += 2;
144 } else if self.bytes[self.pos] == b'*' && self.peek_at(1) == Some(b'/') {
145 depth -= 1;
146 self.pos += 2;
147 } else {
148 self.pos += 1;
149 }
150 }
151 if depth > 0 {
152 return Err(LexerError::UnterminatedBlockComment {
153 offset: start as u32,
154 });
155 }
156 Ok(self.make_token(TokenKind::BlockComment, start))
157 }
158
159 b'\'' => self.lex_string_literal(start),
161
162 b'"' => self.lex_quoted_identifier(start, b'"'),
164
165 b'[' if self.config.bracket_identifiers => self.lex_bracket_identifier(start),
167
168 b'[' => {
170 self.pos += 1;
171 Ok(self.make_token(TokenKind::LBracket, start))
172 }
173 b']' => {
174 self.pos += 1;
175 Ok(self.make_token(TokenKind::RBracket, start))
176 }
177
178 b'`' if self.config.backtick_identifiers => self.lex_quoted_identifier(start, b'`'),
180
181 b'0'..=b'9' => self.lex_number(start),
183
184 b'.' if self.peek_at(1).is_some_and(|b| b.is_ascii_digit()) => self.lex_number(start),
186
187 b'.' => {
189 self.pos += 1;
190 Ok(self.make_token(TokenKind::Dot, start))
191 }
192 b',' => {
193 self.pos += 1;
194 Ok(self.make_token(TokenKind::Comma, start))
195 }
196 b';' => {
197 self.pos += 1;
198 Ok(self.make_token(TokenKind::Semicolon, start))
199 }
200 b'(' => {
201 self.pos += 1;
202 Ok(self.make_token(TokenKind::LParen, start))
203 }
204 b')' => {
205 self.pos += 1;
206 Ok(self.make_token(TokenKind::RParen, start))
207 }
208 b'*' => {
209 self.pos += 1;
210 Ok(self.make_token(TokenKind::Star, start))
211 }
212 b'+' => {
213 self.pos += 1;
214 Ok(self.make_token(TokenKind::Plus, start))
215 }
216 b'-' => {
217 self.pos += 1;
219 Ok(self.make_token(TokenKind::Minus, start))
220 }
221 b'/' => {
222 self.pos += 1;
224 Ok(self.make_token(TokenKind::Slash, start))
225 }
226 b'%' => {
227 self.pos += 1;
228 Ok(self.make_token(TokenKind::Percent, start))
229 }
230 b'=' => {
231 self.pos += 1;
232 Ok(self.make_token(TokenKind::Eq, start))
233 }
234
235 b'<' => {
237 self.pos += 1;
238 match self.peek() {
239 Some(b'=') => {
240 self.pos += 1;
241 Ok(self.make_token(TokenKind::LtEq, start))
242 }
243 Some(b'>') => {
244 self.pos += 1;
245 Ok(self.make_token(TokenKind::Neq, start))
246 }
247 _ => Ok(self.make_token(TokenKind::Lt, start)),
248 }
249 }
250
251 b'>' => {
253 self.pos += 1;
254 if self.peek() == Some(b'=') {
255 self.pos += 1;
256 Ok(self.make_token(TokenKind::GtEq, start))
257 } else {
258 Ok(self.make_token(TokenKind::Gt, start))
259 }
260 }
261
262 b'!' if self.peek_at(1) == Some(b'=') => {
264 self.pos += 2;
265 Ok(self.make_token(TokenKind::Neq, start))
266 }
267
268 b'|' if self.peek_at(1) == Some(b'|') => {
270 self.pos += 2;
271 Ok(self.make_token(TokenKind::Concat, start))
272 }
273
274 b':' if self.config.double_colon && self.peek_at(1) == Some(b':') => {
276 self.pos += 2;
277 Ok(self.make_token(TokenKind::ColonColon, start))
278 }
279
280 b':' => {
282 self.pos += 1;
283 if self
284 .peek()
285 .is_some_and(|b| b.is_ascii_alphanumeric() || b == b'_')
286 {
287 while self
288 .peek()
289 .is_some_and(|b| b.is_ascii_alphanumeric() || b == b'_')
290 {
291 self.pos += 1;
292 }
293 Ok(self.make_token(TokenKind::Placeholder, start))
294 } else {
295 Ok(self.make_token(TokenKind::Colon, start))
296 }
297 }
298
299 b'@' => {
301 self.pos += 1;
302 if self.config.double_at && self.peek() == Some(b'@') {
303 self.pos += 1;
304 }
305 self.eat_word_chars();
307 Ok(self.make_token(TokenKind::AtSign, start))
308 }
309
310 b'?' => {
312 self.pos += 1;
313 Ok(self.make_token(TokenKind::Placeholder, start))
314 }
315
316 b'$' => {
318 if self.config.dollar_quoting {
319 self.lex_dollar_quote_or_param(start)
320 } else {
321 self.pos += 1;
322 while self.peek().is_some_and(|b| b.is_ascii_digit()) {
324 self.pos += 1;
325 }
326 Ok(self.make_token(TokenKind::Placeholder, start))
327 }
328 }
329
330 b if is_word_start(b) || b >= 0x80 => {
332 if b >= 0x80 {
333 let s = &self.source[self.pos..];
334 let first_char = s.chars().next().unwrap();
335 self.pos += first_char.len_utf8();
336 } else {
337 self.pos += 1;
338 }
339 self.eat_word_chars();
340 let word = &self.source[start..self.pos];
342 if word.eq_ignore_ascii_case("N") && self.peek() == Some(b'\'') {
343 return self.lex_string_literal(start);
344 }
345 Ok(self.make_token(TokenKind::Word, start))
346 }
347
348 _ => {
349 let ch = self.source[self.pos..].chars().next().unwrap();
350 Err(LexerError::UnexpectedChar {
351 ch,
352 offset: start as u32,
353 })
354 }
355 }
356 }
357
358 fn lex_string_literal(&mut self, start: usize) -> Result<Token, LexerError> {
359 self.pos += 1; loop {
361 match self.peek() {
362 None => {
363 return Err(LexerError::UnterminatedString {
364 offset: start as u32,
365 })
366 }
367 Some(b'\'') => {
368 self.pos += 1;
369 if self.peek() == Some(b'\'') {
371 self.pos += 1;
372 continue;
373 }
374 return Ok(self.make_token(TokenKind::StringLiteral, start));
375 }
376 Some(_) => self.pos += 1,
377 }
378 }
379 }
380
381 fn lex_quoted_identifier(&mut self, start: usize, quote: u8) -> Result<Token, LexerError> {
382 self.pos += 1; loop {
384 match self.peek() {
385 None => {
386 return Err(LexerError::UnterminatedQuotedIdentifier {
387 offset: start as u32,
388 })
389 }
390 Some(b) if b == quote => {
391 self.pos += 1;
392 if self.peek() == Some(quote) {
394 self.pos += 1;
395 continue;
396 }
397 return Ok(self.make_token(TokenKind::QuotedIdentifier, start));
398 }
399 Some(_) => self.pos += 1,
400 }
401 }
402 }
403
404 fn lex_bracket_identifier(&mut self, start: usize) -> Result<Token, LexerError> {
405 self.pos += 1; loop {
407 match self.peek() {
408 None => {
409 return Err(LexerError::UnterminatedQuotedIdentifier {
410 offset: start as u32,
411 })
412 }
413 Some(b']') => {
414 self.pos += 1;
415 return Ok(self.make_token(TokenKind::QuotedIdentifier, start));
416 }
417 Some(_) => self.pos += 1,
418 }
419 }
420 }
421
422 fn lex_number(&mut self, start: usize) -> Result<Token, LexerError> {
423 while self.peek().is_some_and(|b| b.is_ascii_digit()) {
425 self.pos += 1;
426 }
427 if self.peek() == Some(b'.') && self.peek_at(1).is_some_and(|b| b.is_ascii_digit()) {
429 self.pos += 1; while self.peek().is_some_and(|b| b.is_ascii_digit()) {
431 self.pos += 1;
432 }
433 } else if self.bytes[start] == b'.' {
434 self.pos += 1; while self.peek().is_some_and(|b| b.is_ascii_digit()) {
437 self.pos += 1;
438 }
439 }
440 if self.peek().is_some_and(|b| b == b'e' || b == b'E') {
442 self.pos += 1;
443 if self.peek().is_some_and(|b| b == b'+' || b == b'-') {
444 self.pos += 1;
445 }
446 while self.peek().is_some_and(|b| b.is_ascii_digit()) {
447 self.pos += 1;
448 }
449 }
450 Ok(self.make_token(TokenKind::NumberLiteral, start))
451 }
452
453 fn lex_dollar_quote_or_param(&mut self, start: usize) -> Result<Token, LexerError> {
454 let after_dollar = self.pos + 1;
456 if after_dollar < self.bytes.len() {
457 if self.bytes[after_dollar] == b'$' {
459 self.pos += 2; let tag = "";
462 return self.lex_dollar_body(start, tag);
463 }
464 if self.bytes[after_dollar].is_ascii_alphabetic() || self.bytes[after_dollar] == b'_' {
465 let tag_start = after_dollar;
467 let mut p = after_dollar;
468 while p < self.bytes.len()
469 && (self.bytes[p].is_ascii_alphanumeric() || self.bytes[p] == b'_')
470 {
471 p += 1;
472 }
473 if p < self.bytes.len() && self.bytes[p] == b'$' {
474 let tag = &self.source[tag_start..p];
475 self.pos = p + 1; return self.lex_dollar_body(start, tag);
477 }
478 }
479 }
480
481 self.pos += 1;
483 while self.peek().is_some_and(|b| b.is_ascii_digit()) {
484 self.pos += 1;
485 }
486 Ok(self.make_token(TokenKind::Placeholder, start))
487 }
488
489 fn lex_dollar_body(&mut self, start: usize, tag: &str) -> Result<Token, LexerError> {
490 let end_tag = format!("${tag}$");
491 let end_bytes = end_tag.as_bytes();
492 while self.pos + end_bytes.len() <= self.bytes.len() {
493 if &self.bytes[self.pos..self.pos + end_bytes.len()] == end_bytes {
494 self.pos += end_bytes.len();
495 return Ok(self.make_token(TokenKind::StringLiteral, start));
496 }
497 self.pos += 1;
498 }
499 Err(LexerError::UnterminatedString {
501 offset: start as u32,
502 })
503 }
504
505 fn peek(&self) -> Option<u8> {
506 self.bytes.get(self.pos).copied()
507 }
508
509 fn peek_at(&self, offset: usize) -> Option<u8> {
510 self.bytes.get(self.pos + offset).copied()
511 }
512
513 fn eat_word_chars(&mut self) {
515 while self.pos < self.bytes.len() {
516 let b = self.bytes[self.pos];
517 if is_word_continue(b) {
518 self.pos += 1;
519 } else if b >= 0x80 {
520 let remaining = &self.source[self.pos..];
521 if let Some(c) = remaining.chars().next() {
522 if c.is_alphanumeric() || c == '_' {
523 self.pos += c.len_utf8();
524 } else {
525 break;
526 }
527 } else {
528 break;
529 }
530 } else {
531 break;
532 }
533 }
534 }
535
536 fn make_token(&self, kind: TokenKind, start: usize) -> Token {
537 let text = &self.source[start..self.pos];
538 Token::new(
539 kind,
540 Span::new(start as u32, self.pos as u32),
541 SmolStr::new(text),
542 )
543 }
544}
545
546fn is_word_start(b: u8) -> bool {
547 b.is_ascii_alphabetic() || b == b'_' || b == b'#'
548}
549
550fn is_word_continue(b: u8) -> bool {
551 b.is_ascii_alphanumeric() || b == b'_' || b == b'#'
552}
553
554#[cfg(test)]
555mod tests {
556 use super::*;
557
558 fn lex(input: &str) -> Vec<Token> {
559 let mut lexer = Lexer::new(input, LexerConfig::ansi());
560 lexer.tokenize().unwrap()
561 }
562
563 fn kinds(input: &str) -> Vec<TokenKind> {
564 lex(input).into_iter().map(|t| t.kind).collect()
565 }
566
567 #[test]
568 fn test_simple_select() {
569 let tokens = lex("SELECT 1");
570 assert_eq!(tokens.len(), 4); assert_eq!(tokens[0].kind, TokenKind::Word);
572 assert_eq!(tokens[0].text.as_str(), "SELECT");
573 assert_eq!(tokens[1].kind, TokenKind::Whitespace);
574 assert_eq!(tokens[2].kind, TokenKind::NumberLiteral);
575 assert_eq!(tokens[2].text.as_str(), "1");
576 assert_eq!(tokens[3].kind, TokenKind::Eof);
577 }
578
579 #[test]
580 fn test_select_star() {
581 let k = kinds("SELECT * FROM users;");
582 assert_eq!(
583 k,
584 vec![
585 TokenKind::Word, TokenKind::Whitespace, TokenKind::Star, TokenKind::Whitespace, TokenKind::Word, TokenKind::Whitespace, TokenKind::Word, TokenKind::Semicolon, TokenKind::Eof,
594 ]
595 );
596 }
597
598 #[test]
599 fn test_string_literal() {
600 let tokens = lex("'hello world'");
601 assert_eq!(tokens[0].kind, TokenKind::StringLiteral);
602 assert_eq!(tokens[0].text.as_str(), "'hello world'");
603 }
604
605 #[test]
606 fn test_escaped_string() {
607 let tokens = lex("'it''s'");
608 assert_eq!(tokens[0].kind, TokenKind::StringLiteral);
609 assert_eq!(tokens[0].text.as_str(), "'it''s'");
610 }
611
612 #[test]
613 fn test_line_comment() {
614 let tokens = lex("-- comment\nSELECT");
615 assert_eq!(tokens[0].kind, TokenKind::LineComment);
616 assert_eq!(tokens[0].text.as_str(), "-- comment");
617 assert_eq!(tokens[1].kind, TokenKind::Newline);
618 assert_eq!(tokens[2].kind, TokenKind::Word);
619 }
620
621 #[test]
622 fn test_block_comment() {
623 let tokens = lex("/* multi\nline */");
624 assert_eq!(tokens[0].kind, TokenKind::BlockComment);
625 assert_eq!(tokens[0].text.as_str(), "/* multi\nline */");
626 }
627
628 #[test]
629 fn test_nested_block_comment() {
630 let tokens = lex("/* outer /* inner */ end */");
631 assert_eq!(tokens[0].kind, TokenKind::BlockComment);
632 }
633
634 #[test]
635 fn test_operators() {
636 let k = kinds("<= >= <> !=");
637 assert_eq!(
638 k,
639 vec![
640 TokenKind::LtEq,
641 TokenKind::Whitespace,
642 TokenKind::GtEq,
643 TokenKind::Whitespace,
644 TokenKind::Neq,
645 TokenKind::Whitespace,
646 TokenKind::Neq,
647 TokenKind::Eof,
648 ]
649 );
650 }
651
652 #[test]
653 fn test_number_formats() {
654 let tokens = lex("42 3.14 .5 1e10 2.5E-3");
655 let nums: Vec<&str> = tokens
656 .iter()
657 .filter(|t| t.kind == TokenKind::NumberLiteral)
658 .map(|t| t.text.as_str())
659 .collect();
660 assert_eq!(nums, vec!["42", "3.14", ".5", "1e10", "2.5E-3"]);
661 }
662
663 #[test]
664 fn test_quoted_identifier() {
665 let tokens = lex("\"my column\"");
666 assert_eq!(tokens[0].kind, TokenKind::QuotedIdentifier);
667 assert_eq!(tokens[0].text.as_str(), "\"my column\"");
668 }
669
670 #[test]
671 fn test_postgres_double_colon() {
672 let mut lexer = Lexer::new("col::int", LexerConfig::postgres());
673 let tokens = lexer.tokenize().unwrap();
674 assert_eq!(tokens[1].kind, TokenKind::ColonColon);
675 }
676
677 #[test]
678 fn test_tsql_bracket_identifier() {
679 let mut lexer = Lexer::new("[my col]", LexerConfig::tsql());
680 let tokens = lexer.tokenize().unwrap();
681 assert_eq!(tokens[0].kind, TokenKind::QuotedIdentifier);
682 assert_eq!(tokens[0].text.as_str(), "[my col]");
683 }
684
685 #[test]
686 fn test_newline_types() {
687 let k = kinds("a\nb\r\nc");
688 assert_eq!(
689 k,
690 vec![
691 TokenKind::Word,
692 TokenKind::Newline,
693 TokenKind::Word,
694 TokenKind::Newline,
695 TokenKind::Word,
696 TokenKind::Eof,
697 ]
698 );
699 }
700
701 #[test]
702 fn test_placeholder() {
703 let tokens = lex(":name ?");
704 assert_eq!(tokens[0].kind, TokenKind::Placeholder);
705 assert_eq!(tokens[0].text.as_str(), ":name");
706 assert_eq!(tokens[2].kind, TokenKind::Placeholder);
707 assert_eq!(tokens[2].text.as_str(), "?");
708 }
709}