Skip to main content

powdb_query/
lexer.rs

1use crate::token::Token;
2
3/// Maximum allowed length for a string literal (16 MB).
4/// Prevents unbounded memory consumption from queries with multi-gigabyte strings.
5const MAX_STRING_LITERAL: usize = 16 * 1024 * 1024;
6
7#[derive(Debug)]
8pub struct LexError {
9    pub message: String,
10    pub position: usize,
11}
12
13impl std::fmt::Display for LexError {
14    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15        write!(f, "at position {}: {}", self.position, self.message)
16    }
17}
18
19impl std::error::Error for LexError {}
20
21pub fn lex(input: &str) -> Result<Vec<Token>, LexError> {
22    let mut tokens = Vec::new();
23    let chars: Vec<char> = input.chars().collect();
24    let mut pos = 0;
25
26    while pos < chars.len() {
27        // Skip whitespace
28        if chars[pos].is_whitespace() {
29            pos += 1;
30            continue;
31        }
32
33        // Skip comments
34        if chars[pos] == '#' {
35            while pos < chars.len() && chars[pos] != '\n' {
36                pos += 1;
37            }
38            continue;
39        }
40
41        // Dot-ident: .fieldname
42        if chars[pos] == '.'
43            && pos + 1 < chars.len()
44            && (chars[pos + 1].is_alphabetic() || chars[pos + 1] == '_')
45        {
46            pos += 1; // skip dot
47            let start = pos;
48            while pos < chars.len() && (chars[pos].is_alphanumeric() || chars[pos] == '_') {
49                pos += 1;
50            }
51            let name: String = chars[start..pos].iter().collect();
52            tokens.push(Token::DotIdent(name));
53            continue;
54        }
55
56        // Param: $name
57        if chars[pos] == '$' {
58            pos += 1;
59            let start = pos;
60            while pos < chars.len() && (chars[pos].is_alphanumeric() || chars[pos] == '_') {
61                pos += 1;
62            }
63            let name: String = chars[start..pos].iter().collect();
64            tokens.push(Token::Param(name));
65            continue;
66        }
67
68        // String literal
69        if chars[pos] == '"' {
70            pos += 1;
71            let mut s = String::new();
72            while pos < chars.len() && chars[pos] != '"' {
73                if chars[pos] == '\\' && pos + 1 < chars.len() {
74                    match chars[pos + 1] {
75                        '"' => {
76                            s.push('"');
77                            pos += 2;
78                        }
79                        '\\' => {
80                            s.push('\\');
81                            pos += 2;
82                        }
83                        'n' => {
84                            s.push('\n');
85                            pos += 2;
86                        }
87                        't' => {
88                            s.push('\t');
89                            pos += 2;
90                        }
91                        _ => {
92                            s.push(chars[pos + 1]);
93                            pos += 2;
94                        }
95                    }
96                } else {
97                    s.push(chars[pos]);
98                    pos += 1;
99                }
100            }
101            if pos >= chars.len() {
102                return Err(LexError {
103                    message: "unterminated string".into(),
104                    position: pos,
105                });
106            }
107            pos += 1; // closing quote
108            if s.len() > MAX_STRING_LITERAL {
109                return Err(LexError {
110                    message: format!(
111                        "string literal exceeds maximum size of {}MB",
112                        MAX_STRING_LITERAL / (1024 * 1024)
113                    ),
114                    position: pos,
115                });
116            }
117            tokens.push(Token::StringLit(s));
118            continue;
119        }
120
121        // Number (int or float)
122        if chars[pos].is_ascii_digit()
123            || (chars[pos] == '-' && pos + 1 < chars.len() && chars[pos + 1].is_ascii_digit())
124        {
125            let start = pos;
126            if chars[pos] == '-' {
127                pos += 1;
128            }
129            while pos < chars.len() && chars[pos].is_ascii_digit() {
130                pos += 1;
131            }
132            if pos < chars.len()
133                && chars[pos] == '.'
134                && pos + 1 < chars.len()
135                && chars[pos + 1].is_ascii_digit()
136            {
137                pos += 1;
138                while pos < chars.len() && chars[pos].is_ascii_digit() {
139                    pos += 1;
140                }
141                let s: String = chars[start..pos].iter().collect();
142                let value = s.parse::<f64>().map_err(|_| LexError {
143                    message: format!("float literal out of range: {s}"),
144                    position: start,
145                })?;
146                tokens.push(Token::FloatLit(value));
147            } else {
148                let s: String = chars[start..pos].iter().collect();
149                let value = s.parse::<i64>().map_err(|_| LexError {
150                    message: format!("integer literal out of range for i64: {s}"),
151                    position: start,
152                })?;
153                tokens.push(Token::IntLit(value));
154            }
155            continue;
156        }
157
158        // Identifiers and keywords
159        if chars[pos].is_alphabetic() || chars[pos] == '_' {
160            let start = pos;
161            while pos < chars.len() && (chars[pos].is_alphanumeric() || chars[pos] == '_') {
162                pos += 1;
163            }
164            let word: String = chars[start..pos].iter().collect();
165            let token = match word.as_str() {
166                "type" => Token::Type,
167                "filter" => Token::Filter,
168                "order" => Token::Order,
169                "limit" => Token::Limit,
170                "offset" => Token::Offset,
171                "insert" => Token::Insert,
172                "update" => Token::Update,
173                "delete" => Token::Delete,
174                "upsert" => Token::Upsert,
175                "conflict" => Token::Conflict,
176                "select" => Token::Select,
177                "required" => Token::Required,
178                "multi" => Token::Multi,
179                "link" => Token::Link,
180                "index" => Token::Index,
181                "on" => Token::On,
182                "asc" => Token::Asc,
183                "desc" => Token::Desc,
184                "and" => Token::And,
185                "or" => Token::Or,
186                "not" => Token::Not,
187                "exists" => Token::Exists,
188                "let" => Token::Let,
189                "as" => Token::As,
190                "match" => Token::Match,
191                "group" => Token::Group,
192                "join" => Token::Join,
193                "inner" => Token::Inner,
194                "left" => Token::LeftKw,
195                "right" => Token::RightKw,
196                "outer" => Token::Outer,
197                "cross" => Token::Cross,
198                "transaction" => Token::Transaction,
199                "view" => Token::View,
200                "materialized" => Token::Materialized,
201                "materialize" => Token::Materialized,
202                "refresh" => Token::Refresh,
203                "union" => Token::Union,
204                "having" => Token::Having,
205                "distinct" => Token::Distinct,
206                "in" => Token::In,
207                "between" => Token::Between,
208                "like" => Token::Like,
209                "count" => Token::Count,
210                "avg" => Token::Avg,
211                "sum" => Token::Sum,
212                "min" => Token::Min,
213                "max" => Token::Max,
214                "is" => Token::Is,
215                "null" => Token::Null,
216                "upper" => Token::Upper,
217                "lower" => Token::Lower,
218                "length" => Token::Length,
219                "trim" => Token::Trim,
220                "substring" => Token::Substring,
221                "concat" => Token::Concat,
222                "abs" => Token::Abs,
223                "round" => Token::Round,
224                "ceil" => Token::Ceil,
225                "floor" => Token::Floor,
226                "sqrt" => Token::Sqrt,
227                "pow" => Token::Pow,
228                "now" => Token::Now,
229                "extract" => Token::Extract,
230                "date_add" => Token::DateAdd,
231                "date_diff" => Token::DateDiff,
232                "cast" => Token::Cast,
233                "case" => Token::Case,
234                "when" => Token::When,
235                "then" => Token::Then,
236                "else" => Token::Else,
237                "end" => Token::End,
238                "over" => Token::Over,
239                "partition" => Token::Partition,
240                "row_number" => Token::RowNumber,
241                "rank" => Token::Rank,
242                "dense_rank" => Token::DenseRank,
243                "alter" => Token::Alter,
244                "drop" => Token::Drop,
245                "add" => Token::Add,
246                "column" => Token::Column,
247                "explain" => Token::Explain,
248                "true" => Token::BoolLit(true),
249                "false" => Token::BoolLit(false),
250                _ => Token::Ident(word),
251            };
252            tokens.push(token);
253            continue;
254        }
255
256        // Two-char operators
257        if pos + 1 < chars.len() {
258            let two: String = chars[pos..pos + 2].iter().collect();
259            match two.as_str() {
260                ":=" => {
261                    tokens.push(Token::Assign);
262                    pos += 2;
263                    continue;
264                }
265                "->" => {
266                    tokens.push(Token::Arrow);
267                    pos += 2;
268                    continue;
269                }
270                "!=" => {
271                    tokens.push(Token::Neq);
272                    pos += 2;
273                    continue;
274                }
275                "<=" => {
276                    tokens.push(Token::Lte);
277                    pos += 2;
278                    continue;
279                }
280                ">=" => {
281                    tokens.push(Token::Gte);
282                    pos += 2;
283                    continue;
284                }
285                "??" => {
286                    tokens.push(Token::Coalesce);
287                    pos += 2;
288                    continue;
289                }
290                _ => {}
291            }
292        }
293
294        // Single-char operators
295        let token = match chars[pos] {
296            '=' => Token::Eq,
297            '<' => Token::Lt,
298            '>' => Token::Gt,
299            '|' => Token::Pipe,
300            '+' => Token::Plus,
301            '-' => Token::Minus,
302            '*' => Token::Star,
303            '/' => Token::Slash,
304            '{' => Token::LBrace,
305            '}' => Token::RBrace,
306            '(' => Token::LParen,
307            ')' => Token::RParen,
308            ',' => Token::Comma,
309            ':' => Token::Colon,
310            '.' => Token::Dot,
311            c => {
312                return Err(LexError {
313                    message: format!("unexpected character: {c}"),
314                    position: pos,
315                })
316            }
317        };
318        tokens.push(token);
319        pos += 1;
320    }
321
322    tokens.push(Token::Eof);
323    Ok(tokens)
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329    use crate::token::Token;
330
331    #[test]
332    fn test_lex_simple_query() {
333        let tokens = lex("User filter .age > 30").unwrap();
334        assert_eq!(
335            tokens,
336            vec![
337                Token::Ident("User".into()),
338                Token::Filter,
339                Token::DotIdent("age".into()),
340                Token::Gt,
341                Token::IntLit(30),
342                Token::Eof,
343            ]
344        );
345    }
346
347    #[test]
348    fn test_lex_projection() {
349        let tokens = lex("User { name, email }").unwrap();
350        assert_eq!(
351            tokens,
352            vec![
353                Token::Ident("User".into()),
354                Token::LBrace,
355                Token::Ident("name".into()),
356                Token::Comma,
357                Token::Ident("email".into()),
358                Token::RBrace,
359                Token::Eof,
360            ]
361        );
362    }
363
364    #[test]
365    fn test_lex_insert() {
366        let tokens = lex(r#"insert User { name := "Alice", age := 30 }"#).unwrap();
367        assert_eq!(
368            tokens,
369            vec![
370                Token::Insert,
371                Token::Ident("User".into()),
372                Token::LBrace,
373                Token::Ident("name".into()),
374                Token::Assign,
375                Token::StringLit("Alice".into()),
376                Token::Comma,
377                Token::Ident("age".into()),
378                Token::Assign,
379                Token::IntLit(30),
380                Token::RBrace,
381                Token::Eof,
382            ]
383        );
384    }
385
386    #[test]
387    fn test_lex_params() {
388        let tokens = lex("User filter .age > $min_age").unwrap();
389        assert_eq!(
390            tokens,
391            vec![
392                Token::Ident("User".into()),
393                Token::Filter,
394                Token::DotIdent("age".into()),
395                Token::Gt,
396                Token::Param("min_age".into()),
397                Token::Eof,
398            ]
399        );
400    }
401
402    #[test]
403    fn test_lex_string_with_escapes() {
404        let tokens = lex(r#""hello \"world\"""#).unwrap();
405        assert_eq!(
406            tokens,
407            vec![Token::StringLit("hello \"world\"".into()), Token::Eof,]
408        );
409    }
410
411    #[test]
412    fn test_lex_aggregation() {
413        let tokens = lex("count(User)").unwrap();
414        assert_eq!(
415            tokens,
416            vec![
417                Token::Count,
418                Token::LParen,
419                Token::Ident("User".into()),
420                Token::RParen,
421                Token::Eof,
422            ]
423        );
424    }
425
426    /// Regression for issue #24: an integer literal with more digits than
427    /// i64 can hold previously reached `s.parse::<i64>().unwrap()` and
428    /// panicked. It must return a `LexError` instead.
429    #[test]
430    fn test_lex_intlit_overflow_returns_err() {
431        // 22 digits — well past i64::MAX (19 digits).
432        let err = lex("4444444441111111144444").expect_err("must error, not panic");
433        assert!(
434            err.message.contains("integer literal out of range"),
435            "unexpected message: {}",
436            err.message
437        );
438        assert_eq!(err.position, 0);
439    }
440
441    /// Same bug, reached via the exact fuzzer reproducer from the
442    /// libFuzzer artifact attached to issue #24 (base64
443    /// `YXMJCQkJCQkJCQkJCQkJNDQ0NDQ0NDQ0MTExMTExMTQ0NDQJCQkJCQk=`).
444    #[test]
445    fn test_lex_fuzz_repro_issue_24() {
446        let input = "as\t\t\t\t\t\t\t\t\t\t\t\t\t44444444411111114444\t\t\t\t\t\t";
447        let err = lex(input).expect_err("fuzz reproducer must now error, not panic");
448        assert!(err.message.contains("integer literal"));
449    }
450}