Skip to main content

sochdb_query/sql/
lexer.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! SQL Lexer
19//!
20//! Converts SQL text into a stream of tokens.
21//! Handles string literals, numbers, identifiers, keywords, and operators.
22
23use super::token::{Span, Token, TokenKind};
24use std::iter::Peekable;
25use std::str::Chars;
26
27/// SQL Lexer errors
28#[derive(Debug, Clone, PartialEq)]
29pub struct LexError {
30    pub message: String,
31    pub span: Span,
32}
33
34impl LexError {
35    pub fn new(message: impl Into<String>, span: Span) -> Self {
36        Self {
37            message: message.into(),
38            span,
39        }
40    }
41}
42
43impl std::fmt::Display for LexError {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        write!(
46            f,
47            "Lexer error at line {}, column {}: {}",
48            self.span.line, self.span.column, self.message
49        )
50    }
51}
52
53impl std::error::Error for LexError {}
54
55/// SQL Lexer - tokenizes SQL input
56pub struct Lexer<'a> {
57    input: &'a str,
58    chars: Peekable<Chars<'a>>,
59    pos: usize,
60    line: usize,
61    column: usize,
62    tokens: Vec<Token>,
63    errors: Vec<LexError>,
64    /// Counter for `?` style placeholders (auto-incrementing)
65    placeholder_counter: u32,
66}
67
68impl<'a> Lexer<'a> {
69    /// Create a new lexer for the given SQL input
70    pub fn new(input: &'a str) -> Self {
71        Self {
72            input,
73            chars: input.chars().peekable(),
74            pos: 0,
75            line: 1,
76            column: 1,
77            tokens: Vec::new(),
78            errors: Vec::new(),
79            placeholder_counter: 0,
80        }
81    }
82
83    /// Tokenize the entire input
84    pub fn tokenize(mut self) -> Result<Vec<Token>, Vec<LexError>> {
85        while !self.is_at_end() {
86            self.scan_token();
87        }
88
89        // Add EOF token
90        self.tokens.push(Token::new(
91            TokenKind::Eof,
92            Span::new(self.pos, self.pos, self.line, self.column),
93            "",
94        ));
95
96        if self.errors.is_empty() {
97            Ok(self.tokens)
98        } else {
99            Err(self.errors)
100        }
101    }
102
103    fn is_at_end(&mut self) -> bool {
104        self.chars.peek().is_none()
105    }
106
107    fn advance(&mut self) -> Option<char> {
108        let c = self.chars.next()?;
109        self.pos += c.len_utf8();
110        if c == '\n' {
111            self.line += 1;
112            self.column = 1;
113        } else {
114            self.column += 1;
115        }
116        Some(c)
117    }
118
119    fn peek(&mut self) -> Option<char> {
120        self.chars.peek().copied()
121    }
122
123    fn peek_next(&self) -> Option<char> {
124        let mut chars = self.chars.clone();
125        chars.next();
126        chars.next()
127    }
128
129    fn make_span(&self, start: usize, start_line: usize, start_col: usize) -> Span {
130        Span::new(start, self.pos, start_line, start_col)
131    }
132
133    fn scan_token(&mut self) {
134        let start = self.pos;
135        let start_line = self.line;
136        let start_col = self.column;
137
138        let c = match self.advance() {
139            Some(c) => c,
140            None => return,
141        };
142
143        match c {
144            // Whitespace
145            ' ' | '\t' | '\r' | '\n' => {
146                // Skip whitespace, don't emit token
147            }
148
149            // Single-character tokens
150            '(' => self.add_token(TokenKind::LParen, start, start_line, start_col),
151            ')' => self.add_token(TokenKind::RParen, start, start_line, start_col),
152            '[' => self.add_token(TokenKind::LBracket, start, start_line, start_col),
153            ']' => self.add_token(TokenKind::RBracket, start, start_line, start_col),
154            ',' => self.add_token(TokenKind::Comma, start, start_line, start_col),
155            ';' => self.add_token(TokenKind::Semicolon, start, start_line, start_col),
156            '+' => self.add_token(TokenKind::Plus, start, start_line, start_col),
157            '*' => self.add_token(TokenKind::Star, start, start_line, start_col),
158            '/' => {
159                if self.peek() == Some('/') || self.peek() == Some('*') {
160                    self.scan_comment(start, start_line, start_col);
161                } else {
162                    self.add_token(TokenKind::Slash, start, start_line, start_col);
163                }
164            }
165            '%' => self.add_token(TokenKind::Percent, start, start_line, start_col),
166            '&' => self.add_token(TokenKind::BitAnd, start, start_line, start_col),
167            '~' => self.add_token(TokenKind::BitNot, start, start_line, start_col),
168            '?' => {
169                // Auto-incrementing placeholder for JDBC/ODBC style ?
170                self.placeholder_counter += 1;
171                let span = self.make_span(start, start_line, start_col);
172                self.tokens.push(Token::new(
173                    TokenKind::Placeholder(self.placeholder_counter),
174                    span,
175                    "?",
176                ));
177            }
178            '@' => self.add_token(TokenKind::At, start, start_line, start_col),
179
180            // Two-character tokens
181            '-' => {
182                if self.peek() == Some('-') {
183                    // Line comment
184                    self.scan_line_comment(start, start_line, start_col);
185                } else if self.peek() == Some('>') {
186                    self.advance();
187                    if self.peek() == Some('>') {
188                        self.advance();
189                        self.add_token(TokenKind::DoubleArrow, start, start_line, start_col);
190                    } else {
191                        self.add_token(TokenKind::Arrow, start, start_line, start_col);
192                    }
193                } else {
194                    self.add_token(TokenKind::Minus, start, start_line, start_col);
195                }
196            }
197
198            '=' => self.add_token(TokenKind::Eq, start, start_line, start_col),
199
200            '!' => {
201                if self.peek() == Some('=') {
202                    self.advance();
203                    self.add_token(TokenKind::Ne, start, start_line, start_col);
204                } else {
205                    self.add_error("Unexpected character '!'", start, start_line, start_col);
206                }
207            }
208
209            '<' => {
210                if self.peek() == Some('=') {
211                    self.advance();
212                    self.add_token(TokenKind::Le, start, start_line, start_col);
213                } else if self.peek() == Some('>') {
214                    self.advance();
215                    self.add_token(TokenKind::Ne, start, start_line, start_col);
216                } else if self.peek() == Some('<') {
217                    self.advance();
218                    self.add_token(TokenKind::LeftShift, start, start_line, start_col);
219                } else {
220                    self.add_token(TokenKind::Lt, start, start_line, start_col);
221                }
222            }
223
224            '>' => {
225                if self.peek() == Some('=') {
226                    self.advance();
227                    self.add_token(TokenKind::Ge, start, start_line, start_col);
228                } else if self.peek() == Some('>') {
229                    self.advance();
230                    self.add_token(TokenKind::RightShift, start, start_line, start_col);
231                } else {
232                    self.add_token(TokenKind::Gt, start, start_line, start_col);
233                }
234            }
235
236            '|' => {
237                if self.peek() == Some('|') {
238                    self.advance();
239                    self.add_token(TokenKind::Concat, start, start_line, start_col);
240                } else {
241                    self.add_token(TokenKind::BitOr, start, start_line, start_col);
242                }
243            }
244
245            ':' => {
246                if self.peek() == Some(':') {
247                    self.advance();
248                    self.add_token(TokenKind::DoubleColon, start, start_line, start_col);
249                } else {
250                    self.add_token(TokenKind::Colon, start, start_line, start_col);
251                }
252            }
253
254            '.' => {
255                if self.peek().map(|c| c.is_ascii_digit()).unwrap_or(false) {
256                    self.scan_number(start, start_line, start_col, true);
257                } else {
258                    self.add_token(TokenKind::Dot, start, start_line, start_col);
259                }
260            }
261
262            // String literals
263            '\'' => self.scan_string(start, start_line, start_col, '\''),
264            '"' => self.scan_quoted_identifier(start, start_line, start_col, '"'),
265            '`' => self.scan_quoted_identifier(start, start_line, start_col, '`'),
266
267            // Blob literal (X'...')
268            'X' | 'x' if self.peek() == Some('\'') => {
269                self.advance();
270                self.scan_blob(start, start_line, start_col);
271            }
272
273            // Numbers
274            '0'..='9' => self.scan_number(start, start_line, start_col, false),
275
276            // Identifiers and keywords
277            'a'..='z' | 'A'..='Z' | '_' => self.scan_identifier(start, start_line, start_col),
278
279            // Placeholder ($1, $2, ...)
280            '$' => self.scan_placeholder(start, start_line, start_col),
281
282            _ => {
283                self.add_error(
284                    format!("Unexpected character '{}'", c),
285                    start,
286                    start_line,
287                    start_col,
288                );
289            }
290        }
291    }
292
293    fn scan_string(&mut self, start: usize, start_line: usize, start_col: usize, quote: char) {
294        let mut value = String::new();
295
296        while let Some(c) = self.peek() {
297            if c == quote {
298                self.advance();
299                // Check for escaped quote ('')
300                if self.peek() == Some(quote) {
301                    self.advance();
302                    value.push(quote);
303                } else {
304                    // End of string
305                    let span = self.make_span(start, start_line, start_col);
306                    self.tokens
307                        .push(Token::new(TokenKind::String(value), span, ""));
308                    return;
309                }
310            } else if c == '\\' {
311                self.advance();
312                // Handle escape sequences
313                if let Some(escaped) = self.advance() {
314                    match escaped {
315                        'n' => value.push('\n'),
316                        'r' => value.push('\r'),
317                        't' => value.push('\t'),
318                        '\\' => value.push('\\'),
319                        '\'' => value.push('\''),
320                        '"' => value.push('"'),
321                        '0' => value.push('\0'),
322                        _ => {
323                            value.push('\\');
324                            value.push(escaped);
325                        }
326                    }
327                }
328            } else {
329                self.advance();
330                value.push(c);
331            }
332        }
333
334        self.add_error("Unterminated string literal", start, start_line, start_col);
335    }
336
337    fn scan_quoted_identifier(
338        &mut self,
339        start: usize,
340        start_line: usize,
341        start_col: usize,
342        quote: char,
343    ) {
344        let mut value = String::new();
345
346        while let Some(c) = self.peek() {
347            if c == quote {
348                self.advance();
349                // Check for escaped quote
350                if self.peek() == Some(quote) {
351                    self.advance();
352                    value.push(quote);
353                } else {
354                    let span = self.make_span(start, start_line, start_col);
355                    self.tokens
356                        .push(Token::new(TokenKind::QuotedIdentifier(value), span, ""));
357                    return;
358                }
359            } else {
360                self.advance();
361                value.push(c);
362            }
363        }
364
365        self.add_error(
366            "Unterminated quoted identifier",
367            start,
368            start_line,
369            start_col,
370        );
371    }
372
373    fn scan_number(
374        &mut self,
375        start: usize,
376        start_line: usize,
377        start_col: usize,
378        started_with_dot: bool,
379    ) {
380        let num_start = start;
381        let mut has_dot = started_with_dot;
382        let mut has_exp = false;
383
384        // Consume integer part
385        while let Some(c) = self.peek() {
386            if c.is_ascii_digit() {
387                self.advance();
388            } else if c == '.' && !has_dot && !has_exp {
389                // Check it's not a range operator (..)
390                if self.peek_next() == Some('.') {
391                    break;
392                }
393                has_dot = true;
394                self.advance();
395            } else if (c == 'e' || c == 'E') && !has_exp {
396                has_exp = true;
397                self.advance();
398                // Optional sign
399                if self.peek() == Some('+') || self.peek() == Some('-') {
400                    self.advance();
401                }
402            } else {
403                break;
404            }
405        }
406
407        let literal = &self.input[num_start..self.pos];
408        let span = self.make_span(start, start_line, start_col);
409
410        if has_dot || has_exp {
411            match literal.parse::<f64>() {
412                Ok(n) => self
413                    .tokens
414                    .push(Token::new(TokenKind::Float(n), span, literal)),
415                Err(_) => self.add_error("Invalid float literal", start, start_line, start_col),
416            }
417        } else {
418            match literal.parse::<i64>() {
419                Ok(n) => self
420                    .tokens
421                    .push(Token::new(TokenKind::Integer(n), span, literal)),
422                Err(_) => self.add_error("Invalid integer literal", start, start_line, start_col),
423            }
424        }
425    }
426
427    fn scan_identifier(&mut self, start: usize, start_line: usize, start_col: usize) {
428        while let Some(c) = self.peek() {
429            if c.is_ascii_alphanumeric() || c == '_' {
430                self.advance();
431            } else {
432                break;
433            }
434        }
435
436        let literal = &self.input[start..self.pos];
437        let span = self.make_span(start, start_line, start_col);
438
439        // Check for keyword
440        let kind = TokenKind::from_keyword(literal)
441            .unwrap_or_else(|| TokenKind::Identifier(literal.to_string()));
442
443        self.tokens.push(Token::new(kind, span, literal));
444    }
445
446    fn scan_placeholder(&mut self, start: usize, start_line: usize, start_col: usize) {
447        let mut num = String::new();
448
449        while let Some(c) = self.peek() {
450            if c.is_ascii_digit() {
451                self.advance();
452                num.push(c);
453            } else {
454                break;
455            }
456        }
457
458        let span = self.make_span(start, start_line, start_col);
459
460        if num.is_empty() {
461            self.add_error("Expected number after $", start, start_line, start_col);
462        } else if let Ok(n) = num.parse::<u32>() {
463            self.tokens.push(Token::new(
464                TokenKind::Placeholder(n),
465                span,
466                &self.input[start..self.pos],
467            ));
468        } else {
469            self.add_error("Invalid placeholder number", start, start_line, start_col);
470        }
471    }
472
473    fn scan_comment(&mut self, start: usize, start_line: usize, start_col: usize) {
474        self.advance(); // consume second / or *
475
476        if self.peek() == Some('*') || self.input[start..self.pos].ends_with('*') {
477            // Block comment /* ... */
478            let mut depth = 1;
479
480            while depth > 0 && !self.is_at_end() {
481                let c = self.peek();
482                let next = self.peek_next();
483
484                if c == Some('*') && next == Some('/') {
485                    self.advance();
486                    self.advance();
487                    depth -= 1;
488                } else if c == Some('/') && next == Some('*') {
489                    self.advance();
490                    self.advance();
491                    depth += 1;
492                } else {
493                    self.advance();
494                }
495            }
496
497            if depth > 0 {
498                self.add_error("Unterminated block comment", start, start_line, start_col);
499            }
500        } else {
501            // Line comment //
502            while let Some(c) = self.peek() {
503                if c == '\n' {
504                    break;
505                }
506                self.advance();
507            }
508        }
509        // Don't emit comment tokens
510    }
511
512    fn scan_line_comment(&mut self, _start: usize, _start_line: usize, _start_col: usize) {
513        self.advance(); // consume second -
514
515        while let Some(c) = self.peek() {
516            if c == '\n' {
517                break;
518            }
519            self.advance();
520        }
521        // Don't emit comment tokens
522    }
523
524    fn scan_blob(&mut self, start: usize, start_line: usize, start_col: usize) {
525        let mut hex = String::new();
526
527        while let Some(c) = self.peek() {
528            if c == '\'' {
529                self.advance();
530                break;
531            } else if c.is_ascii_hexdigit() {
532                self.advance();
533                hex.push(c);
534            } else if c.is_whitespace() {
535                self.advance(); // Allow whitespace in blob
536            } else {
537                self.add_error(
538                    "Invalid hex digit in blob literal",
539                    start,
540                    start_line,
541                    start_col,
542                );
543                return;
544            }
545        }
546
547        if !hex.len().is_multiple_of(2) {
548            self.add_error(
549                "Blob literal must have even number of hex digits",
550                start,
551                start_line,
552                start_col,
553            );
554            return;
555        }
556
557        let bytes: Result<Vec<u8>, _> = (0..hex.len())
558            .step_by(2)
559            .map(|i| u8::from_str_radix(&hex[i..i + 2], 16))
560            .collect();
561
562        match bytes {
563            Ok(data) => {
564                let span = self.make_span(start, start_line, start_col);
565                self.tokens
566                    .push(Token::new(TokenKind::Blob(data), span, ""));
567            }
568            Err(_) => {
569                self.add_error("Invalid blob literal", start, start_line, start_col);
570            }
571        }
572    }
573
574    fn add_token(&mut self, kind: TokenKind, start: usize, start_line: usize, start_col: usize) {
575        let span = self.make_span(start, start_line, start_col);
576        let literal = &self.input[start..self.pos];
577        self.tokens.push(Token::new(kind, span, literal));
578    }
579
580    fn add_error(
581        &mut self,
582        message: impl Into<String>,
583        start: usize,
584        start_line: usize,
585        start_col: usize,
586    ) {
587        let span = self.make_span(start, start_line, start_col);
588        self.errors.push(LexError::new(message, span));
589    }
590}
591
592#[cfg(test)]
593mod tests {
594    use super::*;
595
596    #[test]
597    fn test_simple_select() {
598        let tokens = Lexer::new("SELECT * FROM users").tokenize().unwrap();
599        assert_eq!(tokens.len(), 5); // SELECT, *, FROM, users, EOF
600        assert_eq!(tokens[0].kind, TokenKind::Select);
601        assert_eq!(tokens[1].kind, TokenKind::Star);
602        assert_eq!(tokens[2].kind, TokenKind::From);
603        assert!(matches!(tokens[3].kind, TokenKind::Identifier(_)));
604    }
605
606    #[test]
607    fn test_string_literal() {
608        let tokens = Lexer::new("SELECT 'hello''world'").tokenize().unwrap();
609        assert!(matches!(&tokens[1].kind, TokenKind::String(s) if s == "hello'world"));
610    }
611
612    #[test]
613    #[allow(clippy::approx_constant)]
614    fn test_numbers() {
615        let tokens = Lexer::new("42 3.14 1e10 .5").tokenize().unwrap();
616        assert!(matches!(tokens[0].kind, TokenKind::Integer(42)));
617        assert!(matches!(tokens[1].kind, TokenKind::Float(f) if (f - 3.14).abs() < 0.001));
618        assert!(matches!(tokens[2].kind, TokenKind::Float(_)));
619        assert!(matches!(tokens[3].kind, TokenKind::Float(f) if (f - 0.5).abs() < 0.001));
620    }
621
622    #[test]
623    fn test_operators() {
624        let tokens = Lexer::new("= != <> < <= > >= || ->").tokenize().unwrap();
625        assert_eq!(tokens[0].kind, TokenKind::Eq);
626        assert_eq!(tokens[1].kind, TokenKind::Ne);
627        assert_eq!(tokens[2].kind, TokenKind::Ne);
628        assert_eq!(tokens[3].kind, TokenKind::Lt);
629        assert_eq!(tokens[4].kind, TokenKind::Le);
630        assert_eq!(tokens[5].kind, TokenKind::Gt);
631        assert_eq!(tokens[6].kind, TokenKind::Ge);
632        assert_eq!(tokens[7].kind, TokenKind::Concat);
633        assert_eq!(tokens[8].kind, TokenKind::Arrow);
634    }
635
636    #[test]
637    fn test_keywords() {
638        let tokens = Lexer::new("SELECT INSERT UPDATE DELETE FROM WHERE")
639            .tokenize()
640            .unwrap();
641        assert_eq!(tokens[0].kind, TokenKind::Select);
642        assert_eq!(tokens[1].kind, TokenKind::Insert);
643        assert_eq!(tokens[2].kind, TokenKind::Update);
644        assert_eq!(tokens[3].kind, TokenKind::Delete);
645        assert_eq!(tokens[4].kind, TokenKind::From);
646        assert_eq!(tokens[5].kind, TokenKind::Where);
647    }
648
649    #[test]
650    fn test_placeholder() {
651        let tokens = Lexer::new("$1 $2 $10").tokenize().unwrap();
652        assert!(matches!(tokens[0].kind, TokenKind::Placeholder(1)));
653        assert!(matches!(tokens[1].kind, TokenKind::Placeholder(2)));
654        assert!(matches!(tokens[2].kind, TokenKind::Placeholder(10)));
655    }
656
657    #[test]
658    fn test_line_comment() {
659        let tokens = Lexer::new("SELECT -- comment\n* FROM users")
660            .tokenize()
661            .unwrap();
662        assert_eq!(tokens.len(), 5); // SELECT, *, FROM, users, EOF
663        assert_eq!(tokens[0].kind, TokenKind::Select);
664        assert_eq!(tokens[1].kind, TokenKind::Star);
665    }
666
667    #[test]
668    fn test_blob_literal() {
669        let tokens = Lexer::new("X'48454C4C4F'").tokenize().unwrap();
670        assert!(matches!(&tokens[0].kind, TokenKind::Blob(b) if b == b"HELLO"));
671    }
672}