Skip to main content

reifydb_sql/
token.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Copyright (c) 2025 ReifyDB
3
4use crate::Error;
5
6#[derive(Debug, Clone, PartialEq)]
7pub enum Token {
8	// Keywords
9	Keyword(Keyword),
10	// Identifiers
11	Ident(String),
12	// Literals
13	Integer(i64),
14	Float(f64),
15	StringLit(String),
16	// Operators & punctuation
17	Asterisk,   // *
18	Comma,      // ,
19	Dot,        // .
20	Semicolon,  // ;
21	OpenParen,  // (
22	CloseParen, // )
23	Plus,       // +
24	Minus,      // -
25	Slash,      // /
26	Percent,    // %
27	Eq,         // =
28	NotEq,      // <> or !=
29	Lt,         // <
30	Gt,         // >
31	LtEq,       // <=
32	GtEq,       // >=
33}
34
35#[derive(Debug, Clone, PartialEq)]
36pub enum Keyword {
37	Select,
38	From,
39	Where,
40	And,
41	Or,
42	Not,
43	As,
44	Order,
45	By,
46	Asc,
47	Desc,
48	Limit,
49	Offset,
50	Group,
51	Having,
52	Distinct,
53	Insert,
54	Into,
55	Values,
56	Update,
57	Set,
58	Delete,
59	Create,
60	Table,
61	Join,
62	Inner,
63	Left,
64	Right,
65	On,
66	Null,
67	True,
68	False,
69	Is,
70	In,
71	Between,
72	Cast,
73	Count,
74	Sum,
75	Avg,
76	Min,
77	Max,
78	// SQL types
79	Int,
80	Int2,
81	Int4,
82	Int8,
83	Smallint,
84	Integer,
85	Bigint,
86	Float4,
87	Float8,
88	Real,
89	Double,
90	Precision,
91	Boolean,
92	Bool,
93	Varchar,
94	Text,
95	Char,
96	Utf8,
97	Blob,
98	Primary,
99	Key,
100	With,
101	Recursive,
102}
103
104pub fn tokenize(sql: &str) -> Result<Vec<Token>, Error> {
105	let mut tokens = Vec::new();
106	let chars: Vec<char> = sql.chars().collect();
107	let len = chars.len();
108	let mut i = 0;
109
110	while i < len {
111		let c = chars[i];
112
113		// Skip whitespace
114		if c.is_ascii_whitespace() {
115			i += 1;
116			continue;
117		}
118
119		// Skip line comments (-- ...)
120		if c == '-' && i + 1 < len && chars[i + 1] == '-' {
121			while i < len && chars[i] != '\n' {
122				i += 1;
123			}
124			continue;
125		}
126
127		// Operators and punctuation
128		match c {
129			'*' => {
130				tokens.push(Token::Asterisk);
131				i += 1;
132				continue;
133			}
134			',' => {
135				tokens.push(Token::Comma);
136				i += 1;
137				continue;
138			}
139			'.' => {
140				tokens.push(Token::Dot);
141				i += 1;
142				continue;
143			}
144			';' => {
145				tokens.push(Token::Semicolon);
146				i += 1;
147				continue;
148			}
149			'(' => {
150				tokens.push(Token::OpenParen);
151				i += 1;
152				continue;
153			}
154			')' => {
155				tokens.push(Token::CloseParen);
156				i += 1;
157				continue;
158			}
159			'+' => {
160				tokens.push(Token::Plus);
161				i += 1;
162				continue;
163			}
164			'-' => {
165				tokens.push(Token::Minus);
166				i += 1;
167				continue;
168			}
169			'/' => {
170				tokens.push(Token::Slash);
171				i += 1;
172				continue;
173			}
174			'%' => {
175				tokens.push(Token::Percent);
176				i += 1;
177				continue;
178			}
179			'=' => {
180				tokens.push(Token::Eq);
181				i += 1;
182				continue;
183			}
184			'<' => {
185				if i + 1 < len && chars[i + 1] == '=' {
186					tokens.push(Token::LtEq);
187					i += 2;
188				} else if i + 1 < len && chars[i + 1] == '>' {
189					tokens.push(Token::NotEq);
190					i += 2;
191				} else {
192					tokens.push(Token::Lt);
193					i += 1;
194				}
195				continue;
196			}
197			'>' => {
198				if i + 1 < len && chars[i + 1] == '=' {
199					tokens.push(Token::GtEq);
200					i += 2;
201				} else {
202					tokens.push(Token::Gt);
203					i += 1;
204				}
205				continue;
206			}
207			'!' => {
208				if i + 1 < len && chars[i + 1] == '=' {
209					tokens.push(Token::NotEq);
210					i += 2;
211					continue;
212				}
213				return Err(Error(format!("unexpected character '!' at position {i}")));
214			}
215			_ => {}
216		}
217
218		// String literals
219		if c == '\'' {
220			i += 1;
221			let mut s = String::new();
222			while i < len {
223				if chars[i] == '\'' {
224					// Check for escaped single quote ''
225					if i + 1 < len && chars[i + 1] == '\'' {
226						s.push('\'');
227						i += 2;
228					} else {
229						break;
230					}
231				} else {
232					s.push(chars[i]);
233					i += 1;
234				}
235			}
236			if i >= len {
237				return Err(Error("unterminated string literal".into()));
238			}
239			i += 1; // skip closing quote
240			tokens.push(Token::StringLit(s));
241			continue;
242		}
243
244		// Numeric literals
245		if c.is_ascii_digit() {
246			let start = i;
247			while i < len && chars[i].is_ascii_digit() {
248				i += 1;
249			}
250			if i < len && chars[i] == '.' && i + 1 < len && chars[i + 1].is_ascii_digit() {
251				i += 1; // skip dot
252				while i < len && chars[i].is_ascii_digit() {
253					i += 1;
254				}
255				let text: String = chars[start..i].iter().collect();
256				let f: f64 = text.parse().map_err(|e| Error(format!("invalid float: {e}")))?;
257				tokens.push(Token::Float(f));
258			} else {
259				let text: String = chars[start..i].iter().collect();
260				let n: i64 = text.parse().map_err(|e| Error(format!("invalid integer: {e}")))?;
261				tokens.push(Token::Integer(n));
262			}
263			continue;
264		}
265
266		// Identifiers and keywords
267		if c.is_ascii_alphabetic() || c == '_' {
268			let start = i;
269			while i < len && (chars[i].is_ascii_alphanumeric() || chars[i] == '_') {
270				i += 1;
271			}
272			let word: String = chars[start..i].iter().collect();
273			let upper = word.to_ascii_uppercase();
274			let token = match upper.as_str() {
275				"SELECT" => Token::Keyword(Keyword::Select),
276				"FROM" => Token::Keyword(Keyword::From),
277				"WHERE" => Token::Keyword(Keyword::Where),
278				"AND" => Token::Keyword(Keyword::And),
279				"OR" => Token::Keyword(Keyword::Or),
280				"NOT" => Token::Keyword(Keyword::Not),
281				"AS" => Token::Keyword(Keyword::As),
282				"ORDER" => Token::Keyword(Keyword::Order),
283				"BY" => Token::Keyword(Keyword::By),
284				"ASC" => Token::Keyword(Keyword::Asc),
285				"DESC" => Token::Keyword(Keyword::Desc),
286				"LIMIT" => Token::Keyword(Keyword::Limit),
287				"OFFSET" => Token::Keyword(Keyword::Offset),
288				"GROUP" => Token::Keyword(Keyword::Group),
289				"HAVING" => Token::Keyword(Keyword::Having),
290				"DISTINCT" => Token::Keyword(Keyword::Distinct),
291				"INSERT" => Token::Keyword(Keyword::Insert),
292				"INTO" => Token::Keyword(Keyword::Into),
293				"VALUES" => Token::Keyword(Keyword::Values),
294				"UPDATE" => Token::Keyword(Keyword::Update),
295				"SET" => Token::Keyword(Keyword::Set),
296				"DELETE" => Token::Keyword(Keyword::Delete),
297				"CREATE" => Token::Keyword(Keyword::Create),
298				"TABLE" => Token::Keyword(Keyword::Table),
299				"JOIN" => Token::Keyword(Keyword::Join),
300				"INNER" => Token::Keyword(Keyword::Inner),
301				"LEFT" => Token::Keyword(Keyword::Left),
302				"RIGHT" => Token::Keyword(Keyword::Right),
303				"ON" => Token::Keyword(Keyword::On),
304				"NULL" => Token::Keyword(Keyword::Null),
305				"TRUE" => Token::Keyword(Keyword::True),
306				"FALSE" => Token::Keyword(Keyword::False),
307				"IS" => Token::Keyword(Keyword::Is),
308				"IN" => Token::Keyword(Keyword::In),
309				"BETWEEN" => Token::Keyword(Keyword::Between),
310				"CAST" => Token::Keyword(Keyword::Cast),
311				"COUNT" => Token::Keyword(Keyword::Count),
312				"SUM" => Token::Keyword(Keyword::Sum),
313				"AVG" => Token::Keyword(Keyword::Avg),
314				"MIN" => Token::Keyword(Keyword::Min),
315				"MAX" => Token::Keyword(Keyword::Max),
316				"INT" => Token::Keyword(Keyword::Int),
317				"INT2" => Token::Keyword(Keyword::Int2),
318				"INT4" => Token::Keyword(Keyword::Int4),
319				"INT8" => Token::Keyword(Keyword::Int8),
320				"SMALLINT" => Token::Keyword(Keyword::Smallint),
321				"INTEGER" => Token::Keyword(Keyword::Integer),
322				"BIGINT" => Token::Keyword(Keyword::Bigint),
323				"FLOAT4" => Token::Keyword(Keyword::Float4),
324				"FLOAT8" => Token::Keyword(Keyword::Float8),
325				"REAL" => Token::Keyword(Keyword::Real),
326				"DOUBLE" => Token::Keyword(Keyword::Double),
327				"PRECISION" => Token::Keyword(Keyword::Precision),
328				"BOOLEAN" => Token::Keyword(Keyword::Boolean),
329				"BOOL" => Token::Keyword(Keyword::Bool),
330				"VARCHAR" => Token::Keyword(Keyword::Varchar),
331				"TEXT" => Token::Keyword(Keyword::Text),
332				"CHAR" => Token::Keyword(Keyword::Char),
333				"UTF8" => Token::Keyword(Keyword::Utf8),
334				"BLOB" => Token::Keyword(Keyword::Blob),
335				"PRIMARY" => Token::Keyword(Keyword::Primary),
336				"KEY" => Token::Keyword(Keyword::Key),
337				"WITH" => Token::Keyword(Keyword::With),
338				"RECURSIVE" => Token::Keyword(Keyword::Recursive),
339				_ => Token::Ident(word),
340			};
341			tokens.push(token);
342			continue;
343		}
344
345		return Err(Error(format!("unexpected character '{c}' at position {i}")));
346	}
347
348	Ok(tokens)
349}
350
351#[cfg(test)]
352mod tests {
353	use super::*;
354
355	#[test]
356	fn test_simple_select() {
357		let tokens = tokenize("SELECT id, name FROM users").unwrap();
358		assert_eq!(
359			tokens,
360			vec![
361				Token::Keyword(Keyword::Select),
362				Token::Ident("id".into()),
363				Token::Comma,
364				Token::Ident("name".into()),
365				Token::Keyword(Keyword::From),
366				Token::Ident("users".into()),
367			]
368		);
369	}
370
371	#[test]
372	fn test_string_literal() {
373		let tokens = tokenize("SELECT 'hello'").unwrap();
374		assert_eq!(tokens, vec![Token::Keyword(Keyword::Select), Token::StringLit("hello".into()),]);
375	}
376
377	#[test]
378	fn test_comparison_operators() {
379		let tokens = tokenize("a <> b").unwrap();
380		assert_eq!(tokens, vec![Token::Ident("a".into()), Token::NotEq, Token::Ident("b".into()),]);
381	}
382
383	#[test]
384	fn test_numeric_literals() {
385		let tokens = tokenize("42 3.14").unwrap();
386		assert_eq!(tokens, vec![Token::Integer(42), Token::Float(3.14),]);
387	}
388}