Skip to main content

hematite/parser/
lexer.rs

1//! SQL query lexer for tokenizing SQL statements
2
3use crate::error::{HematiteError, Result};
4
5#[derive(Debug, Clone, PartialEq)]
6pub enum Token {
7    // Keywords
8    Begin,
9    Commit,
10    Rollback,
11    Savepoint,
12    Release,
13    Select,
14    Update,
15    From,
16    Insert,
17    Delete,
18    Drop,
19    Explain,
20    Describe,
21    Show,
22    Tables,
23    Views,
24    Indexes,
25    Triggers,
26    Alter,
27    Add,
28    If,
29    Into,
30    Set,
31    Values,
32    Create,
33    View,
34    Trigger,
35    Index,
36    Exists,
37    Union,
38    Intersect,
39    Except,
40    All,
41    With,
42    Recursive,
43    Left,
44    Right,
45    Full,
46    Outer,
47    Inner,
48    Join,
49    On,
50    As,
51    Distinct,
52    Cast,
53    Table,
54    Column,
55    Where,
56    Group,
57    Having,
58    Order,
59    By,
60    Asc,
61    Desc,
62    Over,
63    Partition,
64    Interval,
65    Limit,
66    Offset,
67    Count,
68    Sum,
69    Avg,
70    Min,
71    Max,
72    Int32,
73    Text,
74    Boolean,
75    Float,
76    Int8,
77    Int16,
78    Int64,
79    Int128,
80    Int,
81    UInt8,
82    UInt16,
83    UInt64,
84    UInt128,
85    UInt32,
86    UInt,
87    Bool,
88    Float32,
89    Float64,
90    Decimal,
91    Blob,
92    Date,
93    Time,
94    DateTime,
95    Zone,
96    Char,
97    Varchar,
98    BinaryType,
99    VarBinary,
100    Enum,
101    AutoIncrement,
102    Unique,
103    Primary,
104    Key,
105    Duplicate,
106    Constraint,
107    Check,
108    Foreign,
109    References,
110    Cascade,
111    Restrict,
112    Rename,
113    To,
114    After,
115    Not,
116    Is,
117    Null,
118    Default,
119    In,
120    Between,
121    Like,
122    Case,
123    When,
124    Then,
125    Else,
126    End,
127    And,
128    Or,
129
130    // Operators
131    Equal,
132    NotEqual,
133    LessThan,
134    LessThanOrEqual,
135    GreaterThan,
136    GreaterThanOrEqual,
137    // Note: logical operators are tokenized as keywords (AND/OR) and as symbols (&&/||)
138
139    // Punctuation
140    Comma,
141    Dot,
142    Semicolon,
143    LeftParen,
144    RightParen,
145    Plus,
146    Minus,
147    Asterisk,
148    Slash,
149    Percent,
150    Placeholder,
151
152    // Literals
153    Identifier(String),
154    StringLiteral(String),
155    BlobLiteral(Vec<u8>),
156    NumberLiteral(String),
157    BooleanLiteral(bool),
158    NullLiteral,
159}
160
161#[derive(Debug, Clone)]
162pub struct Lexer {
163    input: String,
164    position: usize,
165    tokens: Vec<Token>,
166}
167
168impl Lexer {
169    pub fn new(input: String) -> Self {
170        Self {
171            input,
172            position: 0,
173            tokens: Vec::new(),
174        }
175    }
176
177    pub fn tokenize(&mut self) -> Result<()> {
178        while self.position < self.input.len() {
179            self.skip_whitespace();
180
181            if self.position >= self.input.len() {
182                break;
183            }
184
185            let ch = self.current_char();
186
187            // Handle identifiers and keywords
188            if ch == 'X' && self.peek_char() == Some('\'') {
189                self.read_blob_literal()?;
190            } else if ch.is_alphabetic() || ch == '_' {
191                self.read_identifier()?;
192            } else if ch == '`' {
193                self.read_quoted_identifier()?;
194            }
195            // Handle string literals
196            else if ch == '\'' {
197                self.read_string_literal()?;
198            }
199            // Handle numbers
200            else if ch.is_ascii_digit() {
201                self.read_number()?;
202            }
203            // Handle operators and punctuation
204            else {
205                self.read_operator_or_punctuation()?;
206            }
207        }
208
209        Ok(())
210    }
211
212    pub fn get_tokens(&self) -> &[Token] {
213        &self.tokens
214    }
215
216    fn skip_whitespace(&mut self) {
217        while self.position < self.input.len() {
218            let ch = self.current_char();
219            if !ch.is_whitespace() {
220                break;
221            }
222            self.advance_char();
223        }
224    }
225
226    fn current_char(&self) -> char {
227        self.input[self.position..].chars().next().unwrap_or('\0')
228    }
229
230    fn peek_char(&self) -> Option<char> {
231        let mut chars = self.input[self.position..].chars();
232        chars.next()?;
233        chars.next()
234    }
235
236    fn advance_char(&mut self) {
237        if self.position < self.input.len() {
238            self.position += self.current_char().len_utf8();
239        }
240    }
241
242    fn read_identifier(&mut self) -> Result<()> {
243        let start = self.position;
244
245        while self.position < self.input.len() {
246            let ch = self.current_char();
247            if ch.is_alphanumeric() || ch == '_' {
248                self.advance_char();
249            } else {
250                break;
251            }
252        }
253
254        let identifier = &self.input[start..self.position];
255        let token = match identifier {
256            "BEGIN" => Token::Begin,
257            "COMMIT" => Token::Commit,
258            "ROLLBACK" => Token::Rollback,
259            "SAVEPOINT" => Token::Savepoint,
260            "RELEASE" => Token::Release,
261            "SELECT" => Token::Select,
262            "UPDATE" => Token::Update,
263            "FROM" => Token::From,
264            "INSERT" => Token::Insert,
265            "DELETE" => Token::Delete,
266            "DROP" => Token::Drop,
267            "EXPLAIN" => Token::Explain,
268            "DESCRIBE" => Token::Describe,
269            "SHOW" => Token::Show,
270            "TABLES" => Token::Tables,
271            "VIEWS" => Token::Views,
272            "INDEXES" => Token::Indexes,
273            "TRIGGERS" => Token::Triggers,
274            "ALTER" => Token::Alter,
275            "ADD" => Token::Add,
276            "IF" => Token::If,
277            "INTO" => Token::Into,
278            "SET" => Token::Set,
279            "VALUES" => Token::Values,
280            "CREATE" => Token::Create,
281            "VIEW" => Token::View,
282            "TRIGGER" => Token::Trigger,
283            "INDEX" => Token::Index,
284            "EXISTS" => Token::Exists,
285            "UNION" => Token::Union,
286            "INTERSECT" => Token::Intersect,
287            "EXCEPT" => Token::Except,
288            "ALL" => Token::All,
289            "WITH" => Token::With,
290            "RECURSIVE" => Token::Recursive,
291            "LEFT" => Token::Left,
292            "RIGHT" => Token::Right,
293            "FULL" => Token::Full,
294            "OUTER" => Token::Outer,
295            "INNER" => Token::Inner,
296            "JOIN" => Token::Join,
297            "ON" => Token::On,
298            "AS" => Token::As,
299            "DISTINCT" => Token::Distinct,
300            "CAST" => Token::Cast,
301            "TABLE" => Token::Table,
302            "COLUMN" => Token::Column,
303            "WHERE" => Token::Where,
304            "GROUP" => Token::Group,
305            "HAVING" => Token::Having,
306            "ORDER" => Token::Order,
307            "BY" => Token::By,
308            "ASC" => Token::Asc,
309            "DESC" => Token::Desc,
310            "OVER" => Token::Over,
311            "PARTITION" => Token::Partition,
312            "INTERVAL" => Token::Interval,
313            "LIMIT" => Token::Limit,
314            "OFFSET" => Token::Offset,
315            "COUNT" => Token::Count,
316            "SUM" => Token::Sum,
317            "AVG" => Token::Avg,
318            "MIN" => Token::Min,
319            "MAX" => Token::Max,
320            "INT8" => Token::Int8,
321            "INT16" => Token::Int16,
322            "INT64" => Token::Int64,
323            "INT128" => Token::Int128,
324            "INT32" => Token::Int32,
325            "INT" => Token::Int,
326            "UINT8" => Token::UInt8,
327            "UINT16" => Token::UInt16,
328            "UINT64" => Token::UInt64,
329            "UINT128" => Token::UInt128,
330            "UINT32" => Token::UInt32,
331            "UINT" => Token::UInt,
332            "TEXT" => Token::Text,
333            "BOOLEAN" => Token::Boolean,
334            "BOOL" => Token::Bool,
335            "FLOAT" => Token::Float,
336            "FLOAT32" => Token::Float32,
337            "FLOAT64" => Token::Float64,
338            "DECIMAL" => Token::Decimal,
339            "BLOB" => Token::Blob,
340            "DATE" => Token::Date,
341            "TIME" => Token::Time,
342            "DATETIME" => Token::DateTime,
343            "ZONE" => Token::Zone,
344            "CHAR" => Token::Char,
345            "VARCHAR" => Token::Varchar,
346            "BINARY" => Token::BinaryType,
347            "VARBINARY" => Token::VarBinary,
348            "ENUM" => Token::Enum,
349            "AUTO_INCREMENT" => Token::AutoIncrement,
350            "UNIQUE" => Token::Unique,
351            "PRIMARY" => Token::Primary,
352            "KEY" => Token::Key,
353            "DUPLICATE" => Token::Duplicate,
354            "CONSTRAINT" => Token::Constraint,
355            "CHECK" => Token::Check,
356            "FOREIGN" => Token::Foreign,
357            "REFERENCES" => Token::References,
358            "CASCADE" => Token::Cascade,
359            "RESTRICT" => Token::Restrict,
360            "RENAME" => Token::Rename,
361            "TO" => Token::To,
362            "AFTER" => Token::After,
363            "NOT" => Token::Not,
364            "IS" => Token::Is,
365            "NULL" => Token::Null,
366            "DEFAULT" => Token::Default,
367            "IN" => Token::In,
368            "BETWEEN" => Token::Between,
369            "LIKE" => Token::Like,
370            "CASE" => Token::Case,
371            "WHEN" => Token::When,
372            "THEN" => Token::Then,
373            "ELSE" => Token::Else,
374            "END" => Token::End,
375            "AND" => Token::And,
376            "OR" => Token::Or,
377            "TRUE" => Token::BooleanLiteral(true),
378            "FALSE" => Token::BooleanLiteral(false),
379            _ => Token::Identifier(identifier.to_string()),
380        };
381
382        self.tokens.push(token);
383        Ok(())
384    }
385
386    fn read_quoted_identifier(&mut self) -> Result<()> {
387        self.advance_char();
388        let mut identifier = String::new();
389
390        while self.position < self.input.len() {
391            let ch = self.current_char();
392            if ch == '`' {
393                if self.peek_char() == Some('`') {
394                    identifier.push('`');
395                    self.advance_char();
396                    self.advance_char();
397                    continue;
398                }
399
400                self.advance_char();
401                self.tokens.push(Token::Identifier(identifier));
402                return Ok(());
403            }
404
405            identifier.push(ch);
406            self.advance_char();
407        }
408
409        Err(HematiteError::ParseError(
410            "Unterminated quoted identifier".to_string(),
411        ))
412    }
413
414    fn read_string_literal(&mut self) -> Result<()> {
415        self.advance_char(); // Skip opening quote
416        let mut literal = String::new();
417
418        while self.position < self.input.len() {
419            let ch = self.current_char();
420            if ch == '\'' {
421                if self.peek_char() == Some('\'') {
422                    literal.push('\'');
423                    self.advance_char();
424                    self.advance_char();
425                    continue;
426                }
427
428                self.tokens.push(Token::StringLiteral(literal));
429                self.advance_char(); // Skip closing quote
430                return Ok(());
431            }
432
433            if ch == '\\' {
434                if let Some(next) = self.peek_char() {
435                    if next == '\'' || next == '\\' {
436                        literal.push(next);
437                        self.advance_char();
438                        self.advance_char();
439                        continue;
440                    }
441                }
442            }
443
444            literal.push(ch);
445            self.advance_char();
446        }
447
448        Err(HematiteError::ParseError(
449            "Unterminated string literal".to_string(),
450        ))
451    }
452
453    fn read_blob_literal(&mut self) -> Result<()> {
454        self.advance_char(); // Skip X
455        self.advance_char(); // Skip opening quote
456        let mut literal = String::new();
457
458        while self.position < self.input.len() {
459            let ch = self.current_char();
460            if ch == '\'' {
461                if literal.len() % 2 != 0 {
462                    return Err(HematiteError::ParseError(
463                        "Hex blob literal must contain an even number of digits".to_string(),
464                    ));
465                }
466
467                let mut bytes = Vec::with_capacity(literal.len() / 2);
468                for index in (0..literal.len()).step_by(2) {
469                    let byte =
470                        u8::from_str_radix(&literal[index..index + 2], 16).map_err(|_| {
471                            HematiteError::ParseError("Invalid hex blob literal".to_string())
472                        })?;
473                    bytes.push(byte);
474                }
475
476                self.tokens.push(Token::BlobLiteral(bytes));
477                self.advance_char(); // Skip closing quote
478                return Ok(());
479            }
480
481            if !ch.is_ascii_hexdigit() {
482                return Err(HematiteError::ParseError(
483                    "Hex blob literal may only contain hexadecimal digits".to_string(),
484                ));
485            }
486
487            literal.push(ch);
488            self.advance_char();
489        }
490
491        Err(HematiteError::ParseError(
492            "Unterminated hex blob literal".to_string(),
493        ))
494    }
495
496    fn read_number(&mut self) -> Result<()> {
497        let start = self.position;
498        let mut has_decimal = false;
499
500        while self.position < self.input.len() {
501            let ch = self.current_char();
502            if ch == '.' {
503                if has_decimal {
504                    return Err(HematiteError::ParseError(
505                        "Invalid number format".to_string(),
506                    ));
507                }
508                has_decimal = true;
509                self.advance_char();
510            } else if ch.is_ascii_digit() {
511                self.advance_char();
512            } else {
513                break;
514            }
515        }
516
517        let number_str = &self.input[start..self.position];
518        if has_decimal {
519            number_str
520                .parse::<f64>()
521                .map_err(|_| HematiteError::ParseError("Invalid number".to_string()))?;
522        } else {
523            number_str
524                .parse::<i128>()
525                .map_err(|_| HematiteError::ParseError("Invalid integer".to_string()))?;
526        }
527
528        self.tokens
529            .push(Token::NumberLiteral(number_str.to_string()));
530
531        Ok(())
532    }
533
534    fn read_operator_or_punctuation(&mut self) -> Result<()> {
535        let ch = self.current_char();
536        let token = match ch {
537            '=' => Token::Equal,
538            '!' => {
539                if self.peek_char() == Some('=') {
540                    self.advance_char();
541                    Token::NotEqual
542                } else {
543                    Token::Not
544                }
545            }
546            '<' => {
547                if self.peek_char() == Some('=') {
548                    self.advance_char();
549                    Token::LessThanOrEqual
550                } else if self.peek_char() == Some('>') {
551                    self.advance_char();
552                    Token::NotEqual
553                } else {
554                    Token::LessThan
555                }
556            }
557            '>' => {
558                if self.peek_char() == Some('=') {
559                    self.advance_char();
560                    Token::GreaterThanOrEqual
561                } else {
562                    Token::GreaterThan
563                }
564            }
565            '&' => {
566                if self.peek_char() == Some('&') {
567                    self.advance_char();
568                    Token::And
569                } else {
570                    return Err(HematiteError::ParseError("Invalid operator".to_string()));
571                }
572            }
573            '|' => {
574                if self.peek_char() == Some('|') {
575                    self.advance_char();
576                    Token::Or
577                } else {
578                    return Err(HematiteError::ParseError("Invalid operator".to_string()));
579                }
580            }
581            ',' => Token::Comma,
582            '.' => Token::Dot,
583            ';' => Token::Semicolon,
584            '(' => Token::LeftParen,
585            ')' => Token::RightParen,
586            '+' => Token::Plus,
587            '-' => Token::Minus,
588            '*' => Token::Asterisk,
589            '/' => Token::Slash,
590            '%' => Token::Percent,
591            '?' => Token::Placeholder,
592            _ => {
593                return Err(HematiteError::ParseError(format!(
594                    "Unexpected character: {}",
595                    ch
596                )))
597            }
598        };
599
600        self.advance_char();
601        self.tokens.push(token);
602        Ok(())
603    }
604}