use super::{
m2_select_validator::M2SelectValidator, ComparisonOperator, Condition, OrderByClause,
ParsedQuery, QueryType, SortDirection, WhereClause,
};
use crate::{Config, Error, Result, TableId, Value};
use std::collections::HashMap;
#[derive(Debug)]
pub struct QueryParser {}
fn empty_parsed(query_type: QueryType, table: Option<TableId>, cql: &str) -> ParsedQuery {
ParsedQuery {
query_type,
table,
columns: Vec::new(),
where_clause: None,
values: Vec::new(),
set_clause: HashMap::new(),
order_by: Vec::new(),
limit: None,
cql: cql.to_string(),
}
}
impl QueryParser {
pub fn new(_config: &Config) -> Self {
Self {}
}
pub fn parse(&self, cql: &str) -> Result<ParsedQuery> {
let cql = cql.trim();
let first_word = cql
.split_whitespace()
.next()
.ok_or_else(|| Error::query_execution("Empty query".to_string()))?;
if first_word.eq_ignore_ascii_case("SELECT") {
self.parse_select(cql)
} else if first_word.eq_ignore_ascii_case("INSERT") {
self.parse_insert(cql)
} else if first_word.eq_ignore_ascii_case("UPDATE") {
self.parse_update(cql)
} else if first_word.eq_ignore_ascii_case("DELETE") {
self.parse_delete(cql)
} else if first_word.eq_ignore_ascii_case("CREATE") {
self.parse_create(cql)
} else if first_word.eq_ignore_ascii_case("DROP") {
self.parse_drop(cql)
} else if first_word.eq_ignore_ascii_case("DESCRIBE")
|| first_word.eq_ignore_ascii_case("DESC")
{
self.parse_describe(cql)
} else if first_word.eq_ignore_ascii_case("USE") {
self.parse_use(cql)
} else {
Err(Error::query_execution(format!(
"Unsupported query type: {}",
first_word.to_uppercase()
)))
}
}
fn parse_select(&self, cql: &str) -> Result<ParsedQuery> {
M2SelectValidator.validate_select(cql)?;
let upper = cql.to_uppercase();
let columns = match extract_between(cql, &upper, "SELECT", "FROM") {
Some(select_part) => {
let select_part = select_part.trim();
if select_part == "*" {
vec!["*".to_string()]
} else {
select_part
.split(',')
.map(|c| c.trim().to_string())
.collect()
}
}
None => Vec::new(),
};
let table = match extract_after(cql, &upper, "FROM") {
Some(from_part) => {
let qualified_name = from_part.split_whitespace().next().ok_or_else(|| {
Error::query_execution("Missing table name after FROM".to_string())
})?;
Some(TableId::new(qualified_name))
}
None => None,
};
let where_clause = extract_clause(cql, &upper, "WHERE", &["ORDER BY", "LIMIT"])
.map(|s| self.parse_where_clause(s))
.transpose()?;
let order_by = match extract_clause(cql, &upper, "ORDER BY", &["LIMIT"]) {
Some(part) => self.parse_order_by(part)?,
None => Vec::new(),
};
let limit = match extract_after(cql, &upper, "LIMIT") {
Some(limit_part) => {
let limit_str = limit_part
.split_whitespace()
.next()
.ok_or_else(|| Error::query_execution("Missing limit value".to_string()))?;
Some(
limit_str
.parse()
.map_err(|_| Error::query_execution("Invalid limit value".to_string()))?,
)
}
None => None,
};
let mut parsed = empty_parsed(QueryType::Select, table, cql);
parsed.columns = columns;
parsed.where_clause = where_clause;
parsed.order_by = order_by;
parsed.limit = limit;
Ok(parsed)
}
fn parse_insert(&self, cql: &str) -> Result<ParsedQuery> {
let upper = cql.to_uppercase();
let paren_pos = cql.find('(');
let values_pos = upper.find("VALUES").unwrap_or(cql.len());
let explicit_columns = matches!(paren_pos, Some(p) if p < values_pos);
let (table, columns) = if explicit_columns {
let table = extract_between(cql, &upper, "INTO", "(").map(|t| TableId::new(t.trim()));
let columns = extract_between(cql, &upper, "(", ")")
.map(|c| c.split(',').map(|col| col.trim().to_string()).collect())
.unwrap_or_default();
(table, columns)
} else {
let table =
extract_between(cql, &upper, "INTO", "VALUES").map(|t| TableId::new(t.trim()));
(table, Vec::new())
};
let values = match extract_between(cql, &upper, "VALUES (", ")") {
Some(values_part) => self.parse_values(values_part)?,
None => Vec::new(),
};
let mut parsed = empty_parsed(QueryType::Insert, table, cql);
parsed.columns = columns;
parsed.values = values;
Ok(parsed)
}
fn parse_update(&self, cql: &str) -> Result<ParsedQuery> {
let upper = cql.to_uppercase();
let table = cql.split_whitespace().nth(1).map(TableId::new);
let set_clause = match extract_clause(cql, &upper, "SET", &["WHERE"]) {
Some(part) => self.parse_set_clause(part)?,
None => HashMap::new(),
};
let where_clause = extract_after(cql, &upper, "WHERE")
.map(|s| self.parse_where_clause(s))
.transpose()?;
let mut parsed = empty_parsed(QueryType::Update, table, cql);
parsed.set_clause = set_clause;
parsed.where_clause = where_clause;
Ok(parsed)
}
fn parse_delete(&self, cql: &str) -> Result<ParsedQuery> {
let upper = cql.to_uppercase();
let table = extract_clause(cql, &upper, "FROM", &["WHERE"]).map(|t| TableId::new(t.trim()));
let where_clause = extract_after(cql, &upper, "WHERE")
.map(|s| self.parse_where_clause(s))
.transpose()?;
let mut parsed = empty_parsed(QueryType::Delete, table, cql);
parsed.where_clause = where_clause;
Ok(parsed)
}
fn parse_create(&self, cql: &str) -> Result<ParsedQuery> {
parse_keyword_target(
cql,
QueryType::CreateTable,
"TABLE",
"Unsupported CREATE statement",
)
}
fn parse_drop(&self, cql: &str) -> Result<ParsedQuery> {
parse_keyword_target(
cql,
QueryType::DropTable,
"TABLE",
"Unsupported DROP statement",
)
}
fn parse_describe(&self, cql: &str) -> Result<ParsedQuery> {
parse_single_target(cql, QueryType::Describe, "Missing table name for DESCRIBE")
}
fn parse_use(&self, cql: &str) -> Result<ParsedQuery> {
parse_single_target(cql, QueryType::Use, "Missing keyspace name for USE")
}
fn parse_where_clause(&self, where_part: &str) -> Result<WhereClause> {
let mut conditions = Vec::new();
let parts: Vec<&str> = where_part.split_whitespace().collect();
if parts.len() >= 3 {
conditions.push(Condition {
column: parts[0].to_string(),
operator: self.parse_operator(parts[1])?,
value: self.parse_value(parts[2])?,
});
}
Ok(WhereClause { conditions })
}
fn parse_operator(&self, op: &str) -> Result<ComparisonOperator> {
match op {
"=" => Ok(ComparisonOperator::Equal),
"<>" | "!=" => Ok(ComparisonOperator::NotEqual),
"<" => Ok(ComparisonOperator::LessThan),
"<=" => Ok(ComparisonOperator::LessThanOrEqual),
">" => Ok(ComparisonOperator::GreaterThan),
">=" => Ok(ComparisonOperator::GreaterThanOrEqual),
"IN" => Ok(ComparisonOperator::In),
"LIKE" => Ok(ComparisonOperator::Like),
_ => Err(Error::query_execution(format!("Unknown operator: {}", op))),
}
}
fn parse_value(&self, value_str: &str) -> Result<Value> {
let value_str = value_str.trim();
if value_str.starts_with('\'') && value_str.ends_with('\'') && value_str.len() >= 2 {
return Ok(Value::Text(value_str[1..value_str.len() - 1].to_string()));
}
if let Ok(int_val) = value_str.parse::<i32>() {
return Ok(Value::Integer(int_val));
}
if let Ok(float_val) = value_str.parse::<f64>() {
return Ok(Value::Float(float_val));
}
if value_str.eq_ignore_ascii_case("TRUE") {
return Ok(Value::Boolean(true));
}
if value_str.eq_ignore_ascii_case("FALSE") {
return Ok(Value::Boolean(false));
}
if value_str.eq_ignore_ascii_case("NULL") {
return Ok(Value::Null);
}
if is_uuid_literal(value_str) {
if let Some(bytes) = parse_uuid_literal(value_str) {
return Ok(Value::Uuid(bytes));
}
}
Ok(Value::Text(value_str.to_string()))
}
fn parse_values(&self, values_part: &str) -> Result<Vec<Value>> {
values_part
.split(',')
.map(|v| self.parse_value(v.trim()))
.collect()
}
fn parse_set_clause(&self, set_part: &str) -> Result<HashMap<String, Value>> {
let mut set_clause = HashMap::new();
for assignment in set_part.split(',') {
let parts: Vec<&str> = assignment.split('=').collect();
if parts.len() == 2 {
let column = parts[0].trim().to_string();
let value = self.parse_value(parts[1].trim())?;
set_clause.insert(column, value);
}
}
Ok(set_clause)
}
fn parse_order_by(&self, order_part: &str) -> Result<Vec<OrderByClause>> {
let mut order_by = Vec::new();
for order_item in order_part.split(',') {
let parts: Vec<&str> = order_item.split_whitespace().collect();
if let Some(&col) = parts.first() {
let direction = if parts.get(1).is_some_and(|d| d.eq_ignore_ascii_case("DESC")) {
SortDirection::Desc
} else {
SortDirection::Asc
};
order_by.push(OrderByClause {
column: col.to_string(),
direction,
});
}
}
Ok(order_by)
}
}
fn extract_between<'a>(text: &'a str, upper: &str, start: &str, end: &str) -> Option<&'a str> {
let start_pos = upper.find(&start.to_uppercase())? + start.len();
let end_pos = upper[start_pos..].find(&end.to_uppercase())?;
Some(&text[start_pos..start_pos + end_pos])
}
fn extract_after<'a>(text: &'a str, upper: &str, pattern: &str) -> Option<&'a str> {
let start_pos = upper.find(&pattern.to_uppercase())? + pattern.len();
Some(&text[start_pos..])
}
fn extract_clause<'a>(
text: &'a str,
upper: &str,
start: &str,
terminators: &[&str],
) -> Option<&'a str> {
for term in terminators {
if let Some(slice) = extract_between(text, upper, start, term) {
return Some(slice);
}
}
extract_after(text, upper, start)
}
fn parse_keyword_target(
cql: &str,
query_type: QueryType,
expected_keyword: &str,
err_msg: &str,
) -> Result<ParsedQuery> {
let words: Vec<&str> = cql.split_whitespace().collect();
if words.len() >= 3 && words[1].eq_ignore_ascii_case(expected_keyword) {
Ok(empty_parsed(query_type, Some(TableId::new(words[2])), cql))
} else {
Err(Error::query_execution(err_msg.to_string()))
}
}
fn parse_single_target(cql: &str, query_type: QueryType, err_msg: &str) -> Result<ParsedQuery> {
match cql.split_whitespace().nth(1) {
Some(name) => Ok(empty_parsed(query_type, Some(TableId::new(name)), cql)),
None => Err(Error::query_execution(err_msg.to_string())),
}
}
fn is_uuid_literal(s: &str) -> bool {
if s.len() != 36 {
return false;
}
let bytes = s.as_bytes();
if bytes[8] != b'-' || bytes[13] != b'-' || bytes[18] != b'-' || bytes[23] != b'-' {
return false;
}
for (i, &b) in bytes.iter().enumerate() {
if i == 8 || i == 13 || i == 18 || i == 23 {
continue;
}
if !b.is_ascii_hexdigit() {
return false;
}
}
true
}
fn parse_uuid_literal(s: &str) -> Option<[u8; 16]> {
let hex: String = s.chars().filter(|&c| c != '-').collect();
if hex.len() != 32 {
return None;
}
let mut bytes = [0u8; 16];
for (i, chunk) in hex.as_bytes().chunks(2).enumerate() {
let hi = char::from(chunk[0]).to_digit(16)? as u8;
let lo = char::from(chunk[1]).to_digit(16)? as u8;
bytes[i] = (hi << 4) | lo;
}
Some(bytes)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_select_basic() {
let parser = QueryParser::new(&Config::default());
let result = parser.parse("SELECT * FROM users").unwrap();
assert_eq!(result.query_type, QueryType::Select);
assert_eq!(result.table, Some(TableId::new("users")));
assert_eq!(result.columns, vec!["*"]);
}
#[test]
fn test_parse_select_with_columns() {
let parser = QueryParser::new(&Config::default());
let result = parser.parse("SELECT id, name FROM users").unwrap();
assert_eq!(result.query_type, QueryType::Select);
assert_eq!(result.columns, vec!["id", "name"]);
}
#[test]
fn test_parse_select_with_where() {
let parser = QueryParser::new(&Config::default());
let result = parser.parse("SELECT * FROM users WHERE id = 1").unwrap();
assert_eq!(result.query_type, QueryType::Select);
assert!(result.where_clause.is_some());
let where_clause = result.where_clause.unwrap();
assert_eq!(where_clause.conditions.len(), 1);
assert_eq!(where_clause.conditions[0].column, "id");
assert_eq!(
where_clause.conditions[0].operator,
ComparisonOperator::Equal
);
}
#[test]
fn test_parse_insert() {
let parser = QueryParser::new(&Config::default());
let result = parser
.parse("INSERT INTO users (id, name) VALUES (1, 'Alice')")
.unwrap();
assert_eq!(result.query_type, QueryType::Insert);
assert_eq!(result.table, Some(TableId::new("users")));
assert_eq!(result.columns, vec!["id", "name"]);
assert_eq!(result.values.len(), 2);
}
#[test]
fn test_parse_update() {
let parser = QueryParser::new(&Config::default());
let result = parser
.parse("UPDATE users SET name = 'Bob' WHERE id = 1")
.unwrap();
assert_eq!(result.query_type, QueryType::Update);
assert_eq!(result.table, Some(TableId::new("users")));
assert!(!result.set_clause.is_empty());
assert!(result.where_clause.is_some());
}
#[test]
fn test_parse_delete() {
let parser = QueryParser::new(&Config::default());
let result = parser.parse("DELETE FROM users WHERE id = 1").unwrap();
assert_eq!(result.query_type, QueryType::Delete);
assert_eq!(result.table, Some(TableId::new("users")));
assert!(result.where_clause.is_some());
}
#[test]
fn test_parse_value_types() {
let parser = QueryParser::new(&Config::default());
assert_eq!(parser.parse_value("123").unwrap(), Value::Integer(123));
#[allow(clippy::approx_constant)]
{
assert_eq!(parser.parse_value("3.14").unwrap(), Value::Float(3.14));
}
assert_eq!(
parser.parse_value("'hello'").unwrap(),
Value::Text("hello".to_string())
);
assert_eq!(parser.parse_value("true").unwrap(), Value::Boolean(true));
assert_eq!(parser.parse_value("NULL").unwrap(), Value::Null);
}
#[test]
fn test_parse_select_with_qualified_table_name() {
let parser = QueryParser::new(&Config::default());
let result = parser
.parse("SELECT * FROM test_basic.simple_table LIMIT 5")
.unwrap();
assert_eq!(result.query_type, QueryType::Select);
assert_eq!(result.table, Some(TableId::new("test_basic.simple_table")));
assert_eq!(result.columns, vec!["*"]);
assert_eq!(result.limit, Some(5));
}
#[test]
fn test_parse_select_with_unqualified_table_name() {
let parser = QueryParser::new(&Config::default());
let result = parser.parse("SELECT * FROM simple_table LIMIT 5").unwrap();
assert_eq!(result.query_type, QueryType::Select);
assert_eq!(result.table, Some(TableId::new("simple_table")));
assert_eq!(result.columns, vec!["*"]);
assert_eq!(result.limit, Some(5));
}
}