Skip to main content

powdb_query/
lexer.rs

1use crate::token::Token;
2
3#[derive(Debug)]
4pub struct LexError {
5    pub message: String,
6    pub position: usize,
7}
8
9impl std::fmt::Display for LexError {
10    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
11        write!(f, "at position {}: {}", self.position, self.message)
12    }
13}
14
15impl std::error::Error for LexError {}
16
17pub fn lex(input: &str) -> Result<Vec<Token>, LexError> {
18    let mut tokens = Vec::new();
19    let chars: Vec<char> = input.chars().collect();
20    let mut pos = 0;
21
22    while pos < chars.len() {
23        // Skip whitespace
24        if chars[pos].is_whitespace() {
25            pos += 1;
26            continue;
27        }
28
29        // Skip comments
30        if chars[pos] == '#' {
31            while pos < chars.len() && chars[pos] != '\n' {
32                pos += 1;
33            }
34            continue;
35        }
36
37        // Dot-ident: .fieldname
38        if chars[pos] == '.'
39            && pos + 1 < chars.len()
40            && (chars[pos + 1].is_alphabetic() || chars[pos + 1] == '_')
41        {
42            pos += 1; // skip dot
43            let start = pos;
44            while pos < chars.len() && (chars[pos].is_alphanumeric() || chars[pos] == '_') {
45                pos += 1;
46            }
47            let name: String = chars[start..pos].iter().collect();
48            tokens.push(Token::DotIdent(name));
49            continue;
50        }
51
52        // Param: $name
53        if chars[pos] == '$' {
54            pos += 1;
55            let start = pos;
56            while pos < chars.len() && (chars[pos].is_alphanumeric() || chars[pos] == '_') {
57                pos += 1;
58            }
59            let name: String = chars[start..pos].iter().collect();
60            tokens.push(Token::Param(name));
61            continue;
62        }
63
64        // String literal
65        if chars[pos] == '"' {
66            pos += 1;
67            let mut s = String::new();
68            while pos < chars.len() && chars[pos] != '"' {
69                if chars[pos] == '\\' && pos + 1 < chars.len() {
70                    match chars[pos + 1] {
71                        '"' => {
72                            s.push('"');
73                            pos += 2;
74                        }
75                        '\\' => {
76                            s.push('\\');
77                            pos += 2;
78                        }
79                        'n' => {
80                            s.push('\n');
81                            pos += 2;
82                        }
83                        't' => {
84                            s.push('\t');
85                            pos += 2;
86                        }
87                        _ => {
88                            s.push(chars[pos + 1]);
89                            pos += 2;
90                        }
91                    }
92                } else {
93                    s.push(chars[pos]);
94                    pos += 1;
95                }
96            }
97            if pos >= chars.len() {
98                return Err(LexError {
99                    message: "unterminated string".into(),
100                    position: pos,
101                });
102            }
103            pos += 1; // closing quote
104            tokens.push(Token::StringLit(s));
105            continue;
106        }
107
108        // Number (int or float)
109        if chars[pos].is_ascii_digit()
110            || (chars[pos] == '-' && pos + 1 < chars.len() && chars[pos + 1].is_ascii_digit())
111        {
112            let start = pos;
113            if chars[pos] == '-' {
114                pos += 1;
115            }
116            while pos < chars.len() && chars[pos].is_ascii_digit() {
117                pos += 1;
118            }
119            if pos < chars.len()
120                && chars[pos] == '.'
121                && pos + 1 < chars.len()
122                && chars[pos + 1].is_ascii_digit()
123            {
124                pos += 1;
125                while pos < chars.len() && chars[pos].is_ascii_digit() {
126                    pos += 1;
127                }
128                let s: String = chars[start..pos].iter().collect();
129                let value = s.parse::<f64>().map_err(|_| LexError {
130                    message: format!("float literal out of range: {s}"),
131                    position: start,
132                })?;
133                tokens.push(Token::FloatLit(value));
134            } else {
135                let s: String = chars[start..pos].iter().collect();
136                let value = s.parse::<i64>().map_err(|_| LexError {
137                    message: format!("integer literal out of range for i64: {s}"),
138                    position: start,
139                })?;
140                tokens.push(Token::IntLit(value));
141            }
142            continue;
143        }
144
145        // Identifiers and keywords
146        if chars[pos].is_alphabetic() || chars[pos] == '_' {
147            let start = pos;
148            while pos < chars.len() && (chars[pos].is_alphanumeric() || chars[pos] == '_') {
149                pos += 1;
150            }
151            let word: String = chars[start..pos].iter().collect();
152            let token = match word.as_str() {
153                "type" => Token::Type,
154                "filter" => Token::Filter,
155                "order" => Token::Order,
156                "limit" => Token::Limit,
157                "offset" => Token::Offset,
158                "insert" => Token::Insert,
159                "update" => Token::Update,
160                "delete" => Token::Delete,
161                "upsert" => Token::Upsert,
162                "conflict" => Token::Conflict,
163                "select" => Token::Select,
164                "required" => Token::Required,
165                "multi" => Token::Multi,
166                "link" => Token::Link,
167                "index" => Token::Index,
168                "on" => Token::On,
169                "asc" => Token::Asc,
170                "desc" => Token::Desc,
171                "and" => Token::And,
172                "or" => Token::Or,
173                "not" => Token::Not,
174                "exists" => Token::Exists,
175                "let" => Token::Let,
176                "as" => Token::As,
177                "match" => Token::Match,
178                "group" => Token::Group,
179                "join" => Token::Join,
180                "inner" => Token::Inner,
181                "left" => Token::LeftKw,
182                "right" => Token::RightKw,
183                "outer" => Token::Outer,
184                "cross" => Token::Cross,
185                "transaction" => Token::Transaction,
186                "view" => Token::View,
187                "materialized" => Token::Materialized,
188                "materialize" => Token::Materialized,
189                "refresh" => Token::Refresh,
190                "union" => Token::Union,
191                "having" => Token::Having,
192                "distinct" => Token::Distinct,
193                "in" => Token::In,
194                "between" => Token::Between,
195                "like" => Token::Like,
196                "count" => Token::Count,
197                "avg" => Token::Avg,
198                "sum" => Token::Sum,
199                "min" => Token::Min,
200                "max" => Token::Max,
201                "is" => Token::Is,
202                "null" => Token::Null,
203                "upper" => Token::Upper,
204                "lower" => Token::Lower,
205                "length" => Token::Length,
206                "trim" => Token::Trim,
207                "substring" => Token::Substring,
208                "concat" => Token::Concat,
209                "abs" => Token::Abs,
210                "round" => Token::Round,
211                "ceil" => Token::Ceil,
212                "floor" => Token::Floor,
213                "sqrt" => Token::Sqrt,
214                "pow" => Token::Pow,
215                "now" => Token::Now,
216                "extract" => Token::Extract,
217                "date_add" => Token::DateAdd,
218                "date_diff" => Token::DateDiff,
219                "cast" => Token::Cast,
220                "case" => Token::Case,
221                "when" => Token::When,
222                "then" => Token::Then,
223                "else" => Token::Else,
224                "end" => Token::End,
225                "over" => Token::Over,
226                "partition" => Token::Partition,
227                "row_number" => Token::RowNumber,
228                "rank" => Token::Rank,
229                "dense_rank" => Token::DenseRank,
230                "alter" => Token::Alter,
231                "drop" => Token::Drop,
232                "add" => Token::Add,
233                "column" => Token::Column,
234                "explain" => Token::Explain,
235                "true" => Token::BoolLit(true),
236                "false" => Token::BoolLit(false),
237                _ => Token::Ident(word),
238            };
239            tokens.push(token);
240            continue;
241        }
242
243        // Two-char operators
244        if pos + 1 < chars.len() {
245            let two: String = chars[pos..pos + 2].iter().collect();
246            match two.as_str() {
247                ":=" => {
248                    tokens.push(Token::Assign);
249                    pos += 2;
250                    continue;
251                }
252                "->" => {
253                    tokens.push(Token::Arrow);
254                    pos += 2;
255                    continue;
256                }
257                "!=" => {
258                    tokens.push(Token::Neq);
259                    pos += 2;
260                    continue;
261                }
262                "<=" => {
263                    tokens.push(Token::Lte);
264                    pos += 2;
265                    continue;
266                }
267                ">=" => {
268                    tokens.push(Token::Gte);
269                    pos += 2;
270                    continue;
271                }
272                "??" => {
273                    tokens.push(Token::Coalesce);
274                    pos += 2;
275                    continue;
276                }
277                _ => {}
278            }
279        }
280
281        // Single-char operators
282        let token = match chars[pos] {
283            '=' => Token::Eq,
284            '<' => Token::Lt,
285            '>' => Token::Gt,
286            '|' => Token::Pipe,
287            '+' => Token::Plus,
288            '-' => Token::Minus,
289            '*' => Token::Star,
290            '/' => Token::Slash,
291            '{' => Token::LBrace,
292            '}' => Token::RBrace,
293            '(' => Token::LParen,
294            ')' => Token::RParen,
295            ',' => Token::Comma,
296            ':' => Token::Colon,
297            '.' => Token::Dot,
298            c => {
299                return Err(LexError {
300                    message: format!("unexpected character: {c}"),
301                    position: pos,
302                })
303            }
304        };
305        tokens.push(token);
306        pos += 1;
307    }
308
309    tokens.push(Token::Eof);
310    Ok(tokens)
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316    use crate::token::Token;
317
318    #[test]
319    fn test_lex_simple_query() {
320        let tokens = lex("User filter .age > 30").unwrap();
321        assert_eq!(
322            tokens,
323            vec![
324                Token::Ident("User".into()),
325                Token::Filter,
326                Token::DotIdent("age".into()),
327                Token::Gt,
328                Token::IntLit(30),
329                Token::Eof,
330            ]
331        );
332    }
333
334    #[test]
335    fn test_lex_projection() {
336        let tokens = lex("User { name, email }").unwrap();
337        assert_eq!(
338            tokens,
339            vec![
340                Token::Ident("User".into()),
341                Token::LBrace,
342                Token::Ident("name".into()),
343                Token::Comma,
344                Token::Ident("email".into()),
345                Token::RBrace,
346                Token::Eof,
347            ]
348        );
349    }
350
351    #[test]
352    fn test_lex_insert() {
353        let tokens = lex(r#"insert User { name := "Alice", age := 30 }"#).unwrap();
354        assert_eq!(
355            tokens,
356            vec![
357                Token::Insert,
358                Token::Ident("User".into()),
359                Token::LBrace,
360                Token::Ident("name".into()),
361                Token::Assign,
362                Token::StringLit("Alice".into()),
363                Token::Comma,
364                Token::Ident("age".into()),
365                Token::Assign,
366                Token::IntLit(30),
367                Token::RBrace,
368                Token::Eof,
369            ]
370        );
371    }
372
373    #[test]
374    fn test_lex_params() {
375        let tokens = lex("User filter .age > $min_age").unwrap();
376        assert_eq!(
377            tokens,
378            vec![
379                Token::Ident("User".into()),
380                Token::Filter,
381                Token::DotIdent("age".into()),
382                Token::Gt,
383                Token::Param("min_age".into()),
384                Token::Eof,
385            ]
386        );
387    }
388
389    #[test]
390    fn test_lex_string_with_escapes() {
391        let tokens = lex(r#""hello \"world\"""#).unwrap();
392        assert_eq!(
393            tokens,
394            vec![Token::StringLit("hello \"world\"".into()), Token::Eof,]
395        );
396    }
397
398    #[test]
399    fn test_lex_aggregation() {
400        let tokens = lex("count(User)").unwrap();
401        assert_eq!(
402            tokens,
403            vec![
404                Token::Count,
405                Token::LParen,
406                Token::Ident("User".into()),
407                Token::RParen,
408                Token::Eof,
409            ]
410        );
411    }
412
413    /// Regression for issue #24: an integer literal with more digits than
414    /// i64 can hold previously reached `s.parse::<i64>().unwrap()` and
415    /// panicked. It must return a `LexError` instead.
416    #[test]
417    fn test_lex_intlit_overflow_returns_err() {
418        // 22 digits — well past i64::MAX (19 digits).
419        let err = lex("4444444441111111144444").expect_err("must error, not panic");
420        assert!(
421            err.message.contains("integer literal out of range"),
422            "unexpected message: {}",
423            err.message
424        );
425        assert_eq!(err.position, 0);
426    }
427
428    /// Same bug, reached via the exact fuzzer reproducer from the
429    /// libFuzzer artifact attached to issue #24 (base64
430    /// `YXMJCQkJCQkJCQkJCQkJNDQ0NDQ0NDQ0MTExMTExMTQ0NDQJCQkJCQk=`).
431    #[test]
432    fn test_lex_fuzz_repro_issue_24() {
433        let input = "as\t\t\t\t\t\t\t\t\t\t\t\t\t44444444411111114444\t\t\t\t\t\t";
434        let err = lex(input).expect_err("fuzz reproducer must now error, not panic");
435        assert!(err.message.contains("integer literal"));
436    }
437}