Skip to main content

rigsql_lexer/
lexer.rs

1use rigsql_core::{Span, Token, TokenKind};
2use smol_str::SmolStr;
3use thiserror::Error;
4
5#[derive(Debug, Error)]
6pub enum LexerError {
7    #[error("Unexpected character '{ch}' at offset {offset}")]
8    UnexpectedChar { ch: char, offset: u32 },
9    #[error("Unterminated string literal starting at offset {offset}")]
10    UnterminatedString { offset: u32 },
11    #[error("Unterminated block comment starting at offset {offset}")]
12    UnterminatedBlockComment { offset: u32 },
13    #[error("Unterminated quoted identifier starting at offset {offset}")]
14    UnterminatedQuotedIdentifier { offset: u32 },
15}
16
17/// Dialect-specific lexer configuration.
18#[derive(Debug, Clone, Default)]
19pub struct LexerConfig {
20    /// Enable `::` as cast operator (PostgreSQL).
21    pub double_colon: bool,
22    /// Enable `[identifier]` quoting (SQL Server).
23    pub bracket_identifiers: bool,
24    /// Enable backtick identifier quoting (MySQL).
25    pub backtick_identifiers: bool,
26    /// Enable `@@variable` (SQL Server).
27    pub double_at: bool,
28    /// Enable dollar-quoted strings `$$...$$` (PostgreSQL).
29    pub dollar_quoting: bool,
30}
31
32impl LexerConfig {
33    pub fn ansi() -> Self {
34        Self::default()
35    }
36
37    pub fn postgres() -> Self {
38        Self {
39            double_colon: true,
40            dollar_quoting: true,
41            ..Self::default()
42        }
43    }
44
45    pub fn tsql() -> Self {
46        Self {
47            bracket_identifiers: true,
48            double_at: true,
49            ..Self::default()
50        }
51    }
52}
53
54pub struct Lexer<'a> {
55    source: &'a str,
56    bytes: &'a [u8],
57    pos: usize,
58    config: LexerConfig,
59}
60
61impl<'a> Lexer<'a> {
62    pub fn new(source: &'a str, config: LexerConfig) -> Self {
63        Self {
64            source,
65            bytes: source.as_bytes(),
66            pos: 0,
67            config,
68        }
69    }
70
71    /// Tokenize the entire source into a Vec of tokens.
72    pub fn tokenize(&mut self) -> Result<Vec<Token>, LexerError> {
73        let mut tokens = Vec::new();
74        loop {
75            let token = self.next_token()?;
76            let is_eof = token.kind == TokenKind::Eof;
77            tokens.push(token);
78            if is_eof {
79                break;
80            }
81        }
82        Ok(tokens)
83    }
84
85    fn next_token(&mut self) -> Result<Token, LexerError> {
86        if self.pos >= self.bytes.len() {
87            return Ok(Token::new(
88                TokenKind::Eof,
89                Span::new(self.pos as u32, self.pos as u32),
90                "",
91            ));
92        }
93
94        let start = self.pos;
95        let ch = self.bytes[self.pos];
96
97        match ch {
98            // Newline
99            b'\n' => {
100                self.pos += 1;
101                Ok(self.make_token(TokenKind::Newline, start))
102            }
103            b'\r' => {
104                self.pos += 1;
105                if self.peek() == Some(b'\n') {
106                    self.pos += 1;
107                }
108                Ok(self.make_token(TokenKind::Newline, start))
109            }
110
111            // Whitespace (not newline)
112            b' ' | b'\t' => {
113                self.pos += 1;
114                while let Some(b) = self.peek() {
115                    if b == b' ' || b == b'\t' {
116                        self.pos += 1;
117                    } else {
118                        break;
119                    }
120                }
121                Ok(self.make_token(TokenKind::Whitespace, start))
122            }
123
124            // Line comment: -- ...
125            b'-' if self.peek_at(1) == Some(b'-') => {
126                self.pos += 2;
127                while let Some(b) = self.peek() {
128                    if b == b'\n' || b == b'\r' {
129                        break;
130                    }
131                    self.pos += 1;
132                }
133                Ok(self.make_token(TokenKind::LineComment, start))
134            }
135
136            // Block comment: /* ... */
137            b'/' if self.peek_at(1) == Some(b'*') => {
138                self.pos += 2;
139                let mut depth = 1u32;
140                while self.pos < self.bytes.len() && depth > 0 {
141                    if self.bytes[self.pos] == b'/' && self.peek_at(1) == Some(b'*') {
142                        depth += 1;
143                        self.pos += 2;
144                    } else if self.bytes[self.pos] == b'*' && self.peek_at(1) == Some(b'/') {
145                        depth -= 1;
146                        self.pos += 2;
147                    } else {
148                        self.pos += 1;
149                    }
150                }
151                if depth > 0 {
152                    return Err(LexerError::UnterminatedBlockComment {
153                        offset: start as u32,
154                    });
155                }
156                Ok(self.make_token(TokenKind::BlockComment, start))
157            }
158
159            // String literal: 'hello'
160            b'\'' => self.lex_string_literal(start),
161
162            // Double-quoted identifier: "name"
163            b'"' => self.lex_quoted_identifier(start, b'"'),
164
165            // Bracket-quoted identifier: [name] (SQL Server)
166            b'[' if self.config.bracket_identifiers => self.lex_bracket_identifier(start),
167
168            // Backtick identifier: `name` (MySQL)
169            b'`' if self.config.backtick_identifiers => self.lex_quoted_identifier(start, b'`'),
170
171            // Numbers
172            b'0'..=b'9' => self.lex_number(start),
173
174            // Dot (could be start of .123 numeric or just dot)
175            b'.' if self.peek_at(1).is_some_and(|b| b.is_ascii_digit()) => self.lex_number(start),
176
177            // Single-character operators & punctuation
178            b'.' => {
179                self.pos += 1;
180                Ok(self.make_token(TokenKind::Dot, start))
181            }
182            b',' => {
183                self.pos += 1;
184                Ok(self.make_token(TokenKind::Comma, start))
185            }
186            b';' => {
187                self.pos += 1;
188                Ok(self.make_token(TokenKind::Semicolon, start))
189            }
190            b'(' => {
191                self.pos += 1;
192                Ok(self.make_token(TokenKind::LParen, start))
193            }
194            b')' => {
195                self.pos += 1;
196                Ok(self.make_token(TokenKind::RParen, start))
197            }
198            b'*' => {
199                self.pos += 1;
200                Ok(self.make_token(TokenKind::Star, start))
201            }
202            b'+' => {
203                self.pos += 1;
204                Ok(self.make_token(TokenKind::Plus, start))
205            }
206            b'-' => {
207                // Single minus (-- already handled above)
208                self.pos += 1;
209                Ok(self.make_token(TokenKind::Minus, start))
210            }
211            b'/' => {
212                // Single slash (/* already handled above)
213                self.pos += 1;
214                Ok(self.make_token(TokenKind::Slash, start))
215            }
216            b'%' => {
217                self.pos += 1;
218                Ok(self.make_token(TokenKind::Percent, start))
219            }
220            b'=' => {
221                self.pos += 1;
222                Ok(self.make_token(TokenKind::Eq, start))
223            }
224
225            // < <= <> operators
226            b'<' => {
227                self.pos += 1;
228                match self.peek() {
229                    Some(b'=') => {
230                        self.pos += 1;
231                        Ok(self.make_token(TokenKind::LtEq, start))
232                    }
233                    Some(b'>') => {
234                        self.pos += 1;
235                        Ok(self.make_token(TokenKind::Neq, start))
236                    }
237                    _ => Ok(self.make_token(TokenKind::Lt, start)),
238                }
239            }
240
241            // > >= operators
242            b'>' => {
243                self.pos += 1;
244                if self.peek() == Some(b'=') {
245                    self.pos += 1;
246                    Ok(self.make_token(TokenKind::GtEq, start))
247                } else {
248                    Ok(self.make_token(TokenKind::Gt, start))
249                }
250            }
251
252            // != operator
253            b'!' if self.peek_at(1) == Some(b'=') => {
254                self.pos += 2;
255                Ok(self.make_token(TokenKind::Neq, start))
256            }
257
258            // || concat operator
259            b'|' if self.peek_at(1) == Some(b'|') => {
260                self.pos += 2;
261                Ok(self.make_token(TokenKind::Concat, start))
262            }
263
264            // :: cast (PostgreSQL)
265            b':' if self.config.double_colon && self.peek_at(1) == Some(b':') => {
266                self.pos += 2;
267                Ok(self.make_token(TokenKind::ColonColon, start))
268            }
269
270            // : named parameter
271            b':' => {
272                self.pos += 1;
273                if self
274                    .peek()
275                    .is_some_and(|b| b.is_ascii_alphanumeric() || b == b'_')
276                {
277                    while self
278                        .peek()
279                        .is_some_and(|b| b.is_ascii_alphanumeric() || b == b'_')
280                    {
281                        self.pos += 1;
282                    }
283                    Ok(self.make_token(TokenKind::Placeholder, start))
284                } else {
285                    Ok(self.make_token(TokenKind::Colon, start))
286                }
287            }
288
289            // @ or @@ (SQL Server)
290            b'@' => {
291                self.pos += 1;
292                if self.config.double_at && self.peek() == Some(b'@') {
293                    self.pos += 1;
294                }
295                // Read variable name (including non-ASCII chars like Japanese)
296                self.eat_word_chars();
297                Ok(self.make_token(TokenKind::AtSign, start))
298            }
299
300            // ? positional parameter
301            b'?' => {
302                self.pos += 1;
303                Ok(self.make_token(TokenKind::Placeholder, start))
304            }
305
306            // $ positional parameter ($1) or dollar-quoting (PostgreSQL)
307            b'$' => {
308                if self.config.dollar_quoting {
309                    self.lex_dollar_quote_or_param(start)
310                } else {
311                    self.pos += 1;
312                    // $1, $2 etc
313                    while self.peek().is_some_and(|b| b.is_ascii_digit()) {
314                        self.pos += 1;
315                    }
316                    Ok(self.make_token(TokenKind::Placeholder, start))
317                }
318            }
319
320            // Word: keyword or identifier (including non-ASCII like Japanese)
321            b if is_word_start(b) || b >= 0x80 => {
322                if b >= 0x80 {
323                    let s = &self.source[self.pos..];
324                    let first_char = s.chars().next().unwrap();
325                    self.pos += first_char.len_utf8();
326                } else {
327                    self.pos += 1;
328                }
329                self.eat_word_chars();
330                Ok(self.make_token(TokenKind::Word, start))
331            }
332
333            _ => {
334                let ch = self.source[self.pos..].chars().next().unwrap();
335                Err(LexerError::UnexpectedChar {
336                    ch,
337                    offset: start as u32,
338                })
339            }
340        }
341    }
342
343    fn lex_string_literal(&mut self, start: usize) -> Result<Token, LexerError> {
344        self.pos += 1; // skip opening quote
345        loop {
346            match self.peek() {
347                None => {
348                    return Err(LexerError::UnterminatedString {
349                        offset: start as u32,
350                    })
351                }
352                Some(b'\'') => {
353                    self.pos += 1;
354                    // Escaped quote ''
355                    if self.peek() == Some(b'\'') {
356                        self.pos += 1;
357                        continue;
358                    }
359                    return Ok(self.make_token(TokenKind::StringLiteral, start));
360                }
361                Some(_) => self.pos += 1,
362            }
363        }
364    }
365
366    fn lex_quoted_identifier(&mut self, start: usize, quote: u8) -> Result<Token, LexerError> {
367        self.pos += 1; // skip opening quote
368        loop {
369            match self.peek() {
370                None => {
371                    return Err(LexerError::UnterminatedQuotedIdentifier {
372                        offset: start as u32,
373                    })
374                }
375                Some(b) if b == quote => {
376                    self.pos += 1;
377                    // Escaped quote
378                    if self.peek() == Some(quote) {
379                        self.pos += 1;
380                        continue;
381                    }
382                    return Ok(self.make_token(TokenKind::QuotedIdentifier, start));
383                }
384                Some(_) => self.pos += 1,
385            }
386        }
387    }
388
389    fn lex_bracket_identifier(&mut self, start: usize) -> Result<Token, LexerError> {
390        self.pos += 1; // skip [
391        loop {
392            match self.peek() {
393                None => {
394                    return Err(LexerError::UnterminatedQuotedIdentifier {
395                        offset: start as u32,
396                    })
397                }
398                Some(b']') => {
399                    self.pos += 1;
400                    return Ok(self.make_token(TokenKind::QuotedIdentifier, start));
401                }
402                Some(_) => self.pos += 1,
403            }
404        }
405    }
406
407    fn lex_number(&mut self, start: usize) -> Result<Token, LexerError> {
408        // Integer part
409        while self.peek().is_some_and(|b| b.is_ascii_digit()) {
410            self.pos += 1;
411        }
412        // Decimal part
413        if self.peek() == Some(b'.') && self.peek_at(1).is_some_and(|b| b.is_ascii_digit()) {
414            self.pos += 1; // skip .
415            while self.peek().is_some_and(|b| b.is_ascii_digit()) {
416                self.pos += 1;
417            }
418        } else if self.bytes[start] == b'.' {
419            // .123 form — dot already consumed before we got here
420            self.pos += 1; // skip .
421            while self.peek().is_some_and(|b| b.is_ascii_digit()) {
422                self.pos += 1;
423            }
424        }
425        // Exponent part
426        if self.peek().is_some_and(|b| b == b'e' || b == b'E') {
427            self.pos += 1;
428            if self.peek().is_some_and(|b| b == b'+' || b == b'-') {
429                self.pos += 1;
430            }
431            while self.peek().is_some_and(|b| b.is_ascii_digit()) {
432                self.pos += 1;
433            }
434        }
435        Ok(self.make_token(TokenKind::NumberLiteral, start))
436    }
437
438    fn lex_dollar_quote_or_param(&mut self, start: usize) -> Result<Token, LexerError> {
439        // Check if it's a dollar-quoted string: $tag$...$tag$ or $$...$$
440        let after_dollar = self.pos + 1;
441        if after_dollar < self.bytes.len() {
442            // $$ or $tag$
443            if self.bytes[after_dollar] == b'$' {
444                // $$...$$ form
445                self.pos += 2; // skip $$
446                let tag = "";
447                return self.lex_dollar_body(start, tag);
448            }
449            if self.bytes[after_dollar].is_ascii_alphabetic() || self.bytes[after_dollar] == b'_' {
450                // $tag$...$tag$ form
451                let tag_start = after_dollar;
452                let mut p = after_dollar;
453                while p < self.bytes.len()
454                    && (self.bytes[p].is_ascii_alphanumeric() || self.bytes[p] == b'_')
455                {
456                    p += 1;
457                }
458                if p < self.bytes.len() && self.bytes[p] == b'$' {
459                    let tag = &self.source[tag_start..p];
460                    self.pos = p + 1; // skip closing $
461                    return self.lex_dollar_body(start, tag);
462                }
463            }
464        }
465
466        // Plain parameter: $1, $2, etc.
467        self.pos += 1;
468        while self.peek().is_some_and(|b| b.is_ascii_digit()) {
469            self.pos += 1;
470        }
471        Ok(self.make_token(TokenKind::Placeholder, start))
472    }
473
474    fn lex_dollar_body(&mut self, start: usize, tag: &str) -> Result<Token, LexerError> {
475        let end_tag = format!("${tag}$");
476        let end_bytes = end_tag.as_bytes();
477        while self.pos + end_bytes.len() <= self.bytes.len() {
478            if &self.bytes[self.pos..self.pos + end_bytes.len()] == end_bytes {
479                self.pos += end_bytes.len();
480                return Ok(self.make_token(TokenKind::StringLiteral, start));
481            }
482            self.pos += 1;
483        }
484        // If we hit EOF without closing, treat as unterminated string
485        Err(LexerError::UnterminatedString {
486            offset: start as u32,
487        })
488    }
489
490    fn peek(&self) -> Option<u8> {
491        self.bytes.get(self.pos).copied()
492    }
493
494    fn peek_at(&self, offset: usize) -> Option<u8> {
495        self.bytes.get(self.pos + offset).copied()
496    }
497
498    /// Advance past word-like characters (ASCII alphanumeric, `_`, and non-ASCII alphanumeric).
499    fn eat_word_chars(&mut self) {
500        while self.pos < self.bytes.len() {
501            let b = self.bytes[self.pos];
502            if is_word_continue(b) {
503                self.pos += 1;
504            } else if b >= 0x80 {
505                let remaining = &self.source[self.pos..];
506                if let Some(c) = remaining.chars().next() {
507                    if c.is_alphanumeric() || c == '_' {
508                        self.pos += c.len_utf8();
509                    } else {
510                        break;
511                    }
512                } else {
513                    break;
514                }
515            } else {
516                break;
517            }
518        }
519    }
520
521    fn make_token(&self, kind: TokenKind, start: usize) -> Token {
522        let text = &self.source[start..self.pos];
523        Token::new(
524            kind,
525            Span::new(start as u32, self.pos as u32),
526            SmolStr::new(text),
527        )
528    }
529}
530
531fn is_word_start(b: u8) -> bool {
532    b.is_ascii_alphabetic() || b == b'_' || b == b'#'
533}
534
535fn is_word_continue(b: u8) -> bool {
536    b.is_ascii_alphanumeric() || b == b'_' || b == b'#'
537}
538
539#[cfg(test)]
540mod tests {
541    use super::*;
542
543    fn lex(input: &str) -> Vec<Token> {
544        let mut lexer = Lexer::new(input, LexerConfig::ansi());
545        lexer.tokenize().unwrap()
546    }
547
548    fn kinds(input: &str) -> Vec<TokenKind> {
549        lex(input).into_iter().map(|t| t.kind).collect()
550    }
551
552    #[test]
553    fn test_simple_select() {
554        let tokens = lex("SELECT 1");
555        assert_eq!(tokens.len(), 4); // SELECT, WS, 1, EOF
556        assert_eq!(tokens[0].kind, TokenKind::Word);
557        assert_eq!(tokens[0].text.as_str(), "SELECT");
558        assert_eq!(tokens[1].kind, TokenKind::Whitespace);
559        assert_eq!(tokens[2].kind, TokenKind::NumberLiteral);
560        assert_eq!(tokens[2].text.as_str(), "1");
561        assert_eq!(tokens[3].kind, TokenKind::Eof);
562    }
563
564    #[test]
565    fn test_select_star() {
566        let k = kinds("SELECT * FROM users;");
567        assert_eq!(
568            k,
569            vec![
570                TokenKind::Word,       // SELECT
571                TokenKind::Whitespace, // ' '
572                TokenKind::Star,       // *
573                TokenKind::Whitespace, // ' '
574                TokenKind::Word,       // FROM
575                TokenKind::Whitespace, // ' '
576                TokenKind::Word,       // users
577                TokenKind::Semicolon,  // ;
578                TokenKind::Eof,
579            ]
580        );
581    }
582
583    #[test]
584    fn test_string_literal() {
585        let tokens = lex("'hello world'");
586        assert_eq!(tokens[0].kind, TokenKind::StringLiteral);
587        assert_eq!(tokens[0].text.as_str(), "'hello world'");
588    }
589
590    #[test]
591    fn test_escaped_string() {
592        let tokens = lex("'it''s'");
593        assert_eq!(tokens[0].kind, TokenKind::StringLiteral);
594        assert_eq!(tokens[0].text.as_str(), "'it''s'");
595    }
596
597    #[test]
598    fn test_line_comment() {
599        let tokens = lex("-- comment\nSELECT");
600        assert_eq!(tokens[0].kind, TokenKind::LineComment);
601        assert_eq!(tokens[0].text.as_str(), "-- comment");
602        assert_eq!(tokens[1].kind, TokenKind::Newline);
603        assert_eq!(tokens[2].kind, TokenKind::Word);
604    }
605
606    #[test]
607    fn test_block_comment() {
608        let tokens = lex("/* multi\nline */");
609        assert_eq!(tokens[0].kind, TokenKind::BlockComment);
610        assert_eq!(tokens[0].text.as_str(), "/* multi\nline */");
611    }
612
613    #[test]
614    fn test_nested_block_comment() {
615        let tokens = lex("/* outer /* inner */ end */");
616        assert_eq!(tokens[0].kind, TokenKind::BlockComment);
617    }
618
619    #[test]
620    fn test_operators() {
621        let k = kinds("<= >= <> !=");
622        assert_eq!(
623            k,
624            vec![
625                TokenKind::LtEq,
626                TokenKind::Whitespace,
627                TokenKind::GtEq,
628                TokenKind::Whitespace,
629                TokenKind::Neq,
630                TokenKind::Whitespace,
631                TokenKind::Neq,
632                TokenKind::Eof,
633            ]
634        );
635    }
636
637    #[test]
638    fn test_number_formats() {
639        let tokens = lex("42 3.14 .5 1e10 2.5E-3");
640        let nums: Vec<&str> = tokens
641            .iter()
642            .filter(|t| t.kind == TokenKind::NumberLiteral)
643            .map(|t| t.text.as_str())
644            .collect();
645        assert_eq!(nums, vec!["42", "3.14", ".5", "1e10", "2.5E-3"]);
646    }
647
648    #[test]
649    fn test_quoted_identifier() {
650        let tokens = lex("\"my column\"");
651        assert_eq!(tokens[0].kind, TokenKind::QuotedIdentifier);
652        assert_eq!(tokens[0].text.as_str(), "\"my column\"");
653    }
654
655    #[test]
656    fn test_postgres_double_colon() {
657        let mut lexer = Lexer::new("col::int", LexerConfig::postgres());
658        let tokens = lexer.tokenize().unwrap();
659        assert_eq!(tokens[1].kind, TokenKind::ColonColon);
660    }
661
662    #[test]
663    fn test_tsql_bracket_identifier() {
664        let mut lexer = Lexer::new("[my col]", LexerConfig::tsql());
665        let tokens = lexer.tokenize().unwrap();
666        assert_eq!(tokens[0].kind, TokenKind::QuotedIdentifier);
667        assert_eq!(tokens[0].text.as_str(), "[my col]");
668    }
669
670    #[test]
671    fn test_newline_types() {
672        let k = kinds("a\nb\r\nc");
673        assert_eq!(
674            k,
675            vec![
676                TokenKind::Word,
677                TokenKind::Newline,
678                TokenKind::Word,
679                TokenKind::Newline,
680                TokenKind::Word,
681                TokenKind::Eof,
682            ]
683        );
684    }
685
686    #[test]
687    fn test_placeholder() {
688        let tokens = lex(":name ?");
689        assert_eq!(tokens[0].kind, TokenKind::Placeholder);
690        assert_eq!(tokens[0].text.as_str(), ":name");
691        assert_eq!(tokens[2].kind, TokenKind::Placeholder);
692        assert_eq!(tokens[2].text.as_str(), "?");
693    }
694}