1use crate::token::Token;
2
3#[derive(Debug)]
4pub struct LexError {
5 pub message: String,
6 pub position: usize,
7}
8
9impl std::fmt::Display for LexError {
10 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
11 write!(f, "at position {}: {}", self.position, self.message)
12 }
13}
14
15impl std::error::Error for LexError {}
16
17pub fn lex(input: &str) -> Result<Vec<Token>, LexError> {
18 let mut tokens = Vec::new();
19 let chars: Vec<char> = input.chars().collect();
20 let mut pos = 0;
21
22 while pos < chars.len() {
23 if chars[pos].is_whitespace() {
25 pos += 1;
26 continue;
27 }
28
29 if chars[pos] == '#' {
31 while pos < chars.len() && chars[pos] != '\n' {
32 pos += 1;
33 }
34 continue;
35 }
36
37 if chars[pos] == '.'
39 && pos + 1 < chars.len()
40 && (chars[pos + 1].is_alphabetic() || chars[pos + 1] == '_')
41 {
42 pos += 1; let start = pos;
44 while pos < chars.len() && (chars[pos].is_alphanumeric() || chars[pos] == '_') {
45 pos += 1;
46 }
47 let name: String = chars[start..pos].iter().collect();
48 tokens.push(Token::DotIdent(name));
49 continue;
50 }
51
52 if chars[pos] == '$' {
54 pos += 1;
55 let start = pos;
56 while pos < chars.len() && (chars[pos].is_alphanumeric() || chars[pos] == '_') {
57 pos += 1;
58 }
59 let name: String = chars[start..pos].iter().collect();
60 tokens.push(Token::Param(name));
61 continue;
62 }
63
64 if chars[pos] == '"' {
66 pos += 1;
67 let mut s = String::new();
68 while pos < chars.len() && chars[pos] != '"' {
69 if chars[pos] == '\\' && pos + 1 < chars.len() {
70 match chars[pos + 1] {
71 '"' => {
72 s.push('"');
73 pos += 2;
74 }
75 '\\' => {
76 s.push('\\');
77 pos += 2;
78 }
79 'n' => {
80 s.push('\n');
81 pos += 2;
82 }
83 't' => {
84 s.push('\t');
85 pos += 2;
86 }
87 _ => {
88 s.push(chars[pos + 1]);
89 pos += 2;
90 }
91 }
92 } else {
93 s.push(chars[pos]);
94 pos += 1;
95 }
96 }
97 if pos >= chars.len() {
98 return Err(LexError {
99 message: "unterminated string".into(),
100 position: pos,
101 });
102 }
103 pos += 1; tokens.push(Token::StringLit(s));
105 continue;
106 }
107
108 if chars[pos].is_ascii_digit()
110 || (chars[pos] == '-' && pos + 1 < chars.len() && chars[pos + 1].is_ascii_digit())
111 {
112 let start = pos;
113 if chars[pos] == '-' {
114 pos += 1;
115 }
116 while pos < chars.len() && chars[pos].is_ascii_digit() {
117 pos += 1;
118 }
119 if pos < chars.len()
120 && chars[pos] == '.'
121 && pos + 1 < chars.len()
122 && chars[pos + 1].is_ascii_digit()
123 {
124 pos += 1;
125 while pos < chars.len() && chars[pos].is_ascii_digit() {
126 pos += 1;
127 }
128 let s: String = chars[start..pos].iter().collect();
129 let value = s.parse::<f64>().map_err(|_| LexError {
130 message: format!("float literal out of range: {s}"),
131 position: start,
132 })?;
133 tokens.push(Token::FloatLit(value));
134 } else {
135 let s: String = chars[start..pos].iter().collect();
136 let value = s.parse::<i64>().map_err(|_| LexError {
137 message: format!("integer literal out of range for i64: {s}"),
138 position: start,
139 })?;
140 tokens.push(Token::IntLit(value));
141 }
142 continue;
143 }
144
145 if chars[pos].is_alphabetic() || chars[pos] == '_' {
147 let start = pos;
148 while pos < chars.len() && (chars[pos].is_alphanumeric() || chars[pos] == '_') {
149 pos += 1;
150 }
151 let word: String = chars[start..pos].iter().collect();
152 let token = match word.as_str() {
153 "type" => Token::Type,
154 "filter" => Token::Filter,
155 "order" => Token::Order,
156 "limit" => Token::Limit,
157 "offset" => Token::Offset,
158 "insert" => Token::Insert,
159 "update" => Token::Update,
160 "delete" => Token::Delete,
161 "upsert" => Token::Upsert,
162 "conflict" => Token::Conflict,
163 "select" => Token::Select,
164 "required" => Token::Required,
165 "multi" => Token::Multi,
166 "link" => Token::Link,
167 "index" => Token::Index,
168 "on" => Token::On,
169 "asc" => Token::Asc,
170 "desc" => Token::Desc,
171 "and" => Token::And,
172 "or" => Token::Or,
173 "not" => Token::Not,
174 "exists" => Token::Exists,
175 "let" => Token::Let,
176 "as" => Token::As,
177 "match" => Token::Match,
178 "group" => Token::Group,
179 "join" => Token::Join,
180 "inner" => Token::Inner,
181 "left" => Token::LeftKw,
182 "right" => Token::RightKw,
183 "outer" => Token::Outer,
184 "cross" => Token::Cross,
185 "transaction" => Token::Transaction,
186 "view" => Token::View,
187 "materialized" => Token::Materialized,
188 "materialize" => Token::Materialized,
189 "refresh" => Token::Refresh,
190 "union" => Token::Union,
191 "having" => Token::Having,
192 "distinct" => Token::Distinct,
193 "in" => Token::In,
194 "between" => Token::Between,
195 "like" => Token::Like,
196 "count" => Token::Count,
197 "avg" => Token::Avg,
198 "sum" => Token::Sum,
199 "min" => Token::Min,
200 "max" => Token::Max,
201 "is" => Token::Is,
202 "null" => Token::Null,
203 "upper" => Token::Upper,
204 "lower" => Token::Lower,
205 "length" => Token::Length,
206 "trim" => Token::Trim,
207 "substring" => Token::Substring,
208 "concat" => Token::Concat,
209 "abs" => Token::Abs,
210 "round" => Token::Round,
211 "ceil" => Token::Ceil,
212 "floor" => Token::Floor,
213 "sqrt" => Token::Sqrt,
214 "pow" => Token::Pow,
215 "now" => Token::Now,
216 "extract" => Token::Extract,
217 "date_add" => Token::DateAdd,
218 "date_diff" => Token::DateDiff,
219 "cast" => Token::Cast,
220 "case" => Token::Case,
221 "when" => Token::When,
222 "then" => Token::Then,
223 "else" => Token::Else,
224 "end" => Token::End,
225 "over" => Token::Over,
226 "partition" => Token::Partition,
227 "row_number" => Token::RowNumber,
228 "rank" => Token::Rank,
229 "dense_rank" => Token::DenseRank,
230 "alter" => Token::Alter,
231 "drop" => Token::Drop,
232 "add" => Token::Add,
233 "column" => Token::Column,
234 "explain" => Token::Explain,
235 "true" => Token::BoolLit(true),
236 "false" => Token::BoolLit(false),
237 _ => Token::Ident(word),
238 };
239 tokens.push(token);
240 continue;
241 }
242
243 if pos + 1 < chars.len() {
245 let two: String = chars[pos..pos + 2].iter().collect();
246 match two.as_str() {
247 ":=" => {
248 tokens.push(Token::Assign);
249 pos += 2;
250 continue;
251 }
252 "->" => {
253 tokens.push(Token::Arrow);
254 pos += 2;
255 continue;
256 }
257 "!=" => {
258 tokens.push(Token::Neq);
259 pos += 2;
260 continue;
261 }
262 "<=" => {
263 tokens.push(Token::Lte);
264 pos += 2;
265 continue;
266 }
267 ">=" => {
268 tokens.push(Token::Gte);
269 pos += 2;
270 continue;
271 }
272 "??" => {
273 tokens.push(Token::Coalesce);
274 pos += 2;
275 continue;
276 }
277 _ => {}
278 }
279 }
280
281 let token = match chars[pos] {
283 '=' => Token::Eq,
284 '<' => Token::Lt,
285 '>' => Token::Gt,
286 '|' => Token::Pipe,
287 '+' => Token::Plus,
288 '-' => Token::Minus,
289 '*' => Token::Star,
290 '/' => Token::Slash,
291 '{' => Token::LBrace,
292 '}' => Token::RBrace,
293 '(' => Token::LParen,
294 ')' => Token::RParen,
295 ',' => Token::Comma,
296 ':' => Token::Colon,
297 '.' => Token::Dot,
298 c => {
299 return Err(LexError {
300 message: format!("unexpected character: {c}"),
301 position: pos,
302 })
303 }
304 };
305 tokens.push(token);
306 pos += 1;
307 }
308
309 tokens.push(Token::Eof);
310 Ok(tokens)
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316 use crate::token::Token;
317
318 #[test]
319 fn test_lex_simple_query() {
320 let tokens = lex("User filter .age > 30").unwrap();
321 assert_eq!(
322 tokens,
323 vec![
324 Token::Ident("User".into()),
325 Token::Filter,
326 Token::DotIdent("age".into()),
327 Token::Gt,
328 Token::IntLit(30),
329 Token::Eof,
330 ]
331 );
332 }
333
334 #[test]
335 fn test_lex_projection() {
336 let tokens = lex("User { name, email }").unwrap();
337 assert_eq!(
338 tokens,
339 vec![
340 Token::Ident("User".into()),
341 Token::LBrace,
342 Token::Ident("name".into()),
343 Token::Comma,
344 Token::Ident("email".into()),
345 Token::RBrace,
346 Token::Eof,
347 ]
348 );
349 }
350
351 #[test]
352 fn test_lex_insert() {
353 let tokens = lex(r#"insert User { name := "Alice", age := 30 }"#).unwrap();
354 assert_eq!(
355 tokens,
356 vec![
357 Token::Insert,
358 Token::Ident("User".into()),
359 Token::LBrace,
360 Token::Ident("name".into()),
361 Token::Assign,
362 Token::StringLit("Alice".into()),
363 Token::Comma,
364 Token::Ident("age".into()),
365 Token::Assign,
366 Token::IntLit(30),
367 Token::RBrace,
368 Token::Eof,
369 ]
370 );
371 }
372
373 #[test]
374 fn test_lex_params() {
375 let tokens = lex("User filter .age > $min_age").unwrap();
376 assert_eq!(
377 tokens,
378 vec![
379 Token::Ident("User".into()),
380 Token::Filter,
381 Token::DotIdent("age".into()),
382 Token::Gt,
383 Token::Param("min_age".into()),
384 Token::Eof,
385 ]
386 );
387 }
388
389 #[test]
390 fn test_lex_string_with_escapes() {
391 let tokens = lex(r#""hello \"world\"""#).unwrap();
392 assert_eq!(
393 tokens,
394 vec![Token::StringLit("hello \"world\"".into()), Token::Eof,]
395 );
396 }
397
398 #[test]
399 fn test_lex_aggregation() {
400 let tokens = lex("count(User)").unwrap();
401 assert_eq!(
402 tokens,
403 vec![
404 Token::Count,
405 Token::LParen,
406 Token::Ident("User".into()),
407 Token::RParen,
408 Token::Eof,
409 ]
410 );
411 }
412
413 #[test]
417 fn test_lex_intlit_overflow_returns_err() {
418 let err = lex("4444444441111111144444").expect_err("must error, not panic");
420 assert!(
421 err.message.contains("integer literal out of range"),
422 "unexpected message: {}",
423 err.message
424 );
425 assert_eq!(err.position, 0);
426 }
427
428 #[test]
432 fn test_lex_fuzz_repro_issue_24() {
433 let input = "as\t\t\t\t\t\t\t\t\t\t\t\t\t44444444411111114444\t\t\t\t\t\t";
434 let err = lex(input).expect_err("fuzz reproducer must now error, not panic");
435 assert!(err.message.contains("integer literal"));
436 }
437}