use async_trait::async_trait;
use crate::utils::{InternalQuery, QueryOperation, DataSource, Column, Predicate, PredicateOperator, PredicateValue, OrderBy, OrderColumn, OrderDirection};
use crate::utils::error::{QueryParsingError, NirvResult};
use sqlparser::ast::{Statement, Query, SelectItem, Expr, BinaryOperator, Value as SqlValue, OrderByExpr, FunctionArg, FunctionArgExpr};
use sqlparser::dialect::{PostgreSqlDialect, MySqlDialect, SQLiteDialect, GenericDialect};
use sqlparser::parser::Parser;
use regex::Regex;
#[async_trait]
pub trait QueryParser: Send + Sync {
async fn parse_sql(&self, sql: &str) -> NirvResult<InternalQuery>;
async fn validate_syntax(&self, sql: &str) -> NirvResult<bool>;
async fn extract_sources(&self, sql: &str) -> NirvResult<Vec<String>>;
}
pub struct DefaultQueryParser {
postgres_dialect: PostgreSqlDialect,
mysql_dialect: MySqlDialect,
sqlite_dialect: SQLiteDialect,
generic_dialect: GenericDialect,
source_regex: Regex,
}
impl DefaultQueryParser {
pub fn new() -> NirvResult<Self> {
let source_regex = Regex::new(r#"source\s*\(\s*['"]([^'"]+)['"]\s*\)"#)
.map_err(|e| QueryParsingError::InvalidSyntax(format!("Failed to compile source regex: {}", e)))?;
Ok(Self {
postgres_dialect: PostgreSqlDialect {},
mysql_dialect: MySqlDialect {},
sqlite_dialect: SQLiteDialect {},
generic_dialect: GenericDialect {},
source_regex,
})
}
pub fn parse(&self, sql: &str) -> NirvResult<InternalQuery> {
let statement = self.try_parse_with_dialects(sql)?;
match statement {
Statement::Query(query) => self.convert_query(*query),
_ => Err(QueryParsingError::UnsupportedFeature("Only SELECT queries are currently supported".to_string()).into()),
}
}
fn try_parse_with_dialects(&self, sql: &str) -> NirvResult<Statement> {
if let Ok(statements) = Parser::parse_sql(&self.postgres_dialect, sql) {
if let Some(statement) = statements.into_iter().next() {
return Ok(statement);
}
}
if let Ok(statements) = Parser::parse_sql(&self.mysql_dialect, sql) {
if let Some(statement) = statements.into_iter().next() {
return Ok(statement);
}
}
if let Ok(statements) = Parser::parse_sql(&self.sqlite_dialect, sql) {
if let Some(statement) = statements.into_iter().next() {
return Ok(statement);
}
}
if let Ok(statements) = Parser::parse_sql(&self.generic_dialect, sql) {
if let Some(statement) = statements.into_iter().next() {
return Ok(statement);
}
}
Err(QueryParsingError::InvalidSyntax("Failed to parse SQL with any supported dialect".to_string()).into())
}
fn convert_query(&self, query: Query) -> NirvResult<InternalQuery> {
let mut internal_query = InternalQuery::new(QueryOperation::Select);
if let sqlparser::ast::SetExpr::Select(body) = query.body.as_ref() {
internal_query.projections = self.extract_projections(&body.projection)?;
internal_query.sources = self.extract_sources(&body.from)?;
if let Some(selection) = &body.selection {
internal_query.predicates = self.extract_predicates(selection)?;
}
if !query.order_by.is_empty() {
internal_query.ordering = Some(self.extract_order_by(&query.order_by)?);
}
if let Some(limit) = &query.limit {
internal_query.limit = Some(self.extract_limit(limit)?);
}
} else {
return Err(QueryParsingError::UnsupportedFeature("Only SELECT queries are supported".to_string()).into());
}
Ok(internal_query)
}
fn extract_projections(&self, projection: &[SelectItem]) -> NirvResult<Vec<Column>> {
let mut columns = Vec::new();
for item in projection {
match item {
SelectItem::UnnamedExpr(expr) => {
let column = self.extract_column_from_expr(expr, None)?;
columns.push(column);
}
SelectItem::ExprWithAlias { expr, alias } => {
let column = self.extract_column_from_expr(expr, Some(alias.value.clone()))?;
columns.push(column);
}
SelectItem::Wildcard(_) => {
columns.push(Column {
name: "*".to_string(),
alias: None,
source: None,
});
}
SelectItem::QualifiedWildcard(object_name, _) => {
columns.push(Column {
name: "*".to_string(),
alias: None,
source: Some(object_name.to_string()),
});
}
}
}
Ok(columns)
}
fn extract_column_from_expr(&self, expr: &Expr, alias: Option<String>) -> NirvResult<Column> {
match expr {
Expr::Identifier(ident) => {
Ok(Column {
name: ident.value.clone(),
alias,
source: None,
})
}
Expr::CompoundIdentifier(idents) => {
if idents.len() == 2 {
Ok(Column {
name: idents[1].value.clone(),
alias,
source: Some(idents[0].value.clone()),
})
} else {
Ok(Column {
name: idents.last().unwrap().value.clone(),
alias,
source: None,
})
}
}
Expr::Function(func) => {
if func.name.to_string().to_lowercase() == "source" {
return Err(QueryParsingError::InvalidSourceFormat("source() function should be used in FROM clause, not SELECT".to_string()).into());
}
Ok(Column {
name: func.name.to_string(),
alias,
source: None,
})
}
_ => {
Ok(Column {
name: "expr".to_string(),
alias,
source: None,
})
}
}
}
fn extract_sources(&self, from: &[sqlparser::ast::TableWithJoins]) -> NirvResult<Vec<DataSource>> {
let mut sources = Vec::new();
for table_with_joins in from {
let source = self.extract_source_from_table(&table_with_joins.relation)?;
sources.push(source);
}
if sources.is_empty() {
return Err(QueryParsingError::MissingSource.into());
}
Ok(sources)
}
fn extract_source_from_table(&self, table: &sqlparser::ast::TableFactor) -> NirvResult<DataSource> {
match table {
sqlparser::ast::TableFactor::Table { name, alias, args, .. } => {
let table_name = name.to_string();
if table_name.to_lowercase() == "source" && args.is_some() {
let source_spec = self.extract_source_from_function_args(args.as_ref().unwrap())?;
Ok(DataSource {
object_type: source_spec.0,
identifier: source_spec.1,
alias: alias.as_ref().map(|a| a.name.value.clone()),
})
} else {
Ok(DataSource {
object_type: "table".to_string(),
identifier: table_name,
alias: alias.as_ref().map(|a| a.name.value.clone()),
})
}
}
sqlparser::ast::TableFactor::Derived { alias, .. } => {
Ok(DataSource {
object_type: "subquery".to_string(),
identifier: "derived".to_string(),
alias: alias.as_ref().map(|a| a.name.value.clone()),
})
}
sqlparser::ast::TableFactor::Function { name, args, alias, .. } => {
if name.to_string().to_lowercase() == "source" {
let source_spec = self.extract_source_from_function_args(args)?;
Ok(DataSource {
object_type: source_spec.0,
identifier: source_spec.1,
alias: alias.as_ref().map(|a| a.name.value.clone()),
})
} else {
Err(QueryParsingError::UnsupportedFeature(format!("Function {} not supported in FROM clause", name)).into())
}
}
_ => Err(QueryParsingError::UnsupportedFeature("Unsupported table reference type".to_string()).into()),
}
}
fn extract_source_function(&self, table_name: &str) -> NirvResult<Option<(String, String)>> {
if let Some(captures) = self.source_regex.captures(table_name) {
if let Some(source_spec) = captures.get(1) {
let spec = source_spec.as_str();
if let Some(dot_pos) = spec.find('.') {
let object_type = spec[..dot_pos].to_string();
let identifier = spec[dot_pos + 1..].to_string();
Ok(Some((object_type, identifier)))
} else {
Ok(Some(("table".to_string(), spec.to_string())))
}
} else {
Err(QueryParsingError::InvalidSourceFormat("Empty source specification".to_string()).into())
}
} else {
Ok(None)
}
}
fn extract_source_from_function_args(&self, args: &[FunctionArg]) -> NirvResult<(String, String)> {
if args.len() != 1 {
return Err(QueryParsingError::InvalidSourceFormat("source() function requires exactly one argument".to_string()).into());
}
if let FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(SqlValue::SingleQuotedString(spec)))) = &args[0] {
if let Some(dot_pos) = spec.find('.') {
let object_type = spec[..dot_pos].to_string();
let identifier = spec[dot_pos + 1..].to_string();
Ok((object_type, identifier))
} else {
Ok(("table".to_string(), spec.to_string()))
}
} else {
Err(QueryParsingError::InvalidSourceFormat("source() function argument must be a string literal".to_string()).into())
}
}
fn extract_predicates(&self, expr: &Expr) -> NirvResult<Vec<Predicate>> {
let mut predicates = Vec::new();
self.extract_predicates_recursive(expr, &mut predicates)?;
Ok(predicates)
}
fn extract_predicates_recursive(&self, expr: &Expr, predicates: &mut Vec<Predicate>) -> NirvResult<()> {
match expr {
Expr::BinaryOp { left, op, right } => {
match op {
BinaryOperator::And => {
self.extract_predicates_recursive(left, predicates)?;
self.extract_predicates_recursive(right, predicates)?;
}
BinaryOperator::Or => {
self.extract_predicates_recursive(left, predicates)?;
self.extract_predicates_recursive(right, predicates)?;
}
_ => {
let predicate = self.create_predicate_from_binary_op(left, op, right)?;
predicates.push(predicate);
}
}
}
Expr::IsNull(expr) => {
let column = self.extract_column_name_from_expr(expr)?;
predicates.push(Predicate {
column,
operator: PredicateOperator::IsNull,
value: PredicateValue::Null,
});
}
Expr::IsNotNull(expr) => {
let column = self.extract_column_name_from_expr(expr)?;
predicates.push(Predicate {
column,
operator: PredicateOperator::IsNotNull,
value: PredicateValue::Null,
});
}
_ => {
}
}
Ok(())
}
fn create_predicate_from_binary_op(&self, left: &Expr, op: &BinaryOperator, right: &Expr) -> NirvResult<Predicate> {
let column = self.extract_column_name_from_expr(left)?;
let operator = self.convert_binary_operator(op)?;
let value = self.extract_predicate_value_from_expr(right)?;
Ok(Predicate {
column,
operator,
value,
})
}
fn extract_column_name_from_expr(&self, expr: &Expr) -> NirvResult<String> {
match expr {
Expr::Identifier(ident) => Ok(ident.value.clone()),
Expr::CompoundIdentifier(idents) => {
if idents.len() >= 2 {
Ok(format!("{}.{}", idents[0].value, idents[1].value))
} else {
Ok(idents[0].value.clone())
}
}
_ => Err(QueryParsingError::InvalidSyntax("Expected column identifier in predicate".to_string()).into()),
}
}
fn convert_binary_operator(&self, op: &BinaryOperator) -> NirvResult<PredicateOperator> {
match op {
BinaryOperator::Eq => Ok(PredicateOperator::Equal),
BinaryOperator::NotEq => Ok(PredicateOperator::NotEqual),
BinaryOperator::Gt => Ok(PredicateOperator::GreaterThan),
BinaryOperator::GtEq => Ok(PredicateOperator::GreaterThanOrEqual),
BinaryOperator::Lt => Ok(PredicateOperator::LessThan),
BinaryOperator::LtEq => Ok(PredicateOperator::LessThanOrEqual),
_ => Err(QueryParsingError::UnsupportedFeature(format!("Operator {:?} not supported", op)).into()),
}
}
fn extract_predicate_value_from_expr(&self, expr: &Expr) -> NirvResult<PredicateValue> {
match expr {
Expr::Value(sql_value) => self.convert_sql_value(sql_value),
Expr::Identifier(ident) => Ok(PredicateValue::String(ident.value.clone())),
_ => Err(QueryParsingError::UnsupportedFeature("Complex expressions in predicates not yet supported".to_string()).into()),
}
}
fn convert_sql_value(&self, value: &SqlValue) -> NirvResult<PredicateValue> {
match value {
SqlValue::SingleQuotedString(s) | SqlValue::DoubleQuotedString(s) => {
Ok(PredicateValue::String(s.clone()))
}
SqlValue::Number(n, _) => {
if let Ok(int_val) = n.parse::<i64>() {
Ok(PredicateValue::Integer(int_val))
} else if let Ok(float_val) = n.parse::<f64>() {
Ok(PredicateValue::Number(float_val))
} else {
Err(QueryParsingError::InvalidSyntax(format!("Invalid number format: {}", n)).into())
}
}
SqlValue::Boolean(b) => Ok(PredicateValue::Boolean(*b)),
SqlValue::Null => Ok(PredicateValue::Null),
_ => Err(QueryParsingError::UnsupportedFeature(format!("Value type {:?} not supported", value)).into()),
}
}
fn extract_order_by(&self, order_by: &[OrderByExpr]) -> NirvResult<OrderBy> {
let mut columns = Vec::new();
for order_expr in order_by {
let column_name = self.extract_column_name_from_expr(&order_expr.expr)?;
let direction = if order_expr.asc.unwrap_or(true) {
OrderDirection::Ascending
} else {
OrderDirection::Descending
};
columns.push(OrderColumn {
column: column_name,
direction,
});
}
Ok(OrderBy { columns })
}
fn extract_limit(&self, limit_expr: &Expr) -> NirvResult<u64> {
match limit_expr {
Expr::Value(SqlValue::Number(n, _)) => {
n.parse::<u64>()
.map_err(|_| QueryParsingError::InvalidSyntax(format!("Invalid LIMIT value: {}", n)).into())
}
_ => Err(QueryParsingError::InvalidSyntax("LIMIT must be a number".to_string()).into()),
}
}
}
impl Default for DefaultQueryParser {
fn default() -> Self {
Self::new().expect("Failed to create default QueryParser")
}
}
#[async_trait]
impl QueryParser for DefaultQueryParser {
async fn parse_sql(&self, sql: &str) -> NirvResult<InternalQuery> {
self.parse(sql)
}
async fn validate_syntax(&self, sql: &str) -> NirvResult<bool> {
match self.try_parse_with_dialects(sql) {
Ok(_) => Ok(true),
Err(_) => Ok(false),
}
}
async fn extract_sources(&self, sql: &str) -> NirvResult<Vec<String>> {
let query = self.parse(sql)?;
Ok(query.sources.into_iter()
.map(|source| format!("{}.{}", source.object_type, source.identifier))
.collect())
}
}
#[cfg
(test)]
mod tests {
use super::*;
use crate::utils::{QueryOperation, DataSource, Column, Predicate, PredicateOperator, PredicateValue, OrderDirection};
fn create_parser() -> DefaultQueryParser {
DefaultQueryParser::new().expect("Failed to create parser")
}
#[test]
fn test_parser_creation() {
let parser = DefaultQueryParser::new();
assert!(parser.is_ok());
}
#[test]
fn test_simple_select_all() {
let parser = create_parser();
let sql = "SELECT * FROM source('postgres.users')";
let result = parser.parse(sql);
assert!(result.is_ok());
let query = result.unwrap();
assert_eq!(query.operation, QueryOperation::Select);
assert_eq!(query.projections.len(), 1);
assert_eq!(query.projections[0].name, "*");
assert_eq!(query.sources.len(), 1);
assert_eq!(query.sources[0].object_type, "postgres");
assert_eq!(query.sources[0].identifier, "users");
}
#[test]
fn test_select_with_columns() {
let parser = create_parser();
let sql = "SELECT id, name, email FROM source('postgres.users')";
let result = parser.parse(sql);
assert!(result.is_ok());
let query = result.unwrap();
assert_eq!(query.projections.len(), 3);
assert_eq!(query.projections[0].name, "id");
assert_eq!(query.projections[1].name, "name");
assert_eq!(query.projections[2].name, "email");
}
#[test]
fn test_select_with_aliases() {
let parser = create_parser();
let sql = "SELECT id as user_id, name as full_name FROM source('postgres.users') as u";
let result = parser.parse(sql);
assert!(result.is_ok());
let query = result.unwrap();
assert_eq!(query.projections.len(), 2);
assert_eq!(query.projections[0].name, "id");
assert_eq!(query.projections[0].alias, Some("user_id".to_string()));
assert_eq!(query.projections[1].name, "name");
assert_eq!(query.projections[1].alias, Some("full_name".to_string()));
assert_eq!(query.sources[0].alias, Some("u".to_string()));
}
#[test]
fn test_source_function_parsing() {
let parser = create_parser();
let test_cases = vec![
("SELECT * FROM source('postgres.users')", "postgres", "users"),
("SELECT * FROM source('file.data.csv')", "file", "data.csv"),
("SELECT * FROM source('api.endpoint')", "api", "endpoint"),
("SELECT * FROM source('users')", "table", "users"), ];
for (sql, expected_type, expected_id) in test_cases {
let result = parser.parse(sql);
assert!(result.is_ok(), "Failed to parse: {}", sql);
let query = result.unwrap();
assert_eq!(query.sources.len(), 1);
assert_eq!(query.sources[0].object_type, expected_type);
assert_eq!(query.sources[0].identifier, expected_id);
}
}
#[test]
fn test_where_clause_parsing() {
let parser = create_parser();
let sql = "SELECT * FROM source('postgres.users') WHERE age > 18 AND name = 'John'";
let result = parser.parse(sql);
assert!(result.is_ok());
let query = result.unwrap();
assert_eq!(query.predicates.len(), 2);
assert_eq!(query.predicates[0].column, "age");
assert_eq!(query.predicates[0].operator, PredicateOperator::GreaterThan);
assert_eq!(query.predicates[0].value, PredicateValue::Integer(18));
assert_eq!(query.predicates[1].column, "name");
assert_eq!(query.predicates[1].operator, PredicateOperator::Equal);
assert_eq!(query.predicates[1].value, PredicateValue::String("John".to_string()));
}
#[test]
fn test_various_operators() {
let parser = create_parser();
let test_cases = vec![
("SELECT * FROM source('test') WHERE id = 1", PredicateOperator::Equal),
("SELECT * FROM source('test') WHERE id != 1", PredicateOperator::NotEqual),
("SELECT * FROM source('test') WHERE id > 1", PredicateOperator::GreaterThan),
("SELECT * FROM source('test') WHERE id >= 1", PredicateOperator::GreaterThanOrEqual),
("SELECT * FROM source('test') WHERE id < 1", PredicateOperator::LessThan),
("SELECT * FROM source('test') WHERE id <= 1", PredicateOperator::LessThanOrEqual),
];
for (sql, expected_op) in test_cases {
let result = parser.parse(sql);
assert!(result.is_ok(), "Failed to parse: {}", sql);
let query = result.unwrap();
assert_eq!(query.predicates.len(), 1);
assert_eq!(query.predicates[0].operator, expected_op);
}
}
#[test]
fn test_null_predicates() {
let parser = create_parser();
let sql1 = "SELECT * FROM source('test') WHERE name IS NULL";
let result1 = parser.parse(sql1);
assert!(result1.is_ok());
let query1 = result1.unwrap();
assert_eq!(query1.predicates.len(), 1);
assert_eq!(query1.predicates[0].operator, PredicateOperator::IsNull);
let sql2 = "SELECT * FROM source('test') WHERE name IS NOT NULL";
let result2 = parser.parse(sql2);
assert!(result2.is_ok());
let query2 = result2.unwrap();
assert_eq!(query2.predicates.len(), 1);
assert_eq!(query2.predicates[0].operator, PredicateOperator::IsNotNull);
}
#[test]
fn test_order_by_clause() {
let parser = create_parser();
let sql = "SELECT * FROM source('postgres.users') ORDER BY name ASC, age DESC";
let result = parser.parse(sql);
assert!(result.is_ok());
let query = result.unwrap();
assert!(query.ordering.is_some());
let ordering = query.ordering.unwrap();
assert_eq!(ordering.columns.len(), 2);
assert_eq!(ordering.columns[0].column, "name");
assert_eq!(ordering.columns[0].direction, OrderDirection::Ascending);
assert_eq!(ordering.columns[1].column, "age");
assert_eq!(ordering.columns[1].direction, OrderDirection::Descending);
}
#[test]
fn test_limit_clause() {
let parser = create_parser();
let sql = "SELECT * FROM source('postgres.users') LIMIT 10";
let result = parser.parse(sql);
assert!(result.is_ok());
let query = result.unwrap();
assert!(query.limit.is_some());
assert_eq!(query.limit.unwrap(), 10);
}
#[test]
fn test_complex_query() {
let parser = create_parser();
let sql = "SELECT id, name as full_name FROM source('postgres.users') as u WHERE age >= 21 AND status = 'active' ORDER BY name ASC LIMIT 50";
let result = parser.parse(sql);
assert!(result.is_ok());
let query = result.unwrap();
assert_eq!(query.projections.len(), 2);
assert_eq!(query.projections[0].name, "id");
assert_eq!(query.projections[1].alias, Some("full_name".to_string()));
assert_eq!(query.sources.len(), 1);
assert_eq!(query.sources[0].object_type, "postgres");
assert_eq!(query.sources[0].identifier, "users");
assert_eq!(query.sources[0].alias, Some("u".to_string()));
assert_eq!(query.predicates.len(), 2);
assert!(query.ordering.is_some());
assert_eq!(query.limit, Some(50));
}
#[test]
fn test_postgresql_dialect() {
let parser = create_parser();
let sql = "SELECT id, name FROM source('postgres.users') WHERE created_at > '2023-01-01'";
let result = parser.parse(sql);
assert!(result.is_ok());
}
#[test]
fn test_mysql_dialect() {
let parser = create_parser();
let sql = "SELECT id, name FROM source('mysql.users') WHERE created_at > '2023-01-01'";
let result = parser.parse(sql);
assert!(result.is_ok());
}
#[test]
fn test_sqlite_dialect() {
let parser = create_parser();
let sql = "SELECT id, name FROM source('sqlite.users') WHERE created_at > '2023-01-01'";
let result = parser.parse(sql);
assert!(result.is_ok());
}
#[test]
fn test_invalid_sql_syntax() {
let parser = create_parser();
let sql = "INVALID SQL SYNTAX";
let result = parser.parse(sql);
assert!(result.is_err());
match result.unwrap_err() {
crate::utils::error::NirvError::QueryParsing(QueryParsingError::InvalidSyntax(_)) => {},
_ => panic!("Expected InvalidSyntax error"),
}
}
#[test]
fn test_missing_source_function() {
let parser = create_parser();
let sql = "SELECT * FROM users"; let result = parser.parse(sql);
assert!(result.is_ok()); let query = result.unwrap();
assert_eq!(query.sources[0].object_type, "table");
assert_eq!(query.sources[0].identifier, "users");
}
#[test]
fn test_invalid_source_format() {
let parser = create_parser();
let sql = "SELECT * FROM source()"; let result = parser.parse(sql);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(),
crate::utils::error::NirvError::QueryParsing(QueryParsingError::InvalidSourceFormat(_)) |
crate::utils::error::NirvError::QueryParsing(QueryParsingError::InvalidSyntax(_))
));
}
#[test]
fn test_unsupported_query_type() {
let parser = create_parser();
let sql = "INSERT INTO users (name) VALUES ('John')";
let result = parser.parse(sql);
assert!(result.is_err());
match result.unwrap_err() {
crate::utils::error::NirvError::QueryParsing(QueryParsingError::UnsupportedFeature(_)) => {},
_ => panic!("Expected UnsupportedFeature error"),
}
}
#[test]
fn test_source_function_in_select_clause() {
let parser = create_parser();
let sql = "SELECT source('test') FROM users";
let result = parser.parse(sql);
assert!(result.is_err());
match result.unwrap_err() {
crate::utils::error::NirvError::QueryParsing(QueryParsingError::InvalidSourceFormat(_)) => {},
_ => panic!("Expected InvalidSourceFormat error"),
}
}
#[test]
fn test_compound_identifiers() {
let parser = create_parser();
let sql = "SELECT u.id, u.name FROM source('postgres.users') as u WHERE u.age > 18";
let result = parser.parse(sql);
assert!(result.is_ok());
let query = result.unwrap();
assert_eq!(query.projections.len(), 2);
assert_eq!(query.projections[0].name, "id");
assert_eq!(query.projections[0].source, Some("u".to_string()));
assert_eq!(query.projections[1].name, "name");
assert_eq!(query.projections[1].source, Some("u".to_string()));
assert_eq!(query.predicates.len(), 1);
assert_eq!(query.predicates[0].column, "u.age");
}
#[test]
fn test_various_value_types() {
let parser = create_parser();
let test_cases = vec![
("SELECT * FROM source('test') WHERE str_col = 'text'", PredicateValue::String("text".to_string())),
("SELECT * FROM source('test') WHERE int_col = 42", PredicateValue::Integer(42)),
("SELECT * FROM source('test') WHERE float_col = 3.14", PredicateValue::Number(3.14)),
("SELECT * FROM source('test') WHERE bool_col = true", PredicateValue::Boolean(true)),
("SELECT * FROM source('test') WHERE null_col = NULL", PredicateValue::Null),
];
for (sql, expected_value) in test_cases {
let result = parser.parse(sql);
assert!(result.is_ok(), "Failed to parse: {}", sql);
let query = result.unwrap();
assert_eq!(query.predicates.len(), 1);
assert_eq!(query.predicates[0].value, expected_value);
}
}
#[test]
fn test_double_quoted_strings() {
let parser = create_parser();
let sql = r#"SELECT * FROM source('postgres.users') WHERE name = "John""#;
let result = parser.parse(sql);
assert!(result.is_ok());
let query = result.unwrap();
assert_eq!(query.sources[0].object_type, "postgres");
assert_eq!(query.sources[0].identifier, "users");
assert_eq!(query.predicates[0].value, PredicateValue::String("John".to_string()));
}
#[test]
fn test_qualified_wildcard() {
let parser = create_parser();
let sql = "SELECT u.* FROM source('postgres.users') as u";
let result = parser.parse(sql);
assert!(result.is_ok());
let query = result.unwrap();
assert_eq!(query.projections.len(), 1);
assert_eq!(query.projections[0].name, "*");
assert_eq!(query.projections[0].source, Some("u".to_string()));
}
}