1use crate::Error;
5
6#[derive(Debug, Clone, PartialEq)]
7pub enum Token {
8 Keyword(Keyword),
10 Ident(String),
12 Integer(i64),
14 Float(f64),
15 StringLit(String),
16 Asterisk, Comma, Dot, Semicolon, OpenParen, CloseParen, Plus, Minus, Slash, Percent, Eq, NotEq, Lt, Gt, LtEq, GtEq, }
34
35#[derive(Debug, Clone, PartialEq)]
36pub enum Keyword {
37 Select,
38 From,
39 Where,
40 And,
41 Or,
42 Not,
43 As,
44 Order,
45 By,
46 Asc,
47 Desc,
48 Limit,
49 Offset,
50 Group,
51 Having,
52 Distinct,
53 Insert,
54 Into,
55 Values,
56 Update,
57 Set,
58 Delete,
59 Create,
60 Table,
61 Join,
62 Inner,
63 Left,
64 Right,
65 On,
66 Null,
67 True,
68 False,
69 Is,
70 In,
71 Between,
72 Cast,
73 Count,
74 Sum,
75 Avg,
76 Min,
77 Max,
78 Int,
80 Int2,
81 Int4,
82 Int8,
83 Smallint,
84 Integer,
85 Bigint,
86 Float4,
87 Float8,
88 Real,
89 Double,
90 Precision,
91 Boolean,
92 Bool,
93 Varchar,
94 Text,
95 Char,
96 Utf8,
97 Blob,
98 Primary,
99 Key,
100 With,
101 Recursive,
102}
103
104pub fn tokenize(sql: &str) -> Result<Vec<Token>, Error> {
105 let mut tokens = Vec::new();
106 let chars: Vec<char> = sql.chars().collect();
107 let len = chars.len();
108 let mut i = 0;
109
110 while i < len {
111 let c = chars[i];
112
113 if c.is_ascii_whitespace() {
115 i += 1;
116 continue;
117 }
118
119 if c == '-' && i + 1 < len && chars[i + 1] == '-' {
121 while i < len && chars[i] != '\n' {
122 i += 1;
123 }
124 continue;
125 }
126
127 match c {
129 '*' => {
130 tokens.push(Token::Asterisk);
131 i += 1;
132 continue;
133 }
134 ',' => {
135 tokens.push(Token::Comma);
136 i += 1;
137 continue;
138 }
139 '.' => {
140 tokens.push(Token::Dot);
141 i += 1;
142 continue;
143 }
144 ';' => {
145 tokens.push(Token::Semicolon);
146 i += 1;
147 continue;
148 }
149 '(' => {
150 tokens.push(Token::OpenParen);
151 i += 1;
152 continue;
153 }
154 ')' => {
155 tokens.push(Token::CloseParen);
156 i += 1;
157 continue;
158 }
159 '+' => {
160 tokens.push(Token::Plus);
161 i += 1;
162 continue;
163 }
164 '-' => {
165 tokens.push(Token::Minus);
166 i += 1;
167 continue;
168 }
169 '/' => {
170 tokens.push(Token::Slash);
171 i += 1;
172 continue;
173 }
174 '%' => {
175 tokens.push(Token::Percent);
176 i += 1;
177 continue;
178 }
179 '=' => {
180 tokens.push(Token::Eq);
181 i += 1;
182 continue;
183 }
184 '<' => {
185 if i + 1 < len && chars[i + 1] == '=' {
186 tokens.push(Token::LtEq);
187 i += 2;
188 } else if i + 1 < len && chars[i + 1] == '>' {
189 tokens.push(Token::NotEq);
190 i += 2;
191 } else {
192 tokens.push(Token::Lt);
193 i += 1;
194 }
195 continue;
196 }
197 '>' => {
198 if i + 1 < len && chars[i + 1] == '=' {
199 tokens.push(Token::GtEq);
200 i += 2;
201 } else {
202 tokens.push(Token::Gt);
203 i += 1;
204 }
205 continue;
206 }
207 '!' => {
208 if i + 1 < len && chars[i + 1] == '=' {
209 tokens.push(Token::NotEq);
210 i += 2;
211 continue;
212 }
213 return Err(Error(format!("unexpected character '!' at position {i}")));
214 }
215 _ => {}
216 }
217
218 if c == '\'' {
220 i += 1;
221 let mut s = String::new();
222 while i < len {
223 if chars[i] == '\'' {
224 if i + 1 < len && chars[i + 1] == '\'' {
226 s.push('\'');
227 i += 2;
228 } else {
229 break;
230 }
231 } else {
232 s.push(chars[i]);
233 i += 1;
234 }
235 }
236 if i >= len {
237 return Err(Error("unterminated string literal".into()));
238 }
239 i += 1; tokens.push(Token::StringLit(s));
241 continue;
242 }
243
244 if c.is_ascii_digit() {
246 let start = i;
247 while i < len && chars[i].is_ascii_digit() {
248 i += 1;
249 }
250 if i < len && chars[i] == '.' && i + 1 < len && chars[i + 1].is_ascii_digit() {
251 i += 1; while i < len && chars[i].is_ascii_digit() {
253 i += 1;
254 }
255 let text: String = chars[start..i].iter().collect();
256 let f: f64 = text.parse().map_err(|e| Error(format!("invalid float: {e}")))?;
257 tokens.push(Token::Float(f));
258 } else {
259 let text: String = chars[start..i].iter().collect();
260 let n: i64 = text.parse().map_err(|e| Error(format!("invalid integer: {e}")))?;
261 tokens.push(Token::Integer(n));
262 }
263 continue;
264 }
265
266 if c.is_ascii_alphabetic() || c == '_' {
268 let start = i;
269 while i < len && (chars[i].is_ascii_alphanumeric() || chars[i] == '_') {
270 i += 1;
271 }
272 let word: String = chars[start..i].iter().collect();
273 let upper = word.to_ascii_uppercase();
274 let token = match upper.as_str() {
275 "SELECT" => Token::Keyword(Keyword::Select),
276 "FROM" => Token::Keyword(Keyword::From),
277 "WHERE" => Token::Keyword(Keyword::Where),
278 "AND" => Token::Keyword(Keyword::And),
279 "OR" => Token::Keyword(Keyword::Or),
280 "NOT" => Token::Keyword(Keyword::Not),
281 "AS" => Token::Keyword(Keyword::As),
282 "ORDER" => Token::Keyword(Keyword::Order),
283 "BY" => Token::Keyword(Keyword::By),
284 "ASC" => Token::Keyword(Keyword::Asc),
285 "DESC" => Token::Keyword(Keyword::Desc),
286 "LIMIT" => Token::Keyword(Keyword::Limit),
287 "OFFSET" => Token::Keyword(Keyword::Offset),
288 "GROUP" => Token::Keyword(Keyword::Group),
289 "HAVING" => Token::Keyword(Keyword::Having),
290 "DISTINCT" => Token::Keyword(Keyword::Distinct),
291 "INSERT" => Token::Keyword(Keyword::Insert),
292 "INTO" => Token::Keyword(Keyword::Into),
293 "VALUES" => Token::Keyword(Keyword::Values),
294 "UPDATE" => Token::Keyword(Keyword::Update),
295 "SET" => Token::Keyword(Keyword::Set),
296 "DELETE" => Token::Keyword(Keyword::Delete),
297 "CREATE" => Token::Keyword(Keyword::Create),
298 "TABLE" => Token::Keyword(Keyword::Table),
299 "JOIN" => Token::Keyword(Keyword::Join),
300 "INNER" => Token::Keyword(Keyword::Inner),
301 "LEFT" => Token::Keyword(Keyword::Left),
302 "RIGHT" => Token::Keyword(Keyword::Right),
303 "ON" => Token::Keyword(Keyword::On),
304 "NULL" => Token::Keyword(Keyword::Null),
305 "TRUE" => Token::Keyword(Keyword::True),
306 "FALSE" => Token::Keyword(Keyword::False),
307 "IS" => Token::Keyword(Keyword::Is),
308 "IN" => Token::Keyword(Keyword::In),
309 "BETWEEN" => Token::Keyword(Keyword::Between),
310 "CAST" => Token::Keyword(Keyword::Cast),
311 "COUNT" => Token::Keyword(Keyword::Count),
312 "SUM" => Token::Keyword(Keyword::Sum),
313 "AVG" => Token::Keyword(Keyword::Avg),
314 "MIN" => Token::Keyword(Keyword::Min),
315 "MAX" => Token::Keyword(Keyword::Max),
316 "INT" => Token::Keyword(Keyword::Int),
317 "INT2" => Token::Keyword(Keyword::Int2),
318 "INT4" => Token::Keyword(Keyword::Int4),
319 "INT8" => Token::Keyword(Keyword::Int8),
320 "SMALLINT" => Token::Keyword(Keyword::Smallint),
321 "INTEGER" => Token::Keyword(Keyword::Integer),
322 "BIGINT" => Token::Keyword(Keyword::Bigint),
323 "FLOAT4" => Token::Keyword(Keyword::Float4),
324 "FLOAT8" => Token::Keyword(Keyword::Float8),
325 "REAL" => Token::Keyword(Keyword::Real),
326 "DOUBLE" => Token::Keyword(Keyword::Double),
327 "PRECISION" => Token::Keyword(Keyword::Precision),
328 "BOOLEAN" => Token::Keyword(Keyword::Boolean),
329 "BOOL" => Token::Keyword(Keyword::Bool),
330 "VARCHAR" => Token::Keyword(Keyword::Varchar),
331 "TEXT" => Token::Keyword(Keyword::Text),
332 "CHAR" => Token::Keyword(Keyword::Char),
333 "UTF8" => Token::Keyword(Keyword::Utf8),
334 "BLOB" => Token::Keyword(Keyword::Blob),
335 "PRIMARY" => Token::Keyword(Keyword::Primary),
336 "KEY" => Token::Keyword(Keyword::Key),
337 "WITH" => Token::Keyword(Keyword::With),
338 "RECURSIVE" => Token::Keyword(Keyword::Recursive),
339 _ => Token::Ident(word),
340 };
341 tokens.push(token);
342 continue;
343 }
344
345 return Err(Error(format!("unexpected character '{c}' at position {i}")));
346 }
347
348 Ok(tokens)
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354
355 #[test]
356 fn test_simple_select() {
357 let tokens = tokenize("SELECT id, name FROM users").unwrap();
358 assert_eq!(
359 tokens,
360 vec![
361 Token::Keyword(Keyword::Select),
362 Token::Ident("id".into()),
363 Token::Comma,
364 Token::Ident("name".into()),
365 Token::Keyword(Keyword::From),
366 Token::Ident("users".into()),
367 ]
368 );
369 }
370
371 #[test]
372 fn test_string_literal() {
373 let tokens = tokenize("SELECT 'hello'").unwrap();
374 assert_eq!(tokens, vec![Token::Keyword(Keyword::Select), Token::StringLit("hello".into()),]);
375 }
376
377 #[test]
378 fn test_comparison_operators() {
379 let tokens = tokenize("a <> b").unwrap();
380 assert_eq!(tokens, vec![Token::Ident("a".into()), Token::NotEq, Token::Ident("b".into()),]);
381 }
382
383 #[test]
384 fn test_numeric_literals() {
385 let tokens = tokenize("42 3.14").unwrap();
386 assert_eq!(tokens, vec![Token::Integer(42), Token::Float(3.14),]);
387 }
388}