Skip to main content

rigsql_lexer/
lexer.rs

1use rigsql_core::{Span, Token, TokenKind};
2use smol_str::SmolStr;
3use thiserror::Error;
4
5#[derive(Debug, Error)]
6pub enum LexerError {
7    #[error("Unexpected character '{ch}' at offset {offset}")]
8    UnexpectedChar { ch: char, offset: u32 },
9    #[error("Unterminated string literal starting at offset {offset}")]
10    UnterminatedString { offset: u32 },
11    #[error("Unterminated block comment starting at offset {offset}")]
12    UnterminatedBlockComment { offset: u32 },
13    #[error("Unterminated quoted identifier starting at offset {offset}")]
14    UnterminatedQuotedIdentifier { offset: u32 },
15}
16
17/// Dialect-specific lexer configuration.
18#[derive(Debug, Clone, Default)]
19pub struct LexerConfig {
20    /// Enable `::` as cast operator (PostgreSQL).
21    pub double_colon: bool,
22    /// Enable `[identifier]` quoting (SQL Server).
23    pub bracket_identifiers: bool,
24    /// Enable backtick identifier quoting (MySQL).
25    pub backtick_identifiers: bool,
26    /// Enable `@@variable` (SQL Server).
27    pub double_at: bool,
28    /// Enable dollar-quoted strings `$$...$$` (PostgreSQL).
29    pub dollar_quoting: bool,
30}
31
32impl LexerConfig {
33    pub fn ansi() -> Self {
34        Self::default()
35    }
36
37    pub fn postgres() -> Self {
38        Self {
39            double_colon: true,
40            dollar_quoting: true,
41            ..Self::default()
42        }
43    }
44
45    pub fn tsql() -> Self {
46        Self {
47            bracket_identifiers: true,
48            double_at: true,
49            ..Self::default()
50        }
51    }
52}
53
54pub struct Lexer<'a> {
55    source: &'a str,
56    bytes: &'a [u8],
57    pos: usize,
58    config: LexerConfig,
59}
60
61impl<'a> Lexer<'a> {
62    pub fn new(source: &'a str, config: LexerConfig) -> Self {
63        Self {
64            source,
65            bytes: source.as_bytes(),
66            pos: 0,
67            config,
68        }
69    }
70
71    /// Tokenize the entire source into a Vec of tokens.
72    pub fn tokenize(&mut self) -> Result<Vec<Token>, LexerError> {
73        let mut tokens = Vec::new();
74        loop {
75            let token = self.next_token()?;
76            let is_eof = token.kind == TokenKind::Eof;
77            tokens.push(token);
78            if is_eof {
79                break;
80            }
81        }
82        Ok(tokens)
83    }
84
85    fn next_token(&mut self) -> Result<Token, LexerError> {
86        if self.pos >= self.bytes.len() {
87            return Ok(Token::new(
88                TokenKind::Eof,
89                Span::new(self.pos as u32, self.pos as u32),
90                "",
91            ));
92        }
93
94        let start = self.pos;
95        let ch = self.bytes[self.pos];
96
97        match ch {
98            // Newline
99            b'\n' => {
100                self.pos += 1;
101                Ok(self.make_token(TokenKind::Newline, start))
102            }
103            b'\r' => {
104                self.pos += 1;
105                if self.peek() == Some(b'\n') {
106                    self.pos += 1;
107                }
108                Ok(self.make_token(TokenKind::Newline, start))
109            }
110
111            // Whitespace (not newline)
112            b' ' | b'\t' => {
113                self.pos += 1;
114                while let Some(b) = self.peek() {
115                    if b == b' ' || b == b'\t' {
116                        self.pos += 1;
117                    } else {
118                        break;
119                    }
120                }
121                Ok(self.make_token(TokenKind::Whitespace, start))
122            }
123
124            // Line comment: -- ...
125            b'-' if self.peek_at(1) == Some(b'-') => {
126                self.pos += 2;
127                while let Some(b) = self.peek() {
128                    if b == b'\n' || b == b'\r' {
129                        break;
130                    }
131                    self.pos += 1;
132                }
133                Ok(self.make_token(TokenKind::LineComment, start))
134            }
135
136            // Block comment: /* ... */
137            b'/' if self.peek_at(1) == Some(b'*') => {
138                self.pos += 2;
139                let mut depth = 1u32;
140                while self.pos < self.bytes.len() && depth > 0 {
141                    if self.bytes[self.pos] == b'/' && self.peek_at(1) == Some(b'*') {
142                        depth += 1;
143                        self.pos += 2;
144                    } else if self.bytes[self.pos] == b'*' && self.peek_at(1) == Some(b'/') {
145                        depth -= 1;
146                        self.pos += 2;
147                    } else {
148                        self.pos += 1;
149                    }
150                }
151                if depth > 0 {
152                    return Err(LexerError::UnterminatedBlockComment {
153                        offset: start as u32,
154                    });
155                }
156                Ok(self.make_token(TokenKind::BlockComment, start))
157            }
158
159            // String literal: 'hello'
160            b'\'' => self.lex_string_literal(start),
161
162            // Double-quoted identifier: "name"
163            b'"' => self.lex_quoted_identifier(start, b'"'),
164
165            // Bracket-quoted identifier: [name] (SQL Server)
166            b'[' if self.config.bracket_identifiers => self.lex_bracket_identifier(start),
167
168            // Array subscript brackets (PostgreSQL)
169            b'[' => {
170                self.pos += 1;
171                Ok(self.make_token(TokenKind::LBracket, start))
172            }
173            b']' => {
174                self.pos += 1;
175                Ok(self.make_token(TokenKind::RBracket, start))
176            }
177
178            // Backtick identifier: `name` (MySQL)
179            b'`' if self.config.backtick_identifiers => self.lex_quoted_identifier(start, b'`'),
180
181            // Numbers
182            b'0'..=b'9' => self.lex_number(start),
183
184            // Dot (could be start of .123 numeric or just dot)
185            b'.' if self.peek_at(1).is_some_and(|b| b.is_ascii_digit()) => self.lex_number(start),
186
187            // Single-character operators & punctuation
188            b'.' => {
189                self.pos += 1;
190                Ok(self.make_token(TokenKind::Dot, start))
191            }
192            b',' => {
193                self.pos += 1;
194                Ok(self.make_token(TokenKind::Comma, start))
195            }
196            b';' => {
197                self.pos += 1;
198                Ok(self.make_token(TokenKind::Semicolon, start))
199            }
200            b'(' => {
201                self.pos += 1;
202                Ok(self.make_token(TokenKind::LParen, start))
203            }
204            b')' => {
205                self.pos += 1;
206                Ok(self.make_token(TokenKind::RParen, start))
207            }
208            b'*' => {
209                self.pos += 1;
210                Ok(self.make_token(TokenKind::Star, start))
211            }
212            b'+' => {
213                self.pos += 1;
214                Ok(self.make_token(TokenKind::Plus, start))
215            }
216            b'-' => {
217                // Single minus (-- already handled above)
218                self.pos += 1;
219                Ok(self.make_token(TokenKind::Minus, start))
220            }
221            b'/' => {
222                // Single slash (/* already handled above)
223                self.pos += 1;
224                Ok(self.make_token(TokenKind::Slash, start))
225            }
226            b'%' => {
227                self.pos += 1;
228                Ok(self.make_token(TokenKind::Percent, start))
229            }
230            b'=' => {
231                self.pos += 1;
232                Ok(self.make_token(TokenKind::Eq, start))
233            }
234
235            // < <= <> operators
236            b'<' => {
237                self.pos += 1;
238                match self.peek() {
239                    Some(b'=') => {
240                        self.pos += 1;
241                        Ok(self.make_token(TokenKind::LtEq, start))
242                    }
243                    Some(b'>') => {
244                        self.pos += 1;
245                        Ok(self.make_token(TokenKind::Neq, start))
246                    }
247                    _ => Ok(self.make_token(TokenKind::Lt, start)),
248                }
249            }
250
251            // > >= operators
252            b'>' => {
253                self.pos += 1;
254                if self.peek() == Some(b'=') {
255                    self.pos += 1;
256                    Ok(self.make_token(TokenKind::GtEq, start))
257                } else {
258                    Ok(self.make_token(TokenKind::Gt, start))
259                }
260            }
261
262            // != operator
263            b'!' if self.peek_at(1) == Some(b'=') => {
264                self.pos += 2;
265                Ok(self.make_token(TokenKind::Neq, start))
266            }
267
268            // || concat operator
269            b'|' if self.peek_at(1) == Some(b'|') => {
270                self.pos += 2;
271                Ok(self.make_token(TokenKind::Concat, start))
272            }
273
274            // :: cast (PostgreSQL)
275            b':' if self.config.double_colon && self.peek_at(1) == Some(b':') => {
276                self.pos += 2;
277                Ok(self.make_token(TokenKind::ColonColon, start))
278            }
279
280            // : named parameter
281            b':' => {
282                self.pos += 1;
283                if self
284                    .peek()
285                    .is_some_and(|b| b.is_ascii_alphanumeric() || b == b'_')
286                {
287                    while self
288                        .peek()
289                        .is_some_and(|b| b.is_ascii_alphanumeric() || b == b'_')
290                    {
291                        self.pos += 1;
292                    }
293                    Ok(self.make_token(TokenKind::Placeholder, start))
294                } else {
295                    Ok(self.make_token(TokenKind::Colon, start))
296                }
297            }
298
299            // @ or @@ (SQL Server)
300            b'@' => {
301                self.pos += 1;
302                if self.config.double_at && self.peek() == Some(b'@') {
303                    self.pos += 1;
304                }
305                // Read variable name (including non-ASCII chars like Japanese)
306                self.eat_word_chars();
307                Ok(self.make_token(TokenKind::AtSign, start))
308            }
309
310            // ? positional parameter
311            b'?' => {
312                self.pos += 1;
313                Ok(self.make_token(TokenKind::Placeholder, start))
314            }
315
316            // $ positional parameter ($1) or dollar-quoting (PostgreSQL)
317            b'$' => {
318                if self.config.dollar_quoting {
319                    self.lex_dollar_quote_or_param(start)
320                } else {
321                    self.pos += 1;
322                    // $1, $2 etc
323                    while self.peek().is_some_and(|b| b.is_ascii_digit()) {
324                        self.pos += 1;
325                    }
326                    Ok(self.make_token(TokenKind::Placeholder, start))
327                }
328            }
329
330            // Word: keyword or identifier (including non-ASCII like Japanese)
331            b if is_word_start(b) || b >= 0x80 => {
332                if b >= 0x80 {
333                    let s = &self.source[self.pos..];
334                    let first_char = s.chars().next().unwrap();
335                    self.pos += first_char.len_utf8();
336                } else {
337                    self.pos += 1;
338                }
339                self.eat_word_chars();
340                // N'...' is a Unicode/NVARCHAR string literal prefix
341                let word = &self.source[start..self.pos];
342                if word.eq_ignore_ascii_case("N") && self.peek() == Some(b'\'') {
343                    return self.lex_string_literal(start);
344                }
345                Ok(self.make_token(TokenKind::Word, start))
346            }
347
348            _ => {
349                let ch = self.source[self.pos..].chars().next().unwrap();
350                Err(LexerError::UnexpectedChar {
351                    ch,
352                    offset: start as u32,
353                })
354            }
355        }
356    }
357
358    fn lex_string_literal(&mut self, start: usize) -> Result<Token, LexerError> {
359        self.pos += 1; // skip opening quote
360        loop {
361            match self.peek() {
362                None => {
363                    return Err(LexerError::UnterminatedString {
364                        offset: start as u32,
365                    })
366                }
367                Some(b'\'') => {
368                    self.pos += 1;
369                    // Escaped quote ''
370                    if self.peek() == Some(b'\'') {
371                        self.pos += 1;
372                        continue;
373                    }
374                    return Ok(self.make_token(TokenKind::StringLiteral, start));
375                }
376                Some(_) => self.pos += 1,
377            }
378        }
379    }
380
381    fn lex_quoted_identifier(&mut self, start: usize, quote: u8) -> Result<Token, LexerError> {
382        self.pos += 1; // skip opening quote
383        loop {
384            match self.peek() {
385                None => {
386                    return Err(LexerError::UnterminatedQuotedIdentifier {
387                        offset: start as u32,
388                    })
389                }
390                Some(b) if b == quote => {
391                    self.pos += 1;
392                    // Escaped quote
393                    if self.peek() == Some(quote) {
394                        self.pos += 1;
395                        continue;
396                    }
397                    return Ok(self.make_token(TokenKind::QuotedIdentifier, start));
398                }
399                Some(_) => self.pos += 1,
400            }
401        }
402    }
403
404    fn lex_bracket_identifier(&mut self, start: usize) -> Result<Token, LexerError> {
405        self.pos += 1; // skip [
406        loop {
407            match self.peek() {
408                None => {
409                    return Err(LexerError::UnterminatedQuotedIdentifier {
410                        offset: start as u32,
411                    })
412                }
413                Some(b']') => {
414                    self.pos += 1;
415                    return Ok(self.make_token(TokenKind::QuotedIdentifier, start));
416                }
417                Some(_) => self.pos += 1,
418            }
419        }
420    }
421
422    fn lex_number(&mut self, start: usize) -> Result<Token, LexerError> {
423        // Integer part
424        while self.peek().is_some_and(|b| b.is_ascii_digit()) {
425            self.pos += 1;
426        }
427        // Decimal part
428        if self.peek() == Some(b'.') && self.peek_at(1).is_some_and(|b| b.is_ascii_digit()) {
429            self.pos += 1; // skip .
430            while self.peek().is_some_and(|b| b.is_ascii_digit()) {
431                self.pos += 1;
432            }
433        } else if self.bytes[start] == b'.' {
434            // .123 form — dot already consumed before we got here
435            self.pos += 1; // skip .
436            while self.peek().is_some_and(|b| b.is_ascii_digit()) {
437                self.pos += 1;
438            }
439        }
440        // Exponent part
441        if self.peek().is_some_and(|b| b == b'e' || b == b'E') {
442            self.pos += 1;
443            if self.peek().is_some_and(|b| b == b'+' || b == b'-') {
444                self.pos += 1;
445            }
446            while self.peek().is_some_and(|b| b.is_ascii_digit()) {
447                self.pos += 1;
448            }
449        }
450        Ok(self.make_token(TokenKind::NumberLiteral, start))
451    }
452
453    fn lex_dollar_quote_or_param(&mut self, start: usize) -> Result<Token, LexerError> {
454        // Check if it's a dollar-quoted string: $tag$...$tag$ or $$...$$
455        let after_dollar = self.pos + 1;
456        if after_dollar < self.bytes.len() {
457            // $$ or $tag$
458            if self.bytes[after_dollar] == b'$' {
459                // $$...$$ form
460                self.pos += 2; // skip $$
461                let tag = "";
462                return self.lex_dollar_body(start, tag);
463            }
464            if self.bytes[after_dollar].is_ascii_alphabetic() || self.bytes[after_dollar] == b'_' {
465                // $tag$...$tag$ form
466                let tag_start = after_dollar;
467                let mut p = after_dollar;
468                while p < self.bytes.len()
469                    && (self.bytes[p].is_ascii_alphanumeric() || self.bytes[p] == b'_')
470                {
471                    p += 1;
472                }
473                if p < self.bytes.len() && self.bytes[p] == b'$' {
474                    let tag = &self.source[tag_start..p];
475                    self.pos = p + 1; // skip closing $
476                    return self.lex_dollar_body(start, tag);
477                }
478            }
479        }
480
481        // Plain parameter: $1, $2, etc.
482        self.pos += 1;
483        while self.peek().is_some_and(|b| b.is_ascii_digit()) {
484            self.pos += 1;
485        }
486        Ok(self.make_token(TokenKind::Placeholder, start))
487    }
488
489    fn lex_dollar_body(&mut self, start: usize, tag: &str) -> Result<Token, LexerError> {
490        let end_tag = format!("${tag}$");
491        let end_bytes = end_tag.as_bytes();
492        while self.pos + end_bytes.len() <= self.bytes.len() {
493            if &self.bytes[self.pos..self.pos + end_bytes.len()] == end_bytes {
494                self.pos += end_bytes.len();
495                return Ok(self.make_token(TokenKind::StringLiteral, start));
496            }
497            self.pos += 1;
498        }
499        // If we hit EOF without closing, treat as unterminated string
500        Err(LexerError::UnterminatedString {
501            offset: start as u32,
502        })
503    }
504
505    fn peek(&self) -> Option<u8> {
506        self.bytes.get(self.pos).copied()
507    }
508
509    fn peek_at(&self, offset: usize) -> Option<u8> {
510        self.bytes.get(self.pos + offset).copied()
511    }
512
513    /// Advance past word-like characters (ASCII alphanumeric, `_`, and non-ASCII alphanumeric).
514    fn eat_word_chars(&mut self) {
515        while self.pos < self.bytes.len() {
516            let b = self.bytes[self.pos];
517            if is_word_continue(b) {
518                self.pos += 1;
519            } else if b >= 0x80 {
520                let remaining = &self.source[self.pos..];
521                if let Some(c) = remaining.chars().next() {
522                    if c.is_alphanumeric() || c == '_' {
523                        self.pos += c.len_utf8();
524                    } else {
525                        break;
526                    }
527                } else {
528                    break;
529                }
530            } else {
531                break;
532            }
533        }
534    }
535
536    fn make_token(&self, kind: TokenKind, start: usize) -> Token {
537        let text = &self.source[start..self.pos];
538        Token::new(
539            kind,
540            Span::new(start as u32, self.pos as u32),
541            SmolStr::new(text),
542        )
543    }
544}
545
546fn is_word_start(b: u8) -> bool {
547    b.is_ascii_alphabetic() || b == b'_' || b == b'#'
548}
549
550fn is_word_continue(b: u8) -> bool {
551    b.is_ascii_alphanumeric() || b == b'_' || b == b'#'
552}
553
554#[cfg(test)]
555mod tests {
556    use super::*;
557
558    fn lex(input: &str) -> Vec<Token> {
559        let mut lexer = Lexer::new(input, LexerConfig::ansi());
560        lexer.tokenize().unwrap()
561    }
562
563    fn kinds(input: &str) -> Vec<TokenKind> {
564        lex(input).into_iter().map(|t| t.kind).collect()
565    }
566
567    #[test]
568    fn test_simple_select() {
569        let tokens = lex("SELECT 1");
570        assert_eq!(tokens.len(), 4); // SELECT, WS, 1, EOF
571        assert_eq!(tokens[0].kind, TokenKind::Word);
572        assert_eq!(tokens[0].text.as_str(), "SELECT");
573        assert_eq!(tokens[1].kind, TokenKind::Whitespace);
574        assert_eq!(tokens[2].kind, TokenKind::NumberLiteral);
575        assert_eq!(tokens[2].text.as_str(), "1");
576        assert_eq!(tokens[3].kind, TokenKind::Eof);
577    }
578
579    #[test]
580    fn test_select_star() {
581        let k = kinds("SELECT * FROM users;");
582        assert_eq!(
583            k,
584            vec![
585                TokenKind::Word,       // SELECT
586                TokenKind::Whitespace, // ' '
587                TokenKind::Star,       // *
588                TokenKind::Whitespace, // ' '
589                TokenKind::Word,       // FROM
590                TokenKind::Whitespace, // ' '
591                TokenKind::Word,       // users
592                TokenKind::Semicolon,  // ;
593                TokenKind::Eof,
594            ]
595        );
596    }
597
598    #[test]
599    fn test_string_literal() {
600        let tokens = lex("'hello world'");
601        assert_eq!(tokens[0].kind, TokenKind::StringLiteral);
602        assert_eq!(tokens[0].text.as_str(), "'hello world'");
603    }
604
605    #[test]
606    fn test_escaped_string() {
607        let tokens = lex("'it''s'");
608        assert_eq!(tokens[0].kind, TokenKind::StringLiteral);
609        assert_eq!(tokens[0].text.as_str(), "'it''s'");
610    }
611
612    #[test]
613    fn test_line_comment() {
614        let tokens = lex("-- comment\nSELECT");
615        assert_eq!(tokens[0].kind, TokenKind::LineComment);
616        assert_eq!(tokens[0].text.as_str(), "-- comment");
617        assert_eq!(tokens[1].kind, TokenKind::Newline);
618        assert_eq!(tokens[2].kind, TokenKind::Word);
619    }
620
621    #[test]
622    fn test_block_comment() {
623        let tokens = lex("/* multi\nline */");
624        assert_eq!(tokens[0].kind, TokenKind::BlockComment);
625        assert_eq!(tokens[0].text.as_str(), "/* multi\nline */");
626    }
627
628    #[test]
629    fn test_nested_block_comment() {
630        let tokens = lex("/* outer /* inner */ end */");
631        assert_eq!(tokens[0].kind, TokenKind::BlockComment);
632    }
633
634    #[test]
635    fn test_operators() {
636        let k = kinds("<= >= <> !=");
637        assert_eq!(
638            k,
639            vec![
640                TokenKind::LtEq,
641                TokenKind::Whitespace,
642                TokenKind::GtEq,
643                TokenKind::Whitespace,
644                TokenKind::Neq,
645                TokenKind::Whitespace,
646                TokenKind::Neq,
647                TokenKind::Eof,
648            ]
649        );
650    }
651
652    #[test]
653    fn test_number_formats() {
654        let tokens = lex("42 3.14 .5 1e10 2.5E-3");
655        let nums: Vec<&str> = tokens
656            .iter()
657            .filter(|t| t.kind == TokenKind::NumberLiteral)
658            .map(|t| t.text.as_str())
659            .collect();
660        assert_eq!(nums, vec!["42", "3.14", ".5", "1e10", "2.5E-3"]);
661    }
662
663    #[test]
664    fn test_quoted_identifier() {
665        let tokens = lex("\"my column\"");
666        assert_eq!(tokens[0].kind, TokenKind::QuotedIdentifier);
667        assert_eq!(tokens[0].text.as_str(), "\"my column\"");
668    }
669
670    #[test]
671    fn test_postgres_double_colon() {
672        let mut lexer = Lexer::new("col::int", LexerConfig::postgres());
673        let tokens = lexer.tokenize().unwrap();
674        assert_eq!(tokens[1].kind, TokenKind::ColonColon);
675    }
676
677    #[test]
678    fn test_tsql_bracket_identifier() {
679        let mut lexer = Lexer::new("[my col]", LexerConfig::tsql());
680        let tokens = lexer.tokenize().unwrap();
681        assert_eq!(tokens[0].kind, TokenKind::QuotedIdentifier);
682        assert_eq!(tokens[0].text.as_str(), "[my col]");
683    }
684
685    #[test]
686    fn test_newline_types() {
687        let k = kinds("a\nb\r\nc");
688        assert_eq!(
689            k,
690            vec![
691                TokenKind::Word,
692                TokenKind::Newline,
693                TokenKind::Word,
694                TokenKind::Newline,
695                TokenKind::Word,
696                TokenKind::Eof,
697            ]
698        );
699    }
700
701    #[test]
702    fn test_placeholder() {
703        let tokens = lex(":name ?");
704        assert_eq!(tokens[0].kind, TokenKind::Placeholder);
705        assert_eq!(tokens[0].text.as_str(), ":name");
706        assert_eq!(tokens[2].kind, TokenKind::Placeholder);
707        assert_eq!(tokens[2].text.as_str(), "?");
708    }
709}