use super::select_ast::*;
use crate::{Error, Result, TableId, Value};
#[derive(Debug)]
pub struct SelectParser {
current_token: Option<Token>,
tokenizer: Tokenizer,
}
#[derive(Debug, Clone, PartialEq)]
pub enum Token {
Select,
Distinct,
From,
Where,
GroupBy,
Having,
OrderBy,
Limit,
Offset,
And,
Or,
Not,
Like,
In,
Between,
As,
Asc,
Desc,
Allow,
Filtering,
Count,
Sum,
Avg,
Min,
Max,
Is,
Null,
Contains,
Key,
Equal, NotEqual, LessThan, LessThanEqual, GreaterThan, GreaterThanEqual, Plus, Minus, Multiply, Divide, Modulo,
Integer(i64),
Float(f64),
String(String),
Boolean(bool),
Identifier(String),
LeftParen, RightParen, LeftBracket, RightBracket, LeftBrace, RightBrace, Comma, Semicolon, Dot, Question,
Eof,
Newline,
Whitespace,
}
fn keyword_for(ident: &str) -> Option<Token> {
const KEYWORDS: &[(&str, Token)] = &[
("SELECT", Token::Select),
("DISTINCT", Token::Distinct),
("FROM", Token::From),
("WHERE", Token::Where),
("HAVING", Token::Having),
("LIMIT", Token::Limit),
("OFFSET", Token::Offset),
("AND", Token::And),
("OR", Token::Or),
("NOT", Token::Not),
("LIKE", Token::Like),
("IN", Token::In),
("BETWEEN", Token::Between),
("AS", Token::As),
("ASC", Token::Asc),
("DESC", Token::Desc),
("ALLOW", Token::Allow),
("FILTERING", Token::Filtering),
("COUNT", Token::Count),
("SUM", Token::Sum),
("AVG", Token::Avg),
("MIN", Token::Min),
("MAX", Token::Max),
("IS", Token::Is),
("NULL", Token::Null),
("CONTAINS", Token::Contains),
("KEY", Token::Key),
("TRUE", Token::Boolean(true)),
("FALSE", Token::Boolean(false)),
];
KEYWORDS
.iter()
.find(|(kw, _)| ident.eq_ignore_ascii_case(kw))
.map(|(_, tok)| tok.clone())
}
fn aggregate_for(token: &Token) -> Option<AggregateType> {
match token {
Token::Count => Some(AggregateType::Count),
Token::Sum => Some(AggregateType::Sum),
Token::Avg => Some(AggregateType::Avg),
Token::Min => Some(AggregateType::Min),
Token::Max => Some(AggregateType::Max),
_ => None,
}
}
#[derive(Debug)]
pub struct Tokenizer {
input: Vec<char>,
position: usize,
current_char: Option<char>,
}
impl Tokenizer {
pub fn new(input: &str) -> Self {
let chars: Vec<char> = input.chars().collect();
let current_char = chars.first().copied();
Self {
input: chars,
position: 0,
current_char,
}
}
fn advance(&mut self) {
self.position += 1;
self.current_char = self.input.get(self.position).copied();
}
fn peek(&self) -> Option<char> {
self.input.get(self.position + 1).copied()
}
fn skip_whitespace(&mut self) {
while let Some(ch) = self.current_char {
if ch.is_whitespace() {
self.advance();
} else {
break;
}
}
}
fn read_string(&mut self, quote_char: char) -> Result<String> {
let mut value = String::new();
self.advance();
while let Some(ch) = self.current_char {
if ch == quote_char {
self.advance(); return Ok(value);
} else if ch == '\\' {
self.advance();
if let Some(escaped) = self.current_char {
let mapped = match escaped {
'n' => Some('\n'),
't' => Some('\t'),
'r' => Some('\r'),
'\\' => Some('\\'),
'\'' => Some('\''),
'"' => Some('"'),
_ => None,
};
match mapped {
Some(c) => value.push(c),
None => {
value.push('\\');
value.push(escaped);
}
}
self.advance();
}
} else {
value.push(ch);
self.advance();
}
}
Err(Error::cql_parse("Unterminated string literal"))
}
fn read_number(&mut self) -> Result<Token> {
let mut value = String::new();
let mut has_dot = false;
while let Some(ch) = self.current_char {
if ch.is_ascii_digit() {
value.push(ch);
self.advance();
} else if ch == '.' && !has_dot {
has_dot = true;
value.push(ch);
self.advance();
} else {
break;
}
}
if has_dot {
value
.parse::<f64>()
.map(Token::Float)
.map_err(|_| Error::cql_parse(format!("Invalid float: {}", value)))
} else {
value
.parse::<i64>()
.map(Token::Integer)
.map_err(|_| Error::cql_parse(format!("Invalid integer: {}", value)))
}
}
fn read_identifier(&mut self) -> String {
let mut value = String::new();
while let Some(ch) = self.current_char {
if ch.is_alphanumeric() || ch == '_' {
value.push(ch);
self.advance();
} else {
break;
}
}
value
}
fn expect_by_keyword(&mut self, after: &str) -> Result<()> {
self.skip_whitespace();
let next = self.read_identifier();
if next.eq_ignore_ascii_case("BY") {
Ok(())
} else {
Err(Error::cql_parse(format!("Expected BY after {}", after)))
}
}
pub fn next_token(&mut self) -> Result<Token> {
loop {
let ch = match self.current_char {
None => return Ok(Token::Eof),
Some(c) => c,
};
let single = match ch {
'(' => Some(Token::LeftParen),
')' => Some(Token::RightParen),
'[' => Some(Token::LeftBracket),
']' => Some(Token::RightBracket),
'{' => Some(Token::LeftBrace),
'}' => Some(Token::RightBrace),
',' => Some(Token::Comma),
';' => Some(Token::Semicolon),
'.' => Some(Token::Dot),
'?' => Some(Token::Question),
'+' => Some(Token::Plus),
'-' => Some(Token::Minus),
'*' => Some(Token::Multiply),
'/' => Some(Token::Divide),
'%' => Some(Token::Modulo),
'=' => Some(Token::Equal),
_ => None,
};
if let Some(tok) = single {
self.advance();
return Ok(tok);
}
match ch {
c if c.is_whitespace() => self.skip_whitespace(),
'!' => {
if self.peek() == Some('=') {
self.advance();
self.advance();
return Ok(Token::NotEqual);
}
return Err(Error::cql_parse("Unexpected character: !"));
}
'<' => {
return Ok(match self.peek() {
Some('=') => {
self.advance();
self.advance();
Token::LessThanEqual
}
Some('>') => {
self.advance();
self.advance();
Token::NotEqual
}
_ => {
self.advance();
Token::LessThan
}
});
}
'>' => {
return Ok(if self.peek() == Some('=') {
self.advance();
self.advance();
Token::GreaterThanEqual
} else {
self.advance();
Token::GreaterThan
});
}
'\'' | '"' => return self.read_string(ch).map(Token::String),
c if c.is_ascii_digit() => return self.read_number(),
c if c.is_alphabetic() || c == '_' => {
let identifier = self.read_identifier();
if identifier.eq_ignore_ascii_case("GROUP") {
self.expect_by_keyword("GROUP")?;
return Ok(Token::GroupBy);
}
if identifier.eq_ignore_ascii_case("ORDER") {
self.expect_by_keyword("ORDER")?;
return Ok(Token::OrderBy);
}
return Ok(keyword_for(&identifier).unwrap_or(Token::Identifier(identifier)));
}
other => return Err(Error::cql_parse(format!("Unexpected character: {}", other))),
}
}
}
}
impl SelectParser {
pub fn new(cql: &str) -> Result<Self> {
let mut tokenizer = Tokenizer::new(cql);
let current_token = Some(tokenizer.next_token()?);
Ok(Self {
current_token,
tokenizer,
})
}
fn advance(&mut self) -> Result<()> {
self.current_token = Some(self.tokenizer.next_token()?);
Ok(())
}
fn peek(&self) -> &Token {
self.current_token.as_ref().unwrap_or(&Token::Eof)
}
fn at(&self, tok: &Token) -> bool {
self.current_token
.as_ref()
.is_some_and(|cur| std::mem::discriminant(cur) == std::mem::discriminant(tok))
}
fn eat(&mut self, tok: &Token) -> Result<bool> {
if self.at(tok) {
self.advance()?;
Ok(true)
} else {
Ok(false)
}
}
fn expect(&mut self, expected: Token) -> Result<()> {
if let Some(ref current) = self.current_token {
if std::mem::discriminant(current) == std::mem::discriminant(&expected) {
self.advance()?;
Ok(())
} else {
Err(Error::cql_parse(format!(
"Expected {:?}, found {:?}",
expected, current
)))
}
} else {
Err(Error::cql_parse("Unexpected end of input"))
}
}
fn expect_integer(&mut self, context: &str) -> Result<i64> {
if let Some(Token::Integer(n)) = self.current_token {
self.advance()?;
Ok(n)
} else {
Err(Error::cql_parse(format!(
"Expected integer after {}",
context
)))
}
}
fn parse_column_ref(&mut self, table_or_column: String) -> Result<ColumnRef> {
if !self.eat(&Token::Dot)? {
return Ok(ColumnRef::new(table_or_column));
}
if let Some(Token::Identifier(column)) = self.current_token.clone() {
self.advance()?;
Ok(ColumnRef::qualified(table_or_column, column))
} else {
Err(Error::cql_parse(
"Expected column name after table qualifier",
))
}
}
pub fn parse_select_statement(&mut self) -> Result<SelectStatement> {
self.expect(Token::Select)?;
let select_clause = self.parse_select_clause()?;
let from_clause = if self.eat(&Token::From)? {
Some(self.parse_from_clause()?)
} else {
None
};
let where_clause = if self.eat(&Token::Where)? {
Some(self.parse_where_expression()?)
} else {
None
};
let group_by = if self.eat(&Token::GroupBy)? {
Some(self.parse_group_by_clause()?)
} else {
None
};
let having_clause = if self.eat(&Token::Having)? {
Some(self.parse_where_expression()?)
} else {
None
};
let order_by = if self.eat(&Token::OrderBy)? {
Some(self.parse_order_by_clause()?)
} else {
None
};
let limit = if self.eat(&Token::Limit)? {
Some(self.parse_limit_clause()?)
} else {
None
};
let offset = if self.eat(&Token::Offset)? {
Some(self.expect_integer("OFFSET")? as u64)
} else {
None
};
let allow_filtering = if self.eat(&Token::Allow)? {
self.expect(Token::Filtering)?;
true
} else {
false
};
Ok(SelectStatement {
select_clause,
from_clause,
where_clause,
group_by,
having_clause,
order_by,
limit,
offset,
allow_filtering,
})
}
fn parse_select_clause(&mut self) -> Result<SelectClause> {
let distinct = self.eat(&Token::Distinct)?;
if self.eat(&Token::Multiply)? {
return Ok(SelectClause::All);
}
let mut expressions = Vec::new();
loop {
expressions.push(self.parse_select_expression()?);
if !self.eat(&Token::Comma)? {
break;
}
}
if distinct {
Ok(SelectClause::Distinct(expressions))
} else {
Ok(SelectClause::Columns(expressions))
}
}
fn parse_select_expression(&mut self) -> Result<SelectExpression> {
let expr = self.parse_primary_expression()?;
if self.eat(&Token::As)? {
if let Some(Token::Identifier(alias)) = self.current_token.clone() {
self.advance()?;
return Ok(SelectExpression::Aliased(Box::new(expr), alias));
}
return Err(Error::cql_parse("Expected alias name after AS"));
}
Ok(expr)
}
fn parse_primary_expression(&mut self) -> Result<SelectExpression> {
if let Some(agg) = aggregate_for(self.peek()) {
self.advance()?;
return self.parse_aggregate_function(agg);
}
match self.current_token.clone() {
Some(Token::Identifier(name)) => {
self.advance()?;
if self.eat(&Token::LeftParen)? {
let mut args = Vec::new();
if !self.at(&Token::RightParen) {
loop {
args.push(self.parse_select_expression()?);
if !self.eat(&Token::Comma)? {
break;
}
}
}
self.expect(Token::RightParen)?;
return Ok(SelectExpression::Function(FunctionCall { name, args }));
}
let col = self.parse_column_ref(name)?;
Ok(SelectExpression::Column(col))
}
Some(Token::Integer(n)) => {
self.advance()?;
Ok(SelectExpression::Literal(Value::BigInt(n)))
}
Some(Token::Float(f)) => {
self.advance()?;
Ok(SelectExpression::Literal(Value::Float(f)))
}
Some(Token::String(s)) => {
self.advance()?;
Ok(SelectExpression::Literal(Value::Text(s)))
}
Some(Token::Boolean(b)) => {
self.advance()?;
Ok(SelectExpression::Literal(Value::Boolean(b)))
}
Some(Token::Null) => {
self.advance()?;
Ok(SelectExpression::Literal(Value::Null))
}
Some(Token::LeftParen) => {
self.advance()?;
let expr = self.parse_select_expression()?;
self.expect(Token::RightParen)?;
Ok(expr)
}
other => Err(Error::cql_parse(format!(
"Unexpected token in expression: {:?}",
other
))),
}
}
fn parse_aggregate_function(&mut self, agg_type: AggregateType) -> Result<SelectExpression> {
self.expect(Token::LeftParen)?;
let distinct = self.eat(&Token::Distinct)?;
let mut args = Vec::new();
if !self.at(&Token::RightParen) {
if self.eat(&Token::Multiply)? {
args.push(SelectExpression::Column(ColumnRef::new("*".to_string())));
} else {
loop {
args.push(self.parse_select_expression()?);
if !self.eat(&Token::Comma)? {
break;
}
}
}
}
self.expect(Token::RightParen)?;
Ok(SelectExpression::Aggregate(AggregateFunction {
function: agg_type,
args,
distinct,
}))
}
fn parse_from_clause(&mut self) -> Result<FromClause> {
let Some(Token::Identifier(first_identifier)) = self.current_token.clone() else {
return Err(Error::cql_parse("Expected table name in FROM clause"));
};
self.advance()?;
let table_name = if self.eat(&Token::Dot)? {
if let Some(Token::Identifier(actual_table)) = self.current_token.clone() {
self.advance()?;
format!("{}.{}", first_identifier, actual_table)
} else {
return Err(Error::cql_parse("Expected table name after keyspace"));
}
} else {
first_identifier
};
let table = TableId::new(table_name);
const CLAUSE_KEYWORDS: &[&str] = &["WHERE", "GROUP", "ORDER", "HAVING", "LIMIT"];
if let Some(Token::Identifier(alias)) = self.current_token.clone() {
let is_clause_kw = CLAUSE_KEYWORDS
.iter()
.any(|kw| alias.eq_ignore_ascii_case(kw));
if !is_clause_kw {
self.advance()?;
return Ok(FromClause::TableAlias(table, alias));
}
}
Ok(FromClause::Table(table))
}
fn parse_where_expression(&mut self) -> Result<WhereExpression> {
self.parse_or_expression()
}
fn parse_or_expression(&mut self) -> Result<WhereExpression> {
let first = self.parse_and_expression()?;
let mut or_exprs = vec![first];
while self.eat(&Token::Or)? {
or_exprs.push(self.parse_and_expression()?);
}
Ok(unwrap_singleton(or_exprs, WhereExpression::Or))
}
fn parse_and_expression(&mut self) -> Result<WhereExpression> {
let first = self.parse_not_expression()?;
let mut and_exprs = vec![first];
while self.eat(&Token::And)? {
and_exprs.push(self.parse_not_expression()?);
}
Ok(unwrap_singleton(and_exprs, WhereExpression::And))
}
fn parse_not_expression(&mut self) -> Result<WhereExpression> {
if self.eat(&Token::Not)? {
let expr = self.parse_comparison_expression()?;
Ok(WhereExpression::Not(Box::new(expr)))
} else {
self.parse_comparison_expression()
}
}
fn parse_comparison_expression(&mut self) -> Result<WhereExpression> {
if self.eat(&Token::LeftParen)? {
let expr = self.parse_where_expression()?;
self.expect(Token::RightParen)?;
return Ok(WhereExpression::Parentheses(Box::new(expr)));
}
let left = self.parse_select_expression()?;
let simple_op = match self.peek() {
Token::Equal => Some(ComparisonOperator::Equal),
Token::NotEqual => Some(ComparisonOperator::NotEqual),
Token::LessThan => Some(ComparisonOperator::LessThan),
Token::LessThanEqual => Some(ComparisonOperator::LessThanOrEqual),
Token::GreaterThan => Some(ComparisonOperator::GreaterThan),
Token::GreaterThanEqual => Some(ComparisonOperator::GreaterThanOrEqual),
Token::Like => Some(ComparisonOperator::Like),
_ => None,
};
if let Some(op) = simple_op {
self.advance()?;
let right = ComparisonRightSide::Value(self.parse_select_expression()?);
return Ok(WhereExpression::Comparison(ComparisonExpression {
left,
operator: op,
right,
}));
}
let operator = match self.peek() {
Token::In => {
self.advance()?;
let right = self.parse_in_expression()?;
return Ok(WhereExpression::Comparison(ComparisonExpression {
left,
operator: ComparisonOperator::In,
right,
}));
}
Token::Between => {
self.advance()?;
let start = self.parse_select_expression()?;
self.expect(Token::And)?;
let end = self.parse_select_expression()?;
return Ok(WhereExpression::Comparison(ComparisonExpression {
left,
operator: ComparisonOperator::Between,
right: ComparisonRightSide::Range(start, end),
}));
}
Token::Is => {
self.advance()?;
let op = if self.eat(&Token::Not)? {
ComparisonOperator::IsNotNull
} else {
ComparisonOperator::IsNull
};
self.expect(Token::Null)?;
op
}
Token::Contains => {
self.advance()?;
if self.eat(&Token::Key)? {
ComparisonOperator::ContainsKey
} else {
ComparisonOperator::Contains
}
}
other => {
return Err(Error::cql_parse(format!(
"Expected comparison operator, found {:?}",
other
)));
}
};
let right = match operator {
ComparisonOperator::IsNull | ComparisonOperator::IsNotNull => {
ComparisonRightSide::Value(SelectExpression::Literal(Value::Null))
}
_ => ComparisonRightSide::Value(self.parse_select_expression()?),
};
Ok(WhereExpression::Comparison(ComparisonExpression {
left,
operator,
right,
}))
}
fn parse_in_expression(&mut self) -> Result<ComparisonRightSide> {
self.expect(Token::LeftParen)?;
let mut values = Vec::new();
if !self.at(&Token::RightParen) {
loop {
values.push(self.parse_select_expression()?);
if !self.eat(&Token::Comma)? {
break;
}
}
}
self.expect(Token::RightParen)?;
Ok(ComparisonRightSide::ValueList(values))
}
fn parse_group_by_clause(&mut self) -> Result<GroupByClause> {
let mut columns = Vec::new();
loop {
let Some(Token::Identifier(name)) = self.current_token.clone() else {
return Err(Error::cql_parse("Expected column name in GROUP BY"));
};
self.advance()?;
columns.push(self.parse_column_ref(name)?);
if !self.eat(&Token::Comma)? {
break;
}
}
Ok(GroupByClause { columns })
}
fn parse_order_by_clause(&mut self) -> Result<OrderByClause> {
let mut items = Vec::new();
loop {
let expression = self.parse_select_expression()?;
let direction = if self.eat(&Token::Desc)? {
SortDirection::Descending
} else if self.eat(&Token::Asc)? {
SortDirection::Ascending
} else {
SortDirection::Ascending
};
items.push(OrderByItem {
expression,
direction,
});
if !self.eat(&Token::Comma)? {
break;
}
}
Ok(OrderByClause { items })
}
fn parse_limit_clause(&mut self) -> Result<LimitClause> {
let count = self.expect_integer("LIMIT")? as u64;
Ok(LimitClause {
count,
per_partition: false, })
}
}
fn unwrap_singleton<F>(mut exprs: Vec<WhereExpression>, wrap: F) -> WhereExpression
where
F: FnOnce(Vec<WhereExpression>) -> WhereExpression,
{
if exprs.len() == 1 {
exprs.pop().expect("checked len == 1")
} else {
wrap(exprs)
}
}
pub fn parse_select(cql: &str) -> Result<SelectStatement> {
let mut parser = SelectParser::new(cql)?;
parser.parse_select_statement()
}
#[cfg(all(test, feature = "state_machine"))]
mod tests {
use super::*;
#[test]
fn test_simple_select() {
let stmt = parse_select("SELECT * FROM users").unwrap();
assert_eq!(stmt.select_clause, SelectClause::All);
if let Some(FromClause::Table(table)) = stmt.from_clause {
assert_eq!(table.name(), "users");
} else {
panic!("Expected Table in FROM clause");
}
}
#[test]
fn test_select_with_columns() {
let stmt = parse_select("SELECT id, name, email FROM users").unwrap();
if let SelectClause::Columns(exprs) = stmt.select_clause {
assert_eq!(exprs.len(), 3);
} else {
panic!("Expected Columns in SELECT clause");
}
}
#[test]
fn test_select_constant() {
let stmt = parse_select("SELECT 1").unwrap();
assert!(stmt.from_clause.is_none());
if let SelectClause::Columns(exprs) = stmt.select_clause {
assert_eq!(exprs.len(), 1);
if let SelectExpression::Literal(Value::BigInt(1)) = &exprs[0] {
} else {
panic!("Expected literal BigInt 1, got: {:?}", &exprs[0]);
}
} else {
panic!("Expected Columns in SELECT clause");
}
}
#[test]
fn test_select_with_where() {
let stmt = parse_select("SELECT * FROM users WHERE id = 123").unwrap();
assert!(stmt.where_clause.is_some());
}
#[test]
fn test_select_with_aggregates() {
let stmt = parse_select("SELECT COUNT(*), AVG(age) FROM users GROUP BY city").unwrap();
assert!(stmt.requires_aggregation());
assert!(stmt.group_by.is_some());
}
#[test]
fn test_complex_where_clause() {
let stmt =
parse_select("SELECT * FROM users WHERE age > 21 AND (city = 'NYC' OR city = 'LA')")
.unwrap();
assert!(stmt.where_clause.is_some());
}
#[test]
fn test_order_by_and_limit() {
let stmt = parse_select("SELECT * FROM users ORDER BY created_at DESC, name ASC LIMIT 10")
.unwrap();
assert!(stmt.order_by.is_some());
assert!(stmt.limit.is_some());
if let Some(limit) = stmt.limit {
assert_eq!(limit.count, 10);
}
}
#[test]
fn test_in_clause() {
let stmt =
parse_select("SELECT * FROM users WHERE status IN ('active', 'pending', 'suspended')")
.unwrap();
assert!(stmt.where_clause.is_some());
}
#[test]
fn test_between_clause() {
let stmt = parse_select(
"SELECT * FROM events WHERE created_at BETWEEN '2024-01-01' AND '2024-12-31'",
)
.unwrap();
assert!(stmt.where_clause.is_some());
}
}