use std::collections::{HashMap, HashSet};
use std::sync::{
Arc, OnceLock,
atomic::{AtomicBool, Ordering as AtomicOrdering},
};
use crate::SqlResult;
use crate::SqlValue;
use arrow::record_batch::RecordBatch;
use llkv_executor::SelectExecution;
use llkv_expr::literal::Literal;
use llkv_plan::validation::{
ensure_known_columns_case_insensitive, ensure_non_empty, ensure_unique_case_insensitive,
};
use llkv_result::Error;
use llkv_runtime::storage_namespace::TEMPORARY_NAMESPACE_ID;
use llkv_runtime::{
AggregateExpr, AssignmentValue, ColumnAssignment, ColumnSpec, CreateIndexPlan, CreateTablePlan,
CreateTableSource, DeletePlan, ForeignKeyAction, ForeignKeySpec, IndexColumnPlan, InsertPlan,
InsertSource, OrderByPlan, OrderSortType, OrderTarget, PlanStatement, PlanValue,
RuntimeContext, RuntimeEngine, RuntimeSession, RuntimeStatementResult, SelectPlan,
SelectProjection, UpdatePlan, extract_rows_from_range,
};
use llkv_storage::pager::Pager;
use llkv_table::catalog::{IdentifierContext, IdentifierResolver};
use regex::Regex;
use simd_r_drive_entry_handle::EntryHandle;
use sqlparser::ast::{
Assignment, AssignmentTarget, BeginTransactionKind, BinaryOperator, ColumnOption,
ColumnOptionDef, ConstraintCharacteristics, DataType as SqlDataType, Delete, ExceptionWhen,
Expr as SqlExpr, FromTable, FunctionArg, FunctionArgExpr, FunctionArguments, GroupByExpr,
Ident, LimitClause, NullsDistinctOption, ObjectName, ObjectNamePart, ObjectType, OrderBy,
OrderByKind, Query, ReferentialAction, SchemaName, Select, SelectItem,
SelectItemQualifiedWildcardKind, Set, SetExpr, SqlOption, Statement, TableConstraint,
TableFactor, TableObject, TableWithJoins, TransactionMode, TransactionModifier, UnaryOperator,
UpdateTableFromKind, Value, ValueWithSpan,
};
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
pub struct SqlEngine<P>
where
P: Pager<Blob = EntryHandle> + Send + Sync,
{
engine: RuntimeEngine<P>,
default_nulls_first: AtomicBool,
}
const DROPPED_TABLE_TRANSACTION_ERR: &str = "another transaction has dropped this table";
impl<P> Clone for SqlEngine<P>
where
P: Pager<Blob = EntryHandle> + Send + Sync + 'static,
{
fn clone(&self) -> Self {
tracing::warn!(
"[SQL_ENGINE] SqlEngine::clone() called - will create new Engine with new session!"
);
Self {
engine: self.engine.clone(),
default_nulls_first: AtomicBool::new(
self.default_nulls_first.load(AtomicOrdering::Relaxed),
),
}
}
}
#[allow(dead_code)]
impl<P> SqlEngine<P>
where
P: Pager<Blob = EntryHandle> + Send + Sync + 'static,
{
fn map_table_error(table_name: &str, err: Error) -> Error {
match err {
Error::NotFound => Self::table_not_found_error(table_name),
Error::InvalidArgumentError(msg) if msg.contains("unknown table") => {
Self::table_not_found_error(table_name)
}
other => other,
}
}
fn table_not_found_error(table_name: &str) -> Error {
Error::CatalogError(format!(
"Catalog Error: Table '{table_name}' does not exist"
))
}
fn is_table_missing_error(err: &Error) -> bool {
match err {
Error::NotFound => true,
Error::CatalogError(msg) => {
msg.contains("Catalog Error: Table") || msg.contains("unknown table")
}
Error::InvalidArgumentError(msg) => {
msg.contains("Catalog Error: Table") || msg.contains("unknown table")
}
_ => false,
}
}
fn execute_plan_statement(
&self,
statement: PlanStatement,
) -> SqlResult<RuntimeStatementResult<P>> {
let table = llkv_runtime::statement_table_name(&statement).map(str::to_string);
self.engine.execute_statement(statement).map_err(|err| {
if let Some(table_name) = table {
Self::map_table_error(&table_name, err)
} else {
err
}
})
}
pub fn new(pager: Arc<P>) -> Self {
let engine = RuntimeEngine::new(pager);
Self {
engine,
default_nulls_first: AtomicBool::new(false),
}
}
fn preprocess_exclude_syntax(sql: &str) -> String {
static EXCLUDE_REGEX: OnceLock<Regex> = OnceLock::new();
let re = EXCLUDE_REGEX.get_or_init(|| {
Regex::new(
r"(?i)EXCLUDE\s*\(\s*([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)+)\s*\)",
)
.expect("valid EXCLUDE qualifier regex")
});
re.replace_all(sql, |caps: ®ex::Captures| {
let qualified_name = &caps[1];
format!("EXCLUDE (\"{}\")", qualified_name)
})
.to_string()
}
pub(crate) fn context_arc(&self) -> Arc<RuntimeContext<P>> {
self.engine.context()
}
pub fn with_context(context: Arc<RuntimeContext<P>>, default_nulls_first: bool) -> Self {
Self {
engine: RuntimeEngine::from_context(context),
default_nulls_first: AtomicBool::new(default_nulls_first),
}
}
#[cfg(test)]
fn default_nulls_first_for_tests(&self) -> bool {
self.default_nulls_first.load(AtomicOrdering::Relaxed)
}
fn has_active_transaction(&self) -> bool {
self.engine.session().has_active_transaction()
}
pub fn session(&self) -> &RuntimeSession<P> {
self.engine.session()
}
pub fn execute(&self, sql: &str) -> SqlResult<Vec<RuntimeStatementResult<P>>> {
tracing::trace!("DEBUG SQL execute: {}", sql);
let processed_sql = Self::preprocess_exclude_syntax(sql);
let dialect = GenericDialect {};
let statements = Parser::parse_sql(&dialect, &processed_sql)
.map_err(|err| Error::InvalidArgumentError(format!("failed to parse SQL: {err}")))?;
tracing::trace!("DEBUG SQL execute: parsed {} statements", statements.len());
let mut results = Vec::with_capacity(statements.len());
for (i, statement) in statements.iter().enumerate() {
tracing::trace!("DEBUG SQL execute: processing statement {}", i);
results.push(self.execute_statement(statement.clone())?);
tracing::trace!("DEBUG SQL execute: statement {} completed", i);
}
tracing::trace!("DEBUG SQL execute completed successfully");
Ok(results)
}
pub fn sql(&self, sql: &str) -> SqlResult<Vec<RecordBatch>> {
let mut results = self.execute(sql)?;
if results.len() != 1 {
return Err(Error::InvalidArgumentError(
"SqlEngine::sql expects exactly one SQL statement".into(),
));
}
match results.pop().expect("checked length above") {
RuntimeStatementResult::Select { execution, .. } => execution.collect(),
other => Err(Error::InvalidArgumentError(format!(
"SqlEngine::sql requires a SELECT statement, got {other:?}",
))),
}
}
fn execute_statement(&self, statement: Statement) -> SqlResult<RuntimeStatementResult<P>> {
tracing::trace!(
"DEBUG SQL execute_statement: {:?}",
match &statement {
Statement::Insert(insert) =>
format!("Insert(table={:?})", Self::table_name_from_insert(insert)),
Statement::Query(_) => "Query".to_string(),
Statement::StartTransaction { .. } => "StartTransaction".to_string(),
Statement::Commit { .. } => "Commit".to_string(),
Statement::Rollback { .. } => "Rollback".to_string(),
Statement::CreateTable(_) => "CreateTable".to_string(),
Statement::Update { .. } => "Update".to_string(),
Statement::Delete(_) => "Delete".to_string(),
other => format!("Other({:?})", other),
}
);
match statement {
Statement::StartTransaction {
modes,
begin,
transaction,
modifier,
statements,
exception,
has_end_keyword,
} => self.handle_start_transaction(
modes,
begin,
transaction,
modifier,
statements,
exception,
has_end_keyword,
),
Statement::Commit {
chain,
end,
modifier,
} => self.handle_commit(chain, end, modifier),
Statement::Rollback { chain, savepoint } => self.handle_rollback(chain, savepoint),
other => self.execute_statement_non_transactional(other),
}
}
fn execute_statement_non_transactional(
&self,
statement: Statement,
) -> SqlResult<RuntimeStatementResult<P>> {
tracing::trace!("DEBUG SQL execute_statement_non_transactional called");
match statement {
Statement::CreateTable(stmt) => {
tracing::trace!("DEBUG SQL execute_statement_non_transactional: CreateTable");
self.handle_create_table(stmt)
}
Statement::CreateIndex(stmt) => {
tracing::trace!("DEBUG SQL execute_statement_non_transactional: CreateIndex");
self.handle_create_index(stmt)
}
Statement::CreateSchema {
schema_name,
if_not_exists,
with,
options,
default_collate_spec,
clone,
} => {
tracing::trace!("DEBUG SQL execute_statement_non_transactional: CreateSchema");
self.handle_create_schema(
schema_name,
if_not_exists,
with,
options,
default_collate_spec,
clone,
)
}
Statement::Insert(stmt) => {
let table_name =
Self::table_name_from_insert(&stmt).unwrap_or_else(|_| "unknown".to_string());
tracing::trace!(
"DEBUG SQL execute_statement_non_transactional: Insert(table={})",
table_name
);
self.handle_insert(stmt)
}
Statement::Query(query) => {
tracing::trace!("DEBUG SQL execute_statement_non_transactional: Query");
self.handle_query(*query)
}
Statement::Update {
table,
assignments,
from,
selection,
returning,
..
} => {
tracing::trace!("DEBUG SQL execute_statement_non_transactional: Update");
self.handle_update(table, assignments, from, selection, returning)
}
Statement::Delete(delete) => {
tracing::trace!("DEBUG SQL execute_statement_non_transactional: Delete");
self.handle_delete(delete)
}
Statement::Drop {
object_type,
if_exists,
names,
cascade,
restrict,
purge,
temporary,
..
} => {
tracing::trace!("DEBUG SQL execute_statement_non_transactional: Drop");
self.handle_drop(
object_type,
if_exists,
names,
cascade,
restrict,
purge,
temporary,
)
}
Statement::Set(set_stmt) => {
tracing::trace!("DEBUG SQL execute_statement_non_transactional: Set");
self.handle_set(set_stmt)
}
Statement::Pragma { name, value, is_eq } => {
tracing::trace!("DEBUG SQL execute_statement_non_transactional: Pragma");
self.handle_pragma(name, value, is_eq)
}
other => {
tracing::trace!(
"DEBUG SQL execute_statement_non_transactional: Other({:?})",
other
);
Err(Error::InvalidArgumentError(format!(
"unsupported SQL statement: {other:?}"
)))
}
}
}
fn table_name_from_insert(insert: &sqlparser::ast::Insert) -> SqlResult<String> {
match &insert.table {
TableObject::TableName(name) => Self::object_name_to_string(name),
_ => Err(Error::InvalidArgumentError(
"INSERT requires a plain table name".into(),
)),
}
}
fn table_name_from_update(table: &TableWithJoins) -> SqlResult<Option<String>> {
if !table.joins.is_empty() {
return Err(Error::InvalidArgumentError(
"UPDATE with JOIN targets is not supported yet".into(),
));
}
Self::table_with_joins_name(table)
}
fn table_name_from_delete(delete: &Delete) -> SqlResult<Option<String>> {
if !delete.tables.is_empty() {
return Err(Error::InvalidArgumentError(
"multi-table DELETE is not supported yet".into(),
));
}
let from_tables = match &delete.from {
FromTable::WithFromKeyword(tables) | FromTable::WithoutKeyword(tables) => tables,
};
if from_tables.is_empty() {
return Ok(None);
}
if from_tables.len() != 1 {
return Err(Error::InvalidArgumentError(
"DELETE over multiple tables is not supported yet".into(),
));
}
Self::table_with_joins_name(&from_tables[0])
}
fn object_name_to_string(name: &ObjectName) -> SqlResult<String> {
let (display, _) = canonical_object_name(name)?;
Ok(display)
}
#[allow(dead_code)]
fn table_object_to_name(table: &TableObject) -> SqlResult<Option<String>> {
match table {
TableObject::TableName(name) => Ok(Some(Self::object_name_to_string(name)?)),
TableObject::TableFunction(_) => Ok(None),
}
}
fn table_with_joins_name(table: &TableWithJoins) -> SqlResult<Option<String>> {
match &table.relation {
TableFactor::Table { name, .. } => Ok(Some(Self::object_name_to_string(name)?)),
_ => Ok(None),
}
}
fn tables_in_query(query: &Query) -> SqlResult<Vec<String>> {
let mut tables = Vec::new();
if let sqlparser::ast::SetExpr::Select(select) = query.body.as_ref() {
for table in &select.from {
if let TableFactor::Table { name, .. } = &table.relation {
tables.push(Self::object_name_to_string(name)?);
}
}
}
Ok(tables)
}
fn collect_known_columns(
&self,
display_name: &str,
canonical_name: &str,
) -> SqlResult<HashSet<String>> {
let context = self.engine.context();
if context.is_table_marked_dropped(canonical_name) {
return Err(Self::table_not_found_error(display_name));
}
match context.table_column_specs(display_name) {
Ok(specs) => Ok(specs
.into_iter()
.map(|spec| spec.name.to_ascii_lowercase())
.collect()),
Err(err) => {
if !Self::is_table_missing_error(&err) {
return Err(Self::map_table_error(display_name, err));
}
Ok(HashSet::new())
}
}
}
fn is_table_marked_dropped(&self, table_name: &str) -> SqlResult<bool> {
let canonical = table_name.to_ascii_lowercase();
Ok(self.engine.context().is_table_marked_dropped(&canonical))
}
fn handle_create_table(
&self,
mut stmt: sqlparser::ast::CreateTable,
) -> SqlResult<RuntimeStatementResult<P>> {
validate_create_table_common(&stmt)?;
let (mut schema_name, table_name) = parse_schema_qualified_name(&stmt.name)?;
let namespace = if stmt.temporary {
if schema_name.is_some() {
return Err(Error::InvalidArgumentError(
"temporary tables cannot specify an explicit schema".into(),
));
}
schema_name = None;
Some(TEMPORARY_NAMESPACE_ID.to_string())
} else {
None
};
if let Some(ref schema) = schema_name {
let catalog = self.engine.context().table_catalog();
if !catalog.schema_exists(schema) {
return Err(Error::CatalogError(format!(
"Schema '{}' does not exist",
schema
)));
}
}
let display_name = match &schema_name {
Some(schema) => format!("{}.{}", schema, table_name),
None => table_name.clone(),
};
let canonical_name = display_name.to_ascii_lowercase();
tracing::trace!(
"\n=== HANDLE_CREATE_TABLE: table='{}' columns={} ===",
display_name,
stmt.columns.len()
);
if display_name.is_empty() {
return Err(Error::InvalidArgumentError(
"table name must not be empty".into(),
));
}
if let Some(query) = stmt.query.take() {
validate_create_table_as(&stmt)?;
if let Some(result) = self.try_handle_range_ctas(
&display_name,
&canonical_name,
&query,
stmt.if_not_exists,
stmt.or_replace,
namespace.clone(),
)? {
return Ok(result);
}
return self.handle_create_table_as(
display_name,
canonical_name,
*query,
stmt.if_not_exists,
stmt.or_replace,
namespace.clone(),
);
}
if stmt.columns.is_empty() {
return Err(Error::InvalidArgumentError(
"CREATE TABLE requires at least one column".into(),
));
}
validate_create_table_definition(&stmt)?;
let column_defs_ast = std::mem::take(&mut stmt.columns);
let constraints = std::mem::take(&mut stmt.constraints);
let column_names: Vec<String> = column_defs_ast
.iter()
.map(|column_def| column_def.name.value.clone())
.collect();
ensure_unique_case_insensitive(column_names.iter().map(|name| name.as_str()), |dup| {
format!(
"duplicate column name '{}' in table '{}'",
dup, display_name
)
})?;
let column_names_lower: HashSet<String> = column_names
.iter()
.map(|name| name.to_ascii_lowercase())
.collect();
let mut columns: Vec<ColumnSpec> = Vec::with_capacity(column_defs_ast.len());
let mut primary_key_columns: HashSet<String> = HashSet::new();
let mut foreign_keys: Vec<ForeignKeySpec> = Vec::new();
for column_def in column_defs_ast {
let is_nullable = column_def
.options
.iter()
.all(|opt| !matches!(opt.option, ColumnOption::NotNull));
let is_primary_key = column_def.options.iter().any(|opt| {
matches!(
opt.option,
ColumnOption::Unique {
is_primary: true,
characteristics: _
}
)
});
let has_unique_constraint = column_def
.options
.iter()
.any(|opt| matches!(opt.option, ColumnOption::Unique { .. }));
let check_expr = column_def.options.iter().find_map(|opt| {
if let ColumnOption::Check(expr) = &opt.option {
Some(expr)
} else {
None
}
});
if let Some(check_expr) = check_expr {
let all_col_refs: Vec<&str> = column_names.iter().map(|s| s.as_str()).collect();
validate_check_constraint(check_expr, &display_name, &all_col_refs)?;
}
let check_expr_str = check_expr.map(|e| e.to_string());
for opt in &column_def.options {
if let ColumnOption::ForeignKey {
foreign_table,
referred_columns,
on_delete,
on_update,
characteristics,
} = &opt.option
{
let spec = self.build_foreign_key_spec(
&display_name,
&canonical_name,
vec![column_def.name.value.clone()],
foreign_table,
referred_columns,
*on_delete,
*on_update,
characteristics,
&column_names_lower,
None,
)?;
foreign_keys.push(spec);
}
}
tracing::trace!(
"DEBUG CREATE TABLE column '{}' is_primary_key={} has_unique={} check_expr={:?}",
column_def.name.value,
is_primary_key,
has_unique_constraint,
check_expr_str
);
let mut column = ColumnSpec::new(
column_def.name.value.clone(),
arrow_type_from_sql(&column_def.data_type)?,
is_nullable,
);
tracing::trace!(
"DEBUG ColumnSpec after new(): primary_key={} unique={}",
column.primary_key,
column.unique
);
column = column
.with_primary_key(is_primary_key)
.with_unique(has_unique_constraint)
.with_check(check_expr_str);
if is_primary_key {
column.nullable = false;
primary_key_columns.insert(column.name.to_ascii_lowercase());
}
tracing::trace!(
"DEBUG ColumnSpec after with_primary_key({})/with_unique({}): primary_key={} unique={} check_expr={:?}",
is_primary_key,
has_unique_constraint,
column.primary_key,
column.unique,
column.check_expr
);
columns.push(column);
}
if !constraints.is_empty() {
let mut column_lookup: HashMap<String, usize> = HashMap::with_capacity(columns.len());
for (idx, column) in columns.iter().enumerate() {
column_lookup.insert(column.name.to_ascii_lowercase(), idx);
}
for constraint in constraints {
match constraint {
TableConstraint::PrimaryKey {
columns: constraint_columns,
..
} => {
if !primary_key_columns.is_empty() {
return Err(Error::InvalidArgumentError(
"multiple PRIMARY KEY constraints are not supported".into(),
));
}
ensure_non_empty(&constraint_columns, || {
"PRIMARY KEY requires at least one column".into()
})?;
let mut pk_column_names: Vec<String> =
Vec::with_capacity(constraint_columns.len());
for index_col in &constraint_columns {
let column_ident = extract_index_column_name(
index_col,
"PRIMARY KEY",
false, false, )?;
pk_column_names.push(column_ident);
}
ensure_unique_case_insensitive(
pk_column_names.iter().map(|name| name.as_str()),
|dup| format!("duplicate column '{}' in PRIMARY KEY constraint", dup),
)?;
ensure_known_columns_case_insensitive(
pk_column_names.iter().map(|name| name.as_str()),
&column_names_lower,
|unknown| {
format!("unknown column '{}' in PRIMARY KEY constraint", unknown)
},
)?;
for column_ident in pk_column_names {
let normalized = column_ident.to_ascii_lowercase();
let idx = column_lookup.get(&normalized).copied().ok_or_else(|| {
Error::InvalidArgumentError(format!(
"unknown column '{}' in PRIMARY KEY constraint",
column_ident
))
})?;
let column = columns.get_mut(idx).expect("column index valid");
column.primary_key = true;
column.unique = true;
column.nullable = false;
primary_key_columns.insert(normalized);
}
}
TableConstraint::Unique {
columns: constraint_columns,
index_type,
index_options,
characteristics,
nulls_distinct,
..
} => {
if !matches!(nulls_distinct, NullsDistinctOption::None) {
return Err(Error::InvalidArgumentError(
"UNIQUE constraints with NULLS DISTINCT/NOT DISTINCT are not supported yet".into(),
));
}
if index_type.is_some() {
return Err(Error::InvalidArgumentError(
"UNIQUE constraints with index types are not supported yet".into(),
));
}
if !index_options.is_empty() {
return Err(Error::InvalidArgumentError(
"UNIQUE constraints with index options are not supported yet"
.into(),
));
}
if characteristics.is_some() {
return Err(Error::InvalidArgumentError(
"UNIQUE constraint characteristics are not supported yet".into(),
));
}
ensure_non_empty(&constraint_columns, || {
"UNIQUE constraint requires at least one column".into()
})?;
let mut unique_column_names: Vec<String> =
Vec::with_capacity(constraint_columns.len());
for index_column in &constraint_columns {
let column_ident = extract_index_column_name(
index_column,
"UNIQUE constraint",
false, false, )?;
unique_column_names.push(column_ident);
}
if unique_column_names.len() > 1 {
return Err(Error::InvalidArgumentError(
"multi-column UNIQUE constraints are not supported yet".into(),
));
}
ensure_unique_case_insensitive(
unique_column_names.iter().map(|name| name.as_str()),
|dup| format!("duplicate column '{}' in UNIQUE constraint", dup),
)?;
ensure_known_columns_case_insensitive(
unique_column_names.iter().map(|name| name.as_str()),
&column_names_lower,
|unknown| format!("unknown column '{}' in UNIQUE constraint", unknown),
)?;
let column_ident = unique_column_names
.into_iter()
.next()
.expect("unique constraint checked for emptiness");
let normalized = column_ident.to_ascii_lowercase();
let idx = column_lookup.get(&normalized).copied().ok_or_else(|| {
Error::InvalidArgumentError(format!(
"unknown column '{}' in UNIQUE constraint",
column_ident
))
})?;
let column = columns
.get_mut(idx)
.expect("column index from lookup must be valid");
column.unique = true;
}
TableConstraint::ForeignKey {
name,
index_name,
columns: fk_columns,
foreign_table,
referred_columns,
on_delete,
on_update,
characteristics,
..
} => {
if index_name.is_some() {
return Err(Error::InvalidArgumentError(
"FOREIGN KEY index clauses are not supported yet".into(),
));
}
let referencing_columns: Vec<String> =
fk_columns.into_iter().map(|ident| ident.value).collect();
let spec = self.build_foreign_key_spec(
&display_name,
&canonical_name,
referencing_columns,
&foreign_table,
&referred_columns,
on_delete,
on_update,
&characteristics,
&column_names_lower,
name.map(|ident| ident.value),
)?;
foreign_keys.push(spec);
}
unsupported => {
return Err(Error::InvalidArgumentError(format!(
"table-level constraint {:?} is not supported",
unsupported
)));
}
}
}
}
let plan = CreateTablePlan {
name: display_name,
if_not_exists: stmt.if_not_exists,
or_replace: stmt.or_replace,
columns,
source: None,
namespace,
foreign_keys,
};
self.execute_plan_statement(PlanStatement::CreateTable(plan))
}
fn handle_create_index(
&self,
stmt: sqlparser::ast::CreateIndex,
) -> SqlResult<RuntimeStatementResult<P>> {
let sqlparser::ast::CreateIndex {
name,
table_name,
using,
columns,
unique,
concurrently,
if_not_exists,
include,
nulls_distinct,
with,
predicate,
index_options,
alter_options,
..
} = stmt;
if concurrently {
return Err(Error::InvalidArgumentError(
"CREATE INDEX CONCURRENTLY is not supported".into(),
));
}
if using.is_some() {
return Err(Error::InvalidArgumentError(
"CREATE INDEX USING clauses are not supported".into(),
));
}
if !include.is_empty() {
return Err(Error::InvalidArgumentError(
"CREATE INDEX INCLUDE columns are not supported".into(),
));
}
if nulls_distinct.is_some() {
return Err(Error::InvalidArgumentError(
"CREATE INDEX NULLS DISTINCT is not supported".into(),
));
}
if !with.is_empty() {
return Err(Error::InvalidArgumentError(
"CREATE INDEX WITH options are not supported".into(),
));
}
if predicate.is_some() {
return Err(Error::InvalidArgumentError(
"partial CREATE INDEX is not supported".into(),
));
}
if !index_options.is_empty() {
return Err(Error::InvalidArgumentError(
"CREATE INDEX options are not supported".into(),
));
}
if !alter_options.is_empty() {
return Err(Error::InvalidArgumentError(
"CREATE INDEX ALTER options are not supported".into(),
));
}
if columns.is_empty() {
return Err(Error::InvalidArgumentError(
"CREATE INDEX requires at least one column".into(),
));
}
let (schema_name, base_table_name) = parse_schema_qualified_name(&table_name)?;
if let Some(ref schema) = schema_name {
let catalog = self.engine.context().table_catalog();
if !catalog.schema_exists(schema) {
return Err(Error::CatalogError(format!(
"Schema '{}' does not exist",
schema
)));
}
}
let display_table_name = schema_name
.as_ref()
.map(|schema| format!("{}.{}", schema, base_table_name))
.unwrap_or_else(|| base_table_name.clone());
let canonical_table_name = display_table_name.to_ascii_lowercase();
let known_columns =
self.collect_known_columns(&display_table_name, &canonical_table_name)?;
let enforce_known_columns = !known_columns.is_empty();
let index_name = match name {
Some(name_obj) => Some(Self::object_name_to_string(&name_obj)?),
None => None,
};
let mut index_columns: Vec<IndexColumnPlan> = Vec::with_capacity(columns.len());
let mut seen_column_names: HashSet<String> = HashSet::new();
for item in columns {
if item.column.with_fill.is_some() {
return Err(Error::InvalidArgumentError(
"CREATE INDEX column WITH FILL is not supported".into(),
));
}
let column_name = extract_index_column_name(
&item,
"CREATE INDEX",
true, true, )?;
let order_expr = &item.column;
let ascending = order_expr.options.asc.unwrap_or(true);
let nulls_first = order_expr.options.nulls_first.unwrap_or(false);
let normalized = column_name.to_ascii_lowercase();
if !seen_column_names.insert(normalized.clone()) {
return Err(Error::InvalidArgumentError(format!(
"duplicate column '{}' in CREATE INDEX",
column_name
)));
}
if enforce_known_columns && !known_columns.contains(&normalized) {
return Err(Error::InvalidArgumentError(format!(
"column '{}' does not exist in table '{}'",
column_name, display_table_name
)));
}
let column_plan = IndexColumnPlan::new(column_name).with_sort(ascending, nulls_first);
index_columns.push(column_plan);
}
if index_columns.len() > 1 && !unique {
return Err(Error::InvalidArgumentError(
"multi-column CREATE INDEX currently supports UNIQUE indexes only".into(),
));
}
let plan = CreateIndexPlan::new(display_table_name)
.with_name(index_name)
.with_unique(unique)
.with_if_not_exists(if_not_exists)
.with_columns(index_columns);
self.execute_plan_statement(PlanStatement::CreateIndex(plan))
}
fn map_referential_action(
action: Option<ReferentialAction>,
kind: &str,
) -> SqlResult<ForeignKeyAction> {
match action {
None | Some(ReferentialAction::NoAction) => Ok(ForeignKeyAction::NoAction),
Some(ReferentialAction::Restrict) => Ok(ForeignKeyAction::Restrict),
Some(other) => Err(Error::InvalidArgumentError(format!(
"FOREIGN KEY ON {kind} {:?} is not supported yet",
other
))),
}
}
#[allow(clippy::too_many_arguments)]
fn build_foreign_key_spec(
&self,
_referencing_display: &str,
referencing_canonical: &str,
referencing_columns: Vec<String>,
foreign_table: &ObjectName,
referenced_columns: &[Ident],
on_delete: Option<ReferentialAction>,
on_update: Option<ReferentialAction>,
characteristics: &Option<ConstraintCharacteristics>,
known_columns_lower: &HashSet<String>,
name: Option<String>,
) -> SqlResult<ForeignKeySpec> {
if characteristics.is_some() {
return Err(Error::InvalidArgumentError(
"FOREIGN KEY constraint characteristics are not supported yet".into(),
));
}
ensure_non_empty(&referencing_columns, || {
"FOREIGN KEY constraint requires at least one referencing column".into()
})?;
ensure_unique_case_insensitive(
referencing_columns.iter().map(|name| name.as_str()),
|dup| format!("duplicate column '{}' in FOREIGN KEY constraint", dup),
)?;
ensure_known_columns_case_insensitive(
referencing_columns.iter().map(|name| name.as_str()),
known_columns_lower,
|unknown| format!("unknown column '{}' in FOREIGN KEY constraint", unknown),
)?;
let referenced_columns_vec: Vec<String> = referenced_columns
.iter()
.map(|ident| ident.value.clone())
.collect();
ensure_unique_case_insensitive(
referenced_columns_vec.iter().map(|name| name.as_str()),
|dup| {
format!(
"duplicate referenced column '{}' in FOREIGN KEY constraint",
dup
)
},
)?;
if !referenced_columns_vec.is_empty()
&& referenced_columns_vec.len() != referencing_columns.len()
{
return Err(Error::InvalidArgumentError(
"FOREIGN KEY referencing and referenced column counts must match".into(),
));
}
let (referenced_display, referenced_canonical) = canonical_object_name(foreign_table)?;
if referenced_canonical == referencing_canonical {
ensure_known_columns_case_insensitive(
referenced_columns_vec.iter().map(|name| name.as_str()),
known_columns_lower,
|unknown| {
format!(
"unknown referenced column '{}' in FOREIGN KEY constraint",
unknown
)
},
)?;
} else {
let known_columns =
self.collect_known_columns(&referenced_display, &referenced_canonical)?;
if !known_columns.is_empty() {
ensure_known_columns_case_insensitive(
referenced_columns_vec.iter().map(|name| name.as_str()),
&known_columns,
|unknown| {
format!(
"unknown referenced column '{}' in FOREIGN KEY constraint",
unknown
)
},
)?;
}
}
let on_delete_action = Self::map_referential_action(on_delete, "DELETE")?;
let on_update_action = Self::map_referential_action(on_update, "UPDATE")?;
Ok(ForeignKeySpec {
name,
columns: referencing_columns,
referenced_table: referenced_display,
referenced_columns: referenced_columns_vec,
on_delete: on_delete_action,
on_update: on_update_action,
})
}
fn handle_create_schema(
&self,
schema_name: SchemaName,
_if_not_exists: bool,
with: Option<Vec<SqlOption>>,
options: Option<Vec<SqlOption>>,
default_collate_spec: Option<SqlExpr>,
clone: Option<ObjectName>,
) -> SqlResult<RuntimeStatementResult<P>> {
if clone.is_some() {
return Err(Error::InvalidArgumentError(
"CREATE SCHEMA ... CLONE is not supported".into(),
));
}
if with.as_ref().is_some_and(|opts| !opts.is_empty()) {
return Err(Error::InvalidArgumentError(
"CREATE SCHEMA ... WITH options are not supported".into(),
));
}
if options.as_ref().is_some_and(|opts| !opts.is_empty()) {
return Err(Error::InvalidArgumentError(
"CREATE SCHEMA options are not supported".into(),
));
}
if default_collate_spec.is_some() {
return Err(Error::InvalidArgumentError(
"CREATE SCHEMA DEFAULT COLLATE is not supported".into(),
));
}
let schema_name = match schema_name {
SchemaName::Simple(name) => name,
_ => {
return Err(Error::InvalidArgumentError(
"CREATE SCHEMA authorization is not supported".into(),
));
}
};
let (display_name, canonical) = canonical_object_name(&schema_name)?;
if display_name.is_empty() {
return Err(Error::InvalidArgumentError(
"schema name must not be empty".into(),
));
}
let catalog = self.engine.context().table_catalog();
if _if_not_exists && catalog.schema_exists(&canonical) {
return Ok(RuntimeStatementResult::NoOp);
}
catalog.register_schema(&canonical).map_err(|err| {
Error::CatalogError(format!(
"Failed to create schema '{}': {}",
display_name, err
))
})?;
Ok(RuntimeStatementResult::NoOp)
}
fn try_handle_range_ctas(
&self,
display_name: &str,
_canonical_name: &str,
query: &Query,
if_not_exists: bool,
or_replace: bool,
namespace: Option<String>,
) -> SqlResult<Option<RuntimeStatementResult<P>>> {
let select = match query.body.as_ref() {
SetExpr::Select(select) => select,
_ => return Ok(None),
};
if select.from.len() != 1 {
return Ok(None);
}
let table_with_joins = &select.from[0];
if !table_with_joins.joins.is_empty() {
return Ok(None);
}
let (range_size, range_alias) = match &table_with_joins.relation {
TableFactor::Table {
name,
args: Some(args),
alias,
..
} => {
let func_name = name.to_string().to_ascii_lowercase();
if func_name != "range" {
return Ok(None);
}
if args.args.len() != 1 {
return Err(Error::InvalidArgumentError(
"range table function expects a single argument".into(),
));
}
let size_expr = &args.args[0];
let range_size = match size_expr {
FunctionArg::Unnamed(FunctionArgExpr::Expr(SqlExpr::Value(value))) => {
match &value.value {
Value::Number(raw, _) => raw.parse::<i64>().map_err(|e| {
Error::InvalidArgumentError(format!(
"invalid range size literal {}: {}",
raw, e
))
})?,
other => {
return Err(Error::InvalidArgumentError(format!(
"unsupported range size value: {:?}",
other
)));
}
}
}
_ => {
return Err(Error::InvalidArgumentError(
"unsupported range argument".into(),
));
}
};
(range_size, alias.as_ref().map(|a| a.name.value.clone()))
}
_ => return Ok(None),
};
if range_size < 0 {
return Err(Error::InvalidArgumentError(
"range size must be non-negative".into(),
));
}
if select.projection.is_empty() {
return Err(Error::InvalidArgumentError(
"CREATE TABLE AS SELECT requires at least one projected column".into(),
));
}
let mut column_specs = Vec::with_capacity(select.projection.len());
let mut column_names = Vec::with_capacity(select.projection.len());
let mut row_template = Vec::with_capacity(select.projection.len());
for item in &select.projection {
match item {
SelectItem::ExprWithAlias { expr, alias } => {
let (value, data_type) = match expr {
SqlExpr::Value(value_with_span) => match &value_with_span.value {
Value::Number(raw, _) => {
let parsed = raw.parse::<i64>().map_err(|e| {
Error::InvalidArgumentError(format!(
"invalid numeric literal {}: {}",
raw, e
))
})?;
(
PlanValue::Integer(parsed),
arrow::datatypes::DataType::Int64,
)
}
Value::SingleQuotedString(s) => (
PlanValue::String(s.clone()),
arrow::datatypes::DataType::Utf8,
),
other => {
return Err(Error::InvalidArgumentError(format!(
"unsupported SELECT expression in range CTAS: {:?}",
other
)));
}
},
SqlExpr::Identifier(ident) => {
let ident_lower = ident.value.to_ascii_lowercase();
if range_alias
.as_ref()
.map(|a| a.eq_ignore_ascii_case(&ident_lower))
.unwrap_or(false)
|| ident_lower == "range"
{
return Err(Error::InvalidArgumentError(
"range() table function columns are not supported yet".into(),
));
}
return Err(Error::InvalidArgumentError(format!(
"unsupported identifier '{}' in range CTAS projection",
ident.value
)));
}
other => {
return Err(Error::InvalidArgumentError(format!(
"unsupported SELECT expression in range CTAS: {:?}",
other
)));
}
};
let column_name = alias.value.clone();
column_specs.push(ColumnSpec::new(column_name.clone(), data_type, true));
column_names.push(column_name);
row_template.push(value);
}
other => {
return Err(Error::InvalidArgumentError(format!(
"unsupported projection {:?} in range CTAS",
other
)));
}
}
}
let plan = CreateTablePlan {
name: display_name.to_string(),
if_not_exists,
or_replace,
columns: column_specs,
source: None,
namespace,
foreign_keys: Vec::new(),
};
let create_result = self.execute_plan_statement(PlanStatement::CreateTable(plan))?;
let row_count = range_size
.try_into()
.map_err(|_| Error::InvalidArgumentError("range size exceeds usize".into()))?;
if row_count > 0 {
let rows = vec![row_template; row_count];
let insert_plan = InsertPlan {
table: display_name.to_string(),
columns: column_names,
source: InsertSource::Rows(rows),
};
self.execute_plan_statement(PlanStatement::Insert(insert_plan))?;
}
Ok(Some(create_result))
}
fn try_handle_pragma_table_info(
&self,
query: &Query,
) -> SqlResult<Option<RuntimeStatementResult<P>>> {
let select = match query.body.as_ref() {
SetExpr::Select(select) => select,
_ => return Ok(None),
};
if select.from.len() != 1 {
return Ok(None);
}
let table_with_joins = &select.from[0];
if !table_with_joins.joins.is_empty() {
return Ok(None);
}
let table_name = match &table_with_joins.relation {
TableFactor::Table {
name,
args: Some(args),
..
} => {
let func_name = name.to_string().to_ascii_lowercase();
if func_name != "pragma_table_info" {
return Ok(None);
}
if args.args.len() != 1 {
return Err(Error::InvalidArgumentError(
"pragma_table_info expects exactly one argument".into(),
));
}
match &args.args[0] {
FunctionArg::Unnamed(FunctionArgExpr::Expr(SqlExpr::Value(value))) => {
match &value.value {
Value::SingleQuotedString(s) => s.clone(),
Value::DoubleQuotedString(s) => s.clone(),
_ => {
return Err(Error::InvalidArgumentError(
"pragma_table_info argument must be a string".into(),
));
}
}
}
_ => {
return Err(Error::InvalidArgumentError(
"pragma_table_info argument must be a string literal".into(),
));
}
}
}
_ => return Ok(None),
};
let context = self.engine.context();
let columns = context.table_column_specs(&table_name)?;
use arrow::array::{BooleanArray, Int32Array, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
let mut cid_values = Vec::new();
let mut name_values = Vec::new();
let mut type_values = Vec::new();
let mut notnull_values = Vec::new();
let mut dflt_value_values: Vec<Option<String>> = Vec::new();
let mut pk_values = Vec::new();
for (idx, col) in columns.iter().enumerate() {
cid_values.push(idx as i32);
name_values.push(col.name.clone());
type_values.push(format!("{:?}", col.data_type)); notnull_values.push(!col.nullable);
dflt_value_values.push(None); pk_values.push(col.primary_key);
}
let schema = Arc::new(Schema::new(vec![
Field::new("cid", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
Field::new("type", DataType::Utf8, false),
Field::new("notnull", DataType::Boolean, false),
Field::new("dflt_value", DataType::Utf8, true),
Field::new("pk", DataType::Boolean, false),
]));
use arrow::array::ArrayRef;
let mut batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(cid_values)) as ArrayRef,
Arc::new(StringArray::from(name_values)) as ArrayRef,
Arc::new(StringArray::from(type_values)) as ArrayRef,
Arc::new(BooleanArray::from(notnull_values)) as ArrayRef,
Arc::new(StringArray::from(dflt_value_values)) as ArrayRef,
Arc::new(BooleanArray::from(pk_values)) as ArrayRef,
],
)
.map_err(|e| Error::Internal(format!("failed to create pragma_table_info batch: {}", e)))?;
let projection_indices: Vec<usize> = select
.projection
.iter()
.filter_map(|item| {
match item {
SelectItem::UnnamedExpr(SqlExpr::Identifier(ident)) => {
schema.index_of(&ident.value).ok()
}
SelectItem::ExprWithAlias { expr, .. } => {
if let SqlExpr::Identifier(ident) = expr {
schema.index_of(&ident.value).ok()
} else {
None
}
}
SelectItem::Wildcard(_) => None, _ => None,
}
})
.collect();
let projected_schema;
if !projection_indices.is_empty() {
let projected_fields: Vec<Field> = projection_indices
.iter()
.map(|&idx| schema.field(idx).clone())
.collect();
projected_schema = Arc::new(Schema::new(projected_fields));
let projected_columns: Vec<ArrayRef> = projection_indices
.iter()
.map(|&idx| Arc::clone(batch.column(idx)))
.collect();
batch = RecordBatch::try_new(Arc::clone(&projected_schema), projected_columns)
.map_err(|e| Error::Internal(format!("failed to project columns: {}", e)))?;
} else {
projected_schema = schema;
}
if let Some(order_by) = &query.order_by {
use arrow::compute::SortColumn;
use arrow::compute::lexsort_to_indices;
use sqlparser::ast::OrderByKind;
let exprs = match &order_by.kind {
OrderByKind::Expressions(exprs) => exprs,
_ => {
return Err(Error::InvalidArgumentError(
"unsupported ORDER BY clause".into(),
));
}
};
let mut sort_columns = Vec::new();
for order_expr in exprs {
if let SqlExpr::Identifier(ident) = &order_expr.expr
&& let Ok(col_idx) = projected_schema.index_of(&ident.value)
{
let options = arrow::compute::SortOptions {
descending: !order_expr.options.asc.unwrap_or(true),
nulls_first: order_expr.options.nulls_first.unwrap_or(false),
};
sort_columns.push(SortColumn {
values: Arc::clone(batch.column(col_idx)),
options: Some(options),
});
}
}
if !sort_columns.is_empty() {
let indices = lexsort_to_indices(&sort_columns, None)
.map_err(|e| Error::Internal(format!("failed to sort: {}", e)))?;
use arrow::compute::take;
let sorted_columns: Result<Vec<ArrayRef>, _> = batch
.columns()
.iter()
.map(|col| take(col.as_ref(), &indices, None))
.collect();
batch = RecordBatch::try_new(
Arc::clone(&projected_schema),
sorted_columns
.map_err(|e| Error::Internal(format!("failed to apply sort: {}", e)))?,
)
.map_err(|e| Error::Internal(format!("failed to create sorted batch: {}", e)))?;
}
}
let execution = SelectExecution::new_single_batch(
table_name.clone(),
Arc::clone(&projected_schema),
batch,
);
Ok(Some(RuntimeStatementResult::Select {
table_name,
schema: projected_schema,
execution,
}))
}
fn handle_create_table_as(
&self,
display_name: String,
_canonical_name: String,
query: Query,
if_not_exists: bool,
or_replace: bool,
namespace: Option<String>,
) -> SqlResult<RuntimeStatementResult<P>> {
let select_plan = self.build_select_plan(query)?;
if select_plan.projections.is_empty() && select_plan.aggregates.is_empty() {
return Err(Error::InvalidArgumentError(
"CREATE TABLE AS SELECT requires at least one projected column".into(),
));
}
let plan = CreateTablePlan {
name: display_name,
if_not_exists,
or_replace,
columns: Vec::new(),
source: Some(CreateTableSource::Select {
plan: Box::new(select_plan),
}),
namespace,
foreign_keys: Vec::new(),
};
self.execute_plan_statement(PlanStatement::CreateTable(plan))
}
fn handle_insert(&self, stmt: sqlparser::ast::Insert) -> SqlResult<RuntimeStatementResult<P>> {
let table_name_debug =
Self::table_name_from_insert(&stmt).unwrap_or_else(|_| "unknown".to_string());
tracing::trace!(
"DEBUG SQL handle_insert called for table={}",
table_name_debug
);
if !self.engine.session().has_active_transaction()
&& self.is_table_marked_dropped(&table_name_debug)?
{
return Err(Error::TransactionContextError(
DROPPED_TABLE_TRANSACTION_ERR.into(),
));
}
if stmt.replace_into || stmt.ignore || stmt.or.is_some() {
return Err(Error::InvalidArgumentError(
"non-standard INSERT forms are not supported".into(),
));
}
if stmt.overwrite {
return Err(Error::InvalidArgumentError(
"INSERT OVERWRITE is not supported".into(),
));
}
if !stmt.assignments.is_empty() {
return Err(Error::InvalidArgumentError(
"INSERT ... SET is not supported".into(),
));
}
if stmt.partitioned.is_some() || !stmt.after_columns.is_empty() {
return Err(Error::InvalidArgumentError(
"partitioned INSERT is not supported".into(),
));
}
if stmt.returning.is_some() {
return Err(Error::InvalidArgumentError(
"INSERT ... RETURNING is not supported".into(),
));
}
if stmt.format_clause.is_some() || stmt.settings.is_some() {
return Err(Error::InvalidArgumentError(
"INSERT with FORMAT or SETTINGS is not supported".into(),
));
}
let (display_name, _canonical_name) = match &stmt.table {
TableObject::TableName(name) => canonical_object_name(name)?,
_ => {
return Err(Error::InvalidArgumentError(
"INSERT requires a plain table name".into(),
));
}
};
let columns: Vec<String> = stmt
.columns
.iter()
.map(|ident| ident.value.clone())
.collect();
let source_expr = stmt
.source
.as_ref()
.ok_or_else(|| Error::InvalidArgumentError("INSERT requires a VALUES clause".into()))?;
validate_simple_query(source_expr)?;
let insert_source = match source_expr.body.as_ref() {
SetExpr::Values(values) => {
if values.rows.is_empty() {
return Err(Error::InvalidArgumentError(
"INSERT VALUES list must contain at least one row".into(),
));
}
let mut rows: Vec<Vec<SqlValue>> = Vec::with_capacity(values.rows.len());
for row in &values.rows {
let mut converted = Vec::with_capacity(row.len());
for expr in row {
converted.push(SqlValue::try_from_expr(expr)?);
}
rows.push(converted);
}
InsertSource::Rows(
rows.into_iter()
.map(|row| row.into_iter().map(PlanValue::from).collect())
.collect(),
)
}
SetExpr::Select(select) => {
if let Some(rows) = extract_constant_select_rows(select.as_ref())? {
InsertSource::Rows(rows)
} else if let Some(range_rows) = extract_rows_from_range(select.as_ref())? {
InsertSource::Rows(range_rows.into_rows())
} else {
let select_plan = self.build_select_plan((**source_expr).clone())?;
InsertSource::Select {
plan: Box::new(select_plan),
}
}
}
_ => {
return Err(Error::InvalidArgumentError(
"unsupported INSERT source".into(),
));
}
};
let plan = InsertPlan {
table: display_name.clone(),
columns,
source: insert_source,
};
tracing::trace!(
"DEBUG SQL handle_insert: about to execute insert for table={}",
display_name
);
self.execute_plan_statement(PlanStatement::Insert(plan))
}
fn handle_update(
&self,
table: TableWithJoins,
assignments: Vec<Assignment>,
from: Option<UpdateTableFromKind>,
selection: Option<SqlExpr>,
returning: Option<Vec<SelectItem>>,
) -> SqlResult<RuntimeStatementResult<P>> {
if from.is_some() {
return Err(Error::InvalidArgumentError(
"UPDATE ... FROM is not supported yet".into(),
));
}
if returning.is_some() {
return Err(Error::InvalidArgumentError(
"UPDATE ... RETURNING is not supported".into(),
));
}
if assignments.is_empty() {
return Err(Error::InvalidArgumentError(
"UPDATE requires at least one assignment".into(),
));
}
let (display_name, canonical_name) = extract_single_table(std::slice::from_ref(&table))?;
if !self.engine.session().has_active_transaction()
&& self
.engine
.context()
.is_table_marked_dropped(&canonical_name)
{
return Err(Error::TransactionContextError(
DROPPED_TABLE_TRANSACTION_ERR.into(),
));
}
let catalog = self.engine.context().table_catalog();
let resolver = catalog.identifier_resolver();
let table_id = catalog.table_id(&canonical_name);
let mut column_assignments = Vec::with_capacity(assignments.len());
let mut seen: HashMap<String, ()> = HashMap::new();
for assignment in assignments {
let column_name = resolve_assignment_column_name(&assignment.target)?;
let normalized = column_name.to_ascii_lowercase();
if seen.insert(normalized, ()).is_some() {
return Err(Error::InvalidArgumentError(format!(
"duplicate column '{}' in UPDATE assignments",
column_name
)));
}
let value = match SqlValue::try_from_expr(&assignment.value) {
Ok(literal) => AssignmentValue::Literal(PlanValue::from(literal)),
Err(Error::InvalidArgumentError(msg))
if msg.contains("unsupported literal expression") =>
{
let translated = translate_scalar_with_context(
&resolver,
IdentifierContext::new(table_id),
&assignment.value,
)?;
AssignmentValue::Expression(translated)
}
Err(err) => return Err(err),
};
column_assignments.push(ColumnAssignment {
column: column_name,
value,
});
}
let filter = match selection {
Some(expr) => Some(translate_condition_with_context(
&resolver,
IdentifierContext::new(table_id),
&expr,
)?),
None => None,
};
let plan = UpdatePlan {
table: display_name.clone(),
assignments: column_assignments,
filter,
};
self.execute_plan_statement(PlanStatement::Update(plan))
}
#[allow(clippy::collapsible_if)]
fn handle_delete(&self, delete: Delete) -> SqlResult<RuntimeStatementResult<P>> {
let Delete {
tables,
from,
using,
selection,
returning,
order_by,
limit,
} = delete;
if !tables.is_empty() {
return Err(Error::InvalidArgumentError(
"multi-table DELETE is not supported yet".into(),
));
}
if let Some(using_tables) = using {
if !using_tables.is_empty() {
return Err(Error::InvalidArgumentError(
"DELETE ... USING is not supported yet".into(),
));
}
}
if returning.is_some() {
return Err(Error::InvalidArgumentError(
"DELETE ... RETURNING is not supported".into(),
));
}
if !order_by.is_empty() {
return Err(Error::InvalidArgumentError(
"DELETE ... ORDER BY is not supported yet".into(),
));
}
if limit.is_some() {
return Err(Error::InvalidArgumentError(
"DELETE ... LIMIT is not supported yet".into(),
));
}
let from_tables = match from {
FromTable::WithFromKeyword(tables) | FromTable::WithoutKeyword(tables) => tables,
};
let (display_name, canonical_name) = extract_single_table(&from_tables)?;
if !self.engine.session().has_active_transaction()
&& self
.engine
.context()
.is_table_marked_dropped(&canonical_name)
{
return Err(Error::TransactionContextError(
DROPPED_TABLE_TRANSACTION_ERR.into(),
));
}
let catalog = self.engine.context().table_catalog();
let resolver = catalog.identifier_resolver();
let table_id = catalog.table_id(&canonical_name);
let filter = selection
.map(|expr| {
translate_condition_with_context(&resolver, IdentifierContext::new(table_id), &expr)
})
.transpose()?;
let plan = DeletePlan {
table: display_name.clone(),
filter,
};
self.execute_plan_statement(PlanStatement::Delete(plan))
}
#[allow(clippy::too_many_arguments)] fn handle_drop(
&self,
object_type: ObjectType,
if_exists: bool,
names: Vec<ObjectName>,
cascade: bool,
restrict: bool,
purge: bool,
temporary: bool,
) -> SqlResult<RuntimeStatementResult<P>> {
if purge || temporary {
return Err(Error::InvalidArgumentError(
"DROP purge/temporary options are not supported".into(),
));
}
match object_type {
ObjectType::Table => {
if cascade || restrict {
return Err(Error::InvalidArgumentError(
"DROP TABLE CASCADE/RESTRICT is not supported".into(),
));
}
let session = self.engine.session();
for name in names {
let table_name = Self::object_name_to_string(&name)?;
session
.drop_table(&table_name, if_exists)
.map_err(|err| Self::map_table_error(&table_name, err))?;
}
Ok(RuntimeStatementResult::NoOp)
}
ObjectType::Schema => {
if restrict {
return Err(Error::InvalidArgumentError(
"DROP SCHEMA RESTRICT is not supported".into(),
));
}
let catalog = self.engine.context().table_catalog();
for name in names {
let (display_name, canonical_name) = canonical_object_name(&name)?;
if !catalog.schema_exists(&canonical_name) {
if if_exists {
continue;
}
return Err(Error::CatalogError(format!(
"Schema '{}' does not exist",
display_name
)));
}
if cascade {
let all_tables = catalog.table_names();
let schema_prefix = format!("{}.", canonical_name);
let ctx = self.engine.context();
for table in all_tables {
if table.to_ascii_lowercase().starts_with(&schema_prefix) {
ctx.drop_table_immediate(&table, false)?;
}
}
} else {
let all_tables = catalog.table_names();
let schema_prefix = format!("{}.", canonical_name);
let has_tables = all_tables
.iter()
.any(|t| t.to_ascii_lowercase().starts_with(&schema_prefix));
if has_tables {
return Err(Error::CatalogError(format!(
"Schema '{}' is not empty. Use CASCADE to drop schema and all its tables",
display_name
)));
}
}
if !catalog.unregister_schema(&canonical_name) && !if_exists {
return Err(Error::CatalogError(format!(
"Schema '{}' does not exist",
display_name
)));
}
}
Ok(RuntimeStatementResult::NoOp)
}
_ => Err(Error::InvalidArgumentError(format!(
"DROP {} is not supported",
object_type
))),
}
}
fn handle_query(&self, query: Query) -> SqlResult<RuntimeStatementResult<P>> {
if let Some(result) = self.try_handle_pragma_table_info(&query)? {
return Ok(result);
}
let select_plan = self.build_select_plan(query)?;
self.execute_plan_statement(PlanStatement::Select(select_plan))
}
fn build_select_plan(&self, query: Query) -> SqlResult<SelectPlan> {
if self.engine.session().has_active_transaction() && self.engine.session().is_aborted() {
return Err(Error::TransactionContextError(
"TransactionContext Error: transaction is aborted".into(),
));
}
validate_simple_query(&query)?;
let catalog = self.engine.context().table_catalog();
let resolver = catalog.identifier_resolver();
let (mut select_plan, select_context) = match query.body.as_ref() {
SetExpr::Select(select) => self.translate_select(select.as_ref(), &resolver)?,
other => {
return Err(Error::InvalidArgumentError(format!(
"unsupported query expression: {other:?}"
)));
}
};
if let Some(order_by) = &query.order_by {
if !select_plan.aggregates.is_empty() {
return Err(Error::InvalidArgumentError(
"ORDER BY is not supported for aggregate queries".into(),
));
}
let order_plan = self.translate_order_by(&resolver, select_context, order_by)?;
select_plan = select_plan.with_order_by(order_plan);
}
Ok(select_plan)
}
fn translate_select(
&self,
select: &Select,
resolver: &IdentifierResolver<'_>,
) -> SqlResult<(SelectPlan, IdentifierContext)> {
if select.distinct.is_some() {
return Err(Error::InvalidArgumentError(
"SELECT DISTINCT is not supported".into(),
));
}
if select.top.is_some() {
return Err(Error::InvalidArgumentError(
"SELECT TOP is not supported".into(),
));
}
if select.exclude.is_some() {
return Err(Error::InvalidArgumentError(
"SELECT EXCLUDE is not supported".into(),
));
}
if select.into.is_some() {
return Err(Error::InvalidArgumentError(
"SELECT INTO is not supported".into(),
));
}
if !select.lateral_views.is_empty() {
return Err(Error::InvalidArgumentError(
"LATERAL VIEW is not supported".into(),
));
}
if select.prewhere.is_some() {
return Err(Error::InvalidArgumentError(
"PREWHERE is not supported".into(),
));
}
if !group_by_is_empty(&select.group_by) || select.value_table_mode.is_some() {
return Err(Error::InvalidArgumentError(
"GROUP BY and SELECT AS VALUE/STRUCT are not supported".into(),
));
}
if !select.cluster_by.is_empty()
|| !select.distribute_by.is_empty()
|| !select.sort_by.is_empty()
{
return Err(Error::InvalidArgumentError(
"CLUSTER/DISTRIBUTE/SORT BY clauses are not supported".into(),
));
}
if select.having.is_some()
|| !select.named_window.is_empty()
|| select.qualify.is_some()
|| select.connect_by.is_some()
{
return Err(Error::InvalidArgumentError(
"advanced SELECT clauses are not supported".into(),
));
}
let table_alias = select
.from
.first()
.and_then(|table_with_joins| match &table_with_joins.relation {
TableFactor::Table { alias, .. } => alias.as_ref().map(|a| a.name.value.clone()),
_ => None,
});
if let Some(alias) = table_alias.as_ref() {
validate_projection_alias_qualifiers(&select.projection, alias)?;
}
let catalog = self.engine.context().table_catalog();
let (mut plan, id_context) = if select.from.is_empty() {
let mut p = SelectPlan::new("");
let projections = self.build_projection_list(
resolver,
IdentifierContext::new(None),
&select.projection,
)?;
p = p.with_projections(projections);
(p, IdentifierContext::new(None))
} else if select.from.len() == 1 {
let (display_name, canonical_name) = extract_single_table(&select.from)?;
let table_id = catalog.table_id(&canonical_name);
let mut p = SelectPlan::new(display_name.clone());
if let Some(aggregates) = self.detect_simple_aggregates(&select.projection)? {
p = p.with_aggregates(aggregates);
} else {
let projections = self.build_projection_list(
resolver,
IdentifierContext::new(table_id),
&select.projection,
)?;
p = p.with_projections(projections);
}
(p, IdentifierContext::new(table_id))
} else {
let tables = extract_tables(&select.from)?;
let mut p = SelectPlan::with_tables(tables);
let projections = self.build_projection_list(
resolver,
IdentifierContext::new(None),
&select.projection,
)?;
p = p.with_projections(projections);
(p, IdentifierContext::new(None))
};
let filter_expr = match &select.selection {
Some(expr) => Some(translate_condition_with_context(
resolver, id_context, expr,
)?),
None => None,
};
plan = plan.with_filter(filter_expr);
Ok((plan, id_context))
}
fn translate_order_by(
&self,
resolver: &IdentifierResolver<'_>,
id_context: IdentifierContext,
order_by: &OrderBy,
) -> SqlResult<Vec<OrderByPlan>> {
let exprs = match &order_by.kind {
OrderByKind::Expressions(exprs) => exprs,
_ => {
return Err(Error::InvalidArgumentError(
"unsupported ORDER BY clause".into(),
));
}
};
let base_nulls_first = self.default_nulls_first.load(AtomicOrdering::Relaxed);
let resolve_simple_column = |expr: &SqlExpr| -> SqlResult<String> {
let scalar = translate_scalar_with_context(resolver, id_context, expr)?;
match scalar {
llkv_expr::expr::ScalarExpr::Column(column) => Ok(column),
other => Err(Error::InvalidArgumentError(format!(
"ORDER BY expression must reference a simple column, found {other:?}"
))),
}
};
let mut plans = Vec::with_capacity(exprs.len());
for order_expr in exprs {
let ascending = order_expr.options.asc.unwrap_or(true);
let default_nulls_first_for_direction = if ascending {
base_nulls_first
} else {
!base_nulls_first
};
let nulls_first = order_expr
.options
.nulls_first
.unwrap_or(default_nulls_first_for_direction);
if let SqlExpr::Identifier(ident) = &order_expr.expr
&& ident.value.eq_ignore_ascii_case("ALL")
&& ident.quote_style.is_none()
{
plans.push(OrderByPlan {
target: OrderTarget::All,
sort_type: OrderSortType::Native,
ascending,
nulls_first,
});
continue;
}
let (target, sort_type) = match &order_expr.expr {
SqlExpr::Identifier(_) | SqlExpr::CompoundIdentifier(_) => (
OrderTarget::Column(resolve_simple_column(&order_expr.expr)?),
OrderSortType::Native,
),
SqlExpr::Cast {
expr,
data_type:
SqlDataType::Int(_)
| SqlDataType::Integer(_)
| SqlDataType::BigInt(_)
| SqlDataType::SmallInt(_)
| SqlDataType::TinyInt(_),
..
} => (
OrderTarget::Column(resolve_simple_column(expr)?),
OrderSortType::CastTextToInteger,
),
SqlExpr::Cast { data_type, .. } => {
return Err(Error::InvalidArgumentError(format!(
"ORDER BY CAST target type {:?} is not supported",
data_type
)));
}
SqlExpr::Value(value_with_span) => match &value_with_span.value {
Value::Number(raw, _) => {
let position: usize = raw.parse().map_err(|_| {
Error::InvalidArgumentError(format!(
"ORDER BY position '{}' is not a valid positive integer",
raw
))
})?;
if position == 0 {
return Err(Error::InvalidArgumentError(
"ORDER BY position must be at least 1".into(),
));
}
(OrderTarget::Index(position - 1), OrderSortType::Native)
}
other => {
return Err(Error::InvalidArgumentError(format!(
"unsupported ORDER BY literal expression: {other:?}"
)));
}
},
other => {
return Err(Error::InvalidArgumentError(format!(
"unsupported ORDER BY expression: {other:?}"
)));
}
};
plans.push(OrderByPlan {
target,
sort_type,
ascending,
nulls_first,
});
}
Ok(plans)
}
fn detect_simple_aggregates(
&self,
projection_items: &[SelectItem],
) -> SqlResult<Option<Vec<AggregateExpr>>> {
if projection_items.is_empty() {
return Ok(None);
}
let mut specs: Vec<AggregateExpr> = Vec::with_capacity(projection_items.len());
for (idx, item) in projection_items.iter().enumerate() {
let (expr, alias_opt) = match item {
SelectItem::UnnamedExpr(expr) => (expr, None),
SelectItem::ExprWithAlias { expr, alias } => (expr, Some(alias.value.clone())),
_ => return Ok(None),
};
let alias = alias_opt.unwrap_or_else(|| format!("col{}", idx + 1));
let SqlExpr::Function(func) = expr else {
return Ok(None);
};
if func.uses_odbc_syntax {
return Err(Error::InvalidArgumentError(
"ODBC function syntax is not supported in aggregate queries".into(),
));
}
if !matches!(func.parameters, FunctionArguments::None) {
return Err(Error::InvalidArgumentError(
"parameterized aggregate functions are not supported".into(),
));
}
if func.filter.is_some()
|| func.null_treatment.is_some()
|| func.over.is_some()
|| !func.within_group.is_empty()
{
return Err(Error::InvalidArgumentError(
"advanced aggregate clauses are not supported".into(),
));
}
let mut is_distinct = false;
let args_slice: &[FunctionArg] = match &func.args {
FunctionArguments::List(list) => {
if let Some(dup) = &list.duplicate_treatment {
use sqlparser::ast::DuplicateTreatment;
match dup {
DuplicateTreatment::All => {}
DuplicateTreatment::Distinct => is_distinct = true,
}
}
if !list.clauses.is_empty() {
return Err(Error::InvalidArgumentError(
"aggregate argument clauses are not supported".into(),
));
}
&list.args
}
FunctionArguments::None => &[],
FunctionArguments::Subquery(_) => {
return Err(Error::InvalidArgumentError(
"aggregate subquery arguments are not supported".into(),
));
}
};
let func_name = if func.name.0.len() == 1 {
match &func.name.0[0] {
ObjectNamePart::Identifier(ident) => ident.value.to_ascii_lowercase(),
_ => {
return Err(Error::InvalidArgumentError(
"unsupported aggregate function name".into(),
));
}
}
} else {
return Err(Error::InvalidArgumentError(
"qualified aggregate function names are not supported".into(),
));
};
let aggregate = match func_name.as_str() {
"count" => {
if args_slice.len() != 1 {
return Err(Error::InvalidArgumentError(
"COUNT accepts exactly one argument".into(),
));
}
match &args_slice[0] {
FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => {
if is_distinct {
return Err(Error::InvalidArgumentError(
"COUNT(DISTINCT *) is not supported".into(),
));
}
AggregateExpr::count_star(alias)
}
FunctionArg::Unnamed(FunctionArgExpr::Expr(arg_expr)) => {
let column = resolve_column_name(arg_expr)?;
if is_distinct {
AggregateExpr::count_distinct_column(column, alias)
} else {
AggregateExpr::count_column(column, alias)
}
}
FunctionArg::Named { .. } | FunctionArg::ExprNamed { .. } => {
return Err(Error::InvalidArgumentError(
"named COUNT arguments are not supported".into(),
));
}
FunctionArg::Unnamed(FunctionArgExpr::QualifiedWildcard(_)) => {
return Err(Error::InvalidArgumentError(
"COUNT does not support qualified wildcards".into(),
));
}
}
}
"sum" | "min" | "max" => {
if is_distinct {
return Err(Error::InvalidArgumentError(
"DISTINCT is not supported for this aggregate".into(),
));
}
if args_slice.len() != 1 {
return Err(Error::InvalidArgumentError(format!(
"{} accepts exactly one argument",
func_name.to_uppercase()
)));
}
let arg_expr = match &args_slice[0] {
FunctionArg::Unnamed(FunctionArgExpr::Expr(arg_expr)) => arg_expr,
FunctionArg::Unnamed(FunctionArgExpr::Wildcard)
| FunctionArg::Unnamed(FunctionArgExpr::QualifiedWildcard(_)) => {
return Err(Error::InvalidArgumentError(format!(
"{} does not support wildcard arguments",
func_name.to_uppercase()
)));
}
FunctionArg::Named { .. } | FunctionArg::ExprNamed { .. } => {
return Err(Error::InvalidArgumentError(format!(
"{} arguments must be column references",
func_name.to_uppercase()
)));
}
};
if func_name == "sum" {
if let Some(column) = parse_count_nulls_case(arg_expr)? {
AggregateExpr::count_nulls(column, alias)
} else {
let column = resolve_column_name(arg_expr)?;
AggregateExpr::sum_int64(column, alias)
}
} else {
let column = resolve_column_name(arg_expr)?;
if func_name == "min" {
AggregateExpr::min_int64(column, alias)
} else {
AggregateExpr::max_int64(column, alias)
}
}
}
_ => return Ok(None),
};
specs.push(aggregate);
}
if specs.is_empty() {
return Ok(None);
}
Ok(Some(specs))
}
fn build_projection_list(
&self,
resolver: &IdentifierResolver<'_>,
id_context: IdentifierContext,
projection_items: &[SelectItem],
) -> SqlResult<Vec<SelectProjection>> {
if projection_items.is_empty() {
return Err(Error::InvalidArgumentError(
"SELECT projection must include at least one column".into(),
));
}
let mut projections = Vec::with_capacity(projection_items.len());
for (idx, item) in projection_items.iter().enumerate() {
match item {
SelectItem::Wildcard(options) => {
if let Some(exclude) = &options.opt_exclude {
use sqlparser::ast::ExcludeSelectItem;
let exclude_cols = match exclude {
ExcludeSelectItem::Single(ident) => vec![ident.value.clone()],
ExcludeSelectItem::Multiple(idents) => {
idents.iter().map(|id| id.value.clone()).collect()
}
};
projections.push(SelectProjection::AllColumnsExcept {
exclude: exclude_cols,
});
} else {
projections.push(SelectProjection::AllColumns);
}
}
SelectItem::QualifiedWildcard(kind, _) => match kind {
SelectItemQualifiedWildcardKind::ObjectName(name) => {
projections.push(SelectProjection::Column {
name: name.to_string(),
alias: None,
});
}
SelectItemQualifiedWildcardKind::Expr(_) => {
return Err(Error::InvalidArgumentError(
"expression-qualified wildcards are not supported".into(),
));
}
},
SelectItem::UnnamedExpr(expr) => match expr {
SqlExpr::Identifier(ident) => {
let parts = vec![ident.value.clone()];
let resolution = resolver.resolve(&parts, id_context)?;
if resolution.is_simple() {
projections.push(SelectProjection::Column {
name: resolution.column().to_string(),
alias: None,
});
} else {
let alias = format!("col{}", idx + 1);
projections.push(SelectProjection::Computed {
expr: resolution.into_scalar_expr(),
alias,
});
}
}
SqlExpr::CompoundIdentifier(parts) => {
let name_parts: Vec<String> =
parts.iter().map(|part| part.value.clone()).collect();
let resolution = resolver.resolve(&name_parts, id_context)?;
if resolution.is_simple() {
projections.push(SelectProjection::Column {
name: resolution.column().to_string(),
alias: None,
});
} else {
let alias = format!("col{}", idx + 1);
projections.push(SelectProjection::Computed {
expr: resolution.into_scalar_expr(),
alias,
});
}
}
_ => {
let alias = format!("col{}", idx + 1);
let scalar = translate_scalar_with_context(resolver, id_context, expr)?;
projections.push(SelectProjection::Computed {
expr: scalar,
alias,
});
}
},
SelectItem::ExprWithAlias { expr, alias } => match expr {
SqlExpr::Identifier(ident) => {
let parts = vec![ident.value.clone()];
let resolution = resolver.resolve(&parts, id_context)?;
if resolution.is_simple() {
projections.push(SelectProjection::Column {
name: resolution.column().to_string(),
alias: Some(alias.value.clone()),
});
} else {
projections.push(SelectProjection::Computed {
expr: resolution.into_scalar_expr(),
alias: alias.value.clone(),
});
}
}
SqlExpr::CompoundIdentifier(parts) => {
let name_parts: Vec<String> =
parts.iter().map(|part| part.value.clone()).collect();
let resolution = resolver.resolve(&name_parts, id_context)?;
if resolution.is_simple() {
projections.push(SelectProjection::Column {
name: resolution.column().to_string(),
alias: Some(alias.value.clone()),
});
} else {
projections.push(SelectProjection::Computed {
expr: resolution.into_scalar_expr(),
alias: alias.value.clone(),
});
}
}
_ => {
let scalar = translate_scalar_with_context(resolver, id_context, expr)?;
projections.push(SelectProjection::Computed {
expr: scalar,
alias: alias.value.clone(),
});
}
},
}
}
Ok(projections)
}
#[allow(clippy::too_many_arguments)] fn handle_start_transaction(
&self,
modes: Vec<TransactionMode>,
begin: bool,
transaction: Option<BeginTransactionKind>,
modifier: Option<TransactionModifier>,
statements: Vec<Statement>,
exception: Option<Vec<ExceptionWhen>>,
has_end_keyword: bool,
) -> SqlResult<RuntimeStatementResult<P>> {
if !modes.is_empty() {
return Err(Error::InvalidArgumentError(
"transaction modes are not supported".into(),
));
}
if modifier.is_some() {
return Err(Error::InvalidArgumentError(
"transaction modifiers are not supported".into(),
));
}
if !statements.is_empty() || exception.is_some() || has_end_keyword {
return Err(Error::InvalidArgumentError(
"BEGIN blocks with inline statements or exceptions are not supported".into(),
));
}
if let Some(kind) = transaction {
match kind {
BeginTransactionKind::Transaction | BeginTransactionKind::Work => {}
}
}
if !begin {
tracing::warn!("Currently treat `START TRANSACTION` same as `BEGIN`")
}
self.execute_plan_statement(PlanStatement::BeginTransaction)
}
fn handle_commit(
&self,
chain: bool,
end: bool,
modifier: Option<TransactionModifier>,
) -> SqlResult<RuntimeStatementResult<P>> {
if chain {
return Err(Error::InvalidArgumentError(
"COMMIT AND [NO] CHAIN is not supported".into(),
));
}
if end {
return Err(Error::InvalidArgumentError(
"END blocks are not supported".into(),
));
}
if modifier.is_some() {
return Err(Error::InvalidArgumentError(
"transaction modifiers are not supported".into(),
));
}
self.execute_plan_statement(PlanStatement::CommitTransaction)
}
fn handle_rollback(
&self,
chain: bool,
savepoint: Option<Ident>,
) -> SqlResult<RuntimeStatementResult<P>> {
if chain {
return Err(Error::InvalidArgumentError(
"ROLLBACK AND [NO] CHAIN is not supported".into(),
));
}
if savepoint.is_some() {
return Err(Error::InvalidArgumentError(
"ROLLBACK TO SAVEPOINT is not supported".into(),
));
}
self.execute_plan_statement(PlanStatement::RollbackTransaction)
}
fn handle_set(&self, set_stmt: Set) -> SqlResult<RuntimeStatementResult<P>> {
match set_stmt {
Set::SingleAssignment {
scope,
hivevar,
variable,
values,
} => {
if scope.is_some() || hivevar {
return Err(Error::InvalidArgumentError(
"SET modifiers are not supported".into(),
));
}
let variable_name_raw = variable.to_string();
let variable_name = variable_name_raw.to_ascii_lowercase();
match variable_name.as_str() {
"default_null_order" => {
if values.len() != 1 {
return Err(Error::InvalidArgumentError(
"SET default_null_order expects exactly one value".into(),
));
}
let value_expr = &values[0];
let normalized = match value_expr {
SqlExpr::Value(value_with_span) => value_with_span
.value
.clone()
.into_string()
.map(|s| s.to_ascii_lowercase()),
SqlExpr::Identifier(ident) => Some(ident.value.to_ascii_lowercase()),
_ => None,
};
if !matches!(normalized.as_deref(), Some("nulls_first" | "nulls_last")) {
return Err(Error::InvalidArgumentError(format!(
"unsupported value for SET default_null_order: {value_expr:?}"
)));
}
let use_nulls_first = matches!(normalized.as_deref(), Some("nulls_first"));
self.default_nulls_first
.store(use_nulls_first, AtomicOrdering::Relaxed);
Ok(RuntimeStatementResult::NoOp)
}
"immediate_transaction_mode" => {
if values.len() != 1 {
return Err(Error::InvalidArgumentError(
"SET immediate_transaction_mode expects exactly one value".into(),
));
}
let normalized = values[0].to_string().to_ascii_lowercase();
let enabled = match normalized.as_str() {
"true" | "on" | "1" => true,
"false" | "off" | "0" => false,
_ => {
return Err(Error::InvalidArgumentError(format!(
"unsupported value for SET immediate_transaction_mode: {}",
values[0]
)));
}
};
if !enabled {
tracing::warn!(
"SET immediate_transaction_mode=false has no effect; continuing with auto mode"
);
}
Ok(RuntimeStatementResult::NoOp)
}
_ => Err(Error::InvalidArgumentError(format!(
"unsupported SET variable: {variable_name_raw}"
))),
}
}
other => Err(Error::InvalidArgumentError(format!(
"unsupported SQL SET statement: {other:?}",
))),
}
}
fn handle_pragma(
&self,
name: ObjectName,
value: Option<Value>,
is_eq: bool,
) -> SqlResult<RuntimeStatementResult<P>> {
let (display, canonical) = canonical_object_name(&name)?;
if value.is_some() || is_eq {
return Err(Error::InvalidArgumentError(format!(
"PRAGMA '{display}' does not accept a value"
)));
}
match canonical.as_str() {
"enable_verification" | "disable_verification" => Ok(RuntimeStatementResult::NoOp),
_ => Err(Error::InvalidArgumentError(format!(
"unsupported PRAGMA '{}'",
display
))),
}
}
}
fn canonical_object_name(name: &ObjectName) -> SqlResult<(String, String)> {
if name.0.is_empty() {
return Err(Error::InvalidArgumentError(
"object name must not be empty".into(),
));
}
let mut parts: Vec<String> = Vec::with_capacity(name.0.len());
for part in &name.0 {
let ident = match part {
ObjectNamePart::Identifier(ident) => ident,
_ => {
return Err(Error::InvalidArgumentError(
"object names using functions are not supported".into(),
));
}
};
parts.push(ident.value.clone());
}
let display = parts.join(".");
let canonical = display.to_ascii_lowercase();
Ok((display, canonical))
}
fn parse_schema_qualified_name(name: &ObjectName) -> SqlResult<(Option<String>, String)> {
if name.0.is_empty() {
return Err(Error::InvalidArgumentError(
"object name must not be empty".into(),
));
}
let mut parts: Vec<String> = Vec::with_capacity(name.0.len());
for part in &name.0 {
let ident = match part {
ObjectNamePart::Identifier(ident) => ident,
_ => {
return Err(Error::InvalidArgumentError(
"object names using functions are not supported".into(),
));
}
};
parts.push(ident.value.clone());
}
match parts.len() {
1 => Ok((None, parts[0].clone())),
2 => Ok((Some(parts[0].clone()), parts[1].clone())),
_ => Err(Error::InvalidArgumentError(format!(
"table name has too many parts: {}",
name
))),
}
}
fn extract_index_column_name(
index_col: &sqlparser::ast::IndexColumn,
context: &str,
allow_sort_options: bool,
allow_compound: bool,
) -> SqlResult<String> {
use sqlparser::ast::Expr as SqlExpr;
if index_col.operator_class.is_some() {
return Err(Error::InvalidArgumentError(format!(
"{} operator classes are not supported",
context
)));
}
let order_expr = &index_col.column;
if allow_sort_options {
let ascending = order_expr.options.asc.unwrap_or(true);
let nulls_first = order_expr.options.nulls_first.unwrap_or(false);
if !ascending {
return Err(Error::InvalidArgumentError(format!(
"{} DESC ordering is not supported",
context
)));
}
if nulls_first {
return Err(Error::InvalidArgumentError(format!(
"{} NULLS FIRST ordering is not supported",
context
)));
}
} else {
if order_expr.options.asc.is_some()
|| order_expr.options.nulls_first.is_some()
|| order_expr.with_fill.is_some()
{
return Err(Error::InvalidArgumentError(format!(
"{} columns must be simple identifiers",
context
)));
}
}
let column_name = match &order_expr.expr {
SqlExpr::Identifier(ident) => ident.value.clone(),
SqlExpr::CompoundIdentifier(parts) => {
if allow_compound {
parts
.last()
.map(|ident| ident.value.clone())
.ok_or_else(|| {
Error::InvalidArgumentError(format!(
"invalid column reference in {}",
context
))
})?
} else if parts.len() == 1 {
parts[0].value.clone()
} else {
return Err(Error::InvalidArgumentError(format!(
"{} columns must be column identifiers",
context
)));
}
}
other => {
return Err(Error::InvalidArgumentError(format!(
"{} only supports column references, found {:?}",
context, other
)));
}
};
Ok(column_name)
}
fn validate_create_table_common(stmt: &sqlparser::ast::CreateTable) -> SqlResult<()> {
if stmt.clone.is_some() || stmt.like.is_some() {
return Err(Error::InvalidArgumentError(
"CREATE TABLE LIKE/CLONE is not supported".into(),
));
}
if stmt.or_replace && stmt.if_not_exists {
return Err(Error::InvalidArgumentError(
"CREATE TABLE cannot combine OR REPLACE with IF NOT EXISTS".into(),
));
}
use sqlparser::ast::TableConstraint;
let mut seen_primary_key = false;
for constraint in &stmt.constraints {
match constraint {
TableConstraint::PrimaryKey { .. } => {
if seen_primary_key {
return Err(Error::InvalidArgumentError(
"multiple PRIMARY KEY constraints are not supported".into(),
));
}
seen_primary_key = true;
}
TableConstraint::Unique { .. } => {
}
TableConstraint::ForeignKey { .. } => {
}
other => {
return Err(Error::InvalidArgumentError(format!(
"table-level constraint {:?} is not supported",
other
)));
}
}
}
Ok(())
}
fn validate_check_constraint(
check_expr: &sqlparser::ast::Expr,
table_name: &str,
column_names: &[&str],
) -> SqlResult<()> {
use sqlparser::ast::Expr as SqlExpr;
let column_names_lower: HashSet<String> = column_names
.iter()
.map(|name| name.to_ascii_lowercase())
.collect();
let mut stack: Vec<&SqlExpr> = vec![check_expr];
while let Some(expr) = stack.pop() {
match expr {
SqlExpr::Subquery(_) => {
return Err(Error::InvalidArgumentError(
"Subqueries are not allowed in CHECK constraints".into(),
));
}
SqlExpr::Function(func) => {
let func_name = func.name.to_string().to_uppercase();
if matches!(func_name.as_str(), "SUM" | "AVG" | "COUNT" | "MIN" | "MAX") {
return Err(Error::InvalidArgumentError(
"Aggregate functions are not allowed in CHECK constraints".into(),
));
}
if let sqlparser::ast::FunctionArguments::List(list) = &func.args {
for arg in &list.args {
if let sqlparser::ast::FunctionArg::Unnamed(
sqlparser::ast::FunctionArgExpr::Expr(expr),
) = arg
{
stack.push(expr);
}
}
}
}
SqlExpr::Identifier(ident) => {
if !column_names_lower.contains(&ident.value.to_ascii_lowercase()) {
return Err(Error::InvalidArgumentError(format!(
"Column '{}' referenced in CHECK constraint does not exist",
ident.value
)));
}
}
SqlExpr::CompoundIdentifier(idents) => {
if idents.len() == 2 {
let first = idents[0].value.as_str();
let second = &idents[1].value;
if column_names_lower.contains(&first.to_ascii_lowercase()) {
continue;
}
if !first.eq_ignore_ascii_case(table_name) {
return Err(Error::InvalidArgumentError(format!(
"CHECK constraint references column from different table '{}'",
first
)));
}
if !column_names_lower.contains(&second.to_ascii_lowercase()) {
return Err(Error::InvalidArgumentError(format!(
"Column '{}' referenced in CHECK constraint does not exist",
second
)));
}
} else if idents.len() == 3 {
let first = &idents[0].value;
let second = &idents[1].value;
let third = &idents[2].value;
if first.eq_ignore_ascii_case(table_name) {
if !column_names_lower.contains(&second.to_ascii_lowercase()) {
return Err(Error::InvalidArgumentError(format!(
"Column '{}' referenced in CHECK constraint does not exist",
second
)));
}
} else if second.eq_ignore_ascii_case(table_name) {
if !column_names_lower.contains(&third.to_ascii_lowercase()) {
return Err(Error::InvalidArgumentError(format!(
"Column '{}' referenced in CHECK constraint does not exist",
third
)));
}
} else {
return Err(Error::InvalidArgumentError(format!(
"CHECK constraint references column from different table '{}'",
second
)));
}
}
}
SqlExpr::BinaryOp { left, right, .. } => {
stack.push(left);
stack.push(right);
}
SqlExpr::UnaryOp { expr, .. } | SqlExpr::Nested(expr) => {
stack.push(expr);
}
SqlExpr::Value(_) | SqlExpr::TypedString { .. } => {}
_ => {}
}
}
Ok(())
}
fn validate_create_table_definition(stmt: &sqlparser::ast::CreateTable) -> SqlResult<()> {
for column in &stmt.columns {
for ColumnOptionDef { option, .. } in &column.options {
match option {
ColumnOption::Null
| ColumnOption::NotNull
| ColumnOption::Unique { .. }
| ColumnOption::Check(_)
| ColumnOption::ForeignKey { .. } => {}
ColumnOption::Default(_) => {
return Err(Error::InvalidArgumentError(format!(
"DEFAULT values are not supported for column '{}'",
column.name
)));
}
other => {
return Err(Error::InvalidArgumentError(format!(
"unsupported column option {:?} on '{}'",
other, column.name
)));
}
}
}
}
Ok(())
}
fn validate_create_table_as(stmt: &sqlparser::ast::CreateTable) -> SqlResult<()> {
if !stmt.columns.is_empty() {
return Err(Error::InvalidArgumentError(
"CREATE TABLE AS SELECT does not support column definitions yet".into(),
));
}
Ok(())
}
fn validate_simple_query(query: &Query) -> SqlResult<()> {
if query.with.is_some() {
return Err(Error::InvalidArgumentError(
"WITH clauses are not supported".into(),
));
}
if let Some(limit_clause) = &query.limit_clause {
match limit_clause {
LimitClause::LimitOffset {
offset: Some(_), ..
}
| LimitClause::OffsetCommaLimit { .. } => {
return Err(Error::InvalidArgumentError(
"OFFSET clauses are not supported".into(),
));
}
LimitClause::LimitOffset { limit_by, .. } if !limit_by.is_empty() => {
return Err(Error::InvalidArgumentError(
"LIMIT BY clauses are not supported".into(),
));
}
_ => {}
}
}
if query.fetch.is_some() {
return Err(Error::InvalidArgumentError(
"FETCH clauses are not supported".into(),
));
}
Ok(())
}
fn resolve_column_name(expr: &SqlExpr) -> SqlResult<String> {
match expr {
SqlExpr::Identifier(ident) => Ok(ident.value.clone()),
SqlExpr::CompoundIdentifier(parts) => {
if let Some(last) = parts.last() {
Ok(last.value.clone())
} else {
Err(Error::InvalidArgumentError(
"empty column identifier".into(),
))
}
}
_ => Err(Error::InvalidArgumentError(
"aggregate arguments must be plain column identifiers".into(),
)),
}
}
fn validate_projection_alias_qualifiers(
projection_items: &[SelectItem],
alias: &str,
) -> SqlResult<()> {
let alias_lower = alias.to_ascii_lowercase();
for item in projection_items {
match item {
SelectItem::UnnamedExpr(expr) | SelectItem::ExprWithAlias { expr, .. } => {
if let SqlExpr::CompoundIdentifier(parts) = expr
&& parts.len() >= 2
&& let Some(first) = parts.first()
&& !first.value.eq_ignore_ascii_case(&alias_lower)
{
return Err(Error::InvalidArgumentError(format!(
"Binder Error: table '{}' not found",
first.value
)));
}
}
_ => {}
}
}
Ok(())
}
#[allow(dead_code)] fn expr_contains_aggregate(expr: &llkv_expr::expr::ScalarExpr<String>) -> bool {
match expr {
llkv_expr::expr::ScalarExpr::Aggregate(_) => true,
llkv_expr::expr::ScalarExpr::Binary { left, right, .. } => {
expr_contains_aggregate(left) || expr_contains_aggregate(right)
}
llkv_expr::expr::ScalarExpr::GetField { base, .. } => expr_contains_aggregate(base),
llkv_expr::expr::ScalarExpr::Column(_) | llkv_expr::expr::ScalarExpr::Literal(_) => false,
}
}
fn try_parse_aggregate_function(
func: &sqlparser::ast::Function,
) -> SqlResult<Option<llkv_expr::expr::AggregateCall<String>>> {
use sqlparser::ast::{FunctionArg, FunctionArgExpr, FunctionArguments, ObjectNamePart};
if func.uses_odbc_syntax {
return Ok(None);
}
if !matches!(func.parameters, FunctionArguments::None) {
return Ok(None);
}
if func.filter.is_some()
|| func.null_treatment.is_some()
|| func.over.is_some()
|| !func.within_group.is_empty()
{
return Ok(None);
}
let func_name = if func.name.0.len() == 1 {
match &func.name.0[0] {
ObjectNamePart::Identifier(ident) => ident.value.to_ascii_lowercase(),
_ => return Ok(None),
}
} else {
return Ok(None);
};
let args_slice: &[FunctionArg] = match &func.args {
FunctionArguments::List(list) => {
if list.duplicate_treatment.is_some() || !list.clauses.is_empty() {
return Ok(None);
}
&list.args
}
FunctionArguments::None => &[],
FunctionArguments::Subquery(_) => return Ok(None),
};
let agg_call = match func_name.as_str() {
"count" => {
if args_slice.len() != 1 {
return Err(Error::InvalidArgumentError(
"COUNT accepts exactly one argument".into(),
));
}
match &args_slice[0] {
FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => {
llkv_expr::expr::AggregateCall::CountStar
}
FunctionArg::Unnamed(FunctionArgExpr::Expr(arg_expr)) => {
let column = resolve_column_name(arg_expr)?;
llkv_expr::expr::AggregateCall::Count(column)
}
_ => {
return Err(Error::InvalidArgumentError(
"unsupported COUNT argument".into(),
));
}
}
}
"sum" => {
if args_slice.len() != 1 {
return Err(Error::InvalidArgumentError(
"SUM accepts exactly one argument".into(),
));
}
let arg_expr = match &args_slice[0] {
FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => expr,
_ => {
return Err(Error::InvalidArgumentError(
"SUM requires a column argument".into(),
));
}
};
if let Some(column) = parse_count_nulls_case(arg_expr)? {
llkv_expr::expr::AggregateCall::CountNulls(column)
} else {
let column = resolve_column_name(arg_expr)?;
llkv_expr::expr::AggregateCall::Sum(column)
}
}
"min" => {
if args_slice.len() != 1 {
return Err(Error::InvalidArgumentError(
"MIN accepts exactly one argument".into(),
));
}
let arg_expr = match &args_slice[0] {
FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => expr,
_ => {
return Err(Error::InvalidArgumentError(
"MIN requires a column argument".into(),
));
}
};
let column = resolve_column_name(arg_expr)?;
llkv_expr::expr::AggregateCall::Min(column)
}
"max" => {
if args_slice.len() != 1 {
return Err(Error::InvalidArgumentError(
"MAX accepts exactly one argument".into(),
));
}
let arg_expr = match &args_slice[0] {
FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => expr,
_ => {
return Err(Error::InvalidArgumentError(
"MAX requires a column argument".into(),
));
}
};
let column = resolve_column_name(arg_expr)?;
llkv_expr::expr::AggregateCall::Max(column)
}
_ => return Ok(None),
};
Ok(Some(agg_call))
}
fn parse_count_nulls_case(expr: &SqlExpr) -> SqlResult<Option<String>> {
let SqlExpr::Case {
operand,
conditions,
else_result,
..
} = expr
else {
return Ok(None);
};
if operand.is_some() || conditions.len() != 1 {
return Ok(None);
}
let case_when = &conditions[0];
if !is_integer_literal(&case_when.result, 1) {
return Ok(None);
}
let else_expr = match else_result {
Some(expr) => expr.as_ref(),
None => return Ok(None),
};
if !is_integer_literal(else_expr, 0) {
return Ok(None);
}
let inner = match &case_when.condition {
SqlExpr::IsNull(inner) => inner.as_ref(),
_ => return Ok(None),
};
resolve_column_name(inner).map(Some)
}
fn is_integer_literal(expr: &SqlExpr, expected: i64) -> bool {
match expr {
SqlExpr::Value(ValueWithSpan {
value: Value::Number(text, _),
..
}) => text.parse::<i64>() == Ok(expected),
_ => false,
}
}
fn translate_condition_with_context(
resolver: &IdentifierResolver<'_>,
context: IdentifierContext,
expr: &SqlExpr,
) -> SqlResult<llkv_expr::expr::Expr<'static, String>> {
match expr {
SqlExpr::IsNull(inner) => {
let scalar = translate_scalar_with_context(resolver, context, inner)?;
match scalar {
llkv_expr::expr::ScalarExpr::Column(column) => {
Ok(llkv_expr::expr::Expr::Pred(llkv_expr::expr::Filter {
field_id: column,
op: llkv_expr::expr::Operator::IsNull,
}))
}
_ => Err(Error::InvalidArgumentError(
"IS NULL predicates currently support column references only".into(),
)),
}
}
SqlExpr::IsNotNull(inner) => {
let scalar = translate_scalar_with_context(resolver, context, inner)?;
match scalar {
llkv_expr::expr::ScalarExpr::Column(column) => {
Ok(llkv_expr::expr::Expr::Pred(llkv_expr::expr::Filter {
field_id: column,
op: llkv_expr::expr::Operator::IsNotNull,
}))
}
_ => Err(Error::InvalidArgumentError(
"IS NOT NULL predicates currently support column references only".into(),
)),
}
}
SqlExpr::BinaryOp { left, op, right } => match op {
BinaryOperator::And => Ok(llkv_expr::expr::Expr::And(vec![
translate_condition_with_context(resolver, context, left)?,
translate_condition_with_context(resolver, context, right)?,
])),
BinaryOperator::Or => Ok(llkv_expr::expr::Expr::Or(vec![
translate_condition_with_context(resolver, context, left)?,
translate_condition_with_context(resolver, context, right)?,
])),
BinaryOperator::Eq
| BinaryOperator::NotEq
| BinaryOperator::Lt
| BinaryOperator::LtEq
| BinaryOperator::Gt
| BinaryOperator::GtEq => {
translate_comparison_with_context(resolver, context, left, op.clone(), right)
}
other => Err(Error::InvalidArgumentError(format!(
"unsupported binary operator in WHERE clause: {other:?}"
))),
},
SqlExpr::UnaryOp {
op: UnaryOperator::Not,
expr,
} => Ok(llkv_expr::expr::Expr::not(
translate_condition_with_context(resolver, context, expr)?,
)),
SqlExpr::Nested(inner) => translate_condition_with_context(resolver, context, inner),
other => Err(Error::InvalidArgumentError(format!(
"unsupported WHERE clause: {other:?}"
))),
}
}
fn translate_comparison_with_context(
resolver: &IdentifierResolver<'_>,
context: IdentifierContext,
left: &SqlExpr,
op: BinaryOperator,
right: &SqlExpr,
) -> SqlResult<llkv_expr::expr::Expr<'static, String>> {
let left_scalar = translate_scalar_with_context(resolver, context, left)?;
let right_scalar = translate_scalar_with_context(resolver, context, right)?;
let compare_op = match op {
BinaryOperator::Eq => llkv_expr::expr::CompareOp::Eq,
BinaryOperator::NotEq => llkv_expr::expr::CompareOp::NotEq,
BinaryOperator::Lt => llkv_expr::expr::CompareOp::Lt,
BinaryOperator::LtEq => llkv_expr::expr::CompareOp::LtEq,
BinaryOperator::Gt => llkv_expr::expr::CompareOp::Gt,
BinaryOperator::GtEq => llkv_expr::expr::CompareOp::GtEq,
other => {
return Err(Error::InvalidArgumentError(format!(
"unsupported comparison operator: {other:?}"
)));
}
};
if let (
llkv_expr::expr::ScalarExpr::Column(column),
llkv_expr::expr::ScalarExpr::Literal(literal),
) = (&left_scalar, &right_scalar)
&& let Some(op) = compare_op_to_filter_operator(compare_op, literal)
{
return Ok(llkv_expr::expr::Expr::Pred(llkv_expr::expr::Filter {
field_id: column.clone(),
op,
}));
}
if let (
llkv_expr::expr::ScalarExpr::Literal(literal),
llkv_expr::expr::ScalarExpr::Column(column),
) = (&left_scalar, &right_scalar)
&& let Some(flipped) = flip_compare_op(compare_op)
&& let Some(op) = compare_op_to_filter_operator(flipped, literal)
{
return Ok(llkv_expr::expr::Expr::Pred(llkv_expr::expr::Filter {
field_id: column.clone(),
op,
}));
}
Ok(llkv_expr::expr::Expr::Compare {
left: left_scalar,
op: compare_op,
right: right_scalar,
})
}
fn compare_op_to_filter_operator(
op: llkv_expr::expr::CompareOp,
literal: &Literal,
) -> Option<llkv_expr::expr::Operator<'static>> {
let lit = literal.clone();
match op {
llkv_expr::expr::CompareOp::Eq => Some(llkv_expr::expr::Operator::Equals(lit)),
llkv_expr::expr::CompareOp::Lt => Some(llkv_expr::expr::Operator::LessThan(lit)),
llkv_expr::expr::CompareOp::LtEq => Some(llkv_expr::expr::Operator::LessThanOrEquals(lit)),
llkv_expr::expr::CompareOp::Gt => Some(llkv_expr::expr::Operator::GreaterThan(lit)),
llkv_expr::expr::CompareOp::GtEq => {
Some(llkv_expr::expr::Operator::GreaterThanOrEquals(lit))
}
llkv_expr::expr::CompareOp::NotEq => None,
}
}
fn flip_compare_op(op: llkv_expr::expr::CompareOp) -> Option<llkv_expr::expr::CompareOp> {
match op {
llkv_expr::expr::CompareOp::Eq => Some(llkv_expr::expr::CompareOp::Eq),
llkv_expr::expr::CompareOp::Lt => Some(llkv_expr::expr::CompareOp::Gt),
llkv_expr::expr::CompareOp::LtEq => Some(llkv_expr::expr::CompareOp::GtEq),
llkv_expr::expr::CompareOp::Gt => Some(llkv_expr::expr::CompareOp::Lt),
llkv_expr::expr::CompareOp::GtEq => Some(llkv_expr::expr::CompareOp::LtEq),
llkv_expr::expr::CompareOp::NotEq => None,
}
}
fn translate_scalar_with_context(
resolver: &IdentifierResolver<'_>,
context: IdentifierContext,
expr: &SqlExpr,
) -> SqlResult<llkv_expr::expr::ScalarExpr<String>> {
match expr {
SqlExpr::Identifier(ident) => {
let parts = vec![ident.value.clone()];
let resolution = resolver.resolve(&parts, context)?;
Ok(resolution.into_scalar_expr())
}
SqlExpr::CompoundIdentifier(idents) => {
if idents.is_empty() {
return Err(Error::InvalidArgumentError(
"invalid compound identifier".into(),
));
}
let parts: Vec<String> = idents.iter().map(|ident| ident.value.clone()).collect();
let resolution = resolver.resolve(&parts, context)?;
Ok(resolution.into_scalar_expr())
}
_ => translate_scalar(expr),
}
}
fn translate_scalar(expr: &SqlExpr) -> SqlResult<llkv_expr::expr::ScalarExpr<String>> {
match expr {
SqlExpr::Identifier(ident) => Ok(llkv_expr::expr::ScalarExpr::column(ident.value.clone())),
SqlExpr::CompoundIdentifier(idents) => {
if idents.is_empty() {
return Err(Error::InvalidArgumentError(
"invalid compound identifier".into(),
));
}
let column_name = idents[0].value.clone();
let mut result = llkv_expr::expr::ScalarExpr::column(column_name);
for part in &idents[1..] {
let field_name = part.value.clone();
result = llkv_expr::expr::ScalarExpr::get_field(result, field_name);
}
Ok(result)
}
SqlExpr::Value(value) => literal_from_value(value),
SqlExpr::BinaryOp { left, op, right } => {
let left_expr = translate_scalar(left)?;
let right_expr = translate_scalar(right)?;
let op = match op {
BinaryOperator::Plus => llkv_expr::expr::BinaryOp::Add,
BinaryOperator::Minus => llkv_expr::expr::BinaryOp::Subtract,
BinaryOperator::Multiply => llkv_expr::expr::BinaryOp::Multiply,
BinaryOperator::Divide => llkv_expr::expr::BinaryOp::Divide,
BinaryOperator::Modulo => llkv_expr::expr::BinaryOp::Modulo,
other => {
return Err(Error::InvalidArgumentError(format!(
"unsupported scalar binary operator: {other:?}"
)));
}
};
Ok(llkv_expr::expr::ScalarExpr::binary(
left_expr, op, right_expr,
))
}
SqlExpr::UnaryOp {
op: UnaryOperator::Minus,
expr,
} => match translate_scalar(expr)? {
llkv_expr::expr::ScalarExpr::Literal(lit) => match lit {
Literal::Integer(v) => {
Ok(llkv_expr::expr::ScalarExpr::literal(Literal::Integer(-v)))
}
Literal::Float(v) => Ok(llkv_expr::expr::ScalarExpr::literal(Literal::Float(-v))),
Literal::Boolean(_) => Err(Error::InvalidArgumentError(
"cannot negate boolean literal".into(),
)),
Literal::String(_) => Err(Error::InvalidArgumentError(
"cannot negate string literal".into(),
)),
Literal::Struct(_) => Err(Error::InvalidArgumentError(
"cannot negate struct literal".into(),
)),
Literal::Null => Err(Error::InvalidArgumentError(
"cannot negate null literal".into(),
)),
},
_ => Err(Error::InvalidArgumentError(
"cannot negate non-literal expression".into(),
)),
},
SqlExpr::UnaryOp {
op: UnaryOperator::Plus,
expr,
} => translate_scalar(expr),
SqlExpr::Nested(inner) => translate_scalar(inner),
SqlExpr::Function(func) => {
if let Some(agg_call) = try_parse_aggregate_function(func)? {
Ok(llkv_expr::expr::ScalarExpr::aggregate(agg_call))
} else {
Err(Error::InvalidArgumentError(format!(
"unsupported function in scalar expression: {:?}",
func.name
)))
}
}
SqlExpr::Dictionary(fields) => {
let mut struct_fields = Vec::new();
for entry in fields {
let key = entry.key.value.clone(); let value_expr = translate_scalar(&entry.value)?;
match value_expr {
llkv_expr::expr::ScalarExpr::Literal(lit) => {
struct_fields.push((key, Box::new(lit)));
}
_ => {
return Err(Error::InvalidArgumentError(
"Dictionary values must be literals".to_string(),
));
}
}
}
Ok(llkv_expr::expr::ScalarExpr::literal(Literal::Struct(
struct_fields,
)))
}
other => Err(Error::InvalidArgumentError(format!(
"unsupported scalar expression: {other:?}"
))),
}
}
fn literal_from_value(value: &ValueWithSpan) -> SqlResult<llkv_expr::expr::ScalarExpr<String>> {
match &value.value {
Value::Number(text, _) => {
if text.contains(['.', 'e', 'E']) {
let parsed = text.parse::<f64>().map_err(|err| {
Error::InvalidArgumentError(format!("invalid float literal: {err}"))
})?;
Ok(llkv_expr::expr::ScalarExpr::literal(Literal::Float(parsed)))
} else {
let parsed = text.parse::<i128>().map_err(|err| {
Error::InvalidArgumentError(format!("invalid integer literal: {err}"))
})?;
Ok(llkv_expr::expr::ScalarExpr::literal(Literal::Integer(
parsed,
)))
}
}
Value::Boolean(value) => Ok(llkv_expr::expr::ScalarExpr::literal(Literal::Boolean(
*value,
))),
Value::Null => Ok(llkv_expr::expr::ScalarExpr::literal(Literal::Null)),
other => {
if let Some(text) = other.clone().into_string() {
Ok(llkv_expr::expr::ScalarExpr::literal(Literal::String(text)))
} else {
Err(Error::InvalidArgumentError(format!(
"unsupported literal: {other:?}"
)))
}
}
}
}
fn resolve_assignment_column_name(target: &AssignmentTarget) -> SqlResult<String> {
match target {
AssignmentTarget::ColumnName(name) => {
if name.0.len() != 1 {
return Err(Error::InvalidArgumentError(
"qualified column names in UPDATE assignments are not supported yet".into(),
));
}
match &name.0[0] {
ObjectNamePart::Identifier(ident) => Ok(ident.value.clone()),
other => Err(Error::InvalidArgumentError(format!(
"unsupported column reference in UPDATE assignment: {other:?}"
))),
}
}
AssignmentTarget::Tuple(_) => Err(Error::InvalidArgumentError(
"tuple assignments are not supported yet".into(),
)),
}
}
fn arrow_type_from_sql(data_type: &SqlDataType) -> SqlResult<arrow::datatypes::DataType> {
use arrow::datatypes::DataType;
match data_type {
SqlDataType::Int(_)
| SqlDataType::Integer(_)
| SqlDataType::BigInt(_)
| SqlDataType::SmallInt(_)
| SqlDataType::TinyInt(_) => Ok(DataType::Int64),
SqlDataType::Float(_)
| SqlDataType::Real
| SqlDataType::Double(_)
| SqlDataType::DoublePrecision => Ok(DataType::Float64),
SqlDataType::Text
| SqlDataType::String(_)
| SqlDataType::Varchar(_)
| SqlDataType::Char(_)
| SqlDataType::Uuid => Ok(DataType::Utf8),
SqlDataType::Date => Ok(DataType::Date32),
SqlDataType::Decimal(_) | SqlDataType::Numeric(_) => Ok(DataType::Float64),
SqlDataType::Boolean => Ok(DataType::Boolean),
SqlDataType::Custom(name, args) => {
if name.0.len() == 1
&& let ObjectNamePart::Identifier(ident) = &name.0[0]
&& ident.value.eq_ignore_ascii_case("row")
{
return row_type_to_arrow(data_type, args);
}
Err(Error::InvalidArgumentError(format!(
"unsupported SQL data type: {data_type:?}"
)))
}
other => Err(Error::InvalidArgumentError(format!(
"unsupported SQL data type: {other:?}"
))),
}
}
fn row_type_to_arrow(
data_type: &SqlDataType,
tokens: &[String],
) -> SqlResult<arrow::datatypes::DataType> {
use arrow::datatypes::{DataType, Field, FieldRef, Fields};
let row_str = data_type.to_string();
if tokens.is_empty() {
return Err(Error::InvalidArgumentError(
"ROW type must define at least one field".into(),
));
}
let dialect = GenericDialect {};
let field_definitions = resolve_row_field_types(tokens, &dialect).map_err(|err| {
Error::InvalidArgumentError(format!("unable to parse ROW type '{row_str}': {err}"))
})?;
let mut fields: Vec<FieldRef> = Vec::with_capacity(field_definitions.len());
for (field_name, field_type) in field_definitions {
let arrow_field_type = arrow_type_from_sql(&field_type)?;
fields.push(Arc::new(Field::new(field_name, arrow_field_type, true)));
}
let struct_fields: Fields = fields.into();
Ok(DataType::Struct(struct_fields))
}
fn resolve_row_field_types(
tokens: &[String],
dialect: &GenericDialect,
) -> SqlResult<Vec<(String, SqlDataType)>> {
if tokens.is_empty() {
return Err(Error::InvalidArgumentError(
"ROW type must define at least one field".into(),
));
}
let mut start = 0;
let mut end = tokens.len();
if tokens[start] == "(" {
if end == 0 || tokens[end - 1] != ")" {
return Err(Error::InvalidArgumentError(
"ROW type is missing closing ')'".into(),
));
}
start += 1;
end -= 1;
} else if tokens[end - 1] == ")" {
return Err(Error::InvalidArgumentError(
"ROW type contains unmatched ')'".into(),
));
}
let slice = &tokens[start..end];
if slice.is_empty() {
return Err(Error::InvalidArgumentError(
"ROW type did not provide any field definitions".into(),
));
}
let mut fields = Vec::new();
let mut index = 0;
while index < slice.len() {
if slice[index] == "," {
index += 1;
continue;
}
let field_name = normalize_row_field_name(&slice[index])?;
index += 1;
if index >= slice.len() {
return Err(Error::InvalidArgumentError(format!(
"ROW field '{field_name}' is missing a type specification"
)));
}
let mut last_success: Option<(usize, SqlDataType)> = None;
let mut type_end = index;
while type_end <= slice.len() {
let candidate = slice[index..type_end].join(" ");
if candidate.trim().is_empty() {
type_end += 1;
continue;
}
if let Ok(parsed_type) = parse_sql_data_type(&candidate, dialect) {
last_success = Some((type_end, parsed_type));
}
if type_end == slice.len() {
break;
}
if slice[type_end] == "," && last_success.is_some() {
break;
}
type_end += 1;
}
let Some((next_index, data_type)) = last_success else {
return Err(Error::InvalidArgumentError(format!(
"failed to parse ROW field type for '{field_name}'"
)));
};
fields.push((field_name, data_type));
index = next_index;
if index < slice.len() && slice[index] == "," {
index += 1;
}
}
if fields.is_empty() {
return Err(Error::InvalidArgumentError(
"ROW type did not provide any field definitions".into(),
));
}
Ok(fields)
}
fn normalize_row_field_name(raw: &str) -> SqlResult<String> {
let trimmed = raw.trim();
if trimmed.is_empty() {
return Err(Error::InvalidArgumentError(
"ROW field name must not be empty".into(),
));
}
if let Some(stripped) = trimmed.strip_prefix('"') {
let without_end = stripped.strip_suffix('"').ok_or_else(|| {
Error::InvalidArgumentError(format!("unterminated quoted ROW field name: {trimmed}"))
})?;
let name = without_end.replace("\"\"", "\"");
return Ok(name);
}
Ok(trimmed.to_string())
}
fn parse_sql_data_type(type_str: &str, dialect: &GenericDialect) -> SqlResult<SqlDataType> {
let trimmed = type_str.trim();
let sql = format!("CREATE TABLE __row(__field {trimmed});");
let statements = Parser::parse_sql(dialect, &sql).map_err(|err| {
Error::InvalidArgumentError(format!("failed to parse ROW field type '{trimmed}': {err}"))
})?;
let stmt = statements.into_iter().next().ok_or_else(|| {
Error::InvalidArgumentError(format!(
"ROW field type '{trimmed}' did not produce a statement"
))
})?;
match stmt {
Statement::CreateTable(table) => table
.columns
.first()
.map(|col| col.data_type.clone())
.ok_or_else(|| {
Error::InvalidArgumentError(format!(
"ROW field type '{trimmed}' missing column definition"
))
}),
other => Err(Error::InvalidArgumentError(format!(
"unexpected statement while parsing ROW field type: {other:?}"
))),
}
}
fn extract_constant_select_rows(select: &Select) -> SqlResult<Option<Vec<Vec<PlanValue>>>> {
if !select.from.is_empty() {
return Ok(None);
}
if select.selection.is_some()
|| select.having.is_some()
|| !select.named_window.is_empty()
|| select.qualify.is_some()
|| select.distinct.is_some()
|| select.top.is_some()
|| select.into.is_some()
|| select.prewhere.is_some()
|| !select.lateral_views.is_empty()
|| select.value_table_mode.is_some()
|| !group_by_is_empty(&select.group_by)
{
return Err(Error::InvalidArgumentError(
"constant SELECT statements do not support advanced clauses".into(),
));
}
if select.projection.is_empty() {
return Err(Error::InvalidArgumentError(
"constant SELECT requires at least one projection".into(),
));
}
let mut row: Vec<PlanValue> = Vec::with_capacity(select.projection.len());
for item in &select.projection {
let expr = match item {
SelectItem::UnnamedExpr(expr) => expr,
SelectItem::ExprWithAlias { expr, .. } => expr,
other => {
return Err(Error::InvalidArgumentError(format!(
"unsupported projection in constant SELECT: {other:?}"
)));
}
};
let value = SqlValue::try_from_expr(expr)?;
row.push(PlanValue::from(value));
}
Ok(Some(vec![row]))
}
fn extract_single_table(from: &[TableWithJoins]) -> SqlResult<(String, String)> {
if from.len() != 1 {
return Err(Error::InvalidArgumentError(
"queries over multiple tables are not supported yet".into(),
));
}
let item = &from[0];
if !item.joins.is_empty() {
return Err(Error::InvalidArgumentError(
"JOIN clauses are not supported yet".into(),
));
}
match &item.relation {
TableFactor::Table { name, .. } => canonical_object_name(name),
_ => Err(Error::InvalidArgumentError(
"queries require a plain table name".into(),
)),
}
}
fn extract_tables(from: &[TableWithJoins]) -> SqlResult<Vec<llkv_plan::TableRef>> {
let mut tables = Vec::new();
for item in from {
if !item.joins.is_empty() {
return Err(Error::InvalidArgumentError(
"JOIN clauses are not supported yet".into(),
));
}
match &item.relation {
TableFactor::Table { name, .. } => {
let (schema_opt, table) = parse_schema_qualified_name(name)?;
let schema = schema_opt.unwrap_or_default();
tables.push(llkv_plan::TableRef::new(schema, table));
}
_ => {
return Err(Error::InvalidArgumentError(
"queries require a plain table name".into(),
));
}
}
}
Ok(tables)
}
fn group_by_is_empty(expr: &GroupByExpr) -> bool {
matches!(
expr,
GroupByExpr::Expressions(exprs, modifiers)
if exprs.is_empty() && modifiers.is_empty()
)
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{Array, Int64Array, StringArray};
use arrow::record_batch::RecordBatch;
use llkv_storage::pager::MemPager;
fn extract_string_options(batches: &[RecordBatch]) -> Vec<Option<String>> {
let mut values: Vec<Option<String>> = Vec::new();
for batch in batches {
let column = batch
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.expect("string column");
for idx in 0..column.len() {
if column.is_null(idx) {
values.push(None);
} else {
values.push(Some(column.value(idx).to_string()));
}
}
}
values
}
#[test]
fn create_insert_select_roundtrip() {
let pager = Arc::new(MemPager::default());
let engine = SqlEngine::new(pager);
let result = engine
.execute("CREATE TABLE people (id INT NOT NULL, name TEXT NOT NULL)")
.expect("create table");
assert!(matches!(
result[0],
RuntimeStatementResult::CreateTable { .. }
));
let result = engine
.execute("INSERT INTO people (id, name) VALUES (1, 'alice'), (2, 'bob')")
.expect("insert rows");
assert!(matches!(
result[0],
RuntimeStatementResult::Insert {
rows_inserted: 2,
..
}
));
let mut result = engine
.execute("SELECT name FROM people WHERE id = 2")
.expect("select rows");
let select_result = result.remove(0);
let batches = match select_result {
RuntimeStatementResult::Select { execution, .. } => {
execution.collect().expect("collect batches")
}
_ => panic!("expected select result"),
};
assert_eq!(batches.len(), 1);
let column = batches[0]
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.expect("string column");
assert_eq!(column.len(), 1);
assert_eq!(column.value(0), "bob");
}
#[test]
fn insert_select_constant_including_null() {
let pager = Arc::new(MemPager::default());
let engine = SqlEngine::new(pager);
engine
.execute("CREATE TABLE integers(i INTEGER)")
.expect("create table");
let result = engine
.execute("INSERT INTO integers SELECT 42")
.expect("insert literal");
assert!(matches!(
result[0],
RuntimeStatementResult::Insert {
rows_inserted: 1,
..
}
));
let result = engine
.execute("INSERT INTO integers SELECT CAST(NULL AS VARCHAR)")
.expect("insert null literal");
assert!(matches!(
result[0],
RuntimeStatementResult::Insert {
rows_inserted: 1,
..
}
));
let mut result = engine
.execute("SELECT * FROM integers")
.expect("select rows");
let select_result = result.remove(0);
let batches = match select_result {
RuntimeStatementResult::Select { execution, .. } => {
execution.collect().expect("collect batches")
}
_ => panic!("expected select result"),
};
let mut values: Vec<Option<i64>> = Vec::new();
for batch in &batches {
let column = batch
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.expect("int column");
for idx in 0..column.len() {
if column.is_null(idx) {
values.push(None);
} else {
values.push(Some(column.value(idx)));
}
}
}
assert_eq!(values, vec![Some(42), None]);
}
#[test]
fn update_with_where_clause_filters_rows() {
let pager = Arc::new(MemPager::default());
let engine = SqlEngine::new(pager);
engine
.execute("SET default_null_order='nulls_first'")
.expect("set default null order");
engine
.execute("CREATE TABLE strings(a VARCHAR)")
.expect("create table");
engine
.execute("INSERT INTO strings VALUES ('3'), ('4'), (NULL)")
.expect("insert seed rows");
let result = engine
.execute("UPDATE strings SET a = 13 WHERE a = '3'")
.expect("update rows");
assert!(matches!(
result[0],
RuntimeStatementResult::Update {
rows_updated: 1,
..
}
));
let mut result = engine
.execute("SELECT * FROM strings ORDER BY cast(a AS INTEGER)")
.expect("select rows");
let select_result = result.remove(0);
let batches = match select_result {
RuntimeStatementResult::Select { execution, .. } => {
execution.collect().expect("collect batches")
}
_ => panic!("expected select result"),
};
let mut values: Vec<Option<String>> = Vec::new();
for batch in &batches {
let column = batch
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.expect("string column");
for idx in 0..column.len() {
if column.is_null(idx) {
values.push(None);
} else {
values.push(Some(column.value(idx).to_string()));
}
}
}
values.sort_by(|a, b| match (a, b) {
(None, None) => std::cmp::Ordering::Equal,
(None, Some(_)) => std::cmp::Ordering::Less,
(Some(_), None) => std::cmp::Ordering::Greater,
(Some(av), Some(bv)) => {
let a_val = av.parse::<i64>().unwrap_or_default();
let b_val = bv.parse::<i64>().unwrap_or_default();
a_val.cmp(&b_val)
}
});
assert_eq!(
values,
vec![None, Some("4".to_string()), Some("13".to_string())]
);
}
#[test]
fn order_by_honors_configured_default_null_order() {
let pager = Arc::new(MemPager::default());
let engine = SqlEngine::new(pager);
engine
.execute("CREATE TABLE strings(a VARCHAR)")
.expect("create table");
engine
.execute("INSERT INTO strings VALUES ('3'), ('4'), (NULL)")
.expect("insert values");
engine
.execute("UPDATE strings SET a = 13 WHERE a = '3'")
.expect("update value");
let mut result = engine
.execute("SELECT * FROM strings ORDER BY cast(a AS INTEGER)")
.expect("select rows");
let select_result = result.remove(0);
let batches = match select_result {
RuntimeStatementResult::Select { execution, .. } => {
execution.collect().expect("collect batches")
}
_ => panic!("expected select result"),
};
let values = extract_string_options(&batches);
assert_eq!(
values,
vec![Some("4".to_string()), Some("13".to_string()), None]
);
assert!(!engine.default_nulls_first_for_tests());
engine
.execute("SET default_null_order='nulls_first'")
.expect("set default null order");
assert!(engine.default_nulls_first_for_tests());
let mut result = engine
.execute("SELECT * FROM strings ORDER BY cast(a AS INTEGER)")
.expect("select rows");
let select_result = result.remove(0);
let batches = match select_result {
RuntimeStatementResult::Select { execution, .. } => {
execution.collect().expect("collect batches")
}
_ => panic!("expected select result"),
};
let values = extract_string_options(&batches);
assert_eq!(
values,
vec![None, Some("4".to_string()), Some("13".to_string())]
);
}
#[test]
fn arrow_type_from_row_returns_struct_fields() {
let dialect = GenericDialect {};
let statements = Parser::parse_sql(
&dialect,
"CREATE TABLE row_types(payload ROW(a INTEGER, b VARCHAR));",
)
.expect("parse ROW type definition");
let data_type = match &statements[0] {
Statement::CreateTable(stmt) => stmt.columns[0].data_type.clone(),
other => panic!("unexpected statement: {other:?}"),
};
let arrow_type = arrow_type_from_sql(&data_type).expect("convert ROW type");
match arrow_type {
arrow::datatypes::DataType::Struct(fields) => {
assert_eq!(fields.len(), 2, "unexpected field count");
assert_eq!(fields[0].name(), "a");
assert_eq!(fields[1].name(), "b");
assert_eq!(fields[0].data_type(), &arrow::datatypes::DataType::Int64);
assert_eq!(fields[1].data_type(), &arrow::datatypes::DataType::Utf8);
}
other => panic!("expected struct type, got {other:?}"),
}
}
}