use crate::ast::{ComparisonOperator, Expr, Literal, Predicate};
use crate::error::SqlGenerationError;
use base64::{engine::general_purpose, Engine as _};
fn generate_expr_sql(
expr: &Expr,
placeholder_count: &mut Option<usize>,
found_placeholders: &mut usize,
buffer: &mut String,
) -> Result<(), SqlGenerationError> {
match expr {
Expr::Placeholder => {
*found_placeholders += 1;
if let Some(expected) = placeholder_count {
if *found_placeholders > *expected {
return Err(SqlGenerationError::PlaceholderCountMismatch { expected: *expected, found: *found_placeholders });
}
}
buffer.push('?');
}
Expr::Literal(lit) => match lit {
Literal::I16(i) => {
buffer.push_str(&i.to_string());
}
Literal::I32(i) => {
buffer.push_str(&i.to_string());
}
Literal::I64(i) => {
buffer.push_str(&i.to_string());
}
Literal::F64(f) => {
buffer.push_str(&f.to_string());
}
Literal::Bool(b) => {
buffer.push_str(if *b { "true" } else { "false" });
}
Literal::String(s) => {
buffer.push('\'');
for c in s.chars() {
match c {
'\'' => buffer.push_str("''"), '\0' => {
continue;
}
_ => buffer.push(c),
}
}
buffer.push('\'');
}
Literal::EntityId(ulid) => {
buffer.push('\'');
buffer.push_str(&general_purpose::URL_SAFE_NO_PAD.encode(ulid.to_bytes()));
buffer.push('\'');
}
Literal::Object(bytes) | Literal::Binary(bytes) => {
buffer.push('\'');
buffer.push_str(&String::from_utf8_lossy(bytes));
buffer.push('\'');
}
Literal::Json(value) => {
buffer.push('\'');
buffer.push_str(&value.to_string());
buffer.push('\'');
}
},
Expr::Path(path) => {
for (i, step) in path.steps.iter().enumerate() {
if i > 0 {
buffer.push('.');
}
buffer.push('"');
buffer.push_str(step);
buffer.push('"');
}
}
Expr::ExprList(exprs) => {
buffer.push('(');
for (i, expr) in exprs.iter().enumerate() {
if i > 0 {
buffer.push_str(", ");
}
match expr {
Expr::Placeholder => {
*found_placeholders += 1;
if let Some(expected) = placeholder_count {
if *found_placeholders > *expected {
return Err(SqlGenerationError::PlaceholderCountMismatch {
expected: *expected,
found: *found_placeholders,
});
}
}
buffer.push('?');
}
Expr::Literal(lit) => match lit {
Literal::I16(i) => {
buffer.push_str(&i.to_string());
}
Literal::I32(i) => {
buffer.push_str(&i.to_string());
}
Literal::I64(i) => {
buffer.push_str(&i.to_string());
}
Literal::F64(f) => {
buffer.push_str(&f.to_string());
}
Literal::String(s) => {
buffer.push('\'');
for c in s.chars() {
match c {
'\'' => buffer.push_str("''"), '\0' => {
continue;
}
_ => buffer.push(c),
}
}
buffer.push('\'');
}
Literal::Bool(b) => {
buffer.push_str(if *b { "true" } else { "false" });
}
Literal::EntityId(ulid) => {
buffer.push('\'');
buffer.push_str(&general_purpose::URL_SAFE_NO_PAD.encode(ulid.to_bytes()));
buffer.push('\'');
}
Literal::Object(_bytes) | Literal::Binary(_bytes) => {
todo!("Object and Binary literals");
}
Literal::Json(value) => {
buffer.push('\'');
buffer.push_str(&value.to_string());
buffer.push('\'');
}
},
_ => {
return Err(SqlGenerationError::InvalidExpression(
"Only literal expressions and placeholders are supported in IN lists".to_string(),
))
}
}
}
buffer.push(')');
}
_ => return Err(SqlGenerationError::InvalidExpression("Only literal, identifier, and list expressions are supported".to_string())),
}
Ok(())
}
fn comparison_op_to_sql(op: &ComparisonOperator) -> Result<&'static str, SqlGenerationError> {
Ok(match op {
ComparisonOperator::Equal => "=",
ComparisonOperator::NotEqual => "<>",
ComparisonOperator::GreaterThan => ">",
ComparisonOperator::GreaterThanOrEqual => ">=",
ComparisonOperator::LessThan => "<",
ComparisonOperator::LessThanOrEqual => "<=",
ComparisonOperator::In => "IN",
ComparisonOperator::Between => return Err(SqlGenerationError::UnsupportedOperator("BETWEEN operator is not yet supported")),
})
}
pub fn generate_selection_sql(predicate: &Predicate, expected_placeholders: Option<usize>) -> Result<String, SqlGenerationError> {
let mut placeholder_count = expected_placeholders;
let mut found_placeholders = 0;
let mut buffer = String::new();
generate_selection_sql_inner(predicate, &mut placeholder_count, &mut found_placeholders, &mut buffer)?;
if let Some(expected) = expected_placeholders {
if found_placeholders != expected {
return Err(SqlGenerationError::PlaceholderCountMismatch { expected, found: found_placeholders });
}
}
Ok(buffer)
}
fn generate_selection_sql_inner(
predicate: &Predicate,
placeholder_count: &mut Option<usize>,
found_placeholders: &mut usize,
buffer: &mut String,
) -> Result<(), SqlGenerationError> {
match predicate {
Predicate::Comparison { left, operator, right } => {
generate_expr_sql(left, placeholder_count, found_placeholders, buffer)?;
buffer.push(' ');
buffer.push_str(comparison_op_to_sql(operator)?);
buffer.push(' ');
generate_expr_sql(right, placeholder_count, found_placeholders, buffer)?;
}
Predicate::And(left, right) => {
generate_selection_sql_inner(left, placeholder_count, found_placeholders, buffer)?;
buffer.push_str(" AND ");
generate_selection_sql_inner(right, placeholder_count, found_placeholders, buffer)?;
}
Predicate::Or(left, right) => {
buffer.push('(');
generate_selection_sql_inner(left, placeholder_count, found_placeholders, buffer)?;
buffer.push_str(" OR ");
generate_selection_sql_inner(right, placeholder_count, found_placeholders, buffer)?;
buffer.push(')');
}
Predicate::Not(pred) => {
buffer.push_str("NOT (");
generate_selection_sql_inner(pred, placeholder_count, found_placeholders, buffer)?;
buffer.push(')');
}
Predicate::IsNull(expr) => {
generate_expr_sql(expr, placeholder_count, found_placeholders, buffer)?;
buffer.push_str(" IS NULL");
}
Predicate::True => buffer.push_str("TRUE"),
Predicate::False => buffer.push_str("FALSE"),
Predicate::Placeholder => {
return Err(SqlGenerationError::InvalidExpression("Placeholder must be transformed before SQL generation".to_string()))
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ast::{ComparisonOperator, Expr, Literal, PathExpr, Predicate};
use crate::error::SqlGenerationError;
use crate::parser::parse_selection;
use anyhow::Result;
#[test]
fn test_simple_equality() -> Result<()> {
let selection = parse_selection("name = 'Alice'").unwrap();
let sql = generate_selection_sql(&selection.predicate, None)?;
assert_eq!(sql, r#""name" = 'Alice'"#);
Ok(())
}
#[test]
fn test_and_condition() -> Result<()> {
let selection = parse_selection("name = 'Alice' AND age = '30'").unwrap();
let sql = generate_selection_sql(&selection.predicate, None)?;
assert_eq!(sql, r#""name" = 'Alice' AND "age" = '30'"#);
Ok(())
}
#[test]
fn test_complex_condition() -> Result<()> {
let selection = parse_selection("(name = 'Alice' OR name = 'Charlie') AND age >= '30' AND age <= '40'").unwrap();
let sql = generate_selection_sql(&selection.predicate, None)?;
assert_eq!(sql, r#"("name" = 'Alice' OR "name" = 'Charlie') AND "age" >= '30' AND "age" <= '40'"#);
Ok(())
}
#[test]
fn test_including_collection_identifier() -> Result<()> {
let selection = parse_selection("person.name = 'Alice'").unwrap();
let sql = generate_selection_sql(&selection.predicate, None)?;
assert_eq!(sql, r#""person"."name" = 'Alice'"#);
Ok(())
}
#[test]
fn test_in_operator() -> Result<()> {
let selection = parse_selection("name IN ('Alice', 'Bob', 'Charlie')").unwrap();
let sql = generate_selection_sql(&selection.predicate, None)?;
assert_eq!(sql, r#""name" IN ('Alice', 'Bob', 'Charlie')"#);
Ok(())
}
#[test]
fn test_placeholder_with_none_count() -> Result<()> {
let query = "user_id = ?";
let selection = parse_selection(query).unwrap();
let sql = generate_selection_sql(&selection.predicate, None)?;
assert_eq!(sql, r#""user_id" = ?"#);
Ok(())
}
#[test]
fn test_placeholder_with_exact_count() -> Result<()> {
let query = "user_id = ? AND status = ?";
let selection = parse_selection(query).unwrap();
let sql = generate_selection_sql(&selection.predicate, Some(2))?;
assert_eq!(sql, r#""user_id" = ? AND "status" = ?"#);
Ok(())
}
#[test]
fn test_placeholder_count_mismatch_too_few() -> Result<()> {
let selection = parse_selection("user_id = ? AND status = ?")?;
match generate_selection_sql(&selection.predicate, Some(1)) {
Err(SqlGenerationError::PlaceholderCountMismatch { expected, found }) => {
assert_eq!(expected, 1);
assert_eq!(found, 2);
}
_ => panic!("Expected PlaceholderCountMismatch error"),
}
Ok(())
}
#[test]
fn test_placeholder_count_mismatch_too_many() -> Result<()> {
let selection = parse_selection("user_id = ?")?;
match generate_selection_sql(&selection.predicate, Some(2)) {
Err(SqlGenerationError::PlaceholderCountMismatch { expected, found }) => {
assert_eq!(expected, 2);
assert_eq!(found, 1);
}
_ => panic!("Expected PlaceholderCountMismatch error"),
}
Ok(())
}
#[test]
fn test_placeholder_in_lists() -> Result<()> {
let query = "status IN (?, ?, ?)";
let selection = parse_selection(query).unwrap();
let sql = generate_selection_sql(&selection.predicate, Some(3))?;
assert_eq!(sql, r#""status" IN (?, ?, ?)"#);
Ok(())
}
#[test]
fn test_placeholder_with_zero_count() -> Result<()> {
let query = "user_id = 123";
let selection = parse_selection(query).unwrap();
let sql = generate_selection_sql(&selection.predicate, Some(0))?;
assert_eq!(sql, r#""user_id" = 123"#);
Ok(())
}
#[test]
fn test_string_escaping() -> Result<()> {
let predicate = Predicate::Comparison {
left: Box::new(Expr::Path(PathExpr::simple("name"))),
operator: ComparisonOperator::Equal,
right: Box::new(Expr::Literal(Literal::String("O'Brien".to_string()))),
};
let sql = generate_selection_sql(&predicate, None)?;
assert_eq!(sql, r#""name" = 'O''Brien'"#);
Ok(())
}
#[test]
fn test_null_byte_handling() -> Result<()> {
let predicate = Predicate::Comparison {
left: Box::new(Expr::Path(PathExpr::simple("data"))),
operator: ComparisonOperator::Equal,
right: Box::new(Expr::Literal(Literal::String("test\0data".to_string()))),
};
let sql = generate_selection_sql(&predicate, None)?;
assert_eq!(sql, r#""data" = 'testdata'"#);
Ok(())
}
#[test]
fn test_placeholder_with_zero_count_but_has_placeholder() -> Result<()> {
let selection = parse_selection("user_id = ?")?;
match generate_selection_sql(&selection.predicate, Some(0)) {
Err(SqlGenerationError::PlaceholderCountMismatch { expected, found }) => {
assert_eq!(expected, 0);
assert_eq!(found, 1);
}
_ => panic!("Expected PlaceholderCountMismatch error"),
}
Ok(())
}
}