1use crate::token::Token;
2
3const MAX_STRING_LITERAL: usize = 16 * 1024 * 1024;
6
7#[derive(Debug)]
8pub struct LexError {
9 pub message: String,
10 pub position: usize,
11}
12
13impl std::fmt::Display for LexError {
14 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15 write!(f, "at position {}: {}", self.position, self.message)
16 }
17}
18
19impl std::error::Error for LexError {}
20
21pub fn lex(input: &str) -> Result<Vec<Token>, LexError> {
22 let mut tokens = Vec::new();
23 let chars: Vec<char> = input.chars().collect();
24 let mut pos = 0;
25
26 while pos < chars.len() {
27 if chars[pos].is_whitespace() {
29 pos += 1;
30 continue;
31 }
32
33 if chars[pos] == '#' {
35 while pos < chars.len() && chars[pos] != '\n' {
36 pos += 1;
37 }
38 continue;
39 }
40
41 if chars[pos] == '.'
43 && pos + 1 < chars.len()
44 && (chars[pos + 1].is_alphabetic() || chars[pos + 1] == '_')
45 {
46 pos += 1; let start = pos;
48 while pos < chars.len() && (chars[pos].is_alphanumeric() || chars[pos] == '_') {
49 pos += 1;
50 }
51 let name: String = chars[start..pos].iter().collect();
52 tokens.push(Token::DotIdent(name));
53 continue;
54 }
55
56 if chars[pos] == '$' {
58 pos += 1;
59 let start = pos;
60 while pos < chars.len() && (chars[pos].is_alphanumeric() || chars[pos] == '_') {
61 pos += 1;
62 }
63 let name: String = chars[start..pos].iter().collect();
64 tokens.push(Token::Param(name));
65 continue;
66 }
67
68 if chars[pos] == '"' {
70 pos += 1;
71 let mut s = String::new();
72 while pos < chars.len() && chars[pos] != '"' {
73 if chars[pos] == '\\' && pos + 1 < chars.len() {
74 match chars[pos + 1] {
75 '"' => {
76 s.push('"');
77 pos += 2;
78 }
79 '\\' => {
80 s.push('\\');
81 pos += 2;
82 }
83 'n' => {
84 s.push('\n');
85 pos += 2;
86 }
87 't' => {
88 s.push('\t');
89 pos += 2;
90 }
91 _ => {
92 s.push(chars[pos + 1]);
93 pos += 2;
94 }
95 }
96 } else {
97 s.push(chars[pos]);
98 pos += 1;
99 }
100 }
101 if pos >= chars.len() {
102 return Err(LexError {
103 message: "unterminated string".into(),
104 position: pos,
105 });
106 }
107 pos += 1; if s.len() > MAX_STRING_LITERAL {
109 return Err(LexError {
110 message: format!(
111 "string literal exceeds maximum size of {}MB",
112 MAX_STRING_LITERAL / (1024 * 1024)
113 ),
114 position: pos,
115 });
116 }
117 tokens.push(Token::StringLit(s));
118 continue;
119 }
120
121 if chars[pos].is_ascii_digit()
123 || (chars[pos] == '-' && pos + 1 < chars.len() && chars[pos + 1].is_ascii_digit())
124 {
125 let start = pos;
126 if chars[pos] == '-' {
127 pos += 1;
128 }
129 while pos < chars.len() && chars[pos].is_ascii_digit() {
130 pos += 1;
131 }
132 if pos < chars.len()
133 && chars[pos] == '.'
134 && pos + 1 < chars.len()
135 && chars[pos + 1].is_ascii_digit()
136 {
137 pos += 1;
138 while pos < chars.len() && chars[pos].is_ascii_digit() {
139 pos += 1;
140 }
141 let s: String = chars[start..pos].iter().collect();
142 let value = s.parse::<f64>().map_err(|_| LexError {
143 message: format!("float literal out of range: {s}"),
144 position: start,
145 })?;
146 tokens.push(Token::FloatLit(value));
147 } else {
148 let s: String = chars[start..pos].iter().collect();
149 let value = s.parse::<i64>().map_err(|_| LexError {
150 message: format!("integer literal out of range for i64: {s}"),
151 position: start,
152 })?;
153 tokens.push(Token::IntLit(value));
154 }
155 continue;
156 }
157
158 if chars[pos].is_alphabetic() || chars[pos] == '_' {
160 let start = pos;
161 while pos < chars.len() && (chars[pos].is_alphanumeric() || chars[pos] == '_') {
162 pos += 1;
163 }
164 let word: String = chars[start..pos].iter().collect();
165 let token = match word.as_str() {
166 "type" => Token::Type,
167 "filter" => Token::Filter,
168 "order" => Token::Order,
169 "limit" => Token::Limit,
170 "offset" => Token::Offset,
171 "insert" => Token::Insert,
172 "update" => Token::Update,
173 "delete" => Token::Delete,
174 "upsert" => Token::Upsert,
175 "conflict" => Token::Conflict,
176 "select" => Token::Select,
177 "required" => Token::Required,
178 "multi" => Token::Multi,
179 "link" => Token::Link,
180 "index" => Token::Index,
181 "on" => Token::On,
182 "asc" => Token::Asc,
183 "desc" => Token::Desc,
184 "and" => Token::And,
185 "or" => Token::Or,
186 "not" => Token::Not,
187 "exists" => Token::Exists,
188 "let" => Token::Let,
189 "as" => Token::As,
190 "match" => Token::Match,
191 "group" => Token::Group,
192 "join" => Token::Join,
193 "inner" => Token::Inner,
194 "left" => Token::LeftKw,
195 "right" => Token::RightKw,
196 "outer" => Token::Outer,
197 "cross" => Token::Cross,
198 "transaction" => Token::Transaction,
199 "view" => Token::View,
200 "materialized" => Token::Materialized,
201 "materialize" => Token::Materialized,
202 "refresh" => Token::Refresh,
203 "union" => Token::Union,
204 "having" => Token::Having,
205 "distinct" => Token::Distinct,
206 "in" => Token::In,
207 "between" => Token::Between,
208 "like" => Token::Like,
209 "count" => Token::Count,
210 "avg" => Token::Avg,
211 "sum" => Token::Sum,
212 "min" => Token::Min,
213 "max" => Token::Max,
214 "is" => Token::Is,
215 "null" => Token::Null,
216 "upper" => Token::Upper,
217 "lower" => Token::Lower,
218 "length" => Token::Length,
219 "trim" => Token::Trim,
220 "substring" => Token::Substring,
221 "concat" => Token::Concat,
222 "abs" => Token::Abs,
223 "round" => Token::Round,
224 "ceil" => Token::Ceil,
225 "floor" => Token::Floor,
226 "sqrt" => Token::Sqrt,
227 "pow" => Token::Pow,
228 "now" => Token::Now,
229 "extract" => Token::Extract,
230 "date_add" => Token::DateAdd,
231 "date_diff" => Token::DateDiff,
232 "cast" => Token::Cast,
233 "case" => Token::Case,
234 "when" => Token::When,
235 "then" => Token::Then,
236 "else" => Token::Else,
237 "end" => Token::End,
238 "over" => Token::Over,
239 "partition" => Token::Partition,
240 "row_number" => Token::RowNumber,
241 "rank" => Token::Rank,
242 "dense_rank" => Token::DenseRank,
243 "alter" => Token::Alter,
244 "drop" => Token::Drop,
245 "add" => Token::Add,
246 "column" => Token::Column,
247 "explain" => Token::Explain,
248 "true" => Token::BoolLit(true),
249 "false" => Token::BoolLit(false),
250 _ => Token::Ident(word),
251 };
252 tokens.push(token);
253 continue;
254 }
255
256 if pos + 1 < chars.len() {
258 let two: String = chars[pos..pos + 2].iter().collect();
259 match two.as_str() {
260 ":=" => {
261 tokens.push(Token::Assign);
262 pos += 2;
263 continue;
264 }
265 "->" => {
266 tokens.push(Token::Arrow);
267 pos += 2;
268 continue;
269 }
270 "!=" => {
271 tokens.push(Token::Neq);
272 pos += 2;
273 continue;
274 }
275 "<=" => {
276 tokens.push(Token::Lte);
277 pos += 2;
278 continue;
279 }
280 ">=" => {
281 tokens.push(Token::Gte);
282 pos += 2;
283 continue;
284 }
285 "??" => {
286 tokens.push(Token::Coalesce);
287 pos += 2;
288 continue;
289 }
290 _ => {}
291 }
292 }
293
294 let token = match chars[pos] {
296 '=' => Token::Eq,
297 '<' => Token::Lt,
298 '>' => Token::Gt,
299 '|' => Token::Pipe,
300 '+' => Token::Plus,
301 '-' => Token::Minus,
302 '*' => Token::Star,
303 '/' => Token::Slash,
304 '{' => Token::LBrace,
305 '}' => Token::RBrace,
306 '(' => Token::LParen,
307 ')' => Token::RParen,
308 ',' => Token::Comma,
309 ':' => Token::Colon,
310 '.' => Token::Dot,
311 c => {
312 return Err(LexError {
313 message: format!("unexpected character: {c}"),
314 position: pos,
315 })
316 }
317 };
318 tokens.push(token);
319 pos += 1;
320 }
321
322 tokens.push(Token::Eof);
323 Ok(tokens)
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329 use crate::token::Token;
330
331 #[test]
332 fn test_lex_simple_query() {
333 let tokens = lex("User filter .age > 30").unwrap();
334 assert_eq!(
335 tokens,
336 vec![
337 Token::Ident("User".into()),
338 Token::Filter,
339 Token::DotIdent("age".into()),
340 Token::Gt,
341 Token::IntLit(30),
342 Token::Eof,
343 ]
344 );
345 }
346
347 #[test]
348 fn test_lex_projection() {
349 let tokens = lex("User { name, email }").unwrap();
350 assert_eq!(
351 tokens,
352 vec![
353 Token::Ident("User".into()),
354 Token::LBrace,
355 Token::Ident("name".into()),
356 Token::Comma,
357 Token::Ident("email".into()),
358 Token::RBrace,
359 Token::Eof,
360 ]
361 );
362 }
363
364 #[test]
365 fn test_lex_insert() {
366 let tokens = lex(r#"insert User { name := "Alice", age := 30 }"#).unwrap();
367 assert_eq!(
368 tokens,
369 vec![
370 Token::Insert,
371 Token::Ident("User".into()),
372 Token::LBrace,
373 Token::Ident("name".into()),
374 Token::Assign,
375 Token::StringLit("Alice".into()),
376 Token::Comma,
377 Token::Ident("age".into()),
378 Token::Assign,
379 Token::IntLit(30),
380 Token::RBrace,
381 Token::Eof,
382 ]
383 );
384 }
385
386 #[test]
387 fn test_lex_params() {
388 let tokens = lex("User filter .age > $min_age").unwrap();
389 assert_eq!(
390 tokens,
391 vec![
392 Token::Ident("User".into()),
393 Token::Filter,
394 Token::DotIdent("age".into()),
395 Token::Gt,
396 Token::Param("min_age".into()),
397 Token::Eof,
398 ]
399 );
400 }
401
402 #[test]
403 fn test_lex_string_with_escapes() {
404 let tokens = lex(r#""hello \"world\"""#).unwrap();
405 assert_eq!(
406 tokens,
407 vec![Token::StringLit("hello \"world\"".into()), Token::Eof,]
408 );
409 }
410
411 #[test]
412 fn test_lex_aggregation() {
413 let tokens = lex("count(User)").unwrap();
414 assert_eq!(
415 tokens,
416 vec![
417 Token::Count,
418 Token::LParen,
419 Token::Ident("User".into()),
420 Token::RParen,
421 Token::Eof,
422 ]
423 );
424 }
425
426 #[test]
430 fn test_lex_intlit_overflow_returns_err() {
431 let err = lex("4444444441111111144444").expect_err("must error, not panic");
433 assert!(
434 err.message.contains("integer literal out of range"),
435 "unexpected message: {}",
436 err.message
437 );
438 assert_eq!(err.position, 0);
439 }
440
441 #[test]
445 fn test_lex_fuzz_repro_issue_24() {
446 let input = "as\t\t\t\t\t\t\t\t\t\t\t\t\t44444444411111114444\t\t\t\t\t\t";
447 let err = lex(input).expect_err("fuzz reproducer must now error, not panic");
448 assert!(err.message.contains("integer literal"));
449 }
450}