#![deny(
clippy::unwrap_used,
clippy::todo,
clippy::unimplemented,
)]
#![warn(
clippy::expect_used,
clippy::panic,
clippy::indexing_slicing,
)]
#![allow(missing_docs)]
#![warn(rust_2018_idioms)]
#![warn(
clippy::pedantic,
clippy::nursery,
clippy::cargo,
)]
#![allow(clippy::cargo_common_metadata)]
#![allow(
clippy::module_name_repetitions,
clippy::missing_errors_doc,
clippy::missing_panics_doc,
clippy::must_use_candidate,
// Stylistic preferences - not safety issues
clippy::similar_names,
clippy::redundant_else,
clippy::needless_continue,
clippy::needless_pass_by_ref_mut,
clippy::uninlined_format_args,
clippy::redundant_closure_for_method_calls,
clippy::match_same_arms,
clippy::doc_markdown,
clippy::items_after_statements,
clippy::option_if_let_else,
clippy::struct_excessive_bools,
clippy::unused_self,
clippy::unused_async,
clippy::return_self_not_must_use,
clippy::if_not_else,
clippy::manual_let_else,
clippy::single_char_add_str,
clippy::unreadable_literal,
clippy::needless_raw_string_hashes,
clippy::or_fun_call,
clippy::derive_partial_eq_without_eq,
clippy::redundant_clone,
clippy::map_unwrap_or,
clippy::needless_borrow,
clippy::format_push_string,
clippy::default_trait_access,
clippy::empty_line_after_doc_comments,
clippy::needless_pass_by_value,
clippy::wildcard_enum_match_arm,
clippy::match_wildcard_for_single_variants,
clippy::suboptimal_flops,
clippy::wildcard_imports,
clippy::ref_option,
clippy::needless_collect,
clippy::bool_to_int_with_if,
clippy::useless_format,
clippy::used_underscore_binding,
clippy::str_to_string,
clippy::implicit_hasher,
clippy::string_add_assign,
clippy::explicit_iter_loop,
clippy::single_match_else,
clippy::manual_string_new,
clippy::derivable_impls,
clippy::too_many_arguments,
clippy::too_many_lines,
clippy::branches_sharing_code,
clippy::manual_strip,
clippy::upper_case_acronyms,
clippy::struct_field_names,
clippy::assigning_clones,
clippy::should_implement_trait,
clippy::boxed_local,
clippy::collapsible_if,
clippy::field_reassign_with_default,
clippy::unnecessary_cast,
clippy::type_complexity,
clippy::manual_is_ascii_check,
clippy::borrow_as_ptr,
clippy::cognitive_complexity,
clippy::fn_params_excessive_bools,
clippy::iter_without_into_iter,
clippy::unit_cmp,
clippy::ptr_arg,
clippy::use_debug,
clippy::redundant_closure,
clippy::clone_on_copy,
clippy::new_without_default,
clippy::manual_range_contains,
clippy::manual_range_patterns,
clippy::if_then_some_else_none,
clippy::match_like_matches_macro,
clippy::option_as_ref_cloned,
clippy::collapsible_match,
clippy::filter_map_identity,
clippy::get_first,
clippy::implicit_clone,
clippy::len_zero,
clippy::write_with_newline,
clippy::single_char_pattern,
clippy::let_and_return,
clippy::redundant_pattern_matching,
clippy::match_ref_pats,
clippy::if_same_then_else,
clippy::semicolon_if_nothing_returned,
clippy::iter_over_hash_type,
clippy::iter_on_single_items,
clippy::iter_on_empty_collections,
clippy::useless_vec,
clippy::vec_init_then_push,
clippy::iter_nth_zero,
clippy::unwrap_or_default,
clippy::trivial_regex,
clippy::map_entry,
clippy::enum_glob_use,
clippy::unnested_or_patterns,
clippy::manual_clamp,
clippy::cast_ptr_alignment,
clippy::ptr_as_ptr,
clippy::imprecise_flops,
clippy::future_not_send,
clippy::significant_drop_in_scrutinee,
clippy::collection_is_never_read,
clippy::manual_div_ceil,
clippy::checked_conversions,
clippy::as_underscore,
clippy::as_ptr_cast_mut,
clippy::trim_split_whitespace,
clippy::string_lit_chars_any,
clippy::large_enum_variant,
clippy::doc_lazy_continuation,
clippy::too_long_first_doc_paragraph,
clippy::useless_conversion,
clippy::multiple_crate_versions,
clippy::unit_arg,
clippy::inherent_to_string,
clippy::to_string_trait_impl,
clippy::borrow_deref_ref,
clippy::manual_map,
clippy::manual_filter_map,
clippy::option_map_unit_fn,
clippy::result_map_unit_fn,
clippy::manual_is_multiple_of,
clippy::print_literal,
clippy::iter_kv_map,
clippy::manual_find,
clippy::write_literal,
clippy::explicit_into_iter_loop,
clippy::manual_ok_or,
clippy::bind_instead_of_map,
clippy::manual_retain,
clippy::io_other_error,
clippy::clone_on_ref_ptr,
clippy::bool_comparison,
clippy::single_match,
clippy::iter_next_loop,
clippy::str_split_at_newline,
clippy::option_as_ref_deref,
clippy::arithmetic_side_effects,
clippy::cloned_instead_of_copied,
clippy::string_slice,
clippy::inconsistent_struct_constructor,
clippy::unnecessary_literal_unwrap,
clippy::ref_binding_to_reference,
clippy::match_bool,
clippy::partialeq_to_none,
clippy::redundant_static_lifetimes,
clippy::char_lit_as_u8,
clippy::manual_is_power_of_two,
clippy::filter_map_bool_then,
clippy::manual_flatten,
clippy::manual_next_back,
clippy::maybe_infinite_iter,
clippy::needless_option_as_deref,
clippy::suspicious_else_formatting,
clippy::useless_transmute,
// Casting lints - intentional in numeric code
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::cast_sign_loss,
clippy::cast_lossless,
// Performance lints that add noise
clippy::trivially_copy_pass_by_ref,
clippy::significant_drop_tightening,
clippy::unnecessary_wraps,
clippy::missing_const_for_fn,
clippy::use_self,
// Lazy static using once_cell is idiomatic
clippy::non_std_lazy_statics,
)]
#![allow(dead_code)]
#![allow(unused_variables)]
pub mod storage;
pub mod compute;
pub mod optimizer;
pub mod vector;
pub mod protocol;
pub mod protocols; pub mod crypto;
pub mod tenant;
pub mod sql;
pub mod audit;
pub mod network;
pub mod repl;
pub mod api; pub mod cli; pub mod session; pub mod ai; pub mod multi_tenant; pub mod git_integration; pub mod runtime; pub mod graph; pub mod search;
#[cfg(feature = "code-graph")]
pub mod code_graph;
#[cfg(feature = "graph-rag")]
pub mod graph_rag;
#[cfg(feature = "mcp-endpoint")]
pub mod mcp;
#[cfg(any(
feature = "ha-tier1",
feature = "ha-tier2",
feature = "ha-tier3",
feature = "ha-dedup",
feature = "ha-branch-replication"
))]
pub mod replication;
#[cfg(feature = "ha-ab-testing")]
pub mod ab_testing;
mod error;
mod types;
pub mod config;
mod embedded_db_dump;
pub use error::{Error, Result};
pub use types::{DataType, Value, Tuple, Schema, Column, ColumnStorageMode, VectorStoreInfo, AgentSession, AgentMessage, DocumentData, DocumentMetadata};
pub use config::{Config, KeySource, ZkeMode, ZkeEncryptionConfig};
pub use storage::StorageEngine;
pub use crypto::{
ZkeConfig, ZkeDerivedKeys, ZkeKeyDerivation, ZkeRequestContext,
ZeroKnowledgeSession, NonceTracker, TimestampValidator,
};
fn convert_logical_referential_action(action: &sql::logical_plan::ReferentialAction) -> sql::constraints::ReferentialAction {
match action {
sql::logical_plan::ReferentialAction::NoAction => sql::constraints::ReferentialAction::NoAction,
sql::logical_plan::ReferentialAction::Restrict => sql::constraints::ReferentialAction::Restrict,
sql::logical_plan::ReferentialAction::Cascade => sql::constraints::ReferentialAction::Cascade,
sql::logical_plan::ReferentialAction::SetNull => sql::constraints::ReferentialAction::SetNull,
sql::logical_plan::ReferentialAction::SetDefault => sql::constraints::ReferentialAction::SetDefault,
}
}
pub struct EmbeddedDatabase {
pub storage: std::sync::Arc<storage::StorageEngine>,
config: Config,
current_transaction: std::sync::Arc<std::sync::Mutex<Option<storage::Transaction>>>,
pub tenant_manager: std::sync::Arc<crate::tenant::TenantManager>,
pub trigger_registry: std::sync::Arc<sql::TriggerRegistry>,
pub function_registry: std::sync::Arc<sql::FunctionRegistry>,
mv_scheduler: std::sync::Arc<storage::MVScheduler>,
auto_refresh_worker: std::sync::Arc<parking_lot::RwLock<Option<storage::AutoRefreshWorker>>>,
pub dump_manager: std::sync::Arc<storage::DumpManager>,
pub session_manager: std::sync::Arc<crate::session::SessionManager>,
pub lock_manager: std::sync::Arc<storage::LockManager>,
pub dirty_tracker: std::sync::Arc<storage::DirtyTracker>,
session_transactions: std::sync::Arc<dashmap::DashMap<crate::session::SessionId, storage::Transaction>>,
prepared_statements: std::sync::Arc<parking_lot::RwLock<std::collections::HashMap<String, sql::LogicalPlan>>>,
savepoints: std::sync::Arc<parking_lot::RwLock<Vec<SavepointState>>>,
plan_cache: std::sync::Arc<std::sync::Mutex<lru::LruCache<String, std::sync::Arc<sql::LogicalPlan>>>>,
parse_cache: std::sync::Arc<std::sync::Mutex<lru::LruCache<String, sqlparser::ast::Statement>>>,
result_cache: std::sync::Arc<std::sync::Mutex<lru::LruCache<String, std::sync::Arc<Vec<Tuple>>>>>,
art_undo_log: std::sync::Arc<parking_lot::RwLock<Vec<(String, u64, std::collections::HashMap<String, Value>)>>>,
}
impl Drop for EmbeddedDatabase {
fn drop(&mut self) {
if let Some(ref worker) = *self.auto_refresh_worker.read() {
worker.request_stop();
}
self.session_transactions.clear();
self.prepared_statements.write().clear();
if let Ok(mut cache) = self.plan_cache.lock() {
cache.clear();
}
if let Ok(mut cache) = self.parse_cache.lock() {
cache.clear();
}
self.savepoints.write().clear();
tracing::debug!("EmbeddedDatabase dropped, resources cleaned up");
}
}
#[derive(Clone)]
struct SavepointState {
name: String,
write_set_snapshot: Vec<(Vec<u8>, Option<Vec<u8>>)>,
}
#[inline]
fn starts_with_icase(s: &str, prefix: &str) -> bool {
#[allow(clippy::indexing_slicing)]
{ s.len() >= prefix.len()
&& s.as_bytes()[..prefix.len()].eq_ignore_ascii_case(prefix.as_bytes()) }
}
#[cfg(feature = "code-graph")]
struct CodeGraphBranchGuard<'a> {
db: Option<&'a EmbeddedDatabase>,
previous: Option<String>,
target: Option<String>,
}
#[cfg(feature = "code-graph")]
impl<'a> CodeGraphBranchGuard<'a> {
fn noop() -> Self {
Self { db: None, previous: None, target: None }
}
fn switch_to(db: &'a EmbeddedDatabase, target: String) -> Self {
let previous = db.storage.get_current_branch();
let current = previous.as_deref().unwrap_or("main");
if current == target {
return Self::noop();
}
let _ = db.switch_branch(&target);
Self { db: Some(db), previous, target: Some(target) }
}
}
#[cfg(feature = "code-graph")]
impl<'a> Drop for CodeGraphBranchGuard<'a> {
fn drop(&mut self) {
let Some(db) = self.db else { return };
let Some(target) = self.target.as_deref() else { return };
match self.previous.as_deref() {
Some(prev) if prev != target => {
let _ = db.switch_branch(prev);
}
None if target != "main" => {
let _ = db.switch_branch("main");
}
_ => {}
}
}
}
impl EmbeddedDatabase {
fn is_transaction_control(sql: &str) -> bool {
let trimmed = sql.trim().trim_end_matches(';').trim();
starts_with_icase(trimmed, "BEGIN") ||
starts_with_icase(trimmed, "START TRANSACTION") ||
trimmed.eq_ignore_ascii_case("COMMIT") ||
trimmed.eq_ignore_ascii_case("ROLLBACK")
}
fn handle_transaction_control(&self, sql: &str) -> Result<u64> {
let trimmed = sql.trim().trim_end_matches(';').trim();
if starts_with_icase(trimmed, "BEGIN") || starts_with_icase(trimmed, "START TRANSACTION") {
self.begin_transaction_internal()?;
Ok(0)
} else if trimmed.eq_ignore_ascii_case("COMMIT") {
self.commit_internal()?;
Ok(0)
} else if trimmed.eq_ignore_ascii_case("ROLLBACK") {
self.rollback_internal()?;
Ok(0)
} else {
Err(Error::query_execution("Unknown transaction control statement"))
}
}
fn begin_transaction_internal(&self) -> Result<()> {
use crate::error::LockResultExt;
let mut txn_ref = self.current_transaction.lock()
.map_lock_err("Failed to acquire transaction lock for begin")?;
if txn_ref.is_some() {
return Err(Error::transaction("Transaction already active"));
}
let txn = self.storage.begin_transaction()?;
*txn_ref = Some(txn);
Ok(())
}
fn commit_internal(&self) -> Result<()> {
use crate::error::LockResultExt;
let mut txn_ref = self.current_transaction.lock()
.map_lock_err("Failed to acquire transaction lock for commit")?;
if let Some(txn) = txn_ref.take() {
txn.commit()?;
self.art_undo_log.write().clear();
self.storage.increment_lsn();
Ok(())
} else {
Err(Error::transaction("No active transaction to commit"))
}
}
fn rollback_internal(&self) -> Result<()> {
use crate::error::LockResultExt;
let mut txn_ref = self.current_transaction.lock()
.map_lock_err("Failed to acquire transaction lock for rollback")?;
if let Some(txn) = txn_ref.take() {
txn.rollback()?;
let undo_entries: Vec<_> = self.art_undo_log.write().drain(..).collect();
for (table_name, row_id, col_values) in undo_entries {
if let Err(e) = self.storage.art_indexes().on_delete(&table_name, row_id, &col_values) {
tracing::debug!("ART rollback for '{}' row {}: {}", table_name, row_id, e);
}
}
Ok(())
} else {
Err(Error::transaction("No active transaction to rollback"))
}
}
#[cfg(feature = "ha-tier1")]
fn try_parse_ha_command(sql: &str) -> Result<Option<sql::LogicalPlan>> {
if sql::Parser::is_switchover(sql) {
let target_node = sql::Parser::parse_switchover_sql(sql)?;
Ok(Some(sql::LogicalPlan::Switchover { target_node }))
} else if sql::Parser::is_switchover_check(sql) {
let target_node = sql::Parser::parse_switchover_check_sql(sql)?;
Ok(Some(sql::LogicalPlan::SwitchoverCheck { target_node }))
} else if sql::Parser::is_cluster_status(sql) {
Ok(Some(sql::LogicalPlan::ClusterStatus))
} else if sql::Parser::is_set_node_alias(sql) {
let (node_id, alias) = sql::Parser::parse_set_node_alias_sql(sql)?;
Ok(Some(sql::LogicalPlan::SetNodeAlias { node_id, alias }))
} else if sql::Parser::is_show_topology(sql) {
Ok(Some(sql::LogicalPlan::ShowTopology))
} else {
Ok(None)
}
}
#[cfg(not(feature = "ha-tier1"))]
fn try_parse_ha_command(_sql: &str) -> Result<Option<sql::LogicalPlan>> {
Ok(None)
}
fn execute_in_transaction(&self, sql: &str, txn: &storage::Transaction) -> Result<u64> {
self.execute_in_transaction_inner(sql, txn, false)
}
fn execute_in_transaction_no_fast_path(&self, sql: &str, txn: &storage::Transaction) -> Result<u64> {
self.execute_in_transaction_inner(sql, txn, true)
}
fn execute_in_transaction_inner(&self, sql: &str, txn: &storage::Transaction, skip_fast_paths: bool) -> Result<u64> {
if let Some(context) = self.tenant_manager.get_current_context() {
self.tenant_manager.record_query(context.tenant_id)
.map_err(|e| Error::query_execution(format!("Quota exceeded: {}", e)))?;
}
let has_savepoints = !self.savepoints.read().is_empty();
let has_session_txns = !self.session_transactions.is_empty();
let use_fast_paths = !skip_fast_paths && !has_savepoints && !has_session_txns;
if use_fast_paths {
if let Some(result) = self.try_fast_insert(sql) {
return result;
}
}
if use_fast_paths {
if let Some(result) = self.try_fast_update(sql) {
return result;
}
}
let plan = if sql::Parser::is_create_branch(sql) {
let (branch_name, parent, as_of_clause, with_options) = sql::Parser::parse_create_branch_sql(sql)?;
sql::phase3::branching::BranchingParser::parse_create_branch(
branch_name,
parent,
&as_of_clause,
with_options.as_deref(),
)?
} else if sql::Parser::is_drop_branch(sql) {
let (branch_name, if_exists) = sql::Parser::parse_drop_branch_sql(sql)?;
sql::phase3::branching::BranchingParser::parse_drop_branch(branch_name, if_exists)?
} else if sql::Parser::is_merge_branch(sql) {
let (source, target, with_options) = sql::Parser::parse_merge_branch_sql(sql)?;
sql::phase3::branching::BranchingParser::parse_merge_branch(
source,
target,
with_options.as_deref(),
)?
} else if sql::Parser::is_use_branch(sql) {
let branch_name = sql::Parser::parse_use_branch_sql(sql)?;
sql::LogicalPlan::UseBranch { branch_name }
} else if sql::Parser::is_show_branches(sql) {
sql::LogicalPlan::ShowBranches
} else if sql::Parser::is_refresh_materialized_view(sql) {
let (view_name, concurrent, incremental) = sql::Parser::parse_refresh_materialized_view_sql(sql)?;
sql::LogicalPlan::RefreshMaterializedView {
name: view_name,
concurrent,
incremental,
}
} else if sql::Parser::is_drop_materialized_view(sql) {
let (view_name, if_exists) = sql::Parser::parse_drop_materialized_view_sql(sql)?;
sql::LogicalPlan::DropMaterializedView {
name: view_name,
if_exists,
}
} else if sql::Parser::is_alter_materialized_view(sql) {
let (view_name, options) = sql::Parser::parse_alter_materialized_view_sql(sql)?;
sql::LogicalPlan::AlterMaterializedView {
name: view_name,
options,
}
} else if sql::Parser::is_alter_column_storage(sql) {
let (table_name, column_name, storage_mode) = sql::Parser::parse_alter_column_storage(sql)?;
sql::LogicalPlan::AlterColumnStorage {
table_name,
column_name,
storage_mode,
}
} else if sql::Parser::is_pg_create_procedure(sql) || sql::Parser::is_pg_create_or_replace_procedure(sql) {
let (name, or_replace, params, language, body) = sql::Parser::parse_pg_create_procedure(sql)?;
let param_list: Vec<sql::logical_plan::FunctionParam> = params.into_iter().map(|(pname, ptype)| {
sql::logical_plan::FunctionParam {
name: pname,
data_type: sql::Planner::parse_data_type_string(&ptype).unwrap_or(DataType::Text),
mode: sql::logical_plan::ParamMode::In,
default: None,
}
}).collect();
sql::LogicalPlan::CreateProcedure {
name,
or_replace,
params: param_list,
body,
language,
}
} else if let Some(plan) = Self::try_parse_ha_command(sql)? {
plan
} else {
let (statement, _) = self.parse_cached(sql)?;
let catalog = self.storage.catalog();
let planner = sql::Planner::with_catalog(&catalog)
.with_sql(sql.to_string());
planner.statement_to_plan(statement)?
};
if matches!(&plan,
sql::LogicalPlan::CreateTable { .. } |
sql::LogicalPlan::DropTable { .. } |
sql::LogicalPlan::CreateMaterializedView { .. } |
sql::LogicalPlan::DropMaterializedView { .. } |
sql::LogicalPlan::Truncate { .. }
) {
self.invalidate_plan_cache();
}
match &plan {
sql::LogicalPlan::CreateTable { name, columns, constraints, if_not_exists, .. } => {
if *if_not_exists && self.storage.catalog().table_exists(name).unwrap_or(false) {
return Ok(0);
}
let schema_columns: Vec<Column> = columns.iter().map(|col_def| {
let default_expr = col_def.default.as_ref().map(|expr| {
serde_json::to_string(expr).unwrap_or_default()
});
Column {
name: col_def.name.clone(),
data_type: col_def.data_type.clone(),
nullable: !col_def.not_null,
primary_key: col_def.primary_key,
source_table: None,
source_table_name: None,
default_expr,
unique: col_def.unique,
storage_mode: col_def.storage_mode,
}
}).collect();
let schema = Schema::new(schema_columns);
let catalog = self.storage.catalog();
if let Err(e) = self.storage.log_create_table(name, &schema) {
tracing::warn!("Failed to log CREATE TABLE to WAL: {}", e);
}
catalog.create_table(name, schema)?;
if !constraints.is_empty() {
let mut table_constraints = sql::TableConstraints::new();
for constraint in constraints {
match constraint {
sql::logical_plan::TableConstraint::ForeignKey {
name: fk_name,
columns: fk_cols,
references_table,
references_columns,
on_delete,
on_update,
deferrable,
initially_deferred,
} => {
let fk = sql::ForeignKeyConstraint::new(
fk_name.clone().unwrap_or_else(|| {
sql::ForeignKeyConstraint::generate_name(name, fk_cols, references_table)
}),
name.clone(),
fk_cols.clone(),
references_table.clone(),
references_columns.clone(),
);
let fk = if let Some(action) = on_delete {
fk.on_delete(convert_logical_referential_action(action))
} else {
fk
};
let fk = if let Some(action) = on_update {
fk.on_update(convert_logical_referential_action(action))
} else {
fk
};
let fk = if *deferrable {
fk.deferrable(*initially_deferred)
} else {
fk
};
table_constraints.add_foreign_key(fk);
}
sql::logical_plan::TableConstraint::PrimaryKey { name: pk_name, columns: pk_cols } => {
table_constraints.add_unique(sql::UniqueConstraint::new(
pk_name.clone().unwrap_or_else(|| format!("{}_pkey", name)),
name.clone(),
pk_cols.clone(),
true,
));
}
sql::logical_plan::TableConstraint::Unique { name: uq_name, columns: uq_cols } => {
table_constraints.add_unique(sql::UniqueConstraint::new(
uq_name.clone().unwrap_or_else(|| format!("{}_unique", name)),
name.clone(),
uq_cols.clone(),
false,
));
}
sql::logical_plan::TableConstraint::Check { name: ck_name, expression } => {
table_constraints.add_check(sql::CheckConstraint::new(
ck_name.clone().unwrap_or_else(|| format!("{}_check", name)),
name.clone(),
serde_json::to_string(expression).unwrap_or_default(),
));
}
}
}
catalog.save_table_constraints(name, &table_constraints)?;
}
let catalog = self.storage.catalog();
let mut col_constraints = sql::TableConstraints::new();
let mut has_col_constraints = false;
for col_def in columns {
if col_def.primary_key {
col_constraints.add_unique(sql::UniqueConstraint::new(
format!("{}_{}_pkey", name, col_def.name),
name.clone(),
vec![col_def.name.clone()],
true, ));
has_col_constraints = true;
} else if col_def.unique {
col_constraints.add_unique(sql::UniqueConstraint::new(
format!("{}_{}_unique", name, col_def.name),
name.clone(),
vec![col_def.name.clone()],
false, ));
has_col_constraints = true;
}
}
if has_col_constraints {
if let Ok(existing) = catalog.load_table_constraints(name) {
for fk in existing.foreign_keys {
col_constraints.foreign_keys.push(fk);
}
for check in existing.check_constraints {
col_constraints.check_constraints.push(check);
}
for unique in existing.unique_constraints {
col_constraints.unique_constraints.push(unique);
}
}
catalog.save_table_constraints(name, &col_constraints)?;
}
Ok(1)
}
sql::LogicalPlan::Insert { table_name, columns, values, returning, on_conflict } => {
let catalog = self.storage.catalog();
let schema = catalog.get_table_schema(table_name)?;
let evaluator = sql::Evaluator::new(std::sync::Arc::new(Schema {
columns: vec![],
}));
let empty_tuple = Tuple::new(vec![]);
let bulk_threshold = self.storage.smfi_bulk_load_threshold();
let _smfi_guard = if values.len() >= bulk_threshold {
Some(self.storage.suspend_smfi_for_bulk_load(
table_name,
storage::BulkLoadReason::MultiRowInsert,
))
} else {
None
};
let mut trigger_context = sql::TriggerContext::new();
let trigger_event = sql::logical_plan::TriggerEvent::Insert;
let has_triggers = self.trigger_registry.has_triggers_for_table(table_name);
let mut returned_tuples: Vec<Tuple> = Vec::new();
let has_returning = returning.is_some();
let default_exprs: Vec<Option<sql::LogicalExpr>> = schema.columns.iter()
.map(|col| {
col.default_expr.as_ref().and_then(|json| {
serde_json::from_str(json).ok()
})
})
.collect();
let column_indices: Option<Vec<usize>> = columns.as_ref().map(|cols| {
cols.iter()
.filter_map(|col_name| schema.get_column_index(col_name))
.collect()
});
let mut count = 0;
for value_row in values {
let mut tuple_values: Vec<Option<Value>> = vec![None; schema.columns.len()];
for (val_idx, expr) in value_row.iter().enumerate() {
let target_col_idx = if let Some(ref indices) = column_indices {
if val_idx >= indices.len() {
return Err(Error::query_execution(
"More values than columns specified"
));
}
*indices.get(val_idx).ok_or_else(|| Error::internal("column index out of bounds"))?
} else {
val_idx
};
if matches!(expr, sql::LogicalExpr::DefaultValue) {
continue;
}
let target_col = schema.get_column_at(target_col_idx)
.ok_or_else(|| Error::query_execution(format!(
"Too many values for INSERT: table has {} columns",
schema.columns.len()
)))?;
let target_type = &target_col.data_type;
let mut value = evaluator.evaluate(expr, &empty_tuple)?;
let needs_cast = match (&value, target_type) {
(Value::Null, _) => false,
(Value::Vector(_), DataType::Vector(_)) => false,
(Value::String(_), DataType::Vector(_)) => true,
(Value::String(_), DataType::Json | DataType::Jsonb) => true,
(Value::Int4(_), DataType::Int4) => false,
(Value::Int8(_), DataType::Int8) => false,
(Value::Float4(_), DataType::Float4) => false,
(Value::Float8(_), DataType::Float8) => false,
(Value::String(_), DataType::Text | DataType::Varchar(_)) => false,
(Value::Boolean(_), DataType::Boolean) => false,
(Value::Json(_), DataType::Json | DataType::Jsonb) => false,
_ => true,
};
if needs_cast {
value = evaluator.cast_value(value, target_type)?;
}
if let Some(target_col_ref) = schema.get_column_at(target_col_idx) {
if matches!(value, Value::Null) && !target_col_ref.nullable {
return Err(Error::constraint_violation(format!(
"NOT NULL constraint violated: cannot insert NULL into column '{}'",
target_col_ref.name
)));
}
}
let tv = tuple_values.get_mut(target_col_idx)
.ok_or_else(|| Error::internal("column index out of bounds"))?;
*tv = Some(value);
}
let final_values: Result<Vec<Value>> = tuple_values
.into_iter()
.enumerate()
.map(|(idx, opt_val)| {
if let Some(val) = opt_val {
Ok(val)
} else {
let col = schema.get_column_at(idx)
.ok_or_else(|| Error::internal("column index out of bounds"))?;
if let Some(ref default_expr) = default_exprs.get(idx).and_then(|d| d.as_ref()) {
let mut value = evaluator.evaluate(default_expr, &empty_tuple)?;
if value.data_type() != col.data_type {
value = evaluator.cast_value(value, &col.data_type)?;
}
Ok(value)
} else if col.primary_key {
Ok(Value::Null)
} else if col.nullable {
Ok(Value::Null)
} else {
Err(Error::query_execution(format!(
"Column '{}' does not have a default value and is not nullable",
col.name
)))
}
}
})
.collect();
let final_values_vec = final_values?;
for (idx, col) in schema.columns.iter().enumerate() {
if !col.nullable && !col.primary_key {
if matches!(final_values_vec.get(idx), Some(Value::Null)) {
return Err(Error::constraint_violation(format!(
"NOT NULL constraint violated: cannot insert NULL into column '{}'",
col.name
)));
}
}
}
let mut tuple = Tuple::new(final_values_vec.clone());
let table_constraints = catalog.load_table_constraints(table_name)?;
for fk in &table_constraints.foreign_keys {
if fk.enforcement == sql::ConstraintEnforcement::Immediate {
let fk_values: Vec<Value> = fk.columns.iter()
.map(|col_name| {
schema.columns.iter()
.position(|c| &c.name == col_name)
.and_then(|idx| final_values_vec.get(idx).cloned())
.unwrap_or(Value::Null)
})
.collect();
if fk_values.iter().any(|v| matches!(v, Value::Null)) {
continue;
}
let key = crate::storage::ArtIndexManager::encode_key(&fk_values);
let exists = if let Some(found) = self.storage.art_indexes().pk_index_contains(&fk.references_table, &key) {
found
} else {
self.check_foreign_key_exists(
&fk.references_table,
&fk.references_columns,
&fk_values,
)?
};
if !exists {
return Err(Error::constraint_violation(format!(
"Foreign key constraint '{}' violated: referenced row in table '{}' does not exist",
fk.name, fk.references_table
)));
}
}
}
for check in &table_constraints.check_constraints {
let check_result = self.evaluate_check_constraint(
&check.expression,
&schema,
&final_values_vec,
)?;
if !check_result {
return Err(Error::constraint_violation(format!(
"CHECK constraint '{}' violated: expression '{}' evaluated to false",
check.name, check.expression
)));
}
}
{
let mut col_values_map = std::collections::HashMap::new();
for (i, col) in schema.columns.iter().enumerate() {
if let Some(v) = final_values_vec.get(i) {
col_values_map.insert(col.name.clone(), v.clone());
}
}
if let Err(e) = self.storage.art_indexes().check_unique_constraints(table_name, &col_values_map) {
match on_conflict {
Some(sql::logical_plan::OnConflictAction::DoNothing) => {
continue;
}
Some(sql::logical_plan::OnConflictAction::DoUpdate { assignments }) => {
let err_msg = e.to_string();
let mut excluded_map = std::collections::HashMap::new();
for (i, col) in schema.columns.iter().enumerate() {
if let Some(v) = final_values_vec.get(i) {
excluded_map.insert(col.name.to_lowercase(), v.clone());
}
}
let existing_row_id = {
let mut found_row_id: Option<u64> = None;
for (i, col) in schema.columns.iter().enumerate() {
if (col.unique || col.primary_key) && !col.primary_key {
if let Some(val) = final_values_vec.get(i) {
if !matches!(val, Value::Null) {
let scan_sql = format!(
"SELECT {} FROM {} WHERE {} = '{}'",
schema.columns.iter().find(|c| c.primary_key).map(|c| c.name.as_str()).unwrap_or("rowid"),
table_name,
col.name,
val.to_string().trim_matches('\'')
);
if let Ok(rows) = self.query(&scan_sql, &[]) {
if let Some(row) = rows.first() {
if let Some(pk_val) = row.values.first() {
match pk_val {
Value::Int8(id) => { found_row_id = Some(*id as u64); }
Value::Int4(id) => { found_row_id = Some(*id as u64); }
_ => {}
}
}
}
}
if found_row_id.is_some() { break; }
}
}
}
}
if found_row_id.is_none() {
let pk_cols: Vec<(usize, &crate::Column)> = schema.columns.iter().enumerate()
.filter(|(_, c)| c.primary_key)
.collect();
let pk_values: Vec<Value> = pk_cols.iter()
.filter_map(|(idx, _)| final_values_vec.get(*idx).cloned())
.collect();
if !pk_values.is_empty() && !pk_values.iter().any(|v| matches!(v, Value::Null)) {
let pk_key = crate::storage::ArtIndexManager::encode_key(&pk_values);
found_row_id = self.storage.art_indexes().pk_index_lookup(table_name, &pk_key);
}
}
found_row_id.ok_or_else(|| Error::query_execution(
format!("ON CONFLICT DO UPDATE: could not find existing row ({})", err_msg)
))?
};
let existing_key = self.storage.branch_aware_data_key(table_name, existing_row_id);
let existing_raw = match txn.get(&existing_key)? {
Some(raw) => raw,
None => self.storage.get(&existing_key)?
.ok_or_else(|| Error::query_execution(
"ON CONFLICT DO UPDATE: existing row not found in storage"
))?,
};
let mut existing_tuple: Tuple = bincode::deserialize(&existing_raw)
.map_err(|err| Error::storage(format!("Failed to deserialize tuple: {}", err)))?;
existing_tuple.row_id = Some(existing_row_id);
let update_evaluator = sql::Evaluator::new(std::sync::Arc::new(schema.clone()));
for (col_name, expr) in assignments {
let target_idx = schema.columns.iter()
.position(|c| c.name.eq_ignore_ascii_case(col_name))
.ok_or_else(|| Error::query_execution(format!(
"ON CONFLICT DO UPDATE: column '{}' not found", col_name
)))?;
let resolved_expr = Self::resolve_excluded_refs(expr, &excluded_map);
let mut new_val = update_evaluator.evaluate(&resolved_expr, &existing_tuple)?;
let target_type = &schema.columns.get(target_idx)
.ok_or_else(|| Error::internal("column index out of bounds"))?
.data_type;
if new_val.data_type() != *target_type && !matches!(new_val, Value::Null) {
new_val = update_evaluator.cast_value(new_val, target_type)?;
}
if target_idx < existing_tuple.values.len() {
#[allow(clippy::indexing_slicing)]
{ existing_tuple.values[target_idx] = new_val; }
}
}
let updated_val = bincode::serialize(&existing_tuple)
.map_err(|err| Error::storage(err.to_string()))?;
txn.put(existing_key.clone(), updated_val.clone())?;
{
let mut updated_col_values = std::collections::HashMap::new();
for (i, col) in schema.columns.iter().enumerate() {
if let Some(v) = existing_tuple.values.get(i) {
updated_col_values.insert(col.name.clone(), v.clone());
}
}
let _ = self.storage.art_indexes().on_delete(table_name, existing_row_id, &col_values_map);
let _ = self.storage.art_indexes().on_insert(table_name, existing_row_id, &updated_col_values);
}
if !skip_fast_paths && self.storage.is_wal_enabled() {
self.storage.log_data_insert(table_name, &existing_key, &updated_val)?;
}
self.invalidate_result_cache();
count += 1;
if has_returning {
if let Some(projected) = Self::project_returning_columns(&existing_tuple, &schema, returning) {
returned_tuples.push(projected);
}
}
continue;
}
None => {
return Err(Error::constraint_violation(e.to_string()));
}
}
}
}
if has_triggers {
let row_context = sql::triggers::TriggerRowContext::for_insert(tuple.clone());
let db_ref = self.clone_for_trigger();
let mut executor_fn = |stmt: &sql::LogicalPlan, _ctx: &sql::triggers::TriggerRowContext| -> Result<()> {
db_ref.execute_plan_internal(stmt)?;
Ok(())
};
let action = self.trigger_registry.execute_triggers(
table_name,
&trigger_event,
&sql::logical_plan::TriggerTiming::Before,
&row_context,
&mut trigger_context,
Some(std::sync::Arc::new(schema.clone())),
&mut executor_fn,
)?;
match action {
sql::triggers::TriggerAction::Abort(msg) => {
return Err(Error::query_execution(format!("INSERT aborted by trigger: {}", msg)));
}
sql::triggers::TriggerAction::Skip => {
continue;
}
sql::triggers::TriggerAction::Continue => {
}
}
}
let row_id = catalog.next_row_id(table_name)?;
let key = self.storage.branch_aware_data_key(table_name, row_id);
for (i, col) in schema.columns.iter().enumerate() {
if col.primary_key {
if let Some(v) = tuple.values.get(i) {
if matches!(v, Value::Null) {
if i < tuple.values.len() {
#[allow(clippy::indexing_slicing)]
match col.data_type {
DataType::Int2 => { tuple.values[i] = Value::Int2(row_id as i16); }
DataType::Int4 => { tuple.values[i] = Value::Int4(row_id as i32); }
_ => { tuple.values[i] = Value::Int8(row_id as i64); }
}
}
}
}
}
}
let mut col_values = std::collections::HashMap::new();
for (i, col) in schema.columns.iter().enumerate() {
if let Some(v) = tuple.values.get(i) {
col_values.insert(col.name.clone(), v.clone());
}
}
self.check_fk_constraints_on_write(table_name, &col_values, Some(txn))?;
let val = bincode::serialize(&tuple).map_err(|e| Error::storage(e.to_string()))?;
txn.put(key.clone(), val.clone())?;
if !skip_fast_paths && self.storage.is_wal_enabled() {
self.storage.log_data_insert(table_name, &key, &val)?;
}
{
if let Err(e) = self.storage.art_indexes().on_insert(table_name, row_id, &col_values) {
tracing::debug!("ART index insert for '{}': {}", table_name, e);
}
if skip_fast_paths {
self.art_undo_log.write().push((table_name.clone(), row_id, col_values));
}
}
count += 1;
if has_returning {
let mut returned_tuple = tuple.clone();
returned_tuple.row_id = Some(row_id);
if let Some(projected) = Self::project_returning_columns(&returned_tuple, &schema, returning) {
returned_tuples.push(projected);
}
}
if let Some(context) = self.tenant_manager.get_current_context() {
let tuple_size = val.len() as u64;
if let Some(current_quota) = self.tenant_manager.get_quota_tracking(context.tenant_id) {
let new_storage = current_quota.storage_bytes_used + tuple_size;
if let Err(e) = self.tenant_manager.update_storage_usage(context.tenant_id, new_storage) {
return Err(Error::query_execution(format!("Storage quota exceeded: {}", e)));
}
}
let new_values = serde_json::to_string(&tuple.values)
.unwrap_or_else(|_| "[]".to_string());
self.tenant_manager.record_change_event(
crate::tenant::ChangeType::Insert,
table_name.to_string(),
row_id.to_string(),
None, Some(new_values),
context.tenant_id,
None, );
}
if has_triggers {
let row_context = sql::triggers::TriggerRowContext::for_insert(tuple.clone());
let db_ref = self.clone_for_trigger();
let mut executor_fn = |stmt: &sql::LogicalPlan, _ctx: &sql::triggers::TriggerRowContext| -> Result<()> {
db_ref.execute_plan_internal(stmt)?;
Ok(())
};
let action = self.trigger_registry.execute_triggers(
table_name,
&trigger_event,
&sql::logical_plan::TriggerTiming::After,
&row_context,
&mut trigger_context,
Some(std::sync::Arc::new(schema.clone())),
&mut executor_fn,
)?;
if let sql::triggers::TriggerAction::Abort(msg) = action {
return Err(Error::query_execution(format!("INSERT aborted by AFTER trigger: {}", msg)));
}
}
}
Ok(count)
}
sql::LogicalPlan::InsertSelect { table_name, columns, source, returning } => {
let mut executor = sql::Executor::with_storage(&self.storage)
.with_timeout(self.config.storage.query_timeout_ms);
let source_rows = executor.execute(source)?;
let catalog = self.storage.catalog();
let schema = catalog.get_table_schema(table_name)?;
let evaluator = sql::Evaluator::new(std::sync::Arc::new(Schema {
columns: vec![],
}));
let empty_tuple = Tuple::new(vec![]);
let column_indices: Option<Vec<usize>> = columns.as_ref().map(|cols| {
cols.iter()
.filter_map(|col_name| schema.get_column_index(col_name))
.collect()
});
let default_exprs: Vec<Option<sql::LogicalExpr>> = schema.columns.iter()
.map(|col| {
col.default_expr.as_ref().and_then(|json| {
serde_json::from_str(json).ok()
})
})
.collect();
let mut trigger_context = sql::TriggerContext::new();
let trigger_event = sql::logical_plan::TriggerEvent::Insert;
let has_triggers = self.trigger_registry.has_triggers_for_table(table_name);
let has_returning = returning.is_some();
let mut returned_tuples: Vec<Tuple> = Vec::new();
let bulk_threshold = self.storage.smfi_bulk_load_threshold();
let _smfi_guard = if source_rows.len() >= bulk_threshold {
Some(self.storage.suspend_smfi_for_bulk_load(
table_name,
storage::BulkLoadReason::MultiRowInsert,
))
} else {
None
};
let mut count = 0u64;
for source_row in &source_rows {
let mut tuple_values: Vec<Option<Value>> = vec![None; schema.columns.len()];
for (val_idx, value) in source_row.values.iter().enumerate() {
let target_col_idx = if let Some(ref indices) = column_indices {
if val_idx >= indices.len() {
return Err(Error::query_execution(
"More values than columns specified"
));
}
*indices.get(val_idx).ok_or_else(|| Error::internal("column index out of bounds"))?
} else {
val_idx
};
let target_col = schema.get_column_at(target_col_idx)
.ok_or_else(|| Error::query_execution(format!(
"Too many values for INSERT: table has {} columns",
schema.columns.len()
)))?;
let target_type = &target_col.data_type;
let mut val = value.clone();
let needs_cast = match (&val, target_type) {
(Value::Null, _) => false,
(Value::Vector(_), DataType::Vector(_)) => false,
(Value::String(_), DataType::Vector(_)) => true,
(Value::String(_), DataType::Json | DataType::Jsonb) => true,
(Value::Int4(_), DataType::Int4) => false,
(Value::Int8(_), DataType::Int8) => false,
(Value::Float4(_), DataType::Float4) => false,
(Value::Float8(_), DataType::Float8) => false,
(Value::String(_), DataType::Text | DataType::Varchar(_)) => false,
(Value::Boolean(_), DataType::Boolean) => false,
(Value::Json(_), DataType::Json | DataType::Jsonb) => false,
_ => true,
};
if needs_cast {
val = evaluator.cast_value(val, target_type)?;
}
if let Some(target_col_ref) = schema.get_column_at(target_col_idx) {
if matches!(val, Value::Null) && !target_col_ref.nullable {
return Err(Error::constraint_violation(format!(
"NOT NULL constraint violated: cannot insert NULL into column '{}'",
target_col_ref.name
)));
}
}
let tv = tuple_values.get_mut(target_col_idx)
.ok_or_else(|| Error::internal("column index out of bounds"))?;
*tv = Some(val);
}
let final_values: Result<Vec<Value>> = tuple_values
.into_iter()
.enumerate()
.map(|(idx, opt_val)| {
if let Some(val) = opt_val {
Ok(val)
} else {
let col = schema.get_column_at(idx)
.ok_or_else(|| Error::internal("column index out of bounds"))?;
if let Some(ref default_expr) = default_exprs.get(idx).and_then(|d| d.as_ref()) {
let mut value = evaluator.evaluate(default_expr, &empty_tuple)?;
if value.data_type() != col.data_type {
value = evaluator.cast_value(value, &col.data_type)?;
}
Ok(value)
} else if col.primary_key {
Ok(Value::Null)
} else if col.nullable {
Ok(Value::Null)
} else {
Err(Error::query_execution(format!(
"Column '{}' does not have a default value and is not nullable",
col.name
)))
}
}
})
.collect();
let final_values_vec = final_values?;
let tuple = Tuple::new(final_values_vec.clone());
let table_constraints = catalog.load_table_constraints(table_name)?;
for fk in &table_constraints.foreign_keys {
if fk.enforcement == sql::ConstraintEnforcement::Immediate {
let fk_values: Vec<Value> = fk.columns.iter()
.map(|col_name| {
schema.columns.iter()
.position(|c| &c.name == col_name)
.and_then(|idx| final_values_vec.get(idx).cloned())
.unwrap_or(Value::Null)
})
.collect();
if fk_values.iter().any(|v| matches!(v, Value::Null)) {
continue;
}
let key = crate::storage::ArtIndexManager::encode_key(&fk_values);
let exists = if let Some(found) = self.storage.art_indexes().pk_index_contains(&fk.references_table, &key) {
found
} else {
self.check_foreign_key_exists(
&fk.references_table,
&fk.references_columns,
&fk_values,
)?
};
if !exists {
return Err(Error::constraint_violation(format!(
"Foreign key constraint '{}' violated: referenced row in table '{}' does not exist",
fk.name, fk.references_table
)));
}
}
}
for check in &table_constraints.check_constraints {
let check_result = self.evaluate_check_constraint(
&check.expression,
&schema,
&final_values_vec,
)?;
if !check_result {
return Err(Error::constraint_violation(format!(
"CHECK constraint '{}' violated: expression '{}' evaluated to false",
check.name, check.expression
)));
}
}
if !table_constraints.unique_constraints.is_empty() {
for uc in &table_constraints.unique_constraints {
let uc_values: Vec<Value> = uc.columns.iter()
.map(|col_name| {
schema.columns.iter()
.position(|c| &c.name == col_name)
.and_then(|idx| final_values_vec.get(idx).cloned())
.unwrap_or(Value::Null)
})
.collect();
if uc_values.iter().any(|v| matches!(v, Value::Null)) {
continue;
}
let key = crate::storage::ArtIndexManager::encode_key(&uc_values);
if self.storage.art_indexes().pk_index_contains(table_name, &key) == Some(true) {
return Err(Error::constraint_violation(format!(
"UNIQUE constraint '{}' violated: duplicate value for columns ({})",
uc.name,
uc.columns.join(", ")
)));
}
}
}
if has_triggers {
let row_context = sql::triggers::TriggerRowContext::for_insert(tuple.clone());
let db_ref = self.clone_for_trigger();
let mut executor_fn = |stmt: &sql::LogicalPlan, _ctx: &sql::triggers::TriggerRowContext| -> Result<()> {
db_ref.execute_plan_internal(stmt)?;
Ok(())
};
let action = self.trigger_registry.execute_triggers(
table_name,
&trigger_event,
&sql::logical_plan::TriggerTiming::Before,
&row_context,
&mut trigger_context,
Some(std::sync::Arc::new(schema.clone())),
&mut executor_fn,
)?;
match action {
sql::triggers::TriggerAction::Abort(msg) => {
return Err(Error::query_execution(format!("INSERT aborted by trigger: {}", msg)));
}
sql::triggers::TriggerAction::Skip => {
continue;
}
sql::triggers::TriggerAction::Continue => {}
}
}
let row_id = self.storage.insert_tuple_branch_aware_with_schema(table_name, tuple.clone(), &schema)?;
{
let mut col_values = std::collections::HashMap::new();
for (i, col) in schema.columns.iter().enumerate() {
if let Some(v) = final_values_vec.get(i) {
col_values.insert(col.name.clone(), v.clone());
}
}
if let Err(e) = self.storage.art_indexes().on_insert(table_name, row_id, &col_values) {
tracing::debug!("ART index insert for '{}': {}", table_name, e);
}
}
count += 1;
if has_returning {
let mut returned_tuple = tuple.clone();
returned_tuple.row_id = Some(row_id);
if let Some(projected) = Self::project_returning_columns(&returned_tuple, &schema, returning) {
returned_tuples.push(projected);
}
}
if has_triggers {
let row_context = sql::triggers::TriggerRowContext::for_insert(tuple.clone());
let db_ref = self.clone_for_trigger();
let mut executor_fn = |stmt: &sql::LogicalPlan, _ctx: &sql::triggers::TriggerRowContext| -> Result<()> {
db_ref.execute_plan_internal(stmt)?;
Ok(())
};
let action = self.trigger_registry.execute_triggers(
table_name,
&trigger_event,
&sql::logical_plan::TriggerTiming::After,
&row_context,
&mut trigger_context,
Some(std::sync::Arc::new(schema.clone())),
&mut executor_fn,
)?;
if let sql::triggers::TriggerAction::Abort(msg) = action {
return Err(Error::query_execution(format!("INSERT aborted by AFTER trigger: {}", msg)));
}
}
}
Ok(count)
}
sql::LogicalPlan::Update { table_name, assignments, selection, returning } => {
let catalog = self.storage.catalog();
let schema = catalog.get_table_schema(table_name)?;
let eval_schema = schema.clone().with_source_table_name(table_name);
let evaluator = sql::Evaluator::with_parameters(
std::sync::Arc::new(eval_schema),
vec![],
);
let mut trigger_context = sql::TriggerContext::new();
let updated_columns: Vec<String> = assignments.iter().map(|(col, _)| col.clone()).collect();
let trigger_event = sql::logical_plan::TriggerEvent::Update(Some(updated_columns));
let has_triggers = self.trigger_registry.has_triggers_for_table(table_name);
let on_branch = self.storage.get_current_branch().is_some();
let tuples = if !on_branch {
if let Some(pk_value) = Self::try_extract_pk_value(selection.as_ref(), &schema) {
match self.storage.get_row_by_pk(table_name, &pk_value)? {
Some(tuple) => vec![tuple],
None => vec![],
}
} else {
self.storage.scan_table_branch_aware(table_name)?
}
} else {
self.storage.scan_table_branch_aware(table_name)?
};
let mut updates: Vec<(u64, Tuple)> = Vec::new();
for old_tuple in tuples {
let matches = if let Some(predicate) = selection {
let result = evaluator.evaluate(predicate, &old_tuple)?;
match result {
Value::Boolean(b) => b,
_ => false,
}
} else {
true
};
if matches {
let mut new_tuple = old_tuple.clone();
for (col_name, value_expr) in assignments {
let bound = self.materialize_scalar_subqueries_for_row(
value_expr, &old_tuple, &schema, table_name,
)?;
let mut new_value = evaluator.evaluate(&bound, &old_tuple)?;
let col_index = evaluator.schema().get_column_index(col_name)
.ok_or_else(|| Error::query_execution(format!("Column '{}' not found", col_name)))?;
let target_col = schema.get_column_at(col_index)
.ok_or_else(|| Error::query_execution(format!("Column '{}' not found", col_name)))?;
let target_type = &target_col.data_type;
let needs_cast = !matches!(&new_value, Value::Null)
&& !matches!(
(&new_value, target_type),
(Value::Vector(_), DataType::Vector(_))
| (Value::Int2(_), DataType::Int2)
| (Value::Int4(_), DataType::Int4)
| (Value::Int8(_), DataType::Int8)
| (Value::Float4(_), DataType::Float4)
| (Value::Float8(_), DataType::Float8)
| (Value::String(_), DataType::Text | DataType::Varchar(_))
| (Value::Boolean(_), DataType::Boolean)
| (Value::Json(_), DataType::Json | DataType::Jsonb)
| (Value::Timestamp(_), DataType::Timestamp | DataType::Timestamptz)
| (Value::Date(_), DataType::Date)
);
if needs_cast {
new_value = evaluator.cast_value(new_value, target_type)?;
}
*new_tuple.values.get_mut(col_index)
.ok_or_else(|| Error::internal("column index out of bounds"))? = new_value;
}
if has_triggers {
let row_context = sql::triggers::TriggerRowContext::for_update(old_tuple.clone(), new_tuple.clone());
let db_ref = self.clone_for_trigger();
let mut executor_fn = |stmt: &sql::LogicalPlan, _ctx: &sql::triggers::TriggerRowContext| -> Result<()> {
db_ref.execute_plan_internal(stmt)?;
Ok(())
};
let action = self.trigger_registry.execute_triggers(
table_name,
&trigger_event,
&sql::logical_plan::TriggerTiming::Before,
&row_context,
&mut trigger_context,
Some(evaluator.schema().clone()),
&mut executor_fn,
)?;
match action {
sql::triggers::TriggerAction::Abort(msg) => {
return Err(Error::query_execution(format!("UPDATE aborted by trigger: {}", msg)));
}
sql::triggers::TriggerAction::Skip => {
continue;
}
sql::triggers::TriggerAction::Continue => {
}
}
}
let mut new_col_values = std::collections::HashMap::with_capacity(schema.columns.len());
for (i, col) in schema.columns.iter().enumerate() {
if let Some(v) = new_tuple.values.get(i) {
new_col_values.insert(col.name.clone(), v.clone());
}
}
self.check_fk_constraints_on_write(table_name, &new_col_values, Some(txn))?;
let row_id = new_tuple.row_id.unwrap_or(0);
updates.push((row_id, new_tuple.clone()));
if let Some(context) = self.tenant_manager.get_current_context() {
let old_values = serde_json::to_string(&old_tuple.values)
.unwrap_or_else(|_| "[]".to_string());
let new_values = serde_json::to_string(&new_tuple.values)
.unwrap_or_else(|_| "[]".to_string());
self.tenant_manager.record_change_event(
crate::tenant::ChangeType::Update,
table_name.to_string(),
row_id.to_string(),
Some(old_values),
Some(new_values),
context.tenant_id,
None,
);
}
if has_triggers {
let row_context = sql::triggers::TriggerRowContext::for_update(old_tuple.clone(), new_tuple.clone());
let db_ref = self.clone_for_trigger();
let mut executor_fn = |stmt: &sql::LogicalPlan, _ctx: &sql::triggers::TriggerRowContext| -> Result<()> {
db_ref.execute_plan_internal(stmt)?;
Ok(())
};
let action = self.trigger_registry.execute_triggers(
table_name,
&trigger_event,
&sql::logical_plan::TriggerTiming::After,
&row_context,
&mut trigger_context,
Some(evaluator.schema().clone()),
&mut executor_fn,
)?;
if let sql::triggers::TriggerAction::Abort(msg) = action {
return Err(Error::query_execution(format!("UPDATE aborted by AFTER trigger: {}", msg)));
}
}
}
}
let update_count = updates.len() as u64;
for (row_id, tuple) in &updates {
let key = self.storage.branch_aware_data_key(table_name, *row_id);
let value = bincode::serialize(tuple)
.map_err(|e| Error::storage(format!("Failed to serialize tuple: {}", e)))?;
txn.put(key.clone(), value.clone())?;
if !skip_fast_paths && self.storage.is_wal_enabled() {
self.storage.log_data_update(table_name, &key, &value)?;
}
self.storage.row_cache().invalidate(table_name, *row_id);
}
if let Some(context) = self.tenant_manager.get_current_context() {
let mut storage_delta: i64 = 0;
for (_row_id, new_tuple) in &updates {
let new_size = bincode::serialize(new_tuple)
.map(|bytes| bytes.len() as i64)
.unwrap_or(256);
storage_delta += new_size;
}
if let Some(current_quota) = self.tenant_manager.get_quota_tracking(context.tenant_id) {
let new_storage = (current_quota.storage_bytes_used as i64 + storage_delta).max(0) as u64;
if let Err(e) = self.tenant_manager.update_storage_usage(context.tenant_id, new_storage) {
return Err(Error::query_execution(format!("Storage quota exceeded: {}", e)));
}
}
}
let returned_tuples: Vec<Tuple> = if returning.is_some() {
updates.iter()
.filter_map(|(_, tuple)| Self::project_returning_columns(tuple, &schema, returning))
.collect()
} else {
Vec::new()
};
let _ = returned_tuples;
Ok(update_count)
}
sql::LogicalPlan::Delete { table_name, selection, returning } => {
let catalog = self.storage.catalog();
let schema = catalog.get_table_schema(table_name)?;
let schema_arc = std::sync::Arc::new(schema);
let eval_schema = std::sync::Arc::new(
(*schema_arc).clone().with_source_table_name(table_name),
);
let evaluator = sql::Evaluator::with_parameters(
eval_schema,
vec![],
);
let mut trigger_context = sql::TriggerContext::new();
let trigger_event = sql::logical_plan::TriggerEvent::Delete;
let has_triggers = self.trigger_registry.has_triggers_for_table(table_name);
let on_branch = self.storage.get_current_branch().is_some();
let tuples = if !on_branch {
if let Some(pk_value) = Self::try_extract_pk_value(selection.as_ref(), &schema_arc) {
match self.storage.get_row_by_pk(table_name, &pk_value)? {
Some(tuple) => vec![tuple],
None => vec![],
}
} else {
self.storage.scan_table_branch_aware(table_name)?
}
} else {
self.storage.scan_table_branch_aware(table_name)?
};
let mut row_ids_to_delete: Vec<u64> = Vec::new();
let mut deleted_tuples: Vec<(u64, Tuple)> = Vec::new();
let mut returned_tuples: Vec<Tuple> = Vec::new();
let has_returning = returning.is_some();
for tuple in tuples {
let matches = if let Some(predicate) = selection {
let result = evaluator.evaluate(predicate, &tuple)?;
match result {
Value::Boolean(b) => b,
_ => false,
}
} else {
true
};
if matches {
if has_triggers {
let row_context = sql::triggers::TriggerRowContext::for_delete(tuple.clone());
let db_ref = self.clone_for_trigger();
let mut executor_fn = |stmt: &sql::LogicalPlan, _ctx: &sql::triggers::TriggerRowContext| -> Result<()> {
db_ref.execute_plan_internal(stmt)?;
Ok(())
};
let action = self.trigger_registry.execute_triggers(
table_name,
&trigger_event,
&sql::logical_plan::TriggerTiming::Before,
&row_context,
&mut trigger_context,
Some(evaluator.schema().clone()),
&mut executor_fn,
)?;
match action {
sql::triggers::TriggerAction::Abort(msg) => {
return Err(Error::query_execution(format!("DELETE aborted by trigger: {}", msg)));
}
sql::triggers::TriggerAction::Skip => {
continue;
}
sql::triggers::TriggerAction::Continue => {
}
}
}
if let Some(row_id) = tuple.row_id {
let referencing_fks = catalog.get_referencing_fks(table_name)?;
for fk in &referencing_fks {
if fk.enforcement == sql::ConstraintEnforcement::Immediate {
let ref_values: Vec<Value> = fk.references_columns.iter()
.map(|col_name| {
schema_arc.columns.iter()
.position(|c| &c.name == col_name)
.and_then(|idx| tuple.values.get(idx).cloned())
.unwrap_or(Value::Null)
})
.collect();
let has_refs = self.check_referencing_rows_exist(
&fk.table_name,
&fk.columns,
&ref_values,
Some(txn),
)?;
if has_refs {
match fk.on_delete {
sql::constraints::ReferentialAction::NoAction |
sql::constraints::ReferentialAction::Restrict => {
return Err(Error::constraint_violation(format!(
"Foreign key constraint '{}' violated: cannot delete row from '{}' - referenced by '{}'",
fk.name, table_name, fk.table_name
)));
}
sql::constraints::ReferentialAction::Cascade => {
self.cascade_delete_referencing_rows(
&fk.table_name,
&fk.columns,
&ref_values,
)?;
}
sql::constraints::ReferentialAction::SetNull => {
self.set_null_referencing_rows(
&fk.table_name,
&fk.columns,
&ref_values,
)?;
}
sql::constraints::ReferentialAction::SetDefault => {
return Err(Error::constraint_violation(format!(
"Foreign key constraint '{}' with SET DEFAULT action: not implemented",
fk.name
)));
}
}
}
}
}
row_ids_to_delete.push(row_id);
deleted_tuples.push((row_id, tuple.clone()));
if has_returning {
if let Some(projected) = Self::project_returning_columns(&tuple, &schema_arc, returning) {
returned_tuples.push(projected);
}
}
if let Some(context) = self.tenant_manager.get_current_context() {
let old_values = serde_json::to_string(&tuple.values)
.unwrap_or_else(|_| "[]".to_string());
self.tenant_manager.record_change_event(
crate::tenant::ChangeType::Delete,
table_name.to_string(),
row_id.to_string(),
Some(old_values),
None, context.tenant_id,
None,
);
}
}
if has_triggers {
let row_context = sql::triggers::TriggerRowContext::for_delete(tuple.clone());
let db_ref = self.clone_for_trigger();
let mut executor_fn = |stmt: &sql::LogicalPlan, _ctx: &sql::triggers::TriggerRowContext| -> Result<()> {
db_ref.execute_plan_internal(stmt)?;
Ok(())
};
let action = self.trigger_registry.execute_triggers(
table_name,
&trigger_event,
&sql::logical_plan::TriggerTiming::After,
&row_context,
&mut trigger_context,
Some(evaluator.schema().clone()),
&mut executor_fn,
)?;
if let sql::triggers::TriggerAction::Abort(msg) = action {
return Err(Error::query_execution(format!("DELETE aborted by AFTER trigger: {}", msg)));
}
}
}
}
let storage_reclaimed: u64 = if self.tenant_manager.get_current_context().is_some() {
(row_ids_to_delete.len() as u64) * 256
} else {
0
};
let delete_count = row_ids_to_delete.len() as u64;
if let Some(branch_id) = self.storage.get_current_branch_id() {
for row_id in &row_ids_to_delete {
let delete_key = format!("bdel:{}:{}:{}", branch_id, table_name, row_id).into_bytes();
txn.put(delete_key, vec![])?;
self.storage.row_cache().invalidate(table_name, *row_id);
}
} else {
for row_id in &row_ids_to_delete {
let key = format!("data:{}:{}", table_name, row_id).into_bytes();
txn.delete(key.clone())?;
if !skip_fast_paths && self.storage.is_wal_enabled() {
self.storage.log_data_delete(table_name, &key)?;
}
self.storage.row_cache().invalidate(table_name, *row_id);
}
}
if let Some(context) = self.tenant_manager.get_current_context() {
if let Some(current_quota) = self.tenant_manager.get_quota_tracking(context.tenant_id) {
let new_storage = current_quota.storage_bytes_used.saturating_sub(storage_reclaimed);
let _ = self.tenant_manager.update_storage_usage(context.tenant_id, new_storage);
}
}
for (row_id, tuple) in &deleted_tuples {
let mut col_values = std::collections::HashMap::new();
for (i, col) in schema_arc.columns.iter().enumerate() {
if let Some(v) = tuple.values.get(i) {
col_values.insert(col.name.clone(), v.clone());
}
}
if let Err(e) = self.storage.art_indexes().on_delete(table_name, *row_id, &col_values) {
tracing::debug!("ART index delete for table '{}': {}", table_name, e);
}
}
let _ = returned_tuples;
Ok(delete_count)
}
sql::LogicalPlan::CreateFunction { name, or_replace, params, return_type, body, language, volatility } => {
let stored_func = sql::StoredFunction {
name: name.clone(),
or_replace: *or_replace,
params: params.clone(),
return_type: return_type.clone(),
body: body.clone(),
language: language.clone(),
volatility: volatility.clone(),
created_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0),
};
self.function_registry.register_function(stored_func.clone())?;
if let Ok(definition) = bincode::serialize(&stored_func) {
if let Err(e) = self.storage.log_create_function(name, &definition) {
tracing::warn!("Failed to log CREATE FUNCTION to WAL: {}", e);
}
}
Ok(0)
}
sql::LogicalPlan::CreateProcedure { name, or_replace, params, body, language } => {
let stored_proc = sql::StoredProcedure {
name: name.clone(),
or_replace: *or_replace,
params: params.clone(),
body: body.clone(),
language: language.clone(),
created_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0),
};
self.function_registry.register_procedure(stored_proc.clone())?;
if let Ok(definition) = bincode::serialize(&stored_proc) {
if let Err(e) = self.storage.log_create_procedure(name, &definition) {
tracing::warn!("Failed to log CREATE PROCEDURE to WAL: {}", e);
}
}
Ok(0)
}
sql::LogicalPlan::DropFunction { name, if_exists } => {
self.function_registry.drop_function(name, *if_exists)?;
if let Err(e) = self.storage.log_drop_function(name) {
tracing::warn!("Failed to log DROP FUNCTION to WAL: {}", e);
}
Ok(0)
}
sql::LogicalPlan::DropProcedure { name, if_exists } => {
self.function_registry.drop_procedure(name, *if_exists)?;
if let Err(e) = self.storage.log_drop_procedure(name) {
tracing::warn!("Failed to log DROP PROCEDURE to WAL: {}", e);
}
Ok(0)
}
sql::LogicalPlan::CreateTrigger {
name,
table_name,
timing,
events,
for_each,
when_condition,
body,
if_not_exists,
referencing,
characteristics,
trigger_type,
from_constraint,
} => {
if let Ok(Some(_)) = self.trigger_registry.get_trigger(table_name, name) {
if *if_not_exists {
return Ok(0);
} else {
return Err(Error::query_execution(format!(
"Trigger '{}' already exists on table '{}'",
name, table_name
)));
}
}
let definition = sql::triggers::TriggerDefinition {
name: name.clone(),
table_name: table_name.clone(),
timing: timing.clone(),
events: events.clone(),
for_each: for_each.clone(),
when_condition: when_condition.clone(),
body: body.clone(),
enabled: true,
created_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64,
referencing: referencing.clone(),
characteristics: characteristics.clone(),
trigger_type: trigger_type.clone(),
from_constraint: from_constraint.clone(),
};
self.trigger_registry.register_trigger(definition.clone())?;
if let Ok(serialized) = bincode::serialize(&definition) {
if let Err(e) = self.storage.log_create_trigger(name, table_name, &serialized) {
tracing::warn!("Failed to log CREATE TRIGGER to WAL: {}", e);
}
}
Ok(0)
}
sql::LogicalPlan::DropTrigger { name, table_name, if_exists } => {
let tbl = table_name.as_ref().ok_or_else(|| {
Error::query_execution("DROP TRIGGER requires ON <table_name> clause".to_string())
})?;
let dropped = self.trigger_registry.drop_trigger(tbl, name)?;
if !dropped && !*if_exists {
return Err(Error::query_execution(format!(
"Trigger '{}' does not exist on table '{}'",
name, tbl
)));
}
if let Err(e) = self.storage.log_drop_trigger(name, table_name.as_deref()) {
tracing::warn!("Failed to log DROP TRIGGER to WAL: {}", e);
}
Ok(0)
}
sql::LogicalPlan::Call { name, args } => {
let schema = std::sync::Arc::new(Schema { columns: vec![] });
let evaluator = sql::Evaluator::new(schema);
let arg_values: Vec<Value> = args.iter()
.map(|expr| evaluator.evaluate(expr, &Tuple::new(vec![])))
.collect::<Result<Vec<_>>>()?;
let db_clone = self.clone_for_trigger();
let sql_executor = |sql: &str| -> Result<Vec<Vec<Value>>> {
let sql_trimmed = sql.trim();
if starts_with_icase(sql_trimmed, "SELECT") || starts_with_icase(sql_trimmed, "WITH") {
let tuples = db_clone.query(sql, &[])?;
Ok(tuples.iter().map(|t| t.values.clone()).collect())
} else {
db_clone.execute(sql)?;
Ok(vec![])
}
};
self.function_registry.execute_procedure(name, &arg_values, sql_executor)?;
Ok(0)
}
sql::LogicalPlan::AlterColumnStorage { table_name, column_name, storage_mode } => {
let catalog = self.storage.catalog();
let mut schema = catalog.get_table_schema(table_name)?;
let col_idx = schema.columns.iter()
.position(|c| c.name == *column_name)
.ok_or_else(|| Error::query_execution(format!(
"Column '{}' not found in table '{}'", column_name, table_name
)))?;
let col_ref = schema.get_column_at(col_idx)
.ok_or_else(|| Error::internal("column index out of bounds"))?;
let old_mode = col_ref.storage_mode;
if old_mode == *storage_mode {
return Ok(0);
}
let column = col_ref.clone();
let rows_migrated = self.storage.migrate_column_storage(
table_name,
col_idx,
&column,
old_mode,
*storage_mode,
)?;
schema.get_column_at_mut(col_idx)
.ok_or_else(|| Error::internal("column index out of bounds"))?
.storage_mode = *storage_mode;
catalog.update_table_schema(table_name, &schema)?;
if let Err(e) = self.storage.log_alter_column_storage(table_name, column_name, storage_mode) {
tracing::warn!("Failed to log ALTER COLUMN STORAGE to WAL: {}", e);
}
tracing::info!(
"Altered {}.{} storage from {:?} to {:?}, migrated {} rows",
table_name, column_name, old_mode, storage_mode, rows_migrated
);
Ok(rows_migrated as u64)
}
sql::LogicalPlan::AlterTableAddColumn { table_name, column_def, if_not_exists } => {
let catalog = self.storage.catalog();
let mut schema = catalog.get_table_schema(table_name)?;
if schema.columns.iter().any(|c| c.name == column_def.name) {
if *if_not_exists {
return Ok(0);
}
return Err(Error::query_execution(format!(
"Column '{}' already exists in table '{}'", column_def.name, table_name
)));
}
let new_column = Column {
name: column_def.name.clone(),
data_type: column_def.data_type.clone(),
nullable: !column_def.not_null,
primary_key: column_def.primary_key,
source_table: None,
source_table_name: Some(table_name.clone()),
default_expr: column_def.default.as_ref().map(|e| format!("{:?}", e)),
unique: column_def.unique,
storage_mode: column_def.storage_mode,
};
schema.columns.push(new_column);
catalog.update_table_schema(table_name, &schema)?;
let rows_updated = self.storage.add_column_to_rows(
table_name,
&column_def.default,
)?;
tracing::info!(
"Added column '{}' to table '{}', updated {} rows",
column_def.name, table_name, rows_updated
);
Ok(rows_updated as u64)
}
sql::LogicalPlan::AlterTableDropColumn { table_name, column_name, if_exists, cascade } => {
let catalog = self.storage.catalog();
let mut schema = catalog.get_table_schema(table_name)?;
let col_idx = schema.columns.iter()
.position(|c| c.name == *column_name);
match col_idx {
Some(idx) => {
let is_pk = schema.get_column_at(idx)
.ok_or_else(|| Error::internal("column index out of bounds"))?
.primary_key;
if is_pk && !cascade {
return Err(Error::query_execution(format!(
"Cannot drop primary key column '{}' without CASCADE", column_name
)));
}
schema.columns.remove(idx);
catalog.update_table_schema(table_name, &schema)?;
let rows_updated = self.storage.drop_column_from_rows(table_name, idx)?;
tracing::info!(
"Dropped column '{}' from table '{}', updated {} rows",
column_name, table_name, rows_updated
);
Ok(rows_updated as u64)
}
None => {
if *if_exists {
Ok(0)
} else {
Err(Error::query_execution(format!(
"Column '{}' does not exist in table '{}'", column_name, table_name
)))
}
}
}
}
sql::LogicalPlan::AlterTableRenameColumn { table_name, old_column_name, new_column_name } => {
let catalog = self.storage.catalog();
let mut schema = catalog.get_table_schema(table_name)?;
if schema.columns.iter().any(|c| c.name == *new_column_name) {
return Err(Error::query_execution(format!(
"Column '{}' already exists in table '{}'", new_column_name, table_name
)));
}
let col_idx = schema.columns.iter()
.position(|c| c.name == *old_column_name)
.ok_or_else(|| Error::query_execution(format!(
"Column '{}' does not exist in table '{}'", old_column_name, table_name
)))?;
schema.get_column_at_mut(col_idx)
.ok_or_else(|| Error::internal("column index out of bounds"))?
.name = new_column_name.clone();
catalog.update_table_schema(table_name, &schema)?;
tracing::info!(
"Renamed column '{}' to '{}' in table '{}'",
old_column_name, new_column_name, table_name
);
Ok(0)
}
sql::LogicalPlan::AlterTableRename { table_name, new_table_name } => {
let catalog = self.storage.catalog();
if catalog.get_table_schema(new_table_name).is_ok() {
return Err(Error::query_execution(format!(
"Table '{}' already exists", new_table_name
)));
}
self.storage.rename_table(table_name, new_table_name)?;
tracing::info!(
"Renamed table '{}' to '{}'",
table_name, new_table_name
);
Ok(0)
}
sql::LogicalPlan::AlterTableAddForeignKey {
table_name,
constraint_name,
columns,
references_table,
references_columns,
on_delete,
on_update,
deferrable,
initially_deferred,
} => {
let catalog = self.storage.catalog();
catalog.get_table_schema(table_name)?;
catalog.get_table_schema(references_table)?;
let fk_name = constraint_name
.clone()
.unwrap_or_else(|| sql::ForeignKeyConstraint::generate_name(
table_name, columns, references_table,
));
let mut fk = sql::ForeignKeyConstraint::new(
fk_name,
table_name.clone(),
columns.clone(),
references_table.clone(),
references_columns.clone(),
);
if let Some(action) = on_delete {
fk = fk.on_delete(convert_logical_referential_action(action));
}
if let Some(action) = on_update {
fk = fk.on_update(convert_logical_referential_action(action));
}
if *deferrable {
fk = fk.deferrable(*initially_deferred);
}
catalog.add_foreign_key(fk)?;
Ok(0)
}
sql::LogicalPlan::AlterTableMulti { operations } => {
let mut total_rows = 0u64;
for sub_plan in operations {
total_rows += self.execute_alter_table_op(sub_plan)?;
}
Ok(total_rows)
}
sql::LogicalPlan::Savepoint { ref name } => {
let write_set_snapshot = txn.savepoint_snapshot();
let savepoint = SavepointState {
name: name.clone(),
write_set_snapshot,
};
self.savepoints.write().push(savepoint);
Ok(0)
}
sql::LogicalPlan::ReleaseSavepoint { ref name } => {
let mut savepoints = self.savepoints.write();
if let Some(pos) = savepoints.iter().rposition(|s| &s.name == name) {
savepoints.truncate(pos);
Ok(0)
} else {
Err(Error::query_execution(format!("Savepoint '{}' does not exist", name)))
}
}
sql::LogicalPlan::RollbackToSavepoint { ref name } => {
let savepoints = self.savepoints.read();
if let Some(pos) = savepoints.iter().rposition(|s| &s.name == name) {
let snapshot = savepoints.get(pos)
.map(|s| s.write_set_snapshot.clone());
drop(savepoints);
if let Some(snapshot) = snapshot {
txn.rollback_to_savepoint(&snapshot);
}
let mut savepoints = self.savepoints.write();
savepoints.truncate(pos + 1);
Ok(0)
} else {
Err(Error::query_execution(format!("Savepoint '{}' does not exist", name)))
}
}
sql::LogicalPlan::Truncate { ref table_name } => {
let catalog = self.storage.catalog();
let _schema = catalog.get_table_schema(table_name)?;
let rows = self.storage.scan_table(table_name)?;
let mut count = 0u64;
for tuple in &rows {
if let Some(row_id) = tuple.row_id {
let key = format!("data:{}:{}", table_name, row_id).into_bytes();
txn.delete(key)?;
self.storage.row_cache().invalidate(table_name, row_id);
count += 1;
}
}
self.storage.art_indexes().clear_table_indexes(table_name);
Ok(count)
}
sql::LogicalPlan::CreateDatabase { name, if_not_exists } => {
let (count, _) = self.handle_create_database(name, *if_not_exists)?;
Ok(count)
}
sql::LogicalPlan::DropDatabase { name, if_exists } => {
let (count, _) = self.handle_drop_database(name, *if_exists)?;
Ok(count)
}
_ => {
let mut executor = sql::Executor::with_storage(&self.storage)
.with_timeout(self.config.storage.query_timeout_ms)
.with_transaction(txn);
let results = executor.execute(&plan)?;
let is_select = matches!(plan,
sql::LogicalPlan::Scan { .. } |
sql::LogicalPlan::Filter { .. } |
sql::LogicalPlan::Project { .. } |
sql::LogicalPlan::Aggregate { .. } |
sql::LogicalPlan::Join { .. } |
sql::LogicalPlan::Sort { .. } |
sql::LogicalPlan::Limit { .. } |
sql::LogicalPlan::With { .. } |
sql::LogicalPlan::TableFunction { .. } |
sql::LogicalPlan::SystemView { .. }
);
let _ = is_select; Ok(results.len() as u64)
}
}
}
#[allow(clippy::expect_used)] pub fn new(path: impl AsRef<std::path::Path>) -> Result<Self> {
let config = Config::default();
let storage = std::sync::Arc::new(storage::StorageEngine::open(path.as_ref(), &config)?);
let mv_scheduler = std::sync::Arc::new(storage::MVScheduler::new(
storage::SchedulerConfig::default(),
std::sync::Arc::clone(&storage),
));
let dump_manager = std::sync::Arc::new(storage::DumpManager::new(
path.as_ref().to_path_buf(),
storage::DumpCompressionType::Zstd,
));
let session_manager = std::sync::Arc::new(crate::session::SessionManager::new());
let lock_manager = std::sync::Arc::new(storage::LockManager::with_default_timeout());
let dirty_tracker = std::sync::Arc::new(storage::DirtyTracker::new());
let catalog = storage::Catalog::new(&storage);
if let Err(e) = catalog.rebuild_all_indexes() {
tracing::warn!("ART rebuild on open failed: {} — falling back to scan paths", e);
}
Ok(Self {
storage,
config,
current_transaction: std::sync::Arc::new(std::sync::Mutex::new(None)),
tenant_manager: std::sync::Arc::new(crate::tenant::TenantManager::new()),
trigger_registry: std::sync::Arc::new(sql::TriggerRegistry::new()),
function_registry: std::sync::Arc::new(sql::FunctionRegistry::new()),
mv_scheduler,
auto_refresh_worker: std::sync::Arc::new(parking_lot::RwLock::new(None)),
dump_manager,
session_manager,
lock_manager,
dirty_tracker,
session_transactions: std::sync::Arc::new(dashmap::DashMap::new()),
prepared_statements: std::sync::Arc::new(parking_lot::RwLock::new(std::collections::HashMap::new())),
savepoints: std::sync::Arc::new(parking_lot::RwLock::new(Vec::new())),
plan_cache: std::sync::Arc::new(std::sync::Mutex::new(lru::LruCache::new(std::num::NonZeroUsize::new(256).expect("256 is non-zero")))),
parse_cache: std::sync::Arc::new(std::sync::Mutex::new(lru::LruCache::new(std::num::NonZeroUsize::new(512).expect("512 is non-zero")))),
result_cache: std::sync::Arc::new(std::sync::Mutex::new(lru::LruCache::new(std::num::NonZeroUsize::new(128).expect("128 is non-zero")))),
art_undo_log: std::sync::Arc::new(parking_lot::RwLock::new(Vec::new())),
})
}
#[allow(clippy::expect_used)] pub fn new_in_memory() -> Result<Self> {
let config = Config::in_memory();
let storage = std::sync::Arc::new(storage::StorageEngine::open_in_memory(&config)?);
let mv_scheduler = std::sync::Arc::new(storage::MVScheduler::new(
storage::SchedulerConfig::default(),
std::sync::Arc::clone(&storage),
));
let dump_path = std::env::temp_dir().join("heliosdb_dumps");
let dump_manager = std::sync::Arc::new(storage::DumpManager::new(
dump_path,
storage::DumpCompressionType::Zstd,
));
let session_manager = std::sync::Arc::new(crate::session::SessionManager::new());
let lock_manager = std::sync::Arc::new(storage::LockManager::with_default_timeout());
let dirty_tracker = std::sync::Arc::new(storage::DirtyTracker::new());
Ok(Self {
storage,
config,
current_transaction: std::sync::Arc::new(std::sync::Mutex::new(None)),
tenant_manager: std::sync::Arc::new(crate::tenant::TenantManager::new()),
trigger_registry: std::sync::Arc::new(sql::TriggerRegistry::new()),
function_registry: std::sync::Arc::new(sql::FunctionRegistry::new()),
mv_scheduler,
auto_refresh_worker: std::sync::Arc::new(parking_lot::RwLock::new(None)),
dump_manager,
session_manager,
lock_manager,
dirty_tracker,
session_transactions: std::sync::Arc::new(dashmap::DashMap::new()),
prepared_statements: std::sync::Arc::new(parking_lot::RwLock::new(std::collections::HashMap::new())),
savepoints: std::sync::Arc::new(parking_lot::RwLock::new(Vec::new())),
plan_cache: std::sync::Arc::new(std::sync::Mutex::new(lru::LruCache::new(std::num::NonZeroUsize::new(256).expect("256 is non-zero")))),
parse_cache: std::sync::Arc::new(std::sync::Mutex::new(lru::LruCache::new(std::num::NonZeroUsize::new(512).expect("512 is non-zero")))),
result_cache: std::sync::Arc::new(std::sync::Mutex::new(lru::LruCache::new(std::num::NonZeroUsize::new(128).expect("128 is non-zero")))),
art_undo_log: std::sync::Arc::new(parking_lot::RwLock::new(Vec::new())),
})
}
#[allow(clippy::expect_used)] pub fn with_config(config: Config) -> Result<Self> {
let memory_only = config.storage.memory_only;
let storage = std::sync::Arc::new(if memory_only {
storage::StorageEngine::open_in_memory(&config)?
} else {
let path = config.storage.path.as_ref()
.ok_or_else(|| Error::config("Storage path not specified for non-memory database".to_string()))?;
storage::StorageEngine::open(path, &config)?
});
let mv_scheduler = std::sync::Arc::new(storage::MVScheduler::new(
storage::SchedulerConfig::default(),
std::sync::Arc::clone(&storage),
));
let dump_path = if let Some(ref p) = config.storage.path {
p.clone()
} else {
std::env::temp_dir().join("heliosdb_dumps")
};
let dump_manager = std::sync::Arc::new(storage::DumpManager::new(
dump_path,
storage::DumpCompressionType::Zstd,
));
let session_manager = std::sync::Arc::new(crate::session::SessionManager::new());
let lock_manager = std::sync::Arc::new(storage::LockManager::with_default_timeout());
let dirty_tracker = std::sync::Arc::new(storage::DirtyTracker::new());
if !memory_only {
let catalog = storage::Catalog::new(&storage);
if let Err(e) = catalog.rebuild_all_indexes() {
tracing::warn!("ART rebuild on open failed: {} — falling back to scan paths", e);
}
}
Ok(Self {
storage,
config,
current_transaction: std::sync::Arc::new(std::sync::Mutex::new(None)),
tenant_manager: std::sync::Arc::new(crate::tenant::TenantManager::new()),
trigger_registry: std::sync::Arc::new(sql::TriggerRegistry::new()),
function_registry: std::sync::Arc::new(sql::FunctionRegistry::new()),
mv_scheduler,
auto_refresh_worker: std::sync::Arc::new(parking_lot::RwLock::new(None)),
dump_manager,
session_manager,
lock_manager,
dirty_tracker,
session_transactions: std::sync::Arc::new(dashmap::DashMap::new()),
prepared_statements: std::sync::Arc::new(parking_lot::RwLock::new(std::collections::HashMap::new())),
savepoints: std::sync::Arc::new(parking_lot::RwLock::new(Vec::new())),
plan_cache: std::sync::Arc::new(std::sync::Mutex::new(lru::LruCache::new(std::num::NonZeroUsize::new(256).expect("256 is non-zero")))),
parse_cache: std::sync::Arc::new(std::sync::Mutex::new(lru::LruCache::new(std::num::NonZeroUsize::new(512).expect("512 is non-zero")))),
result_cache: std::sync::Arc::new(std::sync::Mutex::new(lru::LruCache::new(std::num::NonZeroUsize::new(128).expect("128 is non-zero")))),
art_undo_log: std::sync::Arc::new(parking_lot::RwLock::new(Vec::new())),
})
}
pub fn query_timeout_ms(&self) -> Option<u64> {
self.config.storage.query_timeout_ms
}
fn plan_contains_join(plan: &sql::LogicalPlan) -> bool {
match plan {
sql::LogicalPlan::Join { .. } => true,
sql::LogicalPlan::Filter { input, .. }
| sql::LogicalPlan::Project { input, .. }
| sql::LogicalPlan::Sort { input, .. }
| sql::LogicalPlan::Limit { input, .. }
| sql::LogicalPlan::Aggregate { input, .. } => Self::plan_contains_join(input),
_ => false,
}
}
fn log_slow_query(&self, sql: &str, elapsed: std::time::Duration, rows: u64) {
if let Some(threshold) = self.config.storage.slow_query_threshold_ms {
let elapsed_ms = elapsed.as_millis() as u64;
if elapsed_ms >= threshold {
tracing::warn!(
duration_ms = elapsed_ms,
rows = rows,
"Slow query ({}ms, {} rows): {:.200}",
elapsed_ms,
rows,
sql
);
}
}
}
pub fn execute_batch(&self, statements: &[&str]) -> Result<u64> {
let start = std::time::Instant::now();
let txn_start = std::time::Instant::now();
let txn = self.storage.begin_transaction()?;
tracing::trace!(phase = "txn_begin", duration_us = txn_start.elapsed().as_micros() as u64, "Batch transaction started");
let mut total_rows = 0u64;
for sql in statements {
match self.execute_in_transaction(sql, &txn) {
Ok(count) => total_rows += count,
Err(e) => {
let _ = txn.rollback();
return Err(e);
}
}
}
let commit_start = std::time::Instant::now();
txn.commit()?;
self.storage.increment_lsn();
tracing::debug!(phase = "txn_commit", duration_us = commit_start.elapsed().as_micros() as u64, rows = total_rows, "Batch transaction committed");
let elapsed = start.elapsed();
tracing::debug!(phase = "execute", duration_us = elapsed.as_micros() as u64, "Batch executed ({} statements)", statements.len());
Ok(total_rows)
}
#[cfg(feature = "code-graph")]
pub fn register_grammar(
&self,
name: impl Into<String>,
grammar: tree_sitter::Language,
) -> Option<tree_sitter::Language> {
code_graph::parse::register_grammar(name, grammar)
}
#[cfg(feature = "code-graph")]
pub fn unregister_grammar(&self, name: &str) -> Option<tree_sitter::Language> {
code_graph::parse::unregister_grammar(name)
}
#[cfg(feature = "code-graph")]
pub fn registered_grammars(&self) -> Vec<String> {
code_graph::parse::registered_grammars()
}
#[cfg(feature = "code-graph")]
pub fn register_extractor(
&self,
name: impl Into<String>,
extractor: std::sync::Arc<dyn code_graph::SymbolExtractor>,
) -> Option<std::sync::Arc<dyn code_graph::SymbolExtractor>> {
code_graph::register_extractor(name, extractor)
}
#[cfg(feature = "code-graph")]
pub fn unregister_extractor(
&self,
name: &str,
) -> Option<std::sync::Arc<dyn code_graph::SymbolExtractor>> {
code_graph::unregister_extractor(name)
}
#[cfg(feature = "code-graph")]
pub fn registered_extractors(&self) -> Vec<String> {
code_graph::registered_extractors()
}
#[cfg(feature = "code-graph")]
pub fn code_index(
&self,
opts: code_graph::CodeIndexOptions,
) -> Result<code_graph::CodeIndexStats> {
let stats = code_graph::storage::code_index(self, opts)?;
#[cfg(feature = "graph-rag")]
{
let _ = self.graph_rag_project_symbols();
}
Ok(stats)
}
#[cfg(feature = "code-graph")]
pub fn lsp_definition(
&self,
name: &str,
hint: &code_graph::DefinitionHint,
) -> Result<Vec<code_graph::DefinitionRow>> {
code_graph::lsp::lsp_definition(self, name, hint)
}
#[cfg(feature = "code-graph")]
pub fn lsp_references(
&self,
symbol_id: i64,
) -> Result<Vec<code_graph::ReferenceRow>> {
code_graph::lsp::lsp_references(self, symbol_id)
}
#[cfg(feature = "code-graph")]
pub fn lsp_call_hierarchy(
&self,
symbol_id: i64,
direction: code_graph::lsp::CallDirection,
depth: u32,
) -> Result<Vec<code_graph::lsp::CallHierarchyRow>> {
code_graph::lsp::lsp_call_hierarchy(self, symbol_id, direction, depth)
}
#[cfg(feature = "code-graph")]
pub fn lsp_hover(&self, symbol_id: i64) -> Result<Option<code_graph::HoverRow>> {
code_graph::lsp::lsp_hover(self, symbol_id)
}
#[cfg(feature = "code-graph")]
pub fn lsp_references_diff(
&self,
symbol_id: i64,
at_a: &code_graph::AsOfRef,
at_b: &code_graph::AsOfRef,
) -> Result<Vec<code_graph::RefDiffRow>> {
code_graph::diff::lsp_references_diff(self, symbol_id, at_a, at_b)
}
#[cfg(feature = "code-graph")]
pub fn lsp_body_diff(
&self,
symbol_id: i64,
at_a: &code_graph::AsOfRef,
at_b: &code_graph::AsOfRef,
) -> Result<Vec<code_graph::BodyDiffLine>> {
code_graph::diff::lsp_body_diff(self, symbol_id, at_a, at_b)
}
#[cfg(feature = "code-graph")]
pub fn ast_diff(
&self,
file_path: &str,
at_a: &code_graph::AsOfRef,
at_b: &code_graph::AsOfRef,
) -> Result<Vec<code_graph::AstDiffRow>> {
code_graph::diff::ast_diff(self, file_path, at_a, at_b)
}
#[cfg(feature = "code-graph")]
pub fn code_graph_merkle_refresh(
&self,
) -> Result<code_graph::MerkleStats> {
code_graph::merkle_refresh(self)
}
#[cfg(feature = "code-graph")]
pub fn lsp_rename_apply(
&self,
symbol_id: i64,
new_name: &str,
opts: &code_graph::RenameApplyOptions,
) -> Result<code_graph::RenameApplyStats> {
code_graph::rename_apply(self, symbol_id, new_name, opts)
}
#[cfg(feature = "graph-rag")]
fn run_with_context(
&self,
inner_sql: &str,
opts: &graph_rag::WithContextOptions,
) -> Result<Vec<Tuple>> {
let hits =
graph_rag::graph_rag_expand_with_context(self, inner_sql, opts)?;
Ok(hits
.into_iter()
.map(|h| {
Tuple::new(vec![
Value::Int8(h.node_id),
Value::String(h.node_kind),
h.title.map(Value::String).unwrap_or(Value::Null),
h.text.map(Value::String).unwrap_or(Value::Null),
h.source_ref.map(Value::String).unwrap_or(Value::Null),
Value::Int4(h.hop_distance as i32),
])
})
.collect())
}
#[inline]
fn maybe_rewrite_code_graph<'a>(&self, sql: &'a str) -> std::borrow::Cow<'a, str> {
#[cfg(feature = "code-graph")]
{
let rewritten = code_graph::rewrite_lsp_calls(sql);
if rewritten != sql {
return std::borrow::Cow::Owned(rewritten);
}
}
std::borrow::Cow::Borrowed(sql)
}
#[cfg(feature = "code-graph")]
fn rewrite_and_scope(&self, sql: &str) -> (String, CodeGraphBranchGuard<'_>) {
let rewrite = code_graph::rewrite_lsp_calls_full(sql);
let guard = match rewrite.branch_override {
Some(target) => CodeGraphBranchGuard::switch_to(self, target),
None => CodeGraphBranchGuard::noop(),
};
(rewrite.sql, guard)
}
#[cfg(feature = "code-graph")]
fn handle_create_ast_index(
&self,
ddl: code_graph::AstIndexDdl,
) -> Result<u64> {
let existing = code_graph::storage::get_ast_index(&ddl.index_name);
if existing.is_some() && !ddl.if_not_exists {
return Err(Error::query_execution(format!(
"AST index '{}' already exists",
ddl.index_name
)));
}
let meta = code_graph::AstIndexMeta {
index_name: ddl.index_name.clone(),
table: ddl.table.clone(),
content_col: ddl.content_col,
lang_col: ddl.lang_col,
embed_endpoint: ddl.embed_endpoint.clone(),
embed_bearer: ddl.embed_bearer.clone(),
embed_bodies: ddl.embed_bodies,
auto_reparse: ddl.auto_reparse,
resolve_cross_file: ddl.resolve_cross_file,
paused: false,
};
code_graph::register_ast_index(meta.clone());
let opts = code_graph::CodeIndexOptions {
source_table: ddl.table,
embed_bodies: meta.embed_bodies,
embed_endpoint: meta.embed_endpoint,
embed_bearer: meta.embed_bearer,
force_reparse: false,
parallelism: None,
chunk_size: None,
};
self.code_index(opts)?;
Ok(0)
}
#[cfg(feature = "code-graph")]
fn handle_create_semantic_hash_index(
&self,
ddl: code_graph::SemanticHashIndexDdl,
) -> Result<u64> {
let stats = code_graph::merkle_refresh(self)?;
tracing::info!(
index = %ddl.index_name,
files_hashed = stats.files_hashed,
files_unchanged = stats.files_unchanged,
symbols_hashed = stats.symbols_hashed,
"CREATE SEMANTIC HASH INDEX completed"
);
let _ = ddl.if_not_exists; Ok(stats.files_hashed)
}
#[cfg(feature = "code-graph")]
fn handle_pause_resume(
&self,
pr: code_graph::PauseResume,
) -> Result<u64> {
let (name, paused) = match pr {
code_graph::PauseResume::Pause(n) => (n, true),
code_graph::PauseResume::Resume(n) => (n, false),
};
if !code_graph::storage::set_ast_index_paused(&name, paused) {
return Err(Error::query_execution(format!(
"AST index '{name}' is not registered"
)));
}
Ok(0)
}
#[cfg(feature = "code-graph")]
fn maybe_auto_reparse(&self, touched_table: Option<&str>) {
let Some(tbl) = touched_table else { return };
for idx in code_graph::storage::ast_indexes_for_table(tbl) {
if !idx.auto_reparse {
continue;
}
let opts = code_graph::CodeIndexOptions {
source_table: idx.table.clone(),
embed_bodies: idx.embed_bodies,
embed_endpoint: idx.embed_endpoint.clone(),
embed_bearer: idx.embed_bearer.clone(),
force_reparse: false,
parallelism: None,
chunk_size: None,
};
let _ = self.code_index(opts);
}
}
#[cfg(feature = "code-graph")]
fn touched_table_from_sql(sql: &str) -> Option<String> {
let s = sql.trim_start();
let low = s.to_ascii_lowercase();
if low.starts_with("insert into") {
let rest = s.get("insert into".len()..)?.trim_start();
Some(Self::take_ident(rest))
} else if low.starts_with("update") {
let rest = s.get("update".len()..)?.trim_start();
Some(Self::take_ident(rest))
} else if low.starts_with("delete from") {
let rest = s.get("delete from".len()..)?.trim_start();
Some(Self::take_ident(rest))
} else {
None
}
}
#[cfg(feature = "code-graph")]
fn take_ident(s: &str) -> String {
let mut out = String::new();
let mut it = s.chars().peekable();
if matches!(it.peek(), Some('"')) {
it.next();
for c in it {
if c == '"' {
break;
}
out.push(c);
}
return out;
}
for c in it {
if c.is_alphanumeric() || c == '_' {
out.push(c);
} else {
break;
}
}
out
}
#[cfg(feature = "graph-rag")]
pub fn graph_rag_project_symbols(&self) -> Result<graph_rag::GraphRagStats> {
let mut stats = graph_rag::GraphRagStats::default();
graph_rag::project_code_symbols(self, &mut stats)?;
Ok(stats)
}
#[cfg(feature = "graph-rag")]
pub fn graph_rag_search(
&self,
opts: &graph_rag::GraphRagOptions,
) -> Result<Vec<graph_rag::GraphRagHit>> {
graph_rag::graph_rag_search(self, opts)
}
#[cfg(feature = "graph-rag")]
pub fn graph_rag_link_exact(
&self,
extra_kinds: &[&str],
) -> Result<graph_rag::LinkerStats> {
graph_rag::link_exact_qualified(self, extra_kinds)
}
#[cfg(feature = "graph-rag")]
pub fn graph_rag_link_vector(
&self,
text_queries: &[graph_rag::TextEmbedding],
symbol_targets: &[graph_rag::SymbolEmbedding],
top_k: usize,
threshold: f32,
) -> Result<graph_rag::LinkerStats> {
graph_rag::link_vector_similar(self, text_queries, symbol_targets, top_k, threshold)
}
#[cfg(feature = "graph-rag")]
pub fn graph_rag_ingest_docs(
&self,
opts: &graph_rag::IngestDocsOptions,
) -> Result<graph_rag::IngestStats> {
graph_rag::ingest_docs(self, opts)
}
#[cfg(feature = "graph-rag")]
pub fn graph_rag_ingest_pdf(
&self,
opts: &graph_rag::DoclingIngestOptions,
) -> Result<graph_rag::IngestStats> {
graph_rag::docling_ingest_pdf(self, opts)
}
#[cfg(feature = "graph-rag")]
pub fn graph_rag_ingest_office(
&self,
opts: &graph_rag::DoclingIngestOptions,
) -> Result<graph_rag::IngestStats> {
graph_rag::docling_ingest_office(self, opts)
}
#[cfg(feature = "graph-rag")]
pub fn graph_rag_ingest_audio(
&self,
opts: &graph_rag::DoclingIngestOptions,
) -> Result<graph_rag::IngestStats> {
graph_rag::docling_ingest_audio(self, opts)
}
#[cfg(feature = "graph-rag")]
pub fn graph_rag_ingest_image(
&self,
opts: &graph_rag::DoclingIngestOptions,
) -> Result<graph_rag::IngestStats> {
graph_rag::docling_ingest_image(self, opts)
}
#[cfg(feature = "graph-rag")]
pub fn graph_rag_ingest_email(
&self,
opts: &graph_rag::IngestEmailOptions,
) -> Result<graph_rag::IngestStats> {
graph_rag::ingest_email(self, opts)
}
#[cfg(feature = "graph-rag")]
pub fn graph_rag_ingest_issues(
&self,
opts: &graph_rag::IngestIssuesOptions,
) -> Result<graph_rag::IngestStats> {
graph_rag::ingest_issues(self, opts)
}
#[cfg(feature = "graph-rag")]
pub fn graph_rag_ingest_qa(
&self,
opts: &graph_rag::IngestQaOptions,
) -> Result<graph_rag::IngestStats> {
graph_rag::ingest_qa(self, opts)
}
pub fn execute(&self, sql: &str) -> Result<u64> {
use crate::error::LockResultExt;
if let Some((_, _)) = crate::sql::sqlite_compat::parse_pragma(sql) {
tracing::debug!("PRAGMA stubbed via execute(): {}", sql.trim());
return Ok(0);
}
#[cfg(feature = "code-graph")]
if let Some(ddl) = code_graph::detect_create_ast_index(sql) {
return self.handle_create_ast_index(ddl);
}
#[cfg(feature = "code-graph")]
if let Some(pr) = code_graph::detect_pause_resume(sql) {
return self.handle_pause_resume(pr);
}
#[cfg(feature = "code-graph")]
if let Some(ddl) = code_graph::detect_create_semantic_hash_index(sql) {
return self.handle_create_semantic_hash_index(ddl);
}
#[cfg(feature = "code-graph")]
let (rewritten_owned, _branch_guard) = self.rewrite_and_scope(sql);
#[cfg(feature = "code-graph")]
let sql: &str = &rewritten_owned;
#[cfg(not(feature = "code-graph"))]
let sql: &str = sql;
let start = std::time::Instant::now();
if Self::is_transaction_control(sql) {
return self.handle_transaction_control(sql);
}
let has_active_txn = {
let txn_lock = self.current_transaction.lock()
.map_lock_err("Failed to acquire transaction lock for execute")?;
txn_lock.is_some()
};
let result = if has_active_txn {
let txn_lock = self.current_transaction.lock()
.map_lock_err("Failed to acquire transaction lock for execute")?;
let txn_ref = txn_lock.as_ref()
.ok_or_else(|| Error::transaction("Transaction lock in invalid state"))?;
self.execute_in_transaction_no_fast_path(sql, txn_ref)
} else {
self.execute_with_implicit_transaction(sql)
};
if result.is_ok() {
self.invalidate_result_cache();
#[cfg(feature = "code-graph")]
{
let touched = Self::touched_table_from_sql(sql);
self.maybe_auto_reparse(touched.as_deref());
}
}
let rows = result.as_ref().copied().unwrap_or(0);
self.log_slow_query(sql, start.elapsed(), rows);
result
}
pub fn execute_returning(&self, sql: &str) -> Result<(u64, Vec<Tuple>)> {
self.execute_params_returning(sql, &[])
}
fn execute_with_implicit_transaction(&self, sql: &str) -> Result<u64> {
let txn_start = std::time::Instant::now();
let txn = self.storage.begin_transaction()?;
tracing::trace!(phase = "txn_begin", duration_us = txn_start.elapsed().as_micros() as u64, "Transaction started");
let exec_start = std::time::Instant::now();
let result = self.execute_in_transaction(sql, &txn);
tracing::debug!(phase = "execute", duration_us = exec_start.elapsed().as_micros() as u64, "Query executed");
match result {
Ok(count) => {
let commit_start = std::time::Instant::now();
txn.commit()?;
self.storage.increment_lsn();
tracing::debug!(phase = "txn_commit", duration_us = commit_start.elapsed().as_micros() as u64, rows = count, "Transaction committed");
Ok(count)
}
Err(e) => {
let _ = txn.rollback(); Err(e)
}
}
}
fn invalidate_plan_cache(&self) {
if let Ok(mut cache) = self.plan_cache.lock() {
cache.clear();
}
if let Ok(mut cache) = self.parse_cache.lock() {
cache.clear();
}
self.invalidate_result_cache();
}
fn invalidate_result_cache(&self) {
if let Ok(mut cache) = self.result_cache.lock() {
cache.clear();
}
}
fn handle_pragma_query(&self, name: &str, arg: Option<&str>) -> Result<Vec<Tuple>> {
match name.to_lowercase().as_str() {
"table_info" => {
let table = arg
.unwrap_or("")
.trim()
.trim_matches(|c| c == '\'' || c == '"' || c == '`');
if table.is_empty() {
return Ok(vec![]);
}
let catalog = self.storage.catalog();
let schema = catalog.get_table_schema(table)?;
let mut rows = Vec::with_capacity(schema.columns.len());
for (idx, col) in schema.columns.iter().enumerate() {
rows.push(Tuple::new(vec![
Value::Int4(idx as i32),
Value::String(col.name.clone()),
Value::String(format!("{:?}", col.data_type).to_uppercase()),
Value::Int4(if col.nullable { 0 } else { 1 }),
col.default_expr
.as_ref()
.map(|d| Value::String(d.clone()))
.unwrap_or(Value::Null),
Value::Int4(if col.primary_key { 1 } else { 0 }),
]));
}
Ok(rows)
}
_ => {
tracing::debug!("PRAGMA stubbed (no-op rows): {} = {:?}", name, arg);
Ok(vec![])
}
}
}
fn execute_alter_table_op(&self, plan: &sql::LogicalPlan) -> Result<u64> {
match plan {
sql::LogicalPlan::AlterTableAddColumn { table_name, column_def, if_not_exists } => {
let catalog = self.storage.catalog();
let mut schema = catalog.get_table_schema(table_name)?;
if schema.columns.iter().any(|c| c.name == column_def.name) {
if *if_not_exists {
return Ok(0);
}
return Err(Error::query_execution(format!(
"Column '{}' already exists in table '{}'", column_def.name, table_name
)));
}
let new_column = Column {
name: column_def.name.clone(),
data_type: column_def.data_type.clone(),
nullable: !column_def.not_null,
primary_key: column_def.primary_key,
source_table: None,
source_table_name: Some(table_name.clone()),
default_expr: column_def.default.as_ref().map(|e| format!("{:?}", e)),
unique: column_def.unique,
storage_mode: column_def.storage_mode,
};
schema.columns.push(new_column);
catalog.update_table_schema(table_name, &schema)?;
let rows_updated = self.storage.add_column_to_rows(
table_name,
&column_def.default,
)?;
tracing::info!(
"Added column '{}' to table '{}', updated {} rows",
column_def.name, table_name, rows_updated
);
Ok(rows_updated as u64)
}
sql::LogicalPlan::AlterTableDropColumn { table_name, column_name, if_exists, cascade } => {
let catalog = self.storage.catalog();
let mut schema = catalog.get_table_schema(table_name)?;
let col_idx = schema.columns.iter()
.position(|c| c.name == *column_name);
match col_idx {
Some(idx) => {
let is_pk = schema.get_column_at(idx)
.ok_or_else(|| Error::internal("column index out of bounds"))?
.primary_key;
if is_pk && !cascade {
return Err(Error::query_execution(format!(
"Cannot drop primary key column '{}' without CASCADE", column_name
)));
}
schema.columns.remove(idx);
catalog.update_table_schema(table_name, &schema)?;
let rows_updated = self.storage.drop_column_from_rows(table_name, idx)?;
tracing::info!(
"Dropped column '{}' from table '{}', updated {} rows",
column_name, table_name, rows_updated
);
Ok(rows_updated as u64)
}
None => {
if *if_exists {
Ok(0)
} else {
Err(Error::query_execution(format!(
"Column '{}' does not exist in table '{}'", column_name, table_name
)))
}
}
}
}
sql::LogicalPlan::AlterTableRenameColumn { table_name, old_column_name, new_column_name } => {
let catalog = self.storage.catalog();
let mut schema = catalog.get_table_schema(table_name)?;
if schema.columns.iter().any(|c| c.name == *new_column_name) {
return Err(Error::query_execution(format!(
"Column '{}' already exists in table '{}'", new_column_name, table_name
)));
}
let col_idx = schema.columns.iter()
.position(|c| c.name == *old_column_name)
.ok_or_else(|| Error::query_execution(format!(
"Column '{}' does not exist in table '{}'", old_column_name, table_name
)))?;
schema.get_column_at_mut(col_idx)
.ok_or_else(|| Error::internal("column index out of bounds"))?
.name = new_column_name.clone();
catalog.update_table_schema(table_name, &schema)?;
tracing::info!(
"Renamed column '{}' to '{}' in table '{}'",
old_column_name, new_column_name, table_name
);
Ok(0)
}
sql::LogicalPlan::AlterTableRename { table_name, new_table_name } => {
let catalog = self.storage.catalog();
if catalog.get_table_schema(new_table_name).is_ok() {
return Err(Error::query_execution(format!(
"Table '{}' already exists", new_table_name
)));
}
self.storage.rename_table(table_name, new_table_name)?;
tracing::info!(
"Renamed table '{}' to '{}'",
table_name, new_table_name
);
Ok(0)
}
_ => Err(Error::internal(format!(
"execute_alter_table_op called with non-ALTER plan: {:?}",
plan.plan_type_name()
))),
}
}
#[allow(clippy::indexing_slicing)] fn try_fast_insert(&self, sql: &str) -> Option<Result<u64>> {
let trimmed = sql.trim();
if trimmed.len() < 20 || !trimmed.as_bytes().get(..6)?.eq_ignore_ascii_case(b"INSERT") {
return None;
}
let upper = trimmed.to_ascii_uppercase();
if upper.contains("RETURNING") || upper.contains("ON CONFLICT")
|| upper.contains("DEFAULT") || upper.contains("SELECT") {
return None;
}
let after_insert = trimmed.get(6..)?.trim_start();
if !after_insert.as_bytes().get(..4)?.eq_ignore_ascii_case(b"INTO") {
return None;
}
let after_into = after_insert.get(4..)?.trim_start();
let table_end = after_into.find(|c: char| c == '(' || c.is_whitespace())?;
let table_name = after_into.get(..table_end)?.trim().trim_matches('"');
if table_name.is_empty() {
return None;
}
let rest = after_into.get(table_end..)?.trim_start();
if !rest.starts_with('(') {
return None;
}
let col_end = rest.find(')')?;
let col_list_str = rest.get(1..col_end)?;
let columns: Vec<&str> = col_list_str.split(',').map(|s| s.trim()).collect();
if columns.is_empty() || columns.iter().any(|c| c.is_empty()) {
return None;
}
let after_cols = rest.get(col_end + 1..)?.trim_start();
if after_cols.len() < 6 || !after_cols.as_bytes().get(..6)?.eq_ignore_ascii_case(b"VALUES") {
return None;
}
let values_rest = after_cols.get(6..)?.trim_start();
if !values_rest.starts_with('(') {
return None;
}
let values_inner = values_rest.get(1..)?;
let close_idx = Self::find_closing_paren(values_inner)?;
let values_str = values_inner.get(..close_idx)?;
let after_values = values_inner.get(close_idx + 1..)?.trim();
if !after_values.is_empty() && after_values != ";" {
return None; }
if self.tenant_manager.should_apply_rls(table_name, "INSERT") {
return None;
}
if self.trigger_registry.has_triggers_for_table(table_name) {
return None;
}
let catalog = self.storage.catalog();
let schema = match catalog.get_table_schema(table_name) {
Ok(s) => s,
Err(_) => return None, };
if let Ok(tc) = catalog.load_table_constraints(table_name) {
if !tc.foreign_keys.is_empty() {
return None;
}
}
if columns.len() != Self::fast_parse_value_count(values_str) {
return None; }
let mut target_types = Vec::with_capacity(columns.len());
let mut col_indices = Vec::with_capacity(columns.len());
for col_name in &columns {
match schema.get_column_index(col_name) {
Some(idx) => {
col_indices.push(idx);
match schema.get_column_at(idx) {
Some(col) => target_types.push(col.data_type.clone()),
None => return None,
}
}
None => return None, }
}
let values = Self::fast_parse_values(values_str, &target_types)?;
let mut tuple_values = vec![Value::Null; schema.columns.len()];
let mut user_provided = vec![false; schema.columns.len()];
for (i, &col_idx) in col_indices.iter().enumerate() {
if let Some(val) = values.get(i) {
if col_idx < tuple_values.len() {
tuple_values[col_idx] = val.clone();
user_provided[col_idx] = true;
}
}
}
if let Err(e) = Self::apply_defaults_and_check_not_null(
&mut tuple_values, &schema, &user_provided,
) {
return Some(Err(e));
}
let tuple = Tuple::new(tuple_values);
if self.storage.get_current_branch_id().is_none() {
Some(self.storage.insert_tuple_fast(table_name, tuple, &schema).map(|_| 1))
} else {
Some(self.storage.insert_tuple_branch_aware_with_schema(table_name, tuple, &schema).map(|_| 1))
}
}
fn try_fast_update(&self, sql: &str) -> Option<Result<u64>> {
let trimmed = sql.trim();
if trimmed.len() < 20 || !trimmed.as_bytes().get(..6)?.eq_ignore_ascii_case(b"UPDATE") {
return None;
}
let upper = trimmed.to_ascii_uppercase();
if upper.contains("RETURNING") || upper.contains("JOIN")
|| upper.contains("FROM") || upper.contains("SELECT")
|| upper.contains("CASE") || upper.contains("COALESCE") {
return None;
}
let after_update = trimmed.get(6..)?.trim_start();
let table_end = after_update.find(|c: char| c.is_whitespace())?;
let table_name = after_update.get(..table_end)?.trim();
if table_name.is_empty() {
return None;
}
let rest = after_update.get(table_end..)?.trim_start();
if rest.len() < 3 || !rest.as_bytes().get(..3)?.eq_ignore_ascii_case(b"SET") {
return None;
}
let after_set = rest.get(3..)?.trim_start();
let where_pos = {
let upper_rest = after_set.to_ascii_uppercase();
let pos = upper_rest.find("WHERE")?;
if pos == 0 { return None; }
let prev = after_set.as_bytes().get(pos - 1)?;
if !prev.is_ascii_whitespace() { return None; }
pos
};
let set_clause = after_set.get(..where_pos)?.trim();
let where_clause = after_set.get(where_pos + 5..)?.trim();
if set_clause.contains(',') {
return None; }
let eq_pos = set_clause.find('=')?;
let set_col = set_clause.get(..eq_pos)?.trim();
let set_val_str = set_clause.get(eq_pos + 1..)?.trim();
if set_col.is_empty() || set_val_str.is_empty() {
return None;
}
let where_clause = where_clause.strip_suffix(';').unwrap_or(where_clause).trim();
let where_upper = where_clause.to_ascii_uppercase();
if where_upper.contains("AND") || where_upper.contains("OR")
|| where_upper.contains("IN") || where_upper.contains("BETWEEN") {
return None;
}
let weq_pos = where_clause.find('=')?;
let pk_col = where_clause.get(..weq_pos)?.trim();
let pk_val_str = where_clause.get(weq_pos + 1..)?.trim();
if pk_col.is_empty() || pk_val_str.is_empty() {
return None;
}
if self.tenant_manager.should_apply_rls(table_name, "UPDATE") {
return None;
}
if self.trigger_registry.has_triggers_for_table(table_name) {
return None;
}
if self.storage.get_current_branch_id().is_some() {
return None;
}
let catalog = self.storage.catalog();
let schema = match catalog.get_table_schema(table_name) {
Ok(s) => s,
Err(_) => return None,
};
let pk_col_idx = schema.get_column_index(pk_col)?;
let pk_column = schema.get_column_at(pk_col_idx)?;
if !pk_column.primary_key {
return None; }
let set_col_idx = schema.get_column_index(set_col)?;
let set_column = schema.get_column_at(set_col_idx)?;
if self.storage.art_indexes().has_fk(table_name) {
return None;
}
if let Ok(constraints) = self.storage.catalog().load_table_constraints(table_name) {
if !constraints.foreign_keys.is_empty() {
return None; }
}
let (pk_value, _) = Self::fast_parse_one_value(pk_val_str, &pk_column.data_type)?;
let existing_row = match self.storage.get_row_by_pk_with_schema(table_name, &pk_value, &schema) {
Ok(Some(row)) => row,
Ok(None) => return Some(Ok(0)), Err(e) => return Some(Err(e)),
};
let row_id = existing_row.row_id.unwrap_or(0);
if row_id == 0 {
return None; }
let new_value = if let Some((val, _)) = Self::fast_parse_one_value(set_val_str, &set_column.data_type) {
val
} else if let Some(val) = Self::fast_eval_simple_expr(set_val_str, set_col, set_col_idx, &existing_row) {
val
} else {
return None; };
if !set_column.nullable && matches!(new_value, Value::Null) {
return Some(Err(Error::constraint_violation(format!(
"Column '{}' cannot be null", set_col
))));
}
let mut new_values = existing_row.values.clone();
if set_col_idx < new_values.len() {
#[allow(clippy::indexing_slicing)]
{ new_values[set_col_idx] = new_value; }
} else {
return None;
}
let new_tuple = Tuple::new(new_values);
Some(self.storage.update_tuple_fast(table_name, row_id, new_tuple, &existing_row, &schema))
}
fn try_fast_select(&self, sql: &str) -> Option<Result<Vec<Tuple>>> {
let trimmed = sql.trim();
if trimmed.len() < 20 || !trimmed.as_bytes().get(..6)?.eq_ignore_ascii_case(b"SELECT") {
return None;
}
let after_select = trimmed.get(6..)?.trim_start();
if !after_select.starts_with('*') {
return None;
}
let after_star = after_select.get(1..)?.trim_start();
if after_star.len() < 4 || !after_star.as_bytes().get(..4)?.eq_ignore_ascii_case(b"FROM") {
return None;
}
let after_from = after_star.get(4..)?.trim_start();
let table_end = after_from.find(|c: char| c.is_whitespace())?;
let table_name = after_from.get(..table_end)?.trim();
if table_name.is_empty() {
return None;
}
let rest = after_from.get(table_end..)?.trim_start();
if rest.len() < 5 || !rest.as_bytes().get(..5)?.eq_ignore_ascii_case(b"WHERE") {
return None;
}
let where_clause = rest.get(5..)?.trim_start();
let upper = where_clause.to_ascii_uppercase();
if upper.contains("AND") || upper.contains("OR")
|| upper.contains("JOIN") || upper.contains("ORDER")
|| upper.contains("GROUP") || upper.contains("LIMIT") {
return None;
}
let where_clause = where_clause.strip_suffix(';').unwrap_or(where_clause).trim();
let eq_pos = where_clause.find('=')?;
let pk_col = where_clause.get(..eq_pos)?.trim();
let pk_val_str = where_clause.get(eq_pos + 1..)?.trim();
if pk_col.is_empty() || pk_val_str.is_empty() {
return None;
}
if self.tenant_manager.should_apply_rls(table_name, "SELECT") {
return None;
}
let catalog = self.storage.catalog();
let schema = match catalog.get_table_schema(table_name) {
Ok(s) => s,
Err(_) => return None,
};
let pk_col_idx = schema.get_column_index(pk_col)?;
let pk_column = schema.get_column_at(pk_col_idx)?;
if !pk_column.primary_key {
return None; }
let (pk_value, _) = Self::fast_parse_one_value(pk_val_str, &pk_column.data_type)?;
match self.storage.get_row_by_pk_with_schema(table_name, &pk_value, &schema) {
Ok(Some(row)) => Some(Ok(vec![row])),
Ok(None) => Some(Ok(vec![])),
Err(e) => Some(Err(e)),
}
}
fn fast_eval_simple_expr(expr: &str, col_name: &str, col_idx: usize, row: &Tuple) -> Option<Value> {
let expr = expr.trim();
if !expr.starts_with(col_name) {
return None;
}
let after_col = expr.get(col_name.len()..)?.trim_start();
if after_col.is_empty() {
return None;
}
let current = row.values.get(col_idx)?;
let (op, operand_str) = if let Some(rest) = after_col.strip_prefix('+') {
('+', rest.trim())
} else if let Some(rest) = after_col.strip_prefix('-') {
('-', rest.trim())
} else if let Some(rest) = after_col.strip_prefix('*') {
('*', rest.trim())
} else {
return None;
};
match (current, op) {
(Value::Int2(v), '+') => { let n: i16 = operand_str.parse().ok()?; Some(Value::Int2(v.checked_add(n)?)) }
(Value::Int2(v), '-') => { let n: i16 = operand_str.parse().ok()?; Some(Value::Int2(v.checked_sub(n)?)) }
(Value::Int2(v), '*') => { let n: i16 = operand_str.parse().ok()?; Some(Value::Int2(v.checked_mul(n)?)) }
(Value::Int4(v), '+') => { let n: i32 = operand_str.parse().ok()?; Some(Value::Int4(v.checked_add(n)?)) }
(Value::Int4(v), '-') => { let n: i32 = operand_str.parse().ok()?; Some(Value::Int4(v.checked_sub(n)?)) }
(Value::Int4(v), '*') => { let n: i32 = operand_str.parse().ok()?; Some(Value::Int4(v.checked_mul(n)?)) }
(Value::Int8(v), '+') => { let n: i64 = operand_str.parse().ok()?; Some(Value::Int8(v.checked_add(n)?)) }
(Value::Int8(v), '-') => { let n: i64 = operand_str.parse().ok()?; Some(Value::Int8(v.checked_sub(n)?)) }
(Value::Int8(v), '*') => { let n: i64 = operand_str.parse().ok()?; Some(Value::Int8(v.checked_mul(n)?)) }
(Value::Float4(v), '+') => { let n: f32 = operand_str.parse().ok()?; Some(Value::Float4(v + n)) }
(Value::Float4(v), '-') => { let n: f32 = operand_str.parse().ok()?; Some(Value::Float4(v - n)) }
(Value::Float4(v), '*') => { let n: f32 = operand_str.parse().ok()?; Some(Value::Float4(v * n)) }
(Value::Float8(v), '+') => { let n: f64 = operand_str.parse().ok()?; Some(Value::Float8(v + n)) }
(Value::Float8(v), '-') => { let n: f64 = operand_str.parse().ok()?; Some(Value::Float8(v - n)) }
(Value::Float8(v), '*') => { let n: f64 = operand_str.parse().ok()?; Some(Value::Float8(v * n)) }
_ => None,
}
}
#[allow(clippy::indexing_slicing)] fn find_closing_paren(s: &str) -> Option<usize> {
let mut in_string = false;
let bytes = s.as_bytes();
let mut i = 0;
while i < bytes.len() {
let b = bytes[i];
if in_string {
if b == b'\'' {
if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
i += 2;
continue;
}
in_string = false;
}
} else if b == b'\'' {
in_string = true;
} else if b == b')' {
return Some(i);
}
i += 1;
}
None
}
#[allow(clippy::indexing_slicing)] fn fast_parse_value_count(s: &str) -> usize {
if s.trim().is_empty() {
return 0;
}
let mut count = 1;
let mut in_string = false;
let bytes = s.as_bytes();
let mut i = 0;
while i < bytes.len() {
let b = bytes[i];
if in_string {
if b == b'\'' {
if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
i += 2;
continue;
}
in_string = false;
}
} else if b == b'\'' {
in_string = true;
} else if b == b',' {
count += 1;
}
i += 1;
}
count
}
#[allow(clippy::indexing_slicing)] fn fast_parse_values(s: &str, target_types: &[DataType]) -> Option<Vec<Value>> {
let mut values = Vec::with_capacity(target_types.len());
let mut remaining = s;
let mut type_idx = 0;
while !remaining.is_empty() && type_idx < target_types.len() {
remaining = remaining.trim_start();
if remaining.is_empty() {
break;
}
let (value, rest) = Self::fast_parse_one_value(remaining, &target_types[type_idx])?;
values.push(value);
type_idx += 1;
let rest = rest.trim_start();
if rest.starts_with(',') {
remaining = rest.get(1..)?;
} else {
remaining = rest;
}
}
if values.len() == target_types.len() {
Some(values)
} else {
None
}
}
#[allow(clippy::indexing_slicing)] fn fast_parse_one_value<'a>(s: &'a str, target_type: &DataType) -> Option<(Value, &'a str)> {
let s = s.trim_start();
if s.is_empty() {
return None;
}
let first = s.as_bytes().first()?;
if *first == b'\'' {
match target_type {
DataType::Int2 | DataType::Int4 | DataType::Int8
| DataType::Float4 | DataType::Float8
| DataType::Boolean => return None,
_ => {}
}
let inner = s.get(1..)?;
let mut end = 0;
let bytes = inner.as_bytes();
let mut result = String::new();
let mut seg_start = 0; while end < bytes.len() {
if bytes[end] == b'\'' {
if seg_start < end {
result.push_str(inner.get(seg_start..end)?);
}
if end + 1 < bytes.len() && bytes[end + 1] == b'\'' {
result.push('\'');
end += 2;
seg_start = end;
continue;
}
let rest = inner.get(end + 1..)?;
let typed = match target_type {
DataType::Uuid => uuid::Uuid::parse_str(&result)
.map(Value::Uuid)
.unwrap_or(Value::String(result)),
DataType::Date => result
.parse::<chrono::NaiveDate>()
.map(Value::Date)
.unwrap_or(Value::String(result)),
DataType::Timestamp => chrono::DateTime::parse_from_rfc3339(&result)
.map(|t| Value::Timestamp(t.to_utc()))
.unwrap_or(Value::String(result)),
_ => Value::String(result),
};
return Some((typed, rest));
}
end += 1;
}
return None; }
if s.len() >= 4 && s.as_bytes().get(..4)?.eq_ignore_ascii_case(b"NULL") {
let rest = s.get(4..)?;
if rest.is_empty() || rest.starts_with(',') || rest.starts_with(')') || rest.starts_with(' ') {
return Some((Value::Null, rest));
}
}
if s.len() >= 4 && s.as_bytes().get(..4)?.eq_ignore_ascii_case(b"TRUE") {
let rest = s.get(4..)?;
if rest.is_empty() || rest.starts_with(',') || rest.starts_with(')') || rest.starts_with(' ') {
return Some((Value::Boolean(true), rest));
}
}
if s.len() >= 5 && s.as_bytes().get(..5)?.eq_ignore_ascii_case(b"FALSE") {
let rest = s.get(5..)?;
if rest.is_empty() || rest.starts_with(',') || rest.starts_with(')') || rest.starts_with(' ') {
return Some((Value::Boolean(false), rest));
}
}
if first.is_ascii_digit() || *first == b'-' || *first == b'+' || *first == b'.' {
let end = s.find([',', ')', ' '])
.unwrap_or(s.len());
let num_str = s.get(..end)?.trim();
let rest = s.get(end..)?;
let value = match target_type {
DataType::Int4 => {
let n: i32 = num_str.parse().ok()?;
Value::Int4(n)
}
DataType::Int8 => {
let n: i64 = num_str.parse().ok()?;
Value::Int8(n)
}
DataType::Float4 => {
let f: f32 = num_str.parse().ok()?;
Value::Float4(f)
}
DataType::Float8 => {
let f: f64 = num_str.parse().ok()?;
Value::Float8(f)
}
DataType::Numeric => {
if let Ok(n) = num_str.parse::<i64>() {
Value::Int8(n)
} else if let Ok(f) = num_str.parse::<f64>() {
Value::Float8(f)
} else {
return None;
}
}
_ => {
if num_str.contains('.') {
let f: f64 = num_str.parse().ok()?;
Value::Float8(f)
} else if let Ok(n) = num_str.parse::<i32>() {
Value::Int4(n)
} else if let Ok(n) = num_str.parse::<i64>() {
Value::Int8(n)
} else {
return None;
}
}
};
return Some((value, rest));
}
None
}
pub(crate) fn parse_cached(&self, sql: &str) -> Result<(sqlparser::ast::Statement, bool)> {
if let Ok(mut cache) = self.parse_cache.lock() {
if let Some(stmt) = cache.get(sql) {
return Ok((stmt.clone(), true));
}
}
let parser = sql::Parser::new();
let statement = parser.parse_one(sql)?;
if let Ok(mut cache) = self.parse_cache.lock() {
cache.put(sql.to_string(), statement.clone());
}
Ok((statement, false))
}
fn execute_internal(&self, sql: &str) -> Result<u64> {
if let Some(context) = self.tenant_manager.get_current_context() {
self.tenant_manager.record_query(context.tenant_id)
.map_err(|e| Error::query_execution(format!("Quota exceeded: {}", e)))?;
}
let parse_start = std::time::Instant::now();
let (statement, parse_cached) = self.parse_cached(sql)?;
let parse_elapsed = parse_start.elapsed();
if parse_cached {
tracing::debug!(phase = "parse", duration_us = parse_elapsed.as_micros() as u64, "SQL parsed (AST cached)");
} else {
tracing::debug!(phase = "parse", duration_us = parse_elapsed.as_micros() as u64, "SQL parsed");
}
let plan_start = std::time::Instant::now();
let catalog = self.storage.catalog();
let planner = sql::Planner::with_catalog(&catalog);
let plan = planner.statement_to_plan(statement)?;
let plan_elapsed = plan_start.elapsed();
tracing::debug!(phase = "plan", duration_us = plan_elapsed.as_micros() as u64, "Logical plan created");
if matches!(&plan,
sql::LogicalPlan::CreateTable { .. } |
sql::LogicalPlan::DropTable { .. } |
sql::LogicalPlan::AlterTableAddColumn { .. } |
sql::LogicalPlan::AlterTableDropColumn { .. } |
sql::LogicalPlan::AlterTableRename { .. } |
sql::LogicalPlan::AlterTableRenameColumn { .. } |
sql::LogicalPlan::AlterTableAddForeignKey { .. } |
sql::LogicalPlan::AlterTableMulti { .. } |
sql::LogicalPlan::CreateIndex { .. } |
sql::LogicalPlan::Truncate { .. } |
sql::LogicalPlan::CreateMaterializedView { .. } |
sql::LogicalPlan::DropMaterializedView { .. }
) {
self.invalidate_plan_cache();
}
match &plan {
sql::LogicalPlan::CreateTable { name, columns, if_not_exists, .. } => {
if *if_not_exists && self.storage.catalog().table_exists(name).unwrap_or(false) {
return Ok(0);
}
let schema_columns: Vec<Column> = columns.iter().map(|col_def| {
Column {
name: col_def.name.clone(),
data_type: col_def.data_type.clone(),
nullable: !col_def.not_null, primary_key: col_def.primary_key,
source_table: None,
source_table_name: None,
default_expr: None,
unique: false,
storage_mode: col_def.storage_mode,
}
}).collect();
let schema = Schema::new(schema_columns);
let catalog = self.storage.catalog();
if let Err(e) = self.storage.log_create_table(name, &schema) {
tracing::warn!("Failed to log CREATE TABLE to WAL: {}", e);
}
catalog.create_table(name, schema)?;
Ok(1) }
sql::LogicalPlan::Insert { table_name, columns, values, returning, on_conflict: _ } => {
let rls_enforced = self.tenant_manager.should_apply_rls(table_name, "INSERT");
let rls_check = if rls_enforced {
self.tenant_manager.get_rls_conditions(table_name, "INSERT")
} else {
None
};
let catalog = self.storage.catalog();
let schema = catalog.get_table_schema(table_name)?;
let evaluator = sql::Evaluator::new(std::sync::Arc::new(Schema {
columns: vec![], }));
let empty_tuple = Tuple::new(vec![]);
let mut count = 0;
for value_row in values {
let mut tuple_values: Vec<Value> = Vec::new();
for (col_idx, expr) in value_row.iter().enumerate() {
let target_col_idx = if let Some(ref cols) = columns {
let col_name = cols.get(col_idx)
.ok_or_else(|| Error::internal("column index out of bounds"))?;
schema.get_column_index(col_name)
.ok_or_else(|| Error::query_execution(format!("Column '{}' not found", col_name)))?
} else {
col_idx
};
let target_col = schema.get_column_at(target_col_idx)
.ok_or_else(|| Error::query_execution(format!(
"Too many values for INSERT: table has {} columns",
schema.columns.len()
)))?;
let target_type = &target_col.data_type;
let mut value = evaluator.evaluate(expr, &empty_tuple)?;
let needs_cast = match (&value, target_type) {
(Value::Null, _) => false, (Value::Vector(_), DataType::Vector(_)) => false,
(Value::String(_), DataType::Vector(_)) => true, (Value::String(_), DataType::Json | DataType::Jsonb) => true, (Value::Int4(_), DataType::Int4) => false,
(Value::Int8(_), DataType::Int8) => false,
(Value::Float4(_), DataType::Float4) => false,
(Value::Float8(_), DataType::Float8) => false,
(Value::String(_), DataType::Text | DataType::Varchar(_)) => false,
(Value::Boolean(_), DataType::Boolean) => false,
(Value::Json(_), DataType::Json | DataType::Jsonb) => false,
_ => true, };
if needs_cast {
value = evaluator.cast_value(value, target_type)?;
}
tuple_values.push(value);
}
let tuple = Tuple::new(tuple_values);
if let Some((_, with_check)) = &rls_check {
if let Some(ref with_check_expr) = with_check {
let tenant_context = self.tenant_manager.get_current_context();
let rls_evaluator = tenant::RLSExpressionEvaluator::new(
std::sync::Arc::new(schema.clone()),
tenant_context
);
let expr = rls_evaluator.parse(with_check_expr)?;
let satisfies_policy = rls_evaluator.evaluate(&expr, &tuple)?;
if !satisfies_policy {
return Err(Error::query_execution(format!(
"Row-Level Security policy violation: inserted row does not satisfy WITH CHECK expression"
)));
}
}
}
self.storage.insert_tuple_branch_aware_with_schema(table_name, tuple, &schema)?;
count += 1;
}
Ok(count)
}
sql::LogicalPlan::InsertSelect { table_name, columns, source, returning: _ } => {
let mut executor = sql::Executor::with_storage(&self.storage)
.with_timeout(self.config.storage.query_timeout_ms);
let source_rows = executor.execute(source)?;
let rls_enforced = self.tenant_manager.should_apply_rls(table_name, "INSERT");
let rls_check = if rls_enforced {
self.tenant_manager.get_rls_conditions(table_name, "INSERT")
} else {
None
};
let catalog = self.storage.catalog();
let schema = catalog.get_table_schema(table_name)?;
let evaluator = sql::Evaluator::new(std::sync::Arc::new(Schema {
columns: vec![],
}));
let empty_tuple = Tuple::new(vec![]);
let column_indices: Option<Vec<usize>> = columns.as_ref().map(|cols| {
cols.iter()
.filter_map(|col_name| schema.get_column_index(col_name))
.collect()
});
let default_exprs: Vec<Option<sql::LogicalExpr>> = schema.columns.iter()
.map(|col| {
col.default_expr.as_ref().and_then(|json| {
serde_json::from_str(json).ok()
})
})
.collect();
let mut count = 0u64;
for source_row in &source_rows {
let mut tuple_values: Vec<Option<Value>> = vec![None; schema.columns.len()];
for (val_idx, value) in source_row.values.iter().enumerate() {
let target_col_idx = if let Some(ref indices) = column_indices {
if val_idx >= indices.len() {
return Err(Error::query_execution("More values than columns specified"));
}
*indices.get(val_idx).ok_or_else(|| Error::internal("column index out of bounds"))?
} else {
val_idx
};
let target_col = schema.get_column_at(target_col_idx)
.ok_or_else(|| Error::query_execution(format!(
"Too many values for INSERT: table has {} columns",
schema.columns.len()
)))?;
let target_type = &target_col.data_type;
let mut val = value.clone();
let needs_cast = match (&val, target_type) {
(Value::Null, _) => false,
(Value::Vector(_), DataType::Vector(_)) => false,
(Value::String(_), DataType::Vector(_)) => true,
(Value::String(_), DataType::Json | DataType::Jsonb) => true,
(Value::Int4(_), DataType::Int4) => false,
(Value::Int8(_), DataType::Int8) => false,
(Value::Float4(_), DataType::Float4) => false,
(Value::Float8(_), DataType::Float8) => false,
(Value::String(_), DataType::Text | DataType::Varchar(_)) => false,
(Value::Boolean(_), DataType::Boolean) => false,
(Value::Json(_), DataType::Json | DataType::Jsonb) => false,
_ => true,
};
if needs_cast {
val = evaluator.cast_value(val, target_type)?;
}
let tv = tuple_values.get_mut(target_col_idx)
.ok_or_else(|| Error::internal("column index out of bounds"))?;
*tv = Some(val);
}
let final_values: Result<Vec<Value>> = tuple_values
.into_iter()
.enumerate()
.map(|(idx, opt_val)| {
if let Some(val) = opt_val {
Ok(val)
} else {
let col = schema.get_column_at(idx)
.ok_or_else(|| Error::internal("column index out of bounds"))?;
if let Some(ref default_expr) = default_exprs.get(idx).and_then(|d| d.as_ref()) {
let mut value = evaluator.evaluate(default_expr, &empty_tuple)?;
if value.data_type() != col.data_type {
value = evaluator.cast_value(value, &col.data_type)?;
}
Ok(value)
} else if col.primary_key {
Ok(Value::Null)
} else if col.nullable {
Ok(Value::Null)
} else {
Err(Error::query_execution(format!(
"Column '{}' does not have a default value and is not nullable",
col.name
)))
}
}
})
.collect();
let tuple = Tuple::new(final_values?);
if let Some((_, with_check)) = &rls_check {
if let Some(ref with_check_expr) = with_check {
let tenant_context = self.tenant_manager.get_current_context();
let rls_evaluator = tenant::RLSExpressionEvaluator::new(
std::sync::Arc::new(schema.clone()),
tenant_context
);
let expr = rls_evaluator.parse(with_check_expr)?;
let satisfies_policy = rls_evaluator.evaluate(&expr, &tuple)?;
if !satisfies_policy {
return Err(Error::query_execution(
"Row-Level Security policy violation: inserted row does not satisfy WITH CHECK expression"
));
}
}
}
self.storage.insert_tuple_branch_aware_with_schema(table_name, tuple, &schema)?;
count += 1;
}
Ok(count)
}
sql::LogicalPlan::Update { table_name, assignments, selection, returning } => {
let rls_enforced = self.tenant_manager.should_apply_rls(table_name, "UPDATE");
let rls_condition = if rls_enforced {
self.tenant_manager.get_rls_conditions(table_name, "UPDATE")
} else {
None
};
let catalog = self.storage.catalog();
let schema = catalog.get_table_schema(table_name)?;
let eval_schema = schema.clone().with_source_table_name(table_name);
let evaluator = sql::Evaluator::new(std::sync::Arc::new(eval_schema));
let tuples = self.storage.scan_table_branch_aware(table_name)?;
let mut updates: Vec<(u64, Tuple)> = Vec::new();
for mut tuple in tuples {
let where_matches = if let Some(predicate) = selection {
let result = evaluator.evaluate(predicate, &tuple)?;
match result {
Value::Boolean(b) => b,
_ => false,
}
} else {
true };
let rls_matches = if let Some((using_expr, _)) = &rls_condition {
let tenant_context = self.tenant_manager.get_current_context();
let rls_evaluator = tenant::RLSExpressionEvaluator::new(
std::sync::Arc::new(schema.clone()),
tenant_context
);
let expr = rls_evaluator.parse(using_expr)?;
rls_evaluator.evaluate(&expr, &tuple)?
} else {
true };
if where_matches && rls_matches {
for (col_name, value_expr) in assignments {
let bound = self.materialize_scalar_subqueries_for_row(
value_expr, &tuple, &schema, table_name,
)?;
let mut new_value = evaluator.evaluate(&bound, &tuple)?;
let col_index = evaluator.schema().get_column_index(col_name)
.ok_or_else(|| Error::query_execution(format!("Column '{}' not found", col_name)))?;
let target_col = schema.get_column_at(col_index)
.ok_or_else(|| Error::query_execution(format!("Column '{}' not found", col_name)))?;
let target_type = &target_col.data_type;
let needs_cast = !matches!(&new_value, Value::Null)
&& !matches!(
(&new_value, target_type),
(Value::Vector(_), DataType::Vector(_))
| (Value::Int2(_), DataType::Int2)
| (Value::Int4(_), DataType::Int4)
| (Value::Int8(_), DataType::Int8)
| (Value::Float4(_), DataType::Float4)
| (Value::Float8(_), DataType::Float8)
| (Value::String(_), DataType::Text | DataType::Varchar(_))
| (Value::Boolean(_), DataType::Boolean)
| (Value::Json(_), DataType::Json | DataType::Jsonb)
| (Value::Timestamp(_), DataType::Timestamp | DataType::Timestamptz)
| (Value::Date(_), DataType::Date)
);
if needs_cast {
new_value = evaluator.cast_value(new_value, target_type)?;
}
if let Some(slot) = tuple.values.get_mut(col_index) {
*slot = new_value;
}
}
let mut new_col_values = std::collections::HashMap::with_capacity(schema.columns.len());
for (i, col) in schema.columns.iter().enumerate() {
if let Some(v) = tuple.values.get(i) {
new_col_values.insert(col.name.clone(), v.clone());
}
}
self.check_fk_constraints_on_write(table_name, &new_col_values, None)?;
let row_id = tuple.row_id.unwrap_or(0);
updates.push((row_id, tuple));
}
}
let update_count = self.storage.update_tuples_branch_aware(table_name, updates)?;
Ok(update_count)
}
sql::LogicalPlan::Delete { table_name, selection, returning } => {
let rls_enforced = self.tenant_manager.should_apply_rls(table_name, "DELETE");
let rls_condition = if rls_enforced {
self.tenant_manager.get_rls_conditions(table_name, "DELETE")
} else {
None
};
let catalog = self.storage.catalog();
let schema = catalog.get_table_schema(table_name)?;
let eval_schema = schema.clone().with_source_table_name(table_name);
let evaluator = sql::Evaluator::new(std::sync::Arc::new(eval_schema));
let tuples = self.storage.scan_table_branch_aware(table_name)?;
let mut row_ids_to_delete: Vec<u64> = Vec::new();
let mut deleted_tuples: Vec<(u64, Tuple)> = Vec::new();
for tuple in tuples {
let where_matches = if let Some(predicate) = selection {
let result = evaluator.evaluate(predicate, &tuple)?;
match result {
Value::Boolean(b) => b,
_ => false,
}
} else {
true };
let rls_matches = if let Some((using_expr, _)) = &rls_condition {
let tenant_context = self.tenant_manager.get_current_context();
let rls_evaluator = tenant::RLSExpressionEvaluator::new(
std::sync::Arc::new(schema.clone()),
tenant_context
);
let expr = rls_evaluator.parse(using_expr)?;
rls_evaluator.evaluate(&expr, &tuple)?
} else {
true };
if where_matches && rls_matches {
if let Some(row_id) = tuple.row_id {
row_ids_to_delete.push(row_id);
deleted_tuples.push((row_id, tuple.clone()));
}
}
}
for (row_id, tuple) in &deleted_tuples {
let mut col_values = std::collections::HashMap::new();
for (i, col) in schema.columns.iter().enumerate() {
if let Some(v) = tuple.values.get(i) {
col_values.insert(col.name.clone(), v.clone());
}
}
if let Err(e) = self.storage.art_indexes().on_delete(table_name, *row_id, &col_values) {
tracing::debug!("ART index delete for table '{}': {}", table_name, e);
}
}
let delete_count = self.storage.delete_tuples_branch_aware(table_name, row_ids_to_delete)?;
Ok(delete_count)
}
sql::LogicalPlan::DropTable { name, if_exists } => {
let catalog = self.storage.catalog();
let exists = catalog.table_exists(name)?;
if exists {
let mv_catalog = self.storage.mv_catalog();
if let Ok(mv_names) = mv_catalog.list_views() {
let mut dependent_mvs = Vec::new();
for mv_name in &mv_names {
if let Ok(metadata) = mv_catalog.get_view(mv_name) {
if metadata.base_tables.iter().any(|t| t == name) {
dependent_mvs.push(mv_name.clone());
}
}
}
if !dependent_mvs.is_empty() {
tracing::warn!(
"Dropping table '{}' which is used by materialized view(s): {}. Those views will be stale.",
name,
dependent_mvs.join(", ")
);
}
}
catalog.drop_table(name)?;
if let Err(e) = self.trigger_registry.drop_table_triggers(name) {
tracing::warn!("Failed to clean up triggers for dropped table '{}': {}", name, e);
}
self.storage.predicate_pushdown().remove_table(name);
self.storage.row_cache().invalidate_table(name);
if let Err(e) = self.storage.log_drop_table(name) {
tracing::warn!("Failed to log DROP TABLE to WAL: {}", e);
}
Ok(0) } else if *if_exists {
Ok(0)
} else {
Err(Error::query_execution(format!("Table '{}' does not exist", name)))
}
}
sql::LogicalPlan::Truncate { table_name } => {
let mut trigger_context = sql::TriggerContext::new();
let trigger_event = sql::logical_plan::TriggerEvent::Truncate;
let row_context = sql::triggers::TriggerRowContext {
old_tuple: None,
new_tuple: None,
transition_tables: None,
};
let db_ref = self.clone_for_trigger();
let mut executor_fn = |stmt: &sql::LogicalPlan, _ctx: &sql::triggers::TriggerRowContext| -> Result<()> {
db_ref.execute_plan_internal(stmt)?;
Ok(())
};
let action = self.trigger_registry.execute_triggers(
table_name,
&trigger_event,
&sql::logical_plan::TriggerTiming::Before,
&row_context,
&mut trigger_context,
None, &mut executor_fn,
)?;
match action {
sql::triggers::TriggerAction::Abort(msg) => {
return Err(Error::query_execution(format!("TRUNCATE aborted by trigger: {}", msg)));
}
sql::triggers::TriggerAction::Skip => {
return Ok(0);
}
sql::triggers::TriggerAction::Continue => {
}
}
let prefix = format!("data:{}:", table_name);
let prefix_bytes = prefix.as_bytes();
let mut keys_to_delete = Vec::new();
let iter = self.storage.db.iterator(rocksdb::IteratorMode::Start);
for item in iter {
let (key, _) = item.map_err(|e| Error::storage(format!("Iterator error: {}", e)))?;
if !key.starts_with(prefix_bytes) {
if !key.is_empty() && key.first() > prefix_bytes.first() {
break;
}
continue;
}
keys_to_delete.push(key.to_vec());
}
for key in &keys_to_delete {
self.storage.delete(key)?;
}
self.storage.row_cache().invalidate_table(table_name);
let has_user_branches = self.storage.list_branches()
.map(|b| b.iter().any(|br| br.name != "main"))
.unwrap_or(false);
if !has_user_branches {
self.storage.art_indexes().clear_table_indexes(table_name);
}
let action = self.trigger_registry.execute_triggers(
table_name,
&trigger_event,
&sql::logical_plan::TriggerTiming::After,
&row_context,
&mut trigger_context,
None, &mut executor_fn,
)?;
if let sql::triggers::TriggerAction::Abort(msg) = action {
return Err(Error::query_execution(format!("TRUNCATE failed in AFTER trigger: {}", msg)));
}
if let Err(e) = self.storage.log_truncate(table_name) {
tracing::warn!("Failed to log TRUNCATE to WAL: {}", e);
}
Ok(keys_to_delete.len() as u64) }
sql::LogicalPlan::AlterColumnStorage { table_name, column_name, storage_mode } => {
let catalog = self.storage.catalog();
let mut schema = catalog.get_table_schema(table_name)?;
let col_idx = schema.columns.iter()
.position(|c| c.name == *column_name)
.ok_or_else(|| Error::query_execution(format!(
"Column '{}' not found in table '{}'", column_name, table_name
)))?;
let col_ref = schema.get_column_at(col_idx)
.ok_or_else(|| Error::internal("column index out of bounds"))?;
let old_mode = col_ref.storage_mode;
if old_mode == *storage_mode {
return Ok(0);
}
let column = col_ref.clone();
let rows_migrated = self.storage.migrate_column_storage(
table_name,
col_idx,
&column,
old_mode,
*storage_mode,
)?;
schema.get_column_at_mut(col_idx)
.ok_or_else(|| Error::internal("column index out of bounds"))?
.storage_mode = *storage_mode;
catalog.update_table_schema(table_name, &schema)?;
if let Err(e) = self.storage.log_alter_column_storage(table_name, column_name, storage_mode) {
tracing::warn!("Failed to log ALTER COLUMN STORAGE to WAL: {}", e);
}
tracing::info!(
"Altered {}.{} storage from {:?} to {:?}, migrated {} rows",
table_name, column_name, old_mode, storage_mode, rows_migrated
);
Ok(rows_migrated as u64)
}
sql::LogicalPlan::AlterTableAddColumn { table_name, column_def, if_not_exists } => {
let catalog = self.storage.catalog();
let mut schema = catalog.get_table_schema(table_name)?;
if schema.columns.iter().any(|c| c.name == column_def.name) {
if *if_not_exists {
return Ok(0);
}
return Err(Error::query_execution(format!(
"Column '{}' already exists in table '{}'", column_def.name, table_name
)));
}
let new_column = Column {
name: column_def.name.clone(),
data_type: column_def.data_type.clone(),
nullable: !column_def.not_null,
primary_key: column_def.primary_key,
source_table: None,
source_table_name: Some(table_name.clone()),
default_expr: column_def.default.as_ref().map(|e| format!("{:?}", e)),
unique: column_def.unique,
storage_mode: column_def.storage_mode,
};
schema.columns.push(new_column);
catalog.update_table_schema(table_name, &schema)?;
let rows_updated = self.storage.add_column_to_rows(table_name, &column_def.default)?;
Ok(rows_updated as u64)
}
sql::LogicalPlan::AlterTableDropColumn { table_name, column_name, if_exists, cascade } => {
let catalog = self.storage.catalog();
let mut schema = catalog.get_table_schema(table_name)?;
let col_idx = schema.columns.iter().position(|c| c.name == *column_name);
match col_idx {
Some(idx) => {
if schema.get_column_at(idx).is_some_and(|c| c.primary_key) && !cascade {
return Err(Error::query_execution(format!(
"Cannot drop primary key column '{}' without CASCADE", column_name
)));
}
schema.columns.remove(idx);
catalog.update_table_schema(table_name, &schema)?;
let rows_updated = self.storage.drop_column_from_rows(table_name, idx)?;
Ok(rows_updated as u64)
}
None => {
if *if_exists {
Ok(0)
} else {
Err(Error::query_execution(format!(
"Column '{}' does not exist in table '{}'", column_name, table_name
)))
}
}
}
}
sql::LogicalPlan::AlterTableRenameColumn { table_name, old_column_name, new_column_name } => {
let catalog = self.storage.catalog();
let mut schema = catalog.get_table_schema(table_name)?;
if schema.columns.iter().any(|c| c.name == *new_column_name) {
return Err(Error::query_execution(format!(
"Column '{}' already exists in table '{}'", new_column_name, table_name
)));
}
let col_idx = schema.columns.iter()
.position(|c| c.name == *old_column_name)
.ok_or_else(|| Error::query_execution(format!(
"Column '{}' does not exist in table '{}'", old_column_name, table_name
)))?;
schema.get_column_at_mut(col_idx)
.ok_or_else(|| Error::internal("column index out of bounds"))?
.name = new_column_name.clone();
catalog.update_table_schema(table_name, &schema)?;
Ok(0)
}
sql::LogicalPlan::AlterTableRename { table_name, new_table_name } => {
let catalog = self.storage.catalog();
if catalog.get_table_schema(new_table_name).is_ok() {
return Err(Error::query_execution(format!(
"Table '{}' already exists", new_table_name
)));
}
self.storage.rename_table(table_name, new_table_name)?;
Ok(0)
}
sql::LogicalPlan::AlterTableMulti { operations } => {
let mut total_rows = 0u64;
for sub_plan in operations {
total_rows += self.execute_alter_table_op(sub_plan)?;
}
Ok(total_rows)
}
sql::LogicalPlan::Savepoint { ref name } => {
let txn = self.current_transaction.lock()
.map_err(|_| Error::query_execution("Failed to lock transaction"))?;
let write_set_snapshot = match txn.as_ref() {
Some(t) => t.savepoint_snapshot(),
None => return Err(Error::query_execution("SAVEPOINT can only be used within a transaction")),
};
drop(txn);
let savepoint = SavepointState {
name: name.clone(),
write_set_snapshot,
};
self.savepoints.write().push(savepoint);
Ok(0)
}
sql::LogicalPlan::ReleaseSavepoint { ref name } => {
let mut savepoints = self.savepoints.write();
if let Some(pos) = savepoints.iter().rposition(|s| &s.name == name) {
savepoints.truncate(pos);
Ok(0)
} else {
Err(Error::query_execution(format!("Savepoint '{}' does not exist", name)))
}
}
sql::LogicalPlan::RollbackToSavepoint { ref name } => {
let savepoints = self.savepoints.read();
if let Some(pos) = savepoints.iter().rposition(|s| &s.name == name) {
let snapshot = savepoints.get(pos)
.map(|s| s.write_set_snapshot.clone());
drop(savepoints);
if let Some(snapshot) = snapshot {
let txn = self.current_transaction.lock()
.map_err(|_| Error::query_execution("Failed to lock transaction"))?;
if let Some(t) = txn.as_ref() {
t.rollback_to_savepoint(&snapshot);
}
drop(txn);
}
let mut savepoints = self.savepoints.write();
savepoints.truncate(pos + 1);
Ok(0)
} else {
Err(Error::query_execution(format!("Savepoint '{}' does not exist", name)))
}
}
sql::LogicalPlan::CreateDatabase { name, if_not_exists } => {
let (count, _) = self.handle_create_database(name, *if_not_exists)?;
Ok(count)
}
sql::LogicalPlan::DropDatabase { name, if_exists } => {
let (count, _) = self.handle_drop_database(name, *if_exists)?;
Ok(count)
}
_ => {
let mut executor = sql::Executor::with_storage(&self.storage)
.with_timeout(self.config.storage.query_timeout_ms);
let results = executor.execute(&plan)?;
Ok(results.len() as u64)
}
}
}
pub fn execute_params(&self, sql: &str, params: &[Value]) -> Result<u64> {
let (statement, _) = self.parse_cached(sql)?;
let catalog = self.storage.catalog();
let planner = sql::Planner::with_catalog(&catalog);
let plan = planner.statement_to_plan(statement)?;
let (count, _tuples) = self.execute_plan_with_params(&plan, params)?;
Ok(count)
}
pub fn execute_params_returning(&self, sql: &str, params: &[Value]) -> Result<(u64, Vec<Tuple>)> {
let (statement, _) = self.parse_cached(sql)?;
let catalog = self.storage.catalog();
let planner = sql::Planner::with_catalog(&catalog);
let plan = planner.statement_to_plan(statement)?;
let out = self.execute_plan_with_params(&plan, params);
#[cfg(feature = "code-graph")]
if out.is_ok() {
let touched = Self::touched_table_from_sql(sql);
self.maybe_auto_reparse(touched.as_deref());
}
out
}
fn project_returning_columns(
tuple: &Tuple,
schema: &Schema,
returning_items: &Option<Vec<sql::logical_plan::ReturningItem>>,
) -> Option<Tuple> {
let items = returning_items.as_ref()?;
let evaluator = sql::Evaluator::new(std::sync::Arc::new(schema.clone()));
let mut projected_values = Vec::with_capacity(items.len());
for item in items {
match item {
sql::logical_plan::ReturningItem::Wildcard => {
return Some(tuple.clone());
}
sql::logical_plan::ReturningItem::Column(col_name) => {
if let Some(col_idx) = schema.get_column_index(col_name) {
if let Some(val) = tuple.values.get(col_idx) {
projected_values.push(val.clone());
} else {
projected_values.push(Value::Null);
}
} else {
projected_values.push(Value::Null);
}
}
sql::logical_plan::ReturningItem::Expression { expr, .. } => {
match evaluator.evaluate(expr, tuple) {
Ok(val) => projected_values.push(val),
Err(_) => projected_values.push(Value::Null),
}
}
}
}
Some(Tuple::new(projected_values))
}
fn resolve_excluded_refs(
expr: &sql::logical_plan::LogicalExpr,
excluded_map: &std::collections::HashMap<String, Value>,
) -> sql::logical_plan::LogicalExpr {
match expr {
sql::logical_plan::LogicalExpr::Column { table: Some(tbl), name }
if tbl.eq_ignore_ascii_case("excluded") =>
{
if let Some(val) = excluded_map.get(&name.to_lowercase()) {
sql::logical_plan::LogicalExpr::Literal(val.clone())
} else {
expr.clone()
}
}
sql::logical_plan::LogicalExpr::BinaryExpr { left, op, right } => {
sql::logical_plan::LogicalExpr::BinaryExpr {
left: Box::new(Self::resolve_excluded_refs(left, excluded_map)),
op: op.clone(),
right: Box::new(Self::resolve_excluded_refs(right, excluded_map)),
}
}
sql::logical_plan::LogicalExpr::UnaryExpr { op, expr: inner } => {
sql::logical_plan::LogicalExpr::UnaryExpr {
op: op.clone(),
expr: Box::new(Self::resolve_excluded_refs(inner, excluded_map)),
}
}
sql::logical_plan::LogicalExpr::Cast { expr: inner, data_type } => {
sql::logical_plan::LogicalExpr::Cast {
expr: Box::new(Self::resolve_excluded_refs(inner, excluded_map)),
data_type: data_type.clone(),
}
}
sql::logical_plan::LogicalExpr::ScalarFunction { fun, args } => {
sql::logical_plan::LogicalExpr::ScalarFunction {
fun: fun.clone(),
args: args.iter().map(|a| Self::resolve_excluded_refs(a, excluded_map)).collect(),
}
}
other => other.clone(),
}
}
pub(crate) fn returning_schema(
table_schema: &Schema,
returning_items: &[sql::logical_plan::ReturningItem],
) -> Schema {
let columns = returning_items.iter()
.flat_map(|item| {
match item {
sql::logical_plan::ReturningItem::Wildcard => {
table_schema.columns.clone()
}
sql::logical_plan::ReturningItem::Column(col_name) => {
if let Some(col) = table_schema.columns.iter().find(|c| &c.name == col_name) {
vec![col.clone()]
} else {
vec![Column {
name: col_name.clone(),
data_type: DataType::Text,
nullable: true,
primary_key: false,
source_table: None,
source_table_name: None,
default_expr: None,
unique: false,
storage_mode: crate::ColumnStorageMode::Default,
}]
}
}
sql::logical_plan::ReturningItem::Expression { alias, .. } => {
vec![Column {
name: alias.clone(),
data_type: DataType::Text,
nullable: true,
primary_key: false,
source_table: None,
source_table_name: None,
default_expr: None,
unique: false,
storage_mode: crate::ColumnStorageMode::Default,
}]
}
}
})
.collect();
Schema { columns }
}
fn execute_plan_with_params(&self, plan: &sql::LogicalPlan, params: &[Value]) -> Result<(u64, Vec<Tuple>)> {
let result = self.execute_plan_with_params_inner(plan, params);
if result.is_ok()
&& matches!(
plan,
sql::LogicalPlan::Insert { .. }
| sql::LogicalPlan::InsertSelect { .. }
| sql::LogicalPlan::Update { .. }
| sql::LogicalPlan::Delete { .. }
)
{
self.invalidate_result_cache();
}
result
}
fn execute_plan_with_params_inner(&self, plan: &sql::LogicalPlan, params: &[Value]) -> Result<(u64, Vec<Tuple>)> {
if let sql::LogicalPlan::CreateDatabase { name, if_not_exists } = plan {
return self.handle_create_database(name, *if_not_exists);
}
if let sql::LogicalPlan::DropDatabase { name, if_exists } = plan {
return self.handle_drop_database(name, *if_exists);
}
match plan {
sql::LogicalPlan::Insert { table_name, columns, values, returning, on_conflict } => {
let catalog = self.storage.catalog();
let schema = catalog.get_table_schema(table_name)?;
let evaluator = sql::Evaluator::with_parameters(
std::sync::Arc::new(Schema { columns: vec![] }),
params.to_vec(),
);
let empty_tuple = Tuple::new(vec![]);
let has_returning = returning.is_some();
let mut returned_tuples: Vec<Tuple> = Vec::new();
let mut count = 0;
for value_row in values {
let mut tuple_values: Vec<Value> = vec![Value::Null; schema.columns.len()];
let mut user_provided: Vec<bool> = vec![false; schema.columns.len()];
for (col_idx, expr) in value_row.iter().enumerate() {
let target_col_idx = if let Some(ref cols) = columns {
let col_name = cols.get(col_idx)
.ok_or_else(|| Error::internal("column index out of bounds"))?;
schema.get_column_index(col_name)
.ok_or_else(|| Error::query_execution(format!("Column '{}' not found", col_name)))?
} else {
col_idx
};
if matches!(expr, sql::LogicalExpr::DefaultValue) {
continue;
}
let target_col = schema.get_column_at(target_col_idx)
.ok_or_else(|| Error::query_execution(format!(
"Too many values for INSERT: table has {} columns",
schema.columns.len()
)))?;
let target_type = &target_col.data_type;
let mut value = evaluator.evaluate(expr, &empty_tuple)?;
let needs_cast = match (&value, target_type) {
(Value::Null, _) => false,
(Value::Vector(_), DataType::Vector(_)) => false,
(Value::String(_), DataType::Vector(_)) => true,
(Value::String(_), DataType::Json | DataType::Jsonb) => true,
(Value::Int4(_), DataType::Int4) => false,
(Value::Int8(_), DataType::Int8) => false,
(Value::Float4(_), DataType::Float4) => false,
(Value::Float8(_), DataType::Float8) => false,
(Value::String(_), DataType::Text | DataType::Varchar(_)) => false,
(Value::Boolean(_), DataType::Boolean) => false,
(Value::Json(_), DataType::Json | DataType::Jsonb) => false,
_ => true,
};
if needs_cast {
value = evaluator.cast_value(value, target_type)?;
}
if let Some(slot) = tuple_values.get_mut(target_col_idx) {
*slot = value;
}
if let Some(flag) = user_provided.get_mut(target_col_idx) {
*flag = true;
}
}
Self::apply_defaults_and_check_not_null(
&mut tuple_values, &schema, &user_provided,
)?;
let tuple = Tuple::new(tuple_values);
let mut col_values_map = std::collections::HashMap::with_capacity(schema.columns.len());
for (i, col) in schema.columns.iter().enumerate() {
if let Some(v) = tuple.values.get(i) {
col_values_map.insert(col.name.clone(), v.clone());
}
}
let conflict = self.storage.art_indexes()
.check_unique_constraints(table_name, &col_values_map);
match (conflict, on_conflict) {
(Ok(()), _) => {
self.check_fk_constraints_on_write(table_name, &col_values_map, None)?;
let row_id = self.storage.insert_tuple_branch_aware_with_schema(
table_name, tuple.clone(), &schema,
)?;
if has_returning {
let mut filled = tuple;
for (i, col) in schema.columns.iter().enumerate() {
if col.primary_key {
if let Some(v) = filled.values.get(i) {
if matches!(v, Value::Null) {
if let Some(slot) = filled.values.get_mut(i) {
*slot = match col.data_type {
DataType::Int2 => Value::Int2(row_id as i16),
DataType::Int4 => Value::Int4(row_id as i32),
_ => Value::Int8(row_id as i64),
};
}
}
}
}
}
filled.row_id = Some(row_id);
if let Some(projected) = Self::project_returning_columns(&filled, &schema, returning) {
returned_tuples.push(projected);
}
}
count += 1;
}
(Err(_), Some(sql::logical_plan::OnConflictAction::DoNothing)) => {
}
(Err(e), Some(sql::logical_plan::OnConflictAction::DoUpdate { assignments })) => {
let mut found_row_id: Option<u64> = None;
let art = self.storage.art_indexes();
for (i, col) in schema.columns.iter().enumerate() {
if (col.unique || col.primary_key) && !col.primary_key {
if let Some(val) = tuple.values.get(i) {
if !matches!(val, Value::Null) {
if let Some(name) = art.find_column_index(table_name, &col.name) {
let key = storage::ArtIndexManager::encode_key(&[val.clone()]);
let row_ids = art.index_get_all(&name, &key);
if let Some(rid) = row_ids.first() {
found_row_id = Some(*rid);
}
}
if found_row_id.is_some() { break; }
}
}
}
}
if found_row_id.is_none() {
let pk_values: Vec<Value> = schema.columns.iter().enumerate()
.filter(|(_, c)| c.primary_key)
.filter_map(|(i, _)| tuple.values.get(i).cloned())
.collect();
if !pk_values.is_empty()
&& !pk_values.iter().any(|v| matches!(v, Value::Null))
{
let pk_key = crate::storage::ArtIndexManager::encode_key(&pk_values);
found_row_id = self.storage.art_indexes()
.pk_index_lookup(table_name, &pk_key);
}
}
let existing_row_id = found_row_id.ok_or_else(|| Error::query_execution(
format!("ON CONFLICT DO UPDATE: could not find existing row ({})", e)
))?;
let existing_key = self.storage.branch_aware_data_key(
table_name, existing_row_id,
);
let existing_raw = self.storage.get(&existing_key)?
.ok_or_else(|| Error::query_execution(
"ON CONFLICT DO UPDATE: existing row not found in storage"
))?;
let mut updated_tuple: Tuple = bincode::deserialize(&existing_raw)
.map_err(|err| Error::storage(format!("Failed to deserialize tuple: {}", err)))?;
updated_tuple.row_id = Some(existing_row_id);
let old_tuple_for_art: Tuple = bincode::deserialize(&existing_raw)
.map_err(|err| Error::storage(format!("Failed to deserialize tuple: {}", err)))?;
let mut excluded_map = std::collections::HashMap::new();
for (i, col) in schema.columns.iter().enumerate() {
if let Some(v) = tuple.values.get(i) {
excluded_map.insert(col.name.to_lowercase(), v.clone());
}
}
let update_eval = sql::Evaluator::with_parameters(
std::sync::Arc::new(schema.clone()),
params.to_vec(),
);
for (col_name, expr) in assignments {
let target_idx = schema.columns.iter()
.position(|c| c.name.eq_ignore_ascii_case(col_name))
.ok_or_else(|| Error::query_execution(format!(
"ON CONFLICT DO UPDATE: column '{}' not found", col_name
)))?;
let resolved_expr = Self::resolve_excluded_refs(expr, &excluded_map);
let mut new_val = update_eval.evaluate(&resolved_expr, &updated_tuple)?;
let target_type = &schema.columns.get(target_idx)
.ok_or_else(|| Error::internal("column index out of bounds"))?
.data_type;
if new_val.data_type() != *target_type
&& !matches!(new_val, Value::Null)
{
new_val = update_eval.cast_value(new_val, target_type)?;
}
if target_idx < updated_tuple.values.len() {
#[allow(clippy::indexing_slicing)]
{ updated_tuple.values[target_idx] = new_val; }
}
}
self.storage.update_tuple_fast(
table_name,
existing_row_id,
updated_tuple.clone(),
&old_tuple_for_art,
&schema,
)?;
if has_returning {
if let Some(projected) = Self::project_returning_columns(
&updated_tuple, &schema, returning,
) {
returned_tuples.push(projected);
}
}
count += 1;
}
(Err(e), None) => {
return Err(Error::constraint_violation(e.to_string()));
}
}
}
Ok((count, returned_tuples))
}
sql::LogicalPlan::InsertSelect { table_name, columns, source, returning } => {
let mut executor = sql::Executor::with_storage(&self.storage)
.with_timeout(self.config.storage.query_timeout_ms);
let source_rows = executor.execute(source)?;
let catalog = self.storage.catalog();
let schema = catalog.get_table_schema(table_name)?;
let evaluator = sql::Evaluator::new(std::sync::Arc::new(Schema { columns: vec![] }));
let column_indices: Option<Vec<usize>> = columns.as_ref().map(|cols| {
cols.iter()
.filter_map(|col_name| schema.get_column_index(col_name))
.collect()
});
let has_returning = returning.is_some();
let mut returned_tuples: Vec<Tuple> = Vec::new();
let mut count = 0u64;
for source_row in &source_rows {
let mut tuple_values: Vec<Value> = Vec::new();
for (val_idx, value) in source_row.values.iter().enumerate() {
let target_col_idx = if let Some(ref indices) = column_indices {
*indices.get(val_idx).ok_or_else(|| Error::internal("column index out of bounds"))?
} else {
val_idx
};
let target_col = schema.get_column_at(target_col_idx)
.ok_or_else(|| Error::query_execution(format!(
"Too many values for INSERT: table has {} columns",
schema.columns.len()
)))?;
let target_type = &target_col.data_type;
let mut val = value.clone();
let needs_cast = match (&val, target_type) {
(Value::Null, _) => false,
(Value::Vector(_), DataType::Vector(_)) => false,
(Value::String(_), DataType::Vector(_)) => true,
(Value::String(_), DataType::Json | DataType::Jsonb) => true,
(Value::Int4(_), DataType::Int4) => false,
(Value::Int8(_), DataType::Int8) => false,
(Value::Float4(_), DataType::Float4) => false,
(Value::Float8(_), DataType::Float8) => false,
(Value::String(_), DataType::Text | DataType::Varchar(_)) => false,
(Value::Boolean(_), DataType::Boolean) => false,
(Value::Json(_), DataType::Json | DataType::Jsonb) => false,
_ => true,
};
if needs_cast {
val = evaluator.cast_value(val, target_type)?;
}
tuple_values.push(val);
}
let tuple = Tuple::new(tuple_values);
if has_returning {
if let Some(projected) = Self::project_returning_columns(&tuple, &schema, returning) {
returned_tuples.push(projected);
}
}
self.storage.insert_tuple_branch_aware_with_schema(table_name, tuple, &schema)?;
count += 1;
}
Ok((count, returned_tuples))
}
sql::LogicalPlan::Update { table_name, assignments, selection, returning } => {
let catalog = self.storage.catalog();
let schema = catalog.get_table_schema(table_name)?;
let eval_schema = schema.clone().with_source_table_name(table_name);
let evaluator = sql::Evaluator::with_parameters(
std::sync::Arc::new(eval_schema),
params.to_vec(),
);
let tuples = self.storage.scan_table_branch_aware(table_name)?;
let mut updates: Vec<(u64, Tuple)> = Vec::new();
for mut tuple in tuples {
let matches = if let Some(predicate) = selection {
let result = evaluator.evaluate(predicate, &tuple)?;
match result {
Value::Boolean(b) => b,
_ => false,
}
} else {
true
};
if matches {
for (col_name, value_expr) in assignments {
let bound = self.materialize_scalar_subqueries_for_row(
value_expr, &tuple, &schema, table_name,
)?;
let mut new_value = evaluator.evaluate(&bound, &tuple)?;
let col_index = evaluator.schema().get_column_index(col_name)
.ok_or_else(|| Error::query_execution(format!("Column '{}' not found", col_name)))?;
let target_col = schema.get_column_at(col_index)
.ok_or_else(|| Error::query_execution(format!("Column '{}' not found", col_name)))?;
let target_type = &target_col.data_type;
let needs_cast = match (&new_value, target_type) {
(Value::Null, _) => false,
(Value::Vector(_), DataType::Vector(_)) => false,
(Value::String(_), DataType::Vector(_)) => true,
(Value::String(_), DataType::Json | DataType::Jsonb) => true,
(Value::Int2(_), DataType::Int2) => false,
(Value::Int4(_), DataType::Int4) => false,
(Value::Int8(_), DataType::Int8) => false,
(Value::Float4(_), DataType::Float4) => false,
(Value::Float8(_), DataType::Float8) => false,
(Value::String(_), DataType::Text | DataType::Varchar(_)) => false,
(Value::Boolean(_), DataType::Boolean) => false,
(Value::Json(_), DataType::Json | DataType::Jsonb) => false,
(Value::Timestamp(_), DataType::Timestamp | DataType::Timestamptz) => false,
(Value::Date(_), DataType::Date) => false,
_ => true,
};
if needs_cast {
new_value = evaluator.cast_value(new_value, target_type)?;
}
if let Some(slot) = tuple.values.get_mut(col_index) {
*slot = new_value;
}
}
let mut new_col_values = std::collections::HashMap::with_capacity(schema.columns.len());
for (i, col) in schema.columns.iter().enumerate() {
if let Some(v) = tuple.values.get(i) {
new_col_values.insert(col.name.clone(), v.clone());
}
}
self.check_fk_constraints_on_write(table_name, &new_col_values, None)?;
let row_id = tuple.row_id.unwrap_or(0);
updates.push((row_id, tuple));
}
}
let returned_tuples: Vec<Tuple> = if returning.is_some() {
updates.iter()
.filter_map(|(_, tuple)| Self::project_returning_columns(tuple, &schema, returning))
.collect()
} else {
Vec::new()
};
let count = self.storage.update_tuples_branch_aware(table_name, updates)?;
Ok((count, returned_tuples))
}
sql::LogicalPlan::Delete { table_name, selection, returning } => {
let catalog = self.storage.catalog();
let schema = catalog.get_table_schema(table_name)?;
let eval_schema = schema.clone().with_source_table_name(table_name);
let evaluator = sql::Evaluator::with_parameters(
std::sync::Arc::new(eval_schema),
params.to_vec(),
);
let tuples = self.storage.scan_table_branch_aware(table_name)?;
let mut row_ids_to_delete: Vec<u64> = Vec::new();
let mut deleted_tuples: Vec<(u64, Tuple)> = Vec::new();
let mut returned_tuples: Vec<Tuple> = Vec::new();
let has_returning = returning.is_some();
for tuple in tuples {
let matches = if let Some(predicate) = selection {
let result = evaluator.evaluate(predicate, &tuple)?;
match result {
Value::Boolean(b) => b,
_ => false,
}
} else {
true
};
if matches {
if has_returning {
if let Some(projected) = Self::project_returning_columns(&tuple, &schema, returning) {
returned_tuples.push(projected);
}
}
if let Some(row_id) = tuple.row_id {
row_ids_to_delete.push(row_id);
deleted_tuples.push((row_id, tuple.clone()));
}
}
}
for (row_id, tuple) in &deleted_tuples {
let mut col_values = std::collections::HashMap::new();
for (i, col) in schema.columns.iter().enumerate() {
if let Some(v) = tuple.values.get(i) {
col_values.insert(col.name.clone(), v.clone());
}
}
if let Err(e) = self.storage.art_indexes().on_delete(table_name, *row_id, &col_values) {
tracing::debug!("ART index delete for table '{}': {}", table_name, e);
}
}
let count = self.storage.delete_tuples_branch_aware(table_name, row_ids_to_delete)?;
Ok((count, returned_tuples))
}
sql::LogicalPlan::StartTransaction => {
self.begin_transaction_internal()?;
Ok((0, Vec::new()))
}
sql::LogicalPlan::Commit => {
self.commit_internal()?;
Ok((0, Vec::new()))
}
sql::LogicalPlan::Rollback => {
self.rollback_internal()?;
Ok((0, Vec::new()))
}
sql::LogicalPlan::Savepoint { name } => {
let txn = self.current_transaction.lock()
.map_err(|_| Error::query_execution("Failed to lock transaction"))?;
let write_set_snapshot = match txn.as_ref() {
Some(t) => t.savepoint_snapshot(),
None => return Err(Error::query_execution("SAVEPOINT can only be used within a transaction")),
};
drop(txn);
let savepoint = SavepointState {
name: name.clone(),
write_set_snapshot,
};
self.savepoints.write().push(savepoint);
Ok((0, Vec::new()))
}
sql::LogicalPlan::ReleaseSavepoint { name } => {
let mut savepoints = self.savepoints.write();
if let Some(pos) = savepoints.iter().rposition(|s| &s.name == name) {
savepoints.truncate(pos);
Ok((0, Vec::new()))
} else {
Err(Error::query_execution(format!("Savepoint '{}' does not exist", name)))
}
}
sql::LogicalPlan::RollbackToSavepoint { name } => {
let savepoints = self.savepoints.read();
if let Some(pos) = savepoints.iter().rposition(|s| &s.name == name) {
let snapshot = savepoints.get(pos)
.map(|s| s.write_set_snapshot.clone());
drop(savepoints);
if let Some(snapshot) = snapshot {
let txn = self.current_transaction.lock()
.map_err(|_| Error::query_execution("Failed to lock transaction"))?;
if let Some(t) = txn.as_ref() {
t.rollback_to_savepoint(&snapshot);
}
drop(txn);
}
let mut savepoints = self.savepoints.write();
savepoints.truncate(pos + 1);
Ok((0, Vec::new()))
} else {
Err(Error::query_execution(format!("Savepoint '{}' does not exist", name)))
}
}
sql::LogicalPlan::Prepare { name, statement, .. } => {
self.prepared_statements.write().insert(name.clone(), *statement.clone());
Ok((0, Vec::new()))
}
sql::LogicalPlan::Execute { name, parameters } => {
let stmt = {
let stmts = self.prepared_statements.read();
stmts.get(name).cloned()
};
if let Some(plan) = stmt {
let empty_tuple = Tuple::new(vec![]);
let empty_schema = std::sync::Arc::new(Schema { columns: vec![] });
let evaluator = sql::Evaluator::new(empty_schema);
let param_values: Result<Vec<Value>> = parameters.iter()
.map(|expr| evaluator.evaluate(expr, &empty_tuple))
.collect();
self.execute_plan_with_params(&plan, ¶m_values?)
} else {
Err(Error::query_execution(format!("Prepared statement '{}' does not exist", name)))
}
}
sql::LogicalPlan::Deallocate { name } => {
if let Some(ref stmt_name) = name {
let removed = self.prepared_statements.write().remove(stmt_name);
if removed.is_none() {
return Err(Error::query_execution(format!("Prepared statement '{}' does not exist", stmt_name)));
}
} else {
self.prepared_statements.write().clear();
}
Ok((0, Vec::new()))
}
_ => {
let mut executor = sql::Executor::with_storage(&self.storage)
.with_timeout(self.config.storage.query_timeout_ms)
.with_parameters(params.to_vec());
let results = executor.execute(plan)?;
Ok((results.len() as u64, Vec::new()))
}
}
}
pub fn query(&self, sql: &str, _params: &[&dyn std::fmt::Display]) -> Result<Vec<Tuple>> {
if let Some((name, arg)) = crate::sql::sqlite_compat::parse_pragma(sql) {
return self.handle_pragma_query(&name, arg.as_deref());
}
#[cfg(feature = "graph-rag")]
if let Some((inner, opts)) = graph_rag::detect_with_context(sql) {
return self.run_with_context(&inner, &opts);
}
#[cfg(feature = "code-graph")]
let (rewritten_owned, _branch_guard) = self.rewrite_and_scope(sql);
#[cfg(feature = "code-graph")]
let sql: &str = &rewritten_owned;
#[cfg(not(feature = "code-graph"))]
let sql: &str = sql;
let start = std::time::Instant::now();
{
let upper = sql.trim().to_uppercase();
let is_dml = upper.starts_with("INSERT")
|| upper.starts_with("UPDATE")
|| upper.starts_with("DELETE");
if is_dml && upper.contains("RETURNING") {
let (_count, tuples) = self.execute_returning(sql)?;
self.invalidate_result_cache();
self.log_slow_query(sql, start.elapsed(), tuples.len() as u64);
return Ok(tuples);
}
}
{
use crate::error::LockResultExt;
let has_active_txn = {
let txn_lock = self.current_transaction.lock()
.map_lock_err("Failed to acquire transaction lock for query")?;
txn_lock.is_some()
};
if has_active_txn {
let txn_lock = self.current_transaction.lock()
.map_lock_err("Failed to acquire transaction lock for query")?;
let txn_ref = txn_lock.as_ref()
.ok_or_else(|| Error::transaction("Transaction lock in invalid state"))?;
let (statement, _) = self.parse_cached(sql)?;
let catalog = self.storage.catalog();
let planner = sql::Planner::with_catalog(&catalog)
.with_sql(sql.to_string());
let plan = planner.statement_to_plan(statement)?;
let mut executor = sql::Executor::with_storage(&self.storage)
.with_timeout(self.config.storage.query_timeout_ms)
.with_transaction(txn_ref);
let results = executor.execute(&plan)?;
self.log_slow_query(sql, start.elapsed(), results.len() as u64);
return Ok(results);
}
}
let is_non_deterministic = {
let up = sql.to_ascii_uppercase();
["NEXTVAL", "SETVAL", "CURRVAL", "GEN_RANDOM_UUID",
"UUID_GENERATE_V4", "RANDOM(", "NOW(", "CLOCK_TIMESTAMP"]
.iter()
.any(|m| up.contains(m))
};
if !is_non_deterministic {
if let Some(cached_results) = self.result_cache.lock().ok()
.and_then(|mut cache| cache.get(sql).map(std::sync::Arc::clone))
{
tracing::debug!(phase = "result_cache", "Result cache hit");
self.log_slow_query(sql, start.elapsed(), cached_results.len() as u64);
return Ok((*cached_results).clone());
}
}
if let Some(result) = self.try_fast_select(sql) {
let results = result?;
self.log_slow_query(sql, start.elapsed(), results.len() as u64);
return Ok(results);
}
let cached_plan = self.plan_cache.lock().ok().and_then(|mut cache| cache.get(sql).map(std::sync::Arc::clone));
if let Some(arc_plan) = cached_plan {
tracing::debug!(phase = "parse", duration_us = 0_u64, "SQL parsed (cached)");
tracing::debug!(phase = "plan", duration_us = 0_u64, "Logical plan created (cached)");
if self.tenant_manager.get_current_context().is_none() {
let exec_start = std::time::Instant::now();
let mut executor = sql::Executor::with_storage(&self.storage)
.with_timeout(self.config.storage.query_timeout_ms);
let results = executor.execute(&arc_plan)?;
tracing::debug!(phase = "execute", duration_us = exec_start.elapsed().as_micros() as u64, rows = results.len() as u64, "Query executed");
self.log_slow_query(sql, start.elapsed(), results.len() as u64);
if let Ok(mut cache) = self.result_cache.lock() {
cache.put(sql.to_string(), std::sync::Arc::new(results.clone()));
}
return Ok(results);
}
let plan = self.apply_rls_to_plan((*arc_plan).clone())?;
let exec_start = std::time::Instant::now();
let mut executor = sql::Executor::with_storage(&self.storage)
.with_timeout(self.config.storage.query_timeout_ms);
let results = executor.execute(&plan)?;
tracing::debug!(phase = "execute", duration_us = exec_start.elapsed().as_micros() as u64, rows = results.len() as u64, "Query executed");
self.log_slow_query(sql, start.elapsed(), results.len() as u64);
return Ok(results);
}
let parse_start = std::time::Instant::now();
let (statement, _parse_cached) = self.parse_cached(sql)?;
tracing::debug!(phase = "parse", duration_us = parse_start.elapsed().as_micros() as u64, "SQL parsed");
let plan_start = std::time::Instant::now();
let catalog = self.storage.catalog();
let planner = sql::Planner::with_catalog(&catalog)
.with_sql(sql.to_string());
let plan = planner.statement_to_plan(statement)?;
tracing::debug!(phase = "plan", duration_us = plan_start.elapsed().as_micros() as u64, "Logical plan created");
let plan = {
let opt_start = std::time::Instant::now();
let stats = optimizer::cost::StatsCatalog::new();
let rules: Vec<Box<dyn optimizer::rules::OptimizationRule>> = vec![
Box::new(optimizer::rules::ConstantFoldingRule::new()),
Box::new(optimizer::rules::SelectionPushdownRule::new()),
Box::new(optimizer::rules::JoinPredicatePushdownRule::new()),
Box::new(optimizer::rules::ProjectionPruningRule::new()),
];
let opt = optimizer::Optimizer::with_rules(
stats,
rules,
optimizer::OptimizerConfig::default(),
);
let optimized = opt.optimize_recursive(plan)?;
tracing::debug!(phase = "optimize", duration_us = opt_start.elapsed().as_micros() as u64, "Plan optimized");
optimized
};
if let Ok(mut cache) = self.plan_cache.lock() {
cache.put(sql.to_string(), std::sync::Arc::new(plan.clone()));
}
let plan = self.apply_rls_to_plan(plan)?;
let exec_start = std::time::Instant::now();
let mut executor = sql::Executor::with_storage(&self.storage)
.with_timeout(self.config.storage.query_timeout_ms);
let results = executor.execute(&plan)?;
tracing::debug!(phase = "execute", duration_us = exec_start.elapsed().as_micros() as u64, rows = results.len() as u64, "Query executed");
self.log_slow_query(sql, start.elapsed(), results.len() as u64);
if let Ok(mut cache) = self.result_cache.lock() {
cache.put(sql.to_string(), std::sync::Arc::new(results.clone()));
}
Ok(results)
}
pub fn query_with_columns(&self, sql: &str) -> Result<(Vec<Tuple>, Vec<String>)> {
let plan = if sql::Parser::is_show_branches(sql) {
sql::LogicalPlan::ShowBranches
} else {
let (statement, _) = self.parse_cached(sql)?;
let catalog = self.storage.catalog();
let planner = sql::Planner::with_catalog(&catalog)
.with_sql(sql.to_string());
planner.statement_to_plan(statement)?
};
let plan = {
let stats = optimizer::cost::StatsCatalog::new();
let rules: Vec<Box<dyn optimizer::rules::OptimizationRule>> = vec![
Box::new(optimizer::rules::ConstantFoldingRule::new()),
Box::new(optimizer::rules::SelectionPushdownRule::new()),
Box::new(optimizer::rules::JoinPredicatePushdownRule::new()),
Box::new(optimizer::rules::ProjectionPruningRule::new()),
];
let opt = optimizer::Optimizer::with_rules(
stats,
rules,
optimizer::OptimizerConfig::default(),
);
opt.optimize_recursive(plan)?
};
let mut executor = sql::Executor::with_storage(&self.storage)
.with_timeout(self.config.storage.query_timeout_ms);
executor.execute_with_columns(&plan)
}
pub fn dump_full(&self, path: &std::path::Path) -> Result<storage::DumpMetadata> {
self.dump_manager.create_full_dump(path, self)
}
pub fn dump_sql(&self, path: &std::path::Path) -> Result<storage::DumpMetadata> {
self.dump_manager.create_sql_dump(path, self)
}
pub fn dump_incremental(&self, path: &std::path::Path) -> Result<storage::DumpMetadata> {
self.dump_manager.create_incremental_dump(path, self, false)
}
pub fn dump_incremental_append(&self, path: &std::path::Path) -> Result<storage::DumpMetadata> {
self.dump_manager.create_incremental_dump(path, self, true)
}
pub fn restore_from_dump(&mut self, path: &std::path::Path) -> Result<()> {
let dump_manager = self.dump_manager.clone();
dump_manager.restore_from_dump(path, self)
}
pub fn dump_full_compressed(&self, path: &std::path::Path, compression: storage::DumpCompressionType) -> Result<storage::DumpMetadata> {
let manager = storage::DumpManager::new(
path.parent().unwrap_or(std::path::Path::new(".")).to_path_buf(),
compression
);
manager.create_full_dump(path, self)
}
pub fn dump_full_uncompressed(&self, path: &std::path::Path) -> Result<storage::DumpMetadata> {
self.dump_full_compressed(path, storage::DumpCompressionType::None)
}
pub fn dump_tables(&self, path: &std::path::Path, tables: Vec<&str>) -> Result<storage::DumpMetadata> {
self.dump_full(path)
}
pub fn restore_tables(&mut self, path: &std::path::Path, _tables: Vec<&str>) -> Result<()> {
self.restore_from_dump(path)
}
pub fn read_dump_metadata(&self, path: &std::path::Path) -> Result<storage::DumpMetadata> {
use std::io::{Read, Seek, SeekFrom};
let file = std::fs::File::open(path).map_err(|e| Error::io(e.to_string()))?;
let mut reader = std::io::BufReader::new(file);
reader.seek(SeekFrom::Start(12)).map_err(|e| Error::io(e.to_string()))?;
let mut len_bytes = [0u8; 4];
reader.read_exact(&mut len_bytes).map_err(|e| Error::io(e.to_string()))?;
let len = u32::from_le_bytes(len_bytes) as usize;
if len == 0 || len > 8192 {
return Err(Error::io("Invalid metadata length".to_string()));
}
let mut json_bytes = vec![0u8; len];
reader.read_exact(&mut json_bytes).map_err(|e| Error::io(e.to_string()))?;
let metadata: storage::DumpMetadata = serde_json::from_slice(&json_bytes)
.map_err(|e| Error::io(format!("Failed to deserialize metadata: {}", e)))?;
Ok(metadata)
}
pub fn create_session(&self, user_name: &str, isolation: crate::session::IsolationLevel) -> Result<crate::session::SessionId> {
let user = crate::session::User::new_passwordless(user_name);
self.session_manager.create_session(&user, isolation)
}
pub fn destroy_session(&self, session_id: crate::session::SessionId) -> Result<()> {
self.session_manager.destroy_session(session_id)
}
pub fn begin_transaction_for_session(&self, session_id: crate::session::SessionId) -> Result<()> {
let session_lock = self.session_manager.get_session(session_id)?;
let mut session = session_lock.write();
if session.active_txn.is_some() {
return Err(Error::transaction("Session already has an active transaction"));
}
let txn = storage::Transaction::new_with_session(
self.storage.db.clone(),
self.storage.next_timestamp(),
self.storage.snapshot_manager_arc(),
session_id,
session.isolation_level,
self.lock_manager.clone(),
self.dirty_tracker.clone(),
)?;
let txn_id = txn.snapshot_id();
session.active_txn = Some(txn_id);
session.stats.transactions_started += 1;
self.session_transactions.insert(session_id, txn);
Ok(())
}
pub fn commit_transaction_for_session(&self, session_id: crate::session::SessionId) -> Result<()> {
let session_lock = self.session_manager.get_session(session_id)?;
let mut session = session_lock.write();
if session.active_txn.is_none() {
return Err(Error::transaction("Session has no active transaction to commit"));
}
if let Some((_, txn)) = self.session_transactions.remove(&session_id) {
txn.commit_with_timestamp(self.storage.next_timestamp())?;
self.storage.increment_lsn();
}
self.invalidate_result_cache();
session.active_txn = None;
session.stats.transactions_committed += 1;
Ok(())
}
pub fn rollback_transaction_for_session(&self, session_id: crate::session::SessionId) -> Result<()> {
let session_lock = self.session_manager.get_session(session_id)?;
let mut session = session_lock.write();
if session.active_txn.is_none() {
return Err(Error::transaction("Session has no active transaction to rollback"));
}
if let Some((_, txn)) = self.session_transactions.remove(&session_id) {
txn.rollback()?;
}
self.invalidate_result_cache();
session.active_txn = None;
session.stats.transactions_aborted += 1;
Ok(())
}
pub fn execute_in_session(&self, session_id: crate::session::SessionId, sql: &str) -> Result<u64> {
let session_lock = self.session_manager.get_session(session_id)?;
let mut session = session_lock.write();
session.touch();
session.stats.queries_executed += 1;
if self.session_transactions.contains_key(&session_id) {
if session.isolation_level == crate::session::IsolationLevel::ReadCommitted {
if let Some(mut txn) = self.session_transactions.get_mut(&session_id) {
txn.refresh_snapshot(self.storage.current_timestamp());
}
}
let txn = self.session_transactions.get(&session_id)
.ok_or_else(|| Error::transaction("Session transaction disappeared during execute"))?;
self.execute_in_transaction_no_fast_path(sql, &txn)
} else {
let txn = storage::Transaction::new_with_session(
self.storage.db.clone(),
self.storage.next_timestamp(),
self.storage.snapshot_manager_arc(),
session_id,
session.isolation_level,
self.lock_manager.clone(),
self.dirty_tracker.clone(),
)?;
let result = self.execute_in_transaction_no_fast_path(sql, &txn);
match result {
Ok(count) => {
txn.commit_with_timestamp(self.storage.next_timestamp())?;
self.storage.increment_lsn();
Ok(count)
}
Err(e) => {
let _ = txn.rollback();
Err(e)
}
}
}
}
pub fn query_in_session(&self, session_id: crate::session::SessionId, sql: &str, _params: &[&dyn std::fmt::Display]) -> Result<Vec<Tuple>> {
let session_lock = self.session_manager.get_session(session_id)?;
let mut session = session_lock.write();
session.touch();
session.stats.queries_executed += 1;
if self.session_transactions.contains_key(&session_id) {
if session.isolation_level == crate::session::IsolationLevel::ReadCommitted {
if let Some(mut txn) = self.session_transactions.get_mut(&session_id) {
txn.refresh_snapshot(self.storage.current_timestamp());
}
}
let txn = self.session_transactions.get(&session_id)
.ok_or_else(|| Error::transaction("Session transaction disappeared during query"))?;
let (statement, _) = self.parse_cached(sql)?;
let catalog = self.storage.catalog();
let planner = sql::Planner::with_catalog(&catalog)
.with_sql(sql.to_string());
let plan = planner.statement_to_plan(statement)?;
let mut executor = sql::Executor::with_storage(&self.storage)
.with_timeout(self.config.storage.query_timeout_ms)
.with_transaction(&txn);
executor.execute(&plan)
} else {
self.query(sql, _params)
}
}
pub fn set_session_quota(&self, _user_name: &str, _max_sessions: usize) -> Result<()> {
Ok(())
}
pub fn set_memory_quota(&self, _user_name: &str, _max_bytes: usize) -> Result<()> {
Ok(())
}
pub fn is_dirty(&self) -> bool {
self.dirty_tracker.is_dirty()
}
pub fn mark_table_dirty(&self, table: &str) {
let _ = self.dirty_tracker.track_insert(table, "dummy_key", &[]);
}
pub fn query_params(&self, sql: &str, params: &[Value]) -> Result<Vec<Tuple>> {
#[cfg(feature = "code-graph")]
let (rewritten_owned, _branch_guard) = self.rewrite_and_scope(sql);
#[cfg(feature = "code-graph")]
let sql: &str = &rewritten_owned;
#[cfg(not(feature = "code-graph"))]
let sql: &str = sql;
let start = std::time::Instant::now();
let parse_start = std::time::Instant::now();
let (statement, _) = self.parse_cached(sql)?;
tracing::debug!(phase = "parse", duration_us = parse_start.elapsed().as_micros() as u64, "SQL parsed");
let plan_start = std::time::Instant::now();
let catalog = self.storage.catalog();
let planner = sql::Planner::with_catalog(&catalog)
.with_sql(sql.to_string());
let mut plan = planner.statement_to_plan(statement)?;
tracing::debug!(phase = "plan", duration_us = plan_start.elapsed().as_micros() as u64, "Logical plan created");
plan = self.apply_rls_to_plan(plan)?;
let exec_start = std::time::Instant::now();
let results = self.query_plan_with_params(&plan, params)?;
tracing::debug!(phase = "execute", duration_us = exec_start.elapsed().as_micros() as u64, rows = results.len() as u64, "Query executed");
self.log_slow_query(sql, start.elapsed(), results.len() as u64);
Ok(results)
}
fn query_plan_with_params(&self, plan: &sql::LogicalPlan, params: &[Value]) -> Result<Vec<Tuple>> {
let mut executor = sql::Executor::with_storage(&self.storage)
.with_timeout(self.config.storage.query_timeout_ms)
.with_parameters(params.to_vec());
executor.execute(plan)
}
pub fn begin(&self) -> Result<()> {
self.begin_transaction_internal()
}
pub fn commit(&self) -> Result<()> {
self.commit_internal()
}
pub fn rollback(&self) -> Result<()> {
self.rollback_internal()
}
pub fn in_transaction(&self) -> bool {
self.current_transaction.lock()
.map(|txn| txn.is_some())
.unwrap_or(false)
}
#[allow(dead_code)]
pub(crate) fn bulk_insert_tuples(
&self,
table_name: &str,
tuples: Vec<Tuple>,
) -> Result<Vec<u64>> {
let catalog = self.storage.catalog();
let schema = catalog.get_table_schema(table_name)?;
let mut row_ids: Vec<u64> = Vec::with_capacity(tuples.len());
let mut art_updates: Vec<(u64, std::collections::HashMap<String, Value>)> =
Vec::with_capacity(tuples.len());
for mut tuple in tuples {
let row_id = catalog.next_row_id(table_name)?;
for (i, col) in schema.columns.iter().enumerate() {
if col.primary_key {
if let Some(v) = tuple.values.get(i) {
if matches!(v, Value::Null) && i < tuple.values.len() {
#[allow(clippy::indexing_slicing)]
match col.data_type {
DataType::Int2 => { tuple.values[i] = Value::Int2(row_id as i16); }
DataType::Int4 => { tuple.values[i] = Value::Int4(row_id as i32); }
_ => { tuple.values[i] = Value::Int8(row_id as i64); }
}
}
}
}
}
let mut col_values = std::collections::HashMap::with_capacity(schema.columns.len());
for (i, col) in schema.columns.iter().enumerate() {
if let Some(v) = tuple.values.get(i) {
col_values.insert(col.name.clone(), v.clone());
}
}
if let Err(e) = self.storage.art_indexes()
.check_unique_constraints(table_name, &col_values)
{
return Err(Error::constraint_violation(e.to_string()));
}
let val = bincode::serialize(&tuple)
.map_err(|e| Error::storage(format!("Failed to serialize tuple: {}", e)))?;
let key = self.storage.branch_aware_data_key(table_name, row_id);
self.storage.put(&key, &val)?;
art_updates.push((row_id, col_values));
row_ids.push(row_id);
}
for (row_id, col_values) in art_updates {
if let Err(e) = self.storage.art_indexes()
.on_insert(table_name, row_id, &col_values)
{
tracing::debug!("ART on_insert {}: {}", table_name, e);
}
}
self.invalidate_result_cache();
if let Ok(mut cache) = self.plan_cache.lock() {
cache.clear();
}
Ok(row_ids)
}
#[deprecated(since = "2.1.0", note = "Use `begin()`, `commit()`, and `rollback()` instead")]
pub fn begin_transaction(&self) -> Result<Transaction<'_>> {
let tx = self.storage.begin_transaction()?;
Ok(Transaction { tx, db: self })
}
pub fn current_lsn(&self) -> Option<u64> {
let txn_id = self.storage.snapshot_manager().current_transaction_id();
if txn_id > 1 {
Some(txn_id - 1)
} else {
Some(0)
}
}
pub fn close(self) -> Result<()> {
Ok(())
}
pub fn list_vector_stores(&self) -> Result<Vec<VectorStoreInfo>> {
use crate::vector::DistanceMetric;
let vector_mgr = self.storage.vector_indexes();
let metadata_list = vector_mgr.list_all_metadata();
Ok(metadata_list.iter().map(|meta| {
let (vector_count, metric, index_type) = match vector_mgr.get_index_stats(&meta.name) {
Ok(stats) => (
stats.num_vectors as u64,
match &meta.index_type {
storage::VectorIndexType::Standard(cfg) => match cfg.distance_metric {
DistanceMetric::L2 => "l2".to_string(),
DistanceMetric::Cosine => "cosine".to_string(),
DistanceMetric::InnerProduct => "inner_product".to_string(),
},
storage::VectorIndexType::Quantized(cfg) => match cfg.distance_metric {
DistanceMetric::L2 => "l2".to_string(),
DistanceMetric::Cosine => "cosine".to_string(),
DistanceMetric::InnerProduct => "inner_product".to_string(),
},
},
match &meta.index_type {
storage::VectorIndexType::Standard(_) => "hnsw".to_string(),
storage::VectorIndexType::Quantized(_) => "hnsw_pq".to_string(),
},
),
Err(_) => (0, "cosine".to_string(), "hnsw".to_string()),
};
let dimensions = match &meta.index_type {
storage::VectorIndexType::Standard(cfg) => cfg.dimension as u32,
storage::VectorIndexType::Quantized(cfg) => cfg.dimension as u32,
};
VectorStoreInfo {
name: meta.name.clone(),
dimensions,
vector_count,
created_at: "N/A".to_string(),
metric,
index_type,
}
}).collect())
}
pub fn create_vector_store(&self, name: &str, dimensions: u32) -> Result<VectorStoreInfo> {
use crate::vector::DistanceMetric;
let vector_mgr = self.storage.vector_indexes();
vector_mgr.create_index(
name.to_string(),
name.to_string(), "embedding".to_string(), dimensions as usize,
DistanceMetric::Cosine, )?;
Ok(VectorStoreInfo {
name: name.to_string(),
dimensions,
vector_count: 0,
created_at: chrono::Utc::now().to_rfc3339(),
metric: "cosine".to_string(),
index_type: "hnsw".to_string(),
})
}
pub fn get_vector_store(&self, name: &str) -> Result<VectorStoreInfo> {
use crate::vector::DistanceMetric;
let vector_mgr = self.storage.vector_indexes();
let meta = vector_mgr.get_metadata(name)?;
let stats = vector_mgr.get_index_stats(name)?;
let metric = match &meta.index_type {
storage::VectorIndexType::Standard(cfg) => match cfg.distance_metric {
DistanceMetric::L2 => "l2".to_string(),
DistanceMetric::Cosine => "cosine".to_string(),
DistanceMetric::InnerProduct => "inner_product".to_string(),
},
storage::VectorIndexType::Quantized(cfg) => match cfg.distance_metric {
DistanceMetric::L2 => "l2".to_string(),
DistanceMetric::Cosine => "cosine".to_string(),
DistanceMetric::InnerProduct => "inner_product".to_string(),
},
};
let index_type = match &meta.index_type {
storage::VectorIndexType::Standard(_) => "hnsw".to_string(),
storage::VectorIndexType::Quantized(_) => "hnsw_pq".to_string(),
};
Ok(VectorStoreInfo {
name: name.to_string(),
dimensions: stats.dimensions as u32,
vector_count: stats.num_vectors as u64,
created_at: "N/A".to_string(),
metric,
index_type,
})
}
pub fn delete_vector_store(&self, name: &str) -> Result<()> {
let vector_mgr = self.storage.vector_indexes();
vector_mgr.drop_index(name)
}
pub fn insert_vectors(&self, store: &str, vectors: Vec<Vec<f32>>) -> Result<Vec<String>> {
let vector_mgr = self.storage.vector_indexes();
let _ = vector_mgr.get_metadata(store)?;
let mut ids = Vec::with_capacity(vectors.len());
for vector in vectors {
let id = self.storage.next_timestamp();
let id_str = format!("vec_{}", id);
vector_mgr.insert_vector(store, id, &vector)?;
ids.push(id_str);
}
Ok(ids)
}
pub fn upsert_vectors(&self, store: &str, vectors: Vec<(String, Vec<f32>)>) -> Result<()> {
let vector_mgr = self.storage.vector_indexes();
let _ = vector_mgr.get_metadata(store)?;
for (id_str, vector) in vectors {
let id = id_str.strip_prefix("vec_")
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or_else(|| {
self.storage.next_timestamp()
});
let _ = vector_mgr.delete_vector(store, id);
vector_mgr.insert_vector(store, id, &vector)?;
}
Ok(())
}
pub fn search_vectors(&self, store: &str, query: Vec<f32>, k: usize) -> Result<Vec<(String, f32)>> {
let vector_mgr = self.storage.vector_indexes();
let _ = vector_mgr.get_metadata(store)?;
let results = vector_mgr.search(store, &query, k)?;
Ok(results.into_iter()
.map(|(row_id, distance)| (format!("vec_{}", row_id), distance))
.collect())
}
pub fn text_search(&self, _query: &str) -> Result<Vec<String>> {
Err(Error::Generic("Text search requires embedding model - not yet implemented".to_string()))
}
pub fn store_texts(&self, _store: &str, _texts: Vec<String>) -> Result<Vec<String>> {
Err(Error::Generic("Text storage requires embedding model - not yet implemented".to_string()))
}
pub fn hybrid_search(&self, _store: &str, _query: &str, _k: usize) -> Result<Vec<(String, f32)>> {
Err(Error::Generic("Hybrid search requires embedding model - not yet implemented".to_string()))
}
pub fn delete_vectors(&self, store: &str, ids: Vec<String>) -> Result<()> {
let vector_mgr = self.storage.vector_indexes();
let _ = vector_mgr.get_metadata(store)?;
for id_str in ids {
if let Some(id) = id_str.strip_prefix("vec_").and_then(|s| s.parse::<u64>().ok()) {
vector_mgr.delete_vector(store, id)?;
}
}
Ok(())
}
pub fn fetch_vectors(&self, _store: &str, _ids: Vec<String>) -> Result<Vec<(String, Vec<f32>)>> {
Err(Error::Generic("Vector fetch not yet implemented - HNSW index doesn't store raw vectors".to_string()))
}
pub fn list_agent_sessions(&self) -> Result<Vec<AgentSession>> {
Ok(vec![])
}
pub fn create_agent_session(&self, _name: &str) -> Result<AgentSession> {
Err(Error::Generic("Agent sessions not yet implemented".to_string()))
}
pub fn get_agent_session(&self, _id: &str) -> Result<AgentSession> {
Err(Error::Generic("Agent sessions not yet implemented".to_string()))
}
pub fn delete_agent_session(&self, _id: &str) -> Result<()> {
Err(Error::Generic("Agent sessions not yet implemented".to_string()))
}
pub fn add_agent_message(&self, _session_id: &str, _role: &str, _content: &str) -> Result<AgentMessage> {
Err(Error::Generic("Agent messages not yet implemented".to_string()))
}
pub fn get_agent_messages(&self, _session_id: &str) -> Result<Vec<AgentMessage>> {
Ok(vec![])
}
pub fn clear_agent_messages(&self, _session_id: &str) -> Result<()> {
Err(Error::Generic("Agent messages not yet implemented".to_string()))
}
pub fn generate_schema(&self, _table_name: &str) -> Result<String> {
Err(Error::Generic("Schema generation not yet implemented".to_string()))
}
pub fn chat_completion(&self, _messages: Vec<(String, String)>) -> Result<String> {
Err(Error::Generic("Chat completions not yet implemented".to_string()))
}
pub fn nl_to_sql(&self, _query: &str) -> Result<String> {
Err(Error::Generic("Natural language to SQL not yet implemented".to_string()))
}
pub fn store_document(&self, _collection: &str, _id: &str, _content: &str, _metadata: Option<serde_json::Value>) -> Result<()> {
Err(Error::Generic("Document storage not yet implemented".to_string()))
}
pub fn get_document(&self, _collection: &str, _id: &str) -> Result<DocumentData> {
Err(Error::Generic("Document storage not yet implemented".to_string()))
}
pub fn delete_document(&self, _collection: &str, _id: &str) -> Result<()> {
Err(Error::Generic("Document storage not yet implemented".to_string()))
}
pub fn update_document(&self, _collection: &str, _id: &str, _content: &str, _metadata: Option<serde_json::Value>) -> Result<()> {
Err(Error::Generic("Document storage not yet implemented".to_string()))
}
pub fn list_documents(&self, _collection: &str) -> Result<Vec<DocumentMetadata>> {
Ok(vec![])
}
pub fn search_documents(&self, _collection: &str, _query: &str) -> Result<Vec<DocumentData>> {
Ok(vec![])
}
pub fn create_collection(&self, _name: &str) -> Result<()> {
Err(Error::Generic("Collections not yet implemented".to_string()))
}
pub fn delete_collection(&self, _name: &str) -> Result<()> {
Err(Error::Generic("Collections not yet implemented".to_string()))
}
pub fn list_collections(&self) -> Result<Vec<String>> {
Ok(vec![])
}
pub fn batch_create_documents(&self, _collection: &str, _docs: Vec<DocumentData>) -> Result<Vec<String>> {
Err(Error::Generic("Batch document creation not yet implemented".to_string()))
}
pub fn batch_infer_schema(&self, _data: Vec<Vec<Value>>) -> Result<Schema> {
Err(Error::Generic("Batch schema inference not yet implemented".to_string()))
}
pub fn chat_completion_stream(&self, _messages: Vec<(String, String)>) -> Result<String> {
Err(Error::Generic("Chat completion streaming not yet implemented".to_string()))
}
pub fn compare_schemas(&self, _schema1: &Schema, _schema2: &Schema) -> Result<serde_json::Value> {
Err(Error::Generic("Schema comparison not yet implemented".to_string()))
}
pub fn create_embeddings(&self, _texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
Err(Error::Generic("Embeddings not yet implemented".to_string()))
}
pub fn create_document(&self, _collection: &str, _id: &str, _content: &str, _metadata: Option<serde_json::Value>) -> Result<String> {
Ok("document_id".to_string())
}
pub fn find_similar_documents(&self, _collection: &str, _query: &str, _limit: usize) -> Result<Vec<(DocumentData, f32)>> {
Err(Error::Generic("Similar document search not yet implemented".to_string()))
}
pub fn fork_agent_session(&self, _session_id: &str, _new_name: &str) -> Result<AgentSession> {
Err(Error::Generic("Agent session forking not yet implemented".to_string()))
}
pub fn generate_schema_from_description(&self, _description: &str) -> Result<Schema> {
Err(Error::Generic("Schema generation from description not yet implemented".to_string()))
}
pub fn get_agent_context(&self, _session_id: &str) -> Result<serde_json::Value> {
Err(Error::Generic("Agent context retrieval not yet implemented".to_string()))
}
pub fn get_chat_model(&self, _model_id: &str) -> Result<serde_json::Value> {
Err(Error::Generic("Chat model retrieval not yet implemented".to_string()))
}
pub fn get_document_chunks(&self, _collection: &str, _id: &str) -> Result<Vec<(String, f32)>> {
Err(Error::Generic("Document chunking not yet implemented".to_string()))
}
pub fn infer_schema(&self, _data: Vec<Vec<Value>>) -> Result<Schema> {
Err(Error::Generic("Schema inference not yet implemented".to_string()))
}
pub fn infer_schema_from_file(&self, _path: &str) -> Result<Schema> {
Err(Error::Generic("Schema inference from file not yet implemented".to_string()))
}
pub fn instantiate_schema_template(&self, _template_name: &str, _params: serde_json::Value) -> Result<Schema> {
Err(Error::Generic("Schema template instantiation not yet implemented".to_string()))
}
pub fn list_chat_models(&self) -> Result<Vec<serde_json::Value>> {
Ok(vec![])
}
pub fn list_schema_templates(&self) -> Result<Vec<serde_json::Value>> {
Ok(vec![])
}
pub fn optimize_schema(&self, _schema: &Schema) -> Result<Schema> {
Err(Error::Generic("Schema optimization not yet implemented".to_string()))
}
pub fn validate_schema(&self, _schema: &Schema) -> Result<bool> {
Err(Error::Generic("Schema validation not yet implemented".to_string()))
}
pub fn rag_search(&self, _collection: &str, _query: &str, _k: usize) -> Result<Vec<(DocumentData, f32, String)>> {
Err(Error::Generic("RAG search not yet implemented".to_string()))
}
pub fn rechunk_document(&self, _collection: &str, _id: &str, _chunk_size: usize) -> Result<Vec<String>> {
Err(Error::Generic("Document rechunking not yet implemented".to_string()))
}
pub fn search_agent_memory(&self, _session_id: &str, _query: &str) -> Result<Vec<(AgentMessage, f32)>> {
Err(Error::Generic("Agent memory search not yet implemented".to_string()))
}
pub fn summarize_agent_memory(&self, _session_id: &str) -> Result<String> {
Err(Error::Generic("Agent memory summarization not yet implemented".to_string()))
}
pub fn create_branch(&self, name: &str) -> Result<u64> {
self.execute(&format!("CREATE BRANCH {name}"))
}
pub fn switch_branch(&self, name: &str) -> Result<u64> {
self.execute(&format!("USE BRANCH {name}"))
}
pub fn merge_branch(&self, source: &str) -> Result<u64> {
self.execute(&format!("MERGE BRANCH {source}"))
}
pub fn drop_branch(&self, name: &str) -> Result<u64> {
self.execute(&format!("DROP BRANCH {name}"))
}
pub fn list_branches(&self) -> Result<Vec<Tuple>> {
self.query("LIST BRANCHES", &[])
}
pub fn explain(&self, sql: &str) -> Result<Vec<Tuple>> {
self.query(&format!("EXPLAIN {sql}"), &[])
}
pub fn explain_analyze(&self, sql: &str) -> Result<Vec<Tuple>> {
self.query(&format!("EXPLAIN ANALYZE {sql}"), &[])
}
pub fn refresh_materialized_view(&self, name: &str) -> Result<u64> {
self.execute(&format!("REFRESH MATERIALIZED VIEW {name}"))
}
fn clone_for_trigger(&self) -> Self {
Self {
storage: self.storage.clone(),
config: self.config.clone(),
current_transaction: self.current_transaction.clone(),
tenant_manager: self.tenant_manager.clone(),
trigger_registry: self.trigger_registry.clone(),
function_registry: self.function_registry.clone(),
mv_scheduler: self.mv_scheduler.clone(),
auto_refresh_worker: self.auto_refresh_worker.clone(),
dump_manager: self.dump_manager.clone(),
session_manager: self.session_manager.clone(),
lock_manager: self.lock_manager.clone(),
dirty_tracker: self.dirty_tracker.clone(),
session_transactions: self.session_transactions.clone(),
prepared_statements: self.prepared_statements.clone(),
savepoints: self.savepoints.clone(),
plan_cache: self.plan_cache.clone(),
parse_cache: self.parse_cache.clone(),
result_cache: self.result_cache.clone(),
art_undo_log: self.art_undo_log.clone(),
}
}
fn check_foreign_key_exists(
&self,
table_name: &str,
column_names: &[String],
values: &[Value],
) -> Result<bool> {
let catalog = self.storage.catalog();
let schema = catalog.get_table_schema(table_name)?;
let tuples = self.storage.scan_table(table_name)?;
for tuple in tuples {
let mut matches = true;
for (col_name, expected_value) in column_names.iter().zip(values.iter()) {
let col_idx = schema.columns.iter()
.position(|c| &c.name == col_name);
if let Some(idx) = col_idx {
match tuple.values.get(idx) {
Some(actual_value) if actual_value == expected_value => {}
_ => { matches = false; break; }
}
} else {
matches = false;
break;
}
}
if matches {
return Ok(true);
}
}
Ok(false)
}
fn check_unique_violation(
&self,
table_name: &str,
column_names: &[String],
values: &[Value],
) -> Result<bool> {
let catalog = self.storage.catalog();
let schema = catalog.get_table_schema(table_name)?;
let tuples = self.storage.scan_table(table_name)?;
for tuple in tuples {
let mut matches = true;
for (col_name, expected_value) in column_names.iter().zip(values.iter()) {
let col_idx = schema.columns.iter()
.position(|c| &c.name == col_name);
if let Some(idx) = col_idx {
match tuple.values.get(idx) {
Some(actual_value) if actual_value == expected_value => {}
_ => { matches = false; break; }
}
} else {
matches = false;
break;
}
}
if matches {
return Ok(true); }
}
Ok(false) }
fn cascade_delete_referencing_rows(
&self,
table_name: &str,
fk_columns: &[String],
parent_values: &[Value],
) -> Result<()> {
let catalog = self.storage.catalog();
let schema = catalog.get_table_schema(table_name)?;
let tuples = self.storage.scan_table(table_name)?;
let mut row_ids_to_delete: Vec<u64> = Vec::new();
for tuple in tuples {
let mut matches = true;
for (fk_col, parent_val) in fk_columns.iter().zip(parent_values.iter()) {
let col_idx = schema.columns.iter().position(|c| &c.name == fk_col);
if let Some(idx) = col_idx {
match tuple.values.get(idx) {
Some(val) if val == parent_val => {}
_ => { matches = false; break; }
}
} else {
matches = false;
break;
}
}
if matches {
if let Some(row_id) = tuple.row_id {
row_ids_to_delete.push(row_id);
}
}
}
let txn = self.storage.begin_transaction()?;
for row_id in row_ids_to_delete {
let key = self.storage.branch_aware_data_key(table_name, row_id);
txn.delete(key.clone())?;
self.storage.log_data_delete(table_name, &key)?;
}
txn.commit()?;
Ok(())
}
fn set_null_referencing_rows(
&self,
table_name: &str,
fk_columns: &[String],
parent_values: &[Value],
) -> Result<()> {
let catalog = self.storage.catalog();
let schema = catalog.get_table_schema(table_name)?;
let tuples = self.storage.scan_table(table_name)?;
let mut rows_to_update: Vec<(u64, Tuple)> = Vec::new();
for tuple in tuples {
let mut matches = true;
for (fk_col, parent_val) in fk_columns.iter().zip(parent_values.iter()) {
let col_idx = schema.columns.iter().position(|c| &c.name == fk_col);
if let Some(idx) = col_idx {
match tuple.values.get(idx) {
Some(val) if val == parent_val => {}
_ => { matches = false; break; }
}
} else {
matches = false;
break;
}
}
if matches {
if let Some(row_id) = tuple.row_id {
let mut new_values = tuple.values.clone();
for fk_col in fk_columns {
if let Some(idx) = schema.columns.iter().position(|c| &c.name == fk_col) {
if let Some(slot) = new_values.get_mut(idx) {
*slot = Value::Null;
}
}
}
let new_tuple = Tuple::new(new_values);
rows_to_update.push((row_id, new_tuple));
}
}
}
let txn = self.storage.begin_transaction()?;
for (row_id, new_tuple) in rows_to_update {
let key = self.storage.branch_aware_data_key(table_name, row_id);
let val = bincode::serialize(&new_tuple).map_err(|e| Error::storage(e.to_string()))?;
txn.put(key.clone(), val.clone())?;
self.storage.log_data_update(table_name, &key, &val)?;
}
txn.commit()?;
Ok(())
}
fn evaluate_check_constraint(
&self,
expression: &str,
schema: &Schema,
values: &[Value],
) -> Result<bool> {
let tuple = Tuple::new(values.to_vec());
let logical_expr = if expression.starts_with('{') || expression.starts_with('[') {
serde_json::from_str::<sql::LogicalExpr>(expression)
.map_err(|e| Error::query_execution(format!(
"Failed to deserialize CHECK constraint expression '{}': {}",
expression, e
)))?
} else {
use sqlparser::dialect::PostgreSqlDialect;
use sqlparser::parser::Parser as SqlParser;
let sql = format!("SELECT * FROM dummy WHERE {}", expression);
let dialect = PostgreSqlDialect {};
let mut statements = SqlParser::parse_sql(&dialect, &sql)
.map_err(|e| Error::query_execution(format!(
"Failed to parse CHECK constraint expression '{}': {}",
expression, e
)))?;
if statements.len() != 1 {
return Err(Error::query_execution(
"Invalid CHECK constraint expression: expected single statement"
));
}
let statement = statements.remove(0);
let selection = if let sqlparser::ast::Statement::Query(query) = statement {
if let sqlparser::ast::SetExpr::Select(select) = *query.body {
select.selection
} else {
None
}
} else {
None
};
let selection = selection.ok_or_else(|| Error::query_execution(format!(
"Failed to extract expression from CHECK constraint: {}",
expression
)))?;
let catalog = self.storage.catalog();
let planner = sql::Planner::with_catalog(&catalog);
planner.convert_expr_to_logical(&selection, Some(schema))?
};
let evaluator = sql::Evaluator::new(std::sync::Arc::new(schema.clone()));
let result = evaluator.evaluate(&logical_expr, &tuple)?;
match result {
Value::Boolean(b) => Ok(b),
Value::Null => Ok(true), _ => Err(Error::constraint_violation(format!(
"CHECK constraint expression '{}' did not evaluate to boolean",
expression
))),
}
}
fn check_fk_constraints_on_write(
&self,
table_name: &str,
col_values: &std::collections::HashMap<String, Value>,
active_txn: Option<&storage::Transaction>,
) -> Result<()> {
let catalog = self.storage.catalog();
let constraints = match catalog.load_table_constraints(table_name) {
Ok(c) => c,
Err(_) => return Ok(()), };
if constraints.foreign_keys.is_empty() {
return Ok(());
}
for fk in &constraints.foreign_keys {
let mut parent_values: Vec<Value> = Vec::with_capacity(fk.columns.len());
let mut any_null = false;
for col in &fk.columns {
match col_values.get(col) {
Some(v) if matches!(v, Value::Null) => {
any_null = true;
break;
}
Some(v) => parent_values.push(v.clone()),
None => {
any_null = true;
break;
}
}
}
if any_null {
continue;
}
let parent_exists = self.check_referencing_rows_exist(
&fk.references_table,
&fk.references_columns,
&parent_values,
active_txn,
)?;
if !parent_exists {
let parent_repr: Vec<String> = parent_values.iter()
.map(|v| format!("{v}"))
.collect();
return Err(Error::constraint_violation(format!(
"Foreign key constraint '{}' violated: row references \
non-existent {}({}) = ({})",
fk.name,
fk.references_table,
fk.references_columns.join(", "),
parent_repr.join(", "),
)));
}
}
Ok(())
}
fn check_referencing_rows_exist(
&self,
table_name: &str,
column_names: &[String],
values: &[Value],
active_txn: Option<&storage::Transaction>,
) -> Result<bool> {
if active_txn.is_none() && column_names.len() == values.len() && !column_names.is_empty() {
let art = self.storage.art_indexes();
let index_name = if column_names.len() == 1 {
#[allow(clippy::indexing_slicing)]
art.find_column_index(table_name, &column_names[0])
} else {
None
};
if let Some(name) = index_name {
let key = storage::ArtIndexManager::encode_key(values);
let row_ids = art.index_get_all(&name, &key);
if !row_ids.is_empty() {
return Ok(true);
}
return Ok(false);
}
}
let catalog = self.storage.catalog();
let schema = catalog.get_table_schema(table_name)?;
let base = self.storage.scan_table(table_name)?;
let tuples = if let Some(txn) = active_txn {
txn.merge_with_write_set(table_name, base)?
} else {
base
};
for tuple in tuples {
let mut matches = true;
for (col_name, expected_value) in column_names.iter().zip(values.iter()) {
let col_idx = schema.columns.iter()
.position(|c| &c.name == col_name);
if let Some(idx) = col_idx {
match tuple.values.get(idx) {
Some(actual_value) if actual_value == expected_value => {}
_ => { matches = false; break; }
}
} else {
matches = false;
break;
}
}
if matches {
return Ok(true);
}
}
Ok(false)
}
#[cfg(feature = "server")]
pub fn start_qps_reset_task(&self) -> tokio::task::JoinHandle<()> {
let tenant_manager = self.tenant_manager.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(1));
loop {
interval.tick().await;
let tenants = tenant_manager.list_tenants();
for tenant in tenants {
let _ = tenant_manager.reset_qps_window(tenant.id);
}
}
})
}
pub fn reset_all_qps_windows(&self) {
let tenants = self.tenant_manager.list_tenants();
for tenant in tenants {
let _ = self.tenant_manager.reset_qps_window(tenant.id);
}
}
fn execute_plan_internal(&self, plan: &sql::LogicalPlan) -> Result<u64> {
let (count, _tuples) = self.execute_plan_with_params(plan, &[])?;
Ok(count)
}
pub(crate) const RESERVED_DATABASE_NAMES: &'static [&'static str] = &["heliosdb", "postgres"];
fn database_name_is_reserved(name: &str) -> bool {
let lower = name.to_ascii_lowercase();
Self::RESERVED_DATABASE_NAMES.iter().any(|r| *r == lower)
}
fn handle_create_database(&self, name: &str, if_not_exists: bool) -> Result<(u64, Vec<Tuple>)> {
let trimmed = name.trim();
if trimmed.is_empty() {
return Err(Error::query_execution("CREATE DATABASE requires a non-empty name"));
}
if Self::database_name_is_reserved(trimmed) {
if if_not_exists {
return Ok((0, vec![]));
}
return Err(Error::query_execution(format!(
"database \"{trimmed}\" is a reserved system database and cannot be created"
)));
}
let already_exists = self
.tenant_manager
.list_tenants()
.iter()
.any(|t| t.name.eq_ignore_ascii_case(trimmed));
if already_exists {
if if_not_exists {
return Ok((0, vec![]));
}
return Err(Error::query_execution(format!(
"database \"{trimmed}\" already exists"
)));
}
self.tenant_manager.register_tenant_with_plan(
trimmed.to_string(),
crate::tenant::IsolationMode::DatabasePerTenant,
"free",
);
Ok((0, vec![]))
}
fn handle_drop_database(&self, name: &str, if_exists: bool) -> Result<(u64, Vec<Tuple>)> {
let trimmed = name.trim();
if trimmed.is_empty() {
return Err(Error::query_execution("DROP DATABASE requires a non-empty name"));
}
if Self::database_name_is_reserved(trimmed) {
return Err(Error::query_execution(format!(
"database \"{trimmed}\" is a reserved system database and cannot be dropped"
)));
}
let target = self
.tenant_manager
.list_tenants()
.into_iter()
.find(|t| t.name.eq_ignore_ascii_case(trimmed));
let Some(tenant) = target else {
if if_exists {
return Ok((0, vec![]));
}
return Err(Error::query_execution(format!(
"database \"{trimmed}\" does not exist"
)));
};
self.tenant_manager
.delete_tenant(tenant.id)
.map_err(Error::query_execution)?;
Ok((0, vec![]))
}
pub fn database_name_is_valid(&self, name: &str) -> bool {
let trimmed = name.trim();
if trimmed.is_empty() {
return false;
}
if Self::database_name_is_reserved(trimmed) {
return true;
}
self.tenant_manager
.list_tenants()
.iter()
.any(|t| t.name.eq_ignore_ascii_case(trimmed))
}
pub(crate) fn apply_defaults_and_check_not_null(
tuple_values: &mut [Value],
schema: &Schema,
user_provided: &[bool],
) -> Result<()> {
let evaluator = sql::Evaluator::new(std::sync::Arc::new(schema.clone()));
let empty_tuple = Tuple::new(vec![]);
for (idx, col) in schema.columns.iter().enumerate() {
let slot = match tuple_values.get_mut(idx) {
Some(s) => s,
None => continue,
};
let is_null = matches!(slot, Value::Null);
let was_omitted = user_provided.get(idx).copied().unwrap_or(false) == false;
if is_null && was_omitted {
if let Some(json) = &col.default_expr {
if let Ok(default_expr) = serde_json::from_str::<sql::LogicalExpr>(json) {
if let Ok(v) = evaluator.evaluate(&default_expr, &empty_tuple) {
let casted = if v.data_type() != col.data_type {
evaluator.cast_value(v, &col.data_type).unwrap_or(Value::Null)
} else {
v
};
*slot = casted;
}
}
}
}
if !col.nullable && !col.primary_key {
if matches!(slot, Value::Null) {
return Err(Error::constraint_violation(format!(
"NOT NULL constraint violated: cannot insert NULL into column '{}'",
col.name
)));
}
}
}
Ok(())
}
fn materialize_scalar_subqueries_for_row(
&self,
expr: &sql::LogicalExpr,
outer_row: &Tuple,
outer_schema: &Schema,
outer_table: &str,
) -> Result<sql::LogicalExpr> {
use sql::LogicalExpr;
match expr {
LogicalExpr::ScalarSubquery { subquery } => {
let bound_plan = Self::bind_outer_refs_in_plan(
subquery.as_ref(),
outer_row,
outer_schema,
outer_table,
);
let mut executor = sql::Executor::with_storage(&self.storage)
.with_timeout(self.config.storage.query_timeout_ms);
let rows = executor.execute(&bound_plan)?;
let value = rows.first()
.and_then(|t| t.values.first().cloned())
.unwrap_or(Value::Null);
Ok(LogicalExpr::Literal(value))
}
LogicalExpr::BinaryExpr { left, op, right } => Ok(LogicalExpr::BinaryExpr {
left: Box::new(self.materialize_scalar_subqueries_for_row(left, outer_row, outer_schema, outer_table)?),
op: *op,
right: Box::new(self.materialize_scalar_subqueries_for_row(right, outer_row, outer_schema, outer_table)?),
}),
LogicalExpr::UnaryExpr { op, expr: inner } => Ok(LogicalExpr::UnaryExpr {
op: *op,
expr: Box::new(self.materialize_scalar_subqueries_for_row(inner, outer_row, outer_schema, outer_table)?),
}),
LogicalExpr::Cast { expr: inner, data_type } => Ok(LogicalExpr::Cast {
expr: Box::new(self.materialize_scalar_subqueries_for_row(inner, outer_row, outer_schema, outer_table)?),
data_type: data_type.clone(),
}),
other => Ok(other.clone()),
}
}
fn bind_outer_refs_in_plan(
plan: &sql::LogicalPlan,
outer_row: &Tuple,
outer_schema: &Schema,
outer_table: &str,
) -> sql::LogicalPlan {
use sql::LogicalPlan;
match plan {
LogicalPlan::Filter { input, predicate } => LogicalPlan::Filter {
input: Box::new(Self::bind_outer_refs_in_plan(input, outer_row, outer_schema, outer_table)),
predicate: Self::bind_outer_refs_in_expr(predicate, outer_row, outer_schema, outer_table),
},
LogicalPlan::Project { input, exprs, aliases, distinct, distinct_on } => LogicalPlan::Project {
input: Box::new(Self::bind_outer_refs_in_plan(input, outer_row, outer_schema, outer_table)),
exprs: exprs.iter()
.map(|e| Self::bind_outer_refs_in_expr(e, outer_row, outer_schema, outer_table))
.collect(),
aliases: aliases.clone(),
distinct: *distinct,
distinct_on: distinct_on.clone(),
},
LogicalPlan::Limit { input, limit, offset, limit_param, offset_param } => LogicalPlan::Limit {
input: Box::new(Self::bind_outer_refs_in_plan(input, outer_row, outer_schema, outer_table)),
limit: *limit,
offset: *offset,
limit_param: *limit_param,
offset_param: *offset_param,
},
LogicalPlan::Sort { input, exprs, asc } => LogicalPlan::Sort {
input: Box::new(Self::bind_outer_refs_in_plan(input, outer_row, outer_schema, outer_table)),
exprs: exprs.iter()
.map(|e| Self::bind_outer_refs_in_expr(e, outer_row, outer_schema, outer_table))
.collect(),
asc: asc.clone(),
},
other => other.clone(),
}
}
fn bind_outer_refs_in_expr(
expr: &sql::LogicalExpr,
outer_row: &Tuple,
outer_schema: &Schema,
outer_table: &str,
) -> sql::LogicalExpr {
use sql::LogicalExpr;
match expr {
LogicalExpr::Column { table: Some(tbl), name } if tbl.eq_ignore_ascii_case(outer_table) => {
if let Some(idx) = outer_schema.columns.iter().position(|c| c.name.eq_ignore_ascii_case(name)) {
if let Some(v) = outer_row.values.get(idx) {
return LogicalExpr::Literal(v.clone());
}
}
expr.clone()
}
LogicalExpr::BinaryExpr { left, op, right } => LogicalExpr::BinaryExpr {
left: Box::new(Self::bind_outer_refs_in_expr(left, outer_row, outer_schema, outer_table)),
op: *op,
right: Box::new(Self::bind_outer_refs_in_expr(right, outer_row, outer_schema, outer_table)),
},
LogicalExpr::UnaryExpr { op, expr: inner } => LogicalExpr::UnaryExpr {
op: *op,
expr: Box::new(Self::bind_outer_refs_in_expr(inner, outer_row, outer_schema, outer_table)),
},
LogicalExpr::IsNull { expr: inner, is_null } => LogicalExpr::IsNull {
expr: Box::new(Self::bind_outer_refs_in_expr(inner, outer_row, outer_schema, outer_table)),
is_null: *is_null,
},
LogicalExpr::Between { expr: inner, low, high, negated } => LogicalExpr::Between {
expr: Box::new(Self::bind_outer_refs_in_expr(inner, outer_row, outer_schema, outer_table)),
low: Box::new(Self::bind_outer_refs_in_expr(low, outer_row, outer_schema, outer_table)),
high: Box::new(Self::bind_outer_refs_in_expr(high, outer_row, outer_schema, outer_table)),
negated: *negated,
},
LogicalExpr::ScalarSubquery { subquery } => LogicalExpr::ScalarSubquery {
subquery: Box::new(Self::bind_outer_refs_in_plan(subquery, outer_row, outer_schema, outer_table)),
},
other => other.clone(),
}
}
fn try_extract_pk_value(selection: Option<&sql::LogicalExpr>, schema: &Schema) -> Option<Value> {
let predicate = selection?;
let pk_col = schema.columns.iter().find(|c| c.primary_key)?;
if let sql::LogicalExpr::BinaryExpr { left, op: sql::BinaryOperator::Eq, right } = predicate {
match (left.as_ref(), right.as_ref()) {
(sql::LogicalExpr::Column { name, .. }, sql::LogicalExpr::Literal(val))
if name == &pk_col.name => Some(val.clone()),
(sql::LogicalExpr::Literal(val), sql::LogicalExpr::Column { name, .. })
if name == &pk_col.name => Some(val.clone()),
_ => None,
}
} else {
None
}
}
fn apply_rls_to_plan(&self, plan: sql::LogicalPlan) -> Result<sql::LogicalPlan> {
if self.tenant_manager.get_current_context().is_none() {
return Ok(plan);
}
self.apply_rls_to_plan_recursive(plan)
}
fn apply_rls_to_plan_recursive(&self, plan: sql::LogicalPlan) -> Result<sql::LogicalPlan> {
match plan {
sql::LogicalPlan::Scan { table_name, alias, schema, projection, as_of } => {
if self.tenant_manager.should_apply_rls(&table_name, "SELECT") {
if let Some((using_expr, _)) = self.tenant_manager.get_rls_conditions(&table_name, "SELECT") {
let tenant_context = self.tenant_manager.get_current_context();
let rls_evaluator = tenant::RLSExpressionEvaluator::new(
schema.clone(),
tenant_context
);
let filter_expr = rls_evaluator.parse(&using_expr)?;
let scan_plan = sql::LogicalPlan::Scan {
table_name,
alias: alias.clone(),
schema,
projection,
as_of,
};
return Ok(sql::LogicalPlan::Filter {
input: Box::new(scan_plan),
predicate: filter_expr,
});
}
}
Ok(sql::LogicalPlan::Scan { table_name, alias, schema, projection, as_of })
}
sql::LogicalPlan::Filter { input, predicate } => {
Ok(sql::LogicalPlan::Filter {
input: Box::new(self.apply_rls_to_plan_recursive(*input)?),
predicate,
})
}
sql::LogicalPlan::Project { input, exprs, aliases, distinct, distinct_on } => {
Ok(sql::LogicalPlan::Project {
input: Box::new(self.apply_rls_to_plan_recursive(*input)?),
exprs,
aliases,
distinct,
distinct_on,
})
}
sql::LogicalPlan::Aggregate { input, group_by, aggr_exprs, having } => {
Ok(sql::LogicalPlan::Aggregate {
input: Box::new(self.apply_rls_to_plan_recursive(*input)?),
group_by,
aggr_exprs,
having,
})
}
sql::LogicalPlan::Join { left, right, join_type, on, lateral } => {
Ok(sql::LogicalPlan::Join {
left: Box::new(self.apply_rls_to_plan_recursive(*left)?),
right: Box::new(self.apply_rls_to_plan_recursive(*right)?),
join_type,
on,
lateral,
})
}
sql::LogicalPlan::Sort { input, exprs, asc } => {
Ok(sql::LogicalPlan::Sort {
input: Box::new(self.apply_rls_to_plan_recursive(*input)?),
exprs,
asc,
})
}
sql::LogicalPlan::Limit { input, limit, offset, limit_param, offset_param } => {
Ok(sql::LogicalPlan::Limit {
input: Box::new(self.apply_rls_to_plan_recursive(*input)?),
limit,
offset,
limit_param,
offset_param,
})
}
sql::LogicalPlan::FilteredScan { table_name, alias, schema, projection, predicate, as_of } => {
if self.tenant_manager.should_apply_rls(&table_name, "SELECT") {
if let Some((using_expr, _)) = self.tenant_manager.get_rls_conditions(&table_name, "SELECT") {
let tenant_context = self.tenant_manager.get_current_context();
let rls_evaluator = tenant::RLSExpressionEvaluator::new(
schema.clone(),
tenant_context
);
let rls_predicate = rls_evaluator.parse(&using_expr)?;
let combined_predicate = if let Some(existing) = predicate {
Some(sql::LogicalExpr::BinaryExpr {
left: Box::new(existing),
op: sql::BinaryOperator::And,
right: Box::new(rls_predicate),
})
} else {
Some(rls_predicate)
};
return Ok(sql::LogicalPlan::FilteredScan {
table_name,
alias,
schema,
projection,
predicate: combined_predicate,
as_of,
});
}
}
Ok(sql::LogicalPlan::FilteredScan { table_name, alias, schema, projection, predicate, as_of })
}
other => Ok(other),
}
}
pub async fn start_auto_refresh(
&self,
config: Option<storage::AutoRefreshConfig>,
) -> Result<()> {
let worker_config = config.unwrap_or_else(|| {
storage::AutoRefreshConfig::default()
.with_enabled(true)
.with_interval_seconds(self.config.materialized_views.refresh_check_interval_secs)
.with_staleness_threshold(300) .with_max_cpu_percent(self.config.materialized_views.default_max_cpu_percent as f64)
.with_max_concurrent(self.config.materialized_views.max_concurrent_refreshes)
});
let mut worker = storage::AutoRefreshWorker::new(
worker_config,
std::sync::Arc::clone(&self.storage),
std::sync::Arc::clone(&self.mv_scheduler),
);
worker.start().await?;
*self.auto_refresh_worker.write() = Some(worker);
tracing::info!("Materialized view auto-refresh worker started");
Ok(())
}
pub async fn stop_auto_refresh(&self) -> Result<()> {
let worker = {
let mut worker_guard = self.auto_refresh_worker.write();
worker_guard.take()
};
if let Some(mut worker) = worker {
worker.stop().await?;
tracing::info!("Materialized view auto-refresh worker stopped");
}
Ok(())
}
pub fn is_auto_refresh_running(&self) -> bool {
self.auto_refresh_worker.read().as_ref()
.map(|w| w.is_running())
.unwrap_or(false)
}
pub fn mv_scheduler(&self) -> &std::sync::Arc<storage::MVScheduler> {
&self.mv_scheduler
}
pub fn check_mv_staleness_now(&self) -> Result<()> {
let worker_guard = self.auto_refresh_worker.read();
if let Some(ref worker) = *worker_guard {
worker.check_now()?;
Ok(())
} else {
Err(Error::query_execution("Auto-refresh worker is not running"))
}
}
}
pub struct Transaction<'a> {
tx: storage::Transaction,
db: &'a EmbeddedDatabase,
}
impl Transaction<'_> {
pub fn commit(self) -> Result<()> {
self.tx.commit()
}
pub fn rollback(self) -> Result<()> {
self.tx.rollback()
}
pub fn execute(&self, sql: &str) -> Result<u64> {
self.db.execute_in_transaction_no_fast_path(sql, &self.tx)
}
pub fn query(&self, sql: &str, _params: &[&dyn std::fmt::Display]) -> Result<Vec<Tuple>> {
let (statement, _) = self.db.parse_cached(sql)?;
let catalog = self.db.storage.catalog();
let planner = sql::Planner::with_catalog(&catalog)
.with_sql(sql.to_string());
let plan = planner.statement_to_plan(statement)?;
let mut executor = sql::Executor::with_storage(&self.db.storage)
.with_timeout(self.db.config.storage.query_timeout_ms)
.with_transaction(&self.tx);
executor.execute(&plan)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::indexing_slicing,
)]
mod tests {
use super::*;
#[test]
fn test_embedded_database_creation() {
let db = EmbeddedDatabase::new_in_memory();
assert!(db.is_ok());
}
#[test]
fn test_savepoint_basic_via_execute_works_in_transaction() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE sp_basic (id INT, val TEXT)").unwrap();
db.execute("BEGIN").unwrap();
let result = db.execute("SAVEPOINT s1");
assert!(result.is_ok(),
"SAVEPOINT via execute() in BEGIN block should succeed, got: {:?}", result.err());
db.execute("ROLLBACK").unwrap();
}
#[test]
fn test_savepoint_outside_transaction_succeeds_in_implicit_txn() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let result = db.execute("SAVEPOINT s1");
assert!(result.is_ok(),
"SAVEPOINT in implicit transaction should succeed, got: {:?}", result.err());
}
#[test]
fn test_savepoint_via_execute_returning_path() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE sp_ret (id INT, val TEXT)").unwrap();
db.execute_returning("BEGIN").unwrap();
let result = db.execute_returning("SAVEPOINT s1");
if result.is_ok() {
let _ = db.execute_returning("INSERT INTO sp_ret VALUES (1, 'test')");
let release_result = db.execute_returning("RELEASE SAVEPOINT s1");
assert!(release_result.is_ok(), "RELEASE SAVEPOINT should work via returning path");
let _ = db.execute_returning("COMMIT");
} else {
let err = result.unwrap_err().to_string();
assert!(err.contains("not yet implemented") || err.contains("SAVEPOINT"),
"Unexpected error: {}", err);
let _ = db.execute_returning("ROLLBACK");
}
}
#[test]
fn test_savepoint_nonexistent_rollback_errors() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE sp_noexist (id INT)").unwrap();
db.execute("BEGIN").unwrap();
let result = db.execute("ROLLBACK TO SAVEPOINT nonexistent");
assert!(result.is_err(), "ROLLBACK TO nonexistent savepoint should fail");
db.execute("ROLLBACK").unwrap();
}
#[test]
fn test_savepoint_nonexistent_release_errors() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("BEGIN").unwrap();
let result = db.execute("RELEASE SAVEPOINT nonexistent");
assert!(result.is_err(), "RELEASE nonexistent savepoint should fail");
db.execute("ROLLBACK").unwrap();
}
#[test]
fn test_savepoint_nested_release_via_returning() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE sp_nested_rel (id INT, val TEXT)").unwrap();
db.execute("BEGIN").unwrap();
let sp1 = db.execute_returning("SAVEPOINT s1");
if sp1.is_err() {
db.execute("ROLLBACK").unwrap();
return;
}
db.execute("INSERT INTO sp_nested_rel VALUES (1, 'A')").unwrap();
db.execute_returning("SAVEPOINT s2").unwrap();
db.execute("INSERT INTO sp_nested_rel VALUES (2, 'B')").unwrap();
db.execute_returning("RELEASE SAVEPOINT s2").unwrap();
db.execute_returning("RELEASE SAVEPOINT s1").unwrap();
db.execute("COMMIT").unwrap();
let rows = db.query("SELECT * FROM sp_nested_rel", &[]).unwrap();
assert_eq!(rows.len(), 2, "Both A and B should be preserved after nested RELEASE + COMMIT");
}
#[test]
fn test_savepoint_rollback_to_undoes_inserts() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE sp_stub (id INT, val TEXT)").unwrap();
db.execute("BEGIN").unwrap();
let sp1 = db.execute_returning("SAVEPOINT s1");
if sp1.is_err() {
db.execute("ROLLBACK").unwrap();
return;
}
db.execute("INSERT INTO sp_stub VALUES (1, 'should_vanish')").unwrap();
let rb = db.execute_returning("ROLLBACK TO SAVEPOINT s1");
if rb.is_err() {
db.execute("ROLLBACK").unwrap();
return;
}
db.execute("COMMIT").unwrap();
let rows = db.query("SELECT * FROM sp_stub", &[]).unwrap();
assert_eq!(rows.len(), 0,
"ROLLBACK TO SAVEPOINT should undo INSERTs via transaction write set");
}
#[test]
fn test_savepoint_reuse_name_after_release_via_returning() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE sp_reuse (id INT)").unwrap();
db.execute("BEGIN").unwrap();
let sp1 = db.execute_returning("SAVEPOINT s1");
if sp1.is_err() {
db.execute("ROLLBACK").unwrap();
return;
}
db.execute("INSERT INTO sp_reuse VALUES (1)").unwrap();
db.execute_returning("RELEASE SAVEPOINT s1").unwrap();
db.execute_returning("SAVEPOINT s1").unwrap();
db.execute("INSERT INTO sp_reuse VALUES (2)").unwrap();
db.execute_returning("RELEASE SAVEPOINT s1").unwrap();
db.execute("COMMIT").unwrap();
let rows = db.query("SELECT * FROM sp_reuse", &[]).unwrap();
assert_eq!(rows.len(), 2, "Both inserts should persist after reuse of savepoint name");
}
#[test]
fn test_transaction_read_committed() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE iso_rc (id INT, val TEXT)").unwrap();
let s1 = db.create_session("user1", crate::session::IsolationLevel::ReadCommitted).unwrap();
let s2 = db.create_session("user2", crate::session::IsolationLevel::ReadCommitted).unwrap();
db.begin_transaction_for_session(s1).unwrap();
db.execute_in_session(s1, "INSERT INTO iso_rc VALUES (1, 'uncommitted')").unwrap();
let rows = db.query_in_session(s2, "SELECT * FROM iso_rc", &[]).unwrap();
assert_eq!(rows.len(), 0,
"Uncommitted writes from S1 should be invisible to S2 (read committed)");
db.commit_transaction_for_session(s1).unwrap();
let rows = db.query_in_session(s2, "SELECT * FROM iso_rc WHERE 1=1", &[]).unwrap();
assert_eq!(rows.len(), 1, "After S1 commits, S2 should see the row");
db.destroy_session(s1).unwrap();
db.destroy_session(s2).unwrap();
}
#[test]
fn test_transaction_dirty_read_prevented() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE iso_dirty (id INT, val TEXT)").unwrap();
db.execute("INSERT INTO iso_dirty VALUES (1, 'visible')").unwrap();
let s1 = db.create_session("writer", crate::session::IsolationLevel::ReadCommitted).unwrap();
let s2 = db.create_session("reader", crate::session::IsolationLevel::ReadCommitted).unwrap();
db.begin_transaction_for_session(s1).unwrap();
db.execute_in_session(s1, "INSERT INTO iso_dirty VALUES (2, 'dirty')").unwrap();
let rows = db.query_in_session(s2, "SELECT * FROM iso_dirty", &[]).unwrap();
assert_eq!(rows.len(), 1, "S2 should only see committed data, not dirty writes");
assert_eq!(rows[0].get(1), Some(&Value::String("visible".to_string())));
db.rollback_transaction_for_session(s1).unwrap();
db.destroy_session(s1).unwrap();
db.destroy_session(s2).unwrap();
}
#[test]
fn test_transaction_rollback_visibility() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE iso_rb_vis (id INT, val TEXT)").unwrap();
let s1 = db.create_session("writer", crate::session::IsolationLevel::ReadCommitted).unwrap();
let s2 = db.create_session("reader", crate::session::IsolationLevel::ReadCommitted).unwrap();
db.begin_transaction_for_session(s1).unwrap();
db.execute_in_session(s1, "INSERT INTO iso_rb_vis VALUES (1, 'rolled_back')").unwrap();
db.rollback_transaction_for_session(s1).unwrap();
let rows = db.query_in_session(s2, "SELECT * FROM iso_rb_vis", &[]).unwrap();
assert_eq!(rows.len(), 0, "Rolled-back data should never be visible");
let rows = db.query("SELECT * FROM iso_rb_vis", &[]).unwrap();
assert_eq!(rows.len(), 0, "Rolled-back data should be invisible via default query path too");
db.destroy_session(s1).unwrap();
db.destroy_session(s2).unwrap();
}
#[test]
fn test_concurrent_inserts_different_rows() {
use std::sync::Arc;
let db = Arc::new(EmbeddedDatabase::new_in_memory().unwrap());
db.execute("CREATE TABLE conc_ins (id INT, thread_id INT)").unwrap();
let num_threads = 4;
let rows_per_thread = 25;
let mut handles = Vec::new();
for t in 0..num_threads {
let db_clone = Arc::clone(&db);
let handle = std::thread::spawn(move || {
for i in 0..rows_per_thread {
let id = t * rows_per_thread + i;
db_clone.execute(
&format!("INSERT INTO conc_ins VALUES ({}, {})", id, t)
).unwrap();
}
});
handles.push(handle);
}
for h in handles {
h.join().expect("Thread panicked");
}
let rows = db.query("SELECT * FROM conc_ins", &[]).unwrap();
assert_eq!(rows.len(), (num_threads * rows_per_thread) as usize,
"All inserts from all threads should be visible");
}
#[test]
fn test_concurrent_reads_during_write() {
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
let db = Arc::new(EmbeddedDatabase::new_in_memory().unwrap());
db.execute("CREATE TABLE conc_rw (id INT, val TEXT)").unwrap();
let done = Arc::new(AtomicBool::new(false));
let db_w = Arc::clone(&db);
let done_w = Arc::clone(&done);
let writer = std::thread::spawn(move || {
for i in 0..50 {
db_w.execute(&format!("INSERT INTO conc_rw VALUES ({}, 'row_{}')", i, i)).unwrap();
}
done_w.store(true, Ordering::Release);
});
let mut readers = Vec::new();
for t in 0..3_usize {
let db_r = Arc::clone(&db);
let done_r = Arc::clone(&done);
let reader = std::thread::spawn(move || {
let mut query_count = 0_usize;
while !done_r.load(Ordering::Acquire) {
let sql = format!(
"SELECT * FROM conc_rw WHERE 1=1 /* t{}q{} */", t, query_count
);
let rows = db_r.query(&sql, &[]).unwrap();
assert!(rows.len() <= 50, "Should never exceed 50 rows");
query_count += 1;
std::thread::yield_now();
}
assert!(query_count > 0, "Reader should have executed at least one query");
});
readers.push(reader);
}
writer.join().expect("Writer panicked");
for r in readers {
r.join().expect("Reader panicked");
}
let final_rows = db.query("SELECT * FROM conc_rw WHERE 1=1 /* final */", &[]).unwrap();
assert_eq!(final_rows.len(), 50, "All 50 rows should be visible after writer completes");
}
#[test]
fn test_concurrent_counter_increment() {
use std::sync::Arc;
let db = Arc::new(EmbeddedDatabase::new_in_memory().unwrap());
db.execute("CREATE TABLE conc_counter (id INT, cnt INT)").unwrap();
db.execute("INSERT INTO conc_counter VALUES (1, 0)").unwrap();
let num_threads = 4;
let increments_per_thread = 10;
let mut handles = Vec::new();
for _ in 0..num_threads {
let db_clone = Arc::clone(&db);
let handle = std::thread::spawn(move || {
for _ in 0..increments_per_thread {
let rows = db_clone.query("SELECT cnt FROM conc_counter WHERE id = 1", &[]).unwrap();
if let Some(row) = rows.first() {
if let Some(Value::Int4(current)) = row.get(0) {
let new_val = current + 1;
db_clone.execute(
&format!("UPDATE conc_counter SET cnt = {} WHERE id = 1", new_val)
).unwrap();
}
}
}
});
handles.push(handle);
}
for h in handles {
h.join().expect("Thread panicked");
}
let rows = db.query("SELECT cnt FROM conc_counter WHERE id = 1", &[]).unwrap();
assert_eq!(rows.len(), 1, "Counter row should still exist");
if let Some(Value::Int4(final_val)) = rows[0].get(0) {
let max_expected = (num_threads * increments_per_thread) as i32;
assert!(*final_val > 0, "Counter should have been incremented at least once");
assert!(*final_val <= max_expected,
"Counter {} should not exceed {}", final_val, max_expected);
if *final_val < max_expected {
}
} else {
panic!("Counter value should be Int4");
}
}
#[test]
fn test_concurrent_transactions_different_tables() {
use std::sync::Arc;
let db = Arc::new(EmbeddedDatabase::new_in_memory().unwrap());
let num_threads = 4;
let mut handles = Vec::new();
for t in 0..num_threads {
db.execute(&format!("CREATE TABLE conc_tbl_{} (id INT, val TEXT)", t)).unwrap();
}
for t in 0..num_threads {
let db_clone = Arc::clone(&db);
let handle = std::thread::spawn(move || {
let session = db_clone.create_session(
&format!("user{}", t),
crate::session::IsolationLevel::ReadCommitted,
).unwrap();
db_clone.begin_transaction_for_session(session).unwrap();
for i in 0..10 {
db_clone.execute_in_session(session,
&format!("INSERT INTO conc_tbl_{} VALUES ({}, 'val_{}')", t, i, i)
).unwrap();
}
db_clone.commit_transaction_for_session(session).unwrap();
db_clone.destroy_session(session).unwrap();
});
handles.push(handle);
}
for h in handles {
h.join().expect("Thread panicked");
}
for t in 0..num_threads {
let rows = db.query(&format!("SELECT * FROM conc_tbl_{}", t), &[]).unwrap();
assert_eq!(rows.len(), 10,
"Table conc_tbl_{} should have 10 rows, got {}", t, rows.len());
}
}
#[test]
fn test_transaction_double_commit() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE dbl_commit (id INT)").unwrap();
db.execute("BEGIN").unwrap();
db.execute("INSERT INTO dbl_commit VALUES (1)").unwrap();
db.execute("COMMIT").unwrap();
let result = db.execute("COMMIT");
assert!(result.is_err(), "Second COMMIT without active transaction should fail");
}
#[test]
fn test_transaction_double_rollback() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("BEGIN").unwrap();
db.execute("ROLLBACK").unwrap();
let result = db.execute("ROLLBACK");
assert!(result.is_err(), "Second ROLLBACK without active transaction should fail");
}
#[test]
fn test_autocommit_mode() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE autocommit (id INT, val TEXT)").unwrap();
db.execute("INSERT INTO autocommit VALUES (1, 'a')").unwrap();
db.execute("INSERT INTO autocommit VALUES (2, 'b')").unwrap();
db.execute("INSERT INTO autocommit VALUES (3, 'c')").unwrap();
let rows = db.query("SELECT * FROM autocommit", &[]).unwrap();
assert_eq!(rows.len(), 3, "All auto-committed inserts should be visible");
db.execute("UPDATE autocommit SET val = 'updated' WHERE id = 2").unwrap();
let rows = db.query("SELECT val FROM autocommit WHERE id = 2", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get(0), Some(&Value::String("updated".to_string())));
db.execute("DELETE FROM autocommit WHERE id = 3").unwrap();
let rows = db.query("SELECT * FROM autocommit", &[]).unwrap();
assert_eq!(rows.len(), 2, "Delete should be auto-committed");
}
#[test]
fn test_ddl_in_transaction_commit() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("BEGIN").unwrap();
db.execute("CREATE TABLE ddl_txn (id INT, val TEXT)").unwrap();
db.execute("INSERT INTO ddl_txn VALUES (1, 'hello')").unwrap();
db.execute("COMMIT").unwrap();
let rows = db.query("SELECT * FROM ddl_txn", &[]).unwrap();
assert_eq!(rows.len(), 1, "DDL + DML in committed transaction should persist");
}
#[test]
fn test_ddl_in_transaction_rollback() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("BEGIN").unwrap();
db.execute("CREATE TABLE ddl_rb (id INT, val TEXT)").unwrap();
db.execute("INSERT INTO ddl_rb VALUES (1, 'hello')").unwrap();
db.execute("ROLLBACK").unwrap();
let query_result = db.query("SELECT * FROM ddl_rb", &[]);
if let Ok(rows) = query_result {
assert!(rows.is_empty() || rows.len() == 1,
"DDL rollback behavior: table exists with {} rows", rows.len());
}
}
#[test]
fn test_empty_transaction_commit() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("BEGIN").unwrap();
assert!(db.in_transaction());
db.execute("COMMIT").unwrap();
assert!(!db.in_transaction());
}
#[test]
fn test_empty_transaction_rollback() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("BEGIN").unwrap();
assert!(db.in_transaction());
db.execute("ROLLBACK").unwrap();
assert!(!db.in_transaction());
}
#[test]
fn test_transaction_after_error() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE txn_err (id INT, val TEXT)").unwrap();
db.execute("BEGIN").unwrap();
let result = db.execute("INSERT INTO nonexistent_table VALUES (1)");
assert!(result.is_err(), "Insert into nonexistent table should fail");
assert!(db.in_transaction(), "Transaction should still be active after statement error");
db.execute("INSERT INTO txn_err VALUES (1, 'after_error')").unwrap();
db.execute("COMMIT").unwrap();
let rows = db.query("SELECT * FROM txn_err", &[]).unwrap();
assert_eq!(rows.len(), 1, "Valid insert after error should be committed");
assert_eq!(rows[0].get(1), Some(&Value::String("after_error".to_string())));
}
#[test]
fn test_begin_while_in_transaction_errors() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("BEGIN").unwrap();
let result = db.execute("BEGIN");
assert!(result.is_err(), "Nested BEGIN should fail");
assert!(result.unwrap_err().to_string().contains("already active"),
"Error should mention transaction already active");
db.execute("ROLLBACK").unwrap();
}
#[test]
fn test_transaction_commit_then_new_transaction() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE txn_seq (id INT)").unwrap();
db.execute("BEGIN").unwrap();
db.execute("INSERT INTO txn_seq VALUES (1)").unwrap();
db.execute("COMMIT").unwrap();
db.execute("BEGIN").unwrap();
db.execute("INSERT INTO txn_seq VALUES (2)").unwrap();
db.execute("COMMIT").unwrap();
db.execute("BEGIN").unwrap();
db.execute("INSERT INTO txn_seq VALUES (3)").unwrap();
db.execute("ROLLBACK").unwrap();
db.execute("BEGIN").unwrap();
db.execute("INSERT INTO txn_seq VALUES (4)").unwrap();
db.execute("COMMIT").unwrap();
let rows = db.query("SELECT * FROM txn_seq", &[]).unwrap();
assert_eq!(rows.len(), 3, "Rows from txn 1, 2, 4 should exist (txn 3 rolled back)");
}
#[test]
fn test_insert_rollback_pk_reuse() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE pk_reuse (id INT PRIMARY KEY, val TEXT)").unwrap();
db.execute("BEGIN").unwrap();
db.execute("INSERT INTO pk_reuse VALUES (1, 'rolled_back')").unwrap();
db.execute("ROLLBACK").unwrap();
db.execute("INSERT INTO pk_reuse VALUES (1, 'final')").unwrap();
let rows = db.query("SELECT val FROM pk_reuse WHERE id = 1", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get(0), Some(&Value::String("final".to_string())));
}
#[test]
fn test_update_rollback_preserves_original() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE upd_rb (id INT, val TEXT)").unwrap();
db.execute("INSERT INTO upd_rb VALUES (1, 'original')").unwrap();
db.execute("BEGIN").unwrap();
db.execute("UPDATE upd_rb SET val = 'changed' WHERE id = 1").unwrap();
db.execute("ROLLBACK").unwrap();
let rows = db.query("SELECT val FROM upd_rb WHERE id = 1", &[]).unwrap();
assert_eq!(rows.len(), 1);
let val = rows[0].get(0);
assert_eq!(val, Some(&Value::String("original".to_string())),
"ROLLBACK should undo the UPDATE");
if true {
}
}
#[test]
fn test_delete_rollback_preserves_row() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE del_rb (id INT, val TEXT)").unwrap();
db.execute("INSERT INTO del_rb VALUES (1, 'keep_me')").unwrap();
db.execute("BEGIN").unwrap();
db.execute("DELETE FROM del_rb WHERE id = 1").unwrap();
db.execute("ROLLBACK").unwrap();
let rows = db.query("SELECT * FROM del_rb", &[]).unwrap();
assert_eq!(rows.len(), 1, "ROLLBACK should undo the DELETE");
assert_eq!(rows[0].get(1), Some(&Value::String("keep_me".to_string())));
}
#[test]
fn test_insert_commit_data_integrity() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE integrity (id INT, name TEXT, score FLOAT, active BOOLEAN)").unwrap();
db.execute("BEGIN").unwrap();
db.execute("INSERT INTO integrity VALUES (42, 'test_name', 3.14, true)").unwrap();
db.execute("COMMIT").unwrap();
let rows = db.query("SELECT * FROM integrity WHERE id = 42", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get(0), Some(&Value::Int4(42)));
assert_eq!(rows[0].get(1), Some(&Value::String("test_name".to_string())));
if let Some(Value::Float8(f)) = rows[0].get(2) {
assert!((f - 3.14).abs() < 0.001, "Float should be ~3.14, got {}", f);
} else if let Some(Value::Float4(f)) = rows[0].get(2) {
assert!((f - 3.14_f32).abs() < 0.01, "Float should be ~3.14, got {}", f);
} else {
panic!("Score should be a float type, got {:?}", rows[0].get(2));
}
assert_eq!(rows[0].get(3), Some(&Value::Boolean(true)));
}
#[test]
fn test_multiple_inserts_rollback_clears_all() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE multi_rb (id INT, val TEXT)").unwrap();
db.execute("BEGIN").unwrap();
for i in 1..=10 {
db.execute(&format!("INSERT INTO multi_rb VALUES ({}, 'row_{}')", i, i)).unwrap();
}
db.execute("ROLLBACK").unwrap();
let rows = db.query("SELECT * FROM multi_rb", &[]).unwrap();
assert_eq!(rows.len(), 0, "All 10 inserts should be rolled back");
}
#[test]
fn test_transaction_with_multiple_tables() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE multi_a (id INT, val TEXT)").unwrap();
db.execute("CREATE TABLE multi_b (id INT, ref_id INT)").unwrap();
db.execute("BEGIN").unwrap();
db.execute("INSERT INTO multi_a VALUES (1, 'parent')").unwrap();
db.execute("INSERT INTO multi_b VALUES (100, 1)").unwrap();
db.execute("INSERT INTO multi_b VALUES (101, 1)").unwrap();
db.execute("COMMIT").unwrap();
let rows_a = db.query("SELECT * FROM multi_a", &[]).unwrap();
let rows_b = db.query("SELECT * FROM multi_b", &[]).unwrap();
assert_eq!(rows_a.len(), 1, "Parent table should have 1 row");
assert_eq!(rows_b.len(), 2, "Child table should have 2 rows");
}
#[test]
fn test_transaction_with_multiple_tables_rollback() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE multi_rb_a (id INT, val TEXT)").unwrap();
db.execute("CREATE TABLE multi_rb_b (id INT, ref_id INT)").unwrap();
db.execute("BEGIN").unwrap();
db.execute("INSERT INTO multi_rb_a VALUES (1, 'parent')").unwrap();
db.execute("INSERT INTO multi_rb_b VALUES (100, 1)").unwrap();
db.execute("INSERT INTO multi_rb_b VALUES (101, 1)").unwrap();
db.execute("ROLLBACK").unwrap();
let rows_a = db.query("SELECT * FROM multi_rb_a", &[]).unwrap();
let rows_b = db.query("SELECT * FROM multi_rb_b", &[]).unwrap();
assert_eq!(rows_a.len(), 0, "Parent table should be empty after rollback");
assert_eq!(rows_b.len(), 0, "Child table should be empty after rollback");
}
#[test]
fn test_transaction_handle_commit() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE txn_handle (id INT, val TEXT)").unwrap();
let tx = db.begin_transaction().unwrap();
tx.execute("INSERT INTO txn_handle VALUES (1, 'via_handle')").unwrap();
tx.execute("INSERT INTO txn_handle VALUES (2, 'via_handle')").unwrap();
tx.commit().unwrap();
let rows = db.query("SELECT * FROM txn_handle", &[]).unwrap();
assert_eq!(rows.len(), 2, "Both inserts via Transaction handle should be committed");
}
#[test]
fn test_transaction_handle_rollback() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE txn_h_rb (id INT, val TEXT)").unwrap();
db.execute("INSERT INTO txn_h_rb VALUES (0, 'pre_existing')").unwrap();
let tx = db.begin_transaction().unwrap();
tx.execute("INSERT INTO txn_h_rb VALUES (1, 'will_rollback')").unwrap();
tx.rollback().unwrap();
let rows = db.query("SELECT * FROM txn_h_rb", &[]).unwrap();
assert_eq!(rows.len(), 1, "Only pre-existing row should remain after rollback");
assert_eq!(rows[0].get(1), Some(&Value::String("pre_existing".to_string())));
}
#[test]
fn test_transaction_handle_query() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE txn_h_q (id INT, val TEXT)").unwrap();
db.execute("INSERT INTO txn_h_q VALUES (1, 'committed')").unwrap();
let tx = db.begin_transaction().unwrap();
tx.execute("INSERT INTO txn_h_q VALUES (2, 'in_txn')").unwrap();
let rows = tx.query("SELECT * FROM txn_h_q", &[]).unwrap();
assert!(rows.len() >= 1, "Should see at least the committed row");
tx.commit().unwrap();
let rows = db.query("SELECT * FROM txn_h_q", &[]).unwrap();
assert_eq!(rows.len(), 2, "Both rows should be visible after commit");
}
#[test]
fn test_session_sequential_transactions() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE sess_seq (id INT, val TEXT)").unwrap();
let s1 = db.create_session("user1", crate::session::IsolationLevel::ReadCommitted).unwrap();
db.begin_transaction_for_session(s1).unwrap();
db.execute_in_session(s1, "INSERT INTO sess_seq VALUES (1, 'first')").unwrap();
db.commit_transaction_for_session(s1).unwrap();
db.begin_transaction_for_session(s1).unwrap();
db.execute_in_session(s1, "INSERT INTO sess_seq VALUES (2, 'second')").unwrap();
db.commit_transaction_for_session(s1).unwrap();
let rows = db.query("SELECT * FROM sess_seq", &[]).unwrap();
assert_eq!(rows.len(), 2, "Both sequential transactions should have committed");
db.destroy_session(s1).unwrap();
}
#[test]
fn test_session_rollback_then_new_transaction() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE sess_rb_new (id INT, val TEXT)").unwrap();
let s1 = db.create_session("user1", crate::session::IsolationLevel::ReadCommitted).unwrap();
db.begin_transaction_for_session(s1).unwrap();
db.execute_in_session(s1, "INSERT INTO sess_rb_new VALUES (1, 'rolled_back')").unwrap();
db.rollback_transaction_for_session(s1).unwrap();
db.begin_transaction_for_session(s1).unwrap();
db.execute_in_session(s1, "INSERT INTO sess_rb_new VALUES (2, 'committed')").unwrap();
db.commit_transaction_for_session(s1).unwrap();
let rows = db.query("SELECT * FROM sess_rb_new", &[]).unwrap();
assert_eq!(rows.len(), 1, "Only the committed transaction's data should exist");
assert_eq!(rows[0].get(1), Some(&Value::String("committed".to_string())));
db.destroy_session(s1).unwrap();
}
#[test]
fn test_session_double_begin_errors() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let s1 = db.create_session("user1", crate::session::IsolationLevel::ReadCommitted).unwrap();
db.begin_transaction_for_session(s1).unwrap();
let result = db.begin_transaction_for_session(s1);
assert!(result.is_err(), "Double BEGIN on same session should fail");
db.rollback_transaction_for_session(s1).unwrap();
db.destroy_session(s1).unwrap();
}
#[test]
fn test_session_commit_without_transaction_errors() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let s1 = db.create_session("user1", crate::session::IsolationLevel::ReadCommitted).unwrap();
let result = db.commit_transaction_for_session(s1);
assert!(result.is_err(), "COMMIT without active transaction should fail");
db.destroy_session(s1).unwrap();
}
#[test]
fn test_session_rollback_without_transaction_errors() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let s1 = db.create_session("user1", crate::session::IsolationLevel::ReadCommitted).unwrap();
let result = db.rollback_transaction_for_session(s1);
assert!(result.is_err(), "ROLLBACK without active transaction should fail");
db.destroy_session(s1).unwrap();
}
#[test]
fn test_insert_visible_in_same_transaction() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE t_ryow (id INT PRIMARY KEY, v TEXT)").unwrap();
db.begin().unwrap();
db.execute("INSERT INTO t_ryow VALUES (1, 'hello')").unwrap();
let rows = db.query("SELECT * FROM t_ryow", &[]).unwrap();
assert_eq!(rows.len(), 1, "INSERT must be visible to SELECT within the same transaction");
db.commit().unwrap();
}
#[test]
fn test_update_visible_in_same_transaction() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE t_ryow2 (id INT PRIMARY KEY, v TEXT)").unwrap();
db.execute("INSERT INTO t_ryow2 VALUES (1, 'before')").unwrap();
db.begin().unwrap();
db.execute("UPDATE t_ryow2 SET v = 'after' WHERE id = 1").unwrap();
let rows = db.query("SELECT * FROM t_ryow2 WHERE id = 1", &[]).unwrap();
assert_eq!(rows.len(), 1);
let val = &rows[0].values[1];
assert_eq!(val, &Value::String("after".to_string()),
"UPDATE must be visible to SELECT within the same transaction");
db.commit().unwrap();
}
#[test]
fn test_delete_visible_in_same_transaction() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE t_ryow3 (id INT PRIMARY KEY, v TEXT)").unwrap();
db.execute("INSERT INTO t_ryow3 VALUES (1, 'gone')").unwrap();
db.begin().unwrap();
db.execute("DELETE FROM t_ryow3 WHERE id = 1").unwrap();
let rows = db.query("SELECT * FROM t_ryow3", &[]).unwrap();
assert_eq!(rows.len(), 0,
"DELETE must be reflected in SELECT within the same transaction");
db.commit().unwrap();
}
#[test]
fn test_multiple_inserts_visible_in_same_transaction() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE t_ryow4 (id INT PRIMARY KEY, v TEXT)").unwrap();
db.begin().unwrap();
db.execute("INSERT INTO t_ryow4 VALUES (1, 'a')").unwrap();
db.execute("INSERT INTO t_ryow4 VALUES (2, 'b')").unwrap();
db.execute("INSERT INTO t_ryow4 VALUES (3, 'c')").unwrap();
let rows = db.query("SELECT * FROM t_ryow4", &[]).unwrap();
assert_eq!(rows.len(), 3,
"All INSERTs must be visible to SELECT within the same transaction");
db.commit().unwrap();
}
#[test]
fn test_rollback_hides_inserts() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE t_ryow5 (id INT PRIMARY KEY, v TEXT)").unwrap();
db.begin().unwrap();
db.execute("INSERT INTO t_ryow5 VALUES (1, 'temp')").unwrap();
let rows = db.query("SELECT * FROM t_ryow5", &[]).unwrap();
assert_eq!(rows.len(), 1);
db.rollback().unwrap();
let rows = db.query("SELECT * FROM t_ryow5", &[]).unwrap();
assert_eq!(rows.len(), 0,
"After ROLLBACK, inserted data must not be visible");
}
fn setup_window_test_db() -> EmbeddedDatabase {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute(
"CREATE TABLE employees (id INT PRIMARY KEY, name TEXT, dept TEXT, salary INT, age INT)",
)
.unwrap();
db.execute("INSERT INTO employees (id, name, dept, salary, age) VALUES (1, 'Alice', 'Engineering', 120000, 35)").unwrap();
db.execute("INSERT INTO employees (id, name, dept, salary, age) VALUES (2, 'Bob', 'Engineering', 110000, 28)").unwrap();
db.execute("INSERT INTO employees (id, name, dept, salary, age) VALUES (3, 'Charlie', 'Engineering', 110000, 32)").unwrap();
db.execute("INSERT INTO employees (id, name, dept, salary, age) VALUES (4, 'Dave', 'Sales', 90000, 40)").unwrap();
db.execute("INSERT INTO employees (id, name, dept, salary, age) VALUES (5, 'Eve', 'Sales', 95000, 25)").unwrap();
db.execute("INSERT INTO employees (id, name, dept, salary, age) VALUES (6, 'Frank', 'Marketing', 80000, 45)").unwrap();
db
}
#[test]
fn test_window_row_number_basic() {
let db = setup_window_test_db();
let results = db
.query(
"SELECT name, salary, ROW_NUMBER() OVER (ORDER BY salary DESC) FROM employees",
&[],
)
.unwrap();
assert_eq!(results.len(), 6, "Should return all 6 employees");
let row_nums: std::collections::HashSet<i64> = results
.iter()
.map(|r| match r.get(2).unwrap() {
Value::Int8(v) => *v,
_ => panic!("expected Int8"),
})
.collect();
assert_eq!(row_nums.len(), 6);
for n in 1..=6 {
assert!(row_nums.contains(&n), "Should contain row_number {}", n);
}
for row in &results {
let sal = match row.get(1).unwrap() {
Value::Int4(v) => *v as i64,
Value::Int8(v) => *v,
_ => panic!("unexpected type"),
};
let rn = match row.get(2).unwrap() {
Value::Int8(v) => *v,
_ => panic!("expected Int8"),
};
if sal == 120000 {
assert_eq!(rn, 1, "Highest salary should have row_number 1");
}
}
}
#[test]
fn test_window_row_number_partitioned() {
let db = setup_window_test_db();
let results = db
.query(
"SELECT name, dept, ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary DESC) FROM employees",
&[],
)
.unwrap();
assert_eq!(results.len(), 6);
let mut dept_row_nums: std::collections::HashMap<String, Vec<i64>> =
std::collections::HashMap::new();
for row in &results {
if let (Some(Value::String(dept)), Some(Value::Int8(rn))) = (row.get(1), row.get(2)) {
dept_row_nums
.entry(dept.clone())
.or_default()
.push(*rn);
}
}
if let Some(eng) = dept_row_nums.get("Engineering") {
let mut sorted = eng.clone();
sorted.sort();
assert_eq!(sorted, vec![1, 2, 3]);
}
if let Some(mkt) = dept_row_nums.get("Marketing") {
assert_eq!(mkt, &vec![1]);
}
}
#[test]
fn test_window_rank_basic() {
let db = setup_window_test_db();
let results = db
.query(
"SELECT name, salary, RANK() OVER (ORDER BY salary DESC) FROM employees",
&[],
)
.unwrap();
assert_eq!(results.len(), 6);
let ranks: Vec<(i64, i64)> = results
.iter()
.map(|r| {
let sal = match r.get(1).unwrap() {
Value::Int4(v) => *v as i64,
Value::Int8(v) => *v,
_ => panic!("unexpected salary type"),
};
let rank = match r.get(2).unwrap() {
Value::Int8(v) => *v,
_ => panic!("unexpected rank type"),
};
(sal, rank)
})
.collect();
let rank_120k: Vec<i64> = ranks.iter().filter(|(s, _)| *s == 120000).map(|(_, r)| *r).collect();
assert!(rank_120k.iter().all(|r| *r == 1), "120000 should have rank 1");
let all_ranks: std::collections::HashSet<i64> = ranks.iter().map(|(_, r)| *r).collect();
assert_eq!(all_ranks.len(), 5, "RANK correctly detects ties on ORDER BY columns");
}
#[test]
fn test_window_rank_with_ties() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE scores (id INT PRIMARY KEY, score INT)").unwrap();
db.execute("INSERT INTO scores (id, score) VALUES (1, 100)").unwrap();
db.execute("INSERT INTO scores (id, score) VALUES (2, 90)").unwrap();
db.execute("INSERT INTO scores (id, score) VALUES (3, 90)").unwrap();
db.execute("INSERT INTO scores (id, score) VALUES (4, 80)").unwrap();
let results = db
.query(
"SELECT id, score, RANK() OVER (ORDER BY score DESC) FROM scores",
&[],
)
.unwrap();
assert_eq!(results.len(), 4);
let ranks: Vec<i64> = results
.iter()
.map(|r| match r.get(2).unwrap() {
Value::Int8(v) => *v,
_ => panic!("expected Int8"),
})
.collect();
let mut sorted_ranks = ranks.clone();
sorted_ranks.sort();
assert_eq!(
sorted_ranks,
vec![1, 2, 2, 4],
"RANK correctly detects ties per SQL standard"
);
}
#[test]
fn test_window_dense_rank_basic() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE scores (id INT PRIMARY KEY, score INT)").unwrap();
db.execute("INSERT INTO scores (id, score) VALUES (1, 100)").unwrap();
db.execute("INSERT INTO scores (id, score) VALUES (2, 90)").unwrap();
db.execute("INSERT INTO scores (id, score) VALUES (3, 90)").unwrap();
db.execute("INSERT INTO scores (id, score) VALUES (4, 80)").unwrap();
let results = db
.query(
"SELECT id, score, DENSE_RANK() OVER (ORDER BY score DESC) FROM scores",
&[],
)
.unwrap();
assert_eq!(results.len(), 4);
let ranks: Vec<i64> = results
.iter()
.map(|r| match r.get(2).unwrap() {
Value::Int8(v) => *v,
_ => panic!("expected Int8"),
})
.collect();
let mut sorted_ranks = ranks.clone();
sorted_ranks.sort();
assert_eq!(
sorted_ranks,
vec![1, 2, 2, 3],
"DENSE_RANK correctly detects ties per SQL standard (no gaps)"
);
}
#[test]
fn test_window_ntile_basic() {
let db = setup_window_test_db();
let results = db
.query(
"SELECT name, NTILE(3) OVER (ORDER BY salary) FROM employees",
&[],
)
.unwrap();
assert_eq!(results.len(), 6);
let buckets: Vec<i64> = results
.iter()
.map(|r| match r.get(1).unwrap() {
Value::Int8(v) => *v,
_ => panic!("expected Int8"),
})
.collect();
assert!(
buckets.iter().all(|b| *b >= 1 && *b <= 3),
"All buckets should be 1..=3"
);
for bucket in 1..=3 {
let count = buckets.iter().filter(|&&b| b == bucket).count();
assert_eq!(count, 2, "Bucket {} should have 2 rows", bucket);
}
}
#[test]
fn test_window_ntile_uneven() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
for i in 1..=7 {
db.execute(&format!(
"INSERT INTO nums (id, val) VALUES ({}, {})",
i,
i * 10
))
.unwrap();
}
let results = db
.query("SELECT val, NTILE(3) OVER (ORDER BY val) FROM nums", &[])
.unwrap();
assert_eq!(results.len(), 7);
let buckets: Vec<i64> = results
.iter()
.map(|r| match r.get(1).unwrap() {
Value::Int8(v) => *v,
_ => panic!("expected Int8"),
})
.collect();
assert!(
buckets.iter().all(|b| *b >= 1 && *b <= 3),
"All buckets should be 1..=3"
);
}
#[test]
fn test_window_lag_basic() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (3, 30)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (4, 40)").unwrap();
let results = db
.query(
"SELECT val, LAG(val, 1) OVER (ORDER BY val) FROM nums",
&[],
)
.unwrap();
assert_eq!(results.len(), 4);
assert_eq!(results[0].get(1).unwrap(), &Value::Null);
let lag_val = results[1].get(1).unwrap();
assert!(
matches!(lag_val, Value::Int4(10) | Value::Int8(10)),
"LAG of second row should be 10, got {:?}",
lag_val
);
}
#[test]
fn test_window_lag_offset_2() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
for i in 1..=5 {
db.execute(&format!(
"INSERT INTO nums (id, val) VALUES ({}, {})",
i,
i * 10
))
.unwrap();
}
let results = db
.query(
"SELECT val, LAG(val, 2) OVER (ORDER BY val) FROM nums",
&[],
)
.unwrap();
assert_eq!(results.len(), 5);
assert_eq!(results[0].get(1).unwrap(), &Value::Null);
assert_eq!(results[1].get(1).unwrap(), &Value::Null);
let lag_val = results[2].get(1).unwrap();
assert!(
matches!(lag_val, Value::Int4(10) | Value::Int8(10)),
"LAG(val,2) of third row should be 10, got {:?}",
lag_val
);
}
#[test]
fn test_window_lag_default_offset() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (1, 100)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (2, 200)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (3, 300)").unwrap();
let results = db
.query(
"SELECT val, LAG(val) OVER (ORDER BY val) FROM nums",
&[],
)
.unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0].get(1).unwrap(), &Value::Null);
let lag_val = results[1].get(1).unwrap();
assert!(
matches!(lag_val, Value::Int4(100) | Value::Int8(100)),
"LAG with default offset should be 100 for second row, got {:?}",
lag_val
);
}
#[test]
fn test_window_lead_basic() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (3, 30)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (4, 40)").unwrap();
let results = db
.query(
"SELECT val, LEAD(val, 1) OVER (ORDER BY val) FROM nums",
&[],
)
.unwrap();
assert_eq!(results.len(), 4);
assert_eq!(results[3].get(1).unwrap(), &Value::Null);
let lead_val = results[0].get(1).unwrap();
assert!(
matches!(lead_val, Value::Int4(20) | Value::Int8(20)),
"LEAD of first row should be 20, got {:?}",
lead_val
);
}
#[test]
fn test_window_lead_offset_2() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
for i in 1..=5 {
db.execute(&format!(
"INSERT INTO nums (id, val) VALUES ({}, {})",
i,
i * 10
))
.unwrap();
}
let results = db
.query(
"SELECT val, LEAD(val, 2) OVER (ORDER BY val) FROM nums",
&[],
)
.unwrap();
assert_eq!(results.len(), 5);
assert_eq!(results[3].get(1).unwrap(), &Value::Null);
assert_eq!(results[4].get(1).unwrap(), &Value::Null);
let lead_val = results[0].get(1).unwrap();
assert!(
matches!(lead_val, Value::Int4(30) | Value::Int8(30)),
"LEAD(val,2) of first row should be 30, got {:?}",
lead_val
);
}
#[test]
fn test_window_lead_default_offset() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (1, 100)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (2, 200)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (3, 300)").unwrap();
let results = db
.query(
"SELECT val, LEAD(val) OVER (ORDER BY val) FROM nums",
&[],
)
.unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[2].get(1).unwrap(), &Value::Null);
let lead_val = results[0].get(1).unwrap();
assert!(
matches!(lead_val, Value::Int4(200) | Value::Int8(200)),
"LEAD with default offset should be 200 for first row, got {:?}",
lead_val
);
}
#[test]
fn test_window_first_value_basic() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (3, 30)").unwrap();
let results = db
.query(
"SELECT val, FIRST_VALUE(val) OVER (ORDER BY val) FROM nums",
&[],
)
.unwrap();
assert_eq!(results.len(), 3);
for row in &results {
let fv = row.get(1).unwrap();
assert!(
matches!(fv, Value::Int4(10) | Value::Int8(10)),
"FIRST_VALUE should be 10, got {:?}",
fv
);
}
}
#[test]
fn test_window_first_value_partitioned() {
let db = setup_window_test_db();
let results = db.query(
"SELECT name, dept, salary, FIRST_VALUE(salary) OVER (PARTITION BY dept ORDER BY salary DESC) FROM employees",
&[],
).unwrap();
assert_eq!(results.len(), 6);
for row in &results {
if let Some(Value::String(dept)) = row.get(1) {
let fv = row.get(3).unwrap();
let expected = match dept.as_str() {
"Engineering" => 120000,
"Sales" => 95000,
"Marketing" => 80000,
_ => panic!("unexpected dept"),
};
assert!(
matches!(fv, Value::Int4(v) if *v == expected)
|| matches!(fv, Value::Int8(v) if *v == expected as i64),
"FIRST_VALUE for {} should be {}, got {:?}",
dept,
expected,
fv
);
}
}
}
#[test]
fn test_window_first_value_with_nulls() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE null_first (id INT PRIMARY KEY, val INT)").unwrap();
db.execute("INSERT INTO null_first (id) VALUES (1)").unwrap(); db.execute("INSERT INTO null_first (id, val) VALUES (2, 20)").unwrap();
db.execute("INSERT INTO null_first (id, val) VALUES (3, 30)").unwrap();
let results = db
.query(
"SELECT id, val, FIRST_VALUE(val) OVER (ORDER BY id) FROM null_first",
&[],
)
.unwrap();
assert_eq!(results.len(), 3);
for row in &results {
assert_eq!(
row.get(2).unwrap(),
&Value::Null,
"FIRST_VALUE should be NULL when first row has NULL val"
);
}
}
#[test]
fn test_window_last_value_with_order_by() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (3, 30)").unwrap();
let results = db
.query(
"SELECT val, LAST_VALUE(val) OVER (ORDER BY val) FROM nums",
&[],
)
.unwrap();
assert_eq!(results.len(), 3);
for row in &results {
let val = row.get(0).unwrap();
let lv = row.get(1).unwrap();
assert_eq!(
val, lv,
"LAST_VALUE with default frame (ORDER BY) should equal current row value"
);
}
}
#[test]
fn test_window_last_value_no_order_by() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (3, 30)").unwrap();
let results = db
.query("SELECT val, LAST_VALUE(val) OVER () FROM nums", &[])
.unwrap();
assert_eq!(results.len(), 3);
let last_vals: Vec<&Value> = results.iter().map(|r| r.get(1).unwrap()).collect();
assert!(
last_vals.windows(2).all(|w| w[0] == w[1]),
"All LAST_VALUE results without ORDER BY should be equal"
);
}
#[test]
fn test_window_sum_partitioned() {
let db = setup_window_test_db();
let results = db
.query(
"SELECT name, dept, salary, SUM(salary) OVER (PARTITION BY dept) FROM employees",
&[],
)
.unwrap();
assert_eq!(results.len(), 6);
for row in &results {
if let Some(Value::String(dept)) = row.get(1) {
let sum_val = row.get(3).unwrap();
let expected: f64 = match dept.as_str() {
"Engineering" => 340_000.0,
"Sales" => 185_000.0,
"Marketing" => 80_000.0,
_ => panic!("unexpected dept"),
};
assert!(
matches!(sum_val, Value::Float8(v) if (*v - expected).abs() < 0.01),
"SUM for {} should be {}, got {:?}",
dept,
expected,
sum_val
);
}
}
}
#[test]
fn test_window_sum_running_total() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (3, 30)").unwrap();
let results = db
.query(
"SELECT val, SUM(val) OVER (ORDER BY val) FROM nums",
&[],
)
.unwrap();
assert_eq!(results.len(), 3);
let sums: Vec<f64> = results
.iter()
.map(|r| match r.get(1).unwrap() {
Value::Float8(v) => *v,
other => panic!("expected Float8, got {:?}", other),
})
.collect();
assert!((sums[0] - 10.0).abs() < 0.01, "Running sum row 1 = 10");
assert!((sums[1] - 30.0).abs() < 0.01, "Running sum row 2 = 30");
assert!((sums[2] - 60.0).abs() < 0.01, "Running sum row 3 = 60");
}
#[test]
fn test_window_count_partitioned() {
let db = setup_window_test_db();
let results = db
.query(
"SELECT name, dept, COUNT(salary) OVER (PARTITION BY dept) FROM employees",
&[],
)
.unwrap();
assert_eq!(results.len(), 6);
for row in &results {
if let Some(Value::String(dept)) = row.get(1) {
let count = row.get(2).unwrap();
let expected = match dept.as_str() {
"Engineering" => 3,
"Sales" => 2,
"Marketing" => 1,
_ => panic!("unexpected dept"),
};
assert_eq!(
count,
&Value::Int8(expected),
"COUNT for {} should be {}",
dept,
expected
);
}
}
}
#[test]
fn test_window_count_running() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
for i in 1..=5 {
db.execute(&format!(
"INSERT INTO nums (id, val) VALUES ({}, {})",
i,
i * 10
))
.unwrap();
}
let results = db
.query(
"SELECT val, COUNT(val) OVER (ORDER BY val) FROM nums",
&[],
)
.unwrap();
assert_eq!(results.len(), 5);
let counts: Vec<i64> = results
.iter()
.map(|r| match r.get(1).unwrap() {
Value::Int8(v) => *v,
_ => panic!("expected Int8"),
})
.collect();
assert_eq!(counts, vec![1, 2, 3, 4, 5], "Running count should be 1..=5");
}
#[test]
fn test_window_avg_partitioned() {
let db = setup_window_test_db();
let results = db
.query(
"SELECT name, dept, salary, AVG(salary) OVER (PARTITION BY dept) FROM employees",
&[],
)
.unwrap();
assert_eq!(results.len(), 6);
for row in &results {
if let Some(Value::String(dept)) = row.get(1) {
let avg_val = row.get(3).unwrap();
let expected: f64 = match dept.as_str() {
"Engineering" => 340_000.0 / 3.0,
"Sales" => 92_500.0,
"Marketing" => 80_000.0,
_ => panic!("unexpected dept"),
};
assert!(
matches!(avg_val, Value::Float8(v) if (*v - expected).abs() < 1.0),
"AVG for {} should be ~{}, got {:?}",
dept,
expected,
avg_val
);
}
}
}
#[test]
fn test_window_avg_running() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (3, 30)").unwrap();
let results = db
.query(
"SELECT val, AVG(val) OVER (ORDER BY val) FROM nums",
&[],
)
.unwrap();
assert_eq!(results.len(), 3);
let avgs: Vec<f64> = results
.iter()
.map(|r| match r.get(1).unwrap() {
Value::Float8(v) => *v,
other => panic!("expected Float8, got {:?}", other),
})
.collect();
assert!((avgs[0] - 10.0).abs() < 0.01, "Running avg row 1");
assert!((avgs[1] - 15.0).abs() < 0.01, "Running avg row 2");
assert!((avgs[2] - 20.0).abs() < 0.01, "Running avg row 3");
}
#[test]
fn test_window_min_max_partitioned() {
let db = setup_window_test_db();
let results = db.query(
"SELECT name, dept, MIN(salary) OVER (PARTITION BY dept), MAX(salary) OVER (PARTITION BY dept) FROM employees",
&[],
).unwrap();
assert_eq!(results.len(), 6);
for row in &results {
if let Some(Value::String(dept)) = row.get(1) {
let min_val = row.get(2).unwrap();
let max_val = row.get(3).unwrap();
match dept.as_str() {
"Engineering" => {
assert!(
matches!(min_val, Value::Int4(110000) | Value::Int8(110000)),
"MIN for Engineering = 110000, got {:?}",
min_val
);
assert!(
matches!(max_val, Value::Int4(120000) | Value::Int8(120000)),
"MAX for Engineering = 120000, got {:?}",
max_val
);
}
"Sales" => {
assert!(
matches!(min_val, Value::Int4(90000) | Value::Int8(90000)),
"MIN for Sales = 90000, got {:?}",
min_val
);
assert!(
matches!(max_val, Value::Int4(95000) | Value::Int8(95000)),
"MAX for Sales = 95000, got {:?}",
max_val
);
}
"Marketing" => {
assert!(
matches!(min_val, Value::Int4(80000) | Value::Int8(80000)),
"MIN for Marketing = 80000, got {:?}",
min_val
);
assert!(
matches!(max_val, Value::Int4(80000) | Value::Int8(80000)),
"MAX for Marketing = 80000, got {:?}",
max_val
);
}
_ => panic!("unexpected dept"),
}
}
}
}
#[test]
fn test_window_empty_result_set() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE empty_t (id INT PRIMARY KEY, val INT)").unwrap();
let results = db
.query(
"SELECT val, ROW_NUMBER() OVER (ORDER BY val) FROM empty_t",
&[],
)
.unwrap();
assert_eq!(results.len(), 0, "Window on empty table => 0 rows");
}
#[test]
fn test_window_empty_result_set_via_where() {
let db = setup_window_test_db();
let results = db
.query(
"SELECT name, ROW_NUMBER() OVER (ORDER BY salary) FROM employees WHERE dept = 'NonExistent'",
&[],
)
.unwrap();
assert_eq!(results.len(), 0, "No matching WHERE => 0 rows");
}
#[test]
fn test_window_single_row_partition() {
let db = setup_window_test_db();
let results = db.query(
"SELECT name, dept, \
ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary), \
RANK() OVER (PARTITION BY dept ORDER BY salary), \
SUM(salary) OVER (PARTITION BY dept) \
FROM employees WHERE dept = 'Marketing'",
&[],
).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].get(2).unwrap(), &Value::Int8(1), "ROW_NUMBER = 1");
assert_eq!(results[0].get(3).unwrap(), &Value::Int8(1), "RANK = 1");
assert!(
matches!(results[0].get(4).unwrap(), Value::Float8(v) if (*v - 80000.0).abs() < 0.01),
"SUM = 80000"
);
}
#[test]
fn test_window_all_null_values_in_windowed_column() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE null_t (id INT PRIMARY KEY, val INT)").unwrap();
db.execute("INSERT INTO null_t (id) VALUES (1)").unwrap();
db.execute("INSERT INTO null_t (id) VALUES (2)").unwrap();
db.execute("INSERT INTO null_t (id) VALUES (3)").unwrap();
let results = db
.query(
"SELECT id, val, SUM(val) OVER (ORDER BY id) FROM null_t",
&[],
)
.unwrap();
assert_eq!(results.len(), 3);
for row in &results {
let sum_val = row.get(2).unwrap();
assert!(
matches!(sum_val, Value::Null),
"SUM of all NULLs should be NULL (SQL standard), got {:?}",
sum_val
);
}
}
#[test]
fn test_window_null_in_windowed_column() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE mixed_nulls (id INT PRIMARY KEY, val INT)").unwrap();
db.execute("INSERT INTO mixed_nulls (id, val) VALUES (1, 10)").unwrap();
db.execute("INSERT INTO mixed_nulls (id) VALUES (2)").unwrap(); db.execute("INSERT INTO mixed_nulls (id, val) VALUES (3, 30)").unwrap();
let results = db
.query(
"SELECT id, val, LAG(val, 1) OVER (ORDER BY id) FROM mixed_nulls",
&[],
)
.unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0].get(2).unwrap(), &Value::Null, "LAG for first row = NULL");
let lag2 = results[1].get(2).unwrap();
assert!(
matches!(lag2, Value::Int4(10) | Value::Int8(10)),
"LAG for id=2 should be 10, got {:?}",
lag2
);
assert_eq!(
results[2].get(2).unwrap(),
&Value::Null,
"LAG for id=3 should be NULL (previous row has NULL val)"
);
}
#[test]
fn test_window_multiple_functions_same_select() {
let db = setup_window_test_db();
let results = db.query(
"SELECT name, salary, \
ROW_NUMBER() OVER (ORDER BY salary DESC), \
SUM(salary) OVER (), \
COUNT(salary) OVER () \
FROM employees",
&[],
).unwrap();
assert_eq!(results.len(), 6);
let total: f64 =
120_000.0 + 110_000.0 + 110_000.0 + 90_000.0 + 95_000.0 + 80_000.0;
for row in &results {
let sum_val = row.get(3).unwrap();
assert!(
matches!(sum_val, Value::Float8(v) if (*v - total).abs() < 0.01),
"Total SUM should be {}, got {:?}",
total,
sum_val
);
let count_val = row.get(4).unwrap();
assert_eq!(count_val, &Value::Int8(6), "Total COUNT should be 6");
}
let row_nums: Vec<i64> = results
.iter()
.map(|r| match r.get(2).unwrap() {
Value::Int8(v) => *v,
_ => panic!("expected Int8"),
})
.collect();
let mut sorted = row_nums.clone();
sorted.sort();
assert_eq!(sorted, vec![1, 2, 3, 4, 5, 6]);
}
#[test]
fn test_window_no_partition_by_entire_table() {
let db = setup_window_test_db();
let results = db
.query(
"SELECT name, salary, COUNT(salary) OVER () FROM employees",
&[],
)
.unwrap();
assert_eq!(results.len(), 6);
for row in &results {
assert_eq!(row.get(2).unwrap(), &Value::Int8(6));
}
}
#[test]
fn test_window_partition_with_many_groups() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE big_t (id INT PRIMARY KEY, grp INT, val INT)").unwrap();
for i in 1..=200 {
let grp = (i - 1) / 2 + 1; db.execute(&format!(
"INSERT INTO big_t (id, grp, val) VALUES ({}, {}, {})",
i, grp, i * 10
))
.unwrap();
}
let results = db
.query(
"SELECT grp, COUNT(val) OVER (PARTITION BY grp) FROM big_t",
&[],
)
.unwrap();
assert_eq!(results.len(), 200);
for row in &results {
assert_eq!(
row.get(1).unwrap(),
&Value::Int8(2),
"Each of 100 groups should have COUNT = 2"
);
}
}
#[test]
fn test_window_identical_values_all_rows() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE same_vals (id INT PRIMARY KEY, val INT)").unwrap();
db.execute("INSERT INTO same_vals (id, val) VALUES (1, 42)").unwrap();
db.execute("INSERT INTO same_vals (id, val) VALUES (2, 42)").unwrap();
db.execute("INSERT INTO same_vals (id, val) VALUES (3, 42)").unwrap();
let results = db.query(
"SELECT id, val, ROW_NUMBER() OVER (ORDER BY val), SUM(val) OVER () FROM same_vals",
&[],
).unwrap();
assert_eq!(results.len(), 3);
let row_nums: Vec<i64> = results
.iter()
.map(|r| match r.get(2).unwrap() {
Value::Int8(v) => *v,
_ => panic!("expected Int8"),
})
.collect();
let mut sorted = row_nums.clone();
sorted.sort();
assert_eq!(sorted, vec![1, 2, 3]);
for row in &results {
assert!(
matches!(row.get(3).unwrap(), Value::Float8(v) if (*v - 126.0).abs() < 0.01),
"SUM should be 126"
);
}
}
#[test]
fn test_window_single_row_table() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE single (id INT PRIMARY KEY, val INT)").unwrap();
db.execute("INSERT INTO single (id, val) VALUES (1, 42)").unwrap();
let results = db.query(
"SELECT val, \
ROW_NUMBER() OVER (ORDER BY val), \
RANK() OVER (ORDER BY val), \
LAG(val, 1) OVER (ORDER BY val), \
LEAD(val, 1) OVER (ORDER BY val), \
SUM(val) OVER (), \
COUNT(val) OVER () \
FROM single",
&[],
).unwrap();
assert_eq!(results.len(), 1);
let row = &results[0];
assert_eq!(row.get(1).unwrap(), &Value::Int8(1), "ROW_NUMBER = 1");
assert_eq!(row.get(2).unwrap(), &Value::Int8(1), "RANK = 1");
assert_eq!(row.get(3).unwrap(), &Value::Null, "LAG = NULL");
assert_eq!(row.get(4).unwrap(), &Value::Null, "LEAD = NULL");
assert!(
matches!(row.get(5).unwrap(), Value::Float8(v) if (*v - 42.0).abs() < 0.01),
"SUM = 42"
);
assert_eq!(row.get(6).unwrap(), &Value::Int8(1), "COUNT = 1");
}
#[test]
fn test_window_frame_unbounded_preceding_to_current_row() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (3, 30)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (4, 40)").unwrap();
let results = db.query(
"SELECT val, SUM(val) OVER (ORDER BY val ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) FROM nums",
&[],
).unwrap();
assert_eq!(results.len(), 4);
let sums: Vec<f64> = results
.iter()
.map(|r| match r.get(1).unwrap() {
Value::Float8(v) => *v,
other => panic!("expected Float8, got {:?}", other),
})
.collect();
assert!((sums[0] - 10.0).abs() < 0.01);
assert!((sums[1] - 30.0).abs() < 0.01);
assert!((sums[2] - 60.0).abs() < 0.01);
assert!((sums[3] - 100.0).abs() < 0.01);
}
#[test]
fn test_window_frame_1_preceding_to_1_following() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (3, 30)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (4, 40)").unwrap();
let results = db.query(
"SELECT val, SUM(val) OVER (ORDER BY val ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM nums",
&[],
).unwrap();
assert_eq!(results.len(), 4);
let sums: Vec<f64> = results
.iter()
.map(|r| match r.get(1).unwrap() {
Value::Float8(v) => *v,
other => panic!("expected Float8, got {:?}", other),
})
.collect();
assert!((sums[0] - 30.0).abs() < 0.01, "Row 1: got {}", sums[0]);
assert!((sums[1] - 60.0).abs() < 0.01, "Row 2: got {}", sums[1]);
assert!((sums[2] - 90.0).abs() < 0.01, "Row 3: got {}", sums[2]);
assert!((sums[3] - 70.0).abs() < 0.01, "Row 4: got {}", sums[3]);
}
#[test]
fn test_window_frame_unbounded_preceding_to_unbounded_following() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (3, 30)").unwrap();
let results = db.query(
"SELECT val, SUM(val) OVER (ORDER BY val ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM nums",
&[],
).unwrap();
assert_eq!(results.len(), 3);
for row in &results {
let sum_val = row.get(1).unwrap();
assert!(
matches!(sum_val, Value::Float8(v) if (*v - 60.0).abs() < 0.01),
"Full frame SUM should be 60, got {:?}",
sum_val
);
}
}
#[test]
fn test_window_frame_current_row_to_current_row() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (3, 30)").unwrap();
let results = db.query(
"SELECT val, SUM(val) OVER (ORDER BY val ROWS BETWEEN CURRENT ROW AND CURRENT ROW) FROM nums",
&[],
).unwrap();
assert_eq!(results.len(), 3);
for row in &results {
let val = match row.get(0).unwrap() {
Value::Int4(v) => *v as f64,
Value::Int8(v) => *v as f64,
_ => panic!("unexpected type"),
};
let sum = match row.get(1).unwrap() {
Value::Float8(v) => *v,
other => panic!("expected Float8, got {:?}", other),
};
assert!(
(sum - val).abs() < 0.01,
"CURRENT ROW frame SUM should equal own value"
);
}
}
#[test]
fn test_window_frame_2_preceding_to_current_row() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
for i in 1..=5 {
db.execute(&format!(
"INSERT INTO nums (id, val) VALUES ({}, {})",
i,
i * 10
))
.unwrap();
}
let results = db.query(
"SELECT val, SUM(val) OVER (ORDER BY val ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) FROM nums",
&[],
).unwrap();
assert_eq!(results.len(), 5);
let sums: Vec<f64> = results
.iter()
.map(|r| match r.get(1).unwrap() {
Value::Float8(v) => *v,
other => panic!("expected Float8, got {:?}", other),
})
.collect();
assert!((sums[0] - 10.0).abs() < 0.01);
assert!((sums[1] - 30.0).abs() < 0.01);
assert!((sums[2] - 60.0).abs() < 0.01);
assert!((sums[3] - 90.0).abs() < 0.01);
assert!((sums[4] - 120.0).abs() < 0.01);
}
#[test]
fn test_window_row_number_no_order_by() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (3, 30)").unwrap();
let results = db
.query("SELECT val, ROW_NUMBER() OVER () FROM nums", &[])
.unwrap();
assert_eq!(results.len(), 3);
let row_nums: Vec<i64> = results
.iter()
.map(|r| match r.get(1).unwrap() {
Value::Int8(v) => *v,
_ => panic!("expected Int8"),
})
.collect();
let mut sorted = row_nums.clone();
sorted.sort();
assert_eq!(sorted, vec![1, 2, 3]);
}
#[test]
fn test_window_descending_order() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (3, 30)").unwrap();
let results = db
.query(
"SELECT val, ROW_NUMBER() OVER (ORDER BY val DESC) FROM nums",
&[],
)
.unwrap();
assert_eq!(results.len(), 3);
for row in &results {
let val = match row.get(0).unwrap() {
Value::Int4(v) => *v as i64,
Value::Int8(v) => *v,
_ => panic!("unexpected type"),
};
let rn = match row.get(1).unwrap() {
Value::Int8(v) => *v,
_ => panic!("expected Int8"),
};
match val {
30 => assert_eq!(rn, 1),
20 => assert_eq!(rn, 2),
10 => assert_eq!(rn, 3),
_ => panic!("unexpected val {}", val),
}
}
}
#[test]
fn test_window_sum_with_where_clause() {
let db = setup_window_test_db();
let results = db.query(
"SELECT name, salary, SUM(salary) OVER (ORDER BY salary) FROM employees WHERE dept = 'Engineering'",
&[],
).unwrap();
assert_eq!(results.len(), 3, "Only Engineering employees");
let sums: Vec<f64> = results
.iter()
.map(|r| match r.get(2).unwrap() {
Value::Float8(v) => *v,
other => panic!("expected Float8, got {:?}", other),
})
.collect();
let max_sum = sums.iter().cloned().fold(0.0_f64, f64::max);
assert!(
(max_sum - 340_000.0).abs() < 0.01,
"Max running sum should be 340000, got {}",
max_sum
);
let min_sum = sums.iter().cloned().fold(f64::MAX, f64::min);
assert!(
min_sum >= 109_999.0,
"Min running sum should be >= 110000, got {}",
min_sum
);
}
#[test]
fn test_window_count_star_over() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE t (id INT PRIMARY KEY, val INT)").unwrap();
db.execute("INSERT INTO t (id, val) VALUES (1, 10)").unwrap();
db.execute("INSERT INTO t (id) VALUES (2)").unwrap(); db.execute("INSERT INTO t (id, val) VALUES (3, 30)").unwrap();
let results = db
.query("SELECT id, COUNT(*) OVER () FROM t", &[])
.unwrap();
assert_eq!(results.len(), 3);
for row in &results {
assert_eq!(
row.get(1).unwrap(),
&Value::Int8(3),
"COUNT(*) OVER() should count all rows"
);
}
}
#[test]
fn test_window_count_column_excludes_nulls() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE t (id INT PRIMARY KEY, val INT)").unwrap();
db.execute("INSERT INTO t (id, val) VALUES (1, 10)").unwrap();
db.execute("INSERT INTO t (id) VALUES (2)").unwrap(); db.execute("INSERT INTO t (id, val) VALUES (3, 30)").unwrap();
let results = db
.query("SELECT id, COUNT(val) OVER () FROM t", &[])
.unwrap();
assert_eq!(results.len(), 3);
for row in &results {
assert_eq!(
row.get(1).unwrap(),
&Value::Int8(2),
"COUNT(col) should exclude NULLs per SQL standard"
);
}
}
#[test]
fn test_window_multiple_partitions_multiple_functions() {
let db = setup_window_test_db();
let results = db.query(
"SELECT name, dept, salary, \
ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary DESC), \
SUM(salary) OVER (PARTITION BY dept), \
AVG(salary) OVER (PARTITION BY dept) \
FROM employees",
&[],
).unwrap();
assert_eq!(results.len(), 6);
for row in &results {
assert_eq!(row.len(), 6, "3 original + 3 window columns");
}
}
#[test]
fn test_window_preserves_original_columns() {
let db = setup_window_test_db();
let results = db
.query(
"SELECT id, name, dept, salary, ROW_NUMBER() OVER (ORDER BY salary) FROM employees",
&[],
)
.unwrap();
assert_eq!(results.len(), 6);
for row in &results {
assert_eq!(row.len(), 5, "4 original + 1 window column");
assert!(
matches!(row.get(0).unwrap(), Value::Int4(_) | Value::Int8(_)),
"id should be integer"
);
assert!(matches!(row.get(1).unwrap(), Value::String(_)), "name should be string");
}
}
#[test]
fn test_window_lag_partitioned() {
let db = setup_window_test_db();
let results = db.query(
"SELECT name, dept, salary, LAG(salary, 1) OVER (PARTITION BY dept ORDER BY salary) FROM employees",
&[],
).unwrap();
assert_eq!(results.len(), 6);
let mut dept_null_count: std::collections::HashMap<String, usize> =
std::collections::HashMap::new();
for row in &results {
if let Some(Value::String(dept)) = row.get(1) {
if row.get(3).unwrap() == &Value::Null {
*dept_null_count.entry(dept.clone()).or_insert(0) += 1;
}
}
}
for (dept, count) in &dept_null_count {
assert_eq!(
*count, 1,
"Partition {} should have exactly 1 NULL LAG, got {}",
dept, count
);
}
}
#[test]
fn test_window_lead_partitioned() {
let db = setup_window_test_db();
let results = db.query(
"SELECT name, dept, salary, LEAD(salary, 1) OVER (PARTITION BY dept ORDER BY salary) FROM employees",
&[],
).unwrap();
assert_eq!(results.len(), 6);
let mut dept_null_count: std::collections::HashMap<String, usize> =
std::collections::HashMap::new();
for row in &results {
if let Some(Value::String(dept)) = row.get(1) {
if row.get(3).unwrap() == &Value::Null {
*dept_null_count.entry(dept.clone()).or_insert(0) += 1;
}
}
}
for (dept, count) in &dept_null_count {
assert_eq!(
*count, 1,
"Partition {} should have exactly 1 NULL LEAD, got {}",
dept, count
);
}
}
#[test]
fn test_window_large_dataset_row_number() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE large_t (id INT PRIMARY KEY, val INT)").unwrap();
for i in 1..=500 {
db.execute(&format!(
"INSERT INTO large_t (id, val) VALUES ({}, {})",
i, i
))
.unwrap();
}
let results = db
.query(
"SELECT id, ROW_NUMBER() OVER (ORDER BY val) FROM large_t",
&[],
)
.unwrap();
assert_eq!(results.len(), 500);
let row_nums: std::collections::HashSet<i64> = results
.iter()
.map(|r| match r.get(1).unwrap() {
Value::Int8(v) => *v,
_ => panic!("expected Int8"),
})
.collect();
assert_eq!(row_nums.len(), 500);
assert!(row_nums.contains(&1));
assert!(row_nums.contains(&500));
}
#[test]
fn test_window_percent_rank() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE scores (id INT PRIMARY KEY, score INT)").unwrap();
db.execute("INSERT INTO scores (id, score) VALUES (1, 100)").unwrap();
db.execute("INSERT INTO scores (id, score) VALUES (2, 200)").unwrap();
db.execute("INSERT INTO scores (id, score) VALUES (3, 300)").unwrap();
db.execute("INSERT INTO scores (id, score) VALUES (4, 400)").unwrap();
let results = db
.query(
"SELECT score, PERCENT_RANK() OVER (ORDER BY score) FROM scores",
&[],
)
.unwrap();
assert_eq!(results.len(), 4);
let pct_ranks: Vec<f64> = results
.iter()
.map(|r| match r.get(1).unwrap() {
Value::Float8(v) => *v,
other => panic!("expected Float8, got {:?}", other),
})
.collect();
assert!((pct_ranks[0] - 0.0).abs() < 0.01);
assert!((pct_ranks[1] - 1.0 / 3.0).abs() < 0.01);
assert!((pct_ranks[2] - 2.0 / 3.0).abs() < 0.01);
assert!((pct_ranks[3] - 1.0).abs() < 0.01);
}
#[test]
fn test_window_percent_rank_single_row() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE one_row (id INT PRIMARY KEY, val INT)").unwrap();
db.execute("INSERT INTO one_row (id, val) VALUES (1, 42)").unwrap();
let results = db
.query(
"SELECT val, PERCENT_RANK() OVER (ORDER BY val) FROM one_row",
&[],
)
.unwrap();
assert_eq!(results.len(), 1);
assert!(
matches!(results[0].get(1).unwrap(), Value::Float8(v) if v.abs() < 0.01),
"PERCENT_RANK with single row should be 0.0"
);
}
#[test]
fn test_window_ntile_single_bucket() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (3, 30)").unwrap();
let results = db
.query(
"SELECT val, NTILE(1) OVER (ORDER BY val) FROM nums",
&[],
)
.unwrap();
assert_eq!(results.len(), 3);
for row in &results {
assert_eq!(row.get(1).unwrap(), &Value::Int8(1));
}
}
#[test]
fn test_window_ntile_more_buckets_than_rows() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
let results = db
.query(
"SELECT val, NTILE(5) OVER (ORDER BY val) FROM nums",
&[],
)
.unwrap();
assert_eq!(results.len(), 2);
let buckets: Vec<i64> = results
.iter()
.map(|r| match r.get(1).unwrap() {
Value::Int8(v) => *v,
_ => panic!("expected Int8"),
})
.collect();
assert!(
buckets.iter().all(|b| *b >= 1 && *b <= 5),
"Buckets should be in range 1..=5"
);
}
#[test]
fn test_returning_insert_star() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE ret_test (a INT, b TEXT)").unwrap();
let (count, rows) = db.execute_returning(
"INSERT INTO ret_test (a, b) VALUES (1, 'hello') RETURNING *"
).unwrap();
assert_eq!(count, 1);
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].values.len(), 2);
assert_eq!(rows[0].values[0], Value::Int4(1));
assert_eq!(rows[0].values[1], Value::String("hello".to_string()));
}
#[test]
fn test_returning_insert_specific_columns() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE ret_cols (a INT, b TEXT)").unwrap();
let (count, rows) = db.execute_returning(
"INSERT INTO ret_cols (a, b) VALUES (1, 'world') RETURNING a, b"
).unwrap();
assert_eq!(count, 1);
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].values.len(), 2);
assert_eq!(rows[0].values[0], Value::Int4(1));
assert_eq!(rows[0].values[1], Value::String("world".to_string()));
}
#[test]
fn test_returning_insert_single_column() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE ret_single (a INT, b TEXT)").unwrap();
let (count, rows) = db.execute_returning(
"INSERT INTO ret_single (a, b) VALUES (42, 'test') RETURNING a"
).unwrap();
assert_eq!(count, 1);
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].values.len(), 1);
assert_eq!(rows[0].values[0], Value::Int4(42));
}
#[test]
fn test_returning_insert_expression_with_alias() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE ret_expr (a INT, b INT)").unwrap();
let (count, rows) = db.execute_returning(
"INSERT INTO ret_expr (a, b) VALUES (1, 2) RETURNING a + 1 AS incremented"
).unwrap();
assert_eq!(count, 1);
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].values.len(), 1);
assert_eq!(rows[0].values[0], Value::Int4(2));
}
#[test]
fn test_returning_update_star() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE ret_upd (a INT, b INT)").unwrap();
db.execute("INSERT INTO ret_upd (a, b) VALUES (1, 5)").unwrap();
let (count, rows) = db.execute_returning(
"UPDATE ret_upd SET b = 10 WHERE a = 1 RETURNING *"
).unwrap();
assert_eq!(count, 1);
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].values[0], Value::Int4(1));
assert_eq!(rows[0].values[1], Value::Int4(10));
}
#[test]
fn test_returning_update_specific_column() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE ret_upd2 (a INT, b INT)").unwrap();
db.execute("INSERT INTO ret_upd2 (a, b) VALUES (1, 5)").unwrap();
db.execute("INSERT INTO ret_upd2 (a, b) VALUES (2, 6)").unwrap();
let (count, rows) = db.execute_returning(
"UPDATE ret_upd2 SET b = 99 WHERE a = 2 RETURNING b"
).unwrap();
assert_eq!(count, 1);
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].values.len(), 1);
assert_eq!(rows[0].values[0], Value::Int4(99));
}
#[test]
fn test_returning_delete() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE ret_del (a INT, b TEXT)").unwrap();
db.execute("INSERT INTO ret_del (a, b) VALUES (1, 'one')").unwrap();
db.execute("INSERT INTO ret_del (a, b) VALUES (2, 'two')").unwrap();
let (count, rows) = db.execute_returning(
"DELETE FROM ret_del WHERE a = 1 RETURNING a"
).unwrap();
assert_eq!(count, 1);
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].values.len(), 1);
assert_eq!(rows[0].values[0], Value::Int4(1));
}
#[test]
fn test_returning_delete_star() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE ret_del2 (a INT, b TEXT)").unwrap();
db.execute("INSERT INTO ret_del2 (a, b) VALUES (1, 'one')").unwrap();
db.execute("INSERT INTO ret_del2 (a, b) VALUES (2, 'two')").unwrap();
let (count, rows) = db.execute_returning(
"DELETE FROM ret_del2 WHERE a = 2 RETURNING *"
).unwrap();
assert_eq!(count, 1);
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].values.len(), 2);
assert_eq!(rows[0].values[0], Value::Int4(2));
assert_eq!(rows[0].values[1], Value::String("two".to_string()));
}
#[test]
fn test_returning_multi_row_insert() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE ret_multi (a INT, b INT)").unwrap();
let (count, rows) = db.execute_returning(
"INSERT INTO ret_multi (a, b) VALUES (1, 10), (2, 20), (3, 30) RETURNING *"
).unwrap();
assert_eq!(count, 3);
assert_eq!(rows.len(), 3);
assert_eq!(rows[0].values[0], Value::Int4(1));
assert_eq!(rows[1].values[0], Value::Int4(2));
assert_eq!(rows[2].values[0], Value::Int4(3));
}
#[test]
fn test_returning_no_matching_rows() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE ret_empty (a INT)").unwrap();
let (count, rows) = db.execute_returning(
"DELETE FROM ret_empty WHERE a = 999 RETURNING *"
).unwrap();
assert_eq!(count, 0);
assert_eq!(rows.len(), 0);
}
#[test]
fn test_returning_via_query() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE ret_query (a INT, b TEXT)").unwrap();
let rows = db.query(
"INSERT INTO ret_query (a, b) VALUES (7, 'seven') RETURNING *",
&[]
).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].values[0], Value::Int4(7));
assert_eq!(rows[0].values[1], Value::String("seven".to_string()));
}
#[test]
fn test_returning_update_no_clause() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE ret_none (a INT)").unwrap();
db.execute("INSERT INTO ret_none (a) VALUES (1)").unwrap();
let (count, rows) = db.execute_returning(
"UPDATE ret_none SET a = 2 WHERE a = 1"
).unwrap();
assert_eq!(count, 1);
assert_eq!(rows.len(), 0);
}
fn parse_json_value(val: &Value) -> serde_json::Value {
match val {
Value::Json(j) => serde_json::from_str(j).unwrap(),
Value::String(s) => serde_json::from_str(s).unwrap_or_else(|_| serde_json::json!(s)),
other => panic!("Expected Json or String, got {:?}", other),
}
}
#[test]
fn test_json_column_create_insert_select() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE json_basic (id INT PRIMARY KEY, data JSONB)").unwrap();
db.execute(r#"INSERT INTO json_basic (id, data) VALUES (1, '{"name":"Alice","age":30}')"#).unwrap();
db.execute(r#"INSERT INTO json_basic (id, data) VALUES (2, '{"name":"Bob","age":25}')"#).unwrap();
let rows = db.query("SELECT id, data FROM json_basic ORDER BY id", &[]).unwrap();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].values[0], Value::Int4(1));
let parsed = parse_json_value(&rows[0].values[1]);
assert_eq!(parsed["name"], "Alice");
assert_eq!(parsed["age"], 30);
}
#[test]
fn test_json_column_type_json_vs_jsonb() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE json_types (id INT PRIMARY KEY, j JSON, jb JSONB)").unwrap();
db.execute(r#"INSERT INTO json_types (id, j, jb) VALUES (1, '{"a":1}', '{"b":2}')"#).unwrap();
let rows = db.query("SELECT j, jb FROM json_types WHERE id = 1", &[]).unwrap();
assert_eq!(rows.len(), 1);
let j_parsed = parse_json_value(&rows[0].values[0]);
assert_eq!(j_parsed["a"], 1);
let jb_parsed = parse_json_value(&rows[0].values[1]);
assert_eq!(jb_parsed["b"], 2);
}
#[test]
fn test_json_null_column() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE json_nulls (id INT PRIMARY KEY, data JSONB)").unwrap();
db.execute("INSERT INTO json_nulls (id, data) VALUES (1, NULL)").unwrap();
let rows = db.query("SELECT data FROM json_nulls WHERE id = 1", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].values[0], Value::Null);
}
#[test]
fn test_json_cast_string_to_jsonb() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(r#"SELECT CAST('{"hello":"world"}' AS JSONB)"#, &[]).unwrap();
assert_eq!(rows.len(), 1);
match &rows[0].values[0] {
Value::Json(j) => {
let parsed: serde_json::Value = serde_json::from_str(j).unwrap();
assert_eq!(parsed["hello"], "world");
}
other => panic!("Expected Json from CAST, got {:?}", other),
}
}
#[test]
fn test_json_cast_to_json_type() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(r#"SELECT CAST('{"k":"v"}' AS JSON)"#, &[]).unwrap();
assert_eq!(rows.len(), 1);
match &rows[0].values[0] {
Value::Json(j) => {
let parsed: serde_json::Value = serde_json::from_str(j).unwrap();
assert_eq!(parsed["k"], "v");
}
other => panic!("Expected Json from CAST to JSON, got {:?}", other),
}
}
#[test]
fn test_json_arrow_get_object_field_via_cast() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT CAST('{"name":"Alice","age":30}' AS JSONB)->'name'"#, &[]
).unwrap();
assert_eq!(rows.len(), 1);
match &rows[0].values[0] {
Value::Json(j) => assert_eq!(j, "\"Alice\""),
other => panic!("Expected Json from ->, got {:?}", other),
}
}
#[test]
fn test_json_double_arrow_get_text_via_cast() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT CAST('{"name":"Alice","age":30}' AS JSONB)->>'name'"#, &[]
).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].values[0], Value::String("Alice".to_string()));
}
#[test]
fn test_json_arrow_get_numeric_as_text_via_cast() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT CAST('{"age":25}' AS JSONB)->>'age'"#, &[]
).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].values[0], Value::String("25".to_string()));
}
#[test]
fn test_json_arrow_array_index_via_cast() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT CAST('["apple","banana","cherry"]' AS JSONB)->1"#, &[]
).unwrap();
assert_eq!(rows.len(), 1);
match &rows[0].values[0] {
Value::Json(j) => assert_eq!(j, "\"banana\""),
other => panic!("Expected Json for array index, got {:?}", other),
}
}
#[test]
fn test_json_arrow_missing_key_via_cast() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT CAST('{"name":"Alice"}' AS JSONB)->'nonexistent'"#, &[]
).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].values[0], Value::Null);
}
#[test]
fn test_json_arrow_on_null_column() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE json_null_op (id INT PRIMARY KEY, data JSONB)").unwrap();
db.execute("INSERT INTO json_null_op (id, data) VALUES (1, NULL)").unwrap();
let rows = db.query("SELECT CAST(data AS JSONB)->'key' FROM json_null_op WHERE id = 1", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].values[0], Value::Null);
}
#[test]
fn test_json_nested_arrow_chaining_via_cast() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT CAST('{"user":{"address":{"city":"NYC"}}}' AS JSONB)->'user'->'address'->'city'"#,
&[]
).unwrap();
assert_eq!(rows.len(), 1);
match &rows[0].values[0] {
Value::Json(j) => assert_eq!(j, "\"NYC\""),
other => panic!("Expected nested Json, got {:?}", other),
}
}
#[test]
fn test_json_nested_arrow_then_double_arrow() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT CAST('{"user":{"name":"Alice"}}' AS JSONB)->'user'->>'name'"#, &[]
).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].values[0], Value::String("Alice".to_string()));
}
#[test]
fn test_json_contains_operator_via_cast() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT CAST('{"name":"Alice","city":"NYC"}' AS JSONB) @> CAST('{"city":"NYC"}' AS JSONB)"#,
&[]
).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].values[0], Value::Boolean(true));
let rows = db.query(
r#"SELECT CAST('{"name":"Alice","city":"NYC"}' AS JSONB) @> CAST('{"city":"LA"}' AS JSONB)"#,
&[]
).unwrap();
assert_eq!(rows[0].values[0], Value::Boolean(false));
}
#[test]
fn test_json_contained_by_operator_via_cast() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT CAST('{"a":1}' AS JSONB) <@ CAST('{"a":1,"b":2}' AS JSONB)"#, &[]
).unwrap();
assert_eq!(rows[0].values[0], Value::Boolean(true));
let rows = db.query(
r#"SELECT CAST('{"a":1,"c":3}' AS JSONB) <@ CAST('{"a":1,"b":2}' AS JSONB)"#, &[]
).unwrap();
assert_eq!(rows[0].values[0], Value::Boolean(false));
}
#[test]
fn test_json_contains_nested_via_cast() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT CAST('{"user":{"address":{"city":"NYC"}}}' AS JSONB) @> CAST('{"user":{"address":{"city":"NYC"}}}' AS JSONB)"#,
&[]
).unwrap();
assert_eq!(rows[0].values[0], Value::Boolean(true));
let rows = db.query(
r#"SELECT CAST('{"user":{"address":{"city":"NYC"}}}' AS JSONB) @> CAST('{"user":{"address":{"city":"LA"}}}' AS JSONB)"#,
&[]
).unwrap();
assert_eq!(rows[0].values[0], Value::Boolean(false));
}
#[test]
fn test_json_contains_array_values_via_cast() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT CAST('{"tags":["rust","db","json"]}' AS JSONB) @> CAST('{"tags":["rust"]}' AS JSONB)"#,
&[]
).unwrap();
assert_eq!(rows[0].values[0], Value::Boolean(true));
let rows = db.query(
r#"SELECT CAST('{"tags":["rust","db"]}' AS JSONB) @> CAST('{"tags":["python"]}' AS JSONB)"#,
&[]
).unwrap();
assert_eq!(rows[0].values[0], Value::Boolean(false));
}
#[test]
fn test_json_complex_data_types_via_cast() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let json_str = r#"{"str":"hello","num":42,"flag":true,"arr":[1,2,3],"obj":{"x":1}}"#;
let rows = db.query(
&format!("SELECT CAST('{}' AS JSONB)->>'str'", json_str), &[]
).unwrap();
assert_eq!(rows[0].values[0], Value::String("hello".to_string()));
let rows = db.query(
&format!("SELECT CAST('{}' AS JSONB)->>'num'", json_str), &[]
).unwrap();
assert_eq!(rows[0].values[0], Value::String("42".to_string()));
let rows = db.query(
&format!("SELECT CAST('{}' AS JSONB)->>'flag'", json_str), &[]
).unwrap();
assert_eq!(rows[0].values[0], Value::String("true".to_string()));
let rows = db.query(
&format!("SELECT CAST('{}' AS JSONB)->'arr'", json_str), &[]
).unwrap();
match &rows[0].values[0] {
Value::Json(j) => {
let parsed: serde_json::Value = serde_json::from_str(j).unwrap();
assert!(parsed.is_array());
assert_eq!(parsed.as_array().unwrap().len(), 3);
}
other => panic!("Expected Json for nested array, got {:?}", other),
}
}
#[test]
fn test_json_update_column() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE json_update (id INT PRIMARY KEY, data JSONB)").unwrap();
db.execute(r#"INSERT INTO json_update (id, data) VALUES (1, '{"v":1}')"#).unwrap();
db.execute(r#"UPDATE json_update SET data = '{"v":2,"extra":"added"}' WHERE id = 1"#).unwrap();
let rows = db.query("SELECT data FROM json_update WHERE id = 1", &[]).unwrap();
let parsed = parse_json_value(&rows[0].values[0]);
assert_eq!(parsed["v"], 2);
assert_eq!(parsed["extra"], "added");
}
#[test]
fn test_json_func_jsonb_typeof() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let cases = vec![
(r#"'{"k":"v"}'"#, "object"),
(r#"'[1,2,3]'"#, "array"),
(r#"'"hello"'"#, "string"),
(r#"'42'"#, "number"),
(r#"'true'"#, "boolean"),
(r#"'null'"#, "null"),
];
for (json_literal, expected_type) in cases {
let query = format!("SELECT jsonb_typeof(CAST({} AS JSONB))", json_literal);
let rows = db.query(&query, &[]).unwrap();
assert_eq!(
rows[0].values[0],
Value::String(expected_type.to_string()),
"jsonb_typeof failed for {}",
json_literal
);
}
}
#[test]
fn test_json_func_jsonb_array_length() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query("SELECT jsonb_array_length(CAST('[10,20,30,40]' AS JSONB))", &[]).unwrap();
assert_eq!(rows[0].values[0], Value::Int4(4));
let rows = db.query("SELECT jsonb_array_length(CAST('[]' AS JSONB))", &[]).unwrap();
assert_eq!(rows[0].values[0], Value::Int4(0));
}
#[test]
fn test_json_func_jsonb_extract_path() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT jsonb_extract_path(CAST('{"user":{"address":{"city":"NYC"}}}' AS JSONB), 'user', 'address', 'city')"#,
&[]
).unwrap();
assert_eq!(rows.len(), 1);
match &rows[0].values[0] {
Value::Json(j) => assert_eq!(j, "\"NYC\""),
other => panic!("Expected Json from jsonb_extract_path, got {:?}", other),
}
}
#[test]
fn test_json_func_jsonb_extract_path_text() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT jsonb_extract_path_text(CAST('{"user":{"name":"Alice"}}' AS JSONB), 'user', 'name')"#,
&[]
).unwrap();
assert_eq!(rows[0].values[0], Value::String("Alice".to_string()));
}
#[test]
fn test_json_func_jsonb_extract_path_missing() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT jsonb_extract_path(CAST('{"a":1}' AS JSONB), 'nonexistent', 'path')"#, &[]
).unwrap();
assert_eq!(rows[0].values[0], Value::Null);
}
#[test]
fn test_json_func_jsonb_object_keys() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT jsonb_object_keys(CAST('{"name":"Alice","age":30,"city":"NYC"}' AS JSONB))"#, &[]
).unwrap();
assert_eq!(rows.len(), 1);
match &rows[0].values[0] {
Value::Array(keys) => {
let key_strings: Vec<String> = keys.iter().filter_map(|v| {
if let Value::String(s) = v { Some(s.clone()) } else { None }
}).collect();
assert!(key_strings.contains(&"name".to_string()));
assert!(key_strings.contains(&"age".to_string()));
assert!(key_strings.contains(&"city".to_string()));
assert_eq!(key_strings.len(), 3);
}
other => panic!("Expected Array, got {:?}", other),
}
}
#[test]
fn test_json_func_jsonb_build_object() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query("SELECT jsonb_build_object('name', 'Alice', 'age', 30)", &[]).unwrap();
match &rows[0].values[0] {
Value::Json(j) => {
let parsed: serde_json::Value = serde_json::from_str(j).unwrap();
assert_eq!(parsed["name"], "Alice");
assert_eq!(parsed["age"], 30);
}
other => panic!("Expected Json from jsonb_build_object, got {:?}", other),
}
}
#[test]
fn test_json_func_jsonb_build_array() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query("SELECT jsonb_build_array(1, 'two', 3, true)", &[]).unwrap();
match &rows[0].values[0] {
Value::Json(j) => {
let parsed: serde_json::Value = serde_json::from_str(j).unwrap();
let arr = parsed.as_array().unwrap();
assert_eq!(arr.len(), 4);
assert_eq!(arr[0], 1);
assert_eq!(arr[1], "two");
assert_eq!(arr[2], 3);
assert_eq!(arr[3], true);
}
other => panic!("Expected Json from jsonb_build_array, got {:?}", other),
}
}
#[test]
fn test_json_func_jsonb_strip_nulls() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT jsonb_strip_nulls(CAST('{"a":1,"b":null,"c":"hello","d":null}' AS JSONB))"#, &[]
).unwrap();
match &rows[0].values[0] {
Value::Json(j) => {
let parsed: serde_json::Value = serde_json::from_str(j).unwrap();
assert_eq!(parsed["a"], 1);
assert_eq!(parsed["c"], "hello");
assert!(parsed.get("b").is_none());
assert!(parsed.get("d").is_none());
}
other => panic!("Expected Json from jsonb_strip_nulls, got {:?}", other),
}
}
#[test]
fn test_json_func_jsonb_strip_nulls_nested() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT jsonb_strip_nulls(CAST('{"a":1,"b":{"c":null,"d":2},"e":null}' AS JSONB))"#, &[]
).unwrap();
match &rows[0].values[0] {
Value::Json(j) => {
let parsed: serde_json::Value = serde_json::from_str(j).unwrap();
assert_eq!(parsed["a"], 1);
assert!(parsed.get("e").is_none());
assert!(parsed["b"].get("c").is_none());
assert_eq!(parsed["b"]["d"], 2);
}
other => panic!("Expected Json, got {:?}", other),
}
}
#[test]
fn test_json_func_jsonb_pretty() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT jsonb_pretty(CAST('{"a":1,"b":2}' AS JSONB))"#, &[]
).unwrap();
match &rows[0].values[0] {
Value::String(s) => {
assert!(s.contains('\n'));
let parsed: serde_json::Value = serde_json::from_str(s).unwrap();
assert_eq!(parsed["a"], 1);
assert_eq!(parsed["b"], 2);
}
other => panic!("Expected String from jsonb_pretty, got {:?}", other),
}
}
#[test]
fn test_json_func_jsonb_path_query() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT jsonb_path_query(CAST('{"user":{"name":"Alice"}}' AS JSONB), 'user.name')"#, &[]
).unwrap();
match &rows[0].values[0] {
Value::Json(j) => assert_eq!(j, "\"Alice\""),
other => panic!("Expected Json from jsonb_path_query, got {:?}", other),
}
}
#[test]
fn test_json_func_jsonb_path_query_nested() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT jsonb_path_query(CAST('{"config":{"db":{"host":"localhost","port":5432}}}' AS JSONB), 'config.db.host')"#,
&[]
).unwrap();
match &rows[0].values[0] {
Value::Json(j) => assert_eq!(j, "\"localhost\""),
other => panic!("Expected Json, got {:?}", other),
}
let rows = db.query(
r#"SELECT jsonb_path_query(CAST('{"config":{"db":{"host":"localhost","port":5432}}}' AS JSONB), 'config.db.port')"#,
&[]
).unwrap();
match &rows[0].values[0] {
Value::Json(j) => assert_eq!(j, "5432"),
other => panic!("Expected Json, got {:?}", other),
}
}
#[test]
fn test_json_func_jsonb_path_exists() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT jsonb_path_exists(CAST('{"user":{"name":"Alice"}}' AS JSONB), 'user.name')"#, &[]
).unwrap();
assert_eq!(rows[0].values[0], Value::Boolean(true));
let rows = db.query(
r#"SELECT jsonb_path_exists(CAST('{"user":{"name":"Alice"}}' AS JSONB), 'user.email')"#, &[]
).unwrap();
assert_eq!(rows[0].values[0], Value::Boolean(false));
}
#[test]
fn test_json_func_jsonb_path_query_array() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT jsonb_path_query_array(CAST('{"user":{"name":"Alice"}}' AS JSONB), 'user.name')"#, &[]
).unwrap();
match &rows[0].values[0] {
Value::Array(arr) => {
assert_eq!(arr.len(), 1);
match &arr[0] {
Value::Json(j) => assert_eq!(j, "\"Alice\""),
other => panic!("Expected Json inside array, got {:?}", other),
}
}
other => panic!("Expected Array, got {:?}", other),
}
}
#[test]
fn test_json_func_jsonb_path_query_first() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT jsonb_path_query_first(CAST('{"x":{"y":42}}' AS JSONB), 'x.y')"#, &[]
).unwrap();
match &rows[0].values[0] {
Value::Json(j) => assert_eq!(j, "42"),
other => panic!("Expected Json, got {:?}", other),
}
}
#[test]
fn test_json_func_jsonb_set() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT jsonb_set(CAST('{"name":"Alice","age":30}' AS JSONB), ARRAY['age'], '31')"#, &[]
).unwrap();
match &rows[0].values[0] {
Value::Json(j) => {
let parsed: serde_json::Value = serde_json::from_str(j).unwrap();
assert_eq!(parsed["name"], "Alice");
assert_eq!(parsed["age"], "31");
}
other => panic!("Expected Json from jsonb_set, got {:?}", other),
}
}
#[test]
fn test_json_func_jsonb_set_nested() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT jsonb_set(CAST('{"user":{"name":"Alice","age":30}}' AS JSONB), ARRAY['user','name'], '"Bob"')"#,
&[]
).unwrap();
match &rows[0].values[0] {
Value::Json(j) => {
let parsed: serde_json::Value = serde_json::from_str(j).unwrap();
assert_eq!(parsed["user"]["age"], 30);
}
other => panic!("Expected Json from jsonb_set nested, got {:?}", other),
}
}
#[test]
fn test_json_func_jsonb_concat() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT jsonb_concat(CAST('{"x":1}' AS JSONB), CAST('{"y":2}' AS JSONB))"#, &[]
).unwrap();
match &rows[0].values[0] {
Value::Json(j) => {
let parsed: serde_json::Value = serde_json::from_str(j).unwrap();
assert_eq!(parsed["x"], 1);
assert_eq!(parsed["y"], 2);
}
other => panic!("Expected Json from jsonb_concat, got {:?}", other),
}
}
#[test]
fn test_json_func_jsonb_concat_overwrites() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT jsonb_concat(CAST('{"x":1,"y":2}' AS JSONB), CAST('{"y":99,"z":3}' AS JSONB))"#, &[]
).unwrap();
match &rows[0].values[0] {
Value::Json(j) => {
let parsed: serde_json::Value = serde_json::from_str(j).unwrap();
assert_eq!(parsed["x"], 1);
assert_eq!(parsed["y"], 99);
assert_eq!(parsed["z"], 3);
}
other => panic!("Expected Json from jsonb_concat, got {:?}", other),
}
}
#[test]
fn test_json_func_jsonb_delete() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT jsonb_delete(CAST('{"a":1,"b":2,"c":3}' AS JSONB), ARRAY['b'])"#, &[]
).unwrap();
match &rows[0].values[0] {
Value::Json(j) => {
let parsed: serde_json::Value = serde_json::from_str(j).unwrap();
assert_eq!(parsed["a"], 1);
assert_eq!(parsed["c"], 3);
assert!(parsed.get("b").is_none());
}
other => panic!("Expected Json from jsonb_delete, got {:?}", other),
}
}
#[test]
fn test_json_func_jsonb_each() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT jsonb_each(CAST('{"x":10,"y":20}' AS JSONB))"#, &[]
).unwrap();
match &rows[0].values[0] {
Value::Array(pairs) => {
assert_eq!(pairs.len(), 4);
let has_x = pairs.iter().any(|v| matches!(v, Value::String(s) if s == "x"));
let has_y = pairs.iter().any(|v| matches!(v, Value::String(s) if s == "y"));
assert!(has_x);
assert!(has_y);
}
other => panic!("Expected Array from jsonb_each, got {:?}", other),
}
}
#[test]
fn test_json_func_jsonb_each_text() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT jsonb_each_text(CAST('{"name":"Alice","age":30}' AS JSONB))"#, &[]
).unwrap();
match &rows[0].values[0] {
Value::Array(pairs) => {
for v in pairs {
assert!(matches!(v, Value::String(_)), "Expected text, got {:?}", v);
}
}
other => panic!("Expected Array from jsonb_each_text, got {:?}", other),
}
}
#[test]
fn test_json_func_jsonb_array_elements() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT jsonb_array_elements(CAST('["first","second","third"]' AS JSONB))"#, &[]
).unwrap();
match &rows[0].values[0] {
Value::Json(j) => assert_eq!(j, "\"first\""),
other => panic!("Expected Json from jsonb_array_elements, got {:?}", other),
}
}
#[test]
fn test_json_func_jsonb_array_elements_text() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT jsonb_array_elements_text(CAST('["hello","world"]' AS JSONB))"#, &[]
).unwrap();
assert_eq!(rows[0].values[0], Value::String("hello".to_string()));
}
#[test]
fn test_json_agg_function() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE json_agg_t (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute("INSERT INTO json_agg_t (id, name) VALUES (1, 'Alice')").unwrap();
db.execute("INSERT INTO json_agg_t (id, name) VALUES (2, 'Bob')").unwrap();
db.execute("INSERT INTO json_agg_t (id, name) VALUES (3, 'Charlie')").unwrap();
let rows = db.query("SELECT json_agg(name) FROM json_agg_t", &[]).unwrap();
match &rows[0].values[0] {
Value::Json(j) => {
let parsed: serde_json::Value = serde_json::from_str(j).unwrap();
let arr = parsed.as_array().unwrap();
assert_eq!(arr.len(), 3);
let strings: Vec<&str> = arr.iter().filter_map(|v| v.as_str()).collect();
assert!(strings.contains(&"Alice"));
assert!(strings.contains(&"Bob"));
assert!(strings.contains(&"Charlie"));
}
other => panic!("Expected Json from json_agg, got {:?}", other),
}
}
#[test]
fn test_json_func_null_handling() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query("SELECT jsonb_extract_path(NULL, 'key')", &[]).unwrap();
assert_eq!(rows[0].values[0], Value::Null);
let rows = db.query("SELECT jsonb_pretty(NULL)", &[]).unwrap();
assert_eq!(rows[0].values[0], Value::Null);
}
#[test]
fn test_json_empty_object_and_array() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query("SELECT jsonb_typeof(CAST('{}' AS JSONB))", &[]).unwrap();
assert_eq!(rows[0].values[0], Value::String("object".to_string()));
let rows = db.query("SELECT jsonb_typeof(CAST('[]' AS JSONB))", &[]).unwrap();
assert_eq!(rows[0].values[0], Value::String("array".to_string()));
let rows = db.query("SELECT jsonb_array_length(CAST('[]' AS JSONB))", &[]).unwrap();
assert_eq!(rows[0].values[0], Value::Int4(0));
let rows = db.query("SELECT jsonb_object_keys(CAST('{}' AS JSONB))", &[]).unwrap();
match &rows[0].values[0] {
Value::Array(keys) => assert_eq!(keys.len(), 0),
other => panic!("Expected empty Array, got {:?}", other),
}
}
#[test]
fn test_json_deeply_nested_via_cast() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT CAST('{"a":{"b":{"c":{"d":{"e":"deep"}}}}}' AS JSONB)->'a'->'b'->'c'->'d'->>'e'"#,
&[]
).unwrap();
assert_eq!(rows[0].values[0], Value::String("deep".to_string()));
}
#[test]
fn test_json_double_arrow_on_null_json_field() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT CAST('{"a":null}' AS JSONB)->>'a'"#, &[]
).unwrap();
assert_eq!(rows[0].values[0], Value::String("null".to_string()));
}
#[test]
fn test_json_contains_false_cases() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
r#"SELECT CAST('{"a":1,"b":2}' AS JSONB) @> CAST('{"a":99}' AS JSONB)"#, &[]
).unwrap();
assert_eq!(rows[0].values[0], Value::Boolean(false));
let rows = db.query(
r#"SELECT CAST('{"a":1,"b":2}' AS JSONB) @> CAST('{"z":1}' AS JSONB)"#, &[]
).unwrap();
assert_eq!(rows[0].values[0], Value::Boolean(false));
}
#[test]
fn test_json_storage_roundtrip_preserves_data() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE json_rt (id INT PRIMARY KEY, data JSONB)").unwrap();
let test_cases = vec![
(1, r#"{"nested":{"a":1,"b":[2,3]}}"#),
(2, r#"[1,"two",true,null]"#),
(3, r#""just a string""#),
(4, r#"42"#),
(5, r#"true"#),
];
for (id, json) in &test_cases {
db.execute(&format!("INSERT INTO json_rt (id, data) VALUES ({}, '{}')", id, json)).unwrap();
}
let rows = db.query("SELECT id, data FROM json_rt ORDER BY id", &[]).unwrap();
assert_eq!(rows.len(), 5);
for (i, (_, expected_json)) in test_cases.iter().enumerate() {
let parsed = parse_json_value(&rows[i].values[1]);
let expected: serde_json::Value = serde_json::from_str(expected_json).unwrap();
assert_eq!(parsed, expected, "Round-trip failed for row {}", i + 1);
}
}
#[test]
fn test_json_delete_rows_from_json_table() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE json_del (id INT PRIMARY KEY, data JSONB)").unwrap();
db.execute(r#"INSERT INTO json_del (id, data) VALUES (1, '{"x":1}')"#).unwrap();
db.execute(r#"INSERT INTO json_del (id, data) VALUES (2, '{"x":2}')"#).unwrap();
db.execute(r#"INSERT INTO json_del (id, data) VALUES (3, '{"x":3}')"#).unwrap();
db.execute("DELETE FROM json_del WHERE id = 2").unwrap();
let rows = db.query("SELECT id FROM json_del ORDER BY id", &[]).unwrap();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].values[0], Value::Int4(1));
assert_eq!(rows[1].values[0], Value::Int4(3));
}
#[test]
fn test_json_build_object_then_arrow() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
"SELECT jsonb_build_object('name', 'Alice', 'age', 30)->>'name'", &[]
).unwrap();
assert_eq!(rows[0].values[0], Value::String("Alice".to_string()));
}
#[test]
fn test_json_build_array_then_index() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query("SELECT jsonb_build_array(10, 20, 30)->1", &[]).unwrap();
match &rows[0].values[0] {
Value::Json(j) => assert_eq!(j, "20"),
other => panic!("Expected Json, got {:?}", other),
}
}
#[test]
fn test_json_typeof_on_null() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query("SELECT jsonb_typeof(NULL)", &[]).unwrap();
assert_eq!(rows[0].values[0], Value::String("null".to_string()));
}
#[test]
fn test_json_mixed_with_regular_columns() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE json_mixed (id INT PRIMARY KEY, name TEXT, meta JSONB)").unwrap();
db.execute(r#"INSERT INTO json_mixed (id, name, meta) VALUES (1, 'Alice', '{"role":"admin"}')"#).unwrap();
db.execute(r#"INSERT INTO json_mixed (id, name, meta) VALUES (2, 'Bob', '{"role":"user"}')"#).unwrap();
let rows = db.query("SELECT name, meta FROM json_mixed ORDER BY id", &[]).unwrap();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].values[0], Value::String("Alice".to_string()));
let meta0 = parse_json_value(&rows[0].values[1]);
assert_eq!(meta0["role"], "admin");
assert_eq!(rows[1].values[0], Value::String("Bob".to_string()));
let meta1 = parse_json_value(&rows[1].values[1]);
assert_eq!(meta1["role"], "user");
}
#[test]
fn test_json_large_document() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE json_large (id INT PRIMARY KEY, data JSONB)").unwrap();
let mut json_obj = serde_json::Map::new();
for i in 0..50 {
json_obj.insert(format!("key_{}", i), serde_json::json!(i));
}
let json_str = serde_json::Value::Object(json_obj).to_string();
db.execute(&format!("INSERT INTO json_large (id, data) VALUES (1, '{}')", json_str)).unwrap();
let rows = db.query("SELECT data FROM json_large WHERE id = 1", &[]).unwrap();
let parsed = parse_json_value(&rows[0].values[0]);
assert_eq!(parsed["key_25"], 25);
assert_eq!(parsed["key_0"], 0);
assert_eq!(parsed["key_49"], 49);
}
#[test]
fn test_json_unicode_content() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE json_uni (id INT PRIMARY KEY, data JSONB)").unwrap();
db.execute(r#"INSERT INTO json_uni (id, data) VALUES (1, '{"greeting":"Bonjour"}')"#).unwrap();
let rows = db.query("SELECT data FROM json_uni WHERE id = 1", &[]).unwrap();
let parsed = parse_json_value(&rows[0].values[0]);
assert_eq!(parsed["greeting"], "Bonjour");
}
#[test]
fn test_recursive_cte_simple_counting() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let sql = "\
WITH RECURSIVE cnt(n) AS ( \
SELECT 1 \
UNION ALL \
SELECT n + 1 FROM cnt WHERE n < 10 \
) \
SELECT n FROM cnt";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 10, "Expected 10 rows for counting 1..10, got {}", rows.len());
for (i, row) in rows.iter().enumerate() {
let val = row.get(0).unwrap();
let expected = (i as i32) + 1;
assert_eq!(
val, &Value::Int4(expected),
"Row {} should be {}, got {:?}", i, expected, val
);
}
}
Err(e) => {
panic!(
"Recursive CTE simple counting failed with error: {}. \
This indicates recursive CTEs may not be supported.",
e
);
}
}
}
#[test]
fn test_recursive_cte_tree_traversal() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE rc_employees (id INT PRIMARY KEY, name TEXT, manager_id INT)").unwrap();
db.execute("INSERT INTO rc_employees VALUES (1, 'CEO', NULL)").unwrap();
db.execute("INSERT INTO rc_employees VALUES (2, 'VP', 1)").unwrap();
db.execute("INSERT INTO rc_employees VALUES (3, 'Director', 2)").unwrap();
db.execute("INSERT INTO rc_employees VALUES (4, 'Manager', 3)").unwrap();
db.execute("INSERT INTO rc_employees VALUES (5, 'Staff', 4)").unwrap();
let sql = "\
WITH RECURSIVE reports(id, name, manager_id) AS ( \
SELECT id, name, manager_id FROM rc_employees WHERE id = 2 \
UNION ALL \
SELECT e.id, e.name, e.manager_id \
FROM rc_employees e \
JOIN reports r ON e.manager_id = r.id \
) \
SELECT id, name FROM reports ORDER BY id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 4, "Expected 4 reports under VP, got {}", rows.len());
let ids: Vec<&Value> = rows.iter().map(|r| r.get(0).unwrap()).collect();
assert_eq!(ids[0], &Value::Int4(2), "First should be VP (id=2)");
assert_eq!(ids[1], &Value::Int4(3), "Second should be Director (id=3)");
assert_eq!(ids[2], &Value::Int4(4), "Third should be Manager (id=4)");
assert_eq!(ids[3], &Value::Int4(5), "Fourth should be Staff (id=5)");
}
Err(e) => {
let err_msg = e.to_string();
assert!(
err_msg.contains("not found") ||
err_msg.contains("not implemented") ||
err_msg.contains("not yet") ||
err_msg.contains("recursive") ||
err_msg.contains("ambiguous") ||
err_msg.contains("column"),
"Unexpected error in recursive CTE tree traversal: {}", err_msg
);
}
}
}
#[test]
fn test_recursive_cte_fibonacci() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let sql = "\
WITH RECURSIVE fib(a, b) AS ( \
SELECT 0, 1 \
UNION ALL \
SELECT b, a + b FROM fib WHERE b < 100 \
) \
SELECT a FROM fib";
match db.query(sql, &[]) {
Ok(rows) => {
let expected_deduped: Vec<i32> = vec![0, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89];
let expected_full: Vec<i32> = vec![0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89];
if rows.len() == expected_deduped.len() {
for (i, (row, exp)) in rows.iter().zip(expected_deduped.iter()).enumerate() {
let val = row.get(0).unwrap();
assert_eq!(val, &Value::Int4(*exp),
"Fibonacci (deduped) row {} should be {}, got {:?}", i, exp, val);
}
} else if rows.len() == expected_full.len() {
for (i, (row, exp)) in rows.iter().zip(expected_full.iter()).enumerate() {
let val = row.get(0).unwrap();
assert_eq!(val, &Value::Int4(*exp),
"Fibonacci row {} should be {}, got {:?}", i, exp, val);
}
} else {
panic!("Expected 11 (deduped) or 12 (full) Fibonacci numbers, got {}", rows.len());
}
}
Err(e) => {
let err_msg = e.to_string();
assert!(
err_msg.contains("not found") ||
err_msg.contains("not implemented") ||
err_msg.contains("not yet") ||
err_msg.contains("column") ||
err_msg.contains("type"),
"Unexpected error in recursive CTE Fibonacci: {}", err_msg
);
}
}
}
#[test]
fn test_recursive_cte_depth_limit_via_where() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let sql = "\
WITH RECURSIVE nums(n) AS ( \
SELECT 1 \
UNION ALL \
SELECT n + 1 FROM nums WHERE n < 5 \
) \
SELECT n FROM nums";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 5, "Expected 5 rows for counting 1..5, got {}", rows.len());
for (i, row) in rows.iter().enumerate() {
let val = row.get(0).unwrap();
let expected = (i as i32) + 1;
assert_eq!(
val, &Value::Int4(expected),
"Row {} should be {}, got {:?}", i, expected, val
);
}
}
Err(e) => {
panic!(
"Recursive CTE with WHERE depth limit failed: {}",
e
);
}
}
}
#[test]
fn test_recursive_cte_non_recursive_basic() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let sql = "WITH summary AS (SELECT 42 AS answer) SELECT answer FROM summary";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 1, "Non-recursive CTE should return 1 row");
let val = rows[0].get(0).unwrap();
assert_eq!(
val, &Value::Int4(42),
"Non-recursive CTE should return 42, got {:?}", val
);
}
Err(e) => {
panic!(
"Non-recursive CTE failed: {}. Basic WITH support should work.",
e
);
}
}
}
#[test]
fn test_recursive_cte_non_recursive_table_data() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE rc_products (id INT, name TEXT, price INT)").unwrap();
db.execute("INSERT INTO rc_products VALUES (1, 'Widget', 10)").unwrap();
db.execute("INSERT INTO rc_products VALUES (2, 'Gadget', 25)").unwrap();
db.execute("INSERT INTO rc_products VALUES (3, 'Doohickey', 5)").unwrap();
let sql = "\
WITH expensive AS ( \
SELECT id, name, price FROM rc_products WHERE price > 8 \
) \
SELECT name FROM expensive ORDER BY name";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 2, "Expected 2 expensive products, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::String("Gadget".to_string()));
assert_eq!(rows[1].get(0).unwrap(), &Value::String("Widget".to_string()));
}
Err(e) => {
panic!("Non-recursive CTE with table data failed: {}", e);
}
}
}
#[test]
fn test_recursive_cte_join_with_table() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE rc_items (id INT PRIMARY KEY, label TEXT)").unwrap();
db.execute("INSERT INTO rc_items VALUES (1, 'alpha')").unwrap();
db.execute("INSERT INTO rc_items VALUES (2, 'beta')").unwrap();
db.execute("INSERT INTO rc_items VALUES (3, 'gamma')").unwrap();
let sql = "\
WITH RECURSIVE nums(n) AS ( \
SELECT 1 \
UNION ALL \
SELECT n + 1 FROM nums WHERE n < 5 \
) \
SELECT nums.n, rc_items.label \
FROM nums \
JOIN rc_items ON nums.n = rc_items.id \
ORDER BY nums.n";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 3, "Expected 3 matched rows, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
assert_eq!(rows[0].get(1).unwrap(), &Value::String("alpha".to_string()));
assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(2));
assert_eq!(rows[1].get(1).unwrap(), &Value::String("beta".to_string()));
assert_eq!(rows[2].get(0).unwrap(), &Value::Int4(3));
assert_eq!(rows[2].get(1).unwrap(), &Value::String("gamma".to_string()));
}
Err(e) => {
let err_msg = e.to_string();
assert!(
err_msg.contains("not found") ||
err_msg.contains("not implemented") ||
err_msg.contains("ambiguous") ||
err_msg.contains("table") ||
err_msg.contains("column") ||
err_msg.contains("type"),
"Unexpected error in recursive CTE JOIN: {}", err_msg
);
}
}
}
#[test]
fn test_recursive_cte_empty_base_case() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let sql = "\
WITH RECURSIVE empty(n) AS ( \
SELECT 1 WHERE 1 = 0 \
UNION ALL \
SELECT n + 1 FROM empty WHERE n < 10 \
) \
SELECT n FROM empty";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(
rows.len(), 0,
"Empty base case should produce 0 rows, got {}", rows.len()
);
}
Err(e) => {
let err_msg = e.to_string();
assert!(
err_msg.contains("not found") ||
err_msg.contains("not implemented") ||
err_msg.contains("column") ||
err_msg.contains("type") ||
err_msg.contains("empty"),
"Unexpected error in empty base case CTE: {}", err_msg
);
}
}
}
#[test]
fn test_recursive_cte_union_vs_union_all() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let sql_union_all = "\
WITH RECURSIVE cnt(n) AS ( \
SELECT 1 \
UNION ALL \
SELECT n + 1 FROM cnt WHERE n < 5 \
) \
SELECT n FROM cnt";
let result_all = db.query(sql_union_all, &[]);
let sql_union = "\
WITH RECURSIVE cnt2(n) AS ( \
SELECT 1 \
UNION \
SELECT n + 1 FROM cnt2 WHERE n < 5 \
) \
SELECT n FROM cnt2";
let result_distinct = db.query(sql_union, &[]);
match (result_all, result_distinct) {
(Ok(rows_all), Ok(rows_distinct)) => {
assert_eq!(
rows_all.len(), 5,
"UNION ALL counting 1..5 should produce 5 rows, got {}", rows_all.len()
);
assert!(
rows_distinct.len() <= rows_all.len(),
"UNION should produce <= rows than UNION ALL ({} vs {})",
rows_distinct.len(), rows_all.len()
);
for (i, row) in rows_all.iter().enumerate() {
let val = row.get(0).unwrap();
let expected = (i as i32) + 1;
assert_eq!(
val, &Value::Int4(expected),
"UNION ALL row {} should be {}, got {:?}", i, expected, val
);
}
}
(Ok(_), Err(e)) => {
let err_msg = e.to_string();
assert!(
err_msg.contains("not implemented") ||
err_msg.contains("not found") ||
err_msg.contains("UNION") ||
err_msg.contains("recursive"),
"Unexpected error in UNION recursive CTE: {}", err_msg
);
}
(Err(e_all), _) => {
panic!(
"UNION ALL recursive CTE failed: {}. Basic recursive CTE should work.",
e_all
);
}
}
}
#[test]
fn test_recursive_cte_with_sum_aggregate() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let sql = "\
WITH RECURSIVE nums(n) AS ( \
SELECT 1 \
UNION ALL \
SELECT n + 1 FROM nums WHERE n < 10 \
) \
SELECT SUM(n) FROM nums";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 1, "Aggregate should return 1 row");
let val = rows[0].get(0).unwrap();
match val {
Value::Int8(v) => assert_eq!(*v, 55, "SUM(1..10) should be 55, got {}", v),
Value::Int4(v) => assert_eq!(*v, 55, "SUM(1..10) should be 55, got {}", v),
Value::Numeric(v) => assert_eq!(v, "55", "SUM(1..10) should be 55, got {}", v),
Value::Float8(v) => {
assert!((*v - 55.0).abs() < 0.001,
"SUM(1..10) should be 55.0, got {}", v);
}
other => panic!("SUM returned unexpected type: {:?}", other),
}
}
Err(e) => {
let err_msg = e.to_string();
assert!(
err_msg.contains("not found") ||
err_msg.contains("not implemented") ||
err_msg.contains("aggregate") ||
err_msg.contains("column"),
"Unexpected error in recursive CTE with aggregate: {}", err_msg
);
}
}
}
#[test]
fn test_recursive_cte_with_limit() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let sql = "\
WITH RECURSIVE nums(n) AS ( \
SELECT 1 \
UNION ALL \
SELECT n + 1 FROM nums WHERE n < 100 \
) \
SELECT n FROM nums LIMIT 5";
let rows = db.query(sql, &[]).unwrap();
assert_eq!(rows.len(), 5, "LIMIT 5 should return 5 rows, got {}", rows.len());
for (i, row) in rows.iter().enumerate() {
let val = row.get(0).unwrap();
let expected = (i as i32) + 1;
assert_eq!(
val, &Value::Int4(expected),
"LIMIT row {} should be {}, got {:?}", i, expected, val
);
}
}
#[test]
fn test_recursive_cte_single_row_termination() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let sql = "\
WITH RECURSIVE one(n) AS ( \
SELECT 100 \
UNION ALL \
SELECT n + 1 FROM one WHERE n < 100 \
) \
SELECT n FROM one";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(
rows.len(), 1,
"Should produce exactly 1 row (base case only), got {}", rows.len()
);
assert_eq!(
rows[0].get(0).unwrap(), &Value::Int4(100),
"Single row should be 100"
);
}
Err(e) => {
panic!("Recursive CTE single-row termination failed: {}", e);
}
}
}
#[test]
fn test_recursive_cte_non_recursive_multiple_ctes() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE rc_multi (id INT, category TEXT, amount INT)").unwrap();
db.execute("INSERT INTO rc_multi VALUES (1, 'A', 10)").unwrap();
db.execute("INSERT INTO rc_multi VALUES (2, 'B', 20)").unwrap();
db.execute("INSERT INTO rc_multi VALUES (3, 'A', 30)").unwrap();
let sql = "\
WITH \
cat_a AS (SELECT id, amount FROM rc_multi WHERE category = 'A'), \
cat_b AS (SELECT id, amount FROM rc_multi WHERE category = 'B') \
SELECT cat_a.id, cat_a.amount FROM cat_a ORDER BY cat_a.id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 2, "Category A has 2 items, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
assert_eq!(rows[0].get(1).unwrap(), &Value::Int4(10));
assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(3));
assert_eq!(rows[1].get(1).unwrap(), &Value::Int4(30));
}
Err(e) => {
let err_msg = e.to_string();
assert!(
err_msg.contains("not found") ||
err_msg.contains("not implemented") ||
err_msg.contains("table") ||
err_msg.contains("CTE"),
"Unexpected error in multiple CTEs: {}", err_msg
);
}
}
}
#[test]
fn test_recursive_cte_graph_path() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE rc_edges (src INT, dst INT)").unwrap();
db.execute("INSERT INTO rc_edges VALUES (1, 2)").unwrap();
db.execute("INSERT INTO rc_edges VALUES (2, 3)").unwrap();
db.execute("INSERT INTO rc_edges VALUES (3, 4)").unwrap();
db.execute("INSERT INTO rc_edges VALUES (1, 5)").unwrap();
let sql = "\
WITH RECURSIVE reachable(node) AS ( \
SELECT 1 \
UNION ALL \
SELECT e.dst FROM rc_edges e JOIN reachable r ON e.src = r.node \
) \
SELECT node FROM reachable ORDER BY node";
match db.query(sql, &[]) {
Ok(rows) => {
let nodes: Vec<i64> = rows.iter().map(|r| {
match r.get(0).unwrap() {
Value::Int8(v) => *v,
Value::Int4(v) => i64::from(*v),
other => panic!("Unexpected node type: {:?}", other),
}
}).collect();
assert!(nodes.contains(&1), "Should contain starting node 1");
assert!(nodes.contains(&2), "Should contain node 2");
assert!(nodes.contains(&3), "Should contain node 3");
assert!(nodes.contains(&4), "Should contain node 4");
assert!(nodes.contains(&5), "Should contain node 5");
assert_eq!(nodes.len(), 5,
"Should have exactly 5 distinct reachable nodes, got {:?}", nodes);
}
Err(e) => {
let err_msg = e.to_string();
assert!(
err_msg.contains("not found") ||
err_msg.contains("not implemented") ||
err_msg.contains("ambiguous") ||
err_msg.contains("column") ||
err_msg.contains("table"),
"Unexpected error in recursive CTE graph traversal: {}", err_msg
);
}
}
}
#[test]
fn test_recursive_cte_string_concatenation() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let sql = "\
WITH RECURSIVE strs(s, len) AS ( \
SELECT 'a', 1 \
UNION ALL \
SELECT s || 'a', len + 1 FROM strs WHERE len < 5 \
) \
SELECT s, len FROM strs ORDER BY len";
let rows = db.query(sql, &[]).unwrap();
assert_eq!(rows.len(), 5, "Expected 5 rows, got {}", rows.len());
let expected = ["a", "aa", "aaa", "aaaa", "aaaaa"];
for (i, row) in rows.iter().enumerate() {
let s = row.get(0).unwrap();
assert_eq!(
s, &Value::String(expected[i].to_string()),
"Row {} should be '{}', got {:?}", i, expected[i], s
);
}
}
#[test]
fn test_recursive_cte_powers_of_two() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let sql = "\
WITH RECURSIVE powers(n) AS ( \
SELECT 1 \
UNION ALL \
SELECT n * 2 FROM powers WHERE n < 512 \
) \
SELECT n FROM powers";
match db.query(sql, &[]) {
Ok(rows) => {
let expected: Vec<i32> = vec![1, 2, 4, 8, 16, 32, 64, 128, 256, 512];
assert_eq!(
rows.len(), expected.len(),
"Expected {} powers of 2, got {}", expected.len(), rows.len()
);
for (i, (row, exp)) in rows.iter().zip(expected.iter()).enumerate() {
let val = row.get(0).unwrap();
assert_eq!(
val, &Value::Int4(*exp),
"Power of 2 row {} should be {}, got {:?}", i, exp, val
);
}
}
Err(e) => {
panic!("Recursive CTE powers of two failed: {}", e);
}
}
}
#[test]
fn test_recursive_cte_with_count_aggregate() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let sql = "\
WITH RECURSIVE nums(n) AS ( \
SELECT 1 \
UNION ALL \
SELECT n + 1 FROM nums WHERE n < 20 \
) \
SELECT COUNT(*) FROM nums";
let rows = db.query(sql, &[]).unwrap();
assert_eq!(rows.len(), 1, "COUNT should return 1 row");
let val = rows[0].get(0).unwrap();
match val {
Value::Int8(v) => assert_eq!(*v, 20, "COUNT(*) should be 20, got {}", v),
Value::Int4(v) => assert_eq!(*v, 20, "COUNT(*) should be 20, got {}", v),
other => panic!("COUNT returned unexpected type: {:?}", other),
}
}
#[test]
fn test_recursive_cte_non_recursive_column_alias() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let sql = "WITH t(x, y) AS (SELECT 10, 20) SELECT x, y FROM t";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 1, "Should return 1 row");
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(10));
assert_eq!(rows[0].get(1).unwrap(), &Value::Int4(20));
}
Err(e) => {
let err_msg = e.to_string();
assert!(
err_msg.contains("not found") ||
err_msg.contains("not implemented") ||
err_msg.contains("column") ||
err_msg.contains("alias"),
"Unexpected error in CTE column alias: {}", err_msg
);
}
}
}
#[test]
fn test_recursive_cte_descending_countdown() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let sql = "\
WITH RECURSIVE countdown(n) AS ( \
SELECT 10 \
UNION ALL \
SELECT n - 1 FROM countdown WHERE n > 1 \
) \
SELECT n FROM countdown";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 10, "Expected 10 rows for 10..1, got {}", rows.len());
for (i, row) in rows.iter().enumerate() {
let val = row.get(0).unwrap();
let expected = 10 - i as i32;
assert_eq!(
val, &Value::Int4(expected),
"Countdown row {} should be {}, got {:?}", i, expected, val
);
}
}
Err(e) => {
panic!("Recursive CTE countdown failed: {}", e);
}
}
}
#[test]
fn test_set_op_union_all_basic() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
"SELECT 1 AS id, 'alice' AS name \
UNION ALL \
SELECT 2, 'bob'",
&[],
).unwrap();
assert_eq!(rows.len(), 2, "UNION ALL of two single-row SELECTs should produce 2 rows");
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
assert_eq!(rows[0].get(1).unwrap(), &Value::String("alice".to_string()));
assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(2));
assert_eq!(rows[1].get(1).unwrap(), &Value::String("bob".to_string()));
}
#[test]
fn test_set_op_union_all_preserves_duplicates() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
"SELECT 1 AS v UNION ALL SELECT 1 UNION ALL SELECT 1",
&[],
).unwrap();
assert_eq!(rows.len(), 3, "UNION ALL of three identical rows should produce 3 rows");
for row in &rows {
assert_eq!(row.get(0).unwrap(), &Value::Int4(1));
}
}
#[test]
fn test_set_op_union_distinct_removes_duplicates() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
"SELECT 1 AS v UNION SELECT 1 UNION SELECT 2",
&[],
).unwrap();
assert_eq!(rows.len(), 2, "UNION of (1, 1, 2) should produce 2 distinct rows");
let mut vals: Vec<i32> = rows.iter()
.map(|r| match r.get(0).unwrap() {
Value::Int4(n) => *n,
other => panic!("Expected Int4, got {:?}", other),
})
.collect();
vals.sort();
assert_eq!(vals, vec![1, 2]);
}
#[test]
fn test_set_op_union_vs_union_all_difference() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows_all = db.query(
"SELECT 1 AS v UNION ALL SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 2",
&[],
).unwrap();
let rows_distinct = db.query(
"SELECT 1 AS v UNION SELECT 1 UNION SELECT 2 UNION SELECT 2",
&[],
).unwrap();
assert_eq!(rows_all.len(), 4, "UNION ALL should produce 4 rows");
assert_eq!(rows_distinct.len(), 2, "UNION (distinct) should produce 2 rows");
}
#[test]
fn test_set_op_intersect_basic() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
"SELECT 1 AS v INTERSECT SELECT 1",
&[],
).unwrap();
assert_eq!(rows.len(), 1, "INTERSECT of (1) and (1) should produce 1 row");
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
}
#[test]
fn test_set_op_intersect_no_overlap() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
"SELECT 1 AS v INTERSECT SELECT 2",
&[],
).unwrap();
assert_eq!(rows.len(), 0, "INTERSECT of (1) and (2) should produce 0 rows");
}
#[test]
fn test_set_op_intersect_with_multiple_values() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let sql = "\
SELECT * FROM (SELECT 1 AS v UNION ALL SELECT 2 UNION ALL SELECT 3) AS a \
INTERSECT \
SELECT * FROM (SELECT 2 AS v UNION ALL SELECT 3 UNION ALL SELECT 4) AS b";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 2, "INTERSECT of (1,2,3) and (2,3,4) should produce 2 rows");
let mut vals: Vec<i32> = rows.iter()
.map(|r| match r.get(0).unwrap() {
Value::Int4(n) => *n,
other => panic!("Expected Int4, got {:?}", other),
})
.collect();
vals.sort();
assert_eq!(vals, vec![2, 3]);
}
Err(e) => {
println!("Subquery-based INTERSECT not supported: {}", e);
db.execute("CREATE TABLE int_left (v INT)").unwrap();
db.execute("INSERT INTO int_left VALUES (1), (2), (3)").unwrap();
db.execute("CREATE TABLE int_right (v INT)").unwrap();
db.execute("INSERT INTO int_right VALUES (2), (3), (4)").unwrap();
let rows = db.query(
"SELECT v FROM int_left INTERSECT SELECT v FROM int_right",
&[],
).unwrap();
assert_eq!(rows.len(), 2, "INTERSECT of (1,2,3) and (2,3,4) should produce 2 rows");
let mut vals: Vec<i32> = rows.iter()
.map(|r| match r.get(0).unwrap() {
Value::Int4(n) => *n,
other => panic!("Expected Int4, got {:?}", other),
})
.collect();
vals.sort();
assert_eq!(vals, vec![2, 3]);
}
}
}
#[test]
fn test_set_op_except_basic() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE exc_left (v INT)").unwrap();
db.execute("INSERT INTO exc_left VALUES (1), (2), (3)").unwrap();
db.execute("CREATE TABLE exc_right (v INT)").unwrap();
db.execute("INSERT INTO exc_right VALUES (2), (3)").unwrap();
let rows = db.query(
"SELECT v FROM exc_left EXCEPT SELECT v FROM exc_right",
&[],
).unwrap();
assert_eq!(rows.len(), 1, "EXCEPT of (1,2,3) minus (2,3) should produce 1 row");
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
}
#[test]
fn test_set_op_except_all_rows_removed() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
"SELECT 1 AS v EXCEPT SELECT 1",
&[],
).unwrap();
assert_eq!(rows.len(), 0, "EXCEPT of (1) minus (1) should produce 0 rows");
}
#[test]
fn test_set_op_except_no_overlap() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
"SELECT 1 AS v EXCEPT SELECT 2",
&[],
).unwrap();
assert_eq!(rows.len(), 1, "EXCEPT of (1) minus (2) should produce 1 row");
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
}
#[test]
fn test_set_op_except_all_with_duplicates() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE ea_left (v INT)").unwrap();
db.execute("INSERT INTO ea_left VALUES (1), (1), (1), (2)").unwrap();
db.execute("CREATE TABLE ea_right (v INT)").unwrap();
db.execute("INSERT INTO ea_right VALUES (1)").unwrap();
match db.query(
"SELECT v FROM ea_left EXCEPT ALL SELECT v FROM ea_right",
&[],
) {
Ok(rows) => {
assert_eq!(rows.len(), 3, "EXCEPT ALL should produce 3 rows (two 1s and one 2)");
let mut vals: Vec<i32> = rows.iter()
.map(|r| match r.get(0).unwrap() {
Value::Int4(n) => *n,
other => panic!("Expected Int4, got {:?}", other),
})
.collect();
vals.sort();
assert_eq!(vals, vec![1, 1, 2], "Should have two 1s and one 2");
}
Err(e) => {
println!("EXCEPT ALL not supported: {}", e);
}
}
}
#[test]
fn test_set_op_intersect_all_with_duplicates() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE ia_left (v INT)").unwrap();
db.execute("INSERT INTO ia_left VALUES (1), (1), (1), (2)").unwrap();
db.execute("CREATE TABLE ia_right (v INT)").unwrap();
db.execute("INSERT INTO ia_right VALUES (1), (1)").unwrap();
match db.query(
"SELECT v FROM ia_left INTERSECT ALL SELECT v FROM ia_right",
&[],
) {
Ok(rows) => {
assert_eq!(rows.len(), 2, "INTERSECT ALL should produce 2 rows (min of 3,2 = 2)");
for row in &rows {
assert_eq!(row.get(0).unwrap(), &Value::Int4(1),
"All INTERSECT ALL results should be 1");
}
}
Err(e) => {
println!("INTERSECT ALL not supported: {}", e);
}
}
}
#[test]
fn test_set_op_multiple_unions_chained() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
"SELECT 1 AS v UNION SELECT 2 UNION SELECT 3",
&[],
).unwrap();
assert_eq!(rows.len(), 3, "Three-way UNION of distinct values should produce 3 rows");
let mut vals: Vec<i32> = rows.iter()
.map(|r| match r.get(0).unwrap() {
Value::Int4(n) => *n,
other => panic!("Expected Int4, got {:?}", other),
})
.collect();
vals.sort();
assert_eq!(vals, vec![1, 2, 3]);
}
#[test]
fn test_set_op_multiple_union_all_chained() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
"SELECT 10 AS v UNION ALL SELECT 20 UNION ALL SELECT 30 UNION ALL SELECT 10",
&[],
).unwrap();
assert_eq!(rows.len(), 4, "Four-way UNION ALL should produce 4 rows");
let vals: Vec<i32> = rows.iter()
.map(|r| match r.get(0).unwrap() {
Value::Int4(n) => *n,
other => panic!("Expected Int4, got {:?}", other),
})
.collect();
assert_eq!(vals, vec![10, 20, 30, 10]);
}
#[test]
fn test_set_op_union_uses_first_select_column_names() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
"SELECT 1 AS first_col UNION ALL SELECT 2 AS second_col",
&[],
).unwrap();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(2));
}
#[test]
fn test_set_op_union_with_order_by() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
"SELECT 3 AS v UNION ALL SELECT 1 UNION ALL SELECT 2 ORDER BY v",
&[],
).unwrap();
assert_eq!(rows.len(), 3, "UNION ALL with ORDER BY should produce 3 rows");
let vals: Vec<i32> = rows.iter()
.map(|r| match r.get(0).unwrap() {
Value::Int4(n) => *n,
other => panic!("Expected Int4, got {:?}", other),
})
.collect();
assert_eq!(vals, vec![1, 2, 3], "ORDER BY v should sort ascending");
}
#[test]
fn test_set_op_union_with_order_by_desc() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
"SELECT 3 AS v UNION ALL SELECT 1 UNION ALL SELECT 2 ORDER BY v DESC",
&[],
).unwrap();
assert_eq!(rows.len(), 3);
let vals: Vec<i32> = rows.iter()
.map(|r| match r.get(0).unwrap() {
Value::Int4(n) => *n,
other => panic!("Expected Int4, got {:?}", other),
})
.collect();
assert_eq!(vals, vec![3, 2, 1], "ORDER BY v DESC should sort descending");
}
#[test]
fn test_set_op_union_with_limit() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
"SELECT 1 AS v UNION ALL SELECT 2 UNION ALL SELECT 3 LIMIT 2",
&[],
).unwrap();
assert_eq!(rows.len(), 2, "UNION ALL with LIMIT 2 should produce 2 rows");
}
#[test]
fn test_set_op_union_with_order_by_and_limit() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
"SELECT 5 AS v UNION ALL SELECT 3 UNION ALL SELECT 1 \
UNION ALL SELECT 4 UNION ALL SELECT 2 \
ORDER BY v LIMIT 3",
&[],
).unwrap();
assert_eq!(rows.len(), 3, "UNION ALL with ORDER BY + LIMIT 3 should produce 3 rows");
let vals: Vec<i32> = rows.iter()
.map(|r| match r.get(0).unwrap() {
Value::Int4(n) => *n,
other => panic!("Expected Int4, got {:?}", other),
})
.collect();
assert_eq!(vals, vec![1, 2, 3], "Should return smallest 3 values sorted");
}
#[test]
fn test_set_op_intersect_empty_result() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE ie_left (v INT)").unwrap();
db.execute("INSERT INTO ie_left VALUES (1), (2)").unwrap();
db.execute("CREATE TABLE ie_right (v INT)").unwrap();
db.execute("INSERT INTO ie_right VALUES (3), (4)").unwrap();
let rows = db.query(
"SELECT v FROM ie_left INTERSECT SELECT v FROM ie_right",
&[],
).unwrap();
assert_eq!(rows.len(), 0, "INTERSECT with no common rows should produce 0 rows");
}
#[test]
fn test_set_op_union_with_null_values() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows_all = db.query(
"SELECT NULL AS v UNION ALL SELECT NULL",
&[],
).unwrap();
assert_eq!(rows_all.len(), 2, "UNION ALL of (NULL, NULL) should produce 2 rows");
assert_eq!(rows_all[0].get(0).unwrap(), &Value::Null);
assert_eq!(rows_all[1].get(0).unwrap(), &Value::Null);
let rows_distinct = db.query(
"SELECT NULL AS v UNION SELECT NULL",
&[],
).unwrap();
assert_eq!(rows_distinct.len(), 1,
"UNION of (NULL, NULL) should dedup to 1 row (SQL standard: NULL = NULL for UNION)");
assert_eq!(rows_distinct[0].get(0).unwrap(), &Value::Null);
}
#[test]
fn test_set_op_intersect_with_null_values() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
"SELECT NULL AS v INTERSECT SELECT NULL",
&[],
).unwrap();
assert_eq!(rows.len(), 1,
"INTERSECT of (NULL) and (NULL) should produce 1 row");
assert_eq!(rows[0].get(0).unwrap(), &Value::Null);
}
#[test]
fn test_set_op_except_with_null_values() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
"SELECT NULL AS v EXCEPT SELECT NULL",
&[],
).unwrap();
assert_eq!(rows.len(), 0,
"EXCEPT of (NULL) minus (NULL) should produce 0 rows");
}
#[test]
fn test_set_op_union_with_table_data() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE employees (id INT, name TEXT, dept TEXT)").unwrap();
db.execute("INSERT INTO employees VALUES (1, 'Alice', 'Eng')").unwrap();
db.execute("INSERT INTO employees VALUES (2, 'Bob', 'Eng')").unwrap();
db.execute("CREATE TABLE contractors (id INT, name TEXT, dept TEXT)").unwrap();
db.execute("INSERT INTO contractors VALUES (3, 'Charlie', 'Eng')").unwrap();
db.execute("INSERT INTO contractors VALUES (4, 'Diana', 'Sales')").unwrap();
let rows = db.query(
"SELECT id, name FROM employees UNION ALL SELECT id, name FROM contractors",
&[],
).unwrap();
assert_eq!(rows.len(), 4, "UNION ALL of 2+2 rows should produce 4 rows");
let names: Vec<String> = rows.iter()
.map(|r| match r.get(1).unwrap() {
Value::String(s) => s.clone(),
other => panic!("Expected String, got {:?}", other),
})
.collect();
assert!(names.contains(&"Alice".to_string()));
assert!(names.contains(&"Bob".to_string()));
assert!(names.contains(&"Charlie".to_string()));
assert!(names.contains(&"Diana".to_string()));
}
#[test]
fn test_set_op_union_distinct_with_table_data() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE colors_a (name TEXT)").unwrap();
db.execute("INSERT INTO colors_a VALUES ('red'), ('green'), ('blue')").unwrap();
db.execute("CREATE TABLE colors_b (name TEXT)").unwrap();
db.execute("INSERT INTO colors_b VALUES ('blue'), ('green'), ('yellow')").unwrap();
let rows = db.query(
"SELECT name FROM colors_a UNION SELECT name FROM colors_b",
&[],
).unwrap();
assert_eq!(rows.len(), 4,
"UNION of (red,green,blue) and (blue,green,yellow) should produce 4 unique rows");
let mut names: Vec<String> = rows.iter()
.map(|r| match r.get(0).unwrap() {
Value::String(s) => s.clone(),
other => panic!("Expected String, got {:?}", other),
})
.collect();
names.sort();
assert_eq!(names, vec!["blue", "green", "red", "yellow"]);
}
#[test]
fn test_set_op_intersect_with_table_data() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE skills_a (skill TEXT)").unwrap();
db.execute("INSERT INTO skills_a VALUES ('rust'), ('python'), ('go')").unwrap();
db.execute("CREATE TABLE skills_b (skill TEXT)").unwrap();
db.execute("INSERT INTO skills_b VALUES ('python'), ('go'), ('java')").unwrap();
let rows = db.query(
"SELECT skill FROM skills_a INTERSECT SELECT skill FROM skills_b",
&[],
).unwrap();
assert_eq!(rows.len(), 2,
"INTERSECT of (rust,python,go) and (python,go,java) should produce 2 rows");
let mut names: Vec<String> = rows.iter()
.map(|r| match r.get(0).unwrap() {
Value::String(s) => s.clone(),
other => panic!("Expected String, got {:?}", other),
})
.collect();
names.sort();
assert_eq!(names, vec!["go", "python"]);
}
#[test]
fn test_set_op_except_with_table_data() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE all_items (item TEXT)").unwrap();
db.execute("INSERT INTO all_items VALUES ('a'), ('b'), ('c'), ('d')").unwrap();
db.execute("CREATE TABLE sold_items (item TEXT)").unwrap();
db.execute("INSERT INTO sold_items VALUES ('b'), ('d')").unwrap();
let rows = db.query(
"SELECT item FROM all_items EXCEPT SELECT item FROM sold_items",
&[],
).unwrap();
assert_eq!(rows.len(), 2,
"EXCEPT of (a,b,c,d) minus (b,d) should produce 2 rows");
let mut names: Vec<String> = rows.iter()
.map(|r| match r.get(0).unwrap() {
Value::String(s) => s.clone(),
other => panic!("Expected String, got {:?}", other),
})
.collect();
names.sort();
assert_eq!(names, vec!["a", "c"]);
}
#[test]
fn test_set_op_union_multi_column() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
"SELECT 1 AS a, 'x' AS b UNION ALL SELECT 2, 'y'",
&[],
).unwrap();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
assert_eq!(rows[0].get(1).unwrap(), &Value::String("x".to_string()));
assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(2));
assert_eq!(rows[1].get(1).unwrap(), &Value::String("y".to_string()));
}
#[test]
fn test_set_op_union_distinct_multi_column() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
"SELECT 1 AS a, 'a' AS b \
UNION SELECT 1, 'a' \
UNION SELECT 1, 'b'",
&[],
).unwrap();
assert_eq!(rows.len(), 2,
"UNION should dedup (1,'a') but keep (1,'b') as distinct");
}
#[test]
fn test_set_op_union_empty_left() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE empty_tbl (v INT)").unwrap();
db.execute("CREATE TABLE full_tbl (v INT)").unwrap();
db.execute("INSERT INTO full_tbl VALUES (1), (2)").unwrap();
let rows = db.query(
"SELECT v FROM empty_tbl UNION ALL SELECT v FROM full_tbl",
&[],
).unwrap();
assert_eq!(rows.len(), 2, "UNION ALL of empty + 2 rows should produce 2 rows");
}
#[test]
fn test_set_op_union_empty_right() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE full_tbl2 (v INT)").unwrap();
db.execute("INSERT INTO full_tbl2 VALUES (1), (2)").unwrap();
db.execute("CREATE TABLE empty_tbl2 (v INT)").unwrap();
let rows = db.query(
"SELECT v FROM full_tbl2 UNION ALL SELECT v FROM empty_tbl2",
&[],
).unwrap();
assert_eq!(rows.len(), 2, "UNION ALL of 2 rows + empty should produce 2 rows");
}
#[test]
fn test_set_op_union_both_empty() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE empty_a (v INT)").unwrap();
db.execute("CREATE TABLE empty_b (v INT)").unwrap();
let rows = db.query(
"SELECT v FROM empty_a UNION ALL SELECT v FROM empty_b",
&[],
).unwrap();
assert_eq!(rows.len(), 0, "UNION ALL of two empty tables should produce 0 rows");
}
#[test]
fn test_set_op_except_empty_right_preserves_left() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE exc_full (v INT)").unwrap();
db.execute("INSERT INTO exc_full VALUES (10), (20), (30)").unwrap();
db.execute("CREATE TABLE exc_empty (v INT)").unwrap();
let rows = db.query(
"SELECT v FROM exc_full EXCEPT SELECT v FROM exc_empty",
&[],
).unwrap();
assert_eq!(rows.len(), 3,
"EXCEPT with empty right should return all 3 left rows");
}
#[test]
fn test_set_op_intersect_empty_right_returns_empty() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE isec_full (v INT)").unwrap();
db.execute("INSERT INTO isec_full VALUES (10), (20)").unwrap();
db.execute("CREATE TABLE isec_empty (v INT)").unwrap();
let rows = db.query(
"SELECT v FROM isec_full INTERSECT SELECT v FROM isec_empty",
&[],
).unwrap();
assert_eq!(rows.len(), 0,
"INTERSECT with empty right should produce 0 rows");
}
#[test]
fn test_set_op_union_with_where_clause() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE nums (v INT)").unwrap();
db.execute("INSERT INTO nums VALUES (1), (2), (3), (4), (5)").unwrap();
let rows = db.query(
"SELECT v FROM nums WHERE v <= 2 UNION ALL SELECT v FROM nums WHERE v >= 4",
&[],
).unwrap();
assert_eq!(rows.len(), 4, "UNION ALL of (1,2) and (4,5) should produce 4 rows");
let mut vals: Vec<i32> = rows.iter()
.map(|r| match r.get(0).unwrap() {
Value::Int4(n) => *n,
other => panic!("Expected Int4, got {:?}", other),
})
.collect();
vals.sort();
assert_eq!(vals, vec![1, 2, 4, 5]);
}
#[test]
fn test_set_op_union_null_mixed_with_values() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
"SELECT 1 AS v UNION SELECT NULL UNION SELECT 2 UNION SELECT NULL",
&[],
).unwrap();
assert_eq!(rows.len(), 3,
"UNION of (1, NULL, 2, NULL) should produce 3 rows (dedup NULLs)");
}
#[test]
fn test_set_op_union_all_large_dataset() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE big_a (id INT, val TEXT)").unwrap();
db.execute("CREATE TABLE big_b (id INT, val TEXT)").unwrap();
for i in 0..50 {
db.execute(&format!("INSERT INTO big_a VALUES ({}, 'a{}')", i, i)).unwrap();
}
for i in 25..75 {
db.execute(&format!("INSERT INTO big_b VALUES ({}, 'b{}')", i, i)).unwrap();
}
let rows = db.query(
"SELECT id, val FROM big_a UNION ALL SELECT id, val FROM big_b",
&[],
).unwrap();
assert_eq!(rows.len(), 100,
"UNION ALL of 50+50 rows should produce 100 rows");
}
#[test]
fn test_set_op_intersect_single_common_row() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
"SELECT 1 AS v INTERSECT SELECT 1",
&[],
).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
}
#[test]
fn test_set_op_except_is_not_symmetric() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE set_a (v INT)").unwrap();
db.execute("INSERT INTO set_a VALUES (1), (2), (3)").unwrap();
db.execute("CREATE TABLE set_b (v INT)").unwrap();
db.execute("INSERT INTO set_b VALUES (2), (3), (4)").unwrap();
let a_except_b = db.query(
"SELECT v FROM set_a EXCEPT SELECT v FROM set_b",
&[],
).unwrap();
let b_except_a = db.query(
"SELECT v FROM set_b EXCEPT SELECT v FROM set_a",
&[],
).unwrap();
assert_eq!(a_except_b.len(), 1);
assert_eq!(b_except_a.len(), 1);
assert_eq!(a_except_b[0].get(0).unwrap(), &Value::Int4(1),
"A EXCEPT B should yield 1");
assert_eq!(b_except_a[0].get(0).unwrap(), &Value::Int4(4),
"B EXCEPT A should yield 4");
}
#[test]
fn test_set_op_union_with_string_values() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
"SELECT 'hello' AS greeting UNION ALL SELECT 'world'",
&[],
).unwrap();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].get(0).unwrap(), &Value::String("hello".to_string()));
assert_eq!(rows[1].get(0).unwrap(), &Value::String("world".to_string()));
}
#[test]
fn test_set_op_union_distinct_string_dedup() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
"SELECT 'same' AS v UNION SELECT 'same' UNION SELECT 'different'",
&[],
).unwrap();
assert_eq!(rows.len(), 2, "UNION should dedup 'same' into one row");
let mut vals: Vec<String> = rows.iter()
.map(|r| match r.get(0).unwrap() {
Value::String(s) => s.clone(),
other => panic!("Expected String, got {:?}", other),
})
.collect();
vals.sort();
assert_eq!(vals, vec!["different", "same"]);
}
#[test]
fn test_set_op_union_with_boolean_values() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
"SELECT TRUE AS flag UNION SELECT FALSE UNION SELECT TRUE",
&[],
).unwrap();
assert_eq!(rows.len(), 2, "UNION of (TRUE, FALSE, TRUE) should produce 2 rows");
}
#[test]
fn test_set_op_union_all_preserves_order() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
"SELECT 100 AS v UNION ALL SELECT 200",
&[],
).unwrap();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(100));
assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(200));
}
#[test]
fn test_set_op_except_self_yields_empty() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE self_exc (v INT)").unwrap();
db.execute("INSERT INTO self_exc VALUES (1), (2), (3)").unwrap();
let rows = db.query(
"SELECT v FROM self_exc EXCEPT SELECT v FROM self_exc",
&[],
).unwrap();
assert_eq!(rows.len(), 0, "Table EXCEPT itself should produce 0 rows");
}
#[test]
fn test_set_op_intersect_self_yields_all() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE self_int (v INT)").unwrap();
db.execute("INSERT INTO self_int VALUES (1), (2), (3)").unwrap();
let rows = db.query(
"SELECT v FROM self_int INTERSECT SELECT v FROM self_int",
&[],
).unwrap();
assert_eq!(rows.len(), 3, "Table INTERSECT itself should return all 3 unique rows");
let mut vals: Vec<i32> = rows.iter()
.map(|r| match r.get(0).unwrap() {
Value::Int4(n) => *n,
other => panic!("Expected Int4, got {:?}", other),
})
.collect();
vals.sort();
assert_eq!(vals, vec![1, 2, 3]);
}
#[test]
fn test_set_op_union_with_expressions() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query(
"SELECT 1 + 1 AS result UNION ALL SELECT 2 * 3",
&[],
).unwrap();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(2));
assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(6));
}
#[test]
fn test_set_op_union_single_row_each() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query("SELECT 42 AS v UNION ALL SELECT 99", &[]).unwrap();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(42));
assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(99));
}
fn setup_subquery_tables() -> EmbeddedDatabase {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE customers (id INT, name TEXT, category TEXT)").unwrap();
db.execute("CREATE TABLE orders (id INT, customer_id INT, amount INT, product_id INT)").unwrap();
db.execute("CREATE TABLE products (id INT, name TEXT, price INT)").unwrap();
db.execute("INSERT INTO customers VALUES (1, 'Alice', 'premium')").unwrap();
db.execute("INSERT INTO customers VALUES (2, 'Bob', 'standard')").unwrap();
db.execute("INSERT INTO customers VALUES (3, 'Charlie', 'premium')").unwrap();
db.execute("INSERT INTO customers VALUES (4, 'Diana', 'standard')").unwrap();
db.execute("INSERT INTO orders VALUES (10, 1, 100, 1)").unwrap();
db.execute("INSERT INTO orders VALUES (11, 1, 200, 2)").unwrap();
db.execute("INSERT INTO orders VALUES (12, 2, 50, 1)").unwrap();
db.execute("INSERT INTO orders VALUES (13, 3, 300, 3)").unwrap();
db.execute("INSERT INTO products VALUES (1, 'Widget', 10)").unwrap();
db.execute("INSERT INTO products VALUES (2, 'Gadget', 25)").unwrap();
db.execute("INSERT INTO products VALUES (3, 'Gizmo', 50)").unwrap();
db.execute("INSERT INTO products VALUES (4, 'Doohickey', 5)").unwrap();
db
}
#[test]
fn test_subquery_in_basic() {
let db = setup_subquery_tables();
let sql = "SELECT id, name FROM customers WHERE id IN (SELECT customer_id FROM orders) ORDER BY id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 3, "3 customers have orders, got {}", rows.len());
assert_eq!(rows[0].get(0), Some(&Value::Int4(1)));
assert_eq!(rows[0].get(1), Some(&Value::String("Alice".to_string())));
assert_eq!(rows[1].get(0), Some(&Value::Int4(2)));
assert_eq!(rows[1].get(1), Some(&Value::String("Bob".to_string())));
assert_eq!(rows[2].get(0), Some(&Value::Int4(3)));
assert_eq!(rows[2].get(1), Some(&Value::String("Charlie".to_string())));
}
Err(e) => {
println!("IN subquery not supported: {}", e);
}
}
}
#[test]
fn test_subquery_not_in_basic() {
let db = setup_subquery_tables();
let sql = "SELECT id, name FROM customers WHERE id NOT IN (SELECT customer_id FROM orders) ORDER BY id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 1, "Only Diana has no orders, got {}", rows.len());
assert_eq!(rows[0].get(0), Some(&Value::Int4(4)));
assert_eq!(rows[0].get(1), Some(&Value::String("Diana".to_string())));
}
Err(e) => {
println!("NOT IN subquery not supported: {}", e);
}
}
}
#[test]
fn test_subquery_in_with_empty_result() {
let db = setup_subquery_tables();
let sql = "SELECT id, name FROM customers WHERE id IN (SELECT customer_id FROM orders WHERE amount > 9999)";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 0, "No orders match amount > 9999, so IN list is empty");
}
Err(e) => {
println!("IN subquery with empty result not supported: {}", e);
}
}
}
#[test]
fn test_subquery_in_all_match() {
let db = setup_subquery_tables();
let sql = "SELECT id FROM customers WHERE id IN (SELECT id FROM customers) ORDER BY id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 4, "All 4 customers should match, got {}", rows.len());
for (i, row) in rows.iter().enumerate() {
let expected_id = (i as i32) + 1;
assert_eq!(row.get(0), Some(&Value::Int4(expected_id)));
}
}
Err(e) => {
println!("IN subquery self-reference not supported: {}", e);
}
}
}
#[test]
fn test_subquery_not_in_all_match() {
let db = setup_subquery_tables();
let sql = "SELECT id FROM customers WHERE id NOT IN (SELECT id FROM customers)";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 0, "NOT IN with all ids should return nothing");
}
Err(e) => {
println!("NOT IN subquery self-reference not supported: {}", e);
}
}
}
#[test]
fn test_subquery_in_products_not_ordered() {
let db = setup_subquery_tables();
let sql = "SELECT id, name FROM products WHERE id NOT IN (SELECT product_id FROM orders) ORDER BY id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 1, "Only Doohickey has no orders, got {}", rows.len());
assert_eq!(rows[0].get(0), Some(&Value::Int4(4)));
assert_eq!(rows[0].get(1), Some(&Value::String("Doohickey".to_string())));
}
Err(e) => {
println!("NOT IN subquery for products not supported: {}", e);
}
}
}
#[test]
fn test_exists_basic_uncorrelated() {
let db = setup_subquery_tables();
let sql = "SELECT id, name FROM customers WHERE EXISTS (SELECT 1 FROM orders) ORDER BY id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 4, "EXISTS(non-empty) should return all 4 customers, got {}", rows.len());
assert_eq!(rows[0].get(1), Some(&Value::String("Alice".to_string())));
assert_eq!(rows[3].get(1), Some(&Value::String("Diana".to_string())));
}
Err(e) => {
println!("EXISTS uncorrelated not supported: {}", e);
}
}
}
#[test]
fn test_exists_with_empty_subquery_result() {
let db = setup_subquery_tables();
let sql = "SELECT id, name FROM customers WHERE EXISTS (SELECT 1 FROM orders WHERE amount > 9999)";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 0, "EXISTS on empty subquery should return 0 rows, got {}", rows.len());
}
Err(e) => {
println!("EXISTS with empty subquery not supported: {}", e);
}
}
}
#[test]
fn test_not_exists_uncorrelated() {
let db = setup_subquery_tables();
let sql = "SELECT id, name FROM customers WHERE NOT EXISTS (SELECT 1 FROM orders)";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 0, "NOT EXISTS(non-empty) should return 0 rows, got {}", rows.len());
}
Err(e) => {
println!("NOT EXISTS uncorrelated not supported: {}", e);
}
}
}
#[test]
fn test_not_exists_with_empty_subquery() {
let db = setup_subquery_tables();
let sql = "SELECT id, name FROM customers WHERE NOT EXISTS (SELECT 1 FROM orders WHERE amount > 9999) ORDER BY id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 4, "NOT EXISTS(empty) should return all 4 customers, got {}", rows.len());
}
Err(e) => {
println!("NOT EXISTS with empty subquery not supported: {}", e);
}
}
}
#[test]
fn test_exists_with_specific_filter() {
let db = setup_subquery_tables();
let sql = "SELECT id, name FROM customers WHERE EXISTS (SELECT 1 FROM orders WHERE amount >= 200) ORDER BY id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 4, "EXISTS with matching filter should return all customers, got {}", rows.len());
}
Err(e) => {
println!("EXISTS with filter not supported: {}", e);
}
}
}
#[test]
fn test_exists_against_empty_table() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE parent (id INT, name TEXT)").unwrap();
db.execute("CREATE TABLE child (id INT, parent_id INT)").unwrap();
db.execute("INSERT INTO parent VALUES (1, 'Alice')").unwrap();
let sql = "SELECT id, name FROM parent WHERE EXISTS (SELECT 1 FROM child)";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 0, "EXISTS on empty table should return 0 rows");
}
Err(e) => {
println!("EXISTS against empty table not supported: {}", e);
}
}
}
#[test]
fn test_not_exists_against_empty_table() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE parent_ne (id INT, name TEXT)").unwrap();
db.execute("CREATE TABLE child_ne (id INT, parent_id INT)").unwrap();
db.execute("INSERT INTO parent_ne VALUES (1, 'Alice')").unwrap();
db.execute("INSERT INTO parent_ne VALUES (2, 'Bob')").unwrap();
let sql = "SELECT id, name FROM parent_ne WHERE NOT EXISTS (SELECT 1 FROM child_ne) ORDER BY id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 2, "NOT EXISTS on empty table should return all parent rows");
assert_eq!(rows[0].get(1), Some(&Value::String("Alice".to_string())));
assert_eq!(rows[1].get(1), Some(&Value::String("Bob".to_string())));
}
Err(e) => {
println!("NOT EXISTS against empty table not supported: {}", e);
}
}
}
#[test]
fn test_exists_correlated_subquery() {
let db = setup_subquery_tables();
let sql = "SELECT id, name FROM customers WHERE EXISTS (SELECT 1 FROM orders WHERE orders.customer_id = customers.id) ORDER BY id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 3, "Correlated EXISTS should find 3 customers with orders, got {}", rows.len());
assert_eq!(rows[0].get(0), Some(&Value::Int4(1)));
assert_eq!(rows[1].get(0), Some(&Value::Int4(2)));
assert_eq!(rows[2].get(0), Some(&Value::Int4(3)));
}
Err(e) => {
println!("Correlated EXISTS not supported: {}", e);
}
}
}
#[test]
fn test_not_exists_correlated_subquery() {
let db = setup_subquery_tables();
let sql = "SELECT id, name FROM customers WHERE NOT EXISTS (SELECT 1 FROM orders WHERE orders.customer_id = customers.id) ORDER BY id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 1, "Correlated NOT EXISTS should find 1 customer without orders, got {}", rows.len());
assert_eq!(rows[0].get(0), Some(&Value::Int4(4)));
assert_eq!(rows[0].get(1), Some(&Value::String("Diana".to_string())));
}
Err(e) => {
println!("Correlated NOT EXISTS not supported: {}", e);
}
}
}
#[test]
fn test_subquery_scalar_in_select() {
let db = setup_subquery_tables();
let sql = "SELECT id, (SELECT COUNT(*) FROM orders WHERE orders.customer_id = customers.id) FROM customers ORDER BY id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 4, "Should return all 4 customers");
println!("Scalar subquery returned {} rows - values: {:?}", rows.len(),
rows.iter().map(|r| (r.get(0), r.get(1))).collect::<Vec<_>>());
}
Err(e) => {
println!("Scalar subquery in SELECT not supported: {}", e);
}
}
}
#[test]
fn test_subquery_scalar_in_where() {
let db = setup_subquery_tables();
let sql = "SELECT id, name FROM customers WHERE id > (SELECT MIN(customer_id) FROM orders) ORDER BY id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 3, "Customers with id > 1, got {}", rows.len());
assert_eq!(rows[0].get(0), Some(&Value::Int4(2)));
assert_eq!(rows[1].get(0), Some(&Value::Int4(3)));
assert_eq!(rows[2].get(0), Some(&Value::Int4(4)));
}
Err(e) => {
println!("Scalar subquery in WHERE not supported: {}", e);
}
}
}
#[test]
fn test_subquery_in_from_clause() {
let db = setup_subquery_tables();
let sql = "SELECT * FROM (SELECT id, name FROM customers WHERE category = 'premium') AS sub ORDER BY id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 2, "2 premium customers, got {}", rows.len());
assert_eq!(rows[0].get(0), Some(&Value::Int4(1)));
assert_eq!(rows[0].get(1), Some(&Value::String("Alice".to_string())));
assert_eq!(rows[1].get(0), Some(&Value::Int4(3)));
assert_eq!(rows[1].get(1), Some(&Value::String("Charlie".to_string())));
}
Err(e) => {
println!("Subquery in FROM clause not supported: {}", e);
}
}
}
#[test]
fn test_subquery_in_from_with_aggregation() {
let db = setup_subquery_tables();
let sql = "SELECT * FROM (SELECT customer_id, SUM(amount) AS total FROM orders GROUP BY customer_id) AS sub ORDER BY customer_id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 3, "3 customers have orders, got {}", rows.len());
println!("FROM subquery with aggregation returned: {:?}",
rows.iter().map(|r| (r.get(0), r.get(1))).collect::<Vec<_>>());
}
Err(e) => {
println!("Subquery in FROM with aggregation not supported: {}", e);
}
}
}
#[test]
fn test_subquery_in_from_empty_result() {
let db = setup_subquery_tables();
let sql = "SELECT * FROM (SELECT id, name FROM customers WHERE id > 999) AS sub";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 0, "No customers with id > 999");
}
Err(e) => {
println!("FROM subquery with empty result not supported: {}", e);
}
}
}
#[test]
fn test_subquery_nested_in() {
let db = setup_subquery_tables();
let sql = "SELECT id, name FROM customers WHERE id IN \
(SELECT customer_id FROM orders WHERE product_id IN \
(SELECT id FROM products WHERE price > 20)) ORDER BY id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 2, "2 customers ordered expensive products, got {}", rows.len());
assert_eq!(rows[0].get(0), Some(&Value::Int4(1)));
assert_eq!(rows[0].get(1), Some(&Value::String("Alice".to_string())));
assert_eq!(rows[1].get(0), Some(&Value::Int4(3)));
assert_eq!(rows[1].get(1), Some(&Value::String("Charlie".to_string())));
}
Err(e) => {
println!("Nested IN subquery not supported: {}", e);
}
}
}
#[test]
fn test_subquery_nested_in_three_levels() {
let db = setup_subquery_tables();
let sql = "SELECT id, name FROM products WHERE id IN \
(SELECT product_id FROM orders WHERE customer_id IN \
(SELECT id FROM customers WHERE category = 'premium')) ORDER BY id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 3, "3 products ordered by premium customers, got {}", rows.len());
assert_eq!(rows[0].get(1), Some(&Value::String("Widget".to_string())));
assert_eq!(rows[1].get(1), Some(&Value::String("Gadget".to_string())));
assert_eq!(rows[2].get(1), Some(&Value::String("Gizmo".to_string())));
}
Err(e) => {
println!("3-level nested IN subquery not supported: {}", e);
}
}
}
#[test]
fn test_exists_and_in_combined() {
let db = setup_subquery_tables();
let sql = "SELECT id, name FROM customers \
WHERE category = 'premium' \
AND EXISTS (SELECT 1 FROM orders) \
ORDER BY id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 2, "2 premium customers when orders exist, got {}", rows.len());
assert_eq!(rows[0].get(1), Some(&Value::String("Alice".to_string())));
assert_eq!(rows[1].get(1), Some(&Value::String("Charlie".to_string())));
}
Err(e) => {
println!("Combined EXISTS and filter not supported: {}", e);
}
}
}
#[test]
fn test_in_subquery_with_distinct() {
let db = setup_subquery_tables();
let sql = "SELECT id, name FROM products WHERE id IN (SELECT DISTINCT product_id FROM orders) ORDER BY id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 3, "3 distinct products in orders, got {}", rows.len());
assert_eq!(rows[0].get(1), Some(&Value::String("Widget".to_string())));
assert_eq!(rows[1].get(1), Some(&Value::String("Gadget".to_string())));
assert_eq!(rows[2].get(1), Some(&Value::String("Gizmo".to_string())));
}
Err(e) => {
println!("IN subquery with DISTINCT not supported: {}", e);
}
}
}
#[test]
fn test_subquery_in_with_expression() {
let db = setup_subquery_tables();
let sql = "SELECT id, name FROM customers WHERE id + 0 IN (SELECT customer_id FROM orders) ORDER BY id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 3, "3 customers with orders (via expression), got {}", rows.len());
assert_eq!(rows[0].get(0), Some(&Value::Int4(1)));
assert_eq!(rows[1].get(0), Some(&Value::Int4(2)));
assert_eq!(rows[2].get(0), Some(&Value::Int4(3)));
}
Err(e) => {
println!("IN subquery with expression not supported: {}", e);
}
}
}
#[test]
fn test_subquery_in_single_value() {
let db = setup_subquery_tables();
let sql = "SELECT id, name FROM customers WHERE id IN (SELECT customer_id FROM orders WHERE amount = 300)";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 1, "Only 1 customer has the 300-amount order, got {}", rows.len());
assert_eq!(rows[0].get(0), Some(&Value::Int4(3)));
assert_eq!(rows[0].get(1), Some(&Value::String("Charlie".to_string())));
}
Err(e) => {
println!("IN subquery with single result not supported: {}", e);
}
}
}
#[test]
fn test_exists_with_select_star_subquery() {
let db = setup_subquery_tables();
let sql = "SELECT id FROM customers WHERE EXISTS (SELECT * FROM orders WHERE amount > 100) ORDER BY id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 4, "EXISTS(SELECT *) with matches returns all outer rows, got {}", rows.len());
}
Err(e) => {
println!("EXISTS with SELECT * not supported: {}", e);
}
}
}
#[test]
fn test_subquery_in_from_with_where() {
let db = setup_subquery_tables();
let sql = "SELECT * FROM (SELECT id, name, category FROM customers) AS sub WHERE category = 'standard' ORDER BY id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 2, "2 standard customers via derived table, got {}", rows.len());
assert_eq!(rows[0].get(1), Some(&Value::String("Bob".to_string())));
assert_eq!(rows[1].get(1), Some(&Value::String("Diana".to_string())));
}
Err(e) => {
println!("Derived table with outer WHERE not supported: {}", e);
}
}
}
#[test]
fn test_subquery_in_from_select_subset() {
let db = setup_subquery_tables();
let sql = "SELECT name FROM (SELECT id, name FROM customers WHERE id <= 2) AS sub ORDER BY name";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 2, "2 customers with id <= 2, got {}", rows.len());
assert_eq!(rows[0].get(0), Some(&Value::String("Alice".to_string())));
assert_eq!(rows[1].get(0), Some(&Value::String("Bob".to_string())));
}
Err(e) => {
println!("Derived table with column subset not supported: {}", e);
}
}
}
#[test]
fn test_subquery_in_single_column_single_row() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE vals (v INT)").unwrap();
db.execute("INSERT INTO vals VALUES (10)").unwrap();
db.execute("INSERT INTO vals VALUES (20)").unwrap();
db.execute("INSERT INTO vals VALUES (30)").unwrap();
let sql = "SELECT v FROM vals WHERE v IN (SELECT MAX(v) FROM vals)";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 1, "Only MAX(v)=30 should match, got {}", rows.len());
assert_eq!(rows[0].get(0), Some(&Value::Int4(30)));
}
Err(e) => {
println!("IN subquery with aggregate not supported: {}", e);
}
}
}
#[test]
fn test_exists_on_single_row_table() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE singleton (v INT)").unwrap();
db.execute("INSERT INTO singleton VALUES (42)").unwrap();
db.execute("CREATE TABLE checker (id INT)").unwrap();
db.execute("INSERT INTO checker VALUES (1)").unwrap();
db.execute("INSERT INTO checker VALUES (2)").unwrap();
let sql = "SELECT id FROM checker WHERE EXISTS (SELECT 1 FROM singleton) ORDER BY id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 2, "EXISTS on single-row table should return all checker rows");
assert_eq!(rows[0].get(0), Some(&Value::Int4(1)));
assert_eq!(rows[1].get(0), Some(&Value::Int4(2)));
}
Err(e) => {
println!("EXISTS on single-row table not supported: {}", e);
}
}
}
#[test]
fn test_subquery_in_with_string_column() {
let db = setup_subquery_tables();
let sql = "SELECT id, name FROM customers WHERE name IN (SELECT name FROM products) ORDER BY id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 0, "No customer names match product names, got {}", rows.len());
}
Err(e) => {
println!("IN subquery with string column not supported: {}", e);
}
}
}
#[test]
fn test_subquery_not_in_with_string_column() {
let db = setup_subquery_tables();
let sql = "SELECT id, name FROM customers WHERE name NOT IN (SELECT name FROM products) ORDER BY id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 4, "All 4 customers have non-product names, got {}", rows.len());
}
Err(e) => {
println!("NOT IN subquery with string column not supported: {}", e);
}
}
}
#[test]
fn test_subquery_in_with_filter_on_subquery() {
let db = setup_subquery_tables();
let sql = "SELECT id, name FROM customers WHERE id IN (SELECT customer_id FROM orders WHERE amount >= 100) ORDER BY id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 2, "2 customers with orders >= 100, got {}", rows.len());
assert_eq!(rows[0].get(0), Some(&Value::Int4(1)));
assert_eq!(rows[0].get(1), Some(&Value::String("Alice".to_string())));
assert_eq!(rows[1].get(0), Some(&Value::Int4(3)));
assert_eq!(rows[1].get(1), Some(&Value::String("Charlie".to_string())));
}
Err(e) => {
println!("IN subquery with WHERE filter not supported: {}", e);
}
}
}
#[test]
fn test_exists_combined_with_or() {
let db = setup_subquery_tables();
let sql = "SELECT id, name FROM customers \
WHERE category = 'premium' \
OR EXISTS (SELECT 1 FROM orders WHERE amount > 9999) \
ORDER BY id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 2, "Only premium customers when EXISTS is false, got {}", rows.len());
assert_eq!(rows[0].get(1), Some(&Value::String("Alice".to_string())));
assert_eq!(rows[1].get(1), Some(&Value::String("Charlie".to_string())));
}
Err(e) => {
println!("EXISTS combined with OR not supported: {}", e);
}
}
}
#[test]
fn test_not_exists_combined_with_and() {
let db = setup_subquery_tables();
let sql = "SELECT id, name FROM customers \
WHERE category = 'standard' \
AND NOT EXISTS (SELECT 1 FROM orders WHERE amount > 9999) \
ORDER BY id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 2, "Standard customers when NOT EXISTS is true, got {}", rows.len());
assert_eq!(rows[0].get(1), Some(&Value::String("Bob".to_string())));
assert_eq!(rows[1].get(1), Some(&Value::String("Diana".to_string())));
}
Err(e) => {
println!("NOT EXISTS combined with AND not supported: {}", e);
}
}
}
fn setup_multi_table_join_db() -> EmbeddedDatabase {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE jt_customers (id INT PRIMARY KEY, name TEXT, city TEXT)")
.unwrap();
db.execute("INSERT INTO jt_customers VALUES (1, 'Alice', 'NYC')").unwrap();
db.execute("INSERT INTO jt_customers VALUES (2, 'Bob', 'LA')").unwrap();
db.execute("INSERT INTO jt_customers VALUES (3, 'Carol', 'NYC')").unwrap();
db.execute("INSERT INTO jt_customers VALUES (4, 'Diana', 'Chicago')").unwrap();
db.execute("CREATE TABLE jt_products (id INT PRIMARY KEY, name TEXT, price INT)")
.unwrap();
db.execute("INSERT INTO jt_products VALUES (10, 'Widget', 100)").unwrap();
db.execute("INSERT INTO jt_products VALUES (20, 'Gadget', 250)").unwrap();
db.execute("INSERT INTO jt_products VALUES (30, 'Doohickey', 50)").unwrap();
db.execute(
"CREATE TABLE jt_orders (id INT PRIMARY KEY, customer_id INT, product_id INT, qty INT)",
)
.unwrap();
db.execute("INSERT INTO jt_orders VALUES (100, 1, 10, 2)").unwrap(); db.execute("INSERT INTO jt_orders VALUES (101, 1, 20, 1)").unwrap(); db.execute("INSERT INTO jt_orders VALUES (102, 2, 10, 5)").unwrap(); db.execute("INSERT INTO jt_orders VALUES (103, 3, 30, 3)").unwrap();
db
}
#[test]
fn test_join_three_table_inner() {
let db = setup_multi_table_join_db();
let sql = "\
SELECT jt_customers.name, jt_products.name, jt_orders.qty \
FROM jt_orders \
JOIN jt_customers ON jt_orders.customer_id = jt_customers.id \
JOIN jt_products ON jt_orders.product_id = jt_products.id \
ORDER BY jt_orders.id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 4, "Expected 4 order rows, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::String("Alice".to_string()));
assert_eq!(rows[0].get(1).unwrap(), &Value::String("Widget".to_string()));
assert_eq!(rows[0].get(2).unwrap(), &Value::Int4(2));
assert_eq!(rows[1].get(0).unwrap(), &Value::String("Alice".to_string()));
assert_eq!(rows[1].get(1).unwrap(), &Value::String("Gadget".to_string()));
assert_eq!(rows[1].get(2).unwrap(), &Value::Int4(1));
assert_eq!(rows[2].get(0).unwrap(), &Value::String("Bob".to_string()));
assert_eq!(rows[2].get(1).unwrap(), &Value::String("Widget".to_string()));
assert_eq!(rows[2].get(2).unwrap(), &Value::Int4(5));
assert_eq!(rows[3].get(0).unwrap(), &Value::String("Carol".to_string()));
assert_eq!(rows[3].get(1).unwrap(), &Value::String("Doohickey".to_string()));
assert_eq!(rows[3].get(2).unwrap(), &Value::Int4(3));
}
Err(e) => {
panic!("3-table INNER JOIN failed: {}", e);
}
}
}
#[test]
fn test_join_four_table_chain() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE jt4_cities (id INT PRIMARY KEY, city_name TEXT)").unwrap();
db.execute("INSERT INTO jt4_cities VALUES (1, 'New York')").unwrap();
db.execute("INSERT INTO jt4_cities VALUES (2, 'Los Angeles')").unwrap();
db.execute("CREATE TABLE jt4_addresses (id INT PRIMARY KEY, street TEXT, city_id INT)")
.unwrap();
db.execute("INSERT INTO jt4_addresses VALUES (10, '123 Main St', 1)").unwrap();
db.execute("INSERT INTO jt4_addresses VALUES (20, '456 Oak Ave', 2)").unwrap();
db.execute(
"CREATE TABLE jt4_customers (id INT PRIMARY KEY, name TEXT, address_id INT)",
)
.unwrap();
db.execute("INSERT INTO jt4_customers VALUES (100, 'Alice', 10)").unwrap();
db.execute("INSERT INTO jt4_customers VALUES (200, 'Bob', 20)").unwrap();
db.execute(
"CREATE TABLE jt4_orders (id INT PRIMARY KEY, customer_id INT, amount INT)",
)
.unwrap();
db.execute("INSERT INTO jt4_orders VALUES (1000, 100, 500)").unwrap();
db.execute("INSERT INTO jt4_orders VALUES (1001, 200, 300)").unwrap();
let sql = "\
SELECT jt4_orders.id, jt4_customers.name, jt4_addresses.street, jt4_cities.city_name \
FROM jt4_orders \
JOIN jt4_customers ON jt4_orders.customer_id = jt4_customers.id \
JOIN jt4_addresses ON jt4_customers.address_id = jt4_addresses.id \
JOIN jt4_cities ON jt4_addresses.city_id = jt4_cities.id \
ORDER BY jt4_orders.id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 2, "Expected 2 rows from 4-table chain, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1000));
assert_eq!(rows[0].get(1).unwrap(), &Value::String("Alice".to_string()));
assert_eq!(rows[0].get(2).unwrap(), &Value::String("123 Main St".to_string()));
assert_eq!(rows[0].get(3).unwrap(), &Value::String("New York".to_string()));
assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(1001));
assert_eq!(rows[1].get(1).unwrap(), &Value::String("Bob".to_string()));
assert_eq!(rows[1].get(2).unwrap(), &Value::String("456 Oak Ave".to_string()));
assert_eq!(rows[1].get(3).unwrap(), &Value::String("Los Angeles".to_string()));
}
Err(e) => {
panic!("4-table JOIN chain failed: {}", e);
}
}
}
fn setup_employee_db() -> EmbeddedDatabase {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute(
"CREATE TABLE jt_employees (id INT PRIMARY KEY, name TEXT, manager_id INT, dept TEXT)",
)
.unwrap();
db.execute("INSERT INTO jt_employees VALUES (1, 'Eve', NULL, 'Exec')").unwrap();
db.execute("INSERT INTO jt_employees VALUES (2, 'Frank', 1, 'Engineering')").unwrap();
db.execute("INSERT INTO jt_employees VALUES (3, 'Grace', 1, 'Sales')").unwrap();
db.execute("INSERT INTO jt_employees VALUES (4, 'Hank', 2, 'Engineering')").unwrap();
db.execute("INSERT INTO jt_employees VALUES (5, 'Iris', 2, 'Engineering')").unwrap();
db.execute("INSERT INTO jt_employees VALUES (6, 'Jack', 3, 'Sales')").unwrap();
db
}
#[test]
fn test_join_self_join_employees() {
let db = setup_employee_db();
let sql = "\
SELECT e.name, m.name \
FROM jt_employees e \
JOIN jt_employees m ON e.manager_id = m.id \
ORDER BY e.id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 5, "5 employees have managers, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::String("Frank".to_string()));
assert_eq!(rows[0].get(1).unwrap(), &Value::String("Eve".to_string()));
assert_eq!(rows[1].get(0).unwrap(), &Value::String("Grace".to_string()));
assert_eq!(rows[1].get(1).unwrap(), &Value::String("Eve".to_string()));
assert_eq!(rows[2].get(0).unwrap(), &Value::String("Hank".to_string()));
assert_eq!(rows[2].get(1).unwrap(), &Value::String("Frank".to_string()));
assert_eq!(rows[3].get(0).unwrap(), &Value::String("Iris".to_string()));
assert_eq!(rows[3].get(1).unwrap(), &Value::String("Frank".to_string()));
assert_eq!(rows[4].get(0).unwrap(), &Value::String("Jack".to_string()));
assert_eq!(rows[4].get(1).unwrap(), &Value::String("Grace".to_string()));
}
Err(e) => {
panic!("Self-join (employees->managers) failed: {}", e);
}
}
}
#[test]
fn test_join_self_join_left_with_null_manager() {
let db = setup_employee_db();
let sql = "\
SELECT e.name, m.name \
FROM jt_employees e \
LEFT JOIN jt_employees m ON e.manager_id = m.id \
ORDER BY e.id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 6, "All 6 employees should appear, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::String("Eve".to_string()));
assert_eq!(rows[0].get(1).unwrap(), &Value::Null);
assert_eq!(rows[1].get(0).unwrap(), &Value::String("Frank".to_string()));
assert_eq!(rows[1].get(1).unwrap(), &Value::String("Eve".to_string()));
}
Err(e) => {
panic!("LEFT JOIN self-join failed: {}", e);
}
}
}
#[test]
fn test_join_left_join_three_tables() {
let db = setup_multi_table_join_db();
let sql = "\
SELECT jt_customers.name, jt_orders.id, jt_products.name \
FROM jt_customers \
LEFT JOIN jt_orders ON jt_customers.id = jt_orders.customer_id \
LEFT JOIN jt_products ON jt_orders.product_id = jt_products.id \
ORDER BY jt_customers.id, jt_orders.id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 5, "Expected 5 rows (4 orders + 1 NULL), got {}", rows.len());
let diana_row = &rows[4];
assert_eq!(diana_row.get(0).unwrap(), &Value::String("Diana".to_string()));
assert_eq!(diana_row.get(1).unwrap(), &Value::Null);
assert_eq!(diana_row.get(2).unwrap(), &Value::Null);
}
Err(e) => {
panic!("3-table LEFT JOIN failed: {}", e);
}
}
}
#[test]
fn test_join_right_join() {
let db = setup_multi_table_join_db();
db.execute("INSERT INTO jt_products VALUES (40, 'Thingamajig', 75)").unwrap();
let sql = "\
SELECT jt_orders.id, jt_products.name \
FROM jt_orders \
RIGHT JOIN jt_products ON jt_orders.product_id = jt_products.id \
ORDER BY jt_products.id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 5, "Expected 5 rows from RIGHT JOIN, got {}", rows.len());
let last = &rows[rows.len() - 1];
assert_eq!(last.get(0).unwrap(), &Value::Null);
assert_eq!(last.get(1).unwrap(), &Value::String("Thingamajig".to_string()));
}
Err(e) => {
panic!("RIGHT JOIN failed: {}", e);
}
}
}
#[test]
fn test_join_cross_join() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE jt_colors (id INT PRIMARY KEY, color TEXT)").unwrap();
db.execute("INSERT INTO jt_colors VALUES (1, 'Red')").unwrap();
db.execute("INSERT INTO jt_colors VALUES (2, 'Blue')").unwrap();
db.execute("CREATE TABLE jt_sizes (id INT PRIMARY KEY, size TEXT)").unwrap();
db.execute("INSERT INTO jt_sizes VALUES (1, 'Small')").unwrap();
db.execute("INSERT INTO jt_sizes VALUES (2, 'Medium')").unwrap();
db.execute("INSERT INTO jt_sizes VALUES (3, 'Large')").unwrap();
let sql = "\
SELECT jt_colors.color, jt_sizes.size \
FROM jt_colors \
CROSS JOIN jt_sizes \
ORDER BY jt_colors.id, jt_sizes.id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 6, "CROSS JOIN should produce 6 rows, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::String("Red".to_string()));
assert_eq!(rows[0].get(1).unwrap(), &Value::String("Small".to_string()));
assert_eq!(rows[5].get(0).unwrap(), &Value::String("Blue".to_string()));
assert_eq!(rows[5].get(1).unwrap(), &Value::String("Large".to_string()));
}
Err(e) => {
panic!("CROSS JOIN failed: {}", e);
}
}
}
#[test]
fn test_join_multiple_conditions() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE jt_left (a INT, b INT, val TEXT)").unwrap();
db.execute("INSERT INTO jt_left VALUES (1, 10, 'x')").unwrap();
db.execute("INSERT INTO jt_left VALUES (1, 20, 'y')").unwrap();
db.execute("INSERT INTO jt_left VALUES (2, 10, 'z')").unwrap();
db.execute("CREATE TABLE jt_right (a INT, b INT, info TEXT)").unwrap();
db.execute("INSERT INTO jt_right VALUES (1, 10, 'match1')").unwrap();
db.execute("INSERT INTO jt_right VALUES (1, 20, 'match2')").unwrap();
db.execute("INSERT INTO jt_right VALUES (2, 20, 'no_match')").unwrap();
let sql = "\
SELECT jt_left.val, jt_right.info \
FROM jt_left \
JOIN jt_right ON jt_left.a = jt_right.a AND jt_left.b = jt_right.b \
ORDER BY jt_left.val";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 2, "Expected 2 rows matching both conditions, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::String("x".to_string()));
assert_eq!(rows[0].get(1).unwrap(), &Value::String("match1".to_string()));
assert_eq!(rows[1].get(0).unwrap(), &Value::String("y".to_string()));
assert_eq!(rows[1].get(1).unwrap(), &Value::String("match2".to_string()));
}
Err(e) => {
panic!("JOIN with multiple conditions failed: {}", e);
}
}
}
#[test]
fn test_join_with_aggregate_count() {
let db = setup_multi_table_join_db();
let sql = "\
SELECT jt_customers.name, COUNT(jt_orders.id) \
FROM jt_customers \
JOIN jt_orders ON jt_customers.id = jt_orders.customer_id \
GROUP BY jt_customers.name \
ORDER BY jt_customers.name";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 3, "3 customers have orders, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::String("Alice".to_string()));
assert_eq!(rows[0].get(1).unwrap(), &Value::Int8(2));
assert_eq!(rows[1].get(0).unwrap(), &Value::String("Bob".to_string()));
assert_eq!(rows[1].get(1).unwrap(), &Value::Int8(1));
assert_eq!(rows[2].get(0).unwrap(), &Value::String("Carol".to_string()));
assert_eq!(rows[2].get(1).unwrap(), &Value::Int8(1));
}
Err(e) => {
panic!("JOIN with aggregate COUNT failed: {}", e);
}
}
}
#[test]
fn test_join_with_where_filter() {
let db = setup_multi_table_join_db();
let sql = "\
SELECT jt_customers.name, jt_products.name, jt_orders.qty \
FROM jt_orders \
JOIN jt_customers ON jt_orders.customer_id = jt_customers.id \
JOIN jt_products ON jt_orders.product_id = jt_products.id \
WHERE jt_orders.qty > 2 \
ORDER BY jt_customers.name";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 2, "Expected 2 orders with qty > 2, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::String("Bob".to_string()));
assert_eq!(rows[0].get(2).unwrap(), &Value::Int4(5));
assert_eq!(rows[1].get(0).unwrap(), &Value::String("Carol".to_string()));
assert_eq!(rows[1].get(2).unwrap(), &Value::Int4(3));
}
Err(e) => {
panic!("JOIN with WHERE filter failed: {}", e);
}
}
}
#[test]
fn test_join_empty_result() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE jt_empty_a (id INT PRIMARY KEY, val TEXT)").unwrap();
db.execute("INSERT INTO jt_empty_a VALUES (1, 'one')").unwrap();
db.execute("INSERT INTO jt_empty_a VALUES (2, 'two')").unwrap();
db.execute("CREATE TABLE jt_empty_b (id INT PRIMARY KEY, ref_id INT, info TEXT)").unwrap();
db.execute("INSERT INTO jt_empty_b VALUES (10, 99, 'orphan1')").unwrap();
db.execute("INSERT INTO jt_empty_b VALUES (20, 98, 'orphan2')").unwrap();
let sql = "\
SELECT jt_empty_a.val, jt_empty_b.info \
FROM jt_empty_a \
JOIN jt_empty_b ON jt_empty_a.id = jt_empty_b.ref_id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 0, "No matching rows, expected 0, got {}", rows.len());
}
Err(e) => {
panic!("Empty JOIN result test failed: {}", e);
}
}
}
#[test]
fn test_join_with_null_fk_left_join() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE jt_depts (id INT PRIMARY KEY, dept_name TEXT)").unwrap();
db.execute("INSERT INTO jt_depts VALUES (1, 'Engineering')").unwrap();
db.execute("INSERT INTO jt_depts VALUES (2, 'Marketing')").unwrap();
db.execute("CREATE TABLE jt_staff (id INT PRIMARY KEY, name TEXT, dept_id INT)").unwrap();
db.execute("INSERT INTO jt_staff VALUES (1, 'Alice', 1)").unwrap();
db.execute("INSERT INTO jt_staff VALUES (2, 'Bob', NULL)").unwrap(); db.execute("INSERT INTO jt_staff VALUES (3, 'Carol', 2)").unwrap();
db.execute("INSERT INTO jt_staff VALUES (4, 'Dave', NULL)").unwrap();
let sql = "\
SELECT jt_staff.name, jt_depts.dept_name \
FROM jt_staff \
LEFT JOIN jt_depts ON jt_staff.dept_id = jt_depts.id \
ORDER BY jt_staff.id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 4, "All 4 staff should appear, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::String("Alice".to_string()));
assert_eq!(rows[0].get(1).unwrap(), &Value::String("Engineering".to_string()));
assert_eq!(rows[1].get(0).unwrap(), &Value::String("Bob".to_string()));
assert_eq!(rows[1].get(1).unwrap(), &Value::Null);
assert_eq!(rows[2].get(0).unwrap(), &Value::String("Carol".to_string()));
assert_eq!(rows[2].get(1).unwrap(), &Value::String("Marketing".to_string()));
assert_eq!(rows[3].get(0).unwrap(), &Value::String("Dave".to_string()));
assert_eq!(rows[3].get(1).unwrap(), &Value::Null);
}
Err(e) => {
panic!("LEFT JOIN with NULL FK failed: {}", e);
}
}
}
#[test]
fn test_join_three_table_with_aggregate_sum() {
let db = setup_multi_table_join_db();
let sql = "\
SELECT jt_customers.name, SUM(jt_orders.qty * jt_products.price) \
FROM jt_orders \
JOIN jt_customers ON jt_orders.customer_id = jt_customers.id \
JOIN jt_products ON jt_orders.product_id = jt_products.id \
GROUP BY jt_customers.name \
ORDER BY jt_customers.name";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 3, "3 customers with orders, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::String("Alice".to_string()));
let alice_total = rows[0].get(1).unwrap();
let alice_val = match alice_total {
Value::Int4(v) => *v as i64,
Value::Int8(v) => *v,
Value::Float8(v) => *v as i64,
other => panic!("Unexpected type for SUM: {:?}", other),
};
assert_eq!(alice_val, 450, "Alice total revenue should be 450, got {}", alice_val);
let bob_total = rows[1].get(1).unwrap();
let bob_val = match bob_total {
Value::Int4(v) => *v as i64,
Value::Int8(v) => *v,
Value::Float8(v) => *v as i64,
other => panic!("Unexpected type for SUM: {:?}", other),
};
assert_eq!(bob_val, 500, "Bob total revenue should be 500, got {}", bob_val);
let carol_total = rows[2].get(1).unwrap();
let carol_val = match carol_total {
Value::Int4(v) => *v as i64,
Value::Int8(v) => *v,
Value::Float8(v) => *v as i64,
other => panic!("Unexpected type for SUM: {:?}", other),
};
assert_eq!(carol_val, 150, "Carol total revenue should be 150, got {}", carol_val);
}
Err(e) => {
panic!("3-table JOIN with SUM aggregate failed: {}", e);
}
}
}
#[test]
fn test_join_self_join_same_department() {
let db = setup_employee_db();
let sql = "\
SELECT e1.name, e2.name, e1.dept \
FROM jt_employees e1 \
JOIN jt_employees e2 ON e1.dept = e2.dept AND e1.id < e2.id \
ORDER BY e1.dept, e1.id, e2.id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 4, "Expected 4 same-dept pairs, got {}", rows.len());
let pairs: Vec<(String, String)> = rows
.iter()
.map(|r| {
let n1 = match r.get(0).unwrap() {
Value::String(s) => s.clone(),
other => panic!("Expected String, got {:?}", other),
};
let n2 = match r.get(1).unwrap() {
Value::String(s) => s.clone(),
other => panic!("Expected String, got {:?}", other),
};
(n1, n2)
})
.collect();
assert!(
pairs.contains(&("Frank".to_string(), "Hank".to_string())),
"Should contain (Frank, Hank)"
);
assert!(
pairs.contains(&("Grace".to_string(), "Jack".to_string())),
"Should contain (Grace, Jack)"
);
}
Err(e) => {
panic!("Self-join same department failed: {}", e);
}
}
}
#[test]
fn test_join_cross_join_with_where() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE jt_t1 (id INT PRIMARY KEY, val TEXT)").unwrap();
db.execute("INSERT INTO jt_t1 VALUES (1, 'a')").unwrap();
db.execute("INSERT INTO jt_t1 VALUES (2, 'b')").unwrap();
db.execute("CREATE TABLE jt_t2 (id INT PRIMARY KEY, t1_id INT, info TEXT)").unwrap();
db.execute("INSERT INTO jt_t2 VALUES (10, 1, 'info1')").unwrap();
db.execute("INSERT INTO jt_t2 VALUES (20, 2, 'info2')").unwrap();
db.execute("INSERT INTO jt_t2 VALUES (30, 1, 'info3')").unwrap();
let sql = "\
SELECT jt_t1.val, jt_t2.info \
FROM jt_t1 \
CROSS JOIN jt_t2 \
WHERE jt_t1.id = jt_t2.t1_id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 3, "3 matching rows expected, got {}", rows.len());
let mut pairs: Vec<(String, String)> = rows
.iter()
.map(|r| {
let val = match r.get(0).unwrap() {
Value::String(s) => s.clone(),
other => panic!("Expected String, got {:?}", other),
};
let info = match r.get(1).unwrap() {
Value::String(s) => s.clone(),
other => panic!("Expected String, got {:?}", other),
};
(val, info)
})
.collect();
pairs.sort();
assert_eq!(
pairs,
vec![
("a".to_string(), "info1".to_string()),
("a".to_string(), "info3".to_string()),
("b".to_string(), "info2".to_string()),
],
"CROSS JOIN + WHERE should produce the correct matching pairs"
);
}
Err(e) => {
panic!("CROSS JOIN with WHERE failed: {}", e);
}
}
}
#[test]
fn test_join_left_join_count_including_zero() {
let db = setup_multi_table_join_db();
let sql = "\
SELECT jt_customers.name, COUNT(jt_orders.id) \
FROM jt_customers \
LEFT JOIN jt_orders ON jt_customers.id = jt_orders.customer_id \
GROUP BY jt_customers.name \
ORDER BY jt_customers.name";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 4, "All 4 customers should appear, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::String("Alice".to_string()));
assert_eq!(rows[0].get(1).unwrap(), &Value::Int8(2));
assert_eq!(rows[1].get(0).unwrap(), &Value::String("Bob".to_string()));
assert_eq!(rows[1].get(1).unwrap(), &Value::Int8(1));
assert_eq!(rows[2].get(0).unwrap(), &Value::String("Carol".to_string()));
assert_eq!(rows[2].get(1).unwrap(), &Value::Int8(1));
assert_eq!(rows[3].get(0).unwrap(), &Value::String("Diana".to_string()));
let diana_count = rows[3].get(1).unwrap();
match diana_count {
Value::Int8(n) => {
assert!(
*n == 0 || *n == 1,
"Diana order count should be 0 (or 1 if NULL counted), got {}",
n
);
}
other => panic!("Expected Int8 for COUNT, got {:?}", other),
}
}
Err(e) => {
panic!("LEFT JOIN + COUNT with zero orders failed: {}", e);
}
}
}
#[test]
fn test_join_three_table_mixed_join_types() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE jt_mix_cust (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute("INSERT INTO jt_mix_cust VALUES (1, 'Alice')").unwrap();
db.execute("INSERT INTO jt_mix_cust VALUES (2, 'Bob')").unwrap();
db.execute("CREATE TABLE jt_mix_orders (id INT PRIMARY KEY, cust_id INT, product TEXT)")
.unwrap();
db.execute("INSERT INTO jt_mix_orders VALUES (10, 1, 'Widget')").unwrap();
db.execute("INSERT INTO jt_mix_orders VALUES (20, 2, 'Gadget')").unwrap();
db.execute(
"CREATE TABLE jt_mix_reviews (id INT PRIMARY KEY, order_id INT, rating INT)",
)
.unwrap();
db.execute("INSERT INTO jt_mix_reviews VALUES (100, 10, 5)").unwrap();
let sql = "\
SELECT jt_mix_cust.name, jt_mix_orders.product, jt_mix_reviews.rating \
FROM jt_mix_cust \
JOIN jt_mix_orders ON jt_mix_cust.id = jt_mix_orders.cust_id \
LEFT JOIN jt_mix_reviews ON jt_mix_orders.id = jt_mix_reviews.order_id \
ORDER BY jt_mix_cust.id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 2, "2 orders expected, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::String("Alice".to_string()));
assert_eq!(rows[0].get(1).unwrap(), &Value::String("Widget".to_string()));
assert_eq!(rows[0].get(2).unwrap(), &Value::Int4(5));
assert_eq!(rows[1].get(0).unwrap(), &Value::String("Bob".to_string()));
assert_eq!(rows[1].get(1).unwrap(), &Value::String("Gadget".to_string()));
assert_eq!(rows[1].get(2).unwrap(), &Value::Null);
}
Err(e) => {
panic!("Mixed INNER+LEFT JOIN failed: {}", e);
}
}
}
#[test]
fn test_join_with_where_and_order_by() {
let db = setup_multi_table_join_db();
let sql = "\
SELECT jt_customers.name, jt_products.name, jt_orders.qty \
FROM jt_orders \
JOIN jt_customers ON jt_orders.customer_id = jt_customers.id \
JOIN jt_products ON jt_orders.product_id = jt_products.id \
WHERE jt_customers.city = 'NYC' \
ORDER BY jt_customers.name, jt_orders.qty";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 3, "3 NYC orders expected, got {}", rows.len());
let mut qtys: Vec<i32> = rows
.iter()
.map(|r| match r.get(2).unwrap() {
Value::Int4(v) => *v,
other => panic!("Expected Int4 for qty, got {:?}", other),
})
.collect();
qtys.sort();
assert_eq!(qtys, vec![1, 2, 3], "NYC orders should have qty 1, 2, 3");
for row in &rows {
let name = match row.get(0).unwrap() {
Value::String(s) => s.as_str(),
other => panic!("Expected String name, got {:?}", other),
};
assert!(
name == "Alice" || name == "Carol",
"All results should be NYC customers, got '{}'",
name
);
}
}
Err(e) => {
panic!("JOIN + WHERE + ORDER BY failed: {}", e);
}
}
}
#[test]
fn test_join_empty_table() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE jt_full (id INT PRIMARY KEY, val TEXT)").unwrap();
db.execute("INSERT INTO jt_full VALUES (1, 'one')").unwrap();
db.execute("INSERT INTO jt_full VALUES (2, 'two')").unwrap();
db.execute("CREATE TABLE jt_empty (id INT PRIMARY KEY, ref_id INT)").unwrap();
let sql = "\
SELECT jt_full.val, jt_empty.ref_id \
FROM jt_full \
JOIN jt_empty ON jt_full.id = jt_empty.ref_id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 0, "JOIN with empty table should produce 0 rows, got {}", rows.len());
}
Err(e) => {
panic!("JOIN with empty table failed: {}", e);
}
}
}
#[test]
fn test_join_left_join_empty_right() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE jt_main (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute("INSERT INTO jt_main VALUES (1, 'alpha')").unwrap();
db.execute("INSERT INTO jt_main VALUES (2, 'beta')").unwrap();
db.execute("CREATE TABLE jt_detail (id INT PRIMARY KEY, main_id INT, note TEXT)")
.unwrap();
let sql = "\
SELECT jt_main.name, jt_detail.note \
FROM jt_main \
LEFT JOIN jt_detail ON jt_main.id = jt_detail.main_id \
ORDER BY jt_main.id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 2, "All left rows should appear, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::String("alpha".to_string()));
assert_eq!(rows[0].get(1).unwrap(), &Value::Null);
assert_eq!(rows[1].get(0).unwrap(), &Value::String("beta".to_string()));
assert_eq!(rows[1].get(1).unwrap(), &Value::Null);
}
Err(e) => {
panic!("LEFT JOIN with empty right table failed: {}", e);
}
}
}
#[test]
fn test_join_cross_join_single_row_tables() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE jt_single_a (val TEXT)").unwrap();
db.execute("INSERT INTO jt_single_a VALUES ('hello')").unwrap();
db.execute("CREATE TABLE jt_single_b (val TEXT)").unwrap();
db.execute("INSERT INTO jt_single_b VALUES ('world')").unwrap();
let sql = "\
SELECT jt_single_a.val, jt_single_b.val \
FROM jt_single_a \
CROSS JOIN jt_single_b";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 1, "1x1 CROSS JOIN should produce 1 row, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::String("hello".to_string()));
assert_eq!(rows[0].get(1).unwrap(), &Value::String("world".to_string()));
}
Err(e) => {
panic!("CROSS JOIN single-row tables failed: {}", e);
}
}
}
#[test]
fn test_join_duplicate_column_values() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE jt_parent (id INT PRIMARY KEY, label TEXT)").unwrap();
db.execute("INSERT INTO jt_parent VALUES (1, 'group_a')").unwrap();
db.execute("CREATE TABLE jt_child (id INT PRIMARY KEY, parent_id INT, name TEXT)")
.unwrap();
db.execute("INSERT INTO jt_child VALUES (10, 1, 'child1')").unwrap();
db.execute("INSERT INTO jt_child VALUES (20, 1, 'child2')").unwrap();
db.execute("INSERT INTO jt_child VALUES (30, 1, 'child3')").unwrap();
let sql = "\
SELECT jt_parent.label, jt_child.name \
FROM jt_parent \
JOIN jt_child ON jt_parent.id = jt_child.parent_id \
ORDER BY jt_child.id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 3, "1 parent x 3 children = 3 rows, got {}", rows.len());
for row in &rows {
assert_eq!(row.get(0).unwrap(), &Value::String("group_a".to_string()));
}
assert_eq!(rows[0].get(1).unwrap(), &Value::String("child1".to_string()));
assert_eq!(rows[1].get(1).unwrap(), &Value::String("child2".to_string()));
assert_eq!(rows[2].get(1).unwrap(), &Value::String("child3".to_string()));
}
Err(e) => {
panic!("One-to-many JOIN failed: {}", e);
}
}
}
#[test]
fn test_join_five_table_chain() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE jt5_a (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute("INSERT INTO jt5_a VALUES (1, 'a1')").unwrap();
db.execute("CREATE TABLE jt5_b (id INT PRIMARY KEY, a_id INT, name TEXT)").unwrap();
db.execute("INSERT INTO jt5_b VALUES (10, 1, 'b1')").unwrap();
db.execute("CREATE TABLE jt5_c (id INT PRIMARY KEY, b_id INT, name TEXT)").unwrap();
db.execute("INSERT INTO jt5_c VALUES (100, 10, 'c1')").unwrap();
db.execute("CREATE TABLE jt5_d (id INT PRIMARY KEY, c_id INT, name TEXT)").unwrap();
db.execute("INSERT INTO jt5_d VALUES (1000, 100, 'd1')").unwrap();
db.execute("CREATE TABLE jt5_e (id INT PRIMARY KEY, d_id INT, name TEXT)").unwrap();
db.execute("INSERT INTO jt5_e VALUES (10000, 1000, 'e1')").unwrap();
let sql = "\
SELECT jt5_a.name, jt5_b.name, jt5_c.name, jt5_d.name, jt5_e.name \
FROM jt5_a \
JOIN jt5_b ON jt5_a.id = jt5_b.a_id \
JOIN jt5_c ON jt5_b.id = jt5_c.b_id \
JOIN jt5_d ON jt5_c.id = jt5_d.c_id \
JOIN jt5_e ON jt5_d.id = jt5_e.d_id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 1, "5-table chain should produce 1 row, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::String("a1".to_string()));
assert_eq!(rows[0].get(1).unwrap(), &Value::String("b1".to_string()));
assert_eq!(rows[0].get(2).unwrap(), &Value::String("c1".to_string()));
assert_eq!(rows[0].get(3).unwrap(), &Value::String("d1".to_string()));
assert_eq!(rows[0].get(4).unwrap(), &Value::String("e1".to_string()));
}
Err(e) => {
panic!("5-table JOIN chain failed: {}", e);
}
}
}
#[test]
fn test_join_self_join_two_levels() {
let db = setup_employee_db();
let sql = "\
SELECT e.name, m.name, gm.name \
FROM jt_employees e \
JOIN jt_employees m ON e.manager_id = m.id \
JOIN jt_employees gm ON m.manager_id = gm.id \
ORDER BY e.id";
match db.query(sql, &[]) {
Ok(rows) => {
assert_eq!(rows.len(), 3, "3 employees have grandmanagers, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::String("Hank".to_string()));
assert_eq!(rows[0].get(1).unwrap(), &Value::String("Frank".to_string()));
assert_eq!(rows[0].get(2).unwrap(), &Value::String("Eve".to_string()));
assert_eq!(rows[1].get(0).unwrap(), &Value::String("Iris".to_string()));
assert_eq!(rows[1].get(1).unwrap(), &Value::String("Frank".to_string()));
assert_eq!(rows[1].get(2).unwrap(), &Value::String("Eve".to_string()));
assert_eq!(rows[2].get(0).unwrap(), &Value::String("Jack".to_string()));
assert_eq!(rows[2].get(1).unwrap(), &Value::String("Grace".to_string()));
assert_eq!(rows[2].get(2).unwrap(), &Value::String("Eve".to_string()));
}
Err(e) => {
panic!("Two-level self-join (grandmanager) failed: {}", e);
}
}
}
#[test]
fn test_join_alias_column_resolution_in_where() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE wp_term_taxonomy (term_taxonomy_id INT PRIMARY KEY, term_id INT, taxonomy TEXT)").unwrap();
db.execute("CREATE TABLE wp_terms (term_id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute("INSERT INTO wp_terms VALUES (1, 'Uncategorized')").unwrap();
db.execute("INSERT INTO wp_terms VALUES (2, 'News')").unwrap();
db.execute("INSERT INTO wp_term_taxonomy VALUES (1, 1, 'category')").unwrap();
db.execute("INSERT INTO wp_term_taxonomy VALUES (2, 2, 'category')").unwrap();
db.execute("INSERT INTO wp_term_taxonomy VALUES (3, 2, 'post_tag')").unwrap();
let rows = db.query(
"SELECT tt.term_taxonomy_id FROM wp_term_taxonomy AS tt \
INNER JOIN wp_terms AS t ON t.term_id = tt.term_id \
WHERE tt.taxonomy = 'category'",
&[],
).expect("WordPress-style JOIN with aliased WHERE column should work");
assert_eq!(rows.len(), 2, "Should find 2 category rows");
let rows = db.query(
"SELECT t.name, tt.taxonomy FROM wp_term_taxonomy AS tt \
INNER JOIN wp_terms AS t ON t.term_id = tt.term_id \
WHERE tt.taxonomy = 'category' ORDER BY t.name",
&[],
).expect("Multi-column aliased JOIN should work");
assert_eq!(rows.len(), 2);
db.execute("CREATE TABLE t1 (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute("CREATE TABLE t2 (id INT PRIMARY KEY, t1_id INT, value TEXT)").unwrap();
db.execute("INSERT INTO t1 VALUES (1, 'Alice')").unwrap();
db.execute("INSERT INTO t2 VALUES (1, 1, 'hello')").unwrap();
let rows = db.query(
"SELECT a.name, b.value FROM t1 AS a INNER JOIN t2 AS b ON a.id = b.t1_id WHERE a.name = 'Alice'",
&[],
).expect("JOIN with aliased WHERE column should work");
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get(0).unwrap(), &Value::String("Alice".to_string()));
assert_eq!(rows[0].get(1).unwrap(), &Value::String("hello".to_string()));
db.execute("CREATE TABLE wp_term_relationships (object_id INT, term_taxonomy_id INT)").unwrap();
db.execute("INSERT INTO wp_term_relationships VALUES (10, 1)").unwrap();
db.execute("INSERT INTO wp_term_relationships VALUES (20, 2)").unwrap();
db.execute("INSERT INTO wp_term_relationships VALUES (30, 3)").unwrap();
let rows = db.query(
"SELECT tr.object_id, tt.taxonomy, t.name \
FROM wp_term_relationships AS tr \
INNER JOIN wp_term_taxonomy AS tt ON tr.term_taxonomy_id = tt.term_taxonomy_id \
INNER JOIN wp_terms AS t ON t.term_id = tt.term_id \
WHERE tt.taxonomy = 'category'",
&[],
).expect("Three-table WordPress-style JOIN should work");
assert_eq!(rows.len(), 2, "Should find 2 relationships with category taxonomy");
let rows = db.query(
"SELECT tt.term_taxonomy_id FROM wp_term_taxonomy AS tt \
INNER JOIN wp_terms AS t ON tt.term_id = t.term_id \
WHERE tt.taxonomy = 'category'",
&[],
).expect("JOIN with swapped ON column order should work");
assert_eq!(rows.len(), 2, "Should still find 2 category rows with swapped ON order");
}
#[test]
fn test_truncate_basic() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE trunc_basic (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute("INSERT INTO trunc_basic VALUES (1, 'Alice')").unwrap();
db.execute("INSERT INTO trunc_basic VALUES (2, 'Bob')").unwrap();
db.execute("INSERT INTO trunc_basic VALUES (3, 'Charlie')").unwrap();
let rows = db.query("SELECT * FROM trunc_basic", &[]).unwrap();
assert_eq!(rows.len(), 3, "Should have 3 rows before TRUNCATE");
db.execute("TRUNCATE TABLE trunc_basic").unwrap();
let rows = db.query("SELECT * FROM trunc_basic", &[]).unwrap();
assert_eq!(rows.len(), 0, "Should have 0 rows after TRUNCATE");
}
#[test]
fn test_truncate_preserves_schema() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE trunc_schema (id INT PRIMARY KEY, name TEXT, score FLOAT)").unwrap();
db.execute("INSERT INTO trunc_schema VALUES (1, 'Alice', 95.5)").unwrap();
db.execute("TRUNCATE TABLE trunc_schema").unwrap();
db.execute("INSERT INTO trunc_schema VALUES (10, 'David', 88.0)").unwrap();
let rows = db.query("SELECT id, name, score FROM trunc_schema", &[]).unwrap();
assert_eq!(rows.len(), 1, "Should have 1 row after re-insert");
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(10));
assert_eq!(rows[0].get(1).unwrap(), &Value::String("David".to_string()));
}
#[test]
fn test_truncate_empty_table() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE trunc_empty (id INT PRIMARY KEY, val TEXT)").unwrap();
let result = db.execute("TRUNCATE TABLE trunc_empty");
assert!(result.is_ok(), "TRUNCATE on empty table should succeed, got: {:?}", result.err());
let rows = db.query("SELECT * FROM trunc_empty", &[]).unwrap();
assert_eq!(rows.len(), 0, "Empty table should remain empty after TRUNCATE");
}
#[test]
fn test_truncate_reinsert_after() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE trunc_reinsert (id INT PRIMARY KEY, label TEXT)").unwrap();
db.execute("INSERT INTO trunc_reinsert VALUES (1, 'first')").unwrap();
db.execute("INSERT INTO trunc_reinsert VALUES (2, 'second')").unwrap();
db.execute("TRUNCATE TABLE trunc_reinsert").unwrap();
db.execute("INSERT INTO trunc_reinsert VALUES (1, 'new_first')").unwrap();
db.execute("INSERT INTO trunc_reinsert VALUES (3, 'third')").unwrap();
let rows = db.query("SELECT * FROM trunc_reinsert ORDER BY id", &[]).unwrap();
assert_eq!(rows.len(), 2, "Should have 2 rows after re-insert");
assert_eq!(rows[0].get(1).unwrap(), &Value::String("new_first".to_string()));
assert_eq!(rows[1].get(1).unwrap(), &Value::String("third".to_string()));
}
#[test]
fn test_truncate_multiple_tables() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE trunc_a (id INT PRIMARY KEY, val TEXT)").unwrap();
db.execute("CREATE TABLE trunc_b (id INT PRIMARY KEY, val TEXT)").unwrap();
db.execute("INSERT INTO trunc_a VALUES (1, 'a1')").unwrap();
db.execute("INSERT INTO trunc_a VALUES (2, 'a2')").unwrap();
db.execute("INSERT INTO trunc_b VALUES (10, 'b1')").unwrap();
db.execute("INSERT INTO trunc_b VALUES (20, 'b2')").unwrap();
db.execute("INSERT INTO trunc_b VALUES (30, 'b3')").unwrap();
db.execute("TRUNCATE TABLE trunc_a").unwrap();
let rows_a = db.query("SELECT * FROM trunc_a", &[]).unwrap();
let rows_b = db.query("SELECT * FROM trunc_b", &[]).unwrap();
assert_eq!(rows_a.len(), 0, "Table A should be empty after TRUNCATE");
assert_eq!(rows_b.len(), 3, "Table B should be unaffected by TRUNCATE of A");
db.execute("TRUNCATE TABLE trunc_b").unwrap();
let rows_b = db.query("SELECT * FROM trunc_b", &[]).unwrap();
assert_eq!(rows_b.len(), 0, "Table B should be empty after TRUNCATE");
}
#[test]
fn test_truncate_with_many_rows() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE trunc_many (id INT PRIMARY KEY, val INT)").unwrap();
for i in 1..=150 {
db.execute(&format!("INSERT INTO trunc_many VALUES ({}, {})", i, i * 10)).unwrap();
}
let rows = db.query("SELECT COUNT(*) FROM trunc_many", &[]).unwrap();
match rows[0].get(0).unwrap() {
Value::Int8(n) => assert_eq!(*n, 150, "Should have 150 rows before TRUNCATE"),
other => panic!("Expected Int8 count, got {:?}", other),
}
db.execute("TRUNCATE TABLE trunc_many").unwrap();
let rows = db.query("SELECT COUNT(*) FROM trunc_many", &[]).unwrap();
match rows[0].get(0).unwrap() {
Value::Int8(n) => assert_eq!(*n, 0, "Should have 0 rows after TRUNCATE"),
other => panic!("Expected Int8 count of 0, got {:?}", other),
}
}
#[test]
fn test_truncate_preserves_indexes() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE trunc_idx (id INT PRIMARY KEY, name TEXT, score INT)").unwrap();
db.execute("INSERT INTO trunc_idx VALUES (1, 'Alice', 90)").unwrap();
db.execute("INSERT INTO trunc_idx VALUES (2, 'Bob', 85)").unwrap();
db.execute("TRUNCATE TABLE trunc_idx").unwrap();
db.execute("INSERT INTO trunc_idx VALUES (5, 'Eve', 95)").unwrap();
db.execute("INSERT INTO trunc_idx VALUES (6, 'Frank', 80)").unwrap();
let rows = db.query("SELECT * FROM trunc_idx WHERE id = 5", &[]).unwrap();
assert_eq!(rows.len(), 1, "PK lookup should work after TRUNCATE + re-insert");
assert_eq!(rows[0].get(1).unwrap(), &Value::String("Eve".to_string()));
let rows = db.query("SELECT name FROM trunc_idx ORDER BY score DESC", &[]).unwrap();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].get(0).unwrap(), &Value::String("Eve".to_string()));
assert_eq!(rows[1].get(0).unwrap(), &Value::String("Frank".to_string()));
}
#[test]
fn test_truncate_nonexistent_table() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let result = db.execute("TRUNCATE TABLE no_such_table");
assert!(result.is_err(), "TRUNCATE on non-existent table should error");
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.to_lowercase().contains("no_such_table") || err_msg.to_lowercase().contains("not exist") || err_msg.to_lowercase().contains("not found"),
"Error should mention missing table, got: {}",
err_msg
);
}
#[test]
fn test_truncate_returns_zero() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE trunc_count (id INT PRIMARY KEY)").unwrap();
db.execute("INSERT INTO trunc_count VALUES (1)").unwrap();
db.execute("INSERT INTO trunc_count VALUES (2)").unwrap();
db.execute("INSERT INTO trunc_count VALUES (3)").unwrap();
let count = db.execute("TRUNCATE TABLE trunc_count").unwrap();
assert_eq!(count, 3, "TRUNCATE returns actual row count");
let rows = db.query("SELECT * FROM trunc_count", &[]).unwrap();
assert_eq!(rows.len(), 0, "All rows should be removed");
}
#[test]
fn test_truncate_then_count() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE trunc_cnt (id INT PRIMARY KEY, x INT)").unwrap();
for i in 1..=5 {
db.execute(&format!("INSERT INTO trunc_cnt VALUES ({}, {})", i, i)).unwrap();
}
db.execute("TRUNCATE TABLE trunc_cnt").unwrap();
let rows = db.query("SELECT COUNT(*) FROM trunc_cnt", &[]).unwrap();
match rows[0].get(0).unwrap() {
Value::Int8(n) => assert_eq!(*n, 0),
other => panic!("Expected Int8(0), got {:?}", other),
}
db.execute("INSERT INTO trunc_cnt VALUES (10, 100)").unwrap();
db.execute("INSERT INTO trunc_cnt VALUES (20, 200)").unwrap();
let rows = db.query("SELECT COUNT(*) FROM trunc_cnt", &[]).unwrap();
match rows[0].get(0).unwrap() {
Value::Int8(n) => assert_eq!(*n, 2, "COUNT should be 2 after re-inserting 2 rows"),
other => panic!("Expected Int8(2), got {:?}", other),
}
}
#[test]
fn test_fk_basic_creation() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE fk_parent (id INT PRIMARY KEY, name TEXT)").unwrap();
let result = db.execute(
"CREATE TABLE fk_child (
id INT PRIMARY KEY,
parent_id INT,
FOREIGN KEY (parent_id) REFERENCES fk_parent(id)
)"
);
assert!(result.is_ok(), "Creating table with FK constraint should succeed, got: {:?}", result.err());
}
#[test]
fn test_fk_insert_valid() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE fk_iv_parent (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute(
"CREATE TABLE fk_iv_child (
id INT PRIMARY KEY,
parent_id INT,
FOREIGN KEY (parent_id) REFERENCES fk_iv_parent(id)
)"
).unwrap();
db.execute("INSERT INTO fk_iv_parent VALUES (1, 'Alice')").unwrap();
let result = db.execute("INSERT INTO fk_iv_child VALUES (100, 1)");
assert!(result.is_ok(), "Insert with valid FK reference should succeed, got: {:?}", result.err());
let rows = db.query("SELECT * FROM fk_iv_child WHERE parent_id = 1", &[]).unwrap();
assert_eq!(rows.len(), 1, "Child row should be inserted");
}
#[test]
fn test_fk_insert_invalid() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE fk_ii_parent (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute(
"CREATE TABLE fk_ii_child (
id INT PRIMARY KEY,
parent_id INT,
FOREIGN KEY (parent_id) REFERENCES fk_ii_parent(id)
)"
).unwrap();
let result = db.execute("INSERT INTO fk_ii_child VALUES (1, 999)");
assert!(result.is_err(), "Insert with invalid FK reference should fail");
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.to_lowercase().contains("foreign key") || err_msg.to_lowercase().contains("constraint"),
"Error should mention foreign key constraint, got: {}",
err_msg
);
}
#[test]
fn test_fk_insert_null_fk_value() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE fk_null_parent (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute(
"CREATE TABLE fk_null_child (
id INT PRIMARY KEY,
parent_id INT,
FOREIGN KEY (parent_id) REFERENCES fk_null_parent(id)
)"
).unwrap();
let result = db.execute("INSERT INTO fk_null_child VALUES (1, NULL)");
assert!(result.is_ok(), "Insert with NULL FK value should succeed, got: {:?}", result.err());
let rows = db.query("SELECT * FROM fk_null_child", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get(1).unwrap(), &Value::Null);
}
#[test]
fn test_fk_delete_parent_default_action() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE fk_dp_parent (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute(
"CREATE TABLE fk_dp_child (
id INT PRIMARY KEY,
parent_id INT,
FOREIGN KEY (parent_id) REFERENCES fk_dp_parent(id)
)"
).unwrap();
db.execute("INSERT INTO fk_dp_parent VALUES (1, 'Alice')").unwrap();
db.execute("INSERT INTO fk_dp_child VALUES (100, 1)").unwrap();
let result = db.execute("DELETE FROM fk_dp_parent WHERE id = 1");
assert!(result.is_err(), "Deleting parent with referencing children should fail with default action");
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.to_lowercase().contains("foreign key") || err_msg.to_lowercase().contains("constraint") || err_msg.to_lowercase().contains("referenced"),
"Error should mention FK constraint, got: {}",
err_msg
);
let parent_rows = db.query("SELECT * FROM fk_dp_parent", &[]).unwrap();
let child_rows = db.query("SELECT * FROM fk_dp_child", &[]).unwrap();
assert_eq!(parent_rows.len(), 1, "Parent should still exist");
assert_eq!(child_rows.len(), 1, "Child should still exist");
}
#[test]
fn test_fk_cascade_delete() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE fk_cd_parent (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute(
"CREATE TABLE fk_cd_child (
id INT PRIMARY KEY,
parent_id INT,
FOREIGN KEY (parent_id) REFERENCES fk_cd_parent(id) ON DELETE CASCADE
)"
).unwrap();
db.execute("INSERT INTO fk_cd_parent VALUES (1, 'Alice')").unwrap();
db.execute("INSERT INTO fk_cd_parent VALUES (2, 'Bob')").unwrap();
db.execute("INSERT INTO fk_cd_child VALUES (100, 1)").unwrap();
db.execute("INSERT INTO fk_cd_child VALUES (101, 1)").unwrap();
db.execute("INSERT INTO fk_cd_child VALUES (102, 2)").unwrap();
db.execute("DELETE FROM fk_cd_parent WHERE id = 1").unwrap();
let parent_rows = db.query("SELECT * FROM fk_cd_parent", &[]).unwrap();
assert_eq!(parent_rows.len(), 1, "Only parent id=2 should remain");
assert_eq!(parent_rows[0].get(0).unwrap(), &Value::Int4(2));
let child_rows = db.query("SELECT * FROM fk_cd_child", &[]).unwrap();
assert_eq!(child_rows.len(), 1, "Only child 102 (referencing parent 2) should remain");
assert_eq!(child_rows[0].get(0).unwrap(), &Value::Int4(102));
}
#[test]
fn test_fk_set_null_delete() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE fk_sn_parent (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute(
"CREATE TABLE fk_sn_child (
id INT PRIMARY KEY,
parent_id INT,
FOREIGN KEY (parent_id) REFERENCES fk_sn_parent(id) ON DELETE SET NULL
)"
).unwrap();
db.execute("INSERT INTO fk_sn_parent VALUES (1, 'Alice')").unwrap();
db.execute("INSERT INTO fk_sn_child VALUES (100, 1)").unwrap();
db.execute("INSERT INTO fk_sn_child VALUES (101, 1)").unwrap();
db.execute("DELETE FROM fk_sn_parent WHERE id = 1").unwrap();
let child_rows = db.query("SELECT id, parent_id FROM fk_sn_child ORDER BY id", &[]).unwrap();
assert_eq!(child_rows.len(), 2, "Child rows should still exist");
assert_eq!(child_rows[0].get(1).unwrap(), &Value::Null, "parent_id should be NULL after SET NULL");
assert_eq!(child_rows[1].get(1).unwrap(), &Value::Null, "parent_id should be NULL after SET NULL");
}
#[test]
fn test_fk_restrict_delete() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE fk_rd_parent (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute(
"CREATE TABLE fk_rd_child (
id INT PRIMARY KEY,
parent_id INT,
FOREIGN KEY (parent_id) REFERENCES fk_rd_parent(id) ON DELETE RESTRICT
)"
).unwrap();
db.execute("INSERT INTO fk_rd_parent VALUES (1, 'Alice')").unwrap();
db.execute("INSERT INTO fk_rd_child VALUES (100, 1)").unwrap();
let result = db.execute("DELETE FROM fk_rd_parent WHERE id = 1");
assert!(result.is_err(), "RESTRICT should prevent parent deletion");
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.to_lowercase().contains("foreign key") || err_msg.to_lowercase().contains("constraint"),
"Error should mention FK constraint, got: {}",
err_msg
);
let parent_rows = db.query("SELECT * FROM fk_rd_parent", &[]).unwrap();
let child_rows = db.query("SELECT * FROM fk_rd_child", &[]).unwrap();
assert_eq!(parent_rows.len(), 1);
assert_eq!(child_rows.len(), 1);
}
#[test]
fn test_fk_restrict_allows_delete_when_no_children() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE fk_ra_parent (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute(
"CREATE TABLE fk_ra_child (
id INT PRIMARY KEY,
parent_id INT,
FOREIGN KEY (parent_id) REFERENCES fk_ra_parent(id) ON DELETE RESTRICT
)"
).unwrap();
db.execute("INSERT INTO fk_ra_parent VALUES (1, 'Alice')").unwrap();
db.execute("INSERT INTO fk_ra_parent VALUES (2, 'Bob')").unwrap();
db.execute("INSERT INTO fk_ra_child VALUES (100, 1)").unwrap();
let result = db.execute("DELETE FROM fk_ra_parent WHERE id = 2");
assert!(result.is_ok(), "Should allow deletion of unreferenced parent, got: {:?}", result.err());
let parent_rows = db.query("SELECT * FROM fk_ra_parent", &[]).unwrap();
assert_eq!(parent_rows.len(), 1, "Only parent 1 should remain");
}
#[test]
fn test_fk_no_action_delete() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE fk_na_parent (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute(
"CREATE TABLE fk_na_child (
id INT PRIMARY KEY,
parent_id INT,
FOREIGN KEY (parent_id) REFERENCES fk_na_parent(id) ON DELETE NO ACTION
)"
).unwrap();
db.execute("INSERT INTO fk_na_parent VALUES (1, 'Alice')").unwrap();
db.execute("INSERT INTO fk_na_child VALUES (100, 1)").unwrap();
let result = db.execute("DELETE FROM fk_na_parent WHERE id = 1");
assert!(result.is_err(), "NO ACTION should prevent parent deletion when children exist");
}
#[test]
fn test_fk_self_referencing() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute(
"CREATE TABLE fk_self_emp (
id INT PRIMARY KEY,
name TEXT,
manager_id INT,
FOREIGN KEY (manager_id) REFERENCES fk_self_emp(id)
)"
).unwrap();
db.execute("INSERT INTO fk_self_emp VALUES (1, 'CEO', NULL)").unwrap();
db.execute("INSERT INTO fk_self_emp VALUES (2, 'VP', 1)").unwrap();
let rows = db.query("SELECT * FROM fk_self_emp ORDER BY id", &[]).unwrap();
assert_eq!(rows.len(), 2, "Should have 2 employees");
let result = db.execute("INSERT INTO fk_self_emp VALUES (3, 'Ghost', 999)");
assert!(result.is_err(), "Self-referencing FK with invalid ID should fail");
}
#[test]
fn test_fk_multiple_fks_on_one_table() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE fk_m_departments (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute("CREATE TABLE fk_m_managers (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute(
"CREATE TABLE fk_m_employees (
id INT PRIMARY KEY,
name TEXT,
dept_id INT,
manager_id INT,
FOREIGN KEY (dept_id) REFERENCES fk_m_departments(id),
FOREIGN KEY (manager_id) REFERENCES fk_m_managers(id)
)"
).unwrap();
db.execute("INSERT INTO fk_m_departments VALUES (1, 'Engineering')").unwrap();
db.execute("INSERT INTO fk_m_managers VALUES (10, 'Alice')").unwrap();
let result = db.execute("INSERT INTO fk_m_employees VALUES (100, 'Bob', 1, 10)");
assert!(result.is_ok(), "Insert with valid references to both FK parents should succeed, got: {:?}", result.err());
let result = db.execute("INSERT INTO fk_m_employees VALUES (101, 'Carol', 999, 10)");
assert!(result.is_err(), "Insert with invalid dept FK should fail");
let result = db.execute("INSERT INTO fk_m_employees VALUES (102, 'Dave', 1, 999)");
assert!(result.is_err(), "Insert with invalid manager FK should fail");
}
#[test]
fn test_fk_cascade_delete_multiple_children() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE fk_cm_parent (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute(
"CREATE TABLE fk_cm_child (
id INT PRIMARY KEY,
parent_id INT,
label TEXT,
FOREIGN KEY (parent_id) REFERENCES fk_cm_parent(id) ON DELETE CASCADE
)"
).unwrap();
db.execute("INSERT INTO fk_cm_parent VALUES (1, 'Alpha')").unwrap();
for i in 1..=5 {
db.execute(&format!("INSERT INTO fk_cm_child VALUES ({}, 1, 'child_{}')", i, i)).unwrap();
}
let child_count = db.query("SELECT COUNT(*) FROM fk_cm_child", &[]).unwrap();
match child_count[0].get(0).unwrap() {
Value::Int8(n) => assert_eq!(*n, 5),
other => panic!("Expected 5 children, got {:?}", other),
}
db.execute("DELETE FROM fk_cm_parent WHERE id = 1").unwrap();
let child_count = db.query("SELECT COUNT(*) FROM fk_cm_child", &[]).unwrap();
match child_count[0].get(0).unwrap() {
Value::Int8(n) => assert_eq!(*n, 0, "All children should be cascade-deleted"),
other => panic!("Expected 0 children after cascade, got {:?}", other),
}
}
#[test]
fn test_fk_drop_parent_table() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE fk_drop_parent (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute(
"CREATE TABLE fk_drop_child (
id INT PRIMARY KEY,
parent_id INT,
FOREIGN KEY (parent_id) REFERENCES fk_drop_parent(id)
)"
).unwrap();
db.execute("INSERT INTO fk_drop_parent VALUES (1, 'Alice')").unwrap();
db.execute("INSERT INTO fk_drop_child VALUES (100, 1)").unwrap();
match db.execute("DROP TABLE fk_drop_parent") {
Ok(_) => {
let child_rows = db.query("SELECT * FROM fk_drop_child", &[]).unwrap();
assert_eq!(child_rows.len(), 1, "Child table data should still exist after parent drop");
}
Err(e) => {
let err_msg = e.to_string();
assert!(
err_msg.to_lowercase().contains("foreign key") || err_msg.to_lowercase().contains("referenced") || err_msg.to_lowercase().contains("depends"),
"Error should mention FK dependency, got: {}",
err_msg
);
let parent_rows = db.query("SELECT * FROM fk_drop_parent", &[]).unwrap();
assert_eq!(parent_rows.len(), 1, "Parent should still exist after failed drop");
}
}
}
#[test]
fn test_fk_cascade_update() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE fk_cu_parent (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute(
"CREATE TABLE fk_cu_child (
id INT PRIMARY KEY,
parent_id INT,
FOREIGN KEY (parent_id) REFERENCES fk_cu_parent(id) ON UPDATE CASCADE
)"
).unwrap();
db.execute("INSERT INTO fk_cu_parent VALUES (1, 'Alice')").unwrap();
db.execute("INSERT INTO fk_cu_child VALUES (100, 1)").unwrap();
match db.execute("UPDATE fk_cu_parent SET id = 10 WHERE id = 1") {
Ok(_) => {
let child_rows = db.query("SELECT parent_id FROM fk_cu_child WHERE id = 100", &[]).unwrap();
assert_eq!(child_rows.len(), 1, "Child should still exist");
match child_rows[0].get(0).unwrap() {
Value::Int4(v) => {
if *v == 10 {
} else {
assert_eq!(*v, 1, "Without cascade enforcement, child should retain old FK value");
}
}
other => panic!("Expected Int4 for parent_id, got {:?}", other),
}
}
Err(e) => {
println!("UPDATE parent PK with ON UPDATE CASCADE result: {}", e);
}
}
}
#[test]
fn test_fk_insert_then_delete_child_then_delete_parent() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE fk_idc_parent (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute(
"CREATE TABLE fk_idc_child (
id INT PRIMARY KEY,
parent_id INT,
FOREIGN KEY (parent_id) REFERENCES fk_idc_parent(id)
)"
).unwrap();
db.execute("INSERT INTO fk_idc_parent VALUES (1, 'Alice')").unwrap();
db.execute("INSERT INTO fk_idc_child VALUES (100, 1)").unwrap();
db.execute("DELETE FROM fk_idc_child WHERE id = 100").unwrap();
let result = db.execute("DELETE FROM fk_idc_parent WHERE id = 1");
assert!(result.is_ok(), "Should be able to delete parent after all children removed, got: {:?}", result.err());
let parent_rows = db.query("SELECT * FROM fk_idc_parent", &[]).unwrap();
assert_eq!(parent_rows.len(), 0, "Parent should be deleted");
}
fn setup_group_by_db() -> EmbeddedDatabase {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE gb_sales (id INT, department TEXT, amount INT, rating FLOAT8)").unwrap();
db.execute("INSERT INTO gb_sales VALUES (1, 'Engineering', 100, 4.5)").unwrap();
db.execute("INSERT INTO gb_sales VALUES (2, 'Engineering', 200, 3.8)").unwrap();
db.execute("INSERT INTO gb_sales VALUES (3, 'Engineering', 150, 4.2)").unwrap();
db.execute("INSERT INTO gb_sales VALUES (4, 'Sales', 80, 3.0)").unwrap();
db.execute("INSERT INTO gb_sales VALUES (5, 'Sales', 120, 4.1)").unwrap();
db.execute("INSERT INTO gb_sales VALUES (6, 'Marketing', 90, 3.5)").unwrap();
db.execute("INSERT INTO gb_sales VALUES (7, 'HR', 60, 2.8)").unwrap();
db
}
#[test]
fn test_group_by_having_count() {
let db = setup_group_by_db();
let rows = db.query(
"SELECT department, COUNT(*) AS cnt FROM gb_sales GROUP BY department HAVING COUNT(*) > 1 ORDER BY department",
&[],
).unwrap();
assert_eq!(rows.len(), 2, "Expected 2 departments with count > 1, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::String("Engineering".to_string()));
assert_eq!(rows[0].get(1).unwrap(), &Value::Int8(3));
assert_eq!(rows[1].get(0).unwrap(), &Value::String("Sales".to_string()));
assert_eq!(rows[1].get(1).unwrap(), &Value::Int8(2));
}
#[test]
fn test_group_by_having_sum() {
let db = setup_group_by_db();
let rows = db.query(
"SELECT department, SUM(amount) AS total FROM gb_sales GROUP BY department HAVING SUM(amount) > 100 ORDER BY department",
&[],
).unwrap();
assert_eq!(rows.len(), 2, "Expected 2 departments with sum > 100, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::String("Engineering".to_string()));
assert_eq!(rows[0].get(1).unwrap(), &Value::Int8(450));
assert_eq!(rows[1].get(0).unwrap(), &Value::String("Sales".to_string()));
assert_eq!(rows[1].get(1).unwrap(), &Value::Int8(200));
}
#[test]
fn test_group_by_having_avg() {
let db = setup_group_by_db();
let rows = db.query(
"SELECT department, AVG(rating) FROM gb_sales GROUP BY department HAVING AVG(rating) > 3.5 ORDER BY department",
&[],
).unwrap();
assert_eq!(rows.len(), 2, "Expected 2 departments with avg > 3.5, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::String("Engineering".to_string()));
assert_eq!(rows[1].get(0).unwrap(), &Value::String("Sales".to_string()));
if let Value::Float8(avg) = rows[0].get(1).unwrap() {
assert!((avg - 4.1666).abs() < 0.01, "Engineering avg should be ~4.167, got {}", avg);
} else {
panic!("AVG should return Float8, got {:?}", rows[0].get(1));
}
}
#[test]
fn test_group_by_having_multiple_conditions() {
let db = setup_group_by_db();
let rows = db.query(
"SELECT department, COUNT(*), SUM(amount) FROM gb_sales GROUP BY department HAVING COUNT(*) > 1 AND SUM(amount) > 150 ORDER BY department",
&[],
).unwrap();
assert_eq!(rows.len(), 2, "Expected 2 departments matching both conditions, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::String("Engineering".to_string()));
assert_eq!(rows[1].get(0).unwrap(), &Value::String("Sales".to_string()));
}
#[test]
fn test_group_by_having_no_match() {
let db = setup_group_by_db();
let rows = db.query(
"SELECT department, COUNT(*) FROM gb_sales GROUP BY department HAVING COUNT(*) > 100",
&[],
).unwrap();
assert_eq!(rows.len(), 0, "No group should have count > 100");
}
#[test]
fn test_group_by_having_all_match() {
let db = setup_group_by_db();
let rows = db.query(
"SELECT department, COUNT(*) FROM gb_sales GROUP BY department HAVING COUNT(*) >= 1 ORDER BY department",
&[],
).unwrap();
assert_eq!(rows.len(), 4, "All 4 departments should match count >= 1, got {}", rows.len());
}
#[test]
fn test_group_by_multiple_columns() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE gb_multi (region TEXT, category TEXT, amount INT)").unwrap();
db.execute("INSERT INTO gb_multi VALUES ('East', 'A', 10)").unwrap();
db.execute("INSERT INTO gb_multi VALUES ('East', 'A', 20)").unwrap();
db.execute("INSERT INTO gb_multi VALUES ('East', 'B', 30)").unwrap();
db.execute("INSERT INTO gb_multi VALUES ('West', 'A', 40)").unwrap();
db.execute("INSERT INTO gb_multi VALUES ('West', 'B', 50)").unwrap();
let rows = db.query(
"SELECT region, category, SUM(amount) FROM gb_multi GROUP BY region, category ORDER BY region, category",
&[],
).unwrap();
assert_eq!(rows.len(), 4, "Expected 4 groups from 2-column GROUP BY, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::String("East".to_string()));
assert_eq!(rows[0].get(1).unwrap(), &Value::String("A".to_string()));
assert_eq!(rows[0].get(2).unwrap(), &Value::Int8(30));
assert_eq!(rows[1].get(0).unwrap(), &Value::String("East".to_string()));
assert_eq!(rows[1].get(1).unwrap(), &Value::String("B".to_string()));
assert_eq!(rows[1].get(2).unwrap(), &Value::Int8(30));
assert_eq!(rows[2].get(0).unwrap(), &Value::String("West".to_string()));
assert_eq!(rows[2].get(1).unwrap(), &Value::String("A".to_string()));
assert_eq!(rows[2].get(2).unwrap(), &Value::Int8(40));
assert_eq!(rows[3].get(0).unwrap(), &Value::String("West".to_string()));
assert_eq!(rows[3].get(1).unwrap(), &Value::String("B".to_string()));
assert_eq!(rows[3].get(2).unwrap(), &Value::Int8(50));
}
#[test]
fn test_group_by_with_order_by() {
let db = setup_group_by_db();
let rows = db.query(
"SELECT department, SUM(amount) AS total FROM gb_sales GROUP BY department ORDER BY total ASC",
&[],
).unwrap();
assert_eq!(rows.len(), 4, "Expected 4 departments, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::String("HR".to_string()));
assert_eq!(rows[0].get(1).unwrap(), &Value::Int8(60));
assert_eq!(rows[1].get(0).unwrap(), &Value::String("Marketing".to_string()));
assert_eq!(rows[1].get(1).unwrap(), &Value::Int8(90));
assert_eq!(rows[2].get(0).unwrap(), &Value::String("Sales".to_string()));
assert_eq!(rows[2].get(1).unwrap(), &Value::Int8(200));
assert_eq!(rows[3].get(0).unwrap(), &Value::String("Engineering".to_string()));
assert_eq!(rows[3].get(1).unwrap(), &Value::Int8(450));
}
#[test]
fn test_group_by_null_values() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE gb_nulls (category TEXT, val INT)").unwrap();
db.execute("INSERT INTO gb_nulls VALUES ('A', 10)").unwrap();
db.execute("INSERT INTO gb_nulls VALUES ('A', 20)").unwrap();
db.execute("INSERT INTO gb_nulls VALUES (NULL, 30)").unwrap();
db.execute("INSERT INTO gb_nulls VALUES (NULL, 40)").unwrap();
let rows = db.query(
"SELECT category, SUM(val) FROM gb_nulls GROUP BY category ORDER BY category",
&[],
).unwrap();
assert_eq!(rows.len(), 2, "Expected 2 groups (A and NULL), got {}", rows.len());
let a_group = rows.iter().find(|r| r.get(0).unwrap() == &Value::String("A".to_string()));
let null_group = rows.iter().find(|r| r.get(0).unwrap() == &Value::Null);
assert!(a_group.is_some(), "Should have an 'A' group");
assert_eq!(a_group.unwrap().get(1).unwrap(), &Value::Int8(30));
assert!(null_group.is_some(), "NULL values should form their own group");
assert_eq!(null_group.unwrap().get(1).unwrap(), &Value::Int8(70));
}
#[test]
fn test_group_by_count_distinct() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE gb_cd (grp TEXT, val INT)").unwrap();
db.execute("INSERT INTO gb_cd VALUES ('X', 1)").unwrap();
db.execute("INSERT INTO gb_cd VALUES ('X', 2)").unwrap();
db.execute("INSERT INTO gb_cd VALUES ('X', 2)").unwrap();
db.execute("INSERT INTO gb_cd VALUES ('X', 3)").unwrap();
db.execute("INSERT INTO gb_cd VALUES ('Y', 10)").unwrap();
db.execute("INSERT INTO gb_cd VALUES ('Y', 10)").unwrap();
let result = db.query(
"SELECT grp, COUNT(DISTINCT val) FROM gb_cd GROUP BY grp ORDER BY grp",
&[],
);
match result {
Ok(rows) => {
assert_eq!(rows.len(), 2, "Expected 2 groups, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::String("X".to_string()));
assert_eq!(rows[0].get(1).unwrap(), &Value::Int8(3));
assert_eq!(rows[1].get(0).unwrap(), &Value::String("Y".to_string()));
assert_eq!(rows[1].get(1).unwrap(), &Value::Int8(1));
}
Err(e) => {
let err_msg = e.to_string();
assert!(
err_msg.contains("DISTINCT") || err_msg.contains("distinct") || err_msg.contains("not") || err_msg.contains("syntax"),
"COUNT(DISTINCT) unsupported or syntax error: {}", err_msg
);
}
}
}
#[test]
fn test_cast_int_to_text() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query("SELECT CAST(42 AS TEXT)", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get(0).unwrap(), &Value::String("42".to_string()));
}
#[test]
fn test_cast_text_to_int() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query("SELECT CAST('42' AS INT)", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(42));
}
#[test]
fn test_cast_int_to_float() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query("SELECT CAST(42 AS FLOAT8)", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get(0).unwrap(), &Value::Float8(42.0));
}
#[test]
fn test_cast_float_to_int() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query("SELECT CAST(3.7 AS INT)", &[]).unwrap();
assert_eq!(rows.len(), 1);
let val = rows[0].get(0).unwrap();
assert!(
val == &Value::Int4(3) || val == &Value::Int4(4),
"CAST(3.7 AS INT) should truncate to 3 (or possibly round to 4), got {:?}", val
);
}
#[test]
fn test_cast_text_to_boolean() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query("SELECT CAST('true' AS BOOLEAN)", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get(0).unwrap(), &Value::Boolean(true));
let rows2 = db.query("SELECT CAST('false' AS BOOLEAN)", &[]).unwrap();
assert_eq!(rows2.len(), 1);
assert_eq!(rows2[0].get(0).unwrap(), &Value::Boolean(false));
}
#[test]
fn test_cast_boolean_to_text() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query("SELECT CAST(TRUE AS TEXT)", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get(0).unwrap(), &Value::String("true".to_string()));
let rows2 = db.query("SELECT CAST(FALSE AS TEXT)", &[]).unwrap();
assert_eq!(rows2.len(), 1);
assert_eq!(rows2[0].get(0).unwrap(), &Value::String("false".to_string()));
}
#[test]
fn test_cast_null_cast() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query("SELECT CAST(NULL AS INT)", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get(0).unwrap(), &Value::Null);
}
#[test]
fn test_cast_invalid_text_to_int() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let result = db.query("SELECT CAST('abc' AS INT)", &[]);
assert!(result.is_err(), "CAST('abc' AS INT) should fail, but got: {:?}", result);
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("Cannot cast") || err_msg.contains("cast") || err_msg.contains("invalid"),
"Error should mention cast failure, got: {}", err_msg
);
}
#[test]
fn test_cast_int_to_bigint() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query("SELECT CAST(42 AS BIGINT)", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get(0).unwrap(), &Value::Int8(42));
}
#[test]
fn test_cast_in_where() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE cast_where (id INT, code INT)").unwrap();
db.execute("INSERT INTO cast_where VALUES (1, 42)").unwrap();
db.execute("INSERT INTO cast_where VALUES (2, 99)").unwrap();
db.execute("INSERT INTO cast_where VALUES (3, 42)").unwrap();
let rows = db.query(
"SELECT id FROM cast_where WHERE CAST(code AS TEXT) = '42' ORDER BY id",
&[],
).unwrap();
assert_eq!(rows.len(), 2, "Expected 2 rows with code=42, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(3));
}
#[test]
fn test_alter_add_column_basic() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE alt_add_basic (id INT, name TEXT)").unwrap();
db.execute("INSERT INTO alt_add_basic VALUES (1, 'Alice')").unwrap();
db.execute("ALTER TABLE alt_add_basic ADD COLUMN age INT").unwrap();
let rows = db.query("SELECT id, name, age FROM alt_add_basic", &[]).unwrap();
assert_eq!(rows.len(), 1, "Should still have 1 row");
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
assert_eq!(rows[0].get(1).unwrap(), &Value::String("Alice".to_string()));
assert_eq!(rows[0].get(2).unwrap(), &Value::Null,
"New column should be NULL for existing rows");
}
#[test]
fn test_alter_add_column_with_default() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE alt_add_def (id INT, name TEXT)").unwrap();
db.execute("INSERT INTO alt_add_def VALUES (1, 'Alice')").unwrap();
db.execute("INSERT INTO alt_add_def VALUES (2, 'Bob')").unwrap();
db.execute("ALTER TABLE alt_add_def ADD COLUMN status TEXT DEFAULT 'active'").unwrap();
let rows = db.query("SELECT id, name, status FROM alt_add_def ORDER BY id", &[]).unwrap();
assert_eq!(rows.len(), 2, "Should still have 2 rows");
let status_0 = rows[0].get(2).unwrap();
let status_1 = rows[1].get(2).unwrap();
assert_eq!(status_0, status_1,
"Both existing rows should get same value for new column with DEFAULT");
assert!(
*status_0 == Value::Null || *status_0 == Value::String("active".to_string()),
"New column should be NULL or 'active', got: {:?}", status_0
);
}
#[test]
fn test_alter_add_column_nullable() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE alt_add_null (id INT)").unwrap();
db.execute("INSERT INTO alt_add_null VALUES (1)").unwrap();
db.execute("INSERT INTO alt_add_null VALUES (2)").unwrap();
db.execute("INSERT INTO alt_add_null VALUES (3)").unwrap();
db.execute("ALTER TABLE alt_add_null ADD COLUMN note TEXT").unwrap();
let rows = db.query("SELECT id, note FROM alt_add_null ORDER BY id", &[]).unwrap();
assert_eq!(rows.len(), 3, "Should still have 3 rows");
for (i, row) in rows.iter().enumerate() {
assert_eq!(row.get(1).unwrap(), &Value::Null,
"Row {} new column should be NULL", i);
}
}
#[test]
fn test_alter_add_column_text_type() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE alt_add_text (id INT)").unwrap();
db.execute("INSERT INTO alt_add_text VALUES (1)").unwrap();
db.execute("ALTER TABLE alt_add_text ADD COLUMN description TEXT").unwrap();
db.execute("UPDATE alt_add_text SET description = 'hello world' WHERE id = 1").unwrap();
let rows = db.query("SELECT id, description FROM alt_add_text", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
assert_eq!(rows[0].get(1).unwrap(), &Value::String("hello world".to_string()));
}
#[test]
fn test_alter_add_column_then_insert() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE alt_add_ins (id INT, name TEXT)").unwrap();
db.execute("INSERT INTO alt_add_ins VALUES (1, 'Alice')").unwrap();
db.execute("ALTER TABLE alt_add_ins ADD COLUMN score INT").unwrap();
db.execute("INSERT INTO alt_add_ins VALUES (2, 'Bob', 95)").unwrap();
let rows = db.query("SELECT id, name, score FROM alt_add_ins ORDER BY id", &[]).unwrap();
assert_eq!(rows.len(), 2, "Should have 2 rows total");
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
assert_eq!(rows[0].get(2).unwrap(), &Value::Null);
assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(2));
assert_eq!(rows[1].get(2).unwrap(), &Value::Int4(95));
}
#[test]
fn test_alter_add_column_duplicate() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE alt_add_dup (id INT, name TEXT)").unwrap();
let result = db.execute("ALTER TABLE alt_add_dup ADD COLUMN name TEXT");
assert!(result.is_err(),
"Adding a duplicate column should fail");
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("already exists"),
"Error should mention 'already exists', got: {}", err_msg);
}
#[test]
fn test_alter_add_column_if_not_exists() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE alt_add_ine (id INT, name TEXT)").unwrap();
let result = db.execute("ALTER TABLE alt_add_ine ADD COLUMN IF NOT EXISTS name TEXT");
assert!(result.is_ok(),
"ADD COLUMN IF NOT EXISTS for existing column should succeed silently, got: {:?}",
result.err());
}
#[test]
fn test_alter_drop_column_basic() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE alt_drop_basic (id INT, name TEXT, age INT)").unwrap();
db.execute("INSERT INTO alt_drop_basic VALUES (1, 'Alice', 30)").unwrap();
db.execute("ALTER TABLE alt_drop_basic DROP COLUMN age").unwrap();
let result = db.query("SELECT age FROM alt_drop_basic", &[]);
assert!(result.is_err(),
"Selecting a dropped column should fail");
let rows = db.query("SELECT id, name FROM alt_drop_basic", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
assert_eq!(rows[0].get(1).unwrap(), &Value::String("Alice".to_string()));
}
#[test]
fn test_alter_drop_column_with_data() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE alt_drop_data (id INT, name TEXT, score INT, grade TEXT)").unwrap();
db.execute("INSERT INTO alt_drop_data VALUES (1, 'Alice', 90, 'A')").unwrap();
db.execute("INSERT INTO alt_drop_data VALUES (2, 'Bob', 80, 'B')").unwrap();
db.execute("INSERT INTO alt_drop_data VALUES (3, 'Carol', 70, 'C')").unwrap();
db.execute("ALTER TABLE alt_drop_data DROP COLUMN score").unwrap();
let rows = db.query("SELECT id, name, grade FROM alt_drop_data ORDER BY id", &[]).unwrap();
assert_eq!(rows.len(), 3, "All rows should still exist");
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
assert_eq!(rows[0].get(1).unwrap(), &Value::String("Alice".to_string()));
assert_eq!(rows[0].get(2).unwrap(), &Value::String("A".to_string()));
assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(2));
assert_eq!(rows[1].get(1).unwrap(), &Value::String("Bob".to_string()));
assert_eq!(rows[1].get(2).unwrap(), &Value::String("B".to_string()));
assert_eq!(rows[2].get(0).unwrap(), &Value::Int4(3));
assert_eq!(rows[2].get(1).unwrap(), &Value::String("Carol".to_string()));
assert_eq!(rows[2].get(2).unwrap(), &Value::String("C".to_string()));
}
#[test]
fn test_alter_drop_column_nonexistent() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE alt_drop_ne (id INT, name TEXT)").unwrap();
let result = db.execute("ALTER TABLE alt_drop_ne DROP COLUMN nonexistent");
assert!(result.is_err(),
"Dropping a nonexistent column should fail");
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("does not exist"),
"Error should mention 'does not exist', got: {}", err_msg);
}
#[test]
fn test_alter_drop_column_if_exists() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE alt_drop_ie (id INT, name TEXT)").unwrap();
let result = db.execute("ALTER TABLE alt_drop_ie DROP COLUMN IF EXISTS nonexistent");
assert!(result.is_ok(),
"DROP COLUMN IF EXISTS for nonexistent column should succeed silently, got: {:?}",
result.err());
}
#[test]
fn test_alter_drop_column_last_column() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE alt_drop_last (only_col INT)").unwrap();
db.execute("INSERT INTO alt_drop_last VALUES (42)").unwrap();
let result = db.execute("ALTER TABLE alt_drop_last DROP COLUMN only_col");
if result.is_ok() {
let query_result = db.query("SELECT * FROM alt_drop_last", &[]);
match query_result {
Ok(rows) => {
assert!(rows.is_empty() || rows[0].values.is_empty(),
"After dropping last column, rows should be empty or have no values");
}
Err(_) => {
}
}
}
}
#[test]
fn test_alter_drop_primary_key_column_without_cascade() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE alt_drop_pk (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute("INSERT INTO alt_drop_pk VALUES (1, 'Alice')").unwrap();
let result = db.execute("ALTER TABLE alt_drop_pk DROP COLUMN id");
assert!(result.is_err(),
"Dropping a primary key column without CASCADE should fail");
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("CASCADE") || err_msg.contains("primary key"),
"Error should mention CASCADE or primary key, got: {}", err_msg);
}
#[test]
fn test_alter_rename_column_basic() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE alt_ren_col (id INT, old_name TEXT)").unwrap();
db.execute("INSERT INTO alt_ren_col VALUES (1, 'Alice')").unwrap();
db.execute("ALTER TABLE alt_ren_col RENAME COLUMN old_name TO new_name").unwrap();
let rows = db.query("SELECT id, new_name FROM alt_ren_col", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
assert_eq!(rows[0].get(1).unwrap(), &Value::String("Alice".to_string()));
}
#[test]
fn test_alter_rename_column_preserves_data() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE alt_ren_data (id INT, val TEXT)").unwrap();
db.execute("INSERT INTO alt_ren_data VALUES (1, 'one')").unwrap();
db.execute("INSERT INTO alt_ren_data VALUES (2, 'two')").unwrap();
db.execute("INSERT INTO alt_ren_data VALUES (3, 'three')").unwrap();
db.execute("ALTER TABLE alt_ren_data RENAME COLUMN val TO value").unwrap();
let rows = db.query("SELECT id, value FROM alt_ren_data ORDER BY id", &[]).unwrap();
assert_eq!(rows.len(), 3, "All rows should still exist after rename");
assert_eq!(rows[0].get(1).unwrap(), &Value::String("one".to_string()));
assert_eq!(rows[1].get(1).unwrap(), &Value::String("two".to_string()));
assert_eq!(rows[2].get(1).unwrap(), &Value::String("three".to_string()));
let result = db.query("SELECT val FROM alt_ren_data", &[]);
assert!(result.is_err(),
"Old column name should no longer be valid after rename");
}
#[test]
fn test_alter_rename_column_nonexistent() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE alt_ren_ne (id INT, name TEXT)").unwrap();
let result = db.execute("ALTER TABLE alt_ren_ne RENAME COLUMN ghost TO phantom");
assert!(result.is_err(),
"Renaming a nonexistent column should fail");
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("does not exist"),
"Error should mention 'does not exist', got: {}", err_msg);
}
#[test]
fn test_alter_rename_column_to_existing_name() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE alt_ren_dup (id INT, name TEXT)").unwrap();
let result = db.execute("ALTER TABLE alt_ren_dup RENAME COLUMN id TO name");
assert!(result.is_err(),
"Renaming to an already-existing column name should fail");
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("already exists"),
"Error should mention 'already exists', got: {}", err_msg);
}
#[test]
fn test_alter_rename_table_basic() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE alt_old_tbl (id INT, name TEXT)").unwrap();
db.execute("INSERT INTO alt_old_tbl VALUES (1, 'Alice')").unwrap();
db.execute("ALTER TABLE alt_old_tbl RENAME TO alt_new_tbl").unwrap();
let rows = db.query("SELECT id, name FROM alt_new_tbl", &[]).unwrap();
assert_eq!(rows.len(), 1, "Data should be accessible via new table name");
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
assert_eq!(rows[0].get(1).unwrap(), &Value::String("Alice".to_string()));
}
#[test]
fn test_alter_rename_table_old_name_fails() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE alt_orig (id INT)").unwrap();
db.execute("INSERT INTO alt_orig VALUES (1)").unwrap();
db.execute("ALTER TABLE alt_orig RENAME TO alt_renamed").unwrap();
let result = db.query("SELECT * FROM alt_orig", &[]);
assert!(result.is_err(),
"Querying the old table name after rename should fail");
}
#[test]
fn test_alter_rename_table_to_existing() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE alt_src (id INT)").unwrap();
db.execute("CREATE TABLE alt_dst (id INT)").unwrap();
let result = db.execute("ALTER TABLE alt_src RENAME TO alt_dst");
assert!(result.is_err(),
"Renaming to an existing table name should fail");
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("already exists"),
"Error should mention 'already exists', got: {}", err_msg);
}
#[test]
fn test_alter_add_then_drop_column() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE alt_add_drop (id INT, name TEXT)").unwrap();
db.execute("INSERT INTO alt_add_drop VALUES (1, 'Alice')").unwrap();
db.execute("ALTER TABLE alt_add_drop ADD COLUMN temp INT").unwrap();
let rows = db.query("SELECT id, name, temp FROM alt_add_drop", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].values.len(), 3, "Should have 3 columns after ADD");
db.execute("ALTER TABLE alt_add_drop DROP COLUMN temp").unwrap();
let rows = db.query("SELECT id, name FROM alt_add_drop", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
assert_eq!(rows[0].get(1).unwrap(), &Value::String("Alice".to_string()));
let result = db.query("SELECT temp FROM alt_add_drop", &[]);
assert!(result.is_err(), "Dropped column should not be queryable");
}
#[test]
fn test_alter_multiple_sequential_operations() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE alt_seq (id INT, a TEXT)").unwrap();
db.execute("INSERT INTO alt_seq VALUES (1, 'original')").unwrap();
db.execute("ALTER TABLE alt_seq ADD COLUMN b INT").unwrap();
db.execute("ALTER TABLE alt_seq RENAME COLUMN a TO alpha").unwrap();
db.execute("ALTER TABLE alt_seq ADD COLUMN c TEXT").unwrap();
db.execute("INSERT INTO alt_seq VALUES (2, 'new', 42, 'hello')").unwrap();
let rows = db.query("SELECT id, alpha, b, c FROM alt_seq ORDER BY id", &[]).unwrap();
assert_eq!(rows.len(), 2, "Should have 2 rows");
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
assert_eq!(rows[0].get(1).unwrap(), &Value::String("original".to_string()));
assert_eq!(rows[0].get(2).unwrap(), &Value::Null);
assert_eq!(rows[0].get(3).unwrap(), &Value::Null);
assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(2));
assert_eq!(rows[1].get(1).unwrap(), &Value::String("new".to_string()));
assert_eq!(rows[1].get(2).unwrap(), &Value::Int4(42));
assert_eq!(rows[1].get(3).unwrap(), &Value::String("hello".to_string()));
}
#[test]
fn test_alter_table_nonexistent_table() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let result = db.execute("ALTER TABLE no_such_table ADD COLUMN x INT");
assert!(result.is_err(),
"ALTER TABLE on nonexistent table should fail");
}
fn setup_pagination_db() -> EmbeddedDatabase {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute(
"CREATE TABLE pg_products (id INT, name TEXT, price INT, category TEXT)"
).unwrap();
db.execute("INSERT INTO pg_products VALUES (1, 'Product_01', 50, 'Electronics')").unwrap();
db.execute("INSERT INTO pg_products VALUES (2, 'Product_02', 30, 'Books')").unwrap();
db.execute("INSERT INTO pg_products VALUES (3, 'Product_03', 75, 'Electronics')").unwrap();
db.execute("INSERT INTO pg_products VALUES (4, 'Product_04', 20, 'Clothing')").unwrap();
db.execute("INSERT INTO pg_products VALUES (5, 'Product_05', 90, 'Electronics')").unwrap();
db.execute("INSERT INTO pg_products VALUES (6, 'Product_06', 15, 'Books')").unwrap();
db.execute("INSERT INTO pg_products VALUES (7, 'Product_07', 60, 'Clothing')").unwrap();
db.execute("INSERT INTO pg_products VALUES (8, 'Product_08', 45, 'Books')").unwrap();
db.execute("INSERT INTO pg_products VALUES (9, 'Product_09', 80, 'Clothing')").unwrap();
db.execute("INSERT INTO pg_products VALUES (10, 'Product_10', 35, 'Electronics')").unwrap();
db
}
#[test]
fn test_limit_basic() {
let db = setup_pagination_db();
let rows = db.query(
"SELECT id FROM pg_products ORDER BY id LIMIT 3",
&[],
).unwrap();
assert_eq!(rows.len(), 3, "LIMIT 3 should return 3 rows, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(2));
assert_eq!(rows[2].get(0).unwrap(), &Value::Int4(3));
}
#[test]
fn test_limit_zero() {
let db = setup_pagination_db();
let rows = db.query(
"SELECT id FROM pg_products LIMIT 0",
&[],
).unwrap();
assert_eq!(rows.len(), 0, "LIMIT 0 should return 0 rows, got {}", rows.len());
}
#[test]
fn test_limit_exceeds_rows() {
let db = setup_pagination_db();
let rows = db.query(
"SELECT id FROM pg_products ORDER BY id LIMIT 100",
&[],
).unwrap();
assert_eq!(rows.len(), 10, "LIMIT 100 on 10 rows should return 10, got {}", rows.len());
}
#[test]
fn test_limit_one() {
let db = setup_pagination_db();
let rows = db.query(
"SELECT id FROM pg_products ORDER BY id LIMIT 1",
&[],
).unwrap();
assert_eq!(rows.len(), 1, "LIMIT 1 should return 1 row, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
}
#[test]
fn test_limit_with_order_by() {
let db = setup_pagination_db();
let rows = db.query(
"SELECT id, price FROM pg_products ORDER BY price DESC LIMIT 3",
&[],
).unwrap();
assert_eq!(rows.len(), 3, "Top 3 by price DESC should return 3 rows");
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(5), "Most expensive is id=5 (price 90)");
assert_eq!(rows[0].get(1).unwrap(), &Value::Int4(90));
assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(9), "Second most expensive is id=9 (price 80)");
assert_eq!(rows[1].get(1).unwrap(), &Value::Int4(80));
assert_eq!(rows[2].get(0).unwrap(), &Value::Int4(3), "Third most expensive is id=3 (price 75)");
assert_eq!(rows[2].get(1).unwrap(), &Value::Int4(75));
}
#[test]
fn test_offset_basic() {
let db = setup_pagination_db();
let rows = db.query(
"SELECT id FROM pg_products ORDER BY id OFFSET 2",
&[],
).unwrap();
assert_eq!(rows.len(), 8, "OFFSET 2 on 10 rows should return 8, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(3));
assert_eq!(rows[7].get(0).unwrap(), &Value::Int4(10));
}
#[test]
fn test_offset_exceeds_rows() {
let db = setup_pagination_db();
let rows = db.query(
"SELECT id FROM pg_products ORDER BY id LIMIT 100 OFFSET 20",
&[],
).unwrap();
assert_eq!(rows.len(), 0, "OFFSET beyond row count should return 0 rows, got {}", rows.len());
}
#[test]
fn test_offset_zero() {
let db = setup_pagination_db();
let rows = db.query(
"SELECT id FROM pg_products ORDER BY id OFFSET 0",
&[],
).unwrap();
assert_eq!(rows.len(), 10, "OFFSET 0 should return all 10 rows, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
}
#[test]
fn test_limit_offset_combined() {
let db = setup_pagination_db();
let rows = db.query(
"SELECT id FROM pg_products ORDER BY id LIMIT 3 OFFSET 2",
&[],
).unwrap();
assert_eq!(rows.len(), 3, "LIMIT 3 OFFSET 2 should return 3 rows, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(3));
assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(4));
assert_eq!(rows[2].get(0).unwrap(), &Value::Int4(5));
}
#[test]
fn test_limit_offset_page_2() {
let db = setup_pagination_db();
let rows = db.query(
"SELECT id FROM pg_products ORDER BY id LIMIT 3 OFFSET 3",
&[],
).unwrap();
assert_eq!(rows.len(), 3, "Page 2 (LIMIT 3 OFFSET 3) should return 3 rows, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(4));
assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(5));
assert_eq!(rows[2].get(0).unwrap(), &Value::Int4(6));
}
#[test]
fn test_limit_offset_last_page() {
let db = setup_pagination_db();
let rows = db.query(
"SELECT id FROM pg_products ORDER BY id LIMIT 3 OFFSET 9",
&[],
).unwrap();
assert_eq!(rows.len(), 1, "Last page should return 1 remaining row, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(10));
}
#[test]
fn test_order_by_asc() {
let db = setup_pagination_db();
let rows = db.query(
"SELECT id, price FROM pg_products ORDER BY price ASC",
&[],
).unwrap();
assert_eq!(rows.len(), 10);
assert_eq!(rows[0].get(1).unwrap(), &Value::Int4(15), "Cheapest should be 15");
assert_eq!(rows[9].get(1).unwrap(), &Value::Int4(90), "Most expensive should be 90");
for i in 0..9 {
let cur = match rows[i].get(1).unwrap() { Value::Int4(v) => *v, _ => panic!("expected Int4") };
let nxt = match rows[i + 1].get(1).unwrap() { Value::Int4(v) => *v, _ => panic!("expected Int4") };
assert!(cur <= nxt, "Row {} price {} should be <= row {} price {}", i, cur, i + 1, nxt);
}
}
#[test]
fn test_order_by_desc() {
let db = setup_pagination_db();
let rows = db.query(
"SELECT id, price FROM pg_products ORDER BY price DESC",
&[],
).unwrap();
assert_eq!(rows.len(), 10);
assert_eq!(rows[0].get(1).unwrap(), &Value::Int4(90), "Most expensive should be first");
assert_eq!(rows[9].get(1).unwrap(), &Value::Int4(15), "Cheapest should be last");
for i in 0..9 {
let cur = match rows[i].get(1).unwrap() { Value::Int4(v) => *v, _ => panic!("expected Int4") };
let nxt = match rows[i + 1].get(1).unwrap() { Value::Int4(v) => *v, _ => panic!("expected Int4") };
assert!(cur >= nxt, "Row {} price {} should be >= row {} price {}", i, cur, i + 1, nxt);
}
}
#[test]
fn test_order_by_multiple_columns() {
let db = setup_pagination_db();
let rows = db.query(
"SELECT id, category, price FROM pg_products ORDER BY category, price",
&[],
).unwrap();
assert_eq!(rows.len(), 10);
assert_eq!(rows[0].get(1).unwrap(), &Value::String("Books".to_string()));
assert_eq!(rows[0].get(2).unwrap(), &Value::Int4(15));
assert_eq!(rows[1].get(1).unwrap(), &Value::String("Books".to_string()));
assert_eq!(rows[1].get(2).unwrap(), &Value::Int4(30));
assert_eq!(rows[2].get(1).unwrap(), &Value::String("Books".to_string()));
assert_eq!(rows[2].get(2).unwrap(), &Value::Int4(45));
assert_eq!(rows[3].get(1).unwrap(), &Value::String("Clothing".to_string()));
assert_eq!(rows[3].get(2).unwrap(), &Value::Int4(20));
assert_eq!(rows[4].get(1).unwrap(), &Value::String("Clothing".to_string()));
assert_eq!(rows[4].get(2).unwrap(), &Value::Int4(60));
assert_eq!(rows[5].get(1).unwrap(), &Value::String("Clothing".to_string()));
assert_eq!(rows[5].get(2).unwrap(), &Value::Int4(80));
assert_eq!(rows[6].get(1).unwrap(), &Value::String("Electronics".to_string()));
assert_eq!(rows[6].get(2).unwrap(), &Value::Int4(35));
}
#[test]
fn test_order_by_mixed_directions() {
let db = setup_pagination_db();
let rows = db.query(
"SELECT id, category, price FROM pg_products ORDER BY category ASC, price DESC",
&[],
).unwrap();
assert_eq!(rows.len(), 10);
assert_eq!(rows[0].get(1).unwrap(), &Value::String("Books".to_string()));
assert_eq!(rows[0].get(2).unwrap(), &Value::Int4(45));
assert_eq!(rows[1].get(1).unwrap(), &Value::String("Books".to_string()));
assert_eq!(rows[1].get(2).unwrap(), &Value::Int4(30));
assert_eq!(rows[2].get(1).unwrap(), &Value::String("Books".to_string()));
assert_eq!(rows[2].get(2).unwrap(), &Value::Int4(15));
assert_eq!(rows[3].get(1).unwrap(), &Value::String("Clothing".to_string()));
assert_eq!(rows[3].get(2).unwrap(), &Value::Int4(80));
assert_eq!(rows[4].get(1).unwrap(), &Value::String("Clothing".to_string()));
assert_eq!(rows[4].get(2).unwrap(), &Value::Int4(60));
assert_eq!(rows[5].get(1).unwrap(), &Value::String("Clothing".to_string()));
assert_eq!(rows[5].get(2).unwrap(), &Value::Int4(20));
assert_eq!(rows[6].get(1).unwrap(), &Value::String("Electronics".to_string()));
assert_eq!(rows[6].get(2).unwrap(), &Value::Int4(90));
assert_eq!(rows[7].get(2).unwrap(), &Value::Int4(75));
assert_eq!(rows[8].get(2).unwrap(), &Value::Int4(50));
assert_eq!(rows[9].get(2).unwrap(), &Value::Int4(35));
}
#[test]
fn test_order_by_with_nulls() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE pg_nullsort (id INT, score INT)").unwrap();
db.execute("INSERT INTO pg_nullsort VALUES (1, 50)").unwrap();
db.execute("INSERT INTO pg_nullsort VALUES (2, NULL)").unwrap();
db.execute("INSERT INTO pg_nullsort VALUES (3, 30)").unwrap();
db.execute("INSERT INTO pg_nullsort VALUES (4, NULL)").unwrap();
db.execute("INSERT INTO pg_nullsort VALUES (5, 70)").unwrap();
let rows = db.query(
"SELECT id, score FROM pg_nullsort ORDER BY score ASC, id ASC",
&[],
).unwrap();
assert_eq!(rows.len(), 5);
assert_eq!(rows[0].get(1).unwrap(), &Value::Null);
assert_eq!(rows[1].get(1).unwrap(), &Value::Null);
assert_eq!(rows[2].get(1).unwrap(), &Value::Int4(30));
assert_eq!(rows[3].get(1).unwrap(), &Value::Int4(50));
assert_eq!(rows[4].get(1).unwrap(), &Value::Int4(70));
let rows_desc = db.query(
"SELECT id, score FROM pg_nullsort ORDER BY score DESC, id ASC",
&[],
).unwrap();
assert_eq!(rows_desc.len(), 5);
assert_eq!(rows_desc[0].get(1).unwrap(), &Value::Int4(70));
assert_eq!(rows_desc[1].get(1).unwrap(), &Value::Int4(50));
assert_eq!(rows_desc[2].get(1).unwrap(), &Value::Int4(30));
assert_eq!(rows_desc[3].get(1).unwrap(), &Value::Null);
assert_eq!(rows_desc[4].get(1).unwrap(), &Value::Null);
}
#[test]
fn test_pagination_full_scan() {
let db = setup_pagination_db();
let page_size = 3;
let mut all_ids: Vec<i32> = Vec::new();
for page in 0..4 {
let offset = page * page_size;
let sql = format!(
"SELECT id FROM pg_products ORDER BY id LIMIT {} OFFSET {}",
page_size, offset
);
let rows = db.query(&sql, &[]).unwrap();
if page < 3 {
assert_eq!(rows.len(), 3, "Page {} should have 3 rows, got {}", page, rows.len());
} else {
assert_eq!(rows.len(), 1, "Last page should have 1 row, got {}", rows.len());
}
for row in &rows {
if let Value::Int4(id) = row.get(0).unwrap() {
all_ids.push(*id);
}
}
}
assert_eq!(all_ids, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"Full pagination should yield all IDs 1..=10 in order");
}
#[test]
fn test_limit_with_where() {
let db = setup_pagination_db();
let rows = db.query(
"SELECT id, price FROM pg_products WHERE category = 'Electronics' ORDER BY price DESC LIMIT 2",
&[],
).unwrap();
assert_eq!(rows.len(), 2, "LIMIT 2 on 4 Electronics rows should return 2, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(5), "Most expensive electronics is id=5 (90)");
assert_eq!(rows[0].get(1).unwrap(), &Value::Int4(90));
assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(3), "Second most expensive is id=3 (75)");
assert_eq!(rows[1].get(1).unwrap(), &Value::Int4(75));
}
#[test]
fn test_limit_with_group_by() {
let db = setup_pagination_db();
let rows = db.query(
"SELECT category, COUNT(*) AS cnt FROM pg_products GROUP BY category ORDER BY category LIMIT 2",
&[],
).unwrap();
assert_eq!(rows.len(), 2, "LIMIT 2 on 3 groups should return 2, got {}", rows.len());
assert_eq!(rows[0].get(0).unwrap(), &Value::String("Books".to_string()));
assert_eq!(rows[0].get(1).unwrap(), &Value::Int8(3));
assert_eq!(rows[1].get(0).unwrap(), &Value::String("Clothing".to_string()));
assert_eq!(rows[1].get(1).unwrap(), &Value::Int8(3));
}
#[test]
fn test_pg_compat_version() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query("SELECT version()", &[]).unwrap();
assert_eq!(rows.len(), 1);
let val = rows[0].get(0).unwrap();
match val {
Value::String(s) => {
assert!(s.contains("PostgreSQL"), "version() should mention PostgreSQL, got: {}", s);
assert!(s.contains("HeliosDB"), "version() should mention HeliosDB, got: {}", s);
}
other => panic!("Expected String, got: {:?}", other),
}
}
#[test]
fn test_pg_compat_pg_catalog_version() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query("SELECT pg_catalog.version()", &[]).unwrap();
assert_eq!(rows.len(), 1);
let val = rows[0].get(0).unwrap();
match val {
Value::String(s) => {
assert!(s.contains("PostgreSQL"), "pg_catalog.version() should mention PostgreSQL");
}
other => panic!("Expected String, got: {:?}", other),
}
}
#[test]
fn test_pg_compat_current_schema() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query("SELECT current_schema()", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get(0).unwrap(), &Value::String("public".to_string()));
}
#[test]
fn test_pg_compat_current_database() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
let rows = db.query("SELECT current_database()", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get(0).unwrap(), &Value::String("heliosdb".to_string()));
}
#[test]
fn test_wp_bigint_eq_where_clause() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE wp_posts (ID BIGSERIAL PRIMARY KEY, title TEXT)").unwrap();
db.execute("INSERT INTO wp_posts (title) VALUES ('hello')").unwrap();
let rows_in = db.query("SELECT * FROM wp_posts WHERE ID IN (1)", &[]).unwrap();
assert_eq!(rows_in.len(), 1, "IN (1) should find the row");
let rows_eq = db.query("SELECT * FROM wp_posts WHERE ID = 1", &[]).unwrap();
assert_eq!(rows_eq.len(), 1, "fast-path WHERE ID = 1 should find the row");
let rows_order = db.query("SELECT * FROM wp_posts WHERE ID = 1 ORDER BY ID", &[]).unwrap();
assert_eq!(rows_order.len(), 1, "executor-path WHERE ID = 1 ORDER BY should find the row");
let rows_col = db.query("SELECT ID, title FROM wp_posts WHERE ID = 1", &[]).unwrap();
assert_eq!(rows_col.len(), 1, "SELECT cols WHERE ID = 1 should find the row");
db.execute("CREATE TABLE t_small (id SMALLSERIAL PRIMARY KEY, val TEXT)").unwrap();
db.execute("INSERT INTO t_small (val) VALUES ('x')").unwrap();
let rows_small = db.query("SELECT * FROM t_small WHERE id = 1", &[]).unwrap();
assert_eq!(rows_small.len(), 1, "SMALLSERIAL PK with int4 literal should work");
}
#[test]
fn test_wp_last_insert_id_serial() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE t_serial (id BIGSERIAL PRIMARY KEY, name TEXT)").unwrap();
db.execute("INSERT INTO t_serial (name) VALUES ('hello')").unwrap();
let rows = db.query("SELECT MAX(id) FROM t_serial", &[]).unwrap();
let max_id = rows[0].get(0).unwrap();
match max_id {
Value::Int8(n) => assert!(*n > 0, "SERIAL should auto-generate: got {}", n),
Value::Int4(n) => assert!(*n > 0, "SERIAL should auto-generate: got {}", n),
other => panic!("Unexpected type for MAX(id): {:?}", other),
}
}
#[test]
fn test_wp_duplicate_pk_error_message() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE t_dup (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute("INSERT INTO t_dup VALUES (1, 'a')").unwrap();
let result = db.execute("INSERT INTO t_dup VALUES (1, 'b')");
assert!(result.is_err(), "Duplicate PK insert must fail, but got Ok");
let msg = result.unwrap_err().to_string();
let lower = msg.to_lowercase();
assert!(
lower.contains("duplicate") || lower.contains("unique") || lower.contains("primary key"),
"Duplicate PK error should contain recognizable keywords, got: {}", msg
);
}
#[test]
fn test_wp_duplicate_pk_no_data_corruption() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE t_dup2 (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute("INSERT INTO t_dup2 VALUES (1, 'original')").unwrap();
let _ = db.execute("INSERT INTO t_dup2 VALUES (1, 'duplicate')");
let rows = db.query("SELECT * FROM t_dup2", &[]).unwrap();
assert_eq!(rows.len(), 1, "Only one row should exist after rejected duplicate");
assert_eq!(rows[0].get(1).unwrap(), &Value::String("original".to_string()),
"Original row must be preserved");
}
#[test]
fn test_wp_duplicate_unique_constraint() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE t_uq (id INT PRIMARY KEY, email TEXT UNIQUE)").unwrap();
db.execute("INSERT INTO t_uq VALUES (1, 'a@b.com')").unwrap();
let result = db.execute("INSERT INTO t_uq VALUES (2, 'a@b.com')");
assert!(result.is_err(), "Duplicate UNIQUE insert must fail");
}
#[test]
fn test_on_conflict_do_nothing() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE t_oc1 (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute("INSERT INTO t_oc1 VALUES (1, 'a')").unwrap();
db.execute("INSERT INTO t_oc1 VALUES (1, 'b') ON CONFLICT DO NOTHING").unwrap();
let rows = db.query("SELECT name FROM t_oc1 WHERE id = 1", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].values[0], Value::String("a".to_string()));
}
#[test]
fn test_on_conflict_do_update() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE t_oc2 (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute("INSERT INTO t_oc2 VALUES (1, 'a')").unwrap();
db.execute("INSERT INTO t_oc2 VALUES (1, 'b') ON CONFLICT DO UPDATE SET name = EXCLUDED.name").unwrap();
let rows = db.query("SELECT name FROM t_oc2 WHERE id = 1", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].values[0], Value::String("b".to_string()));
}
#[test]
fn test_on_conflict_do_update_multiple_columns() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE t_oc3 (id INT PRIMARY KEY, name TEXT, score INT)").unwrap();
db.execute("INSERT INTO t_oc3 VALUES (1, 'alice', 10)").unwrap();
db.execute("INSERT INTO t_oc3 VALUES (1, 'bob', 20) ON CONFLICT DO UPDATE SET name = EXCLUDED.name, score = EXCLUDED.score").unwrap();
let rows = db.query("SELECT name, score FROM t_oc3 WHERE id = 1", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].values[0], Value::String("bob".to_string()));
assert_eq!(rows[0].values[1], Value::Int4(20));
}
#[test]
fn test_on_conflict_do_nothing_no_conflict() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE t_oc4 (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute("INSERT INTO t_oc4 VALUES (1, 'a') ON CONFLICT DO NOTHING").unwrap();
let rows = db.query("SELECT name FROM t_oc4 WHERE id = 1", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].values[0], Value::String("a".to_string()));
}
#[test]
fn test_on_conflict_do_update_no_conflict() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE t_oc5 (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute("INSERT INTO t_oc5 VALUES (1, 'a') ON CONFLICT DO UPDATE SET name = EXCLUDED.name").unwrap();
let rows = db.query("SELECT name FROM t_oc5 WHERE id = 1", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].values[0], Value::String("a".to_string()));
}
#[test]
fn test_on_conflict_do_nothing_returns_zero() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE t_oc6 (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute("INSERT INTO t_oc6 VALUES (1, 'a')").unwrap();
let affected = db.execute("INSERT INTO t_oc6 VALUES (1, 'b') ON CONFLICT DO NOTHING").unwrap();
assert_eq!(affected, 0, "DO NOTHING should report 0 affected rows");
}
#[test]
fn test_on_conflict_do_update_returns_one() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE t_oc7 (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute("INSERT INTO t_oc7 VALUES (1, 'a')").unwrap();
let affected = db.execute("INSERT INTO t_oc7 VALUES (1, 'b') ON CONFLICT DO UPDATE SET name = EXCLUDED.name").unwrap();
assert_eq!(affected, 1, "DO UPDATE should report 1 affected row");
}
#[test]
fn test_on_conflict_with_column_list() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE t_oc8 (id INT PRIMARY KEY, name TEXT, val INT)").unwrap();
db.execute("INSERT INTO t_oc8 (id, name, val) VALUES (1, 'a', 10)").unwrap();
db.execute("INSERT INTO t_oc8 (id, name, val) VALUES (1, 'b', 20) ON CONFLICT DO UPDATE SET name = EXCLUDED.name, val = EXCLUDED.val").unwrap();
let rows = db.query("SELECT name, val FROM t_oc8 WHERE id = 1", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].values[0], Value::String("b".to_string()));
assert_eq!(rows[0].values[1], Value::Int4(20));
}
#[test]
fn test_on_conflict_do_update_partial() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE t_oc9 (id INT PRIMARY KEY, name TEXT, val INT)").unwrap();
db.execute("INSERT INTO t_oc9 VALUES (1, 'alice', 10)").unwrap();
db.execute("INSERT INTO t_oc9 VALUES (1, 'bob', 99) ON CONFLICT DO UPDATE SET val = EXCLUDED.val").unwrap();
let rows = db.query("SELECT name, val FROM t_oc9 WHERE id = 1", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].values[0], Value::String("alice".to_string()), "name should be unchanged");
assert_eq!(rows[0].values[1], Value::Int4(99), "val should be updated");
}
#[test]
fn test_on_conflict_do_update_with_literal() {
let db = EmbeddedDatabase::new_in_memory().unwrap();
db.execute("CREATE TABLE t_oc10 (id INT PRIMARY KEY, name TEXT)").unwrap();
db.execute("INSERT INTO t_oc10 VALUES (1, 'a')").unwrap();
db.execute("INSERT INTO t_oc10 VALUES (1, 'b') ON CONFLICT DO UPDATE SET name = 'replaced'").unwrap();
let rows = db.query("SELECT name FROM t_oc10 WHERE id = 1", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].values[0], Value::String("replaced".to_string()));
}
}