1use super::{
67 m2_select_validator::M2SelectValidator, ComparisonOperator, Condition, OrderByClause,
68 ParsedQuery, QueryType, SortDirection, WhereClause,
69};
70use crate::{Config, Error, Result, TableId, Value};
71use std::collections::HashMap;
72
73#[derive(Debug)]
75pub struct QueryParser {}
76
77fn empty_parsed(query_type: QueryType, table: Option<TableId>, cql: &str) -> ParsedQuery {
80 ParsedQuery {
81 query_type,
82 table,
83 columns: Vec::new(),
84 where_clause: None,
85 values: Vec::new(),
86 set_clause: HashMap::new(),
87 order_by: Vec::new(),
88 limit: None,
89 cql: cql.to_string(),
90 }
91}
92
93impl QueryParser {
94 pub fn new(_config: &Config) -> Self {
96 Self {}
97 }
98
99 pub fn parse(&self, cql: &str) -> Result<ParsedQuery> {
101 let cql = cql.trim();
102
103 let first_word = cql
104 .split_whitespace()
105 .next()
106 .ok_or_else(|| Error::query_execution("Empty query".to_string()))?;
107
108 if first_word.eq_ignore_ascii_case("SELECT") {
111 self.parse_select(cql)
112 } else if first_word.eq_ignore_ascii_case("INSERT") {
113 self.parse_insert(cql)
114 } else if first_word.eq_ignore_ascii_case("UPDATE") {
115 self.parse_update(cql)
116 } else if first_word.eq_ignore_ascii_case("DELETE") {
117 self.parse_delete(cql)
118 } else if first_word.eq_ignore_ascii_case("CREATE") {
119 self.parse_create(cql)
120 } else if first_word.eq_ignore_ascii_case("DROP") {
121 self.parse_drop(cql)
122 } else if first_word.eq_ignore_ascii_case("DESCRIBE")
123 || first_word.eq_ignore_ascii_case("DESC")
124 {
125 self.parse_describe(cql)
126 } else if first_word.eq_ignore_ascii_case("USE") {
127 self.parse_use(cql)
128 } else {
129 Err(Error::query_execution(format!(
130 "Unsupported query type: {}",
131 first_word.to_uppercase()
132 )))
133 }
134 }
135
136 fn parse_select(&self, cql: &str) -> Result<ParsedQuery> {
138 M2SelectValidator.validate_select(cql)?;
140
141 let upper = cql.to_uppercase();
143
144 let columns = match extract_between(cql, &upper, "SELECT", "FROM") {
146 Some(select_part) => {
147 let select_part = select_part.trim();
148 if select_part == "*" {
149 vec!["*".to_string()]
150 } else {
151 select_part
152 .split(',')
153 .map(|c| c.trim().to_string())
154 .collect()
155 }
156 }
157 None => Vec::new(),
158 };
159
160 let table = match extract_after(cql, &upper, "FROM") {
166 Some(from_part) => {
167 let qualified_name = from_part.split_whitespace().next().ok_or_else(|| {
168 Error::query_execution("Missing table name after FROM".to_string())
169 })?;
170 Some(TableId::new(qualified_name))
171 }
172 None => None,
173 };
174
175 let where_clause = extract_clause(cql, &upper, "WHERE", &["ORDER BY", "LIMIT"])
177 .map(|s| self.parse_where_clause(s))
178 .transpose()?;
179
180 let order_by = match extract_clause(cql, &upper, "ORDER BY", &["LIMIT"]) {
182 Some(part) => self.parse_order_by(part)?,
183 None => Vec::new(),
184 };
185
186 let limit = match extract_after(cql, &upper, "LIMIT") {
188 Some(limit_part) => {
189 let limit_str = limit_part
190 .split_whitespace()
191 .next()
192 .ok_or_else(|| Error::query_execution("Missing limit value".to_string()))?;
193 Some(
194 limit_str
195 .parse()
196 .map_err(|_| Error::query_execution("Invalid limit value".to_string()))?,
197 )
198 }
199 None => None,
200 };
201
202 let mut parsed = empty_parsed(QueryType::Select, table, cql);
203 parsed.columns = columns;
204 parsed.where_clause = where_clause;
205 parsed.order_by = order_by;
206 parsed.limit = limit;
207 Ok(parsed)
208 }
209
210 fn parse_insert(&self, cql: &str) -> Result<ParsedQuery> {
212 let upper = cql.to_uppercase();
213
214 let paren_pos = cql.find('(');
218 let values_pos = upper.find("VALUES").unwrap_or(cql.len());
219 let explicit_columns = matches!(paren_pos, Some(p) if p < values_pos);
220
221 let (table, columns) = if explicit_columns {
222 let table = extract_between(cql, &upper, "INTO", "(").map(|t| TableId::new(t.trim()));
223 let columns = extract_between(cql, &upper, "(", ")")
224 .map(|c| c.split(',').map(|col| col.trim().to_string()).collect())
225 .unwrap_or_default();
226 (table, columns)
227 } else {
228 let table =
230 extract_between(cql, &upper, "INTO", "VALUES").map(|t| TableId::new(t.trim()));
231 (table, Vec::new())
232 };
233
234 let values = match extract_between(cql, &upper, "VALUES (", ")") {
235 Some(values_part) => self.parse_values(values_part)?,
236 None => Vec::new(),
237 };
238
239 let mut parsed = empty_parsed(QueryType::Insert, table, cql);
240 parsed.columns = columns;
241 parsed.values = values;
242 Ok(parsed)
243 }
244
245 fn parse_update(&self, cql: &str) -> Result<ParsedQuery> {
247 let upper = cql.to_uppercase();
248
249 let table = cql.split_whitespace().nth(1).map(TableId::new);
251
252 let set_clause = match extract_clause(cql, &upper, "SET", &["WHERE"]) {
254 Some(part) => self.parse_set_clause(part)?,
255 None => HashMap::new(),
256 };
257
258 let where_clause = extract_after(cql, &upper, "WHERE")
259 .map(|s| self.parse_where_clause(s))
260 .transpose()?;
261
262 let mut parsed = empty_parsed(QueryType::Update, table, cql);
263 parsed.set_clause = set_clause;
264 parsed.where_clause = where_clause;
265 Ok(parsed)
266 }
267
268 fn parse_delete(&self, cql: &str) -> Result<ParsedQuery> {
270 let upper = cql.to_uppercase();
271
272 let table = extract_clause(cql, &upper, "FROM", &["WHERE"]).map(|t| TableId::new(t.trim()));
273
274 let where_clause = extract_after(cql, &upper, "WHERE")
275 .map(|s| self.parse_where_clause(s))
276 .transpose()?;
277
278 let mut parsed = empty_parsed(QueryType::Delete, table, cql);
279 parsed.where_clause = where_clause;
280 Ok(parsed)
281 }
282
283 fn parse_create(&self, cql: &str) -> Result<ParsedQuery> {
285 parse_keyword_target(
286 cql,
287 QueryType::CreateTable,
288 "TABLE",
289 "Unsupported CREATE statement",
290 )
291 }
292
293 fn parse_drop(&self, cql: &str) -> Result<ParsedQuery> {
295 parse_keyword_target(
296 cql,
297 QueryType::DropTable,
298 "TABLE",
299 "Unsupported DROP statement",
300 )
301 }
302
303 fn parse_describe(&self, cql: &str) -> Result<ParsedQuery> {
305 parse_single_target(cql, QueryType::Describe, "Missing table name for DESCRIBE")
306 }
307
308 fn parse_use(&self, cql: &str) -> Result<ParsedQuery> {
310 parse_single_target(cql, QueryType::Use, "Missing keyspace name for USE")
312 }
313
314 fn parse_where_clause(&self, where_part: &str) -> Result<WhereClause> {
316 let mut conditions = Vec::new();
317
318 let parts: Vec<&str> = where_part.split_whitespace().collect();
320 if parts.len() >= 3 {
321 conditions.push(Condition {
322 column: parts[0].to_string(),
323 operator: self.parse_operator(parts[1])?,
324 value: self.parse_value(parts[2])?,
325 });
326 }
327
328 Ok(WhereClause { conditions })
329 }
330
331 fn parse_operator(&self, op: &str) -> Result<ComparisonOperator> {
333 match op {
334 "=" => Ok(ComparisonOperator::Equal),
335 "<>" | "!=" => Ok(ComparisonOperator::NotEqual),
336 "<" => Ok(ComparisonOperator::LessThan),
337 "<=" => Ok(ComparisonOperator::LessThanOrEqual),
338 ">" => Ok(ComparisonOperator::GreaterThan),
339 ">=" => Ok(ComparisonOperator::GreaterThanOrEqual),
340 "IN" => Ok(ComparisonOperator::In),
341 "LIKE" => Ok(ComparisonOperator::Like),
342 _ => Err(Error::query_execution(format!("Unknown operator: {}", op))),
343 }
344 }
345
346 fn parse_value(&self, value_str: &str) -> Result<Value> {
348 let value_str = value_str.trim();
349
350 if value_str.starts_with('\'') && value_str.ends_with('\'') && value_str.len() >= 2 {
352 return Ok(Value::Text(value_str[1..value_str.len() - 1].to_string()));
353 }
354
355 if let Ok(int_val) = value_str.parse::<i32>() {
357 return Ok(Value::Integer(int_val));
358 }
359
360 if let Ok(float_val) = value_str.parse::<f64>() {
362 return Ok(Value::Float(float_val));
363 }
364
365 if value_str.eq_ignore_ascii_case("TRUE") {
367 return Ok(Value::Boolean(true));
368 }
369 if value_str.eq_ignore_ascii_case("FALSE") {
370 return Ok(Value::Boolean(false));
371 }
372 if value_str.eq_ignore_ascii_case("NULL") {
373 return Ok(Value::Null);
374 }
375
376 if is_uuid_literal(value_str) {
380 if let Some(bytes) = parse_uuid_literal(value_str) {
381 return Ok(Value::Uuid(bytes));
382 }
383 }
384
385 Ok(Value::Text(value_str.to_string()))
387 }
388
389 fn parse_values(&self, values_part: &str) -> Result<Vec<Value>> {
391 values_part
392 .split(',')
393 .map(|v| self.parse_value(v.trim()))
394 .collect()
395 }
396
397 fn parse_set_clause(&self, set_part: &str) -> Result<HashMap<String, Value>> {
399 let mut set_clause = HashMap::new();
400
401 for assignment in set_part.split(',') {
402 let parts: Vec<&str> = assignment.split('=').collect();
403 if parts.len() == 2 {
404 let column = parts[0].trim().to_string();
405 let value = self.parse_value(parts[1].trim())?;
406 set_clause.insert(column, value);
407 }
408 }
409
410 Ok(set_clause)
411 }
412
413 fn parse_order_by(&self, order_part: &str) -> Result<Vec<OrderByClause>> {
415 let mut order_by = Vec::new();
416
417 for order_item in order_part.split(',') {
418 let parts: Vec<&str> = order_item.split_whitespace().collect();
419 if let Some(&col) = parts.first() {
420 let direction = if parts.get(1).is_some_and(|d| d.eq_ignore_ascii_case("DESC")) {
421 SortDirection::Desc
422 } else {
423 SortDirection::Asc
424 };
425 order_by.push(OrderByClause {
426 column: col.to_string(),
427 direction,
428 });
429 }
430 }
431
432 Ok(order_by)
433 }
434}
435
436fn extract_between<'a>(text: &'a str, upper: &str, start: &str, end: &str) -> Option<&'a str> {
446 let start_pos = upper.find(&start.to_uppercase())? + start.len();
447 let end_pos = upper[start_pos..].find(&end.to_uppercase())?;
448 Some(&text[start_pos..start_pos + end_pos])
449}
450
451fn extract_after<'a>(text: &'a str, upper: &str, pattern: &str) -> Option<&'a str> {
453 let start_pos = upper.find(&pattern.to_uppercase())? + pattern.len();
454 Some(&text[start_pos..])
455}
456
457fn extract_clause<'a>(
461 text: &'a str,
462 upper: &str,
463 start: &str,
464 terminators: &[&str],
465) -> Option<&'a str> {
466 for term in terminators {
467 if let Some(slice) = extract_between(text, upper, start, term) {
468 return Some(slice);
469 }
470 }
471 extract_after(text, upper, start)
472}
473
474fn parse_keyword_target(
478 cql: &str,
479 query_type: QueryType,
480 expected_keyword: &str,
481 err_msg: &str,
482) -> Result<ParsedQuery> {
483 let words: Vec<&str> = cql.split_whitespace().collect();
484 if words.len() >= 3 && words[1].eq_ignore_ascii_case(expected_keyword) {
485 Ok(empty_parsed(query_type, Some(TableId::new(words[2])), cql))
486 } else {
487 Err(Error::query_execution(err_msg.to_string()))
488 }
489}
490
491fn parse_single_target(cql: &str, query_type: QueryType, err_msg: &str) -> Result<ParsedQuery> {
494 match cql.split_whitespace().nth(1) {
495 Some(name) => Ok(empty_parsed(query_type, Some(TableId::new(name)), cql)),
496 None => Err(Error::query_execution(err_msg.to_string())),
497 }
498}
499
500fn is_uuid_literal(s: &str) -> bool {
507 if s.len() != 36 {
508 return false;
509 }
510 let bytes = s.as_bytes();
511 if bytes[8] != b'-' || bytes[13] != b'-' || bytes[18] != b'-' || bytes[23] != b'-' {
512 return false;
513 }
514 for (i, &b) in bytes.iter().enumerate() {
515 if i == 8 || i == 13 || i == 18 || i == 23 {
516 continue;
517 }
518 if !b.is_ascii_hexdigit() {
519 return false;
520 }
521 }
522 true
523}
524
525fn parse_uuid_literal(s: &str) -> Option<[u8; 16]> {
528 let hex: String = s.chars().filter(|&c| c != '-').collect();
530 if hex.len() != 32 {
531 return None;
532 }
533 let mut bytes = [0u8; 16];
534 for (i, chunk) in hex.as_bytes().chunks(2).enumerate() {
535 let hi = char::from(chunk[0]).to_digit(16)? as u8;
536 let lo = char::from(chunk[1]).to_digit(16)? as u8;
537 bytes[i] = (hi << 4) | lo;
538 }
539 Some(bytes)
540}
541
542#[cfg(test)]
543mod tests {
544 use super::*;
545
546 #[test]
547 fn test_parse_select_basic() {
548 let parser = QueryParser::new(&Config::default());
549 let result = parser.parse("SELECT * FROM users").unwrap();
550
551 assert_eq!(result.query_type, QueryType::Select);
552 assert_eq!(result.table, Some(TableId::new("users")));
553 assert_eq!(result.columns, vec!["*"]);
554 }
555
556 #[test]
557 fn test_parse_select_with_columns() {
558 let parser = QueryParser::new(&Config::default());
559 let result = parser.parse("SELECT id, name FROM users").unwrap();
560
561 assert_eq!(result.query_type, QueryType::Select);
562 assert_eq!(result.columns, vec!["id", "name"]);
563 }
564
565 #[test]
566 fn test_parse_select_with_where() {
567 let parser = QueryParser::new(&Config::default());
568 let result = parser.parse("SELECT * FROM users WHERE id = 1").unwrap();
569
570 assert_eq!(result.query_type, QueryType::Select);
571 assert!(result.where_clause.is_some());
572
573 let where_clause = result.where_clause.unwrap();
574 assert_eq!(where_clause.conditions.len(), 1);
575 assert_eq!(where_clause.conditions[0].column, "id");
576 assert_eq!(
577 where_clause.conditions[0].operator,
578 ComparisonOperator::Equal
579 );
580 }
581
582 #[test]
583 fn test_parse_insert() {
584 let parser = QueryParser::new(&Config::default());
585 let result = parser
586 .parse("INSERT INTO users (id, name) VALUES (1, 'Alice')")
587 .unwrap();
588
589 assert_eq!(result.query_type, QueryType::Insert);
590 assert_eq!(result.table, Some(TableId::new("users")));
591 assert_eq!(result.columns, vec!["id", "name"]);
592 assert_eq!(result.values.len(), 2);
593 }
594
595 #[test]
596 fn test_parse_update() {
597 let parser = QueryParser::new(&Config::default());
598 let result = parser
599 .parse("UPDATE users SET name = 'Bob' WHERE id = 1")
600 .unwrap();
601
602 assert_eq!(result.query_type, QueryType::Update);
603 assert_eq!(result.table, Some(TableId::new("users")));
604 assert!(!result.set_clause.is_empty());
605 assert!(result.where_clause.is_some());
606 }
607
608 #[test]
609 fn test_parse_delete() {
610 let parser = QueryParser::new(&Config::default());
611 let result = parser.parse("DELETE FROM users WHERE id = 1").unwrap();
612
613 assert_eq!(result.query_type, QueryType::Delete);
614 assert_eq!(result.table, Some(TableId::new("users")));
615 assert!(result.where_clause.is_some());
616 }
617
618 #[test]
619 fn test_parse_value_types() {
620 let parser = QueryParser::new(&Config::default());
621
622 assert_eq!(parser.parse_value("123").unwrap(), Value::Integer(123));
623 #[allow(clippy::approx_constant)]
624 {
625 assert_eq!(parser.parse_value("3.14").unwrap(), Value::Float(3.14));
626 }
627 assert_eq!(
628 parser.parse_value("'hello'").unwrap(),
629 Value::Text("hello".to_string())
630 );
631 assert_eq!(parser.parse_value("true").unwrap(), Value::Boolean(true));
632 assert_eq!(parser.parse_value("NULL").unwrap(), Value::Null);
633 }
634
635 #[test]
636 fn test_parse_select_with_qualified_table_name() {
637 let parser = QueryParser::new(&Config::default());
638 let result = parser
639 .parse("SELECT * FROM test_basic.simple_table LIMIT 5")
640 .unwrap();
641
642 assert_eq!(result.query_type, QueryType::Select);
643 assert_eq!(result.table, Some(TableId::new("test_basic.simple_table")));
645 assert_eq!(result.columns, vec!["*"]);
646 assert_eq!(result.limit, Some(5));
647 }
648
649 #[test]
650 fn test_parse_select_with_unqualified_table_name() {
651 let parser = QueryParser::new(&Config::default());
652 let result = parser.parse("SELECT * FROM simple_table LIMIT 5").unwrap();
653
654 assert_eq!(result.query_type, QueryType::Select);
655 assert_eq!(result.table, Some(TableId::new("simple_table")));
657 assert_eq!(result.columns, vec!["*"]);
658 assert_eq!(result.limit, Some(5));
659 }
660}