1use crate::Error;
5
6#[derive(Debug, Clone, PartialEq)]
7pub enum Token {
8 Keyword(Keyword),
10 Ident(String),
12 Integer(i64),
14 Float(f64),
15 StringLit(String),
16 Asterisk, Comma, Dot, Semicolon, OpenParen, CloseParen, Plus, Minus, Slash, Percent, Eq, NotEq, Lt, Gt, LtEq, GtEq, Concat, }
35
36#[derive(Debug, Clone, PartialEq)]
37pub enum Keyword {
38 Select,
39 From,
40 Where,
41 And,
42 Or,
43 Not,
44 As,
45 Order,
46 By,
47 Asc,
48 Desc,
49 Limit,
50 Offset,
51 Group,
52 Having,
53 Distinct,
54 Insert,
55 Into,
56 Values,
57 Update,
58 Set,
59 Delete,
60 Create,
61 Table,
62 Join,
63 Inner,
64 Left,
65 Right,
66 On,
67 Null,
68 True,
69 False,
70 Is,
71 In,
72 Between,
73 Cast,
74 Count,
75 Sum,
76 Avg,
77 Min,
78 Max,
79 Int,
81 Int2,
82 Int4,
83 Int8,
84 Smallint,
85 Integer,
86 Bigint,
87 Float4,
88 Float8,
89 Real,
90 Double,
91 Precision,
92 Boolean,
93 Bool,
94 Varchar,
95 Text,
96 Char,
97 Utf8,
98 Blob,
99 Primary,
100 Key,
101 With,
102 Recursive,
103 Case,
105 When,
106 Then,
107 Else,
108 End,
109 Exists,
110 Union,
111 All,
112 Intersect,
113 Except,
114 Like,
115 Glob,
116 If,
117 FloatKw,
118 Index,
119 Unique,
120 Drop,
121 Cross,
122 Outer,
123 Full,
124 Natural,
125 Numeric,
126}
127
128pub fn tokenize(sql: &str) -> Result<Vec<Token>, Error> {
129 let mut tokens = Vec::new();
130 let chars: Vec<char> = sql.chars().collect();
131 let len = chars.len();
132 let mut i = 0;
133
134 while i < len {
135 let c = chars[i];
136
137 if c.is_ascii_whitespace() {
139 i += 1;
140 continue;
141 }
142
143 if c == '-' && i + 1 < len && chars[i + 1] == '-' {
145 while i < len && chars[i] != '\n' {
146 i += 1;
147 }
148 continue;
149 }
150
151 match c {
153 '*' => {
154 tokens.push(Token::Asterisk);
155 i += 1;
156 continue;
157 }
158 ',' => {
159 tokens.push(Token::Comma);
160 i += 1;
161 continue;
162 }
163 '.' => {
164 tokens.push(Token::Dot);
165 i += 1;
166 continue;
167 }
168 ';' => {
169 tokens.push(Token::Semicolon);
170 i += 1;
171 continue;
172 }
173 '(' => {
174 tokens.push(Token::OpenParen);
175 i += 1;
176 continue;
177 }
178 ')' => {
179 tokens.push(Token::CloseParen);
180 i += 1;
181 continue;
182 }
183 '+' => {
184 tokens.push(Token::Plus);
185 i += 1;
186 continue;
187 }
188 '-' => {
189 tokens.push(Token::Minus);
190 i += 1;
191 continue;
192 }
193 '/' => {
194 tokens.push(Token::Slash);
195 i += 1;
196 continue;
197 }
198 '%' => {
199 tokens.push(Token::Percent);
200 i += 1;
201 continue;
202 }
203 '=' => {
204 tokens.push(Token::Eq);
205 i += 1;
206 continue;
207 }
208 '<' => {
209 if i + 1 < len && chars[i + 1] == '=' {
210 tokens.push(Token::LtEq);
211 i += 2;
212 } else if i + 1 < len && chars[i + 1] == '>' {
213 tokens.push(Token::NotEq);
214 i += 2;
215 } else {
216 tokens.push(Token::Lt);
217 i += 1;
218 }
219 continue;
220 }
221 '>' => {
222 if i + 1 < len && chars[i + 1] == '=' {
223 tokens.push(Token::GtEq);
224 i += 2;
225 } else {
226 tokens.push(Token::Gt);
227 i += 1;
228 }
229 continue;
230 }
231 '!' => {
232 if i + 1 < len && chars[i + 1] == '=' {
233 tokens.push(Token::NotEq);
234 i += 2;
235 continue;
236 }
237 return Err(Error(format!("unexpected character '!' at position {i}")));
238 }
239 '|' => {
240 if i + 1 < len && chars[i + 1] == '|' {
241 tokens.push(Token::Concat);
242 i += 2;
243 continue;
244 }
245 return Err(Error(format!("unexpected character '|' at position {i}")));
246 }
247 _ => {}
248 }
249
250 if c == '\'' {
252 i += 1;
253 let mut s = String::new();
254 while i < len {
255 if chars[i] == '\'' {
256 if i + 1 < len && chars[i + 1] == '\'' {
258 s.push('\'');
259 i += 2;
260 } else {
261 break;
262 }
263 } else {
264 s.push(chars[i]);
265 i += 1;
266 }
267 }
268 if i >= len {
269 return Err(Error("unterminated string literal".into()));
270 }
271 i += 1; tokens.push(Token::StringLit(s));
273 continue;
274 }
275
276 if c.is_ascii_digit() {
278 let start = i;
279 while i < len && chars[i].is_ascii_digit() {
280 i += 1;
281 }
282 if i < len && chars[i] == '.' && i + 1 < len && chars[i + 1].is_ascii_digit() {
283 i += 1; while i < len && chars[i].is_ascii_digit() {
285 i += 1;
286 }
287 let text: String = chars[start..i].iter().collect();
288 let f: f64 = text.parse().map_err(|e| Error(format!("invalid float: {e}")))?;
289 tokens.push(Token::Float(f));
290 } else {
291 let text: String = chars[start..i].iter().collect();
292 let n: i64 = text.parse().map_err(|e| Error(format!("invalid integer: {e}")))?;
293 tokens.push(Token::Integer(n));
294 }
295 continue;
296 }
297
298 if c.is_ascii_alphabetic() || c == '_' {
300 let start = i;
301 while i < len && (chars[i].is_ascii_alphanumeric() || chars[i] == '_') {
302 i += 1;
303 }
304 let word: String = chars[start..i].iter().collect();
305 let upper = word.to_ascii_uppercase();
306 let token = match upper.as_str() {
307 "SELECT" => Token::Keyword(Keyword::Select),
308 "FROM" => Token::Keyword(Keyword::From),
309 "WHERE" => Token::Keyword(Keyword::Where),
310 "AND" => Token::Keyword(Keyword::And),
311 "OR" => Token::Keyword(Keyword::Or),
312 "NOT" => Token::Keyword(Keyword::Not),
313 "AS" => Token::Keyword(Keyword::As),
314 "ORDER" => Token::Keyword(Keyword::Order),
315 "BY" => Token::Keyword(Keyword::By),
316 "ASC" => Token::Keyword(Keyword::Asc),
317 "DESC" => Token::Keyword(Keyword::Desc),
318 "LIMIT" => Token::Keyword(Keyword::Limit),
319 "OFFSET" => Token::Keyword(Keyword::Offset),
320 "GROUP" => Token::Keyword(Keyword::Group),
321 "HAVING" => Token::Keyword(Keyword::Having),
322 "DISTINCT" => Token::Keyword(Keyword::Distinct),
323 "INSERT" => Token::Keyword(Keyword::Insert),
324 "INTO" => Token::Keyword(Keyword::Into),
325 "VALUES" => Token::Keyword(Keyword::Values),
326 "UPDATE" => Token::Keyword(Keyword::Update),
327 "SET" => Token::Keyword(Keyword::Set),
328 "DELETE" => Token::Keyword(Keyword::Delete),
329 "CREATE" => Token::Keyword(Keyword::Create),
330 "TABLE" => Token::Keyword(Keyword::Table),
331 "JOIN" => Token::Keyword(Keyword::Join),
332 "INNER" => Token::Keyword(Keyword::Inner),
333 "LEFT" => Token::Keyword(Keyword::Left),
334 "RIGHT" => Token::Keyword(Keyword::Right),
335 "ON" => Token::Keyword(Keyword::On),
336 "NULL" => Token::Keyword(Keyword::Null),
337 "TRUE" => Token::Keyword(Keyword::True),
338 "FALSE" => Token::Keyword(Keyword::False),
339 "IS" => Token::Keyword(Keyword::Is),
340 "IN" => Token::Keyword(Keyword::In),
341 "BETWEEN" => Token::Keyword(Keyword::Between),
342 "CAST" => Token::Keyword(Keyword::Cast),
343 "COUNT" => Token::Keyword(Keyword::Count),
344 "SUM" => Token::Keyword(Keyword::Sum),
345 "AVG" => Token::Keyword(Keyword::Avg),
346 "MIN" => Token::Keyword(Keyword::Min),
347 "MAX" => Token::Keyword(Keyword::Max),
348 "INT" => Token::Keyword(Keyword::Int),
349 "INT2" => Token::Keyword(Keyword::Int2),
350 "INT4" => Token::Keyword(Keyword::Int4),
351 "INT8" => Token::Keyword(Keyword::Int8),
352 "SMALLINT" => Token::Keyword(Keyword::Smallint),
353 "INTEGER" => Token::Keyword(Keyword::Integer),
354 "BIGINT" => Token::Keyword(Keyword::Bigint),
355 "FLOAT4" => Token::Keyword(Keyword::Float4),
356 "FLOAT8" => Token::Keyword(Keyword::Float8),
357 "REAL" => Token::Keyword(Keyword::Real),
358 "DOUBLE" => Token::Keyword(Keyword::Double),
359 "PRECISION" => Token::Keyword(Keyword::Precision),
360 "BOOLEAN" => Token::Keyword(Keyword::Boolean),
361 "BOOL" => Token::Keyword(Keyword::Bool),
362 "VARCHAR" => Token::Keyword(Keyword::Varchar),
363 "TEXT" => Token::Keyword(Keyword::Text),
364 "CHAR" => Token::Keyword(Keyword::Char),
365 "UTF8" => Token::Keyword(Keyword::Utf8),
366 "BLOB" => Token::Keyword(Keyword::Blob),
367 "PRIMARY" => Token::Keyword(Keyword::Primary),
368 "KEY" => Token::Keyword(Keyword::Key),
369 "WITH" => Token::Keyword(Keyword::With),
370 "RECURSIVE" => Token::Keyword(Keyword::Recursive),
371 "CASE" => Token::Keyword(Keyword::Case),
372 "WHEN" => Token::Keyword(Keyword::When),
373 "THEN" => Token::Keyword(Keyword::Then),
374 "ELSE" => Token::Keyword(Keyword::Else),
375 "END" => Token::Keyword(Keyword::End),
376 "EXISTS" => Token::Keyword(Keyword::Exists),
377 "UNION" => Token::Keyword(Keyword::Union),
378 "ALL" => Token::Keyword(Keyword::All),
379 "INTERSECT" => Token::Keyword(Keyword::Intersect),
380 "EXCEPT" => Token::Keyword(Keyword::Except),
381 "LIKE" => Token::Keyword(Keyword::Like),
382 "GLOB" => Token::Keyword(Keyword::Glob),
383 "IF" => Token::Keyword(Keyword::If),
384 "FLOAT" => Token::Keyword(Keyword::FloatKw),
385 "INDEX" => Token::Keyword(Keyword::Index),
386 "UNIQUE" => Token::Keyword(Keyword::Unique),
387 "DROP" => Token::Keyword(Keyword::Drop),
388 "CROSS" => Token::Keyword(Keyword::Cross),
389 "OUTER" => Token::Keyword(Keyword::Outer),
390 "FULL" => Token::Keyword(Keyword::Full),
391 "NATURAL" => Token::Keyword(Keyword::Natural),
392 "NUMERIC" => Token::Keyword(Keyword::Numeric),
393 _ => Token::Ident(word),
394 };
395 tokens.push(token);
396 continue;
397 }
398
399 return Err(Error(format!("unexpected character '{c}' at position {i}")));
400 }
401
402 Ok(tokens)
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408
409 #[test]
410 fn test_simple_select() {
411 let tokens = tokenize("SELECT id, name FROM users").unwrap();
412 assert_eq!(
413 tokens,
414 vec![
415 Token::Keyword(Keyword::Select),
416 Token::Ident("id".into()),
417 Token::Comma,
418 Token::Ident("name".into()),
419 Token::Keyword(Keyword::From),
420 Token::Ident("users".into()),
421 ]
422 );
423 }
424
425 #[test]
426 fn test_string_literal() {
427 let tokens = tokenize("SELECT 'hello'").unwrap();
428 assert_eq!(tokens, vec![Token::Keyword(Keyword::Select), Token::StringLit("hello".into()),]);
429 }
430
431 #[test]
432 fn test_comparison_operators() {
433 let tokens = tokenize("a <> b").unwrap();
434 assert_eq!(tokens, vec![Token::Ident("a".into()), Token::NotEq, Token::Ident("b".into()),]);
435 }
436
437 #[test]
438 fn test_numeric_literals() {
439 let tokens = tokenize("42 3.14").unwrap();
440 assert_eq!(tokens, vec![Token::Integer(42), Token::Float(3.14),]);
441 }
442}