use sqlparser::ast::{
Expr as SqlExpr, Ident, ObjectName, Query, Select, SelectItem, SetExpr, Statement,
};
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
use thiserror::Error;
#[derive(Error, Debug)]
#[error("{0}")]
pub struct ParseError(String);
pub use sqlparser::ast;
#[derive(Debug, Clone, PartialEq)]
pub enum SparkStatement {
Sqlparser(Box<Statement>),
DescribeDetail { table: ObjectName },
ShowDatabases,
ShowTables { db: Option<ObjectName> },
Describe {
table: ObjectName,
col: Option<Ident>,
extended: bool,
},
CreateOrReplaceTableAs {
table: ObjectName,
format: String,
query: Box<Query>,
},
}
fn parse_one_statement_raw(query: &str) -> Result<Statement, ParseError> {
let dialect = GenericDialect {};
let stmts = Parser::parse_sql(&dialect, query).map_err(|e| {
ParseError(format!(
"SQL parse error: {}. Hint: supported statements include SELECT, CREATE TABLE/VIEW/FUNCTION/SCHEMA/DATABASE, DROP TABLE/VIEW/SCHEMA.",
e
))
})?;
if stmts.len() != 1 {
return Err(ParseError(format!(
"SQL: expected exactly one statement, got {}. Hint: run one statement at a time.",
stmts.len()
)));
}
Ok(stmts.into_iter().next().expect("len == 1"))
}
fn parse_object_name(name: &str) -> Result<ObjectName, ParseError> {
let s = name.trim();
if s.is_empty() {
return Err(ParseError(
"SQL: expected an object name, got empty string.".to_string(),
));
}
let parts: Vec<Ident> = s
.split('.')
.map(|p| p.trim())
.filter(|p| !p.is_empty())
.map(Ident::new)
.collect();
if parts.is_empty() {
return Err(ParseError(format!(
"SQL: expected an object name, got '{s}'."
)));
}
Ok(ObjectName::from(parts))
}
fn tokenize_ws(s: &str) -> Vec<&str> {
s.split_whitespace().collect()
}
fn try_parse_create_or_replace_table_as(
query: &str,
toks: &[&str],
) -> Result<Option<SparkStatement>, ParseError> {
if toks.len() < 8 {
return Ok(None);
}
if !(toks[0].eq_ignore_ascii_case("CREATE")
&& toks[1].eq_ignore_ascii_case("OR")
&& toks[2].eq_ignore_ascii_case("REPLACE")
&& toks[3].eq_ignore_ascii_case("TABLE"))
{
return Ok(None);
}
let using_pos = toks[4..]
.iter()
.position(|t| t.eq_ignore_ascii_case("USING"))
.map(|i| i + 4);
let using_pos = match using_pos {
Some(pos) => pos,
None => return Ok(None), };
let table_tokens = &toks[4..using_pos];
if table_tokens.is_empty() {
return Err(ParseError(
"SQL: CREATE OR REPLACE TABLE requires a table name.".to_string(),
));
}
let table_name_str = table_tokens.join(" ");
let table = parse_object_name(&table_name_str)?;
if using_pos + 1 >= toks.len() {
return Err(ParseError(
"SQL: CREATE OR REPLACE TABLE ... USING requires a format (e.g., delta, parquet)."
.to_string(),
));
}
let format = toks[using_pos + 1].to_string();
let as_pos = toks[using_pos + 2..]
.iter()
.position(|t| t.eq_ignore_ascii_case("AS"))
.map(|i| i + using_pos + 2);
let as_pos = match as_pos {
Some(pos) => pos,
None => {
return Err(ParseError(
"SQL: CREATE OR REPLACE TABLE ... USING <format> requires AS SELECT ..."
.to_string(),
));
}
};
let subquery_tokens = &toks[as_pos + 1..];
if subquery_tokens.is_empty() {
return Err(ParseError(
"SQL: CREATE OR REPLACE TABLE ... AS requires a SELECT query.".to_string(),
));
}
let format_token = &toks[using_pos + 1];
let query_lower = query.to_lowercase();
let format_byte_pos = query_lower.find(&format_token.to_lowercase()).unwrap_or(0);
let search_start = format_byte_pos + format_token.len();
let remaining = &query[search_start..];
let remaining_lower = remaining.to_lowercase();
let as_offset = find_standalone_as(&remaining_lower);
let as_offset = match as_offset {
Some(offset) => offset,
None => {
return Err(ParseError(
"SQL: Could not locate AS keyword in CREATE OR REPLACE TABLE statement."
.to_string(),
));
}
};
let after_as = &remaining[as_offset..];
let as_lower = after_as.to_lowercase();
let as_keyword_pos = as_lower.find("as").unwrap_or(0);
let after_as_keyword = &after_as[as_keyword_pos + 2..];
let subquery_str = after_as_keyword.trim_start();
if subquery_str.is_empty() {
return Err(ParseError(
"SQL: CREATE OR REPLACE TABLE ... AS requires a SELECT query.".to_string(),
));
}
let dialect = GenericDialect {};
let stmts = Parser::parse_sql(&dialect, subquery_str).map_err(|e| {
ParseError(format!(
"SQL parse error in CREATE OR REPLACE TABLE subquery: {}",
e
))
})?;
if stmts.len() != 1 {
return Err(ParseError(format!(
"SQL: CREATE OR REPLACE TABLE subquery must be a single SELECT statement, got {} statements.",
stmts.len()
)));
}
let stmt = stmts.into_iter().next().expect("len == 1");
let query_ast = match stmt {
Statement::Query(q) => q,
_ => {
return Err(ParseError(
"SQL: CREATE OR REPLACE TABLE ... AS requires a SELECT query.".to_string(),
));
}
};
Ok(Some(SparkStatement::CreateOrReplaceTableAs {
table,
format,
query: query_ast,
}))
}
fn find_standalone_as(s: &str) -> Option<usize> {
let bytes = s.as_bytes();
let len = bytes.len();
for i in 0..len {
if !bytes[i].is_ascii_whitespace() {
continue;
}
if i + 3 <= len {
let candidate = &s[i + 1..i + 3];
if candidate.eq_ignore_ascii_case("as") {
if i + 3 == len || bytes[i + 3].is_ascii_whitespace() {
return Some(i);
}
}
}
}
if len >= 2 && s[..2].eq_ignore_ascii_case("as") && (len == 2 || bytes[2].is_ascii_whitespace())
{
return Some(0);
}
None
}
pub fn parse_spark_sql(query: &str) -> Result<SparkStatement, ParseError> {
let q = query.trim();
if q.is_empty() {
let _ = parse_one_statement_raw(q)?;
}
let toks = tokenize_ws(q);
if let Some(stmt) = try_parse_create_or_replace_table_as(q, &toks)? {
return Ok(stmt);
}
if toks.len() >= 2
&& toks[0].eq_ignore_ascii_case("SHOW")
&& toks[1].eq_ignore_ascii_case("DATABASES")
{
return Ok(SparkStatement::ShowDatabases);
}
if toks.len() >= 2
&& toks[0].eq_ignore_ascii_case("SHOW")
&& toks[1].eq_ignore_ascii_case("TABLES")
{
let db = if toks.len() >= 4
&& (toks[2].eq_ignore_ascii_case("IN") || toks[2].eq_ignore_ascii_case("FROM"))
{
Some(parse_object_name(toks[3])?)
} else {
None
};
return Ok(SparkStatement::ShowTables { db });
}
if toks.len() >= 3
&& toks[0].eq_ignore_ascii_case("DESCRIBE")
&& toks[1].eq_ignore_ascii_case("DETAIL")
{
let table = parse_object_name(&toks[2..].join(" "))?;
return Ok(SparkStatement::DescribeDetail { table });
}
if toks.len() >= 3
&& toks[0].eq_ignore_ascii_case("DESC")
&& toks[1].eq_ignore_ascii_case("DETAIL")
{
let table = parse_object_name(&toks[2..].join(" "))?;
return Ok(SparkStatement::DescribeDetail { table });
}
if !toks.is_empty()
&& (toks[0].eq_ignore_ascii_case("DESCRIBE") || toks[0].eq_ignore_ascii_case("DESC"))
{
if toks.len() >= 2 && toks[1].eq_ignore_ascii_case("DETAIL") {
} else {
let rest = &toks[1..];
if !rest.is_empty() {
let extended = rest.iter().any(|t| t.eq_ignore_ascii_case("EXTENDED"));
let idx = rest.iter().position(|t| {
!t.eq_ignore_ascii_case("TABLE") && !t.eq_ignore_ascii_case("EXTENDED")
});
if let Some(i) = idx {
let table_tok = rest.get(i).copied().unwrap_or("");
if !table_tok.is_empty() {
let table = parse_object_name(table_tok)?;
let col = rest.get(i + 1).map(|c| Ident::new(*c));
return Ok(SparkStatement::Describe {
table,
col,
extended,
});
}
}
}
}
}
let stmt = parse_one_statement_raw(query)?;
Ok(SparkStatement::Sqlparser(Box::new(stmt)))
}
pub fn parse_select_expr(expr_str: &str) -> Result<(SqlExpr, Option<Ident>), ParseError> {
let e = expr_str.trim();
if e.is_empty() {
return Err(ParseError(
"SQL: expected an expression string, got empty.".to_string(),
));
}
const TMP_TABLE: &str = "__spark_sql_parser_expr_t";
let query = format!("SELECT {e} FROM {TMP_TABLE}");
let stmt = parse_one_statement_raw(&query)?;
let query_ast: &Query = match &stmt {
Statement::Query(q) => q.as_ref(),
other => {
return Err(ParseError(format!(
"SQL: expected SELECT when parsing expression, got {other:?}."
)));
}
};
let select: &Select = match query_ast.body.as_ref() {
SetExpr::Select(s) => s.as_ref(),
other => {
return Err(ParseError(format!(
"SQL: expected SELECT when parsing expression, got {other:?}."
)));
}
};
let first: &SelectItem = select.projection.first().ok_or_else(|| {
ParseError("SQL: expected non-empty SELECT list when parsing expression.".to_string())
})?;
match first {
SelectItem::UnnamedExpr(ex) => Ok((ex.clone(), None)),
SelectItem::ExprWithAlias { expr, alias } => Ok((expr.clone(), Some(alias.clone()))),
other => Err(ParseError(format!(
"SQL: unsupported expression form in SELECT list: {other:?}."
))),
}
}
pub fn parse_sql(query: &str) -> Result<Statement, ParseError> {
let stmt = parse_one_statement_raw(query)?;
match &stmt {
Statement::Query(_) => {}
Statement::CreateSchema { .. } | Statement::CreateDatabase { .. } => {}
Statement::CreateTable(_) | Statement::CreateView(_) | Statement::CreateFunction(_) => {}
Statement::AlterTable(_) | Statement::AlterView { .. } | Statement::AlterSchema(_) => {}
Statement::Drop {
object_type:
sqlparser::ast::ObjectType::Table
| sqlparser::ast::ObjectType::View
| sqlparser::ast::ObjectType::Schema
| sqlparser::ast::ObjectType::Database,
..
} => {}
Statement::DropFunction(_) => {}
Statement::Use(_) | Statement::Truncate(_) | Statement::Declare { .. } => {}
Statement::ShowTables { .. }
| Statement::ShowDatabases { .. }
| Statement::ShowSchemas { .. }
| Statement::ShowFunctions { .. }
| Statement::ShowColumns { .. }
| Statement::ShowViews { .. }
| Statement::ShowCreate { .. } => {}
Statement::Insert(_) | Statement::Directory { .. } | Statement::LoadData { .. } => {}
Statement::Update(_) | Statement::Delete(_) => {}
Statement::ExplainTable { .. } => {}
Statement::Set(_) | Statement::Reset(_) => {}
Statement::Cache { .. } | Statement::UNCache { .. } => {}
Statement::Explain { .. } => {}
_ => {
return Err(ParseError(format!(
"SQL: statement type not supported, got {:?}.",
stmt
)));
}
}
Ok(stmt)
}
#[cfg(test)]
mod tests {
use super::*;
use sqlparser::ast::{ObjectType, Statement};
fn assert_parses_to<F>(sql: &str, check: F)
where
F: FnOnce(&Statement) -> bool,
{
let stmt = parse_sql(sql).unwrap_or_else(|e| panic!("parse_sql failed: {e}"));
assert!(check(&stmt), "expected match for: {sql}");
}
#[test]
fn error_multiple_statements() {
let err = parse_sql("SELECT 1; SELECT 2").unwrap_err();
assert!(err.0.contains("expected exactly one statement"));
assert!(err.0.contains("2"));
}
#[test]
fn error_zero_statements() {
let err = parse_sql("").unwrap_err();
assert!(err.0.contains("expected exactly one statement") || err.0.contains("parse error"));
}
#[test]
fn error_unsupported_statement_type() {
let err = parse_sql("COMMIT").unwrap_err();
assert!(err.0.contains("not supported"));
}
#[test]
fn error_syntax() {
let err = parse_sql("SELECT FROM").unwrap_err();
assert!(!err.0.is_empty());
}
#[test]
fn query_select_simple() {
assert_parses_to("SELECT 1", |s| matches!(s, Statement::Query(_)));
}
#[test]
fn query_select_with_from() {
assert_parses_to("SELECT a FROM t", |s| matches!(s, Statement::Query(_)));
}
#[test]
fn query_with_cte() {
assert_parses_to("WITH cte AS (SELECT 1) SELECT * FROM cte", |s| {
matches!(s, Statement::Query(_))
});
}
#[test]
fn query_create_schema() {
assert_parses_to("CREATE SCHEMA s", |s| {
matches!(s, Statement::CreateSchema { .. })
});
}
#[test]
fn query_create_database() {
assert_parses_to("CREATE DATABASE d", |s| {
matches!(s, Statement::CreateDatabase { .. })
});
}
#[test]
fn test_issue_652_create_table() {
assert_parses_to("CREATE TABLE t (a INT)", |s| {
matches!(s, Statement::CreateTable(_))
});
}
#[test]
fn test_issue_652_create_view() {
assert_parses_to("CREATE VIEW v AS SELECT 1", |s| {
matches!(s, Statement::CreateView(_))
});
}
#[test]
fn test_issue_652_create_function() {
assert_parses_to("CREATE FUNCTION f() AS 'com.example.UDF'", |s| {
matches!(s, Statement::CreateFunction(_))
});
}
#[test]
fn test_issue_653_alter_table() {
assert_parses_to("ALTER TABLE t ADD COLUMN c INT", |s| {
matches!(s, Statement::AlterTable(_))
});
}
#[test]
fn test_issue_653_alter_view() {
assert_parses_to("ALTER VIEW v AS SELECT 1", |s| {
matches!(s, Statement::AlterView { .. })
});
}
#[test]
fn test_issue_653_alter_schema() {
assert_parses_to("ALTER SCHEMA db RENAME TO db2", |s| {
matches!(s, Statement::AlterSchema(_))
});
}
#[test]
fn test_issue_654_drop_table() {
let stmt = parse_sql("DROP TABLE t").unwrap();
match &stmt {
Statement::Drop {
object_type: ObjectType::Table,
..
} => {}
_ => panic!("expected Drop Table: {stmt:?}"),
}
}
#[test]
fn test_issue_654_drop_view() {
let stmt = parse_sql("DROP VIEW v").unwrap();
match &stmt {
Statement::Drop {
object_type: ObjectType::View,
..
} => {}
_ => panic!("expected Drop View: {stmt:?}"),
}
}
#[test]
fn test_issue_654_drop_schema() {
let stmt = parse_sql("DROP SCHEMA s").unwrap();
match &stmt {
Statement::Drop {
object_type: ObjectType::Schema,
..
} => {}
_ => panic!("expected Drop Schema: {stmt:?}"),
}
}
#[test]
fn test_issue_654_drop_function() {
assert_parses_to("DROP FUNCTION f", |s| {
matches!(s, Statement::DropFunction(_))
});
}
#[test]
fn test_issue_655_use() {
assert_parses_to("USE db1", |s| matches!(s, Statement::Use(_)));
}
#[test]
fn test_issue_655_truncate() {
assert_parses_to("TRUNCATE TABLE t", |s| matches!(s, Statement::Truncate(_)));
}
#[test]
fn test_issue_655_declare() {
assert_parses_to("DECLARE c CURSOR FOR SELECT 1", |s| {
matches!(s, Statement::Declare { .. })
});
}
#[test]
fn test_issue_656_show_tables() {
assert_parses_to("SHOW TABLES", |s| matches!(s, Statement::ShowTables { .. }));
}
#[test]
fn test_issue_656_show_databases() {
assert_parses_to("SHOW DATABASES", |s| {
matches!(s, Statement::ShowDatabases { .. })
});
}
#[test]
fn test_issue_656_show_schemas() {
assert_parses_to("SHOW SCHEMAS", |s| {
matches!(s, Statement::ShowSchemas { .. })
});
}
#[test]
fn test_issue_656_show_functions() {
assert_parses_to("SHOW FUNCTIONS", |s| {
matches!(s, Statement::ShowFunctions { .. })
});
}
#[test]
fn test_issue_656_show_columns() {
assert_parses_to("SHOW COLUMNS FROM t", |s| {
matches!(s, Statement::ShowColumns { .. })
});
}
#[test]
fn test_issue_656_show_views() {
assert_parses_to("SHOW VIEWS", |s| matches!(s, Statement::ShowViews { .. }));
}
#[test]
fn test_issue_656_show_create_table() {
assert_parses_to("SHOW CREATE TABLE t", |s| {
matches!(s, Statement::ShowCreate { .. })
});
}
#[test]
fn test_issue_657_insert() {
assert_parses_to("INSERT INTO t SELECT 1", |s| {
matches!(s, Statement::Insert(_))
});
}
#[test]
fn test_issue_657_directory() {
assert_parses_to("INSERT OVERWRITE DIRECTORY '/path' SELECT 1", |s| {
matches!(s, Statement::Directory { .. })
});
}
#[test]
fn test_issue_658_describe_table() {
assert_parses_to("DESCRIBE t", |s| {
matches!(s, Statement::ExplainTable { .. })
});
}
#[test]
fn test_issue_659_set() {
assert_parses_to("SET x = 1", |s| matches!(s, Statement::Set(_)));
}
#[test]
fn test_issue_659_reset() {
assert_parses_to("RESET x", |s| matches!(s, Statement::Reset(_)));
}
#[test]
fn test_issue_659_cache() {
assert_parses_to("CACHE TABLE t", |s| matches!(s, Statement::Cache { .. }));
}
#[test]
fn test_issue_659_uncache() {
assert_parses_to("UNCACHE TABLE t", |s| {
matches!(s, Statement::UNCache { .. })
});
}
#[test]
fn test_issue_659_uncache_if_exists() {
assert_parses_to("UNCACHE TABLE IF EXISTS t", |s| {
matches!(s, Statement::UNCache { .. })
});
}
#[test]
fn test_issue_660_explain() {
assert_parses_to("EXPLAIN SELECT 1", |s| {
matches!(s, Statement::Explain { .. })
});
}
#[test]
fn spark_show_databases() {
let s = parse_spark_sql("SHOW DATABASES").unwrap();
assert!(matches!(s, SparkStatement::ShowDatabases));
}
#[test]
fn spark_show_tables_in_db() {
let s = parse_spark_sql("SHOW TABLES IN my_db").unwrap();
match s {
SparkStatement::ShowTables { db: Some(db) } => {
assert_eq!(db.to_string(), "my_db");
}
other => panic!("expected ShowTables with db, got {other:?}"),
}
}
#[test]
fn spark_describe_detail() {
let s = parse_spark_sql("DESCRIBE DETAIL schema1.tbl1").unwrap();
match s {
SparkStatement::DescribeDetail { table } => {
assert_eq!(table.to_string(), "schema1.tbl1");
}
other => panic!("expected DescribeDetail, got {other:?}"),
}
}
#[test]
fn spark_describe_optional_col() {
let s = parse_spark_sql("DESCRIBE t age").unwrap();
match s {
SparkStatement::Describe {
table,
col: Some(c),
extended: false,
} => {
assert_eq!(table.to_string(), "t");
assert_eq!(c.value, "age");
}
other => panic!("expected Describe with col, got {other:?}"),
}
}
#[test]
fn spark_describe_table_extended() {
let s = parse_spark_sql("DESCRIBE TABLE EXTENDED t").unwrap();
match s {
SparkStatement::Describe {
table,
col: None,
extended: true,
} => {
assert_eq!(table.to_string(), "t");
}
other => panic!("expected Describe extended, got {other:?}"),
}
}
#[test]
fn parse_select_expr_with_alias() {
let (e, a) = parse_select_expr("upper(Name) AS u").unwrap();
let _ = e; assert_eq!(a.unwrap().value, "u");
}
#[test]
fn parse_select_expr_without_alias() {
let (_e, a) = parse_select_expr("ltrim(rtrim(Value))").unwrap();
assert!(a.is_none());
}
#[test]
fn spark_show_databases_case_insensitive() {
for sql in ["show databases", "Show Databases", "SHOW DATABASES"] {
let s = parse_spark_sql(sql).unwrap();
assert!(
matches!(s, SparkStatement::ShowDatabases),
"failed for: {sql}"
);
}
}
#[test]
fn spark_show_tables_no_db() {
let s = parse_spark_sql("SHOW TABLES").unwrap();
match s {
SparkStatement::ShowTables { db: None } => {}
other => panic!("expected ShowTables with db=None, got {other:?}"),
}
}
#[test]
fn spark_show_tables_from_db() {
let s = parse_spark_sql("SHOW TABLES FROM other_db").unwrap();
match s {
SparkStatement::ShowTables { db: Some(db) } => assert_eq!(db.to_string(), "other_db"),
other => panic!("expected ShowTables with db, got {other:?}"),
}
}
#[test]
fn spark_show_tables_in_db_case_insensitive() {
let s = parse_spark_sql("show tables in MySchema").unwrap();
match s {
SparkStatement::ShowTables { db: Some(db) } => assert_eq!(db.to_string(), "MySchema"),
other => panic!("expected ShowTables with db, got {other:?}"),
}
}
#[test]
fn spark_describe_detail_single_table() {
let s = parse_spark_sql("DESCRIBE DETAIL t").unwrap();
match s {
SparkStatement::DescribeDetail { table } => assert_eq!(table.to_string(), "t"),
other => panic!("expected DescribeDetail, got {other:?}"),
}
}
#[test]
fn spark_describe_detail_case_insensitive() {
let s = parse_spark_sql("describe detail my_table").unwrap();
match s {
SparkStatement::DescribeDetail { table } => assert_eq!(table.to_string(), "my_table"),
other => panic!("expected DescribeDetail, got {other:?}"),
}
}
#[test]
fn spark_desc_detail_synonym() {
let s = parse_spark_sql("DESC DETAIL catalog.schema.tbl").unwrap();
match s {
SparkStatement::DescribeDetail { table } => {
assert_eq!(table.to_string(), "catalog.schema.tbl")
}
other => panic!("expected DescribeDetail, got {other:?}"),
}
}
#[test]
fn spark_describe_table_only() {
let s = parse_spark_sql("DESCRIBE my_tbl").unwrap();
match s {
SparkStatement::Describe {
table,
col: None,
extended: false,
} => assert_eq!(table.to_string(), "my_tbl"),
other => panic!("expected Describe table only, got {other:?}"),
}
}
#[test]
fn spark_describe_extended_only() {
let s = parse_spark_sql("DESCRIBE EXTENDED t").unwrap();
match s {
SparkStatement::Describe {
table,
col: None,
extended: true,
} => assert_eq!(table.to_string(), "t"),
other => panic!("expected Describe extended, got {other:?}"),
}
}
#[test]
fn spark_desc_short_form() {
let s = parse_spark_sql("DESC t col_x").unwrap();
match s {
SparkStatement::Describe {
table,
col: Some(c),
extended: false,
} => {
assert_eq!(table.to_string(), "t");
assert_eq!(c.value, "col_x");
}
other => panic!("expected Describe with col, got {other:?}"),
}
}
#[test]
fn spark_describe_qualified_table_with_col() {
let s = parse_spark_sql("DESCRIBE global_temp.v id").unwrap();
match s {
SparkStatement::Describe {
table,
col: Some(c),
extended: false,
} => {
assert_eq!(table.to_string(), "global_temp.v");
assert_eq!(c.value, "id");
}
other => panic!("expected Describe qualified table + col, got {other:?}"),
}
}
#[test]
fn spark_parse_spark_sql_empty_fails() {
let err = parse_spark_sql("").unwrap_err();
assert!(
err.0.contains("expected exactly one statement") || err.0.contains("parse error"),
"unexpected error: {}",
err.0
);
}
#[test]
fn spark_parse_spark_sql_whitespace_only_fails() {
let err = parse_spark_sql(" \t\n ").unwrap_err();
assert!(!err.0.is_empty(), "expected some error message");
}
#[test]
fn spark_parse_spark_sql_multiple_statements_fails() {
let err = parse_spark_sql("SELECT 1; SELECT 2").unwrap_err();
assert!(err.0.contains("expected exactly one statement"));
}
#[test]
fn spark_parse_spark_sql_fallback_select() {
let s = parse_spark_sql("SELECT 1 AS x").unwrap();
match s {
SparkStatement::Sqlparser(stmt) if matches!(stmt.as_ref(), Statement::Query(_)) => {}
other => panic!("expected Sqlparser(Query), got {other:?}"),
}
}
#[test]
fn spark_parse_spark_sql_fallback_create_schema() {
let s = parse_spark_sql("CREATE SCHEMA foo").unwrap();
match s {
SparkStatement::Sqlparser(stmt)
if matches!(stmt.as_ref(), Statement::CreateSchema { .. }) => {}
other => panic!("expected Sqlparser(CreateSchema), got {other:?}"),
}
}
#[test]
fn spark_parse_spark_sql_fallback_drop_table() {
let s = parse_spark_sql("DROP TABLE IF EXISTS t").unwrap();
match s {
SparkStatement::Sqlparser(stmt) if matches!(stmt.as_ref(), Statement::Drop { .. }) => {}
other => panic!("expected Sqlparser(Drop), got {other:?}"),
}
}
#[test]
fn parse_select_expr_empty_fails() {
let err = parse_select_expr("").unwrap_err();
assert!(err.0.contains("expected an expression"));
}
#[test]
fn parse_select_expr_whitespace_only_fails() {
let err = parse_select_expr(" \n\t ").unwrap_err();
assert!(err.0.contains("expected an expression"));
}
#[test]
fn parse_select_expr_literal_number() {
let (e, a) = parse_select_expr("42").unwrap();
assert!(matches!(e, SqlExpr::Value(_)));
assert!(a.is_none());
}
#[test]
fn parse_select_expr_literal_string() {
let (e, _) = parse_select_expr("'hello'").unwrap();
assert!(matches!(e, SqlExpr::Value(_)));
}
#[test]
fn parse_select_expr_literal_null() {
let (e, _) = parse_select_expr("NULL").unwrap();
assert!(matches!(e, SqlExpr::Value(_)));
}
#[test]
fn parse_select_expr_identifier() {
let (e, _) = parse_select_expr("column_name").unwrap();
assert!(matches!(e, SqlExpr::Identifier(_)));
}
#[test]
fn parse_select_expr_compound_identifier() {
let (e, _) = parse_select_expr("t.id").unwrap();
assert!(matches!(e, SqlExpr::CompoundIdentifier(_)));
}
#[test]
fn parse_select_expr_binary_op() {
let (e, _) = parse_select_expr("a + b").unwrap();
assert!(matches!(e, SqlExpr::BinaryOp { .. }));
}
#[test]
fn parse_select_expr_function_call() {
let (e, a) = parse_select_expr("COUNT(*)").unwrap();
assert!(matches!(e, SqlExpr::Function(_)));
assert!(a.is_none());
}
#[test]
fn parse_select_expr_function_with_alias() {
let (e, a) = parse_select_expr("SUM(amount) AS total").unwrap();
assert!(matches!(e, SqlExpr::Function(_)));
assert_eq!(a.as_ref().map(|i| i.value.as_str()), Some("total"));
}
#[test]
fn parse_select_expr_nested_function() {
let (_e, a) = parse_select_expr("UPPER(TRIM(name))").unwrap();
assert!(a.is_none());
}
#[test]
fn parse_select_expr_case_when() {
let (e, _) = parse_select_expr("CASE WHEN x > 0 THEN 1 ELSE 0 END").unwrap();
assert!(matches!(e, SqlExpr::Case { .. }));
}
#[test]
fn parse_select_expr_comparison() {
let (e, _) = parse_select_expr("id = 1").unwrap();
assert!(matches!(e, SqlExpr::BinaryOp { .. }));
}
#[test]
fn parse_select_expr_invalid_syntax_fails() {
let err = parse_select_expr("( unclosed").unwrap_err();
assert!(!err.0.is_empty());
}
#[test]
fn spark_create_or_replace_table_using_delta_as_select() {
let sql = "CREATE OR REPLACE TABLE my_table USING delta AS SELECT id, name FROM source";
let s = parse_spark_sql(sql).unwrap();
match s {
SparkStatement::CreateOrReplaceTableAs {
table,
format,
query,
} => {
assert_eq!(table.to_string(), "my_table");
assert_eq!(format.to_lowercase(), "delta");
assert!(matches!(query.body.as_ref(), SetExpr::Select(_)));
}
other => panic!("expected CreateOrReplaceTableAs, got {other:?}"),
}
}
#[test]
fn spark_create_or_replace_table_qualified_name() {
let sql =
"CREATE OR REPLACE TABLE schema1.my_table USING parquet AS SELECT * FROM other_table";
let s = parse_spark_sql(sql).unwrap();
match s {
SparkStatement::CreateOrReplaceTableAs {
table,
format,
query: _,
} => {
assert_eq!(table.to_string(), "schema1.my_table");
assert_eq!(format.to_lowercase(), "parquet");
}
other => panic!("expected CreateOrReplaceTableAs, got {other:?}"),
}
}
#[test]
fn spark_create_or_replace_table_multiline() {
let sql = r#"
CREATE OR REPLACE TABLE clean_events
USING delta AS
SELECT user_id, name, value, '2025-01-01' AS processed_at
FROM raw_events
"#;
let s = parse_spark_sql(sql).unwrap();
match s {
SparkStatement::CreateOrReplaceTableAs {
table,
format,
query: _,
} => {
assert_eq!(table.to_string(), "clean_events");
assert_eq!(format.to_lowercase(), "delta");
}
other => panic!("expected CreateOrReplaceTableAs, got {other:?}"),
}
}
#[test]
fn spark_create_or_replace_table_case_insensitive() {
let sql = "create or replace table T using DELTA as select 1";
let s = parse_spark_sql(sql).unwrap();
match s {
SparkStatement::CreateOrReplaceTableAs { table, format, .. } => {
assert_eq!(table.to_string(), "T");
assert_eq!(format.to_uppercase(), "DELTA");
}
other => panic!("expected CreateOrReplaceTableAs, got {other:?}"),
}
}
#[test]
fn spark_create_table_without_or_replace_falls_through() {
let sql = "CREATE TABLE t (id INT)";
let s = parse_spark_sql(sql).unwrap();
match s {
SparkStatement::Sqlparser(stmt) => {
assert!(matches!(stmt.as_ref(), Statement::CreateTable(_)));
}
other => panic!("expected Sqlparser(CreateTable), got {other:?}"),
}
}
}