1use crate::Error;
5
6#[derive(Debug, Clone, PartialEq)]
7pub enum Token {
8 Keyword(Keyword),
9
10 Ident(String),
11
12 Integer(i64),
13 Float(f64),
14 StringLit(String),
15
16 Asterisk,
17 Comma,
18 Dot,
19 Semicolon,
20 OpenParen,
21 CloseParen,
22 Plus,
23 Minus,
24 Slash,
25 Percent,
26 Eq,
27 NotEq,
28 Lt,
29 Gt,
30 LtEq,
31 GtEq,
32 Concat,
33}
34
35#[derive(Debug, Clone, PartialEq)]
36pub enum Keyword {
37 Select,
38 From,
39 Where,
40 And,
41 Or,
42 Not,
43 As,
44 Order,
45 By,
46 Asc,
47 Desc,
48 Limit,
49 Offset,
50 Group,
51 Having,
52 Distinct,
53 Insert,
54 Into,
55 Values,
56 Update,
57 Set,
58 Delete,
59 Create,
60 Table,
61 Join,
62 Inner,
63 Left,
64 Right,
65 On,
66 Null,
67 True,
68 False,
69 Is,
70 In,
71 Between,
72 Cast,
73 Count,
74 Sum,
75 Avg,
76 Min,
77 Max,
78
79 Int,
80 Int2,
81 Int4,
82 Int8,
83 Smallint,
84 Integer,
85 Bigint,
86 Float4,
87 Float8,
88 Real,
89 Double,
90 Precision,
91 Boolean,
92 Bool,
93 Varchar,
94 Text,
95 Char,
96 Utf8,
97 Blob,
98 Primary,
99 Key,
100 With,
101 Recursive,
102
103 Case,
104 When,
105 Then,
106 Else,
107 End,
108 Exists,
109 Union,
110 All,
111 Intersect,
112 Except,
113 Like,
114 Glob,
115 If,
116 FloatKw,
117 Index,
118 Unique,
119 Drop,
120 Cross,
121 Outer,
122 Full,
123 Natural,
124 Numeric,
125}
126
127pub fn tokenize(sql: &str) -> Result<Vec<Token>, Error> {
128 let mut tokens = Vec::new();
129 let chars: Vec<char> = sql.chars().collect();
130 let len = chars.len();
131 let mut i = 0;
132
133 while i < len {
134 let c = chars[i];
135
136 if c.is_ascii_whitespace() {
137 i += 1;
138 continue;
139 }
140
141 if c == '-' && i + 1 < len && chars[i + 1] == '-' {
142 while i < len && chars[i] != '\n' {
143 i += 1;
144 }
145 continue;
146 }
147
148 match c {
149 '*' => {
150 tokens.push(Token::Asterisk);
151 i += 1;
152 continue;
153 }
154 ',' => {
155 tokens.push(Token::Comma);
156 i += 1;
157 continue;
158 }
159 '.' => {
160 tokens.push(Token::Dot);
161 i += 1;
162 continue;
163 }
164 ';' => {
165 tokens.push(Token::Semicolon);
166 i += 1;
167 continue;
168 }
169 '(' => {
170 tokens.push(Token::OpenParen);
171 i += 1;
172 continue;
173 }
174 ')' => {
175 tokens.push(Token::CloseParen);
176 i += 1;
177 continue;
178 }
179 '+' => {
180 tokens.push(Token::Plus);
181 i += 1;
182 continue;
183 }
184 '-' => {
185 tokens.push(Token::Minus);
186 i += 1;
187 continue;
188 }
189 '/' => {
190 tokens.push(Token::Slash);
191 i += 1;
192 continue;
193 }
194 '%' => {
195 tokens.push(Token::Percent);
196 i += 1;
197 continue;
198 }
199 '=' => {
200 tokens.push(Token::Eq);
201 i += 1;
202 continue;
203 }
204 '<' => {
205 if i + 1 < len && chars[i + 1] == '=' {
206 tokens.push(Token::LtEq);
207 i += 2;
208 } else if i + 1 < len && chars[i + 1] == '>' {
209 tokens.push(Token::NotEq);
210 i += 2;
211 } else {
212 tokens.push(Token::Lt);
213 i += 1;
214 }
215 continue;
216 }
217 '>' => {
218 if i + 1 < len && chars[i + 1] == '=' {
219 tokens.push(Token::GtEq);
220 i += 2;
221 } else {
222 tokens.push(Token::Gt);
223 i += 1;
224 }
225 continue;
226 }
227 '!' => {
228 if i + 1 < len && chars[i + 1] == '=' {
229 tokens.push(Token::NotEq);
230 i += 2;
231 continue;
232 }
233 return Err(Error(format!("unexpected character '!' at position {i}")));
234 }
235 '|' => {
236 if i + 1 < len && chars[i + 1] == '|' {
237 tokens.push(Token::Concat);
238 i += 2;
239 continue;
240 }
241 return Err(Error(format!("unexpected character '|' at position {i}")));
242 }
243 _ => {}
244 }
245
246 if c == '\'' {
247 i += 1;
248 let mut s = String::new();
249 while i < len {
250 if chars[i] == '\'' {
251 if i + 1 < len && chars[i + 1] == '\'' {
252 s.push('\'');
253 i += 2;
254 } else {
255 break;
256 }
257 } else {
258 s.push(chars[i]);
259 i += 1;
260 }
261 }
262 if i >= len {
263 return Err(Error("unterminated string literal".into()));
264 }
265 i += 1;
266 tokens.push(Token::StringLit(s));
267 continue;
268 }
269
270 if c.is_ascii_digit() {
271 let start = i;
272 while i < len && chars[i].is_ascii_digit() {
273 i += 1;
274 }
275 if i < len && chars[i] == '.' && i + 1 < len && chars[i + 1].is_ascii_digit() {
276 i += 1;
277 while i < len && chars[i].is_ascii_digit() {
278 i += 1;
279 }
280 let text: String = chars[start..i].iter().collect();
281 let f: f64 = text.parse().map_err(|e| Error(format!("invalid float: {e}")))?;
282 tokens.push(Token::Float(f));
283 } else {
284 let text: String = chars[start..i].iter().collect();
285 let n: i64 = text.parse().map_err(|e| Error(format!("invalid integer: {e}")))?;
286 tokens.push(Token::Integer(n));
287 }
288 continue;
289 }
290
291 if c.is_ascii_alphabetic() || c == '_' {
292 let start = i;
293 while i < len && (chars[i].is_ascii_alphanumeric() || chars[i] == '_') {
294 i += 1;
295 }
296 let word: String = chars[start..i].iter().collect();
297 let upper = word.to_ascii_uppercase();
298 let token = match upper.as_str() {
299 "SELECT" => Token::Keyword(Keyword::Select),
300 "FROM" => Token::Keyword(Keyword::From),
301 "WHERE" => Token::Keyword(Keyword::Where),
302 "AND" => Token::Keyword(Keyword::And),
303 "OR" => Token::Keyword(Keyword::Or),
304 "NOT" => Token::Keyword(Keyword::Not),
305 "AS" => Token::Keyword(Keyword::As),
306 "ORDER" => Token::Keyword(Keyword::Order),
307 "BY" => Token::Keyword(Keyword::By),
308 "ASC" => Token::Keyword(Keyword::Asc),
309 "DESC" => Token::Keyword(Keyword::Desc),
310 "LIMIT" => Token::Keyword(Keyword::Limit),
311 "OFFSET" => Token::Keyword(Keyword::Offset),
312 "GROUP" => Token::Keyword(Keyword::Group),
313 "HAVING" => Token::Keyword(Keyword::Having),
314 "DISTINCT" => Token::Keyword(Keyword::Distinct),
315 "INSERT" => Token::Keyword(Keyword::Insert),
316 "INTO" => Token::Keyword(Keyword::Into),
317 "VALUES" => Token::Keyword(Keyword::Values),
318 "UPDATE" => Token::Keyword(Keyword::Update),
319 "SET" => Token::Keyword(Keyword::Set),
320 "DELETE" => Token::Keyword(Keyword::Delete),
321 "CREATE" => Token::Keyword(Keyword::Create),
322 "TABLE" => Token::Keyword(Keyword::Table),
323 "JOIN" => Token::Keyword(Keyword::Join),
324 "INNER" => Token::Keyword(Keyword::Inner),
325 "LEFT" => Token::Keyword(Keyword::Left),
326 "RIGHT" => Token::Keyword(Keyword::Right),
327 "ON" => Token::Keyword(Keyword::On),
328 "NULL" => Token::Keyword(Keyword::Null),
329 "TRUE" => Token::Keyword(Keyword::True),
330 "FALSE" => Token::Keyword(Keyword::False),
331 "IS" => Token::Keyword(Keyword::Is),
332 "IN" => Token::Keyword(Keyword::In),
333 "BETWEEN" => Token::Keyword(Keyword::Between),
334 "CAST" => Token::Keyword(Keyword::Cast),
335 "COUNT" => Token::Keyword(Keyword::Count),
336 "SUM" => Token::Keyword(Keyword::Sum),
337 "AVG" => Token::Keyword(Keyword::Avg),
338 "MIN" => Token::Keyword(Keyword::Min),
339 "MAX" => Token::Keyword(Keyword::Max),
340 "INT" => Token::Keyword(Keyword::Int),
341 "INT2" => Token::Keyword(Keyword::Int2),
342 "INT4" => Token::Keyword(Keyword::Int4),
343 "INT8" => Token::Keyword(Keyword::Int8),
344 "SMALLINT" => Token::Keyword(Keyword::Smallint),
345 "INTEGER" => Token::Keyword(Keyword::Integer),
346 "BIGINT" => Token::Keyword(Keyword::Bigint),
347 "FLOAT4" => Token::Keyword(Keyword::Float4),
348 "FLOAT8" => Token::Keyword(Keyword::Float8),
349 "REAL" => Token::Keyword(Keyword::Real),
350 "DOUBLE" => Token::Keyword(Keyword::Double),
351 "PRECISION" => Token::Keyword(Keyword::Precision),
352 "BOOLEAN" => Token::Keyword(Keyword::Boolean),
353 "BOOL" => Token::Keyword(Keyword::Bool),
354 "VARCHAR" => Token::Keyword(Keyword::Varchar),
355 "TEXT" => Token::Keyword(Keyword::Text),
356 "CHAR" => Token::Keyword(Keyword::Char),
357 "UTF8" => Token::Keyword(Keyword::Utf8),
358 "BLOB" => Token::Keyword(Keyword::Blob),
359 "PRIMARY" => Token::Keyword(Keyword::Primary),
360 "KEY" => Token::Keyword(Keyword::Key),
361 "WITH" => Token::Keyword(Keyword::With),
362 "RECURSIVE" => Token::Keyword(Keyword::Recursive),
363 "CASE" => Token::Keyword(Keyword::Case),
364 "WHEN" => Token::Keyword(Keyword::When),
365 "THEN" => Token::Keyword(Keyword::Then),
366 "ELSE" => Token::Keyword(Keyword::Else),
367 "END" => Token::Keyword(Keyword::End),
368 "EXISTS" => Token::Keyword(Keyword::Exists),
369 "UNION" => Token::Keyword(Keyword::Union),
370 "ALL" => Token::Keyword(Keyword::All),
371 "INTERSECT" => Token::Keyword(Keyword::Intersect),
372 "EXCEPT" => Token::Keyword(Keyword::Except),
373 "LIKE" => Token::Keyword(Keyword::Like),
374 "GLOB" => Token::Keyword(Keyword::Glob),
375 "IF" => Token::Keyword(Keyword::If),
376 "FLOAT" => Token::Keyword(Keyword::FloatKw),
377 "INDEX" => Token::Keyword(Keyword::Index),
378 "UNIQUE" => Token::Keyword(Keyword::Unique),
379 "DROP" => Token::Keyword(Keyword::Drop),
380 "CROSS" => Token::Keyword(Keyword::Cross),
381 "OUTER" => Token::Keyword(Keyword::Outer),
382 "FULL" => Token::Keyword(Keyword::Full),
383 "NATURAL" => Token::Keyword(Keyword::Natural),
384 "NUMERIC" => Token::Keyword(Keyword::Numeric),
385 _ => Token::Ident(word),
386 };
387 tokens.push(token);
388 continue;
389 }
390
391 return Err(Error(format!("unexpected character '{c}' at position {i}")));
392 }
393
394 Ok(tokens)
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400
401 #[test]
402 fn test_simple_select() {
403 let tokens = tokenize("SELECT id, name FROM users").unwrap();
404 assert_eq!(
405 tokens,
406 vec![
407 Token::Keyword(Keyword::Select),
408 Token::Ident("id".into()),
409 Token::Comma,
410 Token::Ident("name".into()),
411 Token::Keyword(Keyword::From),
412 Token::Ident("users".into()),
413 ]
414 );
415 }
416
417 #[test]
418 fn test_string_literal() {
419 let tokens = tokenize("SELECT 'hello'").unwrap();
420 assert_eq!(tokens, vec![Token::Keyword(Keyword::Select), Token::StringLit("hello".into()),]);
421 }
422
423 #[test]
424 fn test_comparison_operators() {
425 let tokens = tokenize("a <> b").unwrap();
426 assert_eq!(tokens, vec![Token::Ident("a".into()), Token::NotEq, Token::Ident("b".into()),]);
427 }
428
429 #[test]
430 fn test_numeric_literals() {
431 let tokens = tokenize("42 3.14").unwrap();
432 assert_eq!(tokens, vec![Token::Integer(42), Token::Float(3.14),]);
433 }
434}