Skip to main content

sochdb_query/sql/
lexer.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! SQL Lexer
19//!
20//! Converts SQL text into a stream of tokens.
21//! Handles string literals, numbers, identifiers, keywords, and operators.
22
23use super::token::{Span, Token, TokenKind};
24use std::borrow::Cow;
25
26/// SQL Lexer errors
27#[derive(Debug, Clone, PartialEq)]
28pub struct LexError {
29    pub message: String,
30    pub span: Span,
31}
32
33impl LexError {
34    pub fn new(message: impl Into<String>, span: Span) -> Self {
35        Self {
36            message: message.into(),
37            span,
38        }
39    }
40}
41
42impl std::fmt::Display for LexError {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        write!(
45            f,
46            "Lexer error at line {}, column {}: {}",
47            self.span.line, self.span.column, self.message
48        )
49    }
50}
51
52impl std::error::Error for LexError {}
53
54/// SQL Lexer - tokenizes SQL input
55pub struct Lexer<'a> {
56    input: &'a str,
57    bytes: &'a [u8],
58    pos: usize,
59    line: usize,
60    column: usize,
61    tokens: Vec<Token<'a>>,
62    errors: Vec<LexError>,
63    /// Counter for `?` style placeholders (auto-incrementing)
64    placeholder_counter: u32,
65}
66
67impl<'a> Lexer<'a> {
68    /// Create a new lexer for the given SQL input
69    pub fn new(input: &'a str) -> Self {
70        Self {
71            input,
72            bytes: input.as_bytes(),
73            pos: 0,
74            line: 1,
75            column: 1,
76            tokens: Vec::with_capacity(input.len() / 4),
77            errors: Vec::new(),
78            placeholder_counter: 0,
79        }
80    }
81
82    /// Tokenize the entire input
83    pub fn tokenize(mut self) -> Result<Vec<Token<'a>>, Vec<LexError>> {
84        while !self.is_at_end() {
85            self.scan_token();
86        }
87
88        // Add EOF token
89        self.tokens.push(Token::new(
90            TokenKind::Eof,
91            Span::new(self.pos, self.pos, self.line, self.column),
92            "",
93        ));
94
95        if self.errors.is_empty() {
96            Ok(self.tokens)
97        } else {
98            Err(self.errors)
99        }
100    }
101
102    fn is_at_end(&self) -> bool {
103        self.pos >= self.bytes.len()
104    }
105
106    fn advance(&mut self) -> Option<char> {
107        if self.pos >= self.bytes.len() {
108            return None;
109        }
110        let b = self.bytes[self.pos];
111        if b < 0x80 {
112            // ASCII fast path
113            self.pos += 1;
114            if b == b'\n' {
115                self.line += 1;
116                self.column = 1;
117            } else {
118                self.column += 1;
119            }
120            Some(b as char)
121        } else {
122            // Multi-byte UTF-8
123            let c = self.input[self.pos..].chars().next().unwrap();
124            self.pos += c.len_utf8();
125            self.column += 1;
126            Some(c)
127        }
128    }
129
130    fn peek(&self) -> Option<char> {
131        if self.pos >= self.bytes.len() {
132            return None;
133        }
134        let b = self.bytes[self.pos];
135        if b < 0x80 {
136            Some(b as char)
137        } else {
138            self.input[self.pos..].chars().next()
139        }
140    }
141
142    fn peek_next(&self) -> Option<char> {
143        if self.pos >= self.bytes.len() {
144            return None;
145        }
146        let first_len = if self.bytes[self.pos] < 0x80 {
147            1
148        } else {
149            self.input[self.pos..]
150                .chars()
151                .next()
152                .map_or(1, |c| c.len_utf8())
153        };
154        let next = self.pos + first_len;
155        if next >= self.bytes.len() {
156            return None;
157        }
158        let b = self.bytes[next];
159        if b < 0x80 {
160            Some(b as char)
161        } else {
162            self.input[next..].chars().next()
163        }
164    }
165
166    fn make_span(&self, start: usize, start_line: usize, start_col: usize) -> Span {
167        Span::new(start, self.pos, start_line, start_col)
168    }
169
170    fn scan_token(&mut self) {
171        let start = self.pos;
172        let start_line = self.line;
173        let start_col = self.column;
174
175        let c = match self.advance() {
176            Some(c) => c,
177            None => return,
178        };
179
180        match c {
181            // Whitespace
182            ' ' | '\t' | '\r' | '\n' => {
183                // Skip whitespace, don't emit token
184            }
185
186            // Single-character tokens
187            '(' => self.add_token(TokenKind::LParen, start, start_line, start_col),
188            ')' => self.add_token(TokenKind::RParen, start, start_line, start_col),
189            '[' => self.add_token(TokenKind::LBracket, start, start_line, start_col),
190            ']' => self.add_token(TokenKind::RBracket, start, start_line, start_col),
191            ',' => self.add_token(TokenKind::Comma, start, start_line, start_col),
192            ';' => self.add_token(TokenKind::Semicolon, start, start_line, start_col),
193            '+' => self.add_token(TokenKind::Plus, start, start_line, start_col),
194            '*' => self.add_token(TokenKind::Star, start, start_line, start_col),
195            '/' => {
196                if self.peek() == Some('/') || self.peek() == Some('*') {
197                    self.scan_comment(start, start_line, start_col);
198                } else {
199                    self.add_token(TokenKind::Slash, start, start_line, start_col);
200                }
201            }
202            '%' => self.add_token(TokenKind::Percent, start, start_line, start_col),
203            '&' => self.add_token(TokenKind::BitAnd, start, start_line, start_col),
204            '~' => self.add_token(TokenKind::BitNot, start, start_line, start_col),
205            '?' => {
206                // Auto-incrementing placeholder for JDBC/ODBC style ?
207                self.placeholder_counter += 1;
208                let span = self.make_span(start, start_line, start_col);
209                self.tokens.push(Token::new(
210                    TokenKind::Placeholder(self.placeholder_counter),
211                    span,
212                    "?",
213                ));
214            }
215            '@' => self.add_token(TokenKind::At, start, start_line, start_col),
216
217            // Two-character tokens
218            '-' => {
219                if self.peek() == Some('-') {
220                    // Line comment
221                    self.scan_line_comment(start, start_line, start_col);
222                } else if self.peek() == Some('>') {
223                    self.advance();
224                    if self.peek() == Some('>') {
225                        self.advance();
226                        self.add_token(TokenKind::DoubleArrow, start, start_line, start_col);
227                    } else {
228                        self.add_token(TokenKind::Arrow, start, start_line, start_col);
229                    }
230                } else {
231                    self.add_token(TokenKind::Minus, start, start_line, start_col);
232                }
233            }
234
235            '=' => self.add_token(TokenKind::Eq, start, start_line, start_col),
236
237            '!' => {
238                if self.peek() == Some('=') {
239                    self.advance();
240                    self.add_token(TokenKind::Ne, start, start_line, start_col);
241                } else {
242                    self.add_error("Unexpected character '!'", start, start_line, start_col);
243                }
244            }
245
246            '<' => {
247                if self.peek() == Some('=') {
248                    self.advance();
249                    self.add_token(TokenKind::Le, start, start_line, start_col);
250                } else if self.peek() == Some('-') {
251                    self.advance(); // consume '-'
252                    if self.peek() == Some('>') {
253                        self.advance(); // consume '>'
254                        self.add_token(TokenKind::BiArrow, start, start_line, start_col);
255                    } else {
256                        self.add_token(TokenKind::LeftArrow, start, start_line, start_col);
257                    }
258                } else if self.peek() == Some('>') {
259                    self.advance();
260                    self.add_token(TokenKind::Ne, start, start_line, start_col);
261                } else if self.peek() == Some('<') {
262                    self.advance();
263                    self.add_token(TokenKind::LeftShift, start, start_line, start_col);
264                } else {
265                    self.add_token(TokenKind::Lt, start, start_line, start_col);
266                }
267            }
268
269            '>' => {
270                if self.peek() == Some('=') {
271                    self.advance();
272                    self.add_token(TokenKind::Ge, start, start_line, start_col);
273                } else if self.peek() == Some('>') {
274                    self.advance();
275                    self.add_token(TokenKind::RightShift, start, start_line, start_col);
276                } else {
277                    self.add_token(TokenKind::Gt, start, start_line, start_col);
278                }
279            }
280
281            '|' => {
282                if self.peek() == Some('|') {
283                    self.advance();
284                    self.add_token(TokenKind::Concat, start, start_line, start_col);
285                } else {
286                    self.add_token(TokenKind::BitOr, start, start_line, start_col);
287                }
288            }
289
290            ':' => {
291                if self.peek() == Some(':') {
292                    self.advance();
293                    self.add_token(TokenKind::DoubleColon, start, start_line, start_col);
294                } else {
295                    self.add_token(TokenKind::Colon, start, start_line, start_col);
296                }
297            }
298
299            '.' => {
300                if self.peek().map(|c| c.is_ascii_digit()).unwrap_or(false) {
301                    self.scan_number(start, start_line, start_col, true);
302                } else {
303                    self.add_token(TokenKind::Dot, start, start_line, start_col);
304                }
305            }
306
307            // String literals
308            '\'' => self.scan_string(start, start_line, start_col, '\''),
309            '"' => self.scan_quoted_identifier(start, start_line, start_col, '"'),
310            '`' => self.scan_quoted_identifier(start, start_line, start_col, '`'),
311
312            // Blob literal (X'...')
313            'X' | 'x' if self.peek() == Some('\'') => {
314                self.advance();
315                self.scan_blob(start, start_line, start_col);
316            }
317
318            // Numbers
319            '0'..='9' => self.scan_number(start, start_line, start_col, false),
320
321            // Identifiers and keywords
322            'a'..='z' | 'A'..='Z' | '_' => self.scan_identifier(start, start_line, start_col),
323
324            // Placeholder ($1, $2, ...)
325            '$' => self.scan_placeholder(start, start_line, start_col),
326
327            _ => {
328                self.add_error(
329                    format!("Unexpected character '{}'", c),
330                    start,
331                    start_line,
332                    start_col,
333                );
334            }
335        }
336    }
337
338    fn scan_string(&mut self, start: usize, start_line: usize, start_col: usize, quote: char) {
339        let mut value = String::new();
340
341        while let Some(c) = self.peek() {
342            if c == quote {
343                self.advance();
344                // Check for escaped quote ('')
345                if self.peek() == Some(quote) {
346                    self.advance();
347                    value.push(quote);
348                } else {
349                    // End of string
350                    let span = self.make_span(start, start_line, start_col);
351                    let literal = &self.input[start..self.pos];
352                    self.tokens
353                        .push(Token::new(TokenKind::String(Cow::Owned(value)), span, literal));
354                    return;
355                }
356            } else if c == '\\' {
357                self.advance();
358                // Handle escape sequences
359                if let Some(escaped) = self.advance() {
360                    match escaped {
361                        'n' => value.push('\n'),
362                        'r' => value.push('\r'),
363                        't' => value.push('\t'),
364                        '\\' => value.push('\\'),
365                        '\'' => value.push('\''),
366                        '"' => value.push('"'),
367                        '0' => value.push('\0'),
368                        _ => {
369                            value.push('\\');
370                            value.push(escaped);
371                        }
372                    }
373                }
374            } else {
375                self.advance();
376                value.push(c);
377            }
378        }
379
380        self.add_error("Unterminated string literal", start, start_line, start_col);
381    }
382
383    fn scan_quoted_identifier(
384        &mut self,
385        start: usize,
386        start_line: usize,
387        start_col: usize,
388        quote: char,
389    ) {
390        let mut value = String::new();
391
392        while let Some(c) = self.peek() {
393            if c == quote {
394                self.advance();
395                // Check for escaped quote
396                if self.peek() == Some(quote) {
397                    self.advance();
398                    value.push(quote);
399                } else {
400                    let span = self.make_span(start, start_line, start_col);
401                    let literal = &self.input[start..self.pos];
402                    self.tokens
403                        .push(Token::new(TokenKind::QuotedIdentifier(Cow::Owned(value)), span, literal));
404                    return;
405                }
406            } else {
407                self.advance();
408                value.push(c);
409            }
410        }
411
412        self.add_error(
413            "Unterminated quoted identifier",
414            start,
415            start_line,
416            start_col,
417        );
418    }
419
420    fn scan_number(
421        &mut self,
422        start: usize,
423        start_line: usize,
424        start_col: usize,
425        started_with_dot: bool,
426    ) {
427        let num_start = start;
428        let mut has_dot = started_with_dot;
429        let mut has_exp = false;
430
431        // Consume integer part
432        while let Some(c) = self.peek() {
433            if c.is_ascii_digit() {
434                self.advance();
435            } else if c == '.' && !has_dot && !has_exp {
436                // Check it's not a range operator (..)
437                if self.peek_next() == Some('.') {
438                    break;
439                }
440                has_dot = true;
441                self.advance();
442            } else if (c == 'e' || c == 'E') && !has_exp {
443                has_exp = true;
444                self.advance();
445                // Optional sign
446                if self.peek() == Some('+') || self.peek() == Some('-') {
447                    self.advance();
448                }
449            } else {
450                break;
451            }
452        }
453
454        let literal = &self.input[num_start..self.pos];
455        let span = self.make_span(start, start_line, start_col);
456
457        if has_dot || has_exp {
458            match literal.parse::<f64>() {
459                Ok(n) => self
460                    .tokens
461                    .push(Token::new(TokenKind::Float(n), span, literal)),
462                Err(_) => self.add_error("Invalid float literal", start, start_line, start_col),
463            }
464        } else {
465            match literal.parse::<i64>() {
466                Ok(n) => self
467                    .tokens
468                    .push(Token::new(TokenKind::Integer(n), span, literal)),
469                Err(_) => self.add_error("Invalid integer literal", start, start_line, start_col),
470            }
471        }
472    }
473
474    fn scan_identifier(&mut self, start: usize, start_line: usize, start_col: usize) {
475        while let Some(c) = self.peek() {
476            if c.is_ascii_alphanumeric() || c == '_' {
477                self.advance();
478            } else {
479                break;
480            }
481        }
482
483        let literal = &self.input[start..self.pos];
484        let span = self.make_span(start, start_line, start_col);
485
486        // Check for keyword — zero allocation
487        let kind = TokenKind::from_keyword(literal)
488            .unwrap_or(TokenKind::Identifier(literal));
489
490        self.tokens.push(Token::new(kind, span, literal));
491    }
492
493    fn scan_placeholder(&mut self, start: usize, start_line: usize, start_col: usize) {
494        let mut num = String::new();
495
496        while let Some(c) = self.peek() {
497            if c.is_ascii_digit() {
498                self.advance();
499                num.push(c);
500            } else {
501                break;
502            }
503        }
504
505        let span = self.make_span(start, start_line, start_col);
506
507        if num.is_empty() {
508            self.add_error("Expected number after $", start, start_line, start_col);
509        } else if let Ok(n) = num.parse::<u32>() {
510            self.tokens.push(Token::new(
511                TokenKind::Placeholder(n),
512                span,
513                &self.input[start..self.pos],
514            ));
515        } else {
516            self.add_error("Invalid placeholder number", start, start_line, start_col);
517        }
518    }
519
520    fn scan_comment(&mut self, start: usize, start_line: usize, start_col: usize) {
521        self.advance(); // consume second / or *
522
523        if self.peek() == Some('*') || self.input[start..self.pos].ends_with('*') {
524            // Block comment /* ... */
525            let mut depth = 1;
526
527            while depth > 0 && !self.is_at_end() {
528                let c = self.peek();
529                let next = self.peek_next();
530
531                if c == Some('*') && next == Some('/') {
532                    self.advance();
533                    self.advance();
534                    depth -= 1;
535                } else if c == Some('/') && next == Some('*') {
536                    self.advance();
537                    self.advance();
538                    depth += 1;
539                } else {
540                    self.advance();
541                }
542            }
543
544            if depth > 0 {
545                self.add_error("Unterminated block comment", start, start_line, start_col);
546            }
547        } else {
548            // Line comment //
549            while let Some(c) = self.peek() {
550                if c == '\n' {
551                    break;
552                }
553                self.advance();
554            }
555        }
556        // Don't emit comment tokens
557    }
558
559    fn scan_line_comment(&mut self, _start: usize, _start_line: usize, _start_col: usize) {
560        self.advance(); // consume second -
561
562        while let Some(c) = self.peek() {
563            if c == '\n' {
564                break;
565            }
566            self.advance();
567        }
568        // Don't emit comment tokens
569    }
570
571    fn scan_blob(&mut self, start: usize, start_line: usize, start_col: usize) {
572        let mut hex = String::new();
573
574        while let Some(c) = self.peek() {
575            if c == '\'' {
576                self.advance();
577                break;
578            } else if c.is_ascii_hexdigit() {
579                self.advance();
580                hex.push(c);
581            } else if c.is_whitespace() {
582                self.advance(); // Allow whitespace in blob
583            } else {
584                self.add_error(
585                    "Invalid hex digit in blob literal",
586                    start,
587                    start_line,
588                    start_col,
589                );
590                return;
591            }
592        }
593
594        if !hex.len().is_multiple_of(2) {
595            self.add_error(
596                "Blob literal must have even number of hex digits",
597                start,
598                start_line,
599                start_col,
600            );
601            return;
602        }
603
604        let bytes: Result<Vec<u8>, _> = (0..hex.len())
605            .step_by(2)
606            .map(|i| u8::from_str_radix(&hex[i..i + 2], 16))
607            .collect();
608
609        match bytes {
610            Ok(data) => {
611                let span = self.make_span(start, start_line, start_col);
612                let literal = &self.input[start..self.pos];
613                self.tokens
614                    .push(Token::new(TokenKind::Blob(data), span, literal));
615            }
616            Err(_) => {
617                self.add_error("Invalid blob literal", start, start_line, start_col);
618            }
619        }
620    }
621
622    fn add_token(&mut self, kind: TokenKind<'a>, start: usize, start_line: usize, start_col: usize) {
623        let span = self.make_span(start, start_line, start_col);
624        let literal = &self.input[start..self.pos];
625        self.tokens.push(Token::new(kind, span, literal));
626    }
627
628    fn add_error(
629        &mut self,
630        message: impl Into<String>,
631        start: usize,
632        start_line: usize,
633        start_col: usize,
634    ) {
635        let span = self.make_span(start, start_line, start_col);
636        self.errors.push(LexError::new(message, span));
637    }
638}
639
640#[cfg(test)]
641mod tests {
642    use super::*;
643
644    #[test]
645    fn test_simple_select() {
646        let tokens = Lexer::new("SELECT * FROM users").tokenize().unwrap();
647        assert_eq!(tokens.len(), 5); // SELECT, *, FROM, users, EOF
648        assert_eq!(tokens[0].kind, TokenKind::Select);
649        assert_eq!(tokens[1].kind, TokenKind::Star);
650        assert_eq!(tokens[2].kind, TokenKind::From);
651        assert!(matches!(tokens[3].kind, TokenKind::Identifier(_)));
652    }
653
654    #[test]
655    fn test_string_literal() {
656        let tokens = Lexer::new("SELECT 'hello''world'").tokenize().unwrap();
657        assert!(matches!(&tokens[1].kind, TokenKind::String(s) if s == "hello'world"));
658    }
659
660    #[test]
661    #[allow(clippy::approx_constant)]
662    fn test_numbers() {
663        let tokens = Lexer::new("42 3.14 1e10 .5").tokenize().unwrap();
664        assert!(matches!(tokens[0].kind, TokenKind::Integer(42)));
665        assert!(matches!(tokens[1].kind, TokenKind::Float(f) if (f - 3.14).abs() < 0.001));
666        assert!(matches!(tokens[2].kind, TokenKind::Float(_)));
667        assert!(matches!(tokens[3].kind, TokenKind::Float(f) if (f - 0.5).abs() < 0.001));
668    }
669
670    #[test]
671    fn test_operators() {
672        let tokens = Lexer::new("= != <> < <= > >= || ->").tokenize().unwrap();
673        assert_eq!(tokens[0].kind, TokenKind::Eq);
674        assert_eq!(tokens[1].kind, TokenKind::Ne);
675        assert_eq!(tokens[2].kind, TokenKind::Ne);
676        assert_eq!(tokens[3].kind, TokenKind::Lt);
677        assert_eq!(tokens[4].kind, TokenKind::Le);
678        assert_eq!(tokens[5].kind, TokenKind::Gt);
679        assert_eq!(tokens[6].kind, TokenKind::Ge);
680        assert_eq!(tokens[7].kind, TokenKind::Concat);
681        assert_eq!(tokens[8].kind, TokenKind::Arrow);
682    }
683
684    #[test]
685    fn test_keywords() {
686        let tokens = Lexer::new("SELECT INSERT UPDATE DELETE FROM WHERE")
687            .tokenize()
688            .unwrap();
689        assert_eq!(tokens[0].kind, TokenKind::Select);
690        assert_eq!(tokens[1].kind, TokenKind::Insert);
691        assert_eq!(tokens[2].kind, TokenKind::Update);
692        assert_eq!(tokens[3].kind, TokenKind::Delete);
693        assert_eq!(tokens[4].kind, TokenKind::From);
694        assert_eq!(tokens[5].kind, TokenKind::Where);
695    }
696
697    #[test]
698    fn test_placeholder() {
699        let tokens = Lexer::new("$1 $2 $10").tokenize().unwrap();
700        assert!(matches!(tokens[0].kind, TokenKind::Placeholder(1)));
701        assert!(matches!(tokens[1].kind, TokenKind::Placeholder(2)));
702        assert!(matches!(tokens[2].kind, TokenKind::Placeholder(10)));
703    }
704
705    #[test]
706    fn test_line_comment() {
707        let tokens = Lexer::new("SELECT -- comment\n* FROM users")
708            .tokenize()
709            .unwrap();
710        assert_eq!(tokens.len(), 5); // SELECT, *, FROM, users, EOF
711        assert_eq!(tokens[0].kind, TokenKind::Select);
712        assert_eq!(tokens[1].kind, TokenKind::Star);
713    }
714
715    #[test]
716    fn test_blob_literal() {
717        let tokens = Lexer::new("X'48454C4C4F'").tokenize().unwrap();
718        assert!(matches!(&tokens[0].kind, TokenKind::Blob(b) if b == b"HELLO"));
719    }
720
721    #[test]
722    fn test_left_arrow() {
723        let tokens = Lexer::new("<-").tokenize().unwrap();
724        assert_eq!(tokens[0].kind, TokenKind::LeftArrow);
725    }
726
727    #[test]
728    fn test_biarrow() {
729        let tokens = Lexer::new("<->").tokenize().unwrap();
730        assert_eq!(tokens[0].kind, TokenKind::BiArrow);
731    }
732
733    #[test]
734    fn test_arrow_tokens_in_context() {
735        let tokens = Lexer::new("a -> b <- c <-> d").tokenize().unwrap();
736        assert!(matches!(tokens[0].kind, TokenKind::Identifier("a")));
737        assert_eq!(tokens[1].kind, TokenKind::Arrow);
738        assert!(matches!(tokens[2].kind, TokenKind::Identifier("b")));
739        assert_eq!(tokens[3].kind, TokenKind::LeftArrow);
740        assert!(matches!(tokens[4].kind, TokenKind::Identifier("c")));
741        assert_eq!(tokens[5].kind, TokenKind::BiArrow);
742        assert!(matches!(tokens[6].kind, TokenKind::Identifier("d")));
743    }
744
745    #[test]
746    fn test_relate_keyword() {
747        let tokens = Lexer::new("RELATE LIVE CONTENT EVENT DIFF").tokenize().unwrap();
748        assert_eq!(tokens[0].kind, TokenKind::Relate);
749        assert_eq!(tokens[1].kind, TokenKind::Live);
750        assert_eq!(tokens[2].kind, TokenKind::Content);
751        assert_eq!(tokens[3].kind, TokenKind::Event);
752        assert_eq!(tokens[4].kind, TokenKind::Diff);
753    }
754}