1#![deny(
50 clippy::unwrap_used,
51 clippy::todo,
52 clippy::unimplemented,
53)]
54
55#![warn(
57 clippy::expect_used,
58 clippy::panic,
59 clippy::indexing_slicing,
60)]
61
62#![allow(missing_docs)]
65#![warn(rust_2018_idioms)]
66
67#![warn(
69 clippy::pedantic,
70 clippy::nursery,
71 clippy::cargo,
72)]
73#![allow(clippy::cargo_common_metadata)] #![allow(
77 clippy::module_name_repetitions,
78 clippy::missing_errors_doc,
79 clippy::missing_panics_doc,
80 clippy::must_use_candidate,
81 clippy::similar_names,
83 clippy::redundant_else,
84 clippy::needless_continue,
85 clippy::needless_pass_by_ref_mut,
86 clippy::uninlined_format_args,
87 clippy::redundant_closure_for_method_calls,
88 clippy::match_same_arms,
89 clippy::doc_markdown,
90 clippy::items_after_statements,
91 clippy::option_if_let_else,
92 clippy::struct_excessive_bools,
93 clippy::unused_self,
94 clippy::unused_async,
95 clippy::return_self_not_must_use,
96 clippy::if_not_else,
97 clippy::manual_let_else,
98 clippy::single_char_add_str,
99 clippy::unreadable_literal,
100 clippy::needless_raw_string_hashes,
101 clippy::or_fun_call,
102 clippy::derive_partial_eq_without_eq,
103 clippy::redundant_clone,
104 clippy::map_unwrap_or,
105 clippy::needless_borrow,
106 clippy::format_push_string,
107 clippy::default_trait_access,
108 clippy::empty_line_after_doc_comments,
109 clippy::needless_pass_by_value,
110 clippy::wildcard_enum_match_arm,
111 clippy::match_wildcard_for_single_variants,
112 clippy::suboptimal_flops,
113 clippy::wildcard_imports,
114 clippy::ref_option,
115 clippy::needless_collect,
116 clippy::bool_to_int_with_if,
117 clippy::useless_format,
118 clippy::used_underscore_binding,
119 clippy::str_to_string,
120 clippy::implicit_hasher,
121 clippy::string_add_assign,
122 clippy::explicit_iter_loop,
123 clippy::single_match_else,
124 clippy::manual_string_new,
125 clippy::derivable_impls,
126 clippy::too_many_arguments,
127 clippy::too_many_lines,
128 clippy::branches_sharing_code,
129 clippy::manual_strip,
130 clippy::upper_case_acronyms,
131 clippy::struct_field_names,
132 clippy::assigning_clones,
133 clippy::should_implement_trait,
134 clippy::boxed_local,
135 clippy::collapsible_if,
136 clippy::field_reassign_with_default,
137 clippy::unnecessary_cast,
138 clippy::type_complexity,
139 clippy::manual_is_ascii_check,
140 clippy::borrow_as_ptr,
141 clippy::cognitive_complexity,
142 clippy::fn_params_excessive_bools,
143 clippy::iter_without_into_iter,
144 clippy::unit_cmp,
145 clippy::ptr_arg,
146 clippy::use_debug,
147 clippy::redundant_closure,
148 clippy::clone_on_copy,
149 clippy::new_without_default,
150 clippy::manual_range_contains,
151 clippy::manual_range_patterns,
152 clippy::if_then_some_else_none,
153 clippy::match_like_matches_macro,
154 clippy::option_as_ref_cloned,
155 clippy::collapsible_match,
156 clippy::filter_map_identity,
157 clippy::get_first,
158 clippy::implicit_clone,
159 clippy::len_zero,
160 clippy::write_with_newline,
161 clippy::single_char_pattern,
162 clippy::let_and_return,
163 clippy::redundant_pattern_matching,
164 clippy::match_ref_pats,
165 clippy::if_same_then_else,
166 clippy::semicolon_if_nothing_returned,
167 clippy::iter_over_hash_type,
168 clippy::iter_on_single_items,
169 clippy::iter_on_empty_collections,
170 clippy::useless_vec,
171 clippy::vec_init_then_push,
172 clippy::iter_nth_zero,
173 clippy::unwrap_or_default,
174 clippy::trivial_regex,
175 clippy::map_entry,
176 clippy::enum_glob_use,
177 clippy::unnested_or_patterns,
178 clippy::manual_clamp,
179 clippy::cast_ptr_alignment,
180 clippy::ptr_as_ptr,
181 clippy::imprecise_flops,
182 clippy::future_not_send,
183 clippy::significant_drop_in_scrutinee,
184 clippy::collection_is_never_read,
185 clippy::manual_div_ceil,
186 clippy::checked_conversions,
187 clippy::as_underscore,
188 clippy::as_ptr_cast_mut,
189 clippy::trim_split_whitespace,
190 clippy::string_lit_chars_any,
191 clippy::large_enum_variant,
192 clippy::doc_lazy_continuation,
193 clippy::too_long_first_doc_paragraph,
194 clippy::useless_conversion,
195 clippy::multiple_crate_versions,
196 clippy::unit_arg,
197 clippy::inherent_to_string,
198 clippy::to_string_trait_impl,
199 clippy::borrow_deref_ref,
200 clippy::manual_map,
201 clippy::manual_filter_map,
202 clippy::option_map_unit_fn,
203 clippy::result_map_unit_fn,
204 clippy::manual_is_multiple_of,
205 clippy::print_literal,
206 clippy::iter_kv_map,
207 clippy::manual_find,
208 clippy::write_literal,
209 clippy::explicit_into_iter_loop,
210 clippy::manual_ok_or,
211 clippy::bind_instead_of_map,
212 clippy::manual_retain,
213 clippy::io_other_error,
214 clippy::clone_on_ref_ptr,
215 clippy::bool_comparison,
216 clippy::single_match,
217 clippy::iter_next_loop,
218 clippy::str_split_at_newline,
219 clippy::option_as_ref_deref,
220 clippy::arithmetic_side_effects,
221 clippy::cloned_instead_of_copied,
222 clippy::string_slice,
223 clippy::inconsistent_struct_constructor,
224 clippy::unnecessary_literal_unwrap,
225 clippy::ref_binding_to_reference,
226 clippy::match_bool,
227 clippy::partialeq_to_none,
228 clippy::redundant_static_lifetimes,
229 clippy::char_lit_as_u8,
230 clippy::manual_is_power_of_two,
231 clippy::filter_map_bool_then,
232 clippy::manual_flatten,
233 clippy::manual_next_back,
234 clippy::maybe_infinite_iter,
235 clippy::needless_option_as_deref,
236 clippy::suspicious_else_formatting,
237 clippy::useless_transmute,
238 clippy::cast_precision_loss,
240 clippy::cast_possible_truncation,
241 clippy::cast_possible_wrap,
242 clippy::cast_sign_loss,
243 clippy::cast_lossless,
244 clippy::trivially_copy_pass_by_ref,
246 clippy::significant_drop_tightening,
247 clippy::unnecessary_wraps,
248 clippy::missing_const_for_fn,
249 clippy::use_self,
250 clippy::non_std_lazy_statics,
252)]
253
254#![allow(dead_code)]
256#![allow(unused_variables)]
257
258pub mod storage;
260pub mod compute;
261pub mod optimizer;
262pub mod vector;
263pub mod protocol;
264pub mod protocols; pub mod crypto;
266pub mod tenant;
267pub mod sql;
268pub mod audit;
269pub mod network;
270pub mod repl;
271pub mod api; pub mod cli; pub mod session; pub mod ai; pub mod multi_tenant; pub mod git_integration; pub mod runtime; pub mod graph; pub mod search; #[cfg(feature = "code-graph")]
284pub mod code_graph;
285
286#[cfg(feature = "graph-rag")]
289pub mod graph_rag;
290
291#[cfg(feature = "mcp-endpoint")]
298pub mod mcp;
299
300#[cfg(any(
310 feature = "ha-tier1",
311 feature = "ha-tier2",
312 feature = "ha-tier3",
313 feature = "ha-dedup",
314 feature = "ha-branch-replication"
315))]
316pub mod replication;
317
318#[cfg(feature = "ha-ab-testing")]
323pub mod ab_testing;
324
325mod error;
327mod types;
328pub mod config;
333mod embedded_db_dump;
334
335pub use error::{Error, Result};
337pub use types::{DataType, Value, Tuple, Schema, Column, ColumnStorageMode, VectorStoreInfo, AgentSession, AgentMessage, DocumentData, DocumentMetadata};
338pub use config::{Config, KeySource, ZkeMode, ZkeEncryptionConfig};
339pub use storage::StorageEngine;
340pub use crypto::{
341 ZkeConfig, ZkeDerivedKeys, ZkeKeyDerivation, ZkeRequestContext,
342 ZeroKnowledgeSession, NonceTracker, TimestampValidator,
343};
344
345fn convert_logical_referential_action(action: &sql::logical_plan::ReferentialAction) -> sql::constraints::ReferentialAction {
347 match action {
348 sql::logical_plan::ReferentialAction::NoAction => sql::constraints::ReferentialAction::NoAction,
349 sql::logical_plan::ReferentialAction::Restrict => sql::constraints::ReferentialAction::Restrict,
350 sql::logical_plan::ReferentialAction::Cascade => sql::constraints::ReferentialAction::Cascade,
351 sql::logical_plan::ReferentialAction::SetNull => sql::constraints::ReferentialAction::SetNull,
352 sql::logical_plan::ReferentialAction::SetDefault => sql::constraints::ReferentialAction::SetDefault,
353 }
354}
355
356pub struct EmbeddedDatabase {
372 pub storage: std::sync::Arc<storage::StorageEngine>,
374 config: Config,
375 current_transaction: std::sync::Arc<std::sync::Mutex<Option<storage::Transaction>>>,
377 pub tenant_manager: std::sync::Arc<crate::tenant::TenantManager>,
379 pub trigger_registry: std::sync::Arc<sql::TriggerRegistry>,
381 pub function_registry: std::sync::Arc<sql::FunctionRegistry>,
383 mv_scheduler: std::sync::Arc<storage::MVScheduler>,
385 auto_refresh_worker: std::sync::Arc<parking_lot::RwLock<Option<storage::AutoRefreshWorker>>>,
387 pub dump_manager: std::sync::Arc<storage::DumpManager>,
389 pub session_manager: std::sync::Arc<crate::session::SessionManager>,
391 pub lock_manager: std::sync::Arc<storage::LockManager>,
393 pub dirty_tracker: std::sync::Arc<storage::DirtyTracker>,
395 session_transactions: std::sync::Arc<dashmap::DashMap<crate::session::SessionId, storage::Transaction>>,
397 prepared_statements: std::sync::Arc<parking_lot::RwLock<std::collections::HashMap<String, sql::LogicalPlan>>>,
399 savepoints: std::sync::Arc<parking_lot::RwLock<Vec<SavepointState>>>,
401 plan_cache: std::sync::Arc<std::sync::Mutex<lru::LruCache<String, std::sync::Arc<sql::LogicalPlan>>>>,
403 parse_cache: std::sync::Arc<std::sync::Mutex<lru::LruCache<String, sqlparser::ast::Statement>>>,
405 result_cache: std::sync::Arc<std::sync::Mutex<lru::LruCache<String, std::sync::Arc<Vec<Tuple>>>>>,
407 art_undo_log: std::sync::Arc<parking_lot::RwLock<Vec<(String, u64, std::collections::HashMap<String, Value>)>>>,
410}
411
412impl Drop for EmbeddedDatabase {
413 fn drop(&mut self) {
414 if let Some(ref worker) = *self.auto_refresh_worker.read() {
416 worker.request_stop();
417 }
418
419 self.session_transactions.clear();
421
422 self.prepared_statements.write().clear();
424
425 if let Ok(mut cache) = self.plan_cache.lock() {
427 cache.clear();
428 }
429
430 if let Ok(mut cache) = self.parse_cache.lock() {
432 cache.clear();
433 }
434
435 self.savepoints.write().clear();
437
438 tracing::debug!("EmbeddedDatabase dropped, resources cleaned up");
439 }
440}
441
442#[derive(Clone)]
444struct SavepointState {
445 name: String,
447 write_set_snapshot: Vec<(Vec<u8>, Option<Vec<u8>>)>,
451}
452
453#[inline]
455fn starts_with_icase(s: &str, prefix: &str) -> bool {
456 #[allow(clippy::indexing_slicing)]
458 { s.len() >= prefix.len()
459 && s.as_bytes()[..prefix.len()].eq_ignore_ascii_case(prefix.as_bytes()) }
460}
461
462#[cfg(feature = "code-graph")]
468struct CodeGraphBranchGuard<'a> {
469 db: Option<&'a EmbeddedDatabase>,
470 previous: Option<String>,
471 target: Option<String>,
472}
473
474#[cfg(feature = "code-graph")]
475impl<'a> CodeGraphBranchGuard<'a> {
476 fn noop() -> Self {
477 Self { db: None, previous: None, target: None }
478 }
479
480 fn switch_to(db: &'a EmbeddedDatabase, target: String) -> Self {
481 let previous = db.storage.get_current_branch();
482 let current = previous.as_deref().unwrap_or("main");
483 if current == target {
484 return Self::noop();
486 }
487 let _ = db.switch_branch(&target);
491 Self { db: Some(db), previous, target: Some(target) }
492 }
493}
494
495#[cfg(feature = "code-graph")]
496impl<'a> Drop for CodeGraphBranchGuard<'a> {
497 fn drop(&mut self) {
498 let Some(db) = self.db else { return };
499 let Some(target) = self.target.as_deref() else { return };
500 match self.previous.as_deref() {
501 Some(prev) if prev != target => {
502 let _ = db.switch_branch(prev);
503 }
504 None if target != "main" => {
505 let _ = db.switch_branch("main");
506 }
507 _ => {}
508 }
509 }
510}
511
512impl EmbeddedDatabase {
513 fn is_transaction_control(sql: &str) -> bool {
515 let trimmed = sql.trim().trim_end_matches(';').trim();
516 starts_with_icase(trimmed, "BEGIN") ||
517 starts_with_icase(trimmed, "START TRANSACTION") ||
518 trimmed.eq_ignore_ascii_case("COMMIT") ||
519 trimmed.eq_ignore_ascii_case("ROLLBACK")
520 }
521
522 fn handle_transaction_control(&self, sql: &str) -> Result<u64> {
524 let trimmed = sql.trim().trim_end_matches(';').trim();
525
526 if starts_with_icase(trimmed, "BEGIN") || starts_with_icase(trimmed, "START TRANSACTION") {
527 self.begin_transaction_internal()?;
528 Ok(0)
529 } else if trimmed.eq_ignore_ascii_case("COMMIT") {
530 self.commit_internal()?;
531 Ok(0)
532 } else if trimmed.eq_ignore_ascii_case("ROLLBACK") {
533 self.rollback_internal()?;
534 Ok(0)
535 } else {
536 Err(Error::query_execution("Unknown transaction control statement"))
537 }
538 }
539
540 fn begin_transaction_internal(&self) -> Result<()> {
542 use crate::error::LockResultExt;
543 let mut txn_ref = self.current_transaction.lock()
544 .map_lock_err("Failed to acquire transaction lock for begin")?;
545 if txn_ref.is_some() {
546 return Err(Error::transaction("Transaction already active"));
547 }
548 let txn = self.storage.begin_transaction()?;
549 *txn_ref = Some(txn);
550 Ok(())
551 }
552
553 fn commit_internal(&self) -> Result<()> {
555 use crate::error::LockResultExt;
556 let mut txn_ref = self.current_transaction.lock()
557 .map_lock_err("Failed to acquire transaction lock for commit")?;
558 if let Some(txn) = txn_ref.take() {
559 txn.commit()?;
560 self.art_undo_log.write().clear();
562 self.storage.increment_lsn();
564 Ok(())
565 } else {
566 Err(Error::transaction("No active transaction to commit"))
567 }
568 }
569
570 fn rollback_internal(&self) -> Result<()> {
572 use crate::error::LockResultExt;
573 let mut txn_ref = self.current_transaction.lock()
574 .map_lock_err("Failed to acquire transaction lock for rollback")?;
575 if let Some(txn) = txn_ref.take() {
576 txn.rollback()?;
577 let undo_entries: Vec<_> = self.art_undo_log.write().drain(..).collect();
579 for (table_name, row_id, col_values) in undo_entries {
580 if let Err(e) = self.storage.art_indexes().on_delete(&table_name, row_id, &col_values) {
581 tracing::debug!("ART rollback for '{}' row {}: {}", table_name, row_id, e);
582 }
583 }
584 Ok(())
585 } else {
586 Err(Error::transaction("No active transaction to rollback"))
587 }
588 }
589
590 #[cfg(feature = "ha-tier1")]
594 fn try_parse_ha_command(sql: &str) -> Result<Option<sql::LogicalPlan>> {
595 if sql::Parser::is_switchover(sql) {
596 let target_node = sql::Parser::parse_switchover_sql(sql)?;
597 Ok(Some(sql::LogicalPlan::Switchover { target_node }))
598 } else if sql::Parser::is_switchover_check(sql) {
599 let target_node = sql::Parser::parse_switchover_check_sql(sql)?;
600 Ok(Some(sql::LogicalPlan::SwitchoverCheck { target_node }))
601 } else if sql::Parser::is_cluster_status(sql) {
602 Ok(Some(sql::LogicalPlan::ClusterStatus))
603 } else if sql::Parser::is_set_node_alias(sql) {
604 let (node_id, alias) = sql::Parser::parse_set_node_alias_sql(sql)?;
605 Ok(Some(sql::LogicalPlan::SetNodeAlias { node_id, alias }))
606 } else if sql::Parser::is_show_topology(sql) {
607 Ok(Some(sql::LogicalPlan::ShowTopology))
608 } else {
609 Ok(None)
610 }
611 }
612
613 #[cfg(not(feature = "ha-tier1"))]
615 fn try_parse_ha_command(_sql: &str) -> Result<Option<sql::LogicalPlan>> {
616 Ok(None)
617 }
618
619 fn execute_in_transaction(&self, sql: &str, txn: &storage::Transaction) -> Result<u64> {
651 self.execute_in_transaction_inner(sql, txn, false)
652 }
653
654 fn execute_in_transaction_no_fast_path(&self, sql: &str, txn: &storage::Transaction) -> Result<u64> {
658 self.execute_in_transaction_inner(sql, txn, true)
659 }
660
661 fn execute_in_transaction_inner(&self, sql: &str, txn: &storage::Transaction, skip_fast_paths: bool) -> Result<u64> {
662 if let Some(context) = self.tenant_manager.get_current_context() {
664 self.tenant_manager.record_query(context.tenant_id)
665 .map_err(|e| Error::query_execution(format!("Quota exceeded: {}", e)))?;
666 }
667
668 let has_savepoints = !self.savepoints.read().is_empty();
674 let has_session_txns = !self.session_transactions.is_empty();
675 let use_fast_paths = !skip_fast_paths && !has_savepoints && !has_session_txns;
676
677 if use_fast_paths {
679 if let Some(result) = self.try_fast_insert(sql) {
680 return result;
681 }
682 }
683
684 if use_fast_paths {
686 if let Some(result) = self.try_fast_update(sql) {
687 return result;
688 }
689 }
690
691 let plan = if sql::Parser::is_create_branch(sql) {
693 let (branch_name, parent, as_of_clause, with_options) = sql::Parser::parse_create_branch_sql(sql)?;
695 sql::phase3::branching::BranchingParser::parse_create_branch(
696 branch_name,
697 parent,
698 &as_of_clause,
699 with_options.as_deref(),
700 )?
701 } else if sql::Parser::is_drop_branch(sql) {
702 let (branch_name, if_exists) = sql::Parser::parse_drop_branch_sql(sql)?;
704 sql::phase3::branching::BranchingParser::parse_drop_branch(branch_name, if_exists)?
705 } else if sql::Parser::is_merge_branch(sql) {
706 let (source, target, with_options) = sql::Parser::parse_merge_branch_sql(sql)?;
708 sql::phase3::branching::BranchingParser::parse_merge_branch(
709 source,
710 target,
711 with_options.as_deref(),
712 )?
713 } else if sql::Parser::is_use_branch(sql) {
714 let branch_name = sql::Parser::parse_use_branch_sql(sql)?;
716 sql::LogicalPlan::UseBranch { branch_name }
717 } else if sql::Parser::is_show_branches(sql) {
718 sql::LogicalPlan::ShowBranches
720 } else if sql::Parser::is_refresh_materialized_view(sql) {
721 let (view_name, concurrent, incremental) = sql::Parser::parse_refresh_materialized_view_sql(sql)?;
723 sql::LogicalPlan::RefreshMaterializedView {
724 name: view_name,
725 concurrent,
726 incremental,
727 }
728 } else if sql::Parser::is_drop_materialized_view(sql) {
729 let (view_name, if_exists) = sql::Parser::parse_drop_materialized_view_sql(sql)?;
731 sql::LogicalPlan::DropMaterializedView {
732 name: view_name,
733 if_exists,
734 }
735 } else if sql::Parser::is_alter_materialized_view(sql) {
736 let (view_name, options) = sql::Parser::parse_alter_materialized_view_sql(sql)?;
738 sql::LogicalPlan::AlterMaterializedView {
739 name: view_name,
740 options,
741 }
742 } else if sql::Parser::is_alter_column_storage(sql) {
743 let (table_name, column_name, storage_mode) = sql::Parser::parse_alter_column_storage(sql)?;
745 sql::LogicalPlan::AlterColumnStorage {
746 table_name,
747 column_name,
748 storage_mode,
749 }
750 } else if sql::Parser::is_pg_create_procedure(sql) || sql::Parser::is_pg_create_or_replace_procedure(sql) {
751 let (name, or_replace, params, language, body) = sql::Parser::parse_pg_create_procedure(sql)?;
753 let param_list: Vec<sql::logical_plan::FunctionParam> = params.into_iter().map(|(pname, ptype)| {
754 sql::logical_plan::FunctionParam {
755 name: pname,
756 data_type: sql::Planner::parse_data_type_string(&ptype).unwrap_or(DataType::Text),
757 mode: sql::logical_plan::ParamMode::In,
758 default: None,
759 }
760 }).collect();
761 sql::LogicalPlan::CreateProcedure {
762 name,
763 or_replace,
764 params: param_list,
765 body,
766 language,
767 }
768 } else if let Some(plan) = Self::try_parse_ha_command(sql)? {
769 plan
771 } else {
772 let (statement, _) = self.parse_cached(sql)?;
774
775 let catalog = self.storage.catalog();
777 let planner = sql::Planner::with_catalog(&catalog)
778 .with_sql(sql.to_string());
779 planner.statement_to_plan(statement)?
780 };
781
782 if matches!(&plan,
784 sql::LogicalPlan::CreateTable { .. } |
785 sql::LogicalPlan::DropTable { .. } |
786 sql::LogicalPlan::CreateMaterializedView { .. } |
787 sql::LogicalPlan::DropMaterializedView { .. } |
788 sql::LogicalPlan::Truncate { .. }
789 ) {
790 self.invalidate_plan_cache();
791 }
792
793 match &plan {
795 sql::LogicalPlan::CreateTable { name, columns, constraints, if_not_exists, .. } => {
796 if *if_not_exists && self.storage.catalog().table_exists(name).unwrap_or(false) {
798 return Ok(0);
799 }
800
801 let schema_columns: Vec<Column> = columns.iter().map(|col_def| {
802 let default_expr = col_def.default.as_ref().map(|expr| {
804 serde_json::to_string(expr).unwrap_or_default()
805 });
806
807 Column {
808 name: col_def.name.clone(),
809 data_type: col_def.data_type.clone(),
810 nullable: !col_def.not_null,
811 primary_key: col_def.primary_key,
812 source_table: None,
813 source_table_name: None,
814 default_expr,
815 unique: col_def.unique,
816 storage_mode: col_def.storage_mode,
817 }
818 }).collect();
819
820 let schema = Schema::new(schema_columns);
821 let catalog = self.storage.catalog();
822
823 if let Err(e) = self.storage.log_create_table(name, &schema) {
825 tracing::warn!("Failed to log CREATE TABLE to WAL: {}", e);
826 }
827
828 catalog.create_table(name, schema)?;
829
830 if !constraints.is_empty() {
832 let mut table_constraints = sql::TableConstraints::new();
833 for constraint in constraints {
834 match constraint {
835 sql::logical_plan::TableConstraint::ForeignKey {
836 name: fk_name,
837 columns: fk_cols,
838 references_table,
839 references_columns,
840 on_delete,
841 on_update,
842 deferrable,
843 initially_deferred,
844 } => {
845 let fk = sql::ForeignKeyConstraint::new(
846 fk_name.clone().unwrap_or_else(|| {
847 sql::ForeignKeyConstraint::generate_name(name, fk_cols, references_table)
848 }),
849 name.clone(),
850 fk_cols.clone(),
851 references_table.clone(),
852 references_columns.clone(),
853 );
854 let fk = if let Some(action) = on_delete {
855 fk.on_delete(convert_logical_referential_action(action))
856 } else {
857 fk
858 };
859 let fk = if let Some(action) = on_update {
860 fk.on_update(convert_logical_referential_action(action))
861 } else {
862 fk
863 };
864 let fk = if *deferrable {
865 fk.deferrable(*initially_deferred)
866 } else {
867 fk
868 };
869 table_constraints.add_foreign_key(fk);
870 }
871 sql::logical_plan::TableConstraint::PrimaryKey { name: pk_name, columns: pk_cols } => {
872 table_constraints.add_unique(sql::UniqueConstraint::new(
873 pk_name.clone().unwrap_or_else(|| format!("{}_pkey", name)),
874 name.clone(),
875 pk_cols.clone(),
876 true,
877 ));
878 }
879 sql::logical_plan::TableConstraint::Unique { name: uq_name, columns: uq_cols } => {
880 table_constraints.add_unique(sql::UniqueConstraint::new(
881 uq_name.clone().unwrap_or_else(|| format!("{}_unique", name)),
882 name.clone(),
883 uq_cols.clone(),
884 false,
885 ));
886 }
887 sql::logical_plan::TableConstraint::Check { name: ck_name, expression } => {
888 table_constraints.add_check(sql::CheckConstraint::new(
889 ck_name.clone().unwrap_or_else(|| format!("{}_check", name)),
890 name.clone(),
891 serde_json::to_string(expression).unwrap_or_default(),
892 ));
893 }
894 }
895 }
896 catalog.save_table_constraints(name, &table_constraints)?;
897 }
898
899 let catalog = self.storage.catalog();
902 let mut col_constraints = sql::TableConstraints::new();
903 let mut has_col_constraints = false;
904
905 for col_def in columns {
906 if col_def.primary_key {
907 col_constraints.add_unique(sql::UniqueConstraint::new(
908 format!("{}_{}_pkey", name, col_def.name),
909 name.clone(),
910 vec![col_def.name.clone()],
911 true, ));
913 has_col_constraints = true;
914 } else if col_def.unique {
915 col_constraints.add_unique(sql::UniqueConstraint::new(
916 format!("{}_{}_unique", name, col_def.name),
917 name.clone(),
918 vec![col_def.name.clone()],
919 false, ));
921 has_col_constraints = true;
922 }
923 }
924
925 if has_col_constraints {
926 if let Ok(existing) = catalog.load_table_constraints(name) {
928 for fk in existing.foreign_keys {
929 col_constraints.foreign_keys.push(fk);
930 }
931 for check in existing.check_constraints {
932 col_constraints.check_constraints.push(check);
933 }
934 for unique in existing.unique_constraints {
935 col_constraints.unique_constraints.push(unique);
936 }
937 }
938 catalog.save_table_constraints(name, &col_constraints)?;
939 }
940
941 Ok(1)
942 }
943 sql::LogicalPlan::Insert { table_name, columns, values, returning, on_conflict } => {
944 let catalog = self.storage.catalog();
945 let schema = catalog.get_table_schema(table_name)?;
946 let evaluator = sql::Evaluator::new(std::sync::Arc::new(Schema {
947 columns: vec![],
948 }));
949 let empty_tuple = Tuple::new(vec![]);
950
951 let bulk_threshold = self.storage.smfi_bulk_load_threshold();
955 let _smfi_guard = if values.len() >= bulk_threshold {
956 Some(self.storage.suspend_smfi_for_bulk_load(
957 table_name,
958 storage::BulkLoadReason::MultiRowInsert,
959 ))
960 } else {
961 None
962 };
963
964 let mut trigger_context = sql::TriggerContext::new();
966 let trigger_event = sql::logical_plan::TriggerEvent::Insert;
967 let has_triggers = self.trigger_registry.has_triggers_for_table(table_name);
968
969 let mut returned_tuples: Vec<Tuple> = Vec::new();
971 let has_returning = returning.is_some();
972
973 let default_exprs: Vec<Option<sql::LogicalExpr>> = schema.columns.iter()
975 .map(|col| {
976 col.default_expr.as_ref().and_then(|json| {
977 serde_json::from_str(json).ok()
978 })
979 })
980 .collect();
981
982 let column_indices: Option<Vec<usize>> = columns.as_ref().map(|cols| {
984 cols.iter()
985 .filter_map(|col_name| schema.get_column_index(col_name))
986 .collect()
987 });
988
989 let mut count = 0;
990 for value_row in values {
991 let mut tuple_values: Vec<Option<Value>> = vec![None; schema.columns.len()];
993
994 for (val_idx, expr) in value_row.iter().enumerate() {
996 let target_col_idx = if let Some(ref indices) = column_indices {
997 if val_idx >= indices.len() {
998 return Err(Error::query_execution(
999 "More values than columns specified"
1000 ));
1001 }
1002 *indices.get(val_idx).ok_or_else(|| Error::internal("column index out of bounds"))?
1003 } else {
1004 val_idx
1005 };
1006
1007 if matches!(expr, sql::LogicalExpr::DefaultValue) {
1013 continue;
1014 }
1015
1016 let target_col = schema.get_column_at(target_col_idx)
1017 .ok_or_else(|| Error::query_execution(format!(
1018 "Too many values for INSERT: table has {} columns",
1019 schema.columns.len()
1020 )))?;
1021
1022 let target_type = &target_col.data_type;
1023 let mut value = evaluator.evaluate(expr, &empty_tuple)?;
1024
1025 let needs_cast = match (&value, target_type) {
1026 (Value::Null, _) => false,
1027 (Value::Vector(_), DataType::Vector(_)) => false,
1028 (Value::String(_), DataType::Vector(_)) => true,
1029 (Value::String(_), DataType::Json | DataType::Jsonb) => true,
1030 (Value::Int4(_), DataType::Int4) => false,
1031 (Value::Int8(_), DataType::Int8) => false,
1032 (Value::Float4(_), DataType::Float4) => false,
1033 (Value::Float8(_), DataType::Float8) => false,
1034 (Value::String(_), DataType::Text | DataType::Varchar(_)) => false,
1035 (Value::Boolean(_), DataType::Boolean) => false,
1036 (Value::Json(_), DataType::Json | DataType::Jsonb) => false,
1037 _ => true,
1038 };
1039
1040 if needs_cast {
1041 value = evaluator.cast_value(value, target_type)?;
1042 }
1043
1044 if let Some(target_col_ref) = schema.get_column_at(target_col_idx) {
1046 if matches!(value, Value::Null) && !target_col_ref.nullable {
1047 return Err(Error::constraint_violation(format!(
1048 "NOT NULL constraint violated: cannot insert NULL into column '{}'",
1049 target_col_ref.name
1050 )));
1051 }
1052 }
1053
1054 let tv = tuple_values.get_mut(target_col_idx)
1055 .ok_or_else(|| Error::internal("column index out of bounds"))?;
1056 *tv = Some(value);
1057 }
1058
1059 let final_values: Result<Vec<Value>> = tuple_values
1061 .into_iter()
1062 .enumerate()
1063 .map(|(idx, opt_val)| {
1064 if let Some(val) = opt_val {
1065 Ok(val)
1066 } else {
1067 let col = schema.get_column_at(idx)
1069 .ok_or_else(|| Error::internal("column index out of bounds"))?;
1070 if let Some(ref default_expr) = default_exprs.get(idx).and_then(|d| d.as_ref()) {
1071 let mut value = evaluator.evaluate(default_expr, &empty_tuple)?;
1073 if value.data_type() != col.data_type {
1075 value = evaluator.cast_value(value, &col.data_type)?;
1076 }
1077 Ok(value)
1078 } else if col.primary_key {
1079 Ok(Value::Null)
1082 } else if col.nullable {
1083 Ok(Value::Null)
1084 } else {
1085 Err(Error::query_execution(format!(
1086 "Column '{}' does not have a default value and is not nullable",
1087 col.name
1088 )))
1089 }
1090 }
1091 })
1092 .collect();
1093
1094 let final_values_vec = final_values?;
1095
1096 for (idx, col) in schema.columns.iter().enumerate() {
1102 if !col.nullable && !col.primary_key {
1103 if matches!(final_values_vec.get(idx), Some(Value::Null)) {
1104 return Err(Error::constraint_violation(format!(
1105 "NOT NULL constraint violated: cannot insert NULL into column '{}'",
1106 col.name
1107 )));
1108 }
1109 }
1110 }
1111
1112 let mut tuple = Tuple::new(final_values_vec.clone());
1113
1114 let table_constraints = catalog.load_table_constraints(table_name)?;
1116 for fk in &table_constraints.foreign_keys {
1117 if fk.enforcement == sql::ConstraintEnforcement::Immediate {
1118 let fk_values: Vec<Value> = fk.columns.iter()
1119 .map(|col_name| {
1120 schema.columns.iter()
1121 .position(|c| &c.name == col_name)
1122 .and_then(|idx| final_values_vec.get(idx).cloned())
1123 .unwrap_or(Value::Null)
1124 })
1125 .collect();
1126 if fk_values.iter().any(|v| matches!(v, Value::Null)) {
1127 continue;
1128 }
1129 let key = crate::storage::ArtIndexManager::encode_key(&fk_values);
1131 let exists = if let Some(found) = self.storage.art_indexes().pk_index_contains(&fk.references_table, &key) {
1132 found
1133 } else {
1134 self.check_foreign_key_exists(
1136 &fk.references_table,
1137 &fk.references_columns,
1138 &fk_values,
1139 )?
1140 };
1141 if !exists {
1142 return Err(Error::constraint_violation(format!(
1143 "Foreign key constraint '{}' violated: referenced row in table '{}' does not exist",
1144 fk.name, fk.references_table
1145 )));
1146 }
1147 }
1148 }
1149
1150 for check in &table_constraints.check_constraints {
1152 let check_result = self.evaluate_check_constraint(
1154 &check.expression,
1155 &schema,
1156 &final_values_vec,
1157 )?;
1158
1159 if !check_result {
1160 return Err(Error::constraint_violation(format!(
1161 "CHECK constraint '{}' violated: expression '{}' evaluated to false",
1162 check.name, check.expression
1163 )));
1164 }
1165 }
1166
1167 {
1170 let mut col_values_map = std::collections::HashMap::new();
1171 for (i, col) in schema.columns.iter().enumerate() {
1172 if let Some(v) = final_values_vec.get(i) {
1173 col_values_map.insert(col.name.clone(), v.clone());
1174 }
1175 }
1176 if let Err(e) = self.storage.art_indexes().check_unique_constraints(table_name, &col_values_map) {
1177 match on_conflict {
1178 Some(sql::logical_plan::OnConflictAction::DoNothing) => {
1179 continue;
1181 }
1182 Some(sql::logical_plan::OnConflictAction::DoUpdate { assignments }) => {
1183 let err_msg = e.to_string();
1186
1187 let mut excluded_map = std::collections::HashMap::new();
1189 for (i, col) in schema.columns.iter().enumerate() {
1190 if let Some(v) = final_values_vec.get(i) {
1191 excluded_map.insert(col.name.to_lowercase(), v.clone());
1192 }
1193 }
1194
1195 let existing_row_id = {
1200 let mut found_row_id: Option<u64> = None;
1201
1202 for (i, col) in schema.columns.iter().enumerate() {
1204 if (col.unique || col.primary_key) && !col.primary_key {
1205 if let Some(val) = final_values_vec.get(i) {
1207 if !matches!(val, Value::Null) {
1208 let scan_sql = format!(
1210 "SELECT {} FROM {} WHERE {} = '{}'",
1211 schema.columns.iter().find(|c| c.primary_key).map(|c| c.name.as_str()).unwrap_or("rowid"),
1212 table_name,
1213 col.name,
1214 val.to_string().trim_matches('\'')
1215 );
1216 if let Ok(rows) = self.query(&scan_sql, &[]) {
1217 if let Some(row) = rows.first() {
1218 if let Some(pk_val) = row.values.first() {
1219 match pk_val {
1220 Value::Int8(id) => { found_row_id = Some(*id as u64); }
1221 Value::Int4(id) => { found_row_id = Some(*id as u64); }
1222 _ => {}
1223 }
1224 }
1225 }
1226 }
1227 if found_row_id.is_some() { break; }
1228 }
1229 }
1230 }
1231 }
1232
1233 if found_row_id.is_none() {
1235 let pk_cols: Vec<(usize, &crate::Column)> = schema.columns.iter().enumerate()
1236 .filter(|(_, c)| c.primary_key)
1237 .collect();
1238 let pk_values: Vec<Value> = pk_cols.iter()
1239 .filter_map(|(idx, _)| final_values_vec.get(*idx).cloned())
1240 .collect();
1241 if !pk_values.is_empty() && !pk_values.iter().any(|v| matches!(v, Value::Null)) {
1242 let pk_key = crate::storage::ArtIndexManager::encode_key(&pk_values);
1243 found_row_id = self.storage.art_indexes().pk_index_lookup(table_name, &pk_key);
1244 }
1245 }
1246
1247 found_row_id.ok_or_else(|| Error::query_execution(
1248 format!("ON CONFLICT DO UPDATE: could not find existing row ({})", err_msg)
1249 ))?
1250 };
1251
1252 let existing_key = self.storage.branch_aware_data_key(table_name, existing_row_id);
1258 let existing_raw = match txn.get(&existing_key)? {
1259 Some(raw) => raw,
1260 None => self.storage.get(&existing_key)?
1261 .ok_or_else(|| Error::query_execution(
1262 "ON CONFLICT DO UPDATE: existing row not found in storage"
1263 ))?,
1264 };
1265 let mut existing_tuple: Tuple = bincode::deserialize(&existing_raw)
1266 .map_err(|err| Error::storage(format!("Failed to deserialize tuple: {}", err)))?;
1267 existing_tuple.row_id = Some(existing_row_id);
1268
1269 let update_evaluator = sql::Evaluator::new(std::sync::Arc::new(schema.clone()));
1271 for (col_name, expr) in assignments {
1272 let target_idx = schema.columns.iter()
1273 .position(|c| c.name.eq_ignore_ascii_case(col_name))
1274 .ok_or_else(|| Error::query_execution(format!(
1275 "ON CONFLICT DO UPDATE: column '{}' not found", col_name
1276 )))?;
1277
1278 let resolved_expr = Self::resolve_excluded_refs(expr, &excluded_map);
1280 let mut new_val = update_evaluator.evaluate(&resolved_expr, &existing_tuple)?;
1281 let target_type = &schema.columns.get(target_idx)
1283 .ok_or_else(|| Error::internal("column index out of bounds"))?
1284 .data_type;
1285 if new_val.data_type() != *target_type && !matches!(new_val, Value::Null) {
1286 new_val = update_evaluator.cast_value(new_val, target_type)?;
1287 }
1288 if target_idx < existing_tuple.values.len() {
1289 #[allow(clippy::indexing_slicing)]
1290 { existing_tuple.values[target_idx] = new_val; }
1291 }
1292 }
1293
1294 let updated_val = bincode::serialize(&existing_tuple)
1296 .map_err(|err| Error::storage(err.to_string()))?;
1297 txn.put(existing_key.clone(), updated_val.clone())?;
1298
1299 {
1301 let mut updated_col_values = std::collections::HashMap::new();
1302 for (i, col) in schema.columns.iter().enumerate() {
1303 if let Some(v) = existing_tuple.values.get(i) {
1304 updated_col_values.insert(col.name.clone(), v.clone());
1305 }
1306 }
1307 let _ = self.storage.art_indexes().on_delete(table_name, existing_row_id, &col_values_map);
1309 let _ = self.storage.art_indexes().on_insert(table_name, existing_row_id, &updated_col_values);
1310 }
1311
1312 if !skip_fast_paths && self.storage.is_wal_enabled() {
1314 self.storage.log_data_insert(table_name, &existing_key, &updated_val)?;
1315 }
1316
1317 self.invalidate_result_cache();
1319
1320 count += 1;
1321
1322 if has_returning {
1324 if let Some(projected) = Self::project_returning_columns(&existing_tuple, &schema, returning) {
1325 returned_tuples.push(projected);
1326 }
1327 }
1328 continue;
1329 }
1330 None => {
1331 return Err(Error::constraint_violation(e.to_string()));
1332 }
1333 }
1334 }
1335 }
1336
1337 if has_triggers {
1339 let row_context = sql::triggers::TriggerRowContext::for_insert(tuple.clone());
1340 let db_ref = self.clone_for_trigger();
1341 let mut executor_fn = |stmt: &sql::LogicalPlan, _ctx: &sql::triggers::TriggerRowContext| -> Result<()> {
1342 db_ref.execute_plan_internal(stmt)?;
1343 Ok(())
1344 };
1345
1346 let action = self.trigger_registry.execute_triggers(
1347 table_name,
1348 &trigger_event,
1349 &sql::logical_plan::TriggerTiming::Before,
1350 &row_context,
1351 &mut trigger_context,
1352 Some(std::sync::Arc::new(schema.clone())),
1353 &mut executor_fn,
1354 )?;
1355
1356 match action {
1358 sql::triggers::TriggerAction::Abort(msg) => {
1359 return Err(Error::query_execution(format!("INSERT aborted by trigger: {}", msg)));
1360 }
1361 sql::triggers::TriggerAction::Skip => {
1362 continue;
1364 }
1365 sql::triggers::TriggerAction::Continue => {
1366 }
1368 }
1369 }
1370
1371 let row_id = catalog.next_row_id(table_name)?;
1373 let key = self.storage.branch_aware_data_key(table_name, row_id);
1374
1375 for (i, col) in schema.columns.iter().enumerate() {
1378 if col.primary_key {
1379 if let Some(v) = tuple.values.get(i) {
1380 if matches!(v, Value::Null) {
1381 if i < tuple.values.len() {
1382 #[allow(clippy::indexing_slicing)]
1383 match col.data_type {
1384 DataType::Int2 => { tuple.values[i] = Value::Int2(row_id as i16); }
1385 DataType::Int4 => { tuple.values[i] = Value::Int4(row_id as i32); }
1386 _ => { tuple.values[i] = Value::Int8(row_id as i64); }
1387 }
1388 }
1389 }
1390 }
1391 }
1392 }
1393
1394 let mut col_values = std::collections::HashMap::new();
1397 for (i, col) in schema.columns.iter().enumerate() {
1398 if let Some(v) = tuple.values.get(i) {
1399 col_values.insert(col.name.clone(), v.clone());
1400 }
1401 }
1402
1403 self.check_fk_constraints_on_write(table_name, &col_values, Some(txn))?;
1408
1409 let val = bincode::serialize(&tuple).map_err(|e| Error::storage(e.to_string()))?;
1411 txn.put(key.clone(), val.clone())?;
1412
1413 if !skip_fast_paths && self.storage.is_wal_enabled() {
1416 self.storage.log_data_insert(table_name, &key, &val)?;
1417 }
1418
1419 {
1421 if let Err(e) = self.storage.art_indexes().on_insert(table_name, row_id, &col_values) {
1422 tracing::debug!("ART index insert for '{}': {}", table_name, e);
1423 }
1424 if skip_fast_paths {
1426 self.art_undo_log.write().push((table_name.clone(), row_id, col_values));
1427 }
1428 }
1429
1430 count += 1;
1431
1432 if has_returning {
1434 let mut returned_tuple = tuple.clone();
1436 returned_tuple.row_id = Some(row_id);
1437 if let Some(projected) = Self::project_returning_columns(&returned_tuple, &schema, returning) {
1438 returned_tuples.push(projected);
1439 }
1440 }
1441
1442 if let Some(context) = self.tenant_manager.get_current_context() {
1444 let tuple_size = val.len() as u64;
1446
1447 if let Some(current_quota) = self.tenant_manager.get_quota_tracking(context.tenant_id) {
1449 let new_storage = current_quota.storage_bytes_used + tuple_size;
1450 if let Err(e) = self.tenant_manager.update_storage_usage(context.tenant_id, new_storage) {
1451 return Err(Error::query_execution(format!("Storage quota exceeded: {}", e)));
1453 }
1454 }
1455
1456 let new_values = serde_json::to_string(&tuple.values)
1458 .unwrap_or_else(|_| "[]".to_string());
1459
1460 self.tenant_manager.record_change_event(
1461 crate::tenant::ChangeType::Insert,
1462 table_name.to_string(),
1463 row_id.to_string(),
1464 None, Some(new_values),
1466 context.tenant_id,
1467 None, );
1469 }
1470
1471 if has_triggers {
1473 let row_context = sql::triggers::TriggerRowContext::for_insert(tuple.clone());
1474 let db_ref = self.clone_for_trigger();
1475 let mut executor_fn = |stmt: &sql::LogicalPlan, _ctx: &sql::triggers::TriggerRowContext| -> Result<()> {
1476 db_ref.execute_plan_internal(stmt)?;
1477 Ok(())
1478 };
1479 let action = self.trigger_registry.execute_triggers(
1480 table_name,
1481 &trigger_event,
1482 &sql::logical_plan::TriggerTiming::After,
1483 &row_context,
1484 &mut trigger_context,
1485 Some(std::sync::Arc::new(schema.clone())),
1486 &mut executor_fn,
1487 )?;
1488 if let sql::triggers::TriggerAction::Abort(msg) = action {
1489 return Err(Error::query_execution(format!("INSERT aborted by AFTER trigger: {}", msg)));
1490 }
1491 }
1492 }
1493 Ok(count)
1495 }
1496 sql::LogicalPlan::InsertSelect { table_name, columns, source, returning } => {
1497 let mut executor = sql::Executor::with_storage(&self.storage)
1499 .with_timeout(self.config.storage.query_timeout_ms);
1500 let source_rows = executor.execute(source)?;
1501
1502 let catalog = self.storage.catalog();
1503 let schema = catalog.get_table_schema(table_name)?;
1504 let evaluator = sql::Evaluator::new(std::sync::Arc::new(Schema {
1505 columns: vec![],
1506 }));
1507 let empty_tuple = Tuple::new(vec![]);
1508
1509 let column_indices: Option<Vec<usize>> = columns.as_ref().map(|cols| {
1511 cols.iter()
1512 .filter_map(|col_name| schema.get_column_index(col_name))
1513 .collect()
1514 });
1515
1516 let default_exprs: Vec<Option<sql::LogicalExpr>> = schema.columns.iter()
1518 .map(|col| {
1519 col.default_expr.as_ref().and_then(|json| {
1520 serde_json::from_str(json).ok()
1521 })
1522 })
1523 .collect();
1524
1525 let mut trigger_context = sql::TriggerContext::new();
1527 let trigger_event = sql::logical_plan::TriggerEvent::Insert;
1528 let has_triggers = self.trigger_registry.has_triggers_for_table(table_name);
1529
1530 let has_returning = returning.is_some();
1531 let mut returned_tuples: Vec<Tuple> = Vec::new();
1532
1533 let bulk_threshold = self.storage.smfi_bulk_load_threshold();
1535 let _smfi_guard = if source_rows.len() >= bulk_threshold {
1536 Some(self.storage.suspend_smfi_for_bulk_load(
1537 table_name,
1538 storage::BulkLoadReason::MultiRowInsert,
1539 ))
1540 } else {
1541 None
1542 };
1543
1544 let mut count = 0u64;
1545 for source_row in &source_rows {
1546 let mut tuple_values: Vec<Option<Value>> = vec![None; schema.columns.len()];
1548
1549 for (val_idx, value) in source_row.values.iter().enumerate() {
1551 let target_col_idx = if let Some(ref indices) = column_indices {
1552 if val_idx >= indices.len() {
1553 return Err(Error::query_execution(
1554 "More values than columns specified"
1555 ));
1556 }
1557 *indices.get(val_idx).ok_or_else(|| Error::internal("column index out of bounds"))?
1558 } else {
1559 val_idx
1560 };
1561
1562 let target_col = schema.get_column_at(target_col_idx)
1563 .ok_or_else(|| Error::query_execution(format!(
1564 "Too many values for INSERT: table has {} columns",
1565 schema.columns.len()
1566 )))?;
1567
1568 let target_type = &target_col.data_type;
1569 let mut val = value.clone();
1570
1571 let needs_cast = match (&val, target_type) {
1573 (Value::Null, _) => false,
1574 (Value::Vector(_), DataType::Vector(_)) => false,
1575 (Value::String(_), DataType::Vector(_)) => true,
1576 (Value::String(_), DataType::Json | DataType::Jsonb) => true,
1577 (Value::Int4(_), DataType::Int4) => false,
1578 (Value::Int8(_), DataType::Int8) => false,
1579 (Value::Float4(_), DataType::Float4) => false,
1580 (Value::Float8(_), DataType::Float8) => false,
1581 (Value::String(_), DataType::Text | DataType::Varchar(_)) => false,
1582 (Value::Boolean(_), DataType::Boolean) => false,
1583 (Value::Json(_), DataType::Json | DataType::Jsonb) => false,
1584 _ => true,
1585 };
1586
1587 if needs_cast {
1588 val = evaluator.cast_value(val, target_type)?;
1589 }
1590
1591 if let Some(target_col_ref) = schema.get_column_at(target_col_idx) {
1593 if matches!(val, Value::Null) && !target_col_ref.nullable {
1594 return Err(Error::constraint_violation(format!(
1595 "NOT NULL constraint violated: cannot insert NULL into column '{}'",
1596 target_col_ref.name
1597 )));
1598 }
1599 }
1600
1601 let tv = tuple_values.get_mut(target_col_idx)
1602 .ok_or_else(|| Error::internal("column index out of bounds"))?;
1603 *tv = Some(val);
1604 }
1605
1606 let final_values: Result<Vec<Value>> = tuple_values
1608 .into_iter()
1609 .enumerate()
1610 .map(|(idx, opt_val)| {
1611 if let Some(val) = opt_val {
1612 Ok(val)
1613 } else {
1614 let col = schema.get_column_at(idx)
1615 .ok_or_else(|| Error::internal("column index out of bounds"))?;
1616 if let Some(ref default_expr) = default_exprs.get(idx).and_then(|d| d.as_ref()) {
1617 let mut value = evaluator.evaluate(default_expr, &empty_tuple)?;
1618 if value.data_type() != col.data_type {
1619 value = evaluator.cast_value(value, &col.data_type)?;
1620 }
1621 Ok(value)
1622 } else if col.primary_key {
1623 Ok(Value::Null)
1626 } else if col.nullable {
1627 Ok(Value::Null)
1628 } else {
1629 Err(Error::query_execution(format!(
1630 "Column '{}' does not have a default value and is not nullable",
1631 col.name
1632 )))
1633 }
1634 }
1635 })
1636 .collect();
1637
1638 let final_values_vec = final_values?;
1639 let tuple = Tuple::new(final_values_vec.clone());
1640
1641 let table_constraints = catalog.load_table_constraints(table_name)?;
1643 for fk in &table_constraints.foreign_keys {
1644 if fk.enforcement == sql::ConstraintEnforcement::Immediate {
1645 let fk_values: Vec<Value> = fk.columns.iter()
1646 .map(|col_name| {
1647 schema.columns.iter()
1648 .position(|c| &c.name == col_name)
1649 .and_then(|idx| final_values_vec.get(idx).cloned())
1650 .unwrap_or(Value::Null)
1651 })
1652 .collect();
1653 if fk_values.iter().any(|v| matches!(v, Value::Null)) {
1654 continue;
1655 }
1656 let key = crate::storage::ArtIndexManager::encode_key(&fk_values);
1657 let exists = if let Some(found) = self.storage.art_indexes().pk_index_contains(&fk.references_table, &key) {
1658 found
1659 } else {
1660 self.check_foreign_key_exists(
1661 &fk.references_table,
1662 &fk.references_columns,
1663 &fk_values,
1664 )?
1665 };
1666 if !exists {
1667 return Err(Error::constraint_violation(format!(
1668 "Foreign key constraint '{}' violated: referenced row in table '{}' does not exist",
1669 fk.name, fk.references_table
1670 )));
1671 }
1672 }
1673 }
1674
1675 for check in &table_constraints.check_constraints {
1677 let check_result = self.evaluate_check_constraint(
1678 &check.expression,
1679 &schema,
1680 &final_values_vec,
1681 )?;
1682
1683 if !check_result {
1684 return Err(Error::constraint_violation(format!(
1685 "CHECK constraint '{}' violated: expression '{}' evaluated to false",
1686 check.name, check.expression
1687 )));
1688 }
1689 }
1690
1691 if !table_constraints.unique_constraints.is_empty() {
1693 for uc in &table_constraints.unique_constraints {
1694 let uc_values: Vec<Value> = uc.columns.iter()
1695 .map(|col_name| {
1696 schema.columns.iter()
1697 .position(|c| &c.name == col_name)
1698 .and_then(|idx| final_values_vec.get(idx).cloned())
1699 .unwrap_or(Value::Null)
1700 })
1701 .collect();
1702 if uc_values.iter().any(|v| matches!(v, Value::Null)) {
1703 continue;
1704 }
1705 let key = crate::storage::ArtIndexManager::encode_key(&uc_values);
1706 if self.storage.art_indexes().pk_index_contains(table_name, &key) == Some(true) {
1707 return Err(Error::constraint_violation(format!(
1708 "UNIQUE constraint '{}' violated: duplicate value for columns ({})",
1709 uc.name,
1710 uc.columns.join(", ")
1711 )));
1712 }
1713 }
1714 }
1715
1716 if has_triggers {
1718 let row_context = sql::triggers::TriggerRowContext::for_insert(tuple.clone());
1719 let db_ref = self.clone_for_trigger();
1720 let mut executor_fn = |stmt: &sql::LogicalPlan, _ctx: &sql::triggers::TriggerRowContext| -> Result<()> {
1721 db_ref.execute_plan_internal(stmt)?;
1722 Ok(())
1723 };
1724 let action = self.trigger_registry.execute_triggers(
1725 table_name,
1726 &trigger_event,
1727 &sql::logical_plan::TriggerTiming::Before,
1728 &row_context,
1729 &mut trigger_context,
1730 Some(std::sync::Arc::new(schema.clone())),
1731 &mut executor_fn,
1732 )?;
1733 match action {
1735 sql::triggers::TriggerAction::Abort(msg) => {
1736 return Err(Error::query_execution(format!("INSERT aborted by trigger: {}", msg)));
1737 }
1738 sql::triggers::TriggerAction::Skip => {
1739 continue;
1740 }
1741 sql::triggers::TriggerAction::Continue => {}
1742 }
1743 }
1744
1745 let row_id = self.storage.insert_tuple_branch_aware_with_schema(table_name, tuple.clone(), &schema)?;
1747
1748 {
1750 let mut col_values = std::collections::HashMap::new();
1751 for (i, col) in schema.columns.iter().enumerate() {
1752 if let Some(v) = final_values_vec.get(i) {
1753 col_values.insert(col.name.clone(), v.clone());
1754 }
1755 }
1756 if let Err(e) = self.storage.art_indexes().on_insert(table_name, row_id, &col_values) {
1757 tracing::debug!("ART index insert for '{}': {}", table_name, e);
1758 }
1759 }
1760
1761 count += 1;
1762
1763 if has_returning {
1765 let mut returned_tuple = tuple.clone();
1766 returned_tuple.row_id = Some(row_id);
1767 if let Some(projected) = Self::project_returning_columns(&returned_tuple, &schema, returning) {
1768 returned_tuples.push(projected);
1769 }
1770 }
1771
1772 if has_triggers {
1774 let row_context = sql::triggers::TriggerRowContext::for_insert(tuple.clone());
1775 let db_ref = self.clone_for_trigger();
1776 let mut executor_fn = |stmt: &sql::LogicalPlan, _ctx: &sql::triggers::TriggerRowContext| -> Result<()> {
1777 db_ref.execute_plan_internal(stmt)?;
1778 Ok(())
1779 };
1780 let action = self.trigger_registry.execute_triggers(
1781 table_name,
1782 &trigger_event,
1783 &sql::logical_plan::TriggerTiming::After,
1784 &row_context,
1785 &mut trigger_context,
1786 Some(std::sync::Arc::new(schema.clone())),
1787 &mut executor_fn,
1788 )?;
1789 if let sql::triggers::TriggerAction::Abort(msg) = action {
1790 return Err(Error::query_execution(format!("INSERT aborted by AFTER trigger: {}", msg)));
1791 }
1792 }
1793 }
1794 Ok(count)
1795 }
1796 sql::LogicalPlan::Update { table_name, assignments, selection, returning } => {
1797 let catalog = self.storage.catalog();
1798 let schema = catalog.get_table_schema(table_name)?;
1799 let eval_schema = schema.clone().with_source_table_name(table_name);
1802 let evaluator = sql::Evaluator::with_parameters(
1803 std::sync::Arc::new(eval_schema),
1804 vec![],
1805 );
1806
1807 let mut trigger_context = sql::TriggerContext::new();
1809 let updated_columns: Vec<String> = assignments.iter().map(|(col, _)| col.clone()).collect();
1810 let trigger_event = sql::logical_plan::TriggerEvent::Update(Some(updated_columns));
1811 let has_triggers = self.trigger_registry.has_triggers_for_table(table_name);
1812
1813 let on_branch = self.storage.get_current_branch().is_some();
1818 let tuples = if !on_branch {
1819 if let Some(pk_value) = Self::try_extract_pk_value(selection.as_ref(), &schema) {
1820 match self.storage.get_row_by_pk(table_name, &pk_value)? {
1821 Some(tuple) => vec![tuple],
1822 None => vec![],
1823 }
1824 } else {
1825 self.storage.scan_table_branch_aware(table_name)?
1826 }
1827 } else {
1828 self.storage.scan_table_branch_aware(table_name)?
1829 };
1830 let mut updates: Vec<(u64, Tuple)> = Vec::new();
1831
1832 for old_tuple in tuples {
1833 let matches = if let Some(predicate) = selection {
1834 let result = evaluator.evaluate(predicate, &old_tuple)?;
1835 match result {
1836 Value::Boolean(b) => b,
1837 _ => false,
1838 }
1839 } else {
1840 true
1841 };
1842
1843 if matches {
1844 let mut new_tuple = old_tuple.clone();
1846 for (col_name, value_expr) in assignments {
1847 let bound = self.materialize_scalar_subqueries_for_row(
1854 value_expr, &old_tuple, &schema, table_name,
1855 )?;
1856 let mut new_value = evaluator.evaluate(&bound, &old_tuple)?;
1857 let col_index = evaluator.schema().get_column_index(col_name)
1858 .ok_or_else(|| Error::query_execution(format!("Column '{}' not found", col_name)))?;
1859 let target_col = schema.get_column_at(col_index)
1861 .ok_or_else(|| Error::query_execution(format!("Column '{}' not found", col_name)))?;
1862 let target_type = &target_col.data_type;
1863 let needs_cast = !matches!(&new_value, Value::Null)
1864 && !matches!(
1865 (&new_value, target_type),
1866 (Value::Vector(_), DataType::Vector(_))
1867 | (Value::Int2(_), DataType::Int2)
1868 | (Value::Int4(_), DataType::Int4)
1869 | (Value::Int8(_), DataType::Int8)
1870 | (Value::Float4(_), DataType::Float4)
1871 | (Value::Float8(_), DataType::Float8)
1872 | (Value::String(_), DataType::Text | DataType::Varchar(_))
1873 | (Value::Boolean(_), DataType::Boolean)
1874 | (Value::Json(_), DataType::Json | DataType::Jsonb)
1875 | (Value::Timestamp(_), DataType::Timestamp | DataType::Timestamptz)
1876 | (Value::Date(_), DataType::Date)
1877 );
1878 if needs_cast {
1879 new_value = evaluator.cast_value(new_value, target_type)?;
1880 }
1881 *new_tuple.values.get_mut(col_index)
1882 .ok_or_else(|| Error::internal("column index out of bounds"))? = new_value;
1883 }
1884
1885 if has_triggers {
1887 let row_context = sql::triggers::TriggerRowContext::for_update(old_tuple.clone(), new_tuple.clone());
1888 let db_ref = self.clone_for_trigger();
1889 let mut executor_fn = |stmt: &sql::LogicalPlan, _ctx: &sql::triggers::TriggerRowContext| -> Result<()> {
1890 db_ref.execute_plan_internal(stmt)?;
1891 Ok(())
1892 };
1893
1894 let action = self.trigger_registry.execute_triggers(
1895 table_name,
1896 &trigger_event,
1897 &sql::logical_plan::TriggerTiming::Before,
1898 &row_context,
1899 &mut trigger_context,
1900 Some(evaluator.schema().clone()),
1901 &mut executor_fn,
1902 )?;
1903
1904 match action {
1906 sql::triggers::TriggerAction::Abort(msg) => {
1907 return Err(Error::query_execution(format!("UPDATE aborted by trigger: {}", msg)));
1908 }
1909 sql::triggers::TriggerAction::Skip => {
1910 continue;
1912 }
1913 sql::triggers::TriggerAction::Continue => {
1914 }
1916 }
1917 }
1918
1919 let mut new_col_values = std::collections::HashMap::with_capacity(schema.columns.len());
1923 for (i, col) in schema.columns.iter().enumerate() {
1924 if let Some(v) = new_tuple.values.get(i) {
1925 new_col_values.insert(col.name.clone(), v.clone());
1926 }
1927 }
1928 self.check_fk_constraints_on_write(table_name, &new_col_values, Some(txn))?;
1929
1930 let row_id = new_tuple.row_id.unwrap_or(0);
1931 updates.push((row_id, new_tuple.clone()));
1932
1933 if let Some(context) = self.tenant_manager.get_current_context() {
1935 let old_values = serde_json::to_string(&old_tuple.values)
1936 .unwrap_or_else(|_| "[]".to_string());
1937 let new_values = serde_json::to_string(&new_tuple.values)
1938 .unwrap_or_else(|_| "[]".to_string());
1939
1940 self.tenant_manager.record_change_event(
1941 crate::tenant::ChangeType::Update,
1942 table_name.to_string(),
1943 row_id.to_string(),
1944 Some(old_values),
1945 Some(new_values),
1946 context.tenant_id,
1947 None,
1948 );
1949 }
1950
1951 if has_triggers {
1953 let row_context = sql::triggers::TriggerRowContext::for_update(old_tuple.clone(), new_tuple.clone());
1954 let db_ref = self.clone_for_trigger();
1955 let mut executor_fn = |stmt: &sql::LogicalPlan, _ctx: &sql::triggers::TriggerRowContext| -> Result<()> {
1956 db_ref.execute_plan_internal(stmt)?;
1957 Ok(())
1958 };
1959 let action = self.trigger_registry.execute_triggers(
1960 table_name,
1961 &trigger_event,
1962 &sql::logical_plan::TriggerTiming::After,
1963 &row_context,
1964 &mut trigger_context,
1965 Some(evaluator.schema().clone()),
1966 &mut executor_fn,
1967 )?;
1968
1969 if let sql::triggers::TriggerAction::Abort(msg) = action {
1971 return Err(Error::query_execution(format!("UPDATE aborted by AFTER trigger: {}", msg)));
1972 }
1973 }
1974 }
1975 }
1976
1977 let update_count = updates.len() as u64;
1978 for (row_id, tuple) in &updates {
1982 let key = self.storage.branch_aware_data_key(table_name, *row_id);
1983 let value = bincode::serialize(tuple)
1984 .map_err(|e| Error::storage(format!("Failed to serialize tuple: {}", e)))?;
1985 txn.put(key.clone(), value.clone())?;
1986
1987 if !skip_fast_paths && self.storage.is_wal_enabled() {
1990 self.storage.log_data_update(table_name, &key, &value)?;
1991 }
1992
1993 self.storage.row_cache().invalidate(table_name, *row_id);
1995 }
1996
1997 if let Some(context) = self.tenant_manager.get_current_context() {
1999 let mut storage_delta: i64 = 0;
2001 for (_row_id, new_tuple) in &updates {
2002 let new_size = bincode::serialize(new_tuple)
2003 .map(|bytes| bytes.len() as i64)
2004 .unwrap_or(256);
2005 storage_delta += new_size;
2008 }
2009
2010 if let Some(current_quota) = self.tenant_manager.get_quota_tracking(context.tenant_id) {
2011 let new_storage = (current_quota.storage_bytes_used as i64 + storage_delta).max(0) as u64;
2012 if let Err(e) = self.tenant_manager.update_storage_usage(context.tenant_id, new_storage) {
2013 return Err(Error::query_execution(format!("Storage quota exceeded: {}", e)));
2014 }
2015 }
2016 }
2017
2018 let returned_tuples: Vec<Tuple> = if returning.is_some() {
2020 updates.iter()
2021 .filter_map(|(_, tuple)| Self::project_returning_columns(tuple, &schema, returning))
2022 .collect()
2023 } else {
2024 Vec::new()
2025 };
2026 let _ = returned_tuples; Ok(update_count)
2029 }
2030 sql::LogicalPlan::Delete { table_name, selection, returning } => {
2031 let catalog = self.storage.catalog();
2032 let schema = catalog.get_table_schema(table_name)?;
2033 let schema_arc = std::sync::Arc::new(schema);
2034 let eval_schema = std::sync::Arc::new(
2037 (*schema_arc).clone().with_source_table_name(table_name),
2038 );
2039 let evaluator = sql::Evaluator::with_parameters(
2040 eval_schema,
2041 vec![],
2042 );
2043
2044 let mut trigger_context = sql::TriggerContext::new();
2046 let trigger_event = sql::logical_plan::TriggerEvent::Delete;
2047 let has_triggers = self.trigger_registry.has_triggers_for_table(table_name);
2048
2049 let on_branch = self.storage.get_current_branch().is_some();
2054 let tuples = if !on_branch {
2055 if let Some(pk_value) = Self::try_extract_pk_value(selection.as_ref(), &schema_arc) {
2056 match self.storage.get_row_by_pk(table_name, &pk_value)? {
2057 Some(tuple) => vec![tuple],
2058 None => vec![],
2059 }
2060 } else {
2061 self.storage.scan_table_branch_aware(table_name)?
2062 }
2063 } else {
2064 self.storage.scan_table_branch_aware(table_name)?
2065 };
2066 let mut row_ids_to_delete: Vec<u64> = Vec::new();
2067 let mut deleted_tuples: Vec<(u64, Tuple)> = Vec::new();
2069
2070 let mut returned_tuples: Vec<Tuple> = Vec::new();
2072 let has_returning = returning.is_some();
2073
2074 for tuple in tuples {
2075 let matches = if let Some(predicate) = selection {
2076 let result = evaluator.evaluate(predicate, &tuple)?;
2077 match result {
2078 Value::Boolean(b) => b,
2079 _ => false,
2080 }
2081 } else {
2082 true
2083 };
2084
2085 if matches {
2086 if has_triggers {
2088 let row_context = sql::triggers::TriggerRowContext::for_delete(tuple.clone());
2089 let db_ref = self.clone_for_trigger();
2090 let mut executor_fn = |stmt: &sql::LogicalPlan, _ctx: &sql::triggers::TriggerRowContext| -> Result<()> {
2091 db_ref.execute_plan_internal(stmt)?;
2092 Ok(())
2093 };
2094
2095 let action = self.trigger_registry.execute_triggers(
2096 table_name,
2097 &trigger_event,
2098 &sql::logical_plan::TriggerTiming::Before,
2099 &row_context,
2100 &mut trigger_context,
2101 Some(evaluator.schema().clone()),
2102 &mut executor_fn,
2103 )?;
2104
2105 match action {
2107 sql::triggers::TriggerAction::Abort(msg) => {
2108 return Err(Error::query_execution(format!("DELETE aborted by trigger: {}", msg)));
2109 }
2110 sql::triggers::TriggerAction::Skip => {
2111 continue;
2113 }
2114 sql::triggers::TriggerAction::Continue => {
2115 }
2117 }
2118 }
2119
2120 if let Some(row_id) = tuple.row_id {
2121 let referencing_fks = catalog.get_referencing_fks(table_name)?;
2123 for fk in &referencing_fks {
2124 if fk.enforcement == sql::ConstraintEnforcement::Immediate {
2125 let ref_values: Vec<Value> = fk.references_columns.iter()
2127 .map(|col_name| {
2128 schema_arc.columns.iter()
2129 .position(|c| &c.name == col_name)
2130 .and_then(|idx| tuple.values.get(idx).cloned())
2131 .unwrap_or(Value::Null)
2132 })
2133 .collect();
2134
2135 let has_refs = self.check_referencing_rows_exist(
2142 &fk.table_name,
2143 &fk.columns,
2144 &ref_values,
2145 Some(txn),
2146 )?;
2147
2148 if has_refs {
2149 match fk.on_delete {
2150 sql::constraints::ReferentialAction::NoAction |
2151 sql::constraints::ReferentialAction::Restrict => {
2152 return Err(Error::constraint_violation(format!(
2153 "Foreign key constraint '{}' violated: cannot delete row from '{}' - referenced by '{}'",
2154 fk.name, table_name, fk.table_name
2155 )));
2156 }
2157 sql::constraints::ReferentialAction::Cascade => {
2158 self.cascade_delete_referencing_rows(
2160 &fk.table_name,
2161 &fk.columns,
2162 &ref_values,
2163 )?;
2164 }
2165 sql::constraints::ReferentialAction::SetNull => {
2166 self.set_null_referencing_rows(
2168 &fk.table_name,
2169 &fk.columns,
2170 &ref_values,
2171 )?;
2172 }
2173 sql::constraints::ReferentialAction::SetDefault => {
2174 return Err(Error::constraint_violation(format!(
2176 "Foreign key constraint '{}' with SET DEFAULT action: not implemented",
2177 fk.name
2178 )));
2179 }
2180 }
2181 }
2182 }
2183 }
2184
2185 row_ids_to_delete.push(row_id);
2186 deleted_tuples.push((row_id, tuple.clone()));
2187
2188 if has_returning {
2190 if let Some(projected) = Self::project_returning_columns(&tuple, &schema_arc, returning) {
2191 returned_tuples.push(projected);
2192 }
2193 }
2194
2195 if let Some(context) = self.tenant_manager.get_current_context() {
2197 let old_values = serde_json::to_string(&tuple.values)
2198 .unwrap_or_else(|_| "[]".to_string());
2199
2200 self.tenant_manager.record_change_event(
2201 crate::tenant::ChangeType::Delete,
2202 table_name.to_string(),
2203 row_id.to_string(),
2204 Some(old_values),
2205 None, context.tenant_id,
2207 None,
2208 );
2209 }
2210 }
2211
2212 if has_triggers {
2214 let row_context = sql::triggers::TriggerRowContext::for_delete(tuple.clone());
2215 let db_ref = self.clone_for_trigger();
2216 let mut executor_fn = |stmt: &sql::LogicalPlan, _ctx: &sql::triggers::TriggerRowContext| -> Result<()> {
2217 db_ref.execute_plan_internal(stmt)?;
2218 Ok(())
2219 };
2220 let action = self.trigger_registry.execute_triggers(
2221 table_name,
2222 &trigger_event,
2223 &sql::logical_plan::TriggerTiming::After,
2224 &row_context,
2225 &mut trigger_context,
2226 Some(evaluator.schema().clone()),
2227 &mut executor_fn,
2228 )?;
2229
2230 if let sql::triggers::TriggerAction::Abort(msg) = action {
2232 return Err(Error::query_execution(format!("DELETE aborted by AFTER trigger: {}", msg)));
2233 }
2234 }
2235 }
2236 }
2237
2238 let storage_reclaimed: u64 = if self.tenant_manager.get_current_context().is_some() {
2240 (row_ids_to_delete.len() as u64) * 256
2241 } else {
2242 0
2243 };
2244
2245 let delete_count = row_ids_to_delete.len() as u64;
2246 if let Some(branch_id) = self.storage.get_current_branch_id() {
2250 for row_id in &row_ids_to_delete {
2252 let delete_key = format!("bdel:{}:{}:{}", branch_id, table_name, row_id).into_bytes();
2253 txn.put(delete_key, vec![])?;
2254
2255 self.storage.row_cache().invalidate(table_name, *row_id);
2257 }
2258 } else {
2259 for row_id in &row_ids_to_delete {
2261 let key = format!("data:{}:{}", table_name, row_id).into_bytes();
2262 txn.delete(key.clone())?;
2263
2264 if !skip_fast_paths && self.storage.is_wal_enabled() {
2267 self.storage.log_data_delete(table_name, &key)?;
2268 }
2269
2270 self.storage.row_cache().invalidate(table_name, *row_id);
2272 }
2273 }
2274
2275 if let Some(context) = self.tenant_manager.get_current_context() {
2277 if let Some(current_quota) = self.tenant_manager.get_quota_tracking(context.tenant_id) {
2278 let new_storage = current_quota.storage_bytes_used.saturating_sub(storage_reclaimed);
2279 let _ = self.tenant_manager.update_storage_usage(context.tenant_id, new_storage);
2281 }
2282 }
2283 for (row_id, tuple) in &deleted_tuples {
2285 let mut col_values = std::collections::HashMap::new();
2286 for (i, col) in schema_arc.columns.iter().enumerate() {
2287 if let Some(v) = tuple.values.get(i) {
2288 col_values.insert(col.name.clone(), v.clone());
2289 }
2290 }
2291 if let Err(e) = self.storage.art_indexes().on_delete(table_name, *row_id, &col_values) {
2292 tracing::debug!("ART index delete for table '{}': {}", table_name, e);
2293 }
2294 }
2295
2296 let _ = returned_tuples; Ok(delete_count)
2299 }
2300 sql::LogicalPlan::CreateFunction { name, or_replace, params, return_type, body, language, volatility } => {
2301 let stored_func = sql::StoredFunction {
2303 name: name.clone(),
2304 or_replace: *or_replace,
2305 params: params.clone(),
2306 return_type: return_type.clone(),
2307 body: body.clone(),
2308 language: language.clone(),
2309 volatility: volatility.clone(),
2310 created_at: std::time::SystemTime::now()
2311 .duration_since(std::time::UNIX_EPOCH)
2312 .map(|d| d.as_millis() as u64)
2313 .unwrap_or(0),
2314 };
2315 self.function_registry.register_function(stored_func.clone())?;
2316
2317 if let Ok(definition) = bincode::serialize(&stored_func) {
2319 if let Err(e) = self.storage.log_create_function(name, &definition) {
2320 tracing::warn!("Failed to log CREATE FUNCTION to WAL: {}", e);
2321 }
2322 }
2323 Ok(0)
2324 }
2325 sql::LogicalPlan::CreateProcedure { name, or_replace, params, body, language } => {
2326 let stored_proc = sql::StoredProcedure {
2328 name: name.clone(),
2329 or_replace: *or_replace,
2330 params: params.clone(),
2331 body: body.clone(),
2332 language: language.clone(),
2333 created_at: std::time::SystemTime::now()
2334 .duration_since(std::time::UNIX_EPOCH)
2335 .map(|d| d.as_millis() as u64)
2336 .unwrap_or(0),
2337 };
2338 self.function_registry.register_procedure(stored_proc.clone())?;
2339
2340 if let Ok(definition) = bincode::serialize(&stored_proc) {
2342 if let Err(e) = self.storage.log_create_procedure(name, &definition) {
2343 tracing::warn!("Failed to log CREATE PROCEDURE to WAL: {}", e);
2344 }
2345 }
2346 Ok(0)
2347 }
2348 sql::LogicalPlan::DropFunction { name, if_exists } => {
2349 self.function_registry.drop_function(name, *if_exists)?;
2350
2351 if let Err(e) = self.storage.log_drop_function(name) {
2353 tracing::warn!("Failed to log DROP FUNCTION to WAL: {}", e);
2354 }
2355 Ok(0)
2356 }
2357 sql::LogicalPlan::DropProcedure { name, if_exists } => {
2358 self.function_registry.drop_procedure(name, *if_exists)?;
2359
2360 if let Err(e) = self.storage.log_drop_procedure(name) {
2362 tracing::warn!("Failed to log DROP PROCEDURE to WAL: {}", e);
2363 }
2364 Ok(0)
2365 }
2366 sql::LogicalPlan::CreateTrigger {
2367 name,
2368 table_name,
2369 timing,
2370 events,
2371 for_each,
2372 when_condition,
2373 body,
2374 if_not_exists,
2375 referencing,
2376 characteristics,
2377 trigger_type,
2378 from_constraint,
2379 } => {
2380 if let Ok(Some(_)) = self.trigger_registry.get_trigger(table_name, name) {
2382 if *if_not_exists {
2383 return Ok(0);
2384 } else {
2385 return Err(Error::query_execution(format!(
2386 "Trigger '{}' already exists on table '{}'",
2387 name, table_name
2388 )));
2389 }
2390 }
2391
2392 let definition = sql::triggers::TriggerDefinition {
2394 name: name.clone(),
2395 table_name: table_name.clone(),
2396 timing: timing.clone(),
2397 events: events.clone(),
2398 for_each: for_each.clone(),
2399 when_condition: when_condition.clone(),
2400 body: body.clone(),
2401 enabled: true,
2402 created_at: std::time::SystemTime::now()
2403 .duration_since(std::time::UNIX_EPOCH)
2404 .unwrap_or_default()
2405 .as_millis() as u64,
2406 referencing: referencing.clone(),
2407 characteristics: characteristics.clone(),
2408 trigger_type: trigger_type.clone(),
2409 from_constraint: from_constraint.clone(),
2410 };
2411
2412 self.trigger_registry.register_trigger(definition.clone())?;
2414
2415 if let Ok(serialized) = bincode::serialize(&definition) {
2417 if let Err(e) = self.storage.log_create_trigger(name, table_name, &serialized) {
2418 tracing::warn!("Failed to log CREATE TRIGGER to WAL: {}", e);
2419 }
2420 }
2421
2422 Ok(0)
2423 }
2424 sql::LogicalPlan::DropTrigger { name, table_name, if_exists } => {
2425 let tbl = table_name.as_ref().ok_or_else(|| {
2427 Error::query_execution("DROP TRIGGER requires ON <table_name> clause".to_string())
2428 })?;
2429
2430 let dropped = self.trigger_registry.drop_trigger(tbl, name)?;
2431
2432 if !dropped && !*if_exists {
2433 return Err(Error::query_execution(format!(
2434 "Trigger '{}' does not exist on table '{}'",
2435 name, tbl
2436 )));
2437 }
2438
2439 if let Err(e) = self.storage.log_drop_trigger(name, table_name.as_deref()) {
2441 tracing::warn!("Failed to log DROP TRIGGER to WAL: {}", e);
2442 }
2443
2444 Ok(0)
2445 }
2446 sql::LogicalPlan::Call { name, args } => {
2447 let schema = std::sync::Arc::new(Schema { columns: vec![] });
2449 let evaluator = sql::Evaluator::new(schema);
2450
2451 let arg_values: Vec<Value> = args.iter()
2453 .map(|expr| evaluator.evaluate(expr, &Tuple::new(vec![])))
2454 .collect::<Result<Vec<_>>>()?;
2455
2456 let db_clone = self.clone_for_trigger();
2458 let sql_executor = |sql: &str| -> Result<Vec<Vec<Value>>> {
2459 let sql_trimmed = sql.trim();
2461 if starts_with_icase(sql_trimmed, "SELECT") || starts_with_icase(sql_trimmed, "WITH") {
2462 let tuples = db_clone.query(sql, &[])?;
2463 Ok(tuples.iter().map(|t| t.values.clone()).collect())
2464 } else {
2465 db_clone.execute(sql)?;
2467 Ok(vec![])
2468 }
2469 };
2470
2471 self.function_registry.execute_procedure(name, &arg_values, sql_executor)?;
2472 Ok(0)
2473 }
2474 sql::LogicalPlan::AlterColumnStorage { table_name, column_name, storage_mode } => {
2475 let catalog = self.storage.catalog();
2479 let mut schema = catalog.get_table_schema(table_name)?;
2480
2481 let col_idx = schema.columns.iter()
2483 .position(|c| c.name == *column_name)
2484 .ok_or_else(|| Error::query_execution(format!(
2485 "Column '{}' not found in table '{}'", column_name, table_name
2486 )))?;
2487
2488 let col_ref = schema.get_column_at(col_idx)
2489 .ok_or_else(|| Error::internal("column index out of bounds"))?;
2490 let old_mode = col_ref.storage_mode;
2491 if old_mode == *storage_mode {
2492 return Ok(0);
2494 }
2495
2496 let column = col_ref.clone();
2498 let rows_migrated = self.storage.migrate_column_storage(
2499 table_name,
2500 col_idx,
2501 &column,
2502 old_mode,
2503 *storage_mode,
2504 )?;
2505
2506 schema.get_column_at_mut(col_idx)
2508 .ok_or_else(|| Error::internal("column index out of bounds"))?
2509 .storage_mode = *storage_mode;
2510 catalog.update_table_schema(table_name, &schema)?;
2511
2512 if let Err(e) = self.storage.log_alter_column_storage(table_name, column_name, storage_mode) {
2514 tracing::warn!("Failed to log ALTER COLUMN STORAGE to WAL: {}", e);
2515 }
2516
2517 tracing::info!(
2518 "Altered {}.{} storage from {:?} to {:?}, migrated {} rows",
2519 table_name, column_name, old_mode, storage_mode, rows_migrated
2520 );
2521
2522 Ok(rows_migrated as u64)
2523 }
2524 sql::LogicalPlan::AlterTableAddColumn { table_name, column_def, if_not_exists } => {
2525 let catalog = self.storage.catalog();
2526 let mut schema = catalog.get_table_schema(table_name)?;
2527
2528 if schema.columns.iter().any(|c| c.name == column_def.name) {
2530 if *if_not_exists {
2531 return Ok(0);
2532 }
2533 return Err(Error::query_execution(format!(
2534 "Column '{}' already exists in table '{}'", column_def.name, table_name
2535 )));
2536 }
2537
2538 let new_column = Column {
2540 name: column_def.name.clone(),
2541 data_type: column_def.data_type.clone(),
2542 nullable: !column_def.not_null,
2543 primary_key: column_def.primary_key,
2544 source_table: None,
2545 source_table_name: Some(table_name.clone()),
2546 default_expr: column_def.default.as_ref().map(|e| format!("{:?}", e)),
2547 unique: column_def.unique,
2548 storage_mode: column_def.storage_mode,
2549 };
2550
2551 schema.columns.push(new_column);
2553 catalog.update_table_schema(table_name, &schema)?;
2554
2555 let rows_updated = self.storage.add_column_to_rows(
2557 table_name,
2558 &column_def.default,
2559 )?;
2560
2561 tracing::info!(
2562 "Added column '{}' to table '{}', updated {} rows",
2563 column_def.name, table_name, rows_updated
2564 );
2565
2566 Ok(rows_updated as u64)
2567 }
2568 sql::LogicalPlan::AlterTableDropColumn { table_name, column_name, if_exists, cascade } => {
2569 let catalog = self.storage.catalog();
2570 let mut schema = catalog.get_table_schema(table_name)?;
2571
2572 let col_idx = schema.columns.iter()
2574 .position(|c| c.name == *column_name);
2575
2576 match col_idx {
2577 Some(idx) => {
2578 let is_pk = schema.get_column_at(idx)
2580 .ok_or_else(|| Error::internal("column index out of bounds"))?
2581 .primary_key;
2582 if is_pk && !cascade {
2583 return Err(Error::query_execution(format!(
2584 "Cannot drop primary key column '{}' without CASCADE", column_name
2585 )));
2586 }
2587
2588 schema.columns.remove(idx);
2590 catalog.update_table_schema(table_name, &schema)?;
2591
2592 let rows_updated = self.storage.drop_column_from_rows(table_name, idx)?;
2594
2595 tracing::info!(
2596 "Dropped column '{}' from table '{}', updated {} rows",
2597 column_name, table_name, rows_updated
2598 );
2599
2600 Ok(rows_updated as u64)
2601 }
2602 None => {
2603 if *if_exists {
2604 Ok(0)
2605 } else {
2606 Err(Error::query_execution(format!(
2607 "Column '{}' does not exist in table '{}'", column_name, table_name
2608 )))
2609 }
2610 }
2611 }
2612 }
2613 sql::LogicalPlan::AlterTableRenameColumn { table_name, old_column_name, new_column_name } => {
2614 let catalog = self.storage.catalog();
2615 let mut schema = catalog.get_table_schema(table_name)?;
2616
2617 if schema.columns.iter().any(|c| c.name == *new_column_name) {
2619 return Err(Error::query_execution(format!(
2620 "Column '{}' already exists in table '{}'", new_column_name, table_name
2621 )));
2622 }
2623
2624 let col_idx = schema.columns.iter()
2626 .position(|c| c.name == *old_column_name)
2627 .ok_or_else(|| Error::query_execution(format!(
2628 "Column '{}' does not exist in table '{}'", old_column_name, table_name
2629 )))?;
2630
2631 schema.get_column_at_mut(col_idx)
2632 .ok_or_else(|| Error::internal("column index out of bounds"))?
2633 .name = new_column_name.clone();
2634 catalog.update_table_schema(table_name, &schema)?;
2635
2636 tracing::info!(
2637 "Renamed column '{}' to '{}' in table '{}'",
2638 old_column_name, new_column_name, table_name
2639 );
2640
2641 Ok(0)
2642 }
2643 sql::LogicalPlan::AlterTableRename { table_name, new_table_name } => {
2644 let catalog = self.storage.catalog();
2645
2646 if catalog.get_table_schema(new_table_name).is_ok() {
2648 return Err(Error::query_execution(format!(
2649 "Table '{}' already exists", new_table_name
2650 )));
2651 }
2652
2653 self.storage.rename_table(table_name, new_table_name)?;
2655
2656 tracing::info!(
2657 "Renamed table '{}' to '{}'",
2658 table_name, new_table_name
2659 );
2660
2661 Ok(0)
2662 }
2663 sql::LogicalPlan::AlterTableAddForeignKey {
2664 table_name,
2665 constraint_name,
2666 columns,
2667 references_table,
2668 references_columns,
2669 on_delete,
2670 on_update,
2671 deferrable,
2672 initially_deferred,
2673 } => {
2674 let catalog = self.storage.catalog();
2676 catalog.get_table_schema(table_name)?;
2677 catalog.get_table_schema(references_table)?;
2678
2679 let fk_name = constraint_name
2680 .clone()
2681 .unwrap_or_else(|| sql::ForeignKeyConstraint::generate_name(
2682 table_name, columns, references_table,
2683 ));
2684 let mut fk = sql::ForeignKeyConstraint::new(
2685 fk_name,
2686 table_name.clone(),
2687 columns.clone(),
2688 references_table.clone(),
2689 references_columns.clone(),
2690 );
2691 if let Some(action) = on_delete {
2692 fk = fk.on_delete(convert_logical_referential_action(action));
2693 }
2694 if let Some(action) = on_update {
2695 fk = fk.on_update(convert_logical_referential_action(action));
2696 }
2697 if *deferrable {
2698 fk = fk.deferrable(*initially_deferred);
2699 }
2700 catalog.add_foreign_key(fk)?;
2701 Ok(0)
2702 }
2703 sql::LogicalPlan::AlterTableMulti { operations } => {
2704 let mut total_rows = 0u64;
2705 for sub_plan in operations {
2706 total_rows += self.execute_alter_table_op(sub_plan)?;
2707 }
2708 Ok(total_rows)
2709 }
2710 sql::LogicalPlan::Savepoint { ref name } => {
2711 let write_set_snapshot = txn.savepoint_snapshot();
2712 let savepoint = SavepointState {
2713 name: name.clone(),
2714 write_set_snapshot,
2715 };
2716 self.savepoints.write().push(savepoint);
2717 Ok(0)
2718 }
2719 sql::LogicalPlan::ReleaseSavepoint { ref name } => {
2720 let mut savepoints = self.savepoints.write();
2721 if let Some(pos) = savepoints.iter().rposition(|s| &s.name == name) {
2722 savepoints.truncate(pos);
2723 Ok(0)
2724 } else {
2725 Err(Error::query_execution(format!("Savepoint '{}' does not exist", name)))
2726 }
2727 }
2728 sql::LogicalPlan::RollbackToSavepoint { ref name } => {
2729 let savepoints = self.savepoints.read();
2730 if let Some(pos) = savepoints.iter().rposition(|s| &s.name == name) {
2731 let snapshot = savepoints.get(pos)
2732 .map(|s| s.write_set_snapshot.clone());
2733 drop(savepoints);
2734 if let Some(snapshot) = snapshot {
2735 txn.rollback_to_savepoint(&snapshot);
2736 }
2737 let mut savepoints = self.savepoints.write();
2738 savepoints.truncate(pos + 1);
2739 Ok(0)
2740 } else {
2741 Err(Error::query_execution(format!("Savepoint '{}' does not exist", name)))
2742 }
2743 }
2744 sql::LogicalPlan::Truncate { ref table_name } => {
2745 let catalog = self.storage.catalog();
2748 let _schema = catalog.get_table_schema(table_name)?;
2749 let rows = self.storage.scan_table(table_name)?;
2750 let mut count = 0u64;
2751 for tuple in &rows {
2752 if let Some(row_id) = tuple.row_id {
2753 let key = format!("data:{}:{}", table_name, row_id).into_bytes();
2754 txn.delete(key)?;
2755 self.storage.row_cache().invalidate(table_name, row_id);
2757 count += 1;
2758 }
2759 }
2760 self.storage.art_indexes().clear_table_indexes(table_name);
2762 Ok(count)
2763 }
2764 sql::LogicalPlan::CreateDatabase { name, if_not_exists } => {
2765 let (count, _) = self.handle_create_database(name, *if_not_exists)?;
2766 Ok(count)
2767 }
2768 sql::LogicalPlan::DropDatabase { name, if_exists } => {
2769 let (count, _) = self.handle_drop_database(name, *if_exists)?;
2770 Ok(count)
2771 }
2772 _ => {
2773 let mut executor = sql::Executor::with_storage(&self.storage)
2776 .with_timeout(self.config.storage.query_timeout_ms)
2777 .with_transaction(txn);
2778 let results = executor.execute(&plan)?;
2779 let is_select = matches!(plan,
2781 sql::LogicalPlan::Scan { .. } |
2782 sql::LogicalPlan::Filter { .. } |
2783 sql::LogicalPlan::Project { .. } |
2784 sql::LogicalPlan::Aggregate { .. } |
2785 sql::LogicalPlan::Join { .. } |
2786 sql::LogicalPlan::Sort { .. } |
2787 sql::LogicalPlan::Limit { .. } |
2788 sql::LogicalPlan::With { .. } |
2789 sql::LogicalPlan::TableFunction { .. } |
2790 sql::LogicalPlan::SystemView { .. }
2791 );
2792 let _ = is_select; Ok(results.len() as u64)
2794 }
2795 }
2796 }
2797
2798 #[allow(clippy::expect_used)] pub fn new(path: impl AsRef<std::path::Path>) -> Result<Self> {
2816 let config = Config::default();
2817 let storage = std::sync::Arc::new(storage::StorageEngine::open(path.as_ref(), &config)?);
2818 let mv_scheduler = std::sync::Arc::new(storage::MVScheduler::new(
2819 storage::SchedulerConfig::default(),
2820 std::sync::Arc::clone(&storage),
2821 ));
2822
2823 let dump_manager = std::sync::Arc::new(storage::DumpManager::new(
2824 path.as_ref().to_path_buf(),
2825 storage::DumpCompressionType::Zstd,
2826 ));
2827
2828 let session_manager = std::sync::Arc::new(crate::session::SessionManager::new());
2829 let lock_manager = std::sync::Arc::new(storage::LockManager::with_default_timeout());
2830 let dirty_tracker = std::sync::Arc::new(storage::DirtyTracker::new());
2831
2832 let catalog = storage::Catalog::new(&storage);
2839 if let Err(e) = catalog.rebuild_all_indexes() {
2840 tracing::warn!("ART rebuild on open failed: {} — falling back to scan paths", e);
2841 }
2842
2843 Ok(Self {
2844 storage,
2845 config,
2846 current_transaction: std::sync::Arc::new(std::sync::Mutex::new(None)),
2847 tenant_manager: std::sync::Arc::new(crate::tenant::TenantManager::new()),
2848 trigger_registry: std::sync::Arc::new(sql::TriggerRegistry::new()),
2849 function_registry: std::sync::Arc::new(sql::FunctionRegistry::new()),
2850 mv_scheduler,
2851 auto_refresh_worker: std::sync::Arc::new(parking_lot::RwLock::new(None)),
2852 dump_manager,
2853 session_manager,
2854 lock_manager,
2855 dirty_tracker,
2856 session_transactions: std::sync::Arc::new(dashmap::DashMap::new()),
2857 prepared_statements: std::sync::Arc::new(parking_lot::RwLock::new(std::collections::HashMap::new())),
2858 savepoints: std::sync::Arc::new(parking_lot::RwLock::new(Vec::new())),
2859 plan_cache: std::sync::Arc::new(std::sync::Mutex::new(lru::LruCache::new(std::num::NonZeroUsize::new(256).expect("256 is non-zero")))),
2860 parse_cache: std::sync::Arc::new(std::sync::Mutex::new(lru::LruCache::new(std::num::NonZeroUsize::new(512).expect("512 is non-zero")))),
2861 result_cache: std::sync::Arc::new(std::sync::Mutex::new(lru::LruCache::new(std::num::NonZeroUsize::new(128).expect("128 is non-zero")))),
2862 art_undo_log: std::sync::Arc::new(parking_lot::RwLock::new(Vec::new())),
2863 })
2864 }
2865
2866 #[allow(clippy::expect_used)] pub fn new_in_memory() -> Result<Self> {
2882 let config = Config::in_memory();
2883 let storage = std::sync::Arc::new(storage::StorageEngine::open_in_memory(&config)?);
2884 let mv_scheduler = std::sync::Arc::new(storage::MVScheduler::new(
2885 storage::SchedulerConfig::default(),
2886 std::sync::Arc::clone(&storage),
2887 ));
2888
2889 let dump_path = std::env::temp_dir().join("heliosdb_dumps");
2891 let dump_manager = std::sync::Arc::new(storage::DumpManager::new(
2892 dump_path,
2893 storage::DumpCompressionType::Zstd,
2894 ));
2895
2896 let session_manager = std::sync::Arc::new(crate::session::SessionManager::new());
2897 let lock_manager = std::sync::Arc::new(storage::LockManager::with_default_timeout());
2898 let dirty_tracker = std::sync::Arc::new(storage::DirtyTracker::new());
2899
2900 Ok(Self {
2901 storage,
2902 config,
2903 current_transaction: std::sync::Arc::new(std::sync::Mutex::new(None)),
2904 tenant_manager: std::sync::Arc::new(crate::tenant::TenantManager::new()),
2905 trigger_registry: std::sync::Arc::new(sql::TriggerRegistry::new()),
2906 function_registry: std::sync::Arc::new(sql::FunctionRegistry::new()),
2907 mv_scheduler,
2908 auto_refresh_worker: std::sync::Arc::new(parking_lot::RwLock::new(None)),
2909 dump_manager,
2910 session_manager,
2911 lock_manager,
2912 dirty_tracker,
2913 session_transactions: std::sync::Arc::new(dashmap::DashMap::new()),
2914 prepared_statements: std::sync::Arc::new(parking_lot::RwLock::new(std::collections::HashMap::new())),
2915 savepoints: std::sync::Arc::new(parking_lot::RwLock::new(Vec::new())),
2916 plan_cache: std::sync::Arc::new(std::sync::Mutex::new(lru::LruCache::new(std::num::NonZeroUsize::new(256).expect("256 is non-zero")))),
2917 parse_cache: std::sync::Arc::new(std::sync::Mutex::new(lru::LruCache::new(std::num::NonZeroUsize::new(512).expect("512 is non-zero")))),
2918 result_cache: std::sync::Arc::new(std::sync::Mutex::new(lru::LruCache::new(std::num::NonZeroUsize::new(128).expect("128 is non-zero")))),
2919 art_undo_log: std::sync::Arc::new(parking_lot::RwLock::new(Vec::new())),
2920 })
2921 }
2922
2923 #[allow(clippy::expect_used)] pub fn with_config(config: Config) -> Result<Self> {
2939 let memory_only = config.storage.memory_only;
2940 let storage = std::sync::Arc::new(if memory_only {
2941 storage::StorageEngine::open_in_memory(&config)?
2942 } else {
2943 let path = config.storage.path.as_ref()
2944 .ok_or_else(|| Error::config("Storage path not specified for non-memory database".to_string()))?;
2945 storage::StorageEngine::open(path, &config)?
2946 });
2947 let mv_scheduler = std::sync::Arc::new(storage::MVScheduler::new(
2948 storage::SchedulerConfig::default(),
2949 std::sync::Arc::clone(&storage),
2950 ));
2951
2952 let dump_path = if let Some(ref p) = config.storage.path {
2953 p.clone()
2954 } else {
2955 std::env::temp_dir().join("heliosdb_dumps")
2956 };
2957
2958 let dump_manager = std::sync::Arc::new(storage::DumpManager::new(
2959 dump_path,
2960 storage::DumpCompressionType::Zstd,
2961 ));
2962
2963 let session_manager = std::sync::Arc::new(crate::session::SessionManager::new());
2964 let lock_manager = std::sync::Arc::new(storage::LockManager::with_default_timeout());
2965 let dirty_tracker = std::sync::Arc::new(storage::DirtyTracker::new());
2966
2967 if !memory_only {
2971 let catalog = storage::Catalog::new(&storage);
2972 if let Err(e) = catalog.rebuild_all_indexes() {
2973 tracing::warn!("ART rebuild on open failed: {} — falling back to scan paths", e);
2974 }
2975 }
2976
2977 Ok(Self {
2978 storage,
2979 config,
2980 current_transaction: std::sync::Arc::new(std::sync::Mutex::new(None)),
2981 tenant_manager: std::sync::Arc::new(crate::tenant::TenantManager::new()),
2982 trigger_registry: std::sync::Arc::new(sql::TriggerRegistry::new()),
2983 function_registry: std::sync::Arc::new(sql::FunctionRegistry::new()),
2984 mv_scheduler,
2985 auto_refresh_worker: std::sync::Arc::new(parking_lot::RwLock::new(None)),
2986 dump_manager,
2987 session_manager,
2988 lock_manager,
2989 dirty_tracker,
2990 session_transactions: std::sync::Arc::new(dashmap::DashMap::new()),
2991 prepared_statements: std::sync::Arc::new(parking_lot::RwLock::new(std::collections::HashMap::new())),
2992 savepoints: std::sync::Arc::new(parking_lot::RwLock::new(Vec::new())),
2993 plan_cache: std::sync::Arc::new(std::sync::Mutex::new(lru::LruCache::new(std::num::NonZeroUsize::new(256).expect("256 is non-zero")))),
2994 parse_cache: std::sync::Arc::new(std::sync::Mutex::new(lru::LruCache::new(std::num::NonZeroUsize::new(512).expect("512 is non-zero")))),
2995 result_cache: std::sync::Arc::new(std::sync::Mutex::new(lru::LruCache::new(std::num::NonZeroUsize::new(128).expect("128 is non-zero")))),
2996 art_undo_log: std::sync::Arc::new(parking_lot::RwLock::new(Vec::new())),
2997 })
2998 }
2999
3000 pub fn query_timeout_ms(&self) -> Option<u64> {
3046 self.config.storage.query_timeout_ms
3047 }
3048
3049 fn plan_contains_join(plan: &sql::LogicalPlan) -> bool {
3051 match plan {
3052 sql::LogicalPlan::Join { .. } => true,
3053 sql::LogicalPlan::Filter { input, .. }
3054 | sql::LogicalPlan::Project { input, .. }
3055 | sql::LogicalPlan::Sort { input, .. }
3056 | sql::LogicalPlan::Limit { input, .. }
3057 | sql::LogicalPlan::Aggregate { input, .. } => Self::plan_contains_join(input),
3058 _ => false,
3059 }
3060 }
3061
3062 fn log_slow_query(&self, sql: &str, elapsed: std::time::Duration, rows: u64) {
3064 if let Some(threshold) = self.config.storage.slow_query_threshold_ms {
3065 let elapsed_ms = elapsed.as_millis() as u64;
3066 if elapsed_ms >= threshold {
3067 tracing::warn!(
3068 duration_ms = elapsed_ms,
3069 rows = rows,
3070 "Slow query ({}ms, {} rows): {:.200}",
3071 elapsed_ms,
3072 rows,
3073 sql
3074 );
3075 }
3076 }
3077 }
3078
3079 pub fn execute_batch(&self, statements: &[&str]) -> Result<u64> {
3089 let start = std::time::Instant::now();
3090
3091 let txn_start = std::time::Instant::now();
3092 let txn = self.storage.begin_transaction()?;
3093 tracing::trace!(phase = "txn_begin", duration_us = txn_start.elapsed().as_micros() as u64, "Batch transaction started");
3094
3095 let mut total_rows = 0u64;
3096 for sql in statements {
3097 match self.execute_in_transaction(sql, &txn) {
3098 Ok(count) => total_rows += count,
3099 Err(e) => {
3100 let _ = txn.rollback();
3101 return Err(e);
3102 }
3103 }
3104 }
3105
3106 let commit_start = std::time::Instant::now();
3107 txn.commit()?;
3108 self.storage.increment_lsn();
3109 tracing::debug!(phase = "txn_commit", duration_us = commit_start.elapsed().as_micros() as u64, rows = total_rows, "Batch transaction committed");
3110
3111 let elapsed = start.elapsed();
3112 tracing::debug!(phase = "execute", duration_us = elapsed.as_micros() as u64, "Batch executed ({} statements)", statements.len());
3113
3114 Ok(total_rows)
3115 }
3116
3117 #[cfg(feature = "code-graph")]
3162 pub fn register_grammar(
3163 &self,
3164 name: impl Into<String>,
3165 grammar: tree_sitter::Language,
3166 ) -> Option<tree_sitter::Language> {
3167 code_graph::parse::register_grammar(name, grammar)
3168 }
3169
3170 #[cfg(feature = "code-graph")]
3173 pub fn unregister_grammar(&self, name: &str) -> Option<tree_sitter::Language> {
3174 code_graph::parse::unregister_grammar(name)
3175 }
3176
3177 #[cfg(feature = "code-graph")]
3179 pub fn registered_grammars(&self) -> Vec<String> {
3180 code_graph::parse::registered_grammars()
3181 }
3182
3183 #[cfg(feature = "code-graph")]
3188 pub fn register_extractor(
3189 &self,
3190 name: impl Into<String>,
3191 extractor: std::sync::Arc<dyn code_graph::SymbolExtractor>,
3192 ) -> Option<std::sync::Arc<dyn code_graph::SymbolExtractor>> {
3193 code_graph::register_extractor(name, extractor)
3194 }
3195
3196 #[cfg(feature = "code-graph")]
3198 pub fn unregister_extractor(
3199 &self,
3200 name: &str,
3201 ) -> Option<std::sync::Arc<dyn code_graph::SymbolExtractor>> {
3202 code_graph::unregister_extractor(name)
3203 }
3204
3205 #[cfg(feature = "code-graph")]
3207 pub fn registered_extractors(&self) -> Vec<String> {
3208 code_graph::registered_extractors()
3209 }
3210
3211 #[cfg(feature = "code-graph")]
3212 pub fn code_index(
3213 &self,
3214 opts: code_graph::CodeIndexOptions,
3215 ) -> Result<code_graph::CodeIndexStats> {
3216 let stats = code_graph::storage::code_index(self, opts)?;
3217 #[cfg(feature = "graph-rag")]
3222 {
3223 let _ = self.graph_rag_project_symbols();
3224 }
3225 Ok(stats)
3226 }
3227
3228 #[cfg(feature = "code-graph")]
3233 pub fn lsp_definition(
3234 &self,
3235 name: &str,
3236 hint: &code_graph::DefinitionHint,
3237 ) -> Result<Vec<code_graph::DefinitionRow>> {
3238 code_graph::lsp::lsp_definition(self, name, hint)
3239 }
3240
3241 #[cfg(feature = "code-graph")]
3244 pub fn lsp_references(
3245 &self,
3246 symbol_id: i64,
3247 ) -> Result<Vec<code_graph::ReferenceRow>> {
3248 code_graph::lsp::lsp_references(self, symbol_id)
3249 }
3250
3251 #[cfg(feature = "code-graph")]
3254 pub fn lsp_call_hierarchy(
3255 &self,
3256 symbol_id: i64,
3257 direction: code_graph::lsp::CallDirection,
3258 depth: u32,
3259 ) -> Result<Vec<code_graph::lsp::CallHierarchyRow>> {
3260 code_graph::lsp::lsp_call_hierarchy(self, symbol_id, direction, depth)
3261 }
3262
3263 #[cfg(feature = "code-graph")]
3266 pub fn lsp_hover(&self, symbol_id: i64) -> Result<Option<code_graph::HoverRow>> {
3267 code_graph::lsp::lsp_hover(self, symbol_id)
3268 }
3269
3270 #[cfg(feature = "code-graph")]
3273 pub fn lsp_references_diff(
3274 &self,
3275 symbol_id: i64,
3276 at_a: &code_graph::AsOfRef,
3277 at_b: &code_graph::AsOfRef,
3278 ) -> Result<Vec<code_graph::RefDiffRow>> {
3279 code_graph::diff::lsp_references_diff(self, symbol_id, at_a, at_b)
3280 }
3281
3282 #[cfg(feature = "code-graph")]
3284 pub fn lsp_body_diff(
3285 &self,
3286 symbol_id: i64,
3287 at_a: &code_graph::AsOfRef,
3288 at_b: &code_graph::AsOfRef,
3289 ) -> Result<Vec<code_graph::BodyDiffLine>> {
3290 code_graph::diff::lsp_body_diff(self, symbol_id, at_a, at_b)
3291 }
3292
3293 #[cfg(feature = "code-graph")]
3296 pub fn ast_diff(
3297 &self,
3298 file_path: &str,
3299 at_a: &code_graph::AsOfRef,
3300 at_b: &code_graph::AsOfRef,
3301 ) -> Result<Vec<code_graph::AstDiffRow>> {
3302 code_graph::diff::ast_diff(self, file_path, at_a, at_b)
3303 }
3304
3305 #[cfg(feature = "code-graph")]
3309 pub fn code_graph_merkle_refresh(
3310 &self,
3311 ) -> Result<code_graph::MerkleStats> {
3312 code_graph::merkle_refresh(self)
3313 }
3314
3315 #[cfg(feature = "code-graph")]
3321 pub fn lsp_rename_apply(
3322 &self,
3323 symbol_id: i64,
3324 new_name: &str,
3325 opts: &code_graph::RenameApplyOptions,
3326 ) -> Result<code_graph::RenameApplyStats> {
3327 code_graph::rename_apply(self, symbol_id, new_name, opts)
3328 }
3329
3330 #[cfg(feature = "graph-rag")]
3334 fn run_with_context(
3335 &self,
3336 inner_sql: &str,
3337 opts: &graph_rag::WithContextOptions,
3338 ) -> Result<Vec<Tuple>> {
3339 let hits =
3340 graph_rag::graph_rag_expand_with_context(self, inner_sql, opts)?;
3341 Ok(hits
3342 .into_iter()
3343 .map(|h| {
3344 Tuple::new(vec![
3345 Value::Int8(h.node_id),
3346 Value::String(h.node_kind),
3347 h.title.map(Value::String).unwrap_or(Value::Null),
3348 h.text.map(Value::String).unwrap_or(Value::Null),
3349 h.source_ref.map(Value::String).unwrap_or(Value::Null),
3350 Value::Int4(h.hop_distance as i32),
3351 ])
3352 })
3353 .collect())
3354 }
3355
3356 #[inline]
3361 fn maybe_rewrite_code_graph<'a>(&self, sql: &'a str) -> std::borrow::Cow<'a, str> {
3362 #[cfg(feature = "code-graph")]
3363 {
3364 let rewritten = code_graph::rewrite_lsp_calls(sql);
3365 if rewritten != sql {
3366 return std::borrow::Cow::Owned(rewritten);
3367 }
3368 }
3369 std::borrow::Cow::Borrowed(sql)
3370 }
3371
3372 #[cfg(feature = "code-graph")]
3382 fn rewrite_and_scope(&self, sql: &str) -> (String, CodeGraphBranchGuard<'_>) {
3383 let rewrite = code_graph::rewrite_lsp_calls_full(sql);
3384 let guard = match rewrite.branch_override {
3385 Some(target) => CodeGraphBranchGuard::switch_to(self, target),
3386 None => CodeGraphBranchGuard::noop(),
3387 };
3388 (rewrite.sql, guard)
3389 }
3390
3391 #[cfg(feature = "code-graph")]
3395 fn handle_create_ast_index(
3396 &self,
3397 ddl: code_graph::AstIndexDdl,
3398 ) -> Result<u64> {
3399 let existing = code_graph::storage::get_ast_index(&ddl.index_name);
3400 if existing.is_some() && !ddl.if_not_exists {
3401 return Err(Error::query_execution(format!(
3402 "AST index '{}' already exists",
3403 ddl.index_name
3404 )));
3405 }
3406 let meta = code_graph::AstIndexMeta {
3407 index_name: ddl.index_name.clone(),
3408 table: ddl.table.clone(),
3409 content_col: ddl.content_col,
3410 lang_col: ddl.lang_col,
3411 embed_endpoint: ddl.embed_endpoint.clone(),
3412 embed_bearer: ddl.embed_bearer.clone(),
3413 embed_bodies: ddl.embed_bodies,
3414 auto_reparse: ddl.auto_reparse,
3415 resolve_cross_file: ddl.resolve_cross_file,
3416 paused: false,
3417 };
3418 code_graph::register_ast_index(meta.clone());
3419 let opts = code_graph::CodeIndexOptions {
3421 source_table: ddl.table,
3422 embed_bodies: meta.embed_bodies,
3423 embed_endpoint: meta.embed_endpoint,
3424 embed_bearer: meta.embed_bearer,
3425 force_reparse: false,
3426 parallelism: None,
3427 chunk_size: None,
3428 };
3429 self.code_index(opts)?;
3430 Ok(0)
3431 }
3432
3433 #[cfg(feature = "code-graph")]
3441 fn handle_create_semantic_hash_index(
3442 &self,
3443 ddl: code_graph::SemanticHashIndexDdl,
3444 ) -> Result<u64> {
3445 let stats = code_graph::merkle_refresh(self)?;
3450 tracing::info!(
3451 index = %ddl.index_name,
3452 files_hashed = stats.files_hashed,
3453 files_unchanged = stats.files_unchanged,
3454 symbols_hashed = stats.symbols_hashed,
3455 "CREATE SEMANTIC HASH INDEX completed"
3456 );
3457 let _ = ddl.if_not_exists; Ok(stats.files_hashed)
3459 }
3460
3461 #[cfg(feature = "code-graph")]
3465 fn handle_pause_resume(
3466 &self,
3467 pr: code_graph::PauseResume,
3468 ) -> Result<u64> {
3469 let (name, paused) = match pr {
3470 code_graph::PauseResume::Pause(n) => (n, true),
3471 code_graph::PauseResume::Resume(n) => (n, false),
3472 };
3473 if !code_graph::storage::set_ast_index_paused(&name, paused) {
3474 return Err(Error::query_execution(format!(
3475 "AST index '{name}' is not registered"
3476 )));
3477 }
3478 Ok(0)
3479 }
3480
3481 #[cfg(feature = "code-graph")]
3487 fn maybe_auto_reparse(&self, touched_table: Option<&str>) {
3488 let Some(tbl) = touched_table else { return };
3489 for idx in code_graph::storage::ast_indexes_for_table(tbl) {
3490 if !idx.auto_reparse {
3491 continue;
3492 }
3493 let opts = code_graph::CodeIndexOptions {
3494 source_table: idx.table.clone(),
3495 embed_bodies: idx.embed_bodies,
3496 embed_endpoint: idx.embed_endpoint.clone(),
3497 embed_bearer: idx.embed_bearer.clone(),
3498 force_reparse: false,
3499 parallelism: None,
3500 chunk_size: None,
3501 };
3502 let _ = self.code_index(opts);
3503 }
3504 }
3505
3506 #[cfg(feature = "code-graph")]
3511 fn touched_table_from_sql(sql: &str) -> Option<String> {
3512 let s = sql.trim_start();
3513 let low = s.to_ascii_lowercase();
3514 if low.starts_with("insert into") {
3515 let rest = s.get("insert into".len()..)?.trim_start();
3516 Some(Self::take_ident(rest))
3517 } else if low.starts_with("update") {
3518 let rest = s.get("update".len()..)?.trim_start();
3519 Some(Self::take_ident(rest))
3520 } else if low.starts_with("delete from") {
3521 let rest = s.get("delete from".len()..)?.trim_start();
3522 Some(Self::take_ident(rest))
3523 } else {
3524 None
3525 }
3526 }
3527
3528 #[cfg(feature = "code-graph")]
3529 fn take_ident(s: &str) -> String {
3530 let mut out = String::new();
3531 let mut it = s.chars().peekable();
3532 if matches!(it.peek(), Some('"')) {
3533 it.next();
3534 for c in it {
3535 if c == '"' {
3536 break;
3537 }
3538 out.push(c);
3539 }
3540 return out;
3541 }
3542 for c in it {
3543 if c.is_alphanumeric() || c == '_' {
3544 out.push(c);
3545 } else {
3546 break;
3547 }
3548 }
3549 out
3550 }
3551
3552 #[cfg(feature = "graph-rag")]
3560 pub fn graph_rag_project_symbols(&self) -> Result<graph_rag::GraphRagStats> {
3561 let mut stats = graph_rag::GraphRagStats::default();
3562 graph_rag::project_code_symbols(self, &mut stats)?;
3563 Ok(stats)
3564 }
3565
3566 #[cfg(feature = "graph-rag")]
3572 pub fn graph_rag_search(
3573 &self,
3574 opts: &graph_rag::GraphRagOptions,
3575 ) -> Result<Vec<graph_rag::GraphRagHit>> {
3576 graph_rag::graph_rag_search(self, opts)
3577 }
3578
3579 #[cfg(feature = "graph-rag")]
3585 pub fn graph_rag_link_exact(
3586 &self,
3587 extra_kinds: &[&str],
3588 ) -> Result<graph_rag::LinkerStats> {
3589 graph_rag::link_exact_qualified(self, extra_kinds)
3590 }
3591
3592 #[cfg(feature = "graph-rag")]
3600 pub fn graph_rag_link_vector(
3601 &self,
3602 text_queries: &[graph_rag::TextEmbedding],
3603 symbol_targets: &[graph_rag::SymbolEmbedding],
3604 top_k: usize,
3605 threshold: f32,
3606 ) -> Result<graph_rag::LinkerStats> {
3607 graph_rag::link_vector_similar(self, text_queries, symbol_targets, top_k, threshold)
3608 }
3609
3610 #[cfg(feature = "graph-rag")]
3613 pub fn graph_rag_ingest_docs(
3614 &self,
3615 opts: &graph_rag::IngestDocsOptions,
3616 ) -> Result<graph_rag::IngestStats> {
3617 graph_rag::ingest_docs(self, opts)
3618 }
3619
3620 #[cfg(feature = "graph-rag")]
3623 pub fn graph_rag_ingest_pdf(
3624 &self,
3625 opts: &graph_rag::DoclingIngestOptions,
3626 ) -> Result<graph_rag::IngestStats> {
3627 graph_rag::docling_ingest_pdf(self, opts)
3628 }
3629
3630 #[cfg(feature = "graph-rag")]
3631 pub fn graph_rag_ingest_office(
3632 &self,
3633 opts: &graph_rag::DoclingIngestOptions,
3634 ) -> Result<graph_rag::IngestStats> {
3635 graph_rag::docling_ingest_office(self, opts)
3636 }
3637
3638 #[cfg(feature = "graph-rag")]
3639 pub fn graph_rag_ingest_audio(
3640 &self,
3641 opts: &graph_rag::DoclingIngestOptions,
3642 ) -> Result<graph_rag::IngestStats> {
3643 graph_rag::docling_ingest_audio(self, opts)
3644 }
3645
3646 #[cfg(feature = "graph-rag")]
3647 pub fn graph_rag_ingest_image(
3648 &self,
3649 opts: &graph_rag::DoclingIngestOptions,
3650 ) -> Result<graph_rag::IngestStats> {
3651 graph_rag::docling_ingest_image(self, opts)
3652 }
3653
3654 #[cfg(feature = "graph-rag")]
3656 pub fn graph_rag_ingest_email(
3657 &self,
3658 opts: &graph_rag::IngestEmailOptions,
3659 ) -> Result<graph_rag::IngestStats> {
3660 graph_rag::ingest_email(self, opts)
3661 }
3662
3663 #[cfg(feature = "graph-rag")]
3665 pub fn graph_rag_ingest_issues(
3666 &self,
3667 opts: &graph_rag::IngestIssuesOptions,
3668 ) -> Result<graph_rag::IngestStats> {
3669 graph_rag::ingest_issues(self, opts)
3670 }
3671
3672 #[cfg(feature = "graph-rag")]
3674 pub fn graph_rag_ingest_qa(
3675 &self,
3676 opts: &graph_rag::IngestQaOptions,
3677 ) -> Result<graph_rag::IngestStats> {
3678 graph_rag::ingest_qa(self, opts)
3679 }
3680
3681 pub fn execute(&self, sql: &str) -> Result<u64> {
3682 use crate::error::LockResultExt;
3683
3684 if let Some((_, _)) = crate::sql::sqlite_compat::parse_pragma(sql) {
3687 tracing::debug!("PRAGMA stubbed via execute(): {}", sql.trim());
3688 return Ok(0);
3689 }
3690
3691 #[cfg(feature = "code-graph")]
3694 if let Some(ddl) = code_graph::detect_create_ast_index(sql) {
3695 return self.handle_create_ast_index(ddl);
3696 }
3697 #[cfg(feature = "code-graph")]
3698 if let Some(pr) = code_graph::detect_pause_resume(sql) {
3699 return self.handle_pause_resume(pr);
3700 }
3701 #[cfg(feature = "code-graph")]
3702 if let Some(ddl) = code_graph::detect_create_semantic_hash_index(sql) {
3703 return self.handle_create_semantic_hash_index(ddl);
3704 }
3705 #[cfg(feature = "code-graph")]
3710 let (rewritten_owned, _branch_guard) = self.rewrite_and_scope(sql);
3711 #[cfg(feature = "code-graph")]
3712 let sql: &str = &rewritten_owned;
3713 #[cfg(not(feature = "code-graph"))]
3714 let sql: &str = sql;
3715
3716 let start = std::time::Instant::now();
3717
3718 if Self::is_transaction_control(sql) {
3720 return self.handle_transaction_control(sql);
3721 }
3722
3723 let has_active_txn = {
3725 let txn_lock = self.current_transaction.lock()
3726 .map_lock_err("Failed to acquire transaction lock for execute")?;
3727 txn_lock.is_some()
3728 };
3729
3730 let result = if has_active_txn {
3731 let txn_lock = self.current_transaction.lock()
3733 .map_lock_err("Failed to acquire transaction lock for execute")?;
3734 let txn_ref = txn_lock.as_ref()
3735 .ok_or_else(|| Error::transaction("Transaction lock in invalid state"))?;
3736 self.execute_in_transaction_no_fast_path(sql, txn_ref)
3737 } else {
3738 self.execute_with_implicit_transaction(sql)
3740 };
3741
3742 if result.is_ok() {
3744 self.invalidate_result_cache();
3745 #[cfg(feature = "code-graph")]
3746 {
3747 let touched = Self::touched_table_from_sql(sql);
3748 self.maybe_auto_reparse(touched.as_deref());
3749 }
3750 }
3751
3752 let rows = result.as_ref().copied().unwrap_or(0);
3753 self.log_slow_query(sql, start.elapsed(), rows);
3754 result
3755 }
3756
3757 pub fn execute_returning(&self, sql: &str) -> Result<(u64, Vec<Tuple>)> {
3791 self.execute_params_returning(sql, &[])
3792 }
3793
3794 fn execute_with_implicit_transaction(&self, sql: &str) -> Result<u64> {
3796 let txn_start = std::time::Instant::now();
3798 let txn = self.storage.begin_transaction()?;
3799 tracing::trace!(phase = "txn_begin", duration_us = txn_start.elapsed().as_micros() as u64, "Transaction started");
3800
3801 let exec_start = std::time::Instant::now();
3803 let result = self.execute_in_transaction(sql, &txn);
3804 tracing::debug!(phase = "execute", duration_us = exec_start.elapsed().as_micros() as u64, "Query executed");
3805
3806 match result {
3808 Ok(count) => {
3809 let commit_start = std::time::Instant::now();
3810 txn.commit()?;
3811 self.storage.increment_lsn();
3813 tracing::debug!(phase = "txn_commit", duration_us = commit_start.elapsed().as_micros() as u64, rows = count, "Transaction committed");
3814 Ok(count)
3815 }
3816 Err(e) => {
3817 let _ = txn.rollback(); Err(e)
3819 }
3820 }
3821 }
3822
3823 fn invalidate_plan_cache(&self) {
3825 if let Ok(mut cache) = self.plan_cache.lock() {
3826 cache.clear();
3827 }
3828 if let Ok(mut cache) = self.parse_cache.lock() {
3830 cache.clear();
3831 }
3832 self.invalidate_result_cache();
3834 }
3835
3836 fn invalidate_result_cache(&self) {
3838 if let Ok(mut cache) = self.result_cache.lock() {
3839 cache.clear();
3840 }
3841 }
3842
3843 fn handle_pragma_query(&self, name: &str, arg: Option<&str>) -> Result<Vec<Tuple>> {
3849 match name.to_lowercase().as_str() {
3850 "table_info" => {
3851 let table = arg
3852 .unwrap_or("")
3853 .trim()
3854 .trim_matches(|c| c == '\'' || c == '"' || c == '`');
3855 if table.is_empty() {
3856 return Ok(vec![]);
3857 }
3858 let catalog = self.storage.catalog();
3859 let schema = catalog.get_table_schema(table)?;
3860 let mut rows = Vec::with_capacity(schema.columns.len());
3861 for (idx, col) in schema.columns.iter().enumerate() {
3862 rows.push(Tuple::new(vec![
3863 Value::Int4(idx as i32),
3864 Value::String(col.name.clone()),
3865 Value::String(format!("{:?}", col.data_type).to_uppercase()),
3866 Value::Int4(if col.nullable { 0 } else { 1 }),
3867 col.default_expr
3868 .as_ref()
3869 .map(|d| Value::String(d.clone()))
3870 .unwrap_or(Value::Null),
3871 Value::Int4(if col.primary_key { 1 } else { 0 }),
3872 ]));
3873 }
3874 Ok(rows)
3875 }
3876 _ => {
3877 tracing::debug!("PRAGMA stubbed (no-op rows): {} = {:?}", name, arg);
3878 Ok(vec![])
3879 }
3880 }
3881 }
3882
3883 fn execute_alter_table_op(&self, plan: &sql::LogicalPlan) -> Result<u64> {
3889 match plan {
3890 sql::LogicalPlan::AlterTableAddColumn { table_name, column_def, if_not_exists } => {
3891 let catalog = self.storage.catalog();
3892 let mut schema = catalog.get_table_schema(table_name)?;
3893
3894 if schema.columns.iter().any(|c| c.name == column_def.name) {
3895 if *if_not_exists {
3896 return Ok(0);
3897 }
3898 return Err(Error::query_execution(format!(
3899 "Column '{}' already exists in table '{}'", column_def.name, table_name
3900 )));
3901 }
3902
3903 let new_column = Column {
3904 name: column_def.name.clone(),
3905 data_type: column_def.data_type.clone(),
3906 nullable: !column_def.not_null,
3907 primary_key: column_def.primary_key,
3908 source_table: None,
3909 source_table_name: Some(table_name.clone()),
3910 default_expr: column_def.default.as_ref().map(|e| format!("{:?}", e)),
3911 unique: column_def.unique,
3912 storage_mode: column_def.storage_mode,
3913 };
3914
3915 schema.columns.push(new_column);
3916 catalog.update_table_schema(table_name, &schema)?;
3917
3918 let rows_updated = self.storage.add_column_to_rows(
3919 table_name,
3920 &column_def.default,
3921 )?;
3922
3923 tracing::info!(
3924 "Added column '{}' to table '{}', updated {} rows",
3925 column_def.name, table_name, rows_updated
3926 );
3927
3928 Ok(rows_updated as u64)
3929 }
3930 sql::LogicalPlan::AlterTableDropColumn { table_name, column_name, if_exists, cascade } => {
3931 let catalog = self.storage.catalog();
3932 let mut schema = catalog.get_table_schema(table_name)?;
3933
3934 let col_idx = schema.columns.iter()
3935 .position(|c| c.name == *column_name);
3936
3937 match col_idx {
3938 Some(idx) => {
3939 let is_pk = schema.get_column_at(idx)
3940 .ok_or_else(|| Error::internal("column index out of bounds"))?
3941 .primary_key;
3942 if is_pk && !cascade {
3943 return Err(Error::query_execution(format!(
3944 "Cannot drop primary key column '{}' without CASCADE", column_name
3945 )));
3946 }
3947
3948 schema.columns.remove(idx);
3949 catalog.update_table_schema(table_name, &schema)?;
3950
3951 let rows_updated = self.storage.drop_column_from_rows(table_name, idx)?;
3952
3953 tracing::info!(
3954 "Dropped column '{}' from table '{}', updated {} rows",
3955 column_name, table_name, rows_updated
3956 );
3957
3958 Ok(rows_updated as u64)
3959 }
3960 None => {
3961 if *if_exists {
3962 Ok(0)
3963 } else {
3964 Err(Error::query_execution(format!(
3965 "Column '{}' does not exist in table '{}'", column_name, table_name
3966 )))
3967 }
3968 }
3969 }
3970 }
3971 sql::LogicalPlan::AlterTableRenameColumn { table_name, old_column_name, new_column_name } => {
3972 let catalog = self.storage.catalog();
3973 let mut schema = catalog.get_table_schema(table_name)?;
3974
3975 if schema.columns.iter().any(|c| c.name == *new_column_name) {
3976 return Err(Error::query_execution(format!(
3977 "Column '{}' already exists in table '{}'", new_column_name, table_name
3978 )));
3979 }
3980
3981 let col_idx = schema.columns.iter()
3982 .position(|c| c.name == *old_column_name)
3983 .ok_or_else(|| Error::query_execution(format!(
3984 "Column '{}' does not exist in table '{}'", old_column_name, table_name
3985 )))?;
3986
3987 schema.get_column_at_mut(col_idx)
3988 .ok_or_else(|| Error::internal("column index out of bounds"))?
3989 .name = new_column_name.clone();
3990 catalog.update_table_schema(table_name, &schema)?;
3991
3992 tracing::info!(
3993 "Renamed column '{}' to '{}' in table '{}'",
3994 old_column_name, new_column_name, table_name
3995 );
3996
3997 Ok(0)
3998 }
3999 sql::LogicalPlan::AlterTableRename { table_name, new_table_name } => {
4000 let catalog = self.storage.catalog();
4001
4002 if catalog.get_table_schema(new_table_name).is_ok() {
4003 return Err(Error::query_execution(format!(
4004 "Table '{}' already exists", new_table_name
4005 )));
4006 }
4007
4008 self.storage.rename_table(table_name, new_table_name)?;
4009
4010 tracing::info!(
4011 "Renamed table '{}' to '{}'",
4012 table_name, new_table_name
4013 );
4014
4015 Ok(0)
4016 }
4017 _ => Err(Error::internal(format!(
4018 "execute_alter_table_op called with non-ALTER plan: {:?}",
4019 plan.plan_type_name()
4020 ))),
4021 }
4022 }
4023
4024 #[allow(clippy::indexing_slicing)] fn try_fast_insert(&self, sql: &str) -> Option<Result<u64>> {
4036 let trimmed = sql.trim();
4037
4038 if trimmed.len() < 20 || !trimmed.as_bytes().get(..6)?.eq_ignore_ascii_case(b"INSERT") {
4040 return None;
4041 }
4042
4043 let upper = trimmed.to_ascii_uppercase();
4045 if upper.contains("RETURNING") || upper.contains("ON CONFLICT")
4046 || upper.contains("DEFAULT") || upper.contains("SELECT") {
4047 return None;
4048 }
4049
4050 let after_insert = trimmed.get(6..)?.trim_start();
4053 if !after_insert.as_bytes().get(..4)?.eq_ignore_ascii_case(b"INTO") {
4054 return None;
4055 }
4056 let after_into = after_insert.get(4..)?.trim_start();
4057
4058 let table_end = after_into.find(|c: char| c == '(' || c.is_whitespace())?;
4064 let table_name = after_into.get(..table_end)?.trim().trim_matches('"');
4065 if table_name.is_empty() {
4066 return None;
4067 }
4068 let rest = after_into.get(table_end..)?.trim_start();
4069
4070 if !rest.starts_with('(') {
4072 return None;
4073 }
4074 let col_end = rest.find(')')?;
4075 let col_list_str = rest.get(1..col_end)?;
4076 let columns: Vec<&str> = col_list_str.split(',').map(|s| s.trim()).collect();
4077 if columns.is_empty() || columns.iter().any(|c| c.is_empty()) {
4078 return None;
4079 }
4080
4081 let after_cols = rest.get(col_end + 1..)?.trim_start();
4083 if after_cols.len() < 6 || !after_cols.as_bytes().get(..6)?.eq_ignore_ascii_case(b"VALUES") {
4084 return None;
4085 }
4086 let values_rest = after_cols.get(6..)?.trim_start();
4087
4088 if !values_rest.starts_with('(') {
4090 return None;
4091 }
4092 let values_inner = values_rest.get(1..)?;
4094 let close_idx = Self::find_closing_paren(values_inner)?;
4095 let values_str = values_inner.get(..close_idx)?;
4096
4097 let after_values = values_inner.get(close_idx + 1..)?.trim();
4099 if !after_values.is_empty() && after_values != ";" {
4100 return None; }
4102
4103 if self.tenant_manager.should_apply_rls(table_name, "INSERT") {
4105 return None;
4106 }
4107
4108 if self.trigger_registry.has_triggers_for_table(table_name) {
4110 return None;
4111 }
4112
4113 let catalog = self.storage.catalog();
4115 let schema = match catalog.get_table_schema(table_name) {
4116 Ok(s) => s,
4117 Err(_) => return None, };
4119
4120 if let Ok(tc) = catalog.load_table_constraints(table_name) {
4127 if !tc.foreign_keys.is_empty() {
4128 return None;
4129 }
4130 }
4131
4132 if columns.len() != Self::fast_parse_value_count(values_str) {
4134 return None; }
4136
4137 let mut target_types = Vec::with_capacity(columns.len());
4138 let mut col_indices = Vec::with_capacity(columns.len());
4139 for col_name in &columns {
4140 match schema.get_column_index(col_name) {
4141 Some(idx) => {
4142 col_indices.push(idx);
4143 match schema.get_column_at(idx) {
4144 Some(col) => target_types.push(col.data_type.clone()),
4145 None => return None,
4146 }
4147 }
4148 None => return None, }
4150 }
4151
4152 let values = Self::fast_parse_values(values_str, &target_types)?;
4154
4155 let mut tuple_values = vec![Value::Null; schema.columns.len()];
4157 let mut user_provided = vec![false; schema.columns.len()];
4158 for (i, &col_idx) in col_indices.iter().enumerate() {
4159 if let Some(val) = values.get(i) {
4160 if col_idx < tuple_values.len() {
4161 tuple_values[col_idx] = val.clone();
4162 user_provided[col_idx] = true;
4163 }
4164 }
4165 }
4166
4167 if let Err(e) = Self::apply_defaults_and_check_not_null(
4171 &mut tuple_values, &schema, &user_provided,
4172 ) {
4173 return Some(Err(e));
4174 }
4175
4176 let tuple = Tuple::new(tuple_values);
4177 if self.storage.get_current_branch_id().is_none() {
4179 Some(self.storage.insert_tuple_fast(table_name, tuple, &schema).map(|_| 1))
4180 } else {
4181 Some(self.storage.insert_tuple_branch_aware_with_schema(table_name, tuple, &schema).map(|_| 1))
4182 }
4183 }
4184
4185 fn try_fast_update(&self, sql: &str) -> Option<Result<u64>> {
4188 let trimmed = sql.trim();
4189
4190 if trimmed.len() < 20 || !trimmed.as_bytes().get(..6)?.eq_ignore_ascii_case(b"UPDATE") {
4192 return None;
4193 }
4194
4195 let upper = trimmed.to_ascii_uppercase();
4197 if upper.contains("RETURNING") || upper.contains("JOIN")
4198 || upper.contains("FROM") || upper.contains("SELECT")
4199 || upper.contains("CASE") || upper.contains("COALESCE") {
4200 return None;
4201 }
4202
4203 let after_update = trimmed.get(6..)?.trim_start();
4205
4206 let table_end = after_update.find(|c: char| c.is_whitespace())?;
4208 let table_name = after_update.get(..table_end)?.trim();
4209 if table_name.is_empty() {
4210 return None;
4211 }
4212 let rest = after_update.get(table_end..)?.trim_start();
4213
4214 if rest.len() < 3 || !rest.as_bytes().get(..3)?.eq_ignore_ascii_case(b"SET") {
4216 return None;
4217 }
4218 let after_set = rest.get(3..)?.trim_start();
4219
4220 let where_pos = {
4222 let upper_rest = after_set.to_ascii_uppercase();
4223 let pos = upper_rest.find("WHERE")?;
4224 if pos == 0 { return None; }
4226 let prev = after_set.as_bytes().get(pos - 1)?;
4227 if !prev.is_ascii_whitespace() { return None; }
4228 pos
4229 };
4230
4231 let set_clause = after_set.get(..where_pos)?.trim();
4232 let where_clause = after_set.get(where_pos + 5..)?.trim();
4233
4234 if set_clause.contains(',') {
4236 return None; }
4238 let eq_pos = set_clause.find('=')?;
4239 let set_col = set_clause.get(..eq_pos)?.trim();
4240 let set_val_str = set_clause.get(eq_pos + 1..)?.trim();
4241 if set_col.is_empty() || set_val_str.is_empty() {
4242 return None;
4243 }
4244
4245 let where_clause = where_clause.strip_suffix(';').unwrap_or(where_clause).trim();
4248 let where_upper = where_clause.to_ascii_uppercase();
4250 if where_upper.contains("AND") || where_upper.contains("OR")
4251 || where_upper.contains("IN") || where_upper.contains("BETWEEN") {
4252 return None;
4253 }
4254 let weq_pos = where_clause.find('=')?;
4255 let pk_col = where_clause.get(..weq_pos)?.trim();
4256 let pk_val_str = where_clause.get(weq_pos + 1..)?.trim();
4257 if pk_col.is_empty() || pk_val_str.is_empty() {
4258 return None;
4259 }
4260
4261 if self.tenant_manager.should_apply_rls(table_name, "UPDATE") {
4263 return None;
4264 }
4265
4266 if self.trigger_registry.has_triggers_for_table(table_name) {
4268 return None;
4269 }
4270
4271 if self.storage.get_current_branch_id().is_some() {
4273 return None;
4274 }
4275
4276 let catalog = self.storage.catalog();
4278 let schema = match catalog.get_table_schema(table_name) {
4279 Ok(s) => s,
4280 Err(_) => return None,
4281 };
4282
4283 let pk_col_idx = schema.get_column_index(pk_col)?;
4285 let pk_column = schema.get_column_at(pk_col_idx)?;
4286 if !pk_column.primary_key {
4287 return None; }
4289
4290 let set_col_idx = schema.get_column_index(set_col)?;
4292 let set_column = schema.get_column_at(set_col_idx)?;
4293
4294 if self.storage.art_indexes().has_fk(table_name) {
4302 return None;
4303 }
4304 if let Ok(constraints) = self.storage.catalog().load_table_constraints(table_name) {
4305 if !constraints.foreign_keys.is_empty() {
4306 return None; }
4308 }
4309
4310 let (pk_value, _) = Self::fast_parse_one_value(pk_val_str, &pk_column.data_type)?;
4312
4313 let existing_row = match self.storage.get_row_by_pk_with_schema(table_name, &pk_value, &schema) {
4315 Ok(Some(row)) => row,
4316 Ok(None) => return Some(Ok(0)), Err(e) => return Some(Err(e)),
4318 };
4319
4320 let row_id = existing_row.row_id.unwrap_or(0);
4321 if row_id == 0 {
4322 return None; }
4324
4325 let new_value = if let Some((val, _)) = Self::fast_parse_one_value(set_val_str, &set_column.data_type) {
4327 val
4328 } else if let Some(val) = Self::fast_eval_simple_expr(set_val_str, set_col, set_col_idx, &existing_row) {
4329 val
4330 } else {
4331 return None; };
4333
4334 if !set_column.nullable && matches!(new_value, Value::Null) {
4336 return Some(Err(Error::constraint_violation(format!(
4337 "Column '{}' cannot be null", set_col
4338 ))));
4339 }
4340
4341 let mut new_values = existing_row.values.clone();
4343 if set_col_idx < new_values.len() {
4344 #[allow(clippy::indexing_slicing)]
4346 { new_values[set_col_idx] = new_value; }
4347 } else {
4348 return None;
4349 }
4350
4351 let new_tuple = Tuple::new(new_values);
4352
4353 Some(self.storage.update_tuple_fast(table_name, row_id, new_tuple, &existing_row, &schema))
4355 }
4356
4357 fn try_fast_select(&self, sql: &str) -> Option<Result<Vec<Tuple>>> {
4360 let trimmed = sql.trim();
4361
4362 if trimmed.len() < 20 || !trimmed.as_bytes().get(..6)?.eq_ignore_ascii_case(b"SELECT") {
4364 return None;
4365 }
4366
4367 let after_select = trimmed.get(6..)?.trim_start();
4368
4369 if !after_select.starts_with('*') {
4371 return None;
4372 }
4373 let after_star = after_select.get(1..)?.trim_start();
4374
4375 if after_star.len() < 4 || !after_star.as_bytes().get(..4)?.eq_ignore_ascii_case(b"FROM") {
4377 return None;
4378 }
4379 let after_from = after_star.get(4..)?.trim_start();
4380
4381 let table_end = after_from.find(|c: char| c.is_whitespace())?;
4383 let table_name = after_from.get(..table_end)?.trim();
4384 if table_name.is_empty() {
4385 return None;
4386 }
4387 let rest = after_from.get(table_end..)?.trim_start();
4388
4389 if rest.len() < 5 || !rest.as_bytes().get(..5)?.eq_ignore_ascii_case(b"WHERE") {
4391 return None;
4392 }
4393 let where_clause = rest.get(5..)?.trim_start();
4394
4395 let upper = where_clause.to_ascii_uppercase();
4397 if upper.contains("AND") || upper.contains("OR")
4398 || upper.contains("JOIN") || upper.contains("ORDER")
4399 || upper.contains("GROUP") || upper.contains("LIMIT") {
4400 return None;
4401 }
4402
4403 let where_clause = where_clause.strip_suffix(';').unwrap_or(where_clause).trim();
4405 let eq_pos = where_clause.find('=')?;
4406 let pk_col = where_clause.get(..eq_pos)?.trim();
4407 let pk_val_str = where_clause.get(eq_pos + 1..)?.trim();
4408 if pk_col.is_empty() || pk_val_str.is_empty() {
4409 return None;
4410 }
4411
4412 if self.tenant_manager.should_apply_rls(table_name, "SELECT") {
4414 return None;
4415 }
4416
4417 let catalog = self.storage.catalog();
4419 let schema = match catalog.get_table_schema(table_name) {
4420 Ok(s) => s,
4421 Err(_) => return None,
4422 };
4423
4424 let pk_col_idx = schema.get_column_index(pk_col)?;
4426 let pk_column = schema.get_column_at(pk_col_idx)?;
4427 if !pk_column.primary_key {
4428 return None; }
4430
4431 let (pk_value, _) = Self::fast_parse_one_value(pk_val_str, &pk_column.data_type)?;
4433
4434 match self.storage.get_row_by_pk_with_schema(table_name, &pk_value, &schema) {
4436 Ok(Some(row)) => Some(Ok(vec![row])),
4437 Ok(None) => Some(Ok(vec![])),
4438 Err(e) => Some(Err(e)),
4439 }
4440 }
4441
4442 fn fast_eval_simple_expr(expr: &str, col_name: &str, col_idx: usize, row: &Tuple) -> Option<Value> {
4445 let expr = expr.trim();
4446
4447 if !expr.starts_with(col_name) {
4450 return None;
4451 }
4452 let after_col = expr.get(col_name.len()..)?.trim_start();
4453 if after_col.is_empty() {
4454 return None;
4455 }
4456
4457 let current = row.values.get(col_idx)?;
4459
4460 let (op, operand_str) = if let Some(rest) = after_col.strip_prefix('+') {
4462 ('+', rest.trim())
4463 } else if let Some(rest) = after_col.strip_prefix('-') {
4464 ('-', rest.trim())
4465 } else if let Some(rest) = after_col.strip_prefix('*') {
4466 ('*', rest.trim())
4467 } else {
4468 return None;
4469 };
4470
4471 match (current, op) {
4473 (Value::Int2(v), '+') => { let n: i16 = operand_str.parse().ok()?; Some(Value::Int2(v.checked_add(n)?)) }
4474 (Value::Int2(v), '-') => { let n: i16 = operand_str.parse().ok()?; Some(Value::Int2(v.checked_sub(n)?)) }
4475 (Value::Int2(v), '*') => { let n: i16 = operand_str.parse().ok()?; Some(Value::Int2(v.checked_mul(n)?)) }
4476 (Value::Int4(v), '+') => { let n: i32 = operand_str.parse().ok()?; Some(Value::Int4(v.checked_add(n)?)) }
4477 (Value::Int4(v), '-') => { let n: i32 = operand_str.parse().ok()?; Some(Value::Int4(v.checked_sub(n)?)) }
4478 (Value::Int4(v), '*') => { let n: i32 = operand_str.parse().ok()?; Some(Value::Int4(v.checked_mul(n)?)) }
4479 (Value::Int8(v), '+') => { let n: i64 = operand_str.parse().ok()?; Some(Value::Int8(v.checked_add(n)?)) }
4480 (Value::Int8(v), '-') => { let n: i64 = operand_str.parse().ok()?; Some(Value::Int8(v.checked_sub(n)?)) }
4481 (Value::Int8(v), '*') => { let n: i64 = operand_str.parse().ok()?; Some(Value::Int8(v.checked_mul(n)?)) }
4482 (Value::Float4(v), '+') => { let n: f32 = operand_str.parse().ok()?; Some(Value::Float4(v + n)) }
4483 (Value::Float4(v), '-') => { let n: f32 = operand_str.parse().ok()?; Some(Value::Float4(v - n)) }
4484 (Value::Float4(v), '*') => { let n: f32 = operand_str.parse().ok()?; Some(Value::Float4(v * n)) }
4485 (Value::Float8(v), '+') => { let n: f64 = operand_str.parse().ok()?; Some(Value::Float8(v + n)) }
4486 (Value::Float8(v), '-') => { let n: f64 = operand_str.parse().ok()?; Some(Value::Float8(v - n)) }
4487 (Value::Float8(v), '*') => { let n: f64 = operand_str.parse().ok()?; Some(Value::Float8(v * n)) }
4488 _ => None,
4489 }
4490 }
4491
4492 #[allow(clippy::indexing_slicing)] fn find_closing_paren(s: &str) -> Option<usize> {
4495 let mut in_string = false;
4496 let bytes = s.as_bytes();
4497 let mut i = 0;
4498 while i < bytes.len() {
4499 let b = bytes[i];
4500 if in_string {
4501 if b == b'\'' {
4502 if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
4504 i += 2;
4505 continue;
4506 }
4507 in_string = false;
4508 }
4509 } else if b == b'\'' {
4510 in_string = true;
4511 } else if b == b')' {
4512 return Some(i);
4513 }
4514 i += 1;
4515 }
4516 None
4517 }
4518
4519 #[allow(clippy::indexing_slicing)] fn fast_parse_value_count(s: &str) -> usize {
4522 if s.trim().is_empty() {
4523 return 0;
4524 }
4525 let mut count = 1;
4526 let mut in_string = false;
4527 let bytes = s.as_bytes();
4528 let mut i = 0;
4529 while i < bytes.len() {
4530 let b = bytes[i];
4531 if in_string {
4532 if b == b'\'' {
4533 if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
4534 i += 2;
4535 continue;
4536 }
4537 in_string = false;
4538 }
4539 } else if b == b'\'' {
4540 in_string = true;
4541 } else if b == b',' {
4542 count += 1;
4543 }
4544 i += 1;
4545 }
4546 count
4547 }
4548
4549 #[allow(clippy::indexing_slicing)] fn fast_parse_values(s: &str, target_types: &[DataType]) -> Option<Vec<Value>> {
4553 let mut values = Vec::with_capacity(target_types.len());
4554 let mut remaining = s;
4555 let mut type_idx = 0;
4556
4557 while !remaining.is_empty() && type_idx < target_types.len() {
4558 remaining = remaining.trim_start();
4559 if remaining.is_empty() {
4560 break;
4561 }
4562
4563 let (value, rest) = Self::fast_parse_one_value(remaining, &target_types[type_idx])?;
4564 values.push(value);
4565 type_idx += 1;
4566
4567 let rest = rest.trim_start();
4569 if rest.starts_with(',') {
4570 remaining = rest.get(1..)?;
4571 } else {
4572 remaining = rest;
4573 }
4574 }
4575
4576 if values.len() == target_types.len() {
4577 Some(values)
4578 } else {
4579 None
4580 }
4581 }
4582
4583 #[allow(clippy::indexing_slicing)] fn fast_parse_one_value<'a>(s: &'a str, target_type: &DataType) -> Option<(Value, &'a str)> {
4586 let s = s.trim_start();
4587 if s.is_empty() {
4588 return None;
4589 }
4590
4591 let first = s.as_bytes().first()?;
4592
4593 if *first == b'\'' {
4595 match target_type {
4598 DataType::Int2 | DataType::Int4 | DataType::Int8
4599 | DataType::Float4 | DataType::Float8
4600 | DataType::Boolean => return None,
4601 _ => {}
4602 }
4603
4604 let inner = s.get(1..)?;
4605 let mut end = 0;
4606 let bytes = inner.as_bytes();
4607 let mut result = String::new();
4608 let mut seg_start = 0; while end < bytes.len() {
4610 if bytes[end] == b'\'' {
4611 if seg_start < end {
4613 result.push_str(inner.get(seg_start..end)?);
4614 }
4615 if end + 1 < bytes.len() && bytes[end + 1] == b'\'' {
4616 result.push('\'');
4617 end += 2;
4618 seg_start = end;
4619 continue;
4620 }
4621 let rest = inner.get(end + 1..)?;
4623 let typed = match target_type {
4631 DataType::Uuid => uuid::Uuid::parse_str(&result)
4632 .map(Value::Uuid)
4633 .unwrap_or(Value::String(result)),
4634 DataType::Date => result
4635 .parse::<chrono::NaiveDate>()
4636 .map(Value::Date)
4637 .unwrap_or(Value::String(result)),
4638 DataType::Timestamp => chrono::DateTime::parse_from_rfc3339(&result)
4639 .map(|t| Value::Timestamp(t.to_utc()))
4640 .unwrap_or(Value::String(result)),
4641 _ => Value::String(result),
4642 };
4643 return Some((typed, rest));
4644 }
4645 end += 1;
4646 }
4647 return None; }
4649
4650 if s.len() >= 4 && s.as_bytes().get(..4)?.eq_ignore_ascii_case(b"NULL") {
4652 let rest = s.get(4..)?;
4653 if rest.is_empty() || rest.starts_with(',') || rest.starts_with(')') || rest.starts_with(' ') {
4655 return Some((Value::Null, rest));
4656 }
4657 }
4658
4659 if s.len() >= 4 && s.as_bytes().get(..4)?.eq_ignore_ascii_case(b"TRUE") {
4661 let rest = s.get(4..)?;
4662 if rest.is_empty() || rest.starts_with(',') || rest.starts_with(')') || rest.starts_with(' ') {
4663 return Some((Value::Boolean(true), rest));
4664 }
4665 }
4666 if s.len() >= 5 && s.as_bytes().get(..5)?.eq_ignore_ascii_case(b"FALSE") {
4667 let rest = s.get(5..)?;
4668 if rest.is_empty() || rest.starts_with(',') || rest.starts_with(')') || rest.starts_with(' ') {
4669 return Some((Value::Boolean(false), rest));
4670 }
4671 }
4672
4673 if first.is_ascii_digit() || *first == b'-' || *first == b'+' || *first == b'.' {
4675 let end = s.find([',', ')', ' '])
4676 .unwrap_or(s.len());
4677 let num_str = s.get(..end)?.trim();
4678 let rest = s.get(end..)?;
4679
4680 let value = match target_type {
4682 DataType::Int4 => {
4683 let n: i32 = num_str.parse().ok()?;
4684 Value::Int4(n)
4685 }
4686 DataType::Int8 => {
4687 let n: i64 = num_str.parse().ok()?;
4688 Value::Int8(n)
4689 }
4690 DataType::Float4 => {
4691 let f: f32 = num_str.parse().ok()?;
4692 Value::Float4(f)
4693 }
4694 DataType::Float8 => {
4695 let f: f64 = num_str.parse().ok()?;
4696 Value::Float8(f)
4697 }
4698 DataType::Numeric => {
4699 if let Ok(n) = num_str.parse::<i64>() {
4701 Value::Int8(n)
4702 } else if let Ok(f) = num_str.parse::<f64>() {
4703 Value::Float8(f)
4704 } else {
4705 return None;
4706 }
4707 }
4708 _ => {
4709 if num_str.contains('.') {
4711 let f: f64 = num_str.parse().ok()?;
4712 Value::Float8(f)
4713 } else if let Ok(n) = num_str.parse::<i32>() {
4714 Value::Int4(n)
4715 } else if let Ok(n) = num_str.parse::<i64>() {
4716 Value::Int8(n)
4717 } else {
4718 return None;
4719 }
4720 }
4721 };
4722 return Some((value, rest));
4723 }
4724
4725 None
4727 }
4728
4729 pub(crate) fn parse_cached(&self, sql: &str) -> Result<(sqlparser::ast::Statement, bool)> {
4731 if let Ok(mut cache) = self.parse_cache.lock() {
4733 if let Some(stmt) = cache.get(sql) {
4734 return Ok((stmt.clone(), true));
4735 }
4736 }
4737 let parser = sql::Parser::new();
4739 let statement = parser.parse_one(sql)?;
4740 if let Ok(mut cache) = self.parse_cache.lock() {
4741 cache.put(sql.to_string(), statement.clone());
4742 }
4743 Ok((statement, false))
4744 }
4745
4746 fn execute_internal(&self, sql: &str) -> Result<u64> {
4748 if let Some(context) = self.tenant_manager.get_current_context() {
4750 self.tenant_manager.record_query(context.tenant_id)
4751 .map_err(|e| Error::query_execution(format!("Quota exceeded: {}", e)))?;
4752 }
4753
4754 let parse_start = std::time::Instant::now();
4756 let (statement, parse_cached) = self.parse_cached(sql)?;
4757 let parse_elapsed = parse_start.elapsed();
4758 if parse_cached {
4759 tracing::debug!(phase = "parse", duration_us = parse_elapsed.as_micros() as u64, "SQL parsed (AST cached)");
4760 } else {
4761 tracing::debug!(phase = "parse", duration_us = parse_elapsed.as_micros() as u64, "SQL parsed");
4762 }
4763
4764 let plan_start = std::time::Instant::now();
4766 let catalog = self.storage.catalog();
4767 let planner = sql::Planner::with_catalog(&catalog);
4768 let plan = planner.statement_to_plan(statement)?;
4769 let plan_elapsed = plan_start.elapsed();
4770 tracing::debug!(phase = "plan", duration_us = plan_elapsed.as_micros() as u64, "Logical plan created");
4771
4772 if matches!(&plan,
4774 sql::LogicalPlan::CreateTable { .. } |
4775 sql::LogicalPlan::DropTable { .. } |
4776 sql::LogicalPlan::AlterTableAddColumn { .. } |
4777 sql::LogicalPlan::AlterTableDropColumn { .. } |
4778 sql::LogicalPlan::AlterTableRename { .. } |
4779 sql::LogicalPlan::AlterTableRenameColumn { .. } |
4780 sql::LogicalPlan::AlterTableAddForeignKey { .. } |
4781 sql::LogicalPlan::AlterTableMulti { .. } |
4782 sql::LogicalPlan::CreateIndex { .. } |
4783 sql::LogicalPlan::Truncate { .. } |
4784 sql::LogicalPlan::CreateMaterializedView { .. } |
4785 sql::LogicalPlan::DropMaterializedView { .. }
4786 ) {
4787 self.invalidate_plan_cache();
4788 }
4789
4790 match &plan {
4792 sql::LogicalPlan::CreateTable { name, columns, if_not_exists, .. } => {
4793 if *if_not_exists && self.storage.catalog().table_exists(name).unwrap_or(false) {
4795 return Ok(0);
4796 }
4797
4798 let schema_columns: Vec<Column> = columns.iter().map(|col_def| {
4800 Column {
4801 name: col_def.name.clone(),
4802 data_type: col_def.data_type.clone(),
4803 nullable: !col_def.not_null, primary_key: col_def.primary_key,
4805 source_table: None,
4806 source_table_name: None,
4807 default_expr: None,
4808 unique: false,
4809 storage_mode: col_def.storage_mode,
4810 }
4811 }).collect();
4812
4813 let schema = Schema::new(schema_columns);
4814 let catalog = self.storage.catalog();
4815
4816 if let Err(e) = self.storage.log_create_table(name, &schema) {
4818 tracing::warn!("Failed to log CREATE TABLE to WAL: {}", e);
4819 }
4820
4821 catalog.create_table(name, schema)?;
4822 Ok(1) }
4824 sql::LogicalPlan::Insert { table_name, columns, values, returning, on_conflict: _ } => {
4825 let rls_enforced = self.tenant_manager.should_apply_rls(table_name, "INSERT");
4827 let rls_check = if rls_enforced {
4828 self.tenant_manager.get_rls_conditions(table_name, "INSERT")
4829 } else {
4830 None
4831 };
4832
4833 let catalog = self.storage.catalog();
4835 let schema = catalog.get_table_schema(table_name)?;
4836
4837 let evaluator = sql::Evaluator::new(std::sync::Arc::new(Schema {
4840 columns: vec![], }));
4842 let empty_tuple = Tuple::new(vec![]);
4843
4844 let mut count = 0;
4845 for value_row in values {
4846 let mut tuple_values: Vec<Value> = Vec::new();
4849
4850 for (col_idx, expr) in value_row.iter().enumerate() {
4851 let target_col_idx = if let Some(ref cols) = columns {
4853 let col_name = cols.get(col_idx)
4855 .ok_or_else(|| Error::internal("column index out of bounds"))?;
4856 schema.get_column_index(col_name)
4857 .ok_or_else(|| Error::query_execution(format!("Column '{}' not found", col_name)))?
4858 } else {
4859 col_idx
4861 };
4862
4863 let target_col = schema.get_column_at(target_col_idx)
4864 .ok_or_else(|| Error::query_execution(format!(
4865 "Too many values for INSERT: table has {} columns",
4866 schema.columns.len()
4867 )))?;
4868
4869 let target_type = &target_col.data_type;
4870
4871 let mut value = evaluator.evaluate(expr, &empty_tuple)?;
4873
4874 let needs_cast = match (&value, target_type) {
4876 (Value::Null, _) => false, (Value::Vector(_), DataType::Vector(_)) => false,
4878 (Value::String(_), DataType::Vector(_)) => true, (Value::String(_), DataType::Json | DataType::Jsonb) => true, (Value::Int4(_), DataType::Int4) => false,
4881 (Value::Int8(_), DataType::Int8) => false,
4882 (Value::Float4(_), DataType::Float4) => false,
4883 (Value::Float8(_), DataType::Float8) => false,
4884 (Value::String(_), DataType::Text | DataType::Varchar(_)) => false,
4885 (Value::Boolean(_), DataType::Boolean) => false,
4886 (Value::Json(_), DataType::Json | DataType::Jsonb) => false,
4887 _ => true, };
4889
4890 if needs_cast {
4891 value = evaluator.cast_value(value, target_type)?;
4892 }
4893
4894 tuple_values.push(value);
4895 }
4896
4897 let tuple = Tuple::new(tuple_values);
4898
4899 if let Some((_, with_check)) = &rls_check {
4901 if let Some(ref with_check_expr) = with_check {
4902 let tenant_context = self.tenant_manager.get_current_context();
4904 let rls_evaluator = tenant::RLSExpressionEvaluator::new(
4905 std::sync::Arc::new(schema.clone()),
4906 tenant_context
4907 );
4908 let expr = rls_evaluator.parse(with_check_expr)?;
4909 let satisfies_policy = rls_evaluator.evaluate(&expr, &tuple)?;
4910
4911 if !satisfies_policy {
4912 return Err(Error::query_execution(format!(
4913 "Row-Level Security policy violation: inserted row does not satisfy WITH CHECK expression"
4914 )));
4915 }
4916 }
4917 }
4918
4919 self.storage.insert_tuple_branch_aware_with_schema(table_name, tuple, &schema)?;
4920 count += 1;
4921 }
4922 Ok(count)
4923 }
4924 sql::LogicalPlan::InsertSelect { table_name, columns, source, returning: _ } => {
4925 let mut executor = sql::Executor::with_storage(&self.storage)
4927 .with_timeout(self.config.storage.query_timeout_ms);
4928 let source_rows = executor.execute(source)?;
4929
4930 let rls_enforced = self.tenant_manager.should_apply_rls(table_name, "INSERT");
4932 let rls_check = if rls_enforced {
4933 self.tenant_manager.get_rls_conditions(table_name, "INSERT")
4934 } else {
4935 None
4936 };
4937
4938 let catalog = self.storage.catalog();
4939 let schema = catalog.get_table_schema(table_name)?;
4940 let evaluator = sql::Evaluator::new(std::sync::Arc::new(Schema {
4941 columns: vec![],
4942 }));
4943 let empty_tuple = Tuple::new(vec![]);
4944
4945 let column_indices: Option<Vec<usize>> = columns.as_ref().map(|cols| {
4946 cols.iter()
4947 .filter_map(|col_name| schema.get_column_index(col_name))
4948 .collect()
4949 });
4950
4951 let default_exprs: Vec<Option<sql::LogicalExpr>> = schema.columns.iter()
4952 .map(|col| {
4953 col.default_expr.as_ref().and_then(|json| {
4954 serde_json::from_str(json).ok()
4955 })
4956 })
4957 .collect();
4958
4959 let mut count = 0u64;
4960 for source_row in &source_rows {
4961 let mut tuple_values: Vec<Option<Value>> = vec![None; schema.columns.len()];
4962
4963 for (val_idx, value) in source_row.values.iter().enumerate() {
4964 let target_col_idx = if let Some(ref indices) = column_indices {
4965 if val_idx >= indices.len() {
4966 return Err(Error::query_execution("More values than columns specified"));
4967 }
4968 *indices.get(val_idx).ok_or_else(|| Error::internal("column index out of bounds"))?
4969 } else {
4970 val_idx
4971 };
4972
4973 let target_col = schema.get_column_at(target_col_idx)
4974 .ok_or_else(|| Error::query_execution(format!(
4975 "Too many values for INSERT: table has {} columns",
4976 schema.columns.len()
4977 )))?;
4978
4979 let target_type = &target_col.data_type;
4980 let mut val = value.clone();
4981
4982 let needs_cast = match (&val, target_type) {
4983 (Value::Null, _) => false,
4984 (Value::Vector(_), DataType::Vector(_)) => false,
4985 (Value::String(_), DataType::Vector(_)) => true,
4986 (Value::String(_), DataType::Json | DataType::Jsonb) => true,
4987 (Value::Int4(_), DataType::Int4) => false,
4988 (Value::Int8(_), DataType::Int8) => false,
4989 (Value::Float4(_), DataType::Float4) => false,
4990 (Value::Float8(_), DataType::Float8) => false,
4991 (Value::String(_), DataType::Text | DataType::Varchar(_)) => false,
4992 (Value::Boolean(_), DataType::Boolean) => false,
4993 (Value::Json(_), DataType::Json | DataType::Jsonb) => false,
4994 _ => true,
4995 };
4996
4997 if needs_cast {
4998 val = evaluator.cast_value(val, target_type)?;
4999 }
5000
5001 let tv = tuple_values.get_mut(target_col_idx)
5002 .ok_or_else(|| Error::internal("column index out of bounds"))?;
5003 *tv = Some(val);
5004 }
5005
5006 let final_values: Result<Vec<Value>> = tuple_values
5007 .into_iter()
5008 .enumerate()
5009 .map(|(idx, opt_val)| {
5010 if let Some(val) = opt_val {
5011 Ok(val)
5012 } else {
5013 let col = schema.get_column_at(idx)
5014 .ok_or_else(|| Error::internal("column index out of bounds"))?;
5015 if let Some(ref default_expr) = default_exprs.get(idx).and_then(|d| d.as_ref()) {
5016 let mut value = evaluator.evaluate(default_expr, &empty_tuple)?;
5017 if value.data_type() != col.data_type {
5018 value = evaluator.cast_value(value, &col.data_type)?;
5019 }
5020 Ok(value)
5021 } else if col.primary_key {
5022 Ok(Value::Null)
5025 } else if col.nullable {
5026 Ok(Value::Null)
5027 } else {
5028 Err(Error::query_execution(format!(
5029 "Column '{}' does not have a default value and is not nullable",
5030 col.name
5031 )))
5032 }
5033 }
5034 })
5035 .collect();
5036
5037 let tuple = Tuple::new(final_values?);
5038
5039 if let Some((_, with_check)) = &rls_check {
5041 if let Some(ref with_check_expr) = with_check {
5042 let tenant_context = self.tenant_manager.get_current_context();
5043 let rls_evaluator = tenant::RLSExpressionEvaluator::new(
5044 std::sync::Arc::new(schema.clone()),
5045 tenant_context
5046 );
5047 let expr = rls_evaluator.parse(with_check_expr)?;
5048 let satisfies_policy = rls_evaluator.evaluate(&expr, &tuple)?;
5049 if !satisfies_policy {
5050 return Err(Error::query_execution(
5051 "Row-Level Security policy violation: inserted row does not satisfy WITH CHECK expression"
5052 ));
5053 }
5054 }
5055 }
5056
5057 self.storage.insert_tuple_branch_aware_with_schema(table_name, tuple, &schema)?;
5058 count += 1;
5059 }
5060 Ok(count)
5061 }
5062 sql::LogicalPlan::Update { table_name, assignments, selection, returning } => {
5063 let rls_enforced = self.tenant_manager.should_apply_rls(table_name, "UPDATE");
5065 let rls_condition = if rls_enforced {
5066 self.tenant_manager.get_rls_conditions(table_name, "UPDATE")
5067 } else {
5068 None
5069 };
5070
5071 let catalog = self.storage.catalog();
5073 let schema = catalog.get_table_schema(table_name)?;
5074 let eval_schema = schema.clone().with_source_table_name(table_name);
5076 let evaluator = sql::Evaluator::new(std::sync::Arc::new(eval_schema));
5077
5078 let tuples = self.storage.scan_table_branch_aware(table_name)?;
5080 let mut updates: Vec<(u64, Tuple)> = Vec::new();
5081
5082 for mut tuple in tuples {
5083 let where_matches = if let Some(predicate) = selection {
5085 let result = evaluator.evaluate(predicate, &tuple)?;
5086 match result {
5087 Value::Boolean(b) => b,
5088 _ => false,
5089 }
5090 } else {
5091 true };
5093
5094 let rls_matches = if let Some((using_expr, _)) = &rls_condition {
5096 let tenant_context = self.tenant_manager.get_current_context();
5098 let rls_evaluator = tenant::RLSExpressionEvaluator::new(
5099 std::sync::Arc::new(schema.clone()),
5100 tenant_context
5101 );
5102 let expr = rls_evaluator.parse(using_expr)?;
5103 rls_evaluator.evaluate(&expr, &tuple)?
5104 } else {
5105 true };
5107
5108 if where_matches && rls_matches {
5109 for (col_name, value_expr) in assignments {
5114 let bound = self.materialize_scalar_subqueries_for_row(
5115 value_expr, &tuple, &schema, table_name,
5116 )?;
5117 let mut new_value = evaluator.evaluate(&bound, &tuple)?;
5118 let col_index = evaluator.schema().get_column_index(col_name)
5120 .ok_or_else(|| Error::query_execution(format!("Column '{}' not found", col_name)))?;
5121 let target_col = schema.get_column_at(col_index)
5123 .ok_or_else(|| Error::query_execution(format!("Column '{}' not found", col_name)))?;
5124 let target_type = &target_col.data_type;
5125 let needs_cast = !matches!(&new_value, Value::Null)
5126 && !matches!(
5127 (&new_value, target_type),
5128 (Value::Vector(_), DataType::Vector(_))
5129 | (Value::Int2(_), DataType::Int2)
5130 | (Value::Int4(_), DataType::Int4)
5131 | (Value::Int8(_), DataType::Int8)
5132 | (Value::Float4(_), DataType::Float4)
5133 | (Value::Float8(_), DataType::Float8)
5134 | (Value::String(_), DataType::Text | DataType::Varchar(_))
5135 | (Value::Boolean(_), DataType::Boolean)
5136 | (Value::Json(_), DataType::Json | DataType::Jsonb)
5137 | (Value::Timestamp(_), DataType::Timestamp | DataType::Timestamptz)
5138 | (Value::Date(_), DataType::Date)
5139 );
5140 if needs_cast {
5141 new_value = evaluator.cast_value(new_value, target_type)?;
5142 }
5143 if let Some(slot) = tuple.values.get_mut(col_index) {
5144 *slot = new_value;
5145 }
5146 }
5147
5148 let mut new_col_values = std::collections::HashMap::with_capacity(schema.columns.len());
5152 for (i, col) in schema.columns.iter().enumerate() {
5153 if let Some(v) = tuple.values.get(i) {
5154 new_col_values.insert(col.name.clone(), v.clone());
5155 }
5156 }
5157 self.check_fk_constraints_on_write(table_name, &new_col_values, None)?;
5158
5159 let row_id = tuple.row_id.unwrap_or(0);
5160 updates.push((row_id, tuple));
5161 }
5162 }
5163
5164 let update_count = self.storage.update_tuples_branch_aware(table_name, updates)?;
5168 Ok(update_count)
5169 }
5170 sql::LogicalPlan::Delete { table_name, selection, returning } => {
5171 let rls_enforced = self.tenant_manager.should_apply_rls(table_name, "DELETE");
5173 let rls_condition = if rls_enforced {
5174 self.tenant_manager.get_rls_conditions(table_name, "DELETE")
5175 } else {
5176 None
5177 };
5178
5179 let catalog = self.storage.catalog();
5181 let schema = catalog.get_table_schema(table_name)?;
5182 let eval_schema = schema.clone().with_source_table_name(table_name);
5184 let evaluator = sql::Evaluator::new(std::sync::Arc::new(eval_schema));
5185
5186 let tuples = self.storage.scan_table_branch_aware(table_name)?;
5188 let mut row_ids_to_delete: Vec<u64> = Vec::new();
5189 let mut deleted_tuples: Vec<(u64, Tuple)> = Vec::new();
5190
5191 for tuple in tuples {
5192 let where_matches = if let Some(predicate) = selection {
5194 let result = evaluator.evaluate(predicate, &tuple)?;
5195 match result {
5196 Value::Boolean(b) => b,
5197 _ => false,
5198 }
5199 } else {
5200 true };
5202
5203 let rls_matches = if let Some((using_expr, _)) = &rls_condition {
5205 let tenant_context = self.tenant_manager.get_current_context();
5207 let rls_evaluator = tenant::RLSExpressionEvaluator::new(
5208 std::sync::Arc::new(schema.clone()),
5209 tenant_context
5210 );
5211 let expr = rls_evaluator.parse(using_expr)?;
5212 rls_evaluator.evaluate(&expr, &tuple)?
5213 } else {
5214 true };
5216
5217 if where_matches && rls_matches {
5218 if let Some(row_id) = tuple.row_id {
5219 row_ids_to_delete.push(row_id);
5220 deleted_tuples.push((row_id, tuple.clone()));
5221 }
5222 }
5223 }
5224
5225 for (row_id, tuple) in &deleted_tuples {
5227 let mut col_values = std::collections::HashMap::new();
5228 for (i, col) in schema.columns.iter().enumerate() {
5229 if let Some(v) = tuple.values.get(i) {
5230 col_values.insert(col.name.clone(), v.clone());
5231 }
5232 }
5233 if let Err(e) = self.storage.art_indexes().on_delete(table_name, *row_id, &col_values) {
5234 tracing::debug!("ART index delete for table '{}': {}", table_name, e);
5235 }
5236 }
5237
5238 let delete_count = self.storage.delete_tuples_branch_aware(table_name, row_ids_to_delete)?;
5242 Ok(delete_count)
5243 }
5244 sql::LogicalPlan::DropTable { name, if_exists } => {
5245 let catalog = self.storage.catalog();
5247 let exists = catalog.table_exists(name)?;
5248
5249 if exists {
5250 let mv_catalog = self.storage.mv_catalog();
5252 if let Ok(mv_names) = mv_catalog.list_views() {
5253 let mut dependent_mvs = Vec::new();
5254 for mv_name in &mv_names {
5255 if let Ok(metadata) = mv_catalog.get_view(mv_name) {
5256 if metadata.base_tables.iter().any(|t| t == name) {
5257 dependent_mvs.push(mv_name.clone());
5258 }
5259 }
5260 }
5261 if !dependent_mvs.is_empty() {
5262 tracing::warn!(
5263 "Dropping table '{}' which is used by materialized view(s): {}. Those views will be stale.",
5264 name,
5265 dependent_mvs.join(", ")
5266 );
5267 }
5268 }
5269
5270 catalog.drop_table(name)?;
5272
5273 if let Err(e) = self.trigger_registry.drop_table_triggers(name) {
5275 tracing::warn!("Failed to clean up triggers for dropped table '{}': {}", name, e);
5276 }
5277
5278 self.storage.predicate_pushdown().remove_table(name);
5280
5281 self.storage.row_cache().invalidate_table(name);
5283
5284 if let Err(e) = self.storage.log_drop_table(name) {
5286 tracing::warn!("Failed to log DROP TABLE to WAL: {}", e);
5287 }
5288
5289 Ok(0) } else if *if_exists {
5291 Ok(0)
5293 } else {
5294 Err(Error::query_execution(format!("Table '{}' does not exist", name)))
5296 }
5297 }
5298 sql::LogicalPlan::Truncate { table_name } => {
5299 let mut trigger_context = sql::TriggerContext::new();
5304 let trigger_event = sql::logical_plan::TriggerEvent::Truncate;
5305
5306 let row_context = sql::triggers::TriggerRowContext {
5308 old_tuple: None,
5309 new_tuple: None,
5310 transition_tables: None,
5311 };
5312
5313 let db_ref = self.clone_for_trigger();
5315 let mut executor_fn = |stmt: &sql::LogicalPlan, _ctx: &sql::triggers::TriggerRowContext| -> Result<()> {
5316 db_ref.execute_plan_internal(stmt)?;
5317 Ok(())
5318 };
5319
5320 let action = self.trigger_registry.execute_triggers(
5321 table_name,
5322 &trigger_event,
5323 &sql::logical_plan::TriggerTiming::Before,
5324 &row_context,
5325 &mut trigger_context,
5326 None, &mut executor_fn,
5328 )?;
5329
5330 match action {
5332 sql::triggers::TriggerAction::Abort(msg) => {
5333 return Err(Error::query_execution(format!("TRUNCATE aborted by trigger: {}", msg)));
5334 }
5335 sql::triggers::TriggerAction::Skip => {
5336 return Ok(0);
5338 }
5339 sql::triggers::TriggerAction::Continue => {
5340 }
5342 }
5343
5344 let prefix = format!("data:{}:", table_name);
5345 let prefix_bytes = prefix.as_bytes();
5346 let mut keys_to_delete = Vec::new();
5347
5348 let iter = self.storage.db.iterator(rocksdb::IteratorMode::Start);
5350 for item in iter {
5351 let (key, _) = item.map_err(|e| Error::storage(format!("Iterator error: {}", e)))?;
5352
5353 if !key.starts_with(prefix_bytes) {
5354 if !key.is_empty() && key.first() > prefix_bytes.first() {
5355 break;
5356 }
5357 continue;
5358 }
5359
5360 keys_to_delete.push(key.to_vec());
5361 }
5362
5363 for key in &keys_to_delete {
5365 self.storage.delete(key)?;
5366 }
5367
5368 self.storage.row_cache().invalidate_table(table_name);
5370
5371 let has_user_branches = self.storage.list_branches()
5378 .map(|b| b.iter().any(|br| br.name != "main"))
5379 .unwrap_or(false);
5380 if !has_user_branches {
5381 self.storage.art_indexes().clear_table_indexes(table_name);
5382 }
5383
5384 let action = self.trigger_registry.execute_triggers(
5386 table_name,
5387 &trigger_event,
5388 &sql::logical_plan::TriggerTiming::After,
5389 &row_context,
5390 &mut trigger_context,
5391 None, &mut executor_fn,
5393 )?;
5394
5395 if let sql::triggers::TriggerAction::Abort(msg) = action {
5397 return Err(Error::query_execution(format!("TRUNCATE failed in AFTER trigger: {}", msg)));
5398 }
5399
5400 if let Err(e) = self.storage.log_truncate(table_name) {
5402 tracing::warn!("Failed to log TRUNCATE to WAL: {}", e);
5403 }
5404
5405 Ok(keys_to_delete.len() as u64) }
5407 sql::LogicalPlan::AlterColumnStorage { table_name, column_name, storage_mode } => {
5408 let catalog = self.storage.catalog();
5412 let mut schema = catalog.get_table_schema(table_name)?;
5413
5414 let col_idx = schema.columns.iter()
5416 .position(|c| c.name == *column_name)
5417 .ok_or_else(|| Error::query_execution(format!(
5418 "Column '{}' not found in table '{}'", column_name, table_name
5419 )))?;
5420
5421 let col_ref = schema.get_column_at(col_idx)
5422 .ok_or_else(|| Error::internal("column index out of bounds"))?;
5423 let old_mode = col_ref.storage_mode;
5424 if old_mode == *storage_mode {
5425 return Ok(0);
5427 }
5428
5429 let column = col_ref.clone();
5431 let rows_migrated = self.storage.migrate_column_storage(
5432 table_name,
5433 col_idx,
5434 &column,
5435 old_mode,
5436 *storage_mode,
5437 )?;
5438
5439 schema.get_column_at_mut(col_idx)
5441 .ok_or_else(|| Error::internal("column index out of bounds"))?
5442 .storage_mode = *storage_mode;
5443 catalog.update_table_schema(table_name, &schema)?;
5444
5445 if let Err(e) = self.storage.log_alter_column_storage(table_name, column_name, storage_mode) {
5447 tracing::warn!("Failed to log ALTER COLUMN STORAGE to WAL: {}", e);
5448 }
5449
5450 tracing::info!(
5451 "Altered {}.{} storage from {:?} to {:?}, migrated {} rows",
5452 table_name, column_name, old_mode, storage_mode, rows_migrated
5453 );
5454
5455 Ok(rows_migrated as u64)
5456 }
5457 sql::LogicalPlan::AlterTableAddColumn { table_name, column_def, if_not_exists } => {
5458 let catalog = self.storage.catalog();
5459 let mut schema = catalog.get_table_schema(table_name)?;
5460
5461 if schema.columns.iter().any(|c| c.name == column_def.name) {
5462 if *if_not_exists {
5463 return Ok(0);
5464 }
5465 return Err(Error::query_execution(format!(
5466 "Column '{}' already exists in table '{}'", column_def.name, table_name
5467 )));
5468 }
5469
5470 let new_column = Column {
5471 name: column_def.name.clone(),
5472 data_type: column_def.data_type.clone(),
5473 nullable: !column_def.not_null,
5474 primary_key: column_def.primary_key,
5475 source_table: None,
5476 source_table_name: Some(table_name.clone()),
5477 default_expr: column_def.default.as_ref().map(|e| format!("{:?}", e)),
5478 unique: column_def.unique,
5479 storage_mode: column_def.storage_mode,
5480 };
5481
5482 schema.columns.push(new_column);
5483 catalog.update_table_schema(table_name, &schema)?;
5484
5485 let rows_updated = self.storage.add_column_to_rows(table_name, &column_def.default)?;
5486 Ok(rows_updated as u64)
5487 }
5488 sql::LogicalPlan::AlterTableDropColumn { table_name, column_name, if_exists, cascade } => {
5489 let catalog = self.storage.catalog();
5490 let mut schema = catalog.get_table_schema(table_name)?;
5491
5492 let col_idx = schema.columns.iter().position(|c| c.name == *column_name);
5493
5494 match col_idx {
5495 Some(idx) => {
5496 if schema.get_column_at(idx).is_some_and(|c| c.primary_key) && !cascade {
5497 return Err(Error::query_execution(format!(
5498 "Cannot drop primary key column '{}' without CASCADE", column_name
5499 )));
5500 }
5501
5502 schema.columns.remove(idx);
5503 catalog.update_table_schema(table_name, &schema)?;
5504 let rows_updated = self.storage.drop_column_from_rows(table_name, idx)?;
5505 Ok(rows_updated as u64)
5506 }
5507 None => {
5508 if *if_exists {
5509 Ok(0)
5510 } else {
5511 Err(Error::query_execution(format!(
5512 "Column '{}' does not exist in table '{}'", column_name, table_name
5513 )))
5514 }
5515 }
5516 }
5517 }
5518 sql::LogicalPlan::AlterTableRenameColumn { table_name, old_column_name, new_column_name } => {
5519 let catalog = self.storage.catalog();
5520 let mut schema = catalog.get_table_schema(table_name)?;
5521
5522 if schema.columns.iter().any(|c| c.name == *new_column_name) {
5523 return Err(Error::query_execution(format!(
5524 "Column '{}' already exists in table '{}'", new_column_name, table_name
5525 )));
5526 }
5527
5528 let col_idx = schema.columns.iter()
5529 .position(|c| c.name == *old_column_name)
5530 .ok_or_else(|| Error::query_execution(format!(
5531 "Column '{}' does not exist in table '{}'", old_column_name, table_name
5532 )))?;
5533
5534 schema.get_column_at_mut(col_idx)
5535 .ok_or_else(|| Error::internal("column index out of bounds"))?
5536 .name = new_column_name.clone();
5537 catalog.update_table_schema(table_name, &schema)?;
5538 Ok(0)
5539 }
5540 sql::LogicalPlan::AlterTableRename { table_name, new_table_name } => {
5541 let catalog = self.storage.catalog();
5542
5543 if catalog.get_table_schema(new_table_name).is_ok() {
5544 return Err(Error::query_execution(format!(
5545 "Table '{}' already exists", new_table_name
5546 )));
5547 }
5548
5549 self.storage.rename_table(table_name, new_table_name)?;
5550 Ok(0)
5551 }
5552 sql::LogicalPlan::AlterTableMulti { operations } => {
5553 let mut total_rows = 0u64;
5554 for sub_plan in operations {
5555 total_rows += self.execute_alter_table_op(sub_plan)?;
5556 }
5557 Ok(total_rows)
5558 }
5559 sql::LogicalPlan::Savepoint { ref name } => {
5560 let txn = self.current_transaction.lock()
5561 .map_err(|_| Error::query_execution("Failed to lock transaction"))?;
5562 let write_set_snapshot = match txn.as_ref() {
5563 Some(t) => t.savepoint_snapshot(),
5564 None => return Err(Error::query_execution("SAVEPOINT can only be used within a transaction")),
5565 };
5566 drop(txn);
5567 let savepoint = SavepointState {
5568 name: name.clone(),
5569 write_set_snapshot,
5570 };
5571 self.savepoints.write().push(savepoint);
5572 Ok(0)
5573 }
5574 sql::LogicalPlan::ReleaseSavepoint { ref name } => {
5575 let mut savepoints = self.savepoints.write();
5576 if let Some(pos) = savepoints.iter().rposition(|s| &s.name == name) {
5577 savepoints.truncate(pos);
5578 Ok(0)
5579 } else {
5580 Err(Error::query_execution(format!("Savepoint '{}' does not exist", name)))
5581 }
5582 }
5583 sql::LogicalPlan::RollbackToSavepoint { ref name } => {
5584 let savepoints = self.savepoints.read();
5585 if let Some(pos) = savepoints.iter().rposition(|s| &s.name == name) {
5586 let snapshot = savepoints.get(pos)
5587 .map(|s| s.write_set_snapshot.clone());
5588 drop(savepoints);
5589
5590 if let Some(snapshot) = snapshot {
5591 let txn = self.current_transaction.lock()
5592 .map_err(|_| Error::query_execution("Failed to lock transaction"))?;
5593 if let Some(t) = txn.as_ref() {
5594 t.rollback_to_savepoint(&snapshot);
5595 }
5596 drop(txn);
5597 }
5598
5599 let mut savepoints = self.savepoints.write();
5600 savepoints.truncate(pos + 1);
5601 Ok(0)
5602 } else {
5603 Err(Error::query_execution(format!("Savepoint '{}' does not exist", name)))
5604 }
5605 }
5606 sql::LogicalPlan::CreateDatabase { name, if_not_exists } => {
5607 let (count, _) = self.handle_create_database(name, *if_not_exists)?;
5608 Ok(count)
5609 }
5610 sql::LogicalPlan::DropDatabase { name, if_exists } => {
5611 let (count, _) = self.handle_drop_database(name, *if_exists)?;
5612 Ok(count)
5613 }
5614 _ => {
5615 let mut executor = sql::Executor::with_storage(&self.storage)
5617 .with_timeout(self.config.storage.query_timeout_ms);
5618 let results = executor.execute(&plan)?;
5619 Ok(results.len() as u64)
5620 }
5621 }
5622 }
5623
5624 pub fn execute_params(&self, sql: &str, params: &[Value]) -> Result<u64> {
5680 let (statement, _) = self.parse_cached(sql)?;
5682
5683 let catalog = self.storage.catalog();
5685 let planner = sql::Planner::with_catalog(&catalog);
5686 let plan = planner.statement_to_plan(statement)?;
5687
5688 let (count, _tuples) = self.execute_plan_with_params(&plan, params)?;
5690 Ok(count)
5691 }
5692
5693 pub fn execute_params_returning(&self, sql: &str, params: &[Value]) -> Result<(u64, Vec<Tuple>)> {
5728 let (statement, _) = self.parse_cached(sql)?;
5730
5731 let catalog = self.storage.catalog();
5733 let planner = sql::Planner::with_catalog(&catalog);
5734 let plan = planner.statement_to_plan(statement)?;
5735
5736 let out = self.execute_plan_with_params(&plan, params);
5738
5739 #[cfg(feature = "code-graph")]
5743 if out.is_ok() {
5744 let touched = Self::touched_table_from_sql(sql);
5745 self.maybe_auto_reparse(touched.as_deref());
5746 }
5747
5748 out
5749 }
5750
5751 fn project_returning_columns(
5762 tuple: &Tuple,
5763 schema: &Schema,
5764 returning_items: &Option<Vec<sql::logical_plan::ReturningItem>>,
5765 ) -> Option<Tuple> {
5766 let items = returning_items.as_ref()?;
5767
5768 let evaluator = sql::Evaluator::new(std::sync::Arc::new(schema.clone()));
5769 let mut projected_values = Vec::with_capacity(items.len());
5770
5771 for item in items {
5772 match item {
5773 sql::logical_plan::ReturningItem::Wildcard => {
5774 return Some(tuple.clone());
5776 }
5777 sql::logical_plan::ReturningItem::Column(col_name) => {
5778 if let Some(col_idx) = schema.get_column_index(col_name) {
5779 if let Some(val) = tuple.values.get(col_idx) {
5780 projected_values.push(val.clone());
5781 } else {
5782 projected_values.push(Value::Null);
5783 }
5784 } else {
5785 projected_values.push(Value::Null);
5787 }
5788 }
5789 sql::logical_plan::ReturningItem::Expression { expr, .. } => {
5790 match evaluator.evaluate(expr, tuple) {
5792 Ok(val) => projected_values.push(val),
5793 Err(_) => projected_values.push(Value::Null),
5794 }
5795 }
5796 }
5797 }
5798
5799 Some(Tuple::new(projected_values))
5800 }
5801
5802 fn resolve_excluded_refs(
5808 expr: &sql::logical_plan::LogicalExpr,
5809 excluded_map: &std::collections::HashMap<String, Value>,
5810 ) -> sql::logical_plan::LogicalExpr {
5811 match expr {
5812 sql::logical_plan::LogicalExpr::Column { table: Some(tbl), name }
5813 if tbl.eq_ignore_ascii_case("excluded") =>
5814 {
5815 if let Some(val) = excluded_map.get(&name.to_lowercase()) {
5817 sql::logical_plan::LogicalExpr::Literal(val.clone())
5818 } else {
5819 expr.clone()
5820 }
5821 }
5822 sql::logical_plan::LogicalExpr::BinaryExpr { left, op, right } => {
5823 sql::logical_plan::LogicalExpr::BinaryExpr {
5824 left: Box::new(Self::resolve_excluded_refs(left, excluded_map)),
5825 op: op.clone(),
5826 right: Box::new(Self::resolve_excluded_refs(right, excluded_map)),
5827 }
5828 }
5829 sql::logical_plan::LogicalExpr::UnaryExpr { op, expr: inner } => {
5830 sql::logical_plan::LogicalExpr::UnaryExpr {
5831 op: op.clone(),
5832 expr: Box::new(Self::resolve_excluded_refs(inner, excluded_map)),
5833 }
5834 }
5835 sql::logical_plan::LogicalExpr::Cast { expr: inner, data_type } => {
5836 sql::logical_plan::LogicalExpr::Cast {
5837 expr: Box::new(Self::resolve_excluded_refs(inner, excluded_map)),
5838 data_type: data_type.clone(),
5839 }
5840 }
5841 sql::logical_plan::LogicalExpr::ScalarFunction { fun, args } => {
5842 sql::logical_plan::LogicalExpr::ScalarFunction {
5843 fun: fun.clone(),
5844 args: args.iter().map(|a| Self::resolve_excluded_refs(a, excluded_map)).collect(),
5845 }
5846 }
5847 other => other.clone(),
5849 }
5850 }
5851
5852 pub(crate) fn returning_schema(
5854 table_schema: &Schema,
5855 returning_items: &[sql::logical_plan::ReturningItem],
5856 ) -> Schema {
5857 let columns = returning_items.iter()
5858 .flat_map(|item| {
5859 match item {
5860 sql::logical_plan::ReturningItem::Wildcard => {
5861 table_schema.columns.clone()
5862 }
5863 sql::logical_plan::ReturningItem::Column(col_name) => {
5864 if let Some(col) = table_schema.columns.iter().find(|c| &c.name == col_name) {
5865 vec![col.clone()]
5866 } else {
5867 vec![Column {
5868 name: col_name.clone(),
5869 data_type: DataType::Text,
5870 nullable: true,
5871 primary_key: false,
5872 source_table: None,
5873 source_table_name: None,
5874 default_expr: None,
5875 unique: false,
5876 storage_mode: crate::ColumnStorageMode::Default,
5877 }]
5878 }
5879 }
5880 sql::logical_plan::ReturningItem::Expression { alias, .. } => {
5881 vec![Column {
5882 name: alias.clone(),
5883 data_type: DataType::Text,
5884 nullable: true,
5885 primary_key: false,
5886 source_table: None,
5887 source_table_name: None,
5888 default_expr: None,
5889 unique: false,
5890 storage_mode: crate::ColumnStorageMode::Default,
5891 }]
5892 }
5893 }
5894 })
5895 .collect();
5896 Schema { columns }
5897 }
5898
5899 fn execute_plan_with_params(&self, plan: &sql::LogicalPlan, params: &[Value]) -> Result<(u64, Vec<Tuple>)> {
5904 let result = self.execute_plan_with_params_inner(plan, params);
5905 if result.is_ok()
5912 && matches!(
5913 plan,
5914 sql::LogicalPlan::Insert { .. }
5915 | sql::LogicalPlan::InsertSelect { .. }
5916 | sql::LogicalPlan::Update { .. }
5917 | sql::LogicalPlan::Delete { .. }
5918 )
5919 {
5920 self.invalidate_result_cache();
5921 }
5922 result
5923 }
5924
5925 fn execute_plan_with_params_inner(&self, plan: &sql::LogicalPlan, params: &[Value]) -> Result<(u64, Vec<Tuple>)> {
5926 if let sql::LogicalPlan::CreateDatabase { name, if_not_exists } = plan {
5932 return self.handle_create_database(name, *if_not_exists);
5933 }
5934 if let sql::LogicalPlan::DropDatabase { name, if_exists } = plan {
5935 return self.handle_drop_database(name, *if_exists);
5936 }
5937 match plan {
5938 sql::LogicalPlan::Insert { table_name, columns, values, returning, on_conflict } => {
5939 let catalog = self.storage.catalog();
5941 let schema = catalog.get_table_schema(table_name)?;
5942
5943 let evaluator = sql::Evaluator::with_parameters(
5945 std::sync::Arc::new(Schema { columns: vec![] }),
5946 params.to_vec(),
5947 );
5948 let empty_tuple = Tuple::new(vec![]);
5949
5950 let has_returning = returning.is_some();
5951 let mut returned_tuples: Vec<Tuple> = Vec::new();
5952 let mut count = 0;
5953 for value_row in values {
5954 let mut tuple_values: Vec<Value> = vec![Value::Null; schema.columns.len()];
5962 let mut user_provided: Vec<bool> = vec![false; schema.columns.len()];
5963
5964 for (col_idx, expr) in value_row.iter().enumerate() {
5965 let target_col_idx = if let Some(ref cols) = columns {
5966 let col_name = cols.get(col_idx)
5967 .ok_or_else(|| Error::internal("column index out of bounds"))?;
5968 schema.get_column_index(col_name)
5969 .ok_or_else(|| Error::query_execution(format!("Column '{}' not found", col_name)))?
5970 } else {
5971 col_idx
5972 };
5973
5974 if matches!(expr, sql::LogicalExpr::DefaultValue) {
5979 continue;
5980 }
5981
5982 let target_col = schema.get_column_at(target_col_idx)
5983 .ok_or_else(|| Error::query_execution(format!(
5984 "Too many values for INSERT: table has {} columns",
5985 schema.columns.len()
5986 )))?;
5987
5988 let target_type = &target_col.data_type;
5989 let mut value = evaluator.evaluate(expr, &empty_tuple)?;
5990
5991 let needs_cast = match (&value, target_type) {
5993 (Value::Null, _) => false,
5994 (Value::Vector(_), DataType::Vector(_)) => false,
5995 (Value::String(_), DataType::Vector(_)) => true,
5996 (Value::String(_), DataType::Json | DataType::Jsonb) => true,
5997 (Value::Int4(_), DataType::Int4) => false,
5998 (Value::Int8(_), DataType::Int8) => false,
5999 (Value::Float4(_), DataType::Float4) => false,
6000 (Value::Float8(_), DataType::Float8) => false,
6001 (Value::String(_), DataType::Text | DataType::Varchar(_)) => false,
6002 (Value::Boolean(_), DataType::Boolean) => false,
6003 (Value::Json(_), DataType::Json | DataType::Jsonb) => false,
6004 _ => true,
6005 };
6006
6007 if needs_cast {
6008 value = evaluator.cast_value(value, target_type)?;
6009 }
6010
6011 if let Some(slot) = tuple_values.get_mut(target_col_idx) {
6012 *slot = value;
6013 }
6014 if let Some(flag) = user_provided.get_mut(target_col_idx) {
6015 *flag = true;
6016 }
6017 }
6018
6019 Self::apply_defaults_and_check_not_null(
6025 &mut tuple_values, &schema, &user_provided,
6026 )?;
6027
6028 let tuple = Tuple::new(tuple_values);
6029
6030 let mut col_values_map = std::collections::HashMap::with_capacity(schema.columns.len());
6039 for (i, col) in schema.columns.iter().enumerate() {
6040 if let Some(v) = tuple.values.get(i) {
6041 col_values_map.insert(col.name.clone(), v.clone());
6042 }
6043 }
6044
6045 let conflict = self.storage.art_indexes()
6046 .check_unique_constraints(table_name, &col_values_map);
6047
6048 match (conflict, on_conflict) {
6049 (Ok(()), _) => {
6050 self.check_fk_constraints_on_write(table_name, &col_values_map, None)?;
6063 let row_id = self.storage.insert_tuple_branch_aware_with_schema(
6064 table_name, tuple.clone(), &schema,
6065 )?;
6066 if has_returning {
6067 let mut filled = tuple;
6068 for (i, col) in schema.columns.iter().enumerate() {
6069 if col.primary_key {
6070 if let Some(v) = filled.values.get(i) {
6071 if matches!(v, Value::Null) {
6072 if let Some(slot) = filled.values.get_mut(i) {
6073 *slot = match col.data_type {
6074 DataType::Int2 => Value::Int2(row_id as i16),
6075 DataType::Int4 => Value::Int4(row_id as i32),
6076 _ => Value::Int8(row_id as i64),
6077 };
6078 }
6079 }
6080 }
6081 }
6082 }
6083 filled.row_id = Some(row_id);
6084 if let Some(projected) = Self::project_returning_columns(&filled, &schema, returning) {
6085 returned_tuples.push(projected);
6086 }
6087 }
6088 count += 1;
6089 }
6090 (Err(_), Some(sql::logical_plan::OnConflictAction::DoNothing)) => {
6091 }
6093 (Err(e), Some(sql::logical_plan::OnConflictAction::DoUpdate { assignments })) => {
6094 let mut found_row_id: Option<u64> = None;
6110 let art = self.storage.art_indexes();
6111 for (i, col) in schema.columns.iter().enumerate() {
6112 if (col.unique || col.primary_key) && !col.primary_key {
6113 if let Some(val) = tuple.values.get(i) {
6114 if !matches!(val, Value::Null) {
6115 if let Some(name) = art.find_column_index(table_name, &col.name) {
6119 let key = storage::ArtIndexManager::encode_key(&[val.clone()]);
6120 let row_ids = art.index_get_all(&name, &key);
6121 if let Some(rid) = row_ids.first() {
6122 found_row_id = Some(*rid);
6123 }
6124 }
6125 if found_row_id.is_some() { break; }
6126 }
6127 }
6128 }
6129 }
6130 if found_row_id.is_none() {
6131 let pk_values: Vec<Value> = schema.columns.iter().enumerate()
6132 .filter(|(_, c)| c.primary_key)
6133 .filter_map(|(i, _)| tuple.values.get(i).cloned())
6134 .collect();
6135 if !pk_values.is_empty()
6136 && !pk_values.iter().any(|v| matches!(v, Value::Null))
6137 {
6138 let pk_key = crate::storage::ArtIndexManager::encode_key(&pk_values);
6139 found_row_id = self.storage.art_indexes()
6140 .pk_index_lookup(table_name, &pk_key);
6141 }
6142 }
6143 let existing_row_id = found_row_id.ok_or_else(|| Error::query_execution(
6144 format!("ON CONFLICT DO UPDATE: could not find existing row ({})", e)
6145 ))?;
6146
6147 let existing_key = self.storage.branch_aware_data_key(
6149 table_name, existing_row_id,
6150 );
6151 let existing_raw = self.storage.get(&existing_key)?
6152 .ok_or_else(|| Error::query_execution(
6153 "ON CONFLICT DO UPDATE: existing row not found in storage"
6154 ))?;
6155 let mut updated_tuple: Tuple = bincode::deserialize(&existing_raw)
6156 .map_err(|err| Error::storage(format!("Failed to deserialize tuple: {}", err)))?;
6157 updated_tuple.row_id = Some(existing_row_id);
6158
6159 let old_tuple_for_art: Tuple = bincode::deserialize(&existing_raw)
6161 .map_err(|err| Error::storage(format!("Failed to deserialize tuple: {}", err)))?;
6162
6163 let mut excluded_map = std::collections::HashMap::new();
6169 for (i, col) in schema.columns.iter().enumerate() {
6170 if let Some(v) = tuple.values.get(i) {
6171 excluded_map.insert(col.name.to_lowercase(), v.clone());
6172 }
6173 }
6174 let update_eval = sql::Evaluator::with_parameters(
6175 std::sync::Arc::new(schema.clone()),
6176 params.to_vec(),
6177 );
6178 for (col_name, expr) in assignments {
6179 let target_idx = schema.columns.iter()
6180 .position(|c| c.name.eq_ignore_ascii_case(col_name))
6181 .ok_or_else(|| Error::query_execution(format!(
6182 "ON CONFLICT DO UPDATE: column '{}' not found", col_name
6183 )))?;
6184 let resolved_expr = Self::resolve_excluded_refs(expr, &excluded_map);
6185 let mut new_val = update_eval.evaluate(&resolved_expr, &updated_tuple)?;
6186 let target_type = &schema.columns.get(target_idx)
6187 .ok_or_else(|| Error::internal("column index out of bounds"))?
6188 .data_type;
6189 if new_val.data_type() != *target_type
6190 && !matches!(new_val, Value::Null)
6191 {
6192 new_val = update_eval.cast_value(new_val, target_type)?;
6193 }
6194 if target_idx < updated_tuple.values.len() {
6195 #[allow(clippy::indexing_slicing)]
6196 { updated_tuple.values[target_idx] = new_val; }
6197 }
6198 }
6199
6200 self.storage.update_tuple_fast(
6204 table_name,
6205 existing_row_id,
6206 updated_tuple.clone(),
6207 &old_tuple_for_art,
6208 &schema,
6209 )?;
6210
6211 if has_returning {
6212 if let Some(projected) = Self::project_returning_columns(
6213 &updated_tuple, &schema, returning,
6214 ) {
6215 returned_tuples.push(projected);
6216 }
6217 }
6218 count += 1;
6219 }
6220 (Err(e), None) => {
6221 return Err(Error::constraint_violation(e.to_string()));
6223 }
6224 }
6225 }
6226 Ok((count, returned_tuples))
6227 }
6228 sql::LogicalPlan::InsertSelect { table_name, columns, source, returning } => {
6229 let mut executor = sql::Executor::with_storage(&self.storage)
6231 .with_timeout(self.config.storage.query_timeout_ms);
6232 let source_rows = executor.execute(source)?;
6233
6234 let catalog = self.storage.catalog();
6235 let schema = catalog.get_table_schema(table_name)?;
6236 let evaluator = sql::Evaluator::new(std::sync::Arc::new(Schema { columns: vec![] }));
6237
6238 let column_indices: Option<Vec<usize>> = columns.as_ref().map(|cols| {
6239 cols.iter()
6240 .filter_map(|col_name| schema.get_column_index(col_name))
6241 .collect()
6242 });
6243
6244 let has_returning = returning.is_some();
6245 let mut returned_tuples: Vec<Tuple> = Vec::new();
6246 let mut count = 0u64;
6247
6248 for source_row in &source_rows {
6249 let mut tuple_values: Vec<Value> = Vec::new();
6250
6251 for (val_idx, value) in source_row.values.iter().enumerate() {
6252 let target_col_idx = if let Some(ref indices) = column_indices {
6253 *indices.get(val_idx).ok_or_else(|| Error::internal("column index out of bounds"))?
6254 } else {
6255 val_idx
6256 };
6257
6258 let target_col = schema.get_column_at(target_col_idx)
6259 .ok_or_else(|| Error::query_execution(format!(
6260 "Too many values for INSERT: table has {} columns",
6261 schema.columns.len()
6262 )))?;
6263
6264 let target_type = &target_col.data_type;
6265 let mut val = value.clone();
6266
6267 let needs_cast = match (&val, target_type) {
6268 (Value::Null, _) => false,
6269 (Value::Vector(_), DataType::Vector(_)) => false,
6270 (Value::String(_), DataType::Vector(_)) => true,
6271 (Value::String(_), DataType::Json | DataType::Jsonb) => true,
6272 (Value::Int4(_), DataType::Int4) => false,
6273 (Value::Int8(_), DataType::Int8) => false,
6274 (Value::Float4(_), DataType::Float4) => false,
6275 (Value::Float8(_), DataType::Float8) => false,
6276 (Value::String(_), DataType::Text | DataType::Varchar(_)) => false,
6277 (Value::Boolean(_), DataType::Boolean) => false,
6278 (Value::Json(_), DataType::Json | DataType::Jsonb) => false,
6279 _ => true,
6280 };
6281
6282 if needs_cast {
6283 val = evaluator.cast_value(val, target_type)?;
6284 }
6285
6286 tuple_values.push(val);
6287 }
6288
6289 let tuple = Tuple::new(tuple_values);
6290 if has_returning {
6291 if let Some(projected) = Self::project_returning_columns(&tuple, &schema, returning) {
6292 returned_tuples.push(projected);
6293 }
6294 }
6295 self.storage.insert_tuple_branch_aware_with_schema(table_name, tuple, &schema)?;
6296 count += 1;
6297 }
6298 Ok((count, returned_tuples))
6299 }
6300 sql::LogicalPlan::Update { table_name, assignments, selection, returning } => {
6301 let catalog = self.storage.catalog();
6302 let schema = catalog.get_table_schema(table_name)?;
6303 let eval_schema = schema.clone().with_source_table_name(table_name);
6307 let evaluator = sql::Evaluator::with_parameters(
6308 std::sync::Arc::new(eval_schema),
6309 params.to_vec(),
6310 );
6311
6312 let tuples = self.storage.scan_table_branch_aware(table_name)?;
6314 let mut updates: Vec<(u64, Tuple)> = Vec::new();
6315
6316 for mut tuple in tuples {
6317 let matches = if let Some(predicate) = selection {
6318 let result = evaluator.evaluate(predicate, &tuple)?;
6319 match result {
6320 Value::Boolean(b) => b,
6321 _ => false,
6322 }
6323 } else {
6324 true
6325 };
6326
6327 if matches {
6328 for (col_name, value_expr) in assignments {
6329 let bound = self.materialize_scalar_subqueries_for_row(
6330 value_expr, &tuple, &schema, table_name,
6331 )?;
6332 let mut new_value = evaluator.evaluate(&bound, &tuple)?;
6333 let col_index = evaluator.schema().get_column_index(col_name)
6334 .ok_or_else(|| Error::query_execution(format!("Column '{}' not found", col_name)))?;
6335 let target_col = schema.get_column_at(col_index)
6343 .ok_or_else(|| Error::query_execution(format!("Column '{}' not found", col_name)))?;
6344 let target_type = &target_col.data_type;
6345 let needs_cast = match (&new_value, target_type) {
6346 (Value::Null, _) => false,
6347 (Value::Vector(_), DataType::Vector(_)) => false,
6348 (Value::String(_), DataType::Vector(_)) => true,
6349 (Value::String(_), DataType::Json | DataType::Jsonb) => true,
6350 (Value::Int2(_), DataType::Int2) => false,
6351 (Value::Int4(_), DataType::Int4) => false,
6352 (Value::Int8(_), DataType::Int8) => false,
6353 (Value::Float4(_), DataType::Float4) => false,
6354 (Value::Float8(_), DataType::Float8) => false,
6355 (Value::String(_), DataType::Text | DataType::Varchar(_)) => false,
6356 (Value::Boolean(_), DataType::Boolean) => false,
6357 (Value::Json(_), DataType::Json | DataType::Jsonb) => false,
6358 (Value::Timestamp(_), DataType::Timestamp | DataType::Timestamptz) => false,
6359 (Value::Date(_), DataType::Date) => false,
6360 _ => true,
6361 };
6362 if needs_cast {
6363 new_value = evaluator.cast_value(new_value, target_type)?;
6364 }
6365 if let Some(slot) = tuple.values.get_mut(col_index) {
6366 *slot = new_value;
6367 }
6368 }
6369
6370 let mut new_col_values = std::collections::HashMap::with_capacity(schema.columns.len());
6378 for (i, col) in schema.columns.iter().enumerate() {
6379 if let Some(v) = tuple.values.get(i) {
6380 new_col_values.insert(col.name.clone(), v.clone());
6381 }
6382 }
6383 self.check_fk_constraints_on_write(table_name, &new_col_values, None)?;
6384
6385 let row_id = tuple.row_id.unwrap_or(0);
6386 updates.push((row_id, tuple));
6387 }
6388 }
6389
6390 let returned_tuples: Vec<Tuple> = if returning.is_some() {
6392 updates.iter()
6393 .filter_map(|(_, tuple)| Self::project_returning_columns(tuple, &schema, returning))
6394 .collect()
6395 } else {
6396 Vec::new()
6397 };
6398
6399 let count = self.storage.update_tuples_branch_aware(table_name, updates)?;
6401 Ok((count, returned_tuples))
6402 }
6403 sql::LogicalPlan::Delete { table_name, selection, returning } => {
6404 let catalog = self.storage.catalog();
6405 let schema = catalog.get_table_schema(table_name)?;
6406 let eval_schema = schema.clone().with_source_table_name(table_name);
6407 let evaluator = sql::Evaluator::with_parameters(
6408 std::sync::Arc::new(eval_schema),
6409 params.to_vec(),
6410 );
6411
6412 let tuples = self.storage.scan_table_branch_aware(table_name)?;
6414 let mut row_ids_to_delete: Vec<u64> = Vec::new();
6415 let mut deleted_tuples: Vec<(u64, Tuple)> = Vec::new();
6416 let mut returned_tuples: Vec<Tuple> = Vec::new();
6417 let has_returning = returning.is_some();
6418
6419 for tuple in tuples {
6420 let matches = if let Some(predicate) = selection {
6421 let result = evaluator.evaluate(predicate, &tuple)?;
6422 match result {
6423 Value::Boolean(b) => b,
6424 _ => false,
6425 }
6426 } else {
6427 true
6428 };
6429
6430 if matches {
6431 if has_returning {
6433 if let Some(projected) = Self::project_returning_columns(&tuple, &schema, returning) {
6434 returned_tuples.push(projected);
6435 }
6436 }
6437
6438 if let Some(row_id) = tuple.row_id {
6439 row_ids_to_delete.push(row_id);
6440 deleted_tuples.push((row_id, tuple.clone()));
6441 }
6442 }
6443 }
6444
6445 for (row_id, tuple) in &deleted_tuples {
6447 let mut col_values = std::collections::HashMap::new();
6448 for (i, col) in schema.columns.iter().enumerate() {
6449 if let Some(v) = tuple.values.get(i) {
6450 col_values.insert(col.name.clone(), v.clone());
6451 }
6452 }
6453 if let Err(e) = self.storage.art_indexes().on_delete(table_name, *row_id, &col_values) {
6454 tracing::debug!("ART index delete for table '{}': {}", table_name, e);
6455 }
6456 }
6457
6458 let count = self.storage.delete_tuples_branch_aware(table_name, row_ids_to_delete)?;
6460 Ok((count, returned_tuples))
6461 }
6462 sql::LogicalPlan::StartTransaction => {
6464 self.begin_transaction_internal()?;
6465 Ok((0, Vec::new()))
6466 }
6467 sql::LogicalPlan::Commit => {
6468 self.commit_internal()?;
6469 Ok((0, Vec::new()))
6470 }
6471 sql::LogicalPlan::Rollback => {
6472 self.rollback_internal()?;
6473 Ok((0, Vec::new()))
6474 }
6475 sql::LogicalPlan::Savepoint { name } => {
6477 let txn = self.current_transaction.lock()
6479 .map_err(|_| Error::query_execution("Failed to lock transaction"))?;
6480 let write_set_snapshot = match txn.as_ref() {
6481 Some(t) => t.savepoint_snapshot(),
6482 None => return Err(Error::query_execution("SAVEPOINT can only be used within a transaction")),
6483 };
6484 drop(txn);
6485
6486 let savepoint = SavepointState {
6487 name: name.clone(),
6488 write_set_snapshot,
6489 };
6490 self.savepoints.write().push(savepoint);
6491 Ok((0, Vec::new()))
6492 }
6493 sql::LogicalPlan::ReleaseSavepoint { name } => {
6494 let mut savepoints = self.savepoints.write();
6495 if let Some(pos) = savepoints.iter().rposition(|s| &s.name == name) {
6497 savepoints.truncate(pos);
6498 Ok((0, Vec::new()))
6499 } else {
6500 Err(Error::query_execution(format!("Savepoint '{}' does not exist", name)))
6501 }
6502 }
6503 sql::LogicalPlan::RollbackToSavepoint { name } => {
6504 let savepoints = self.savepoints.read();
6505 if let Some(pos) = savepoints.iter().rposition(|s| &s.name == name) {
6507 let snapshot = savepoints.get(pos)
6508 .map(|s| s.write_set_snapshot.clone());
6509 drop(savepoints);
6510
6511 if let Some(snapshot) = snapshot {
6513 let txn = self.current_transaction.lock()
6514 .map_err(|_| Error::query_execution("Failed to lock transaction"))?;
6515 if let Some(t) = txn.as_ref() {
6516 t.rollback_to_savepoint(&snapshot);
6517 }
6518 drop(txn);
6519 }
6520
6521 let mut savepoints = self.savepoints.write();
6523 savepoints.truncate(pos + 1);
6524 Ok((0, Vec::new()))
6525 } else {
6526 Err(Error::query_execution(format!("Savepoint '{}' does not exist", name)))
6527 }
6528 }
6529 sql::LogicalPlan::Prepare { name, statement, .. } => {
6531 self.prepared_statements.write().insert(name.clone(), *statement.clone());
6533 Ok((0, Vec::new()))
6534 }
6535 sql::LogicalPlan::Execute { name, parameters } => {
6536 let stmt = {
6538 let stmts = self.prepared_statements.read();
6539 stmts.get(name).cloned()
6540 };
6541 if let Some(plan) = stmt {
6542 let empty_tuple = Tuple::new(vec![]);
6544 let empty_schema = std::sync::Arc::new(Schema { columns: vec![] });
6545 let evaluator = sql::Evaluator::new(empty_schema);
6546 let param_values: Result<Vec<Value>> = parameters.iter()
6547 .map(|expr| evaluator.evaluate(expr, &empty_tuple))
6548 .collect();
6549 self.execute_plan_with_params(&plan, ¶m_values?)
6551 } else {
6552 Err(Error::query_execution(format!("Prepared statement '{}' does not exist", name)))
6553 }
6554 }
6555 sql::LogicalPlan::Deallocate { name } => {
6556 if let Some(ref stmt_name) = name {
6557 let removed = self.prepared_statements.write().remove(stmt_name);
6559 if removed.is_none() {
6560 return Err(Error::query_execution(format!("Prepared statement '{}' does not exist", stmt_name)));
6561 }
6562 } else {
6563 self.prepared_statements.write().clear();
6565 }
6566 Ok((0, Vec::new()))
6567 }
6568 _ => {
6569 let mut executor = sql::Executor::with_storage(&self.storage)
6571 .with_timeout(self.config.storage.query_timeout_ms)
6572 .with_parameters(params.to_vec());
6573 let results = executor.execute(plan)?;
6574 Ok((results.len() as u64, Vec::new()))
6575 }
6576 }
6577 }
6578
6579 pub fn query(&self, sql: &str, _params: &[&dyn std::fmt::Display]) -> Result<Vec<Tuple>> {
6601 if let Some((name, arg)) = crate::sql::sqlite_compat::parse_pragma(sql) {
6605 return self.handle_pragma_query(&name, arg.as_deref());
6606 }
6607
6608 #[cfg(feature = "graph-rag")]
6613 if let Some((inner, opts)) = graph_rag::detect_with_context(sql) {
6614 return self.run_with_context(&inner, &opts);
6615 }
6616 #[cfg(feature = "code-graph")]
6622 let (rewritten_owned, _branch_guard) = self.rewrite_and_scope(sql);
6623 #[cfg(feature = "code-graph")]
6624 let sql: &str = &rewritten_owned;
6625 #[cfg(not(feature = "code-graph"))]
6626 let sql: &str = sql;
6627 let start = std::time::Instant::now();
6628
6629 {
6632 let upper = sql.trim().to_uppercase();
6633 let is_dml = upper.starts_with("INSERT")
6634 || upper.starts_with("UPDATE")
6635 || upper.starts_with("DELETE");
6636 if is_dml && upper.contains("RETURNING") {
6637 let (_count, tuples) = self.execute_returning(sql)?;
6638 self.invalidate_result_cache();
6640 self.log_slow_query(sql, start.elapsed(), tuples.len() as u64);
6641 return Ok(tuples);
6642 }
6643 }
6644
6645 {
6648 use crate::error::LockResultExt;
6649 let has_active_txn = {
6650 let txn_lock = self.current_transaction.lock()
6651 .map_lock_err("Failed to acquire transaction lock for query")?;
6652 txn_lock.is_some()
6653 };
6654 if has_active_txn {
6655 let txn_lock = self.current_transaction.lock()
6656 .map_lock_err("Failed to acquire transaction lock for query")?;
6657 let txn_ref = txn_lock.as_ref()
6658 .ok_or_else(|| Error::transaction("Transaction lock in invalid state"))?;
6659 let (statement, _) = self.parse_cached(sql)?;
6661 let catalog = self.storage.catalog();
6662 let planner = sql::Planner::with_catalog(&catalog)
6663 .with_sql(sql.to_string());
6664 let plan = planner.statement_to_plan(statement)?;
6665 let mut executor = sql::Executor::with_storage(&self.storage)
6666 .with_timeout(self.config.storage.query_timeout_ms)
6667 .with_transaction(txn_ref);
6668 let results = executor.execute(&plan)?;
6669 self.log_slow_query(sql, start.elapsed(), results.len() as u64);
6670 return Ok(results);
6671 }
6672 }
6673
6674 let is_non_deterministic = {
6681 let up = sql.to_ascii_uppercase();
6682 ["NEXTVAL", "SETVAL", "CURRVAL", "GEN_RANDOM_UUID",
6683 "UUID_GENERATE_V4", "RANDOM(", "NOW(", "CLOCK_TIMESTAMP"]
6684 .iter()
6685 .any(|m| up.contains(m))
6686 };
6687 if !is_non_deterministic {
6688 if let Some(cached_results) = self.result_cache.lock().ok()
6689 .and_then(|mut cache| cache.get(sql).map(std::sync::Arc::clone))
6690 {
6691 tracing::debug!(phase = "result_cache", "Result cache hit");
6692 self.log_slow_query(sql, start.elapsed(), cached_results.len() as u64);
6693 return Ok((*cached_results).clone());
6694 }
6695 }
6696
6697 if let Some(result) = self.try_fast_select(sql) {
6699 let results = result?;
6700 self.log_slow_query(sql, start.elapsed(), results.len() as u64);
6701 return Ok(results);
6702 }
6703
6704 let cached_plan = self.plan_cache.lock().ok().and_then(|mut cache| cache.get(sql).map(std::sync::Arc::clone));
6706
6707 if let Some(arc_plan) = cached_plan {
6708 tracing::debug!(phase = "parse", duration_us = 0_u64, "SQL parsed (cached)");
6709 tracing::debug!(phase = "plan", duration_us = 0_u64, "Logical plan created (cached)");
6710
6711 if self.tenant_manager.get_current_context().is_none() {
6713 let exec_start = std::time::Instant::now();
6714 let mut executor = sql::Executor::with_storage(&self.storage)
6715 .with_timeout(self.config.storage.query_timeout_ms);
6716 let results = executor.execute(&arc_plan)?;
6717 tracing::debug!(phase = "execute", duration_us = exec_start.elapsed().as_micros() as u64, rows = results.len() as u64, "Query executed");
6718 self.log_slow_query(sql, start.elapsed(), results.len() as u64);
6719 if let Ok(mut cache) = self.result_cache.lock() {
6721 cache.put(sql.to_string(), std::sync::Arc::new(results.clone()));
6722 }
6723 return Ok(results);
6724 }
6725
6726 let plan = self.apply_rls_to_plan((*arc_plan).clone())?;
6728 let exec_start = std::time::Instant::now();
6729 let mut executor = sql::Executor::with_storage(&self.storage)
6730 .with_timeout(self.config.storage.query_timeout_ms);
6731 let results = executor.execute(&plan)?;
6732 tracing::debug!(phase = "execute", duration_us = exec_start.elapsed().as_micros() as u64, rows = results.len() as u64, "Query executed");
6733 self.log_slow_query(sql, start.elapsed(), results.len() as u64);
6734 return Ok(results);
6735 }
6736
6737 let parse_start = std::time::Instant::now();
6740 let (statement, _parse_cached) = self.parse_cached(sql)?;
6741 tracing::debug!(phase = "parse", duration_us = parse_start.elapsed().as_micros() as u64, "SQL parsed");
6742
6743 let plan_start = std::time::Instant::now();
6745 let catalog = self.storage.catalog();
6746 let planner = sql::Planner::with_catalog(&catalog)
6747 .with_sql(sql.to_string());
6748 let plan = planner.statement_to_plan(statement)?;
6749 tracing::debug!(phase = "plan", duration_us = plan_start.elapsed().as_micros() as u64, "Logical plan created");
6750
6751 let plan = {
6753 let opt_start = std::time::Instant::now();
6754 let stats = optimizer::cost::StatsCatalog::new();
6755 let rules: Vec<Box<dyn optimizer::rules::OptimizationRule>> = vec![
6756 Box::new(optimizer::rules::ConstantFoldingRule::new()),
6757 Box::new(optimizer::rules::SelectionPushdownRule::new()),
6758 Box::new(optimizer::rules::JoinPredicatePushdownRule::new()),
6759 Box::new(optimizer::rules::ProjectionPruningRule::new()),
6760 ];
6761 let opt = optimizer::Optimizer::with_rules(
6762 stats,
6763 rules,
6764 optimizer::OptimizerConfig::default(),
6765 );
6766 let optimized = opt.optimize_recursive(plan)?;
6767 tracing::debug!(phase = "optimize", duration_us = opt_start.elapsed().as_micros() as u64, "Plan optimized");
6768 optimized
6769 };
6770
6771 if let Ok(mut cache) = self.plan_cache.lock() {
6773 cache.put(sql.to_string(), std::sync::Arc::new(plan.clone()));
6774 }
6775
6776 let plan = self.apply_rls_to_plan(plan)?;
6778
6779 let exec_start = std::time::Instant::now();
6781 let mut executor = sql::Executor::with_storage(&self.storage)
6782 .with_timeout(self.config.storage.query_timeout_ms);
6783 let results = executor.execute(&plan)?;
6784 tracing::debug!(phase = "execute", duration_us = exec_start.elapsed().as_micros() as u64, rows = results.len() as u64, "Query executed");
6785
6786 self.log_slow_query(sql, start.elapsed(), results.len() as u64);
6787
6788 if let Ok(mut cache) = self.result_cache.lock() {
6790 cache.put(sql.to_string(), std::sync::Arc::new(results.clone()));
6791 }
6792
6793 Ok(results)
6794 }
6795
6796 pub fn query_with_columns(&self, sql: &str) -> Result<(Vec<Tuple>, Vec<String>)> {
6802 let plan = if sql::Parser::is_show_branches(sql) {
6807 sql::LogicalPlan::ShowBranches
6808 } else {
6809 let (statement, _) = self.parse_cached(sql)?;
6810 let catalog = self.storage.catalog();
6811 let planner = sql::Planner::with_catalog(&catalog)
6812 .with_sql(sql.to_string());
6813 planner.statement_to_plan(statement)?
6814 };
6815
6816 let plan = {
6817 let stats = optimizer::cost::StatsCatalog::new();
6818 let rules: Vec<Box<dyn optimizer::rules::OptimizationRule>> = vec![
6819 Box::new(optimizer::rules::ConstantFoldingRule::new()),
6820 Box::new(optimizer::rules::SelectionPushdownRule::new()),
6821 Box::new(optimizer::rules::JoinPredicatePushdownRule::new()),
6822 Box::new(optimizer::rules::ProjectionPruningRule::new()),
6823 ];
6824 let opt = optimizer::Optimizer::with_rules(
6825 stats,
6826 rules,
6827 optimizer::OptimizerConfig::default(),
6828 );
6829 opt.optimize_recursive(plan)?
6830 };
6831
6832 let mut executor = sql::Executor::with_storage(&self.storage)
6833 .with_timeout(self.config.storage.query_timeout_ms);
6834 executor.execute_with_columns(&plan)
6835 }
6836
6837 pub fn dump_full(&self, path: &std::path::Path) -> Result<storage::DumpMetadata> {
6850 self.dump_manager.create_full_dump(path, self)
6851 }
6852
6853 pub fn dump_sql(&self, path: &std::path::Path) -> Result<storage::DumpMetadata> {
6863 self.dump_manager.create_sql_dump(path, self)
6864 }
6865
6866 pub fn dump_incremental(&self, path: &std::path::Path) -> Result<storage::DumpMetadata> {
6875 self.dump_manager.create_incremental_dump(path, self, false)
6876 }
6877
6878 pub fn dump_incremental_append(&self, path: &std::path::Path) -> Result<storage::DumpMetadata> {
6886 self.dump_manager.create_incremental_dump(path, self, true)
6887 }
6888
6889 pub fn restore_from_dump(&mut self, path: &std::path::Path) -> Result<()> {
6898 let dump_manager = self.dump_manager.clone();
6899 dump_manager.restore_from_dump(path, self)
6900 }
6901
6902 pub fn dump_full_compressed(&self, path: &std::path::Path, compression: storage::DumpCompressionType) -> Result<storage::DumpMetadata> {
6909 let manager = storage::DumpManager::new(
6911 path.parent().unwrap_or(std::path::Path::new(".")).to_path_buf(),
6912 compression
6913 );
6914 manager.create_full_dump(path, self)
6915 }
6916
6917 pub fn dump_full_uncompressed(&self, path: &std::path::Path) -> Result<storage::DumpMetadata> {
6925 self.dump_full_compressed(path, storage::DumpCompressionType::None)
6926 }
6927
6928 pub fn dump_tables(&self, path: &std::path::Path, tables: Vec<&str>) -> Result<storage::DumpMetadata> {
6937 self.dump_full(path)
6941 }
6942
6943 pub fn restore_tables(&mut self, path: &std::path::Path, _tables: Vec<&str>) -> Result<()> {
6950 self.restore_from_dump(path)
6952 }
6953
6954 pub fn read_dump_metadata(&self, path: &std::path::Path) -> Result<storage::DumpMetadata> {
6963 use std::io::{Read, Seek, SeekFrom};
6964 let file = std::fs::File::open(path).map_err(|e| Error::io(e.to_string()))?;
6965 let mut reader = std::io::BufReader::new(file);
6966
6967 reader.seek(SeekFrom::Start(12)).map_err(|e| Error::io(e.to_string()))?;
6969
6970 let mut len_bytes = [0u8; 4];
6972 reader.read_exact(&mut len_bytes).map_err(|e| Error::io(e.to_string()))?;
6973 let len = u32::from_le_bytes(len_bytes) as usize;
6974
6975 if len == 0 || len > 8192 {
6976 return Err(Error::io("Invalid metadata length".to_string()));
6977 }
6978
6979 let mut json_bytes = vec![0u8; len];
6981 reader.read_exact(&mut json_bytes).map_err(|e| Error::io(e.to_string()))?;
6982
6983 let metadata: storage::DumpMetadata = serde_json::from_slice(&json_bytes)
6984 .map_err(|e| Error::io(format!("Failed to deserialize metadata: {}", e)))?;
6985
6986 Ok(metadata)
6987 }
6988
6989 pub fn create_session(&self, user_name: &str, isolation: crate::session::IsolationLevel) -> Result<crate::session::SessionId> {
7026 let user = crate::session::User::new_passwordless(user_name);
7027 self.session_manager.create_session(&user, isolation)
7028 }
7029
7030 pub fn destroy_session(&self, session_id: crate::session::SessionId) -> Result<()> {
7040 self.session_manager.destroy_session(session_id)
7041 }
7042
7043 pub fn begin_transaction_for_session(&self, session_id: crate::session::SessionId) -> Result<()> {
7056 let session_lock = self.session_manager.get_session(session_id)?;
7057 let mut session = session_lock.write();
7058
7059 if session.active_txn.is_some() {
7060 return Err(Error::transaction("Session already has an active transaction"));
7061 }
7062
7063 let txn = storage::Transaction::new_with_session(
7065 self.storage.db.clone(),
7066 self.storage.next_timestamp(),
7067 self.storage.snapshot_manager_arc(),
7068 session_id,
7069 session.isolation_level,
7070 self.lock_manager.clone(),
7071 self.dirty_tracker.clone(),
7072 )?;
7073
7074 let txn_id = txn.snapshot_id();
7075 session.active_txn = Some(txn_id);
7076 session.stats.transactions_started += 1;
7077
7078 self.session_transactions.insert(session_id, txn);
7080
7081 Ok(())
7082 }
7083
7084 pub fn commit_transaction_for_session(&self, session_id: crate::session::SessionId) -> Result<()> {
7097 let session_lock = self.session_manager.get_session(session_id)?;
7098 let mut session = session_lock.write();
7099
7100 if session.active_txn.is_none() {
7101 return Err(Error::transaction("Session has no active transaction to commit"));
7102 }
7103
7104 if let Some((_, txn)) = self.session_transactions.remove(&session_id) {
7106 txn.commit_with_timestamp(self.storage.next_timestamp())?;
7107 self.storage.increment_lsn();
7108 }
7109
7110 self.invalidate_result_cache();
7112
7113 session.active_txn = None;
7114 session.stats.transactions_committed += 1;
7115 Ok(())
7116 }
7117
7118 pub fn rollback_transaction_for_session(&self, session_id: crate::session::SessionId) -> Result<()> {
7132 let session_lock = self.session_manager.get_session(session_id)?;
7133 let mut session = session_lock.write();
7134
7135 if session.active_txn.is_none() {
7136 return Err(Error::transaction("Session has no active transaction to rollback"));
7137 }
7138
7139 if let Some((_, txn)) = self.session_transactions.remove(&session_id) {
7141 txn.rollback()?;
7142 }
7143
7144 self.invalidate_result_cache();
7146
7147 session.active_txn = None;
7148 session.stats.transactions_aborted += 1;
7149 Ok(())
7150 }
7151
7152 pub fn execute_in_session(&self, session_id: crate::session::SessionId, sql: &str) -> Result<u64> {
7171 let session_lock = self.session_manager.get_session(session_id)?;
7172 let mut session = session_lock.write();
7173 session.touch();
7174 session.stats.queries_executed += 1;
7175
7176 if self.session_transactions.contains_key(&session_id) {
7178 if session.isolation_level == crate::session::IsolationLevel::ReadCommitted {
7181 if let Some(mut txn) = self.session_transactions.get_mut(&session_id) {
7182 txn.refresh_snapshot(self.storage.current_timestamp());
7183 }
7184 }
7185
7186 let txn = self.session_transactions.get(&session_id)
7189 .ok_or_else(|| Error::transaction("Session transaction disappeared during execute"))?;
7190
7191 self.execute_in_transaction_no_fast_path(sql, &txn)
7194 } else {
7195 let txn = storage::Transaction::new_with_session(
7198 self.storage.db.clone(),
7199 self.storage.next_timestamp(),
7200 self.storage.snapshot_manager_arc(),
7201 session_id,
7202 session.isolation_level,
7203 self.lock_manager.clone(),
7204 self.dirty_tracker.clone(),
7205 )?;
7206
7207 let result = self.execute_in_transaction_no_fast_path(sql, &txn);
7208
7209 match result {
7210 Ok(count) => {
7211 txn.commit_with_timestamp(self.storage.next_timestamp())?;
7212 self.storage.increment_lsn();
7213 Ok(count)
7214 }
7215 Err(e) => {
7216 let _ = txn.rollback();
7217 Err(e)
7218 }
7219 }
7220 }
7221 }
7222
7223 pub fn query_in_session(&self, session_id: crate::session::SessionId, sql: &str, _params: &[&dyn std::fmt::Display]) -> Result<Vec<Tuple>> {
7239 let session_lock = self.session_manager.get_session(session_id)?;
7240 let mut session = session_lock.write();
7241 session.touch();
7242 session.stats.queries_executed += 1;
7243
7244 if self.session_transactions.contains_key(&session_id) {
7246 if session.isolation_level == crate::session::IsolationLevel::ReadCommitted {
7249 if let Some(mut txn) = self.session_transactions.get_mut(&session_id) {
7250 txn.refresh_snapshot(self.storage.current_timestamp());
7251 }
7252 }
7253
7254 let txn = self.session_transactions.get(&session_id)
7257 .ok_or_else(|| Error::transaction("Session transaction disappeared during query"))?;
7258
7259 let (statement, _) = self.parse_cached(sql)?;
7261
7262 let catalog = self.storage.catalog();
7264 let planner = sql::Planner::with_catalog(&catalog)
7265 .with_sql(sql.to_string());
7266 let plan = planner.statement_to_plan(statement)?;
7267
7268 let mut executor = sql::Executor::with_storage(&self.storage)
7270 .with_timeout(self.config.storage.query_timeout_ms)
7271 .with_transaction(&txn);
7272
7273 executor.execute(&plan)
7274 } else {
7275 self.query(sql, _params)
7276 }
7277 }
7278
7279 pub fn set_session_quota(&self, _user_name: &str, _max_sessions: usize) -> Result<()> {
7281 Ok(())
7283 }
7284
7285 pub fn set_memory_quota(&self, _user_name: &str, _max_bytes: usize) -> Result<()> {
7287 Ok(())
7289 }
7290
7291 pub fn is_dirty(&self) -> bool {
7293 self.dirty_tracker.is_dirty()
7294 }
7295
7296 pub fn mark_table_dirty(&self, table: &str) {
7298 let _ = self.dirty_tracker.track_insert(table, "dummy_key", &[]);
7300 }
7301
7302 pub fn query_params(&self, sql: &str, params: &[Value]) -> Result<Vec<Tuple>> {
7359 #[cfg(feature = "code-graph")]
7363 let (rewritten_owned, _branch_guard) = self.rewrite_and_scope(sql);
7364 #[cfg(feature = "code-graph")]
7365 let sql: &str = &rewritten_owned;
7366 #[cfg(not(feature = "code-graph"))]
7367 let sql: &str = sql;
7368 let start = std::time::Instant::now();
7369
7370 let parse_start = std::time::Instant::now();
7372 let (statement, _) = self.parse_cached(sql)?;
7373 tracing::debug!(phase = "parse", duration_us = parse_start.elapsed().as_micros() as u64, "SQL parsed");
7374
7375 let plan_start = std::time::Instant::now();
7377 let catalog = self.storage.catalog();
7378 let planner = sql::Planner::with_catalog(&catalog)
7379 .with_sql(sql.to_string());
7380 let mut plan = planner.statement_to_plan(statement)?;
7381 tracing::debug!(phase = "plan", duration_us = plan_start.elapsed().as_micros() as u64, "Logical plan created");
7382
7383 plan = self.apply_rls_to_plan(plan)?;
7385
7386 let exec_start = std::time::Instant::now();
7388 let results = self.query_plan_with_params(&plan, params)?;
7389 tracing::debug!(phase = "execute", duration_us = exec_start.elapsed().as_micros() as u64, rows = results.len() as u64, "Query executed");
7390
7391 self.log_slow_query(sql, start.elapsed(), results.len() as u64);
7392 Ok(results)
7393 }
7394
7395 fn query_plan_with_params(&self, plan: &sql::LogicalPlan, params: &[Value]) -> Result<Vec<Tuple>> {
7397 let mut executor = sql::Executor::with_storage(&self.storage)
7399 .with_timeout(self.config.storage.query_timeout_ms)
7400 .with_parameters(params.to_vec());
7401
7402 executor.execute(plan)
7403 }
7404
7405 pub fn begin(&self) -> Result<()> {
7435 self.begin_transaction_internal()
7436 }
7437
7438 pub fn commit(&self) -> Result<()> {
7461 self.commit_internal()
7462 }
7463
7464 pub fn rollback(&self) -> Result<()> {
7488 self.rollback_internal()
7489 }
7490
7491 pub fn in_transaction(&self) -> bool {
7515 self.current_transaction.lock()
7516 .map(|txn| txn.is_some())
7517 .unwrap_or(false)
7518 }
7519
7520 #[allow(dead_code)]
7555 pub(crate) fn bulk_insert_tuples(
7556 &self,
7557 table_name: &str,
7558 tuples: Vec<Tuple>,
7559 ) -> Result<Vec<u64>> {
7560 let catalog = self.storage.catalog();
7561 let schema = catalog.get_table_schema(table_name)?;
7562
7563 let mut row_ids: Vec<u64> = Vec::with_capacity(tuples.len());
7564 let mut art_updates: Vec<(u64, std::collections::HashMap<String, Value>)> =
7565 Vec::with_capacity(tuples.len());
7566
7567 for mut tuple in tuples {
7568 let row_id = catalog.next_row_id(table_name)?;
7569
7570 for (i, col) in schema.columns.iter().enumerate() {
7573 if col.primary_key {
7574 if let Some(v) = tuple.values.get(i) {
7575 if matches!(v, Value::Null) && i < tuple.values.len() {
7576 #[allow(clippy::indexing_slicing)]
7577 match col.data_type {
7578 DataType::Int2 => { tuple.values[i] = Value::Int2(row_id as i16); }
7579 DataType::Int4 => { tuple.values[i] = Value::Int4(row_id as i32); }
7580 _ => { tuple.values[i] = Value::Int8(row_id as i64); }
7581 }
7582 }
7583 }
7584 }
7585 }
7586
7587 let mut col_values = std::collections::HashMap::with_capacity(schema.columns.len());
7590 for (i, col) in schema.columns.iter().enumerate() {
7591 if let Some(v) = tuple.values.get(i) {
7592 col_values.insert(col.name.clone(), v.clone());
7593 }
7594 }
7595
7596 if let Err(e) = self.storage.art_indexes()
7598 .check_unique_constraints(table_name, &col_values)
7599 {
7600 return Err(Error::constraint_violation(e.to_string()));
7601 }
7602
7603 let val = bincode::serialize(&tuple)
7607 .map_err(|e| Error::storage(format!("Failed to serialize tuple: {}", e)))?;
7608 let key = self.storage.branch_aware_data_key(table_name, row_id);
7609 self.storage.put(&key, &val)?;
7610
7611 art_updates.push((row_id, col_values));
7612 row_ids.push(row_id);
7613 }
7614
7615 for (row_id, col_values) in art_updates {
7616 if let Err(e) = self.storage.art_indexes()
7617 .on_insert(table_name, row_id, &col_values)
7618 {
7619 tracing::debug!("ART on_insert {}: {}", table_name, e);
7620 }
7621 }
7622
7623 self.invalidate_result_cache();
7628 if let Ok(mut cache) = self.plan_cache.lock() {
7629 cache.clear();
7630 }
7631
7632 Ok(row_ids)
7633 }
7634
7635 #[deprecated(since = "2.1.0", note = "Use `begin()`, `commit()`, and `rollback()` instead")]
7655 pub fn begin_transaction(&self) -> Result<Transaction<'_>> {
7656 let tx = self.storage.begin_transaction()?;
7657 Ok(Transaction { tx, db: self })
7658 }
7659
7660 pub fn current_lsn(&self) -> Option<u64> {
7665 let txn_id = self.storage.snapshot_manager().current_transaction_id();
7668 if txn_id > 1 {
7670 Some(txn_id - 1)
7671 } else {
7672 Some(0)
7674 }
7675 }
7676
7677 pub fn close(self) -> Result<()> {
7682 Ok(())
7685 }
7686
7687 pub fn list_vector_stores(&self) -> Result<Vec<VectorStoreInfo>> {
7691 use crate::vector::DistanceMetric;
7692
7693 let vector_mgr = self.storage.vector_indexes();
7694 let metadata_list = vector_mgr.list_all_metadata();
7695
7696 Ok(metadata_list.iter().map(|meta| {
7697 let (vector_count, metric, index_type) = match vector_mgr.get_index_stats(&meta.name) {
7699 Ok(stats) => (
7700 stats.num_vectors as u64,
7701 match &meta.index_type {
7702 storage::VectorIndexType::Standard(cfg) => match cfg.distance_metric {
7703 DistanceMetric::L2 => "l2".to_string(),
7704 DistanceMetric::Cosine => "cosine".to_string(),
7705 DistanceMetric::InnerProduct => "inner_product".to_string(),
7706 },
7707 storage::VectorIndexType::Quantized(cfg) => match cfg.distance_metric {
7708 DistanceMetric::L2 => "l2".to_string(),
7709 DistanceMetric::Cosine => "cosine".to_string(),
7710 DistanceMetric::InnerProduct => "inner_product".to_string(),
7711 },
7712 },
7713 match &meta.index_type {
7714 storage::VectorIndexType::Standard(_) => "hnsw".to_string(),
7715 storage::VectorIndexType::Quantized(_) => "hnsw_pq".to_string(),
7716 },
7717 ),
7718 Err(_) => (0, "cosine".to_string(), "hnsw".to_string()),
7719 };
7720
7721 let dimensions = match &meta.index_type {
7722 storage::VectorIndexType::Standard(cfg) => cfg.dimension as u32,
7723 storage::VectorIndexType::Quantized(cfg) => cfg.dimension as u32,
7724 };
7725
7726 VectorStoreInfo {
7727 name: meta.name.clone(),
7728 dimensions,
7729 vector_count,
7730 created_at: "N/A".to_string(),
7731 metric,
7732 index_type,
7733 }
7734 }).collect())
7735 }
7736
7737 pub fn create_vector_store(&self, name: &str, dimensions: u32) -> Result<VectorStoreInfo> {
7739 use crate::vector::DistanceMetric;
7740
7741 let vector_mgr = self.storage.vector_indexes();
7742
7743 vector_mgr.create_index(
7745 name.to_string(),
7746 name.to_string(), "embedding".to_string(), dimensions as usize,
7749 DistanceMetric::Cosine, )?;
7751
7752 Ok(VectorStoreInfo {
7753 name: name.to_string(),
7754 dimensions,
7755 vector_count: 0,
7756 created_at: chrono::Utc::now().to_rfc3339(),
7757 metric: "cosine".to_string(),
7758 index_type: "hnsw".to_string(),
7759 })
7760 }
7761
7762 pub fn get_vector_store(&self, name: &str) -> Result<VectorStoreInfo> {
7764 use crate::vector::DistanceMetric;
7765
7766 let vector_mgr = self.storage.vector_indexes();
7767
7768 let meta = vector_mgr.get_metadata(name)?;
7769 let stats = vector_mgr.get_index_stats(name)?;
7770
7771 let metric = match &meta.index_type {
7772 storage::VectorIndexType::Standard(cfg) => match cfg.distance_metric {
7773 DistanceMetric::L2 => "l2".to_string(),
7774 DistanceMetric::Cosine => "cosine".to_string(),
7775 DistanceMetric::InnerProduct => "inner_product".to_string(),
7776 },
7777 storage::VectorIndexType::Quantized(cfg) => match cfg.distance_metric {
7778 DistanceMetric::L2 => "l2".to_string(),
7779 DistanceMetric::Cosine => "cosine".to_string(),
7780 DistanceMetric::InnerProduct => "inner_product".to_string(),
7781 },
7782 };
7783
7784 let index_type = match &meta.index_type {
7785 storage::VectorIndexType::Standard(_) => "hnsw".to_string(),
7786 storage::VectorIndexType::Quantized(_) => "hnsw_pq".to_string(),
7787 };
7788
7789 Ok(VectorStoreInfo {
7790 name: name.to_string(),
7791 dimensions: stats.dimensions as u32,
7792 vector_count: stats.num_vectors as u64,
7793 created_at: "N/A".to_string(),
7794 metric,
7795 index_type,
7796 })
7797 }
7798
7799 pub fn delete_vector_store(&self, name: &str) -> Result<()> {
7801 let vector_mgr = self.storage.vector_indexes();
7802 vector_mgr.drop_index(name)
7803 }
7804
7805 pub fn insert_vectors(&self, store: &str, vectors: Vec<Vec<f32>>) -> Result<Vec<String>> {
7809 let vector_mgr = self.storage.vector_indexes();
7810
7811 let _ = vector_mgr.get_metadata(store)?;
7813
7814 let mut ids = Vec::with_capacity(vectors.len());
7815
7816 for vector in vectors {
7817 let id = self.storage.next_timestamp();
7819 let id_str = format!("vec_{}", id);
7820
7821 vector_mgr.insert_vector(store, id, &vector)?;
7823
7824 ids.push(id_str);
7825 }
7826
7827 Ok(ids)
7828 }
7829
7830 pub fn upsert_vectors(&self, store: &str, vectors: Vec<(String, Vec<f32>)>) -> Result<()> {
7832 let vector_mgr = self.storage.vector_indexes();
7833
7834 let _ = vector_mgr.get_metadata(store)?;
7836
7837 for (id_str, vector) in vectors {
7838 let id = id_str.strip_prefix("vec_")
7840 .and_then(|s| s.parse::<u64>().ok())
7841 .unwrap_or_else(|| {
7842 self.storage.next_timestamp()
7844 });
7845
7846 let _ = vector_mgr.delete_vector(store, id);
7848
7849 vector_mgr.insert_vector(store, id, &vector)?;
7851 }
7852
7853 Ok(())
7854 }
7855
7856 pub fn search_vectors(&self, store: &str, query: Vec<f32>, k: usize) -> Result<Vec<(String, f32)>> {
7860 let vector_mgr = self.storage.vector_indexes();
7861
7862 let _ = vector_mgr.get_metadata(store)?;
7864
7865 let results = vector_mgr.search(store, &query, k)?;
7867
7868 Ok(results.into_iter()
7870 .map(|(row_id, distance)| (format!("vec_{}", row_id), distance))
7871 .collect())
7872 }
7873
7874 pub fn text_search(&self, _query: &str) -> Result<Vec<String>> {
7876 Err(Error::Generic("Text search requires embedding model - not yet implemented".to_string()))
7877 }
7878
7879 pub fn store_texts(&self, _store: &str, _texts: Vec<String>) -> Result<Vec<String>> {
7881 Err(Error::Generic("Text storage requires embedding model - not yet implemented".to_string()))
7882 }
7883
7884 pub fn hybrid_search(&self, _store: &str, _query: &str, _k: usize) -> Result<Vec<(String, f32)>> {
7886 Err(Error::Generic("Hybrid search requires embedding model - not yet implemented".to_string()))
7887 }
7888
7889 pub fn delete_vectors(&self, store: &str, ids: Vec<String>) -> Result<()> {
7891 let vector_mgr = self.storage.vector_indexes();
7892
7893 let _ = vector_mgr.get_metadata(store)?;
7895
7896 for id_str in ids {
7897 if let Some(id) = id_str.strip_prefix("vec_").and_then(|s| s.parse::<u64>().ok()) {
7899 vector_mgr.delete_vector(store, id)?;
7900 }
7901 }
7902
7903 Ok(())
7904 }
7905
7906 pub fn fetch_vectors(&self, _store: &str, _ids: Vec<String>) -> Result<Vec<(String, Vec<f32>)>> {
7908 Err(Error::Generic("Vector fetch not yet implemented - HNSW index doesn't store raw vectors".to_string()))
7909 }
7910
7911 pub fn list_agent_sessions(&self) -> Result<Vec<AgentSession>> {
7915 Ok(vec![])
7916 }
7917
7918 pub fn create_agent_session(&self, _name: &str) -> Result<AgentSession> {
7920 Err(Error::Generic("Agent sessions not yet implemented".to_string()))
7921 }
7922
7923 pub fn get_agent_session(&self, _id: &str) -> Result<AgentSession> {
7925 Err(Error::Generic("Agent sessions not yet implemented".to_string()))
7926 }
7927
7928 pub fn delete_agent_session(&self, _id: &str) -> Result<()> {
7930 Err(Error::Generic("Agent sessions not yet implemented".to_string()))
7931 }
7932
7933 pub fn add_agent_message(&self, _session_id: &str, _role: &str, _content: &str) -> Result<AgentMessage> {
7935 Err(Error::Generic("Agent messages not yet implemented".to_string()))
7936 }
7937
7938 pub fn get_agent_messages(&self, _session_id: &str) -> Result<Vec<AgentMessage>> {
7940 Ok(vec![])
7941 }
7942
7943 pub fn clear_agent_messages(&self, _session_id: &str) -> Result<()> {
7945 Err(Error::Generic("Agent messages not yet implemented".to_string()))
7946 }
7947
7948 pub fn generate_schema(&self, _table_name: &str) -> Result<String> {
7950 Err(Error::Generic("Schema generation not yet implemented".to_string()))
7951 }
7952
7953 pub fn chat_completion(&self, _messages: Vec<(String, String)>) -> Result<String> {
7955 Err(Error::Generic("Chat completions not yet implemented".to_string()))
7956 }
7957
7958 pub fn nl_to_sql(&self, _query: &str) -> Result<String> {
7960 Err(Error::Generic("Natural language to SQL not yet implemented".to_string()))
7961 }
7962
7963 pub fn store_document(&self, _collection: &str, _id: &str, _content: &str, _metadata: Option<serde_json::Value>) -> Result<()> {
7965 Err(Error::Generic("Document storage not yet implemented".to_string()))
7966 }
7967
7968 pub fn get_document(&self, _collection: &str, _id: &str) -> Result<DocumentData> {
7970 Err(Error::Generic("Document storage not yet implemented".to_string()))
7971 }
7972
7973 pub fn delete_document(&self, _collection: &str, _id: &str) -> Result<()> {
7975 Err(Error::Generic("Document storage not yet implemented".to_string()))
7976 }
7977
7978 pub fn update_document(&self, _collection: &str, _id: &str, _content: &str, _metadata: Option<serde_json::Value>) -> Result<()> {
7980 Err(Error::Generic("Document storage not yet implemented".to_string()))
7981 }
7982
7983 pub fn list_documents(&self, _collection: &str) -> Result<Vec<DocumentMetadata>> {
7985 Ok(vec![])
7986 }
7987
7988 pub fn search_documents(&self, _collection: &str, _query: &str) -> Result<Vec<DocumentData>> {
7990 Ok(vec![])
7991 }
7992
7993 pub fn create_collection(&self, _name: &str) -> Result<()> {
7995 Err(Error::Generic("Collections not yet implemented".to_string()))
7996 }
7997
7998 pub fn delete_collection(&self, _name: &str) -> Result<()> {
8000 Err(Error::Generic("Collections not yet implemented".to_string()))
8001 }
8002
8003 pub fn list_collections(&self) -> Result<Vec<String>> {
8005 Ok(vec![])
8006 }
8007
8008 pub fn batch_create_documents(&self, _collection: &str, _docs: Vec<DocumentData>) -> Result<Vec<String>> {
8010 Err(Error::Generic("Batch document creation not yet implemented".to_string()))
8011 }
8012
8013 pub fn batch_infer_schema(&self, _data: Vec<Vec<Value>>) -> Result<Schema> {
8015 Err(Error::Generic("Batch schema inference not yet implemented".to_string()))
8016 }
8017
8018 pub fn chat_completion_stream(&self, _messages: Vec<(String, String)>) -> Result<String> {
8020 Err(Error::Generic("Chat completion streaming not yet implemented".to_string()))
8021 }
8022
8023 pub fn compare_schemas(&self, _schema1: &Schema, _schema2: &Schema) -> Result<serde_json::Value> {
8025 Err(Error::Generic("Schema comparison not yet implemented".to_string()))
8026 }
8027
8028 pub fn create_embeddings(&self, _texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
8030 Err(Error::Generic("Embeddings not yet implemented".to_string()))
8031 }
8032
8033 pub fn create_document(&self, _collection: &str, _id: &str, _content: &str, _metadata: Option<serde_json::Value>) -> Result<String> {
8035 Ok("document_id".to_string())
8036 }
8037
8038 pub fn find_similar_documents(&self, _collection: &str, _query: &str, _limit: usize) -> Result<Vec<(DocumentData, f32)>> {
8040 Err(Error::Generic("Similar document search not yet implemented".to_string()))
8041 }
8042
8043 pub fn fork_agent_session(&self, _session_id: &str, _new_name: &str) -> Result<AgentSession> {
8045 Err(Error::Generic("Agent session forking not yet implemented".to_string()))
8046 }
8047
8048 pub fn generate_schema_from_description(&self, _description: &str) -> Result<Schema> {
8050 Err(Error::Generic("Schema generation from description not yet implemented".to_string()))
8051 }
8052
8053 pub fn get_agent_context(&self, _session_id: &str) -> Result<serde_json::Value> {
8055 Err(Error::Generic("Agent context retrieval not yet implemented".to_string()))
8056 }
8057
8058 pub fn get_chat_model(&self, _model_id: &str) -> Result<serde_json::Value> {
8060 Err(Error::Generic("Chat model retrieval not yet implemented".to_string()))
8061 }
8062
8063 pub fn get_document_chunks(&self, _collection: &str, _id: &str) -> Result<Vec<(String, f32)>> {
8065 Err(Error::Generic("Document chunking not yet implemented".to_string()))
8066 }
8067
8068 pub fn infer_schema(&self, _data: Vec<Vec<Value>>) -> Result<Schema> {
8070 Err(Error::Generic("Schema inference not yet implemented".to_string()))
8071 }
8072
8073 pub fn infer_schema_from_file(&self, _path: &str) -> Result<Schema> {
8075 Err(Error::Generic("Schema inference from file not yet implemented".to_string()))
8076 }
8077
8078 pub fn instantiate_schema_template(&self, _template_name: &str, _params: serde_json::Value) -> Result<Schema> {
8080 Err(Error::Generic("Schema template instantiation not yet implemented".to_string()))
8081 }
8082
8083 pub fn list_chat_models(&self) -> Result<Vec<serde_json::Value>> {
8085 Ok(vec![])
8086 }
8087
8088 pub fn list_schema_templates(&self) -> Result<Vec<serde_json::Value>> {
8090 Ok(vec![])
8091 }
8092
8093 pub fn optimize_schema(&self, _schema: &Schema) -> Result<Schema> {
8095 Err(Error::Generic("Schema optimization not yet implemented".to_string()))
8096 }
8097
8098 pub fn validate_schema(&self, _schema: &Schema) -> Result<bool> {
8100 Err(Error::Generic("Schema validation not yet implemented".to_string()))
8101 }
8102
8103 pub fn rag_search(&self, _collection: &str, _query: &str, _k: usize) -> Result<Vec<(DocumentData, f32, String)>> {
8105 Err(Error::Generic("RAG search not yet implemented".to_string()))
8106 }
8107
8108 pub fn rechunk_document(&self, _collection: &str, _id: &str, _chunk_size: usize) -> Result<Vec<String>> {
8110 Err(Error::Generic("Document rechunking not yet implemented".to_string()))
8111 }
8112
8113 pub fn search_agent_memory(&self, _session_id: &str, _query: &str) -> Result<Vec<(AgentMessage, f32)>> {
8115 Err(Error::Generic("Agent memory search not yet implemented".to_string()))
8116 }
8117
8118 pub fn summarize_agent_memory(&self, _session_id: &str) -> Result<String> {
8120 Err(Error::Generic("Agent memory summarization not yet implemented".to_string()))
8121 }
8122
8123 pub fn create_branch(&self, name: &str) -> Result<u64> {
8127 self.execute(&format!("CREATE BRANCH {name}"))
8128 }
8129
8130 pub fn switch_branch(&self, name: &str) -> Result<u64> {
8132 self.execute(&format!("USE BRANCH {name}"))
8133 }
8134
8135 pub fn merge_branch(&self, source: &str) -> Result<u64> {
8137 self.execute(&format!("MERGE BRANCH {source}"))
8138 }
8139
8140 pub fn drop_branch(&self, name: &str) -> Result<u64> {
8142 self.execute(&format!("DROP BRANCH {name}"))
8143 }
8144
8145 pub fn list_branches(&self) -> Result<Vec<Tuple>> {
8147 self.query("LIST BRANCHES", &[])
8148 }
8149
8150 pub fn explain(&self, sql: &str) -> Result<Vec<Tuple>> {
8152 self.query(&format!("EXPLAIN {sql}"), &[])
8153 }
8154
8155 pub fn explain_analyze(&self, sql: &str) -> Result<Vec<Tuple>> {
8157 self.query(&format!("EXPLAIN ANALYZE {sql}"), &[])
8158 }
8159
8160 pub fn refresh_materialized_view(&self, name: &str) -> Result<u64> {
8162 self.execute(&format!("REFRESH MATERIALIZED VIEW {name}"))
8163 }
8164
8165 fn clone_for_trigger(&self) -> Self {
8170 Self {
8171 storage: self.storage.clone(),
8172 config: self.config.clone(),
8173 current_transaction: self.current_transaction.clone(),
8174 tenant_manager: self.tenant_manager.clone(),
8175 trigger_registry: self.trigger_registry.clone(),
8176 function_registry: self.function_registry.clone(),
8177 mv_scheduler: self.mv_scheduler.clone(),
8178 auto_refresh_worker: self.auto_refresh_worker.clone(),
8179 dump_manager: self.dump_manager.clone(),
8180 session_manager: self.session_manager.clone(),
8181 lock_manager: self.lock_manager.clone(),
8182 dirty_tracker: self.dirty_tracker.clone(),
8183 session_transactions: self.session_transactions.clone(),
8184 prepared_statements: self.prepared_statements.clone(),
8185 savepoints: self.savepoints.clone(),
8186 plan_cache: self.plan_cache.clone(),
8187 parse_cache: self.parse_cache.clone(),
8188 result_cache: self.result_cache.clone(),
8189 art_undo_log: self.art_undo_log.clone(),
8190 }
8191 }
8192
8193 fn check_foreign_key_exists(
8197 &self,
8198 table_name: &str,
8199 column_names: &[String],
8200 values: &[Value],
8201 ) -> Result<bool> {
8202 let catalog = self.storage.catalog();
8204 let schema = catalog.get_table_schema(table_name)?;
8205
8206 let tuples = self.storage.scan_table(table_name)?;
8208
8209 for tuple in tuples {
8210 let mut matches = true;
8211 for (col_name, expected_value) in column_names.iter().zip(values.iter()) {
8212 let col_idx = schema.columns.iter()
8214 .position(|c| &c.name == col_name);
8215
8216 if let Some(idx) = col_idx {
8217 match tuple.values.get(idx) {
8218 Some(actual_value) if actual_value == expected_value => {}
8219 _ => { matches = false; break; }
8220 }
8221 } else {
8222 matches = false;
8223 break;
8224 }
8225 }
8226
8227 if matches {
8228 return Ok(true);
8229 }
8230 }
8231
8232 Ok(false)
8233 }
8234
8235 fn check_unique_violation(
8240 &self,
8241 table_name: &str,
8242 column_names: &[String],
8243 values: &[Value],
8244 ) -> Result<bool> {
8245 let catalog = self.storage.catalog();
8246 let schema = catalog.get_table_schema(table_name)?;
8247
8248 let tuples = self.storage.scan_table(table_name)?;
8250
8251 for tuple in tuples {
8252 let mut matches = true;
8253 for (col_name, expected_value) in column_names.iter().zip(values.iter()) {
8254 let col_idx = schema.columns.iter()
8256 .position(|c| &c.name == col_name);
8257
8258 if let Some(idx) = col_idx {
8259 match tuple.values.get(idx) {
8260 Some(actual_value) if actual_value == expected_value => {}
8261 _ => { matches = false; break; }
8262 }
8263 } else {
8264 matches = false;
8265 break;
8266 }
8267 }
8268
8269 if matches {
8270 return Ok(true); }
8272 }
8273
8274 Ok(false) }
8276
8277 fn cascade_delete_referencing_rows(
8281 &self,
8282 table_name: &str,
8283 fk_columns: &[String],
8284 parent_values: &[Value],
8285 ) -> Result<()> {
8286 let catalog = self.storage.catalog();
8287 let schema = catalog.get_table_schema(table_name)?;
8288
8289 let tuples = self.storage.scan_table(table_name)?;
8291 let mut row_ids_to_delete: Vec<u64> = Vec::new();
8292
8293 for tuple in tuples {
8294 let mut matches = true;
8295 for (fk_col, parent_val) in fk_columns.iter().zip(parent_values.iter()) {
8296 let col_idx = schema.columns.iter().position(|c| &c.name == fk_col);
8297 if let Some(idx) = col_idx {
8298 match tuple.values.get(idx) {
8299 Some(val) if val == parent_val => {}
8300 _ => { matches = false; break; }
8301 }
8302 } else {
8303 matches = false;
8304 break;
8305 }
8306 }
8307
8308 if matches {
8309 if let Some(row_id) = tuple.row_id {
8310 row_ids_to_delete.push(row_id);
8311 }
8312 }
8313 }
8314
8315 let txn = self.storage.begin_transaction()?;
8317 for row_id in row_ids_to_delete {
8318 let key = self.storage.branch_aware_data_key(table_name, row_id);
8319 txn.delete(key.clone())?;
8320
8321 self.storage.log_data_delete(table_name, &key)?;
8323 }
8324 txn.commit()?;
8325
8326 Ok(())
8327 }
8328
8329 fn set_null_referencing_rows(
8333 &self,
8334 table_name: &str,
8335 fk_columns: &[String],
8336 parent_values: &[Value],
8337 ) -> Result<()> {
8338 let catalog = self.storage.catalog();
8339 let schema = catalog.get_table_schema(table_name)?;
8340
8341 let tuples = self.storage.scan_table(table_name)?;
8343 let mut rows_to_update: Vec<(u64, Tuple)> = Vec::new();
8344
8345 for tuple in tuples {
8346 let mut matches = true;
8347 for (fk_col, parent_val) in fk_columns.iter().zip(parent_values.iter()) {
8348 let col_idx = schema.columns.iter().position(|c| &c.name == fk_col);
8349 if let Some(idx) = col_idx {
8350 match tuple.values.get(idx) {
8351 Some(val) if val == parent_val => {}
8352 _ => { matches = false; break; }
8353 }
8354 } else {
8355 matches = false;
8356 break;
8357 }
8358 }
8359
8360 if matches {
8361 if let Some(row_id) = tuple.row_id {
8362 let mut new_values = tuple.values.clone();
8364 for fk_col in fk_columns {
8365 if let Some(idx) = schema.columns.iter().position(|c| &c.name == fk_col) {
8366 if let Some(slot) = new_values.get_mut(idx) {
8367 *slot = Value::Null;
8368 }
8369 }
8370 }
8371 let new_tuple = Tuple::new(new_values);
8372 rows_to_update.push((row_id, new_tuple));
8373 }
8374 }
8375 }
8376
8377 let txn = self.storage.begin_transaction()?;
8379 for (row_id, new_tuple) in rows_to_update {
8380 let key = self.storage.branch_aware_data_key(table_name, row_id);
8381 let val = bincode::serialize(&new_tuple).map_err(|e| Error::storage(e.to_string()))?;
8382 txn.put(key.clone(), val.clone())?;
8383
8384 self.storage.log_data_update(table_name, &key, &val)?;
8386 }
8387 txn.commit()?;
8388
8389 Ok(())
8390 }
8391
8392 fn evaluate_check_constraint(
8397 &self,
8398 expression: &str,
8399 schema: &Schema,
8400 values: &[Value],
8401 ) -> Result<bool> {
8402 let tuple = Tuple::new(values.to_vec());
8404
8405 let logical_expr = if expression.starts_with('{') || expression.starts_with('[') {
8407 serde_json::from_str::<sql::LogicalExpr>(expression)
8409 .map_err(|e| Error::query_execution(format!(
8410 "Failed to deserialize CHECK constraint expression '{}': {}",
8411 expression, e
8412 )))?
8413 } else {
8414 use sqlparser::dialect::PostgreSqlDialect;
8416 use sqlparser::parser::Parser as SqlParser;
8417
8418 let sql = format!("SELECT * FROM dummy WHERE {}", expression);
8420 let dialect = PostgreSqlDialect {};
8421
8422 let mut statements = SqlParser::parse_sql(&dialect, &sql)
8423 .map_err(|e| Error::query_execution(format!(
8424 "Failed to parse CHECK constraint expression '{}': {}",
8425 expression, e
8426 )))?;
8427
8428 if statements.len() != 1 {
8429 return Err(Error::query_execution(
8430 "Invalid CHECK constraint expression: expected single statement"
8431 ));
8432 }
8433
8434 let statement = statements.remove(0);
8436
8437 let selection = if let sqlparser::ast::Statement::Query(query) = statement {
8438 if let sqlparser::ast::SetExpr::Select(select) = *query.body {
8439 select.selection
8440 } else {
8441 None
8442 }
8443 } else {
8444 None
8445 };
8446
8447 let selection = selection.ok_or_else(|| Error::query_execution(format!(
8448 "Failed to extract expression from CHECK constraint: {}",
8449 expression
8450 )))?;
8451
8452 let catalog = self.storage.catalog();
8454 let planner = sql::Planner::with_catalog(&catalog);
8455
8456 planner.convert_expr_to_logical(&selection, Some(schema))?
8458 };
8459
8460 let evaluator = sql::Evaluator::new(std::sync::Arc::new(schema.clone()));
8462 let result = evaluator.evaluate(&logical_expr, &tuple)?;
8463
8464 match result {
8466 Value::Boolean(b) => Ok(b),
8467 Value::Null => Ok(true), _ => Err(Error::constraint_violation(format!(
8469 "CHECK constraint expression '{}' did not evaluate to boolean",
8470 expression
8471 ))),
8472 }
8473 }
8474
8475 fn check_fk_constraints_on_write(
8493 &self,
8494 table_name: &str,
8495 col_values: &std::collections::HashMap<String, Value>,
8496 active_txn: Option<&storage::Transaction>,
8497 ) -> Result<()> {
8498 let catalog = self.storage.catalog();
8499 let constraints = match catalog.load_table_constraints(table_name) {
8500 Ok(c) => c,
8501 Err(_) => return Ok(()), };
8503 if constraints.foreign_keys.is_empty() {
8504 return Ok(());
8505 }
8506 for fk in &constraints.foreign_keys {
8507 let mut parent_values: Vec<Value> = Vec::with_capacity(fk.columns.len());
8509 let mut any_null = false;
8510 for col in &fk.columns {
8511 match col_values.get(col) {
8512 Some(v) if matches!(v, Value::Null) => {
8513 any_null = true;
8514 break;
8515 }
8516 Some(v) => parent_values.push(v.clone()),
8517 None => {
8518 any_null = true;
8519 break;
8520 }
8521 }
8522 }
8523 if any_null {
8526 continue;
8527 }
8528 let parent_exists = self.check_referencing_rows_exist(
8529 &fk.references_table,
8530 &fk.references_columns,
8531 &parent_values,
8532 active_txn,
8533 )?;
8534 if !parent_exists {
8535 let parent_repr: Vec<String> = parent_values.iter()
8536 .map(|v| format!("{v}"))
8537 .collect();
8538 return Err(Error::constraint_violation(format!(
8539 "Foreign key constraint '{}' violated: row references \
8540 non-existent {}({}) = ({})",
8541 fk.name,
8542 fk.references_table,
8543 fk.references_columns.join(", "),
8544 parent_repr.join(", "),
8545 )));
8546 }
8547 }
8548 Ok(())
8549 }
8550
8551 fn check_referencing_rows_exist(
8552 &self,
8553 table_name: &str,
8554 column_names: &[String],
8555 values: &[Value],
8556 active_txn: Option<&storage::Transaction>,
8557 ) -> Result<bool> {
8558 if active_txn.is_none() && column_names.len() == values.len() && !column_names.is_empty() {
8575 let art = self.storage.art_indexes();
8576 let index_name = if column_names.len() == 1 {
8578 #[allow(clippy::indexing_slicing)]
8579 art.find_column_index(table_name, &column_names[0])
8580 } else {
8581 None
8582 };
8583 if let Some(name) = index_name {
8584 let key = storage::ArtIndexManager::encode_key(values);
8585 let row_ids = art.index_get_all(&name, &key);
8586 if !row_ids.is_empty() {
8587 return Ok(true);
8588 }
8589 return Ok(false);
8595 }
8596 }
8597
8598 let catalog = self.storage.catalog();
8608 let schema = catalog.get_table_schema(table_name)?;
8609 let base = self.storage.scan_table(table_name)?;
8610 let tuples = if let Some(txn) = active_txn {
8611 txn.merge_with_write_set(table_name, base)?
8612 } else {
8613 base
8614 };
8615
8616 for tuple in tuples {
8617 let mut matches = true;
8618 for (col_name, expected_value) in column_names.iter().zip(values.iter()) {
8619 let col_idx = schema.columns.iter()
8620 .position(|c| &c.name == col_name);
8621
8622 if let Some(idx) = col_idx {
8623 match tuple.values.get(idx) {
8624 Some(actual_value) if actual_value == expected_value => {}
8625 _ => { matches = false; break; }
8626 }
8627 } else {
8628 matches = false;
8629 break;
8630 }
8631 }
8632
8633 if matches {
8634 return Ok(true);
8635 }
8636 }
8637
8638 Ok(false)
8639 }
8640
8641 #[cfg(feature = "server")]
8670 pub fn start_qps_reset_task(&self) -> tokio::task::JoinHandle<()> {
8671 let tenant_manager = self.tenant_manager.clone();
8672
8673 tokio::spawn(async move {
8674 let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(1));
8675 loop {
8676 interval.tick().await;
8677
8678 let tenants = tenant_manager.list_tenants();
8680 for tenant in tenants {
8681 let _ = tenant_manager.reset_qps_window(tenant.id);
8682 }
8683 }
8684 })
8685 }
8686
8687 pub fn reset_all_qps_windows(&self) {
8706 let tenants = self.tenant_manager.list_tenants();
8707 for tenant in tenants {
8708 let _ = self.tenant_manager.reset_qps_window(tenant.id);
8709 }
8710 }
8711
8712 fn execute_plan_internal(&self, plan: &sql::LogicalPlan) -> Result<u64> {
8717 let (count, _tuples) = self.execute_plan_with_params(plan, &[])?;
8719 Ok(count)
8720 }
8721
8722 pub(crate) const RESERVED_DATABASE_NAMES: &'static [&'static str] = &["heliosdb", "postgres"];
8727
8728 fn database_name_is_reserved(name: &str) -> bool {
8729 let lower = name.to_ascii_lowercase();
8730 Self::RESERVED_DATABASE_NAMES.iter().any(|r| *r == lower)
8731 }
8732
8733 fn handle_create_database(&self, name: &str, if_not_exists: bool) -> Result<(u64, Vec<Tuple>)> {
8737 let trimmed = name.trim();
8738 if trimmed.is_empty() {
8739 return Err(Error::query_execution("CREATE DATABASE requires a non-empty name"));
8740 }
8741 if Self::database_name_is_reserved(trimmed) {
8742 if if_not_exists {
8743 return Ok((0, vec![]));
8744 }
8745 return Err(Error::query_execution(format!(
8746 "database \"{trimmed}\" is a reserved system database and cannot be created"
8747 )));
8748 }
8749 let already_exists = self
8751 .tenant_manager
8752 .list_tenants()
8753 .iter()
8754 .any(|t| t.name.eq_ignore_ascii_case(trimmed));
8755 if already_exists {
8756 if if_not_exists {
8757 return Ok((0, vec![]));
8758 }
8759 return Err(Error::query_execution(format!(
8760 "database \"{trimmed}\" already exists"
8761 )));
8762 }
8763 self.tenant_manager.register_tenant_with_plan(
8764 trimmed.to_string(),
8765 crate::tenant::IsolationMode::DatabasePerTenant,
8766 "free",
8767 );
8768 Ok((0, vec![]))
8769 }
8770
8771 fn handle_drop_database(&self, name: &str, if_exists: bool) -> Result<(u64, Vec<Tuple>)> {
8773 let trimmed = name.trim();
8774 if trimmed.is_empty() {
8775 return Err(Error::query_execution("DROP DATABASE requires a non-empty name"));
8776 }
8777 if Self::database_name_is_reserved(trimmed) {
8778 return Err(Error::query_execution(format!(
8779 "database \"{trimmed}\" is a reserved system database and cannot be dropped"
8780 )));
8781 }
8782 let target = self
8783 .tenant_manager
8784 .list_tenants()
8785 .into_iter()
8786 .find(|t| t.name.eq_ignore_ascii_case(trimmed));
8787 let Some(tenant) = target else {
8788 if if_exists {
8789 return Ok((0, vec![]));
8790 }
8791 return Err(Error::query_execution(format!(
8792 "database \"{trimmed}\" does not exist"
8793 )));
8794 };
8795 self.tenant_manager
8796 .delete_tenant(tenant.id)
8797 .map_err(Error::query_execution)?;
8798 Ok((0, vec![]))
8799 }
8800
8801 pub fn database_name_is_valid(&self, name: &str) -> bool {
8806 let trimmed = name.trim();
8807 if trimmed.is_empty() {
8808 return false;
8809 }
8810 if Self::database_name_is_reserved(trimmed) {
8811 return true;
8812 }
8813 self.tenant_manager
8814 .list_tenants()
8815 .iter()
8816 .any(|t| t.name.eq_ignore_ascii_case(trimmed))
8817 }
8818
8819 pub(crate) fn apply_defaults_and_check_not_null(
8835 tuple_values: &mut [Value],
8836 schema: &Schema,
8837 user_provided: &[bool],
8838 ) -> Result<()> {
8839 let evaluator = sql::Evaluator::new(std::sync::Arc::new(schema.clone()));
8840 let empty_tuple = Tuple::new(vec![]);
8841
8842 for (idx, col) in schema.columns.iter().enumerate() {
8843 let slot = match tuple_values.get_mut(idx) {
8844 Some(s) => s,
8845 None => continue,
8846 };
8847 let is_null = matches!(slot, Value::Null);
8848 let was_omitted = user_provided.get(idx).copied().unwrap_or(false) == false;
8849
8850 if is_null && was_omitted {
8851 if let Some(json) = &col.default_expr {
8853 if let Ok(default_expr) = serde_json::from_str::<sql::LogicalExpr>(json) {
8854 if let Ok(v) = evaluator.evaluate(&default_expr, &empty_tuple) {
8855 let casted = if v.data_type() != col.data_type {
8856 evaluator.cast_value(v, &col.data_type).unwrap_or(Value::Null)
8857 } else {
8858 v
8859 };
8860 *slot = casted;
8861 }
8862 }
8863 }
8864 }
8865
8866 if !col.nullable && !col.primary_key {
8870 if matches!(slot, Value::Null) {
8871 return Err(Error::constraint_violation(format!(
8872 "NOT NULL constraint violated: cannot insert NULL into column '{}'",
8873 col.name
8874 )));
8875 }
8876 }
8877 }
8878 Ok(())
8879 }
8880
8881 fn materialize_scalar_subqueries_for_row(
8897 &self,
8898 expr: &sql::LogicalExpr,
8899 outer_row: &Tuple,
8900 outer_schema: &Schema,
8901 outer_table: &str,
8902 ) -> Result<sql::LogicalExpr> {
8903 use sql::LogicalExpr;
8904 match expr {
8905 LogicalExpr::ScalarSubquery { subquery } => {
8906 let bound_plan = Self::bind_outer_refs_in_plan(
8908 subquery.as_ref(),
8909 outer_row,
8910 outer_schema,
8911 outer_table,
8912 );
8913 let mut executor = sql::Executor::with_storage(&self.storage)
8915 .with_timeout(self.config.storage.query_timeout_ms);
8916 let rows = executor.execute(&bound_plan)?;
8917 let value = rows.first()
8918 .and_then(|t| t.values.first().cloned())
8919 .unwrap_or(Value::Null);
8920 Ok(LogicalExpr::Literal(value))
8921 }
8922 LogicalExpr::BinaryExpr { left, op, right } => Ok(LogicalExpr::BinaryExpr {
8923 left: Box::new(self.materialize_scalar_subqueries_for_row(left, outer_row, outer_schema, outer_table)?),
8924 op: *op,
8925 right: Box::new(self.materialize_scalar_subqueries_for_row(right, outer_row, outer_schema, outer_table)?),
8926 }),
8927 LogicalExpr::UnaryExpr { op, expr: inner } => Ok(LogicalExpr::UnaryExpr {
8928 op: *op,
8929 expr: Box::new(self.materialize_scalar_subqueries_for_row(inner, outer_row, outer_schema, outer_table)?),
8930 }),
8931 LogicalExpr::Cast { expr: inner, data_type } => Ok(LogicalExpr::Cast {
8932 expr: Box::new(self.materialize_scalar_subqueries_for_row(inner, outer_row, outer_schema, outer_table)?),
8933 data_type: data_type.clone(),
8934 }),
8935 other => Ok(other.clone()),
8937 }
8938 }
8939
8940 fn bind_outer_refs_in_plan(
8945 plan: &sql::LogicalPlan,
8946 outer_row: &Tuple,
8947 outer_schema: &Schema,
8948 outer_table: &str,
8949 ) -> sql::LogicalPlan {
8950 use sql::LogicalPlan;
8951 match plan {
8952 LogicalPlan::Filter { input, predicate } => LogicalPlan::Filter {
8953 input: Box::new(Self::bind_outer_refs_in_plan(input, outer_row, outer_schema, outer_table)),
8954 predicate: Self::bind_outer_refs_in_expr(predicate, outer_row, outer_schema, outer_table),
8955 },
8956 LogicalPlan::Project { input, exprs, aliases, distinct, distinct_on } => LogicalPlan::Project {
8957 input: Box::new(Self::bind_outer_refs_in_plan(input, outer_row, outer_schema, outer_table)),
8958 exprs: exprs.iter()
8959 .map(|e| Self::bind_outer_refs_in_expr(e, outer_row, outer_schema, outer_table))
8960 .collect(),
8961 aliases: aliases.clone(),
8962 distinct: *distinct,
8963 distinct_on: distinct_on.clone(),
8964 },
8965 LogicalPlan::Limit { input, limit, offset, limit_param, offset_param } => LogicalPlan::Limit {
8966 input: Box::new(Self::bind_outer_refs_in_plan(input, outer_row, outer_schema, outer_table)),
8967 limit: *limit,
8968 offset: *offset,
8969 limit_param: *limit_param,
8970 offset_param: *offset_param,
8971 },
8972 LogicalPlan::Sort { input, exprs, asc } => LogicalPlan::Sort {
8973 input: Box::new(Self::bind_outer_refs_in_plan(input, outer_row, outer_schema, outer_table)),
8974 exprs: exprs.iter()
8975 .map(|e| Self::bind_outer_refs_in_expr(e, outer_row, outer_schema, outer_table))
8976 .collect(),
8977 asc: asc.clone(),
8978 },
8979 other => other.clone(),
8980 }
8981 }
8982
8983 fn bind_outer_refs_in_expr(
8984 expr: &sql::LogicalExpr,
8985 outer_row: &Tuple,
8986 outer_schema: &Schema,
8987 outer_table: &str,
8988 ) -> sql::LogicalExpr {
8989 use sql::LogicalExpr;
8990 match expr {
8991 LogicalExpr::Column { table: Some(tbl), name } if tbl.eq_ignore_ascii_case(outer_table) => {
8992 if let Some(idx) = outer_schema.columns.iter().position(|c| c.name.eq_ignore_ascii_case(name)) {
8993 if let Some(v) = outer_row.values.get(idx) {
8994 return LogicalExpr::Literal(v.clone());
8995 }
8996 }
8997 expr.clone()
8998 }
8999 LogicalExpr::BinaryExpr { left, op, right } => LogicalExpr::BinaryExpr {
9000 left: Box::new(Self::bind_outer_refs_in_expr(left, outer_row, outer_schema, outer_table)),
9001 op: *op,
9002 right: Box::new(Self::bind_outer_refs_in_expr(right, outer_row, outer_schema, outer_table)),
9003 },
9004 LogicalExpr::UnaryExpr { op, expr: inner } => LogicalExpr::UnaryExpr {
9005 op: *op,
9006 expr: Box::new(Self::bind_outer_refs_in_expr(inner, outer_row, outer_schema, outer_table)),
9007 },
9008 LogicalExpr::IsNull { expr: inner, is_null } => LogicalExpr::IsNull {
9009 expr: Box::new(Self::bind_outer_refs_in_expr(inner, outer_row, outer_schema, outer_table)),
9010 is_null: *is_null,
9011 },
9012 LogicalExpr::Between { expr: inner, low, high, negated } => LogicalExpr::Between {
9013 expr: Box::new(Self::bind_outer_refs_in_expr(inner, outer_row, outer_schema, outer_table)),
9014 low: Box::new(Self::bind_outer_refs_in_expr(low, outer_row, outer_schema, outer_table)),
9015 high: Box::new(Self::bind_outer_refs_in_expr(high, outer_row, outer_schema, outer_table)),
9016 negated: *negated,
9017 },
9018 LogicalExpr::ScalarSubquery { subquery } => LogicalExpr::ScalarSubquery {
9019 subquery: Box::new(Self::bind_outer_refs_in_plan(subquery, outer_row, outer_schema, outer_table)),
9020 },
9021 other => other.clone(),
9022 }
9023 }
9024
9025 fn try_extract_pk_value(selection: Option<&sql::LogicalExpr>, schema: &Schema) -> Option<Value> {
9028 let predicate = selection?;
9029 let pk_col = schema.columns.iter().find(|c| c.primary_key)?;
9030
9031 if let sql::LogicalExpr::BinaryExpr { left, op: sql::BinaryOperator::Eq, right } = predicate {
9032 match (left.as_ref(), right.as_ref()) {
9033 (sql::LogicalExpr::Column { name, .. }, sql::LogicalExpr::Literal(val))
9034 if name == &pk_col.name => Some(val.clone()),
9035 (sql::LogicalExpr::Literal(val), sql::LogicalExpr::Column { name, .. })
9036 if name == &pk_col.name => Some(val.clone()),
9037 _ => None,
9038 }
9039 } else {
9040 None
9041 }
9042 }
9043
9044 fn apply_rls_to_plan(&self, plan: sql::LogicalPlan) -> Result<sql::LogicalPlan> {
9046 if self.tenant_manager.get_current_context().is_none() {
9048 return Ok(plan);
9049 }
9050 self.apply_rls_to_plan_recursive(plan)
9051 }
9052
9053 fn apply_rls_to_plan_recursive(&self, plan: sql::LogicalPlan) -> Result<sql::LogicalPlan> {
9055 match plan {
9056 sql::LogicalPlan::Scan { table_name, alias, schema, projection, as_of } => {
9057 if self.tenant_manager.should_apply_rls(&table_name, "SELECT") {
9059 if let Some((using_expr, _)) = self.tenant_manager.get_rls_conditions(&table_name, "SELECT") {
9060 let tenant_context = self.tenant_manager.get_current_context();
9062 let rls_evaluator = tenant::RLSExpressionEvaluator::new(
9063 schema.clone(),
9064 tenant_context
9065 );
9066 let filter_expr = rls_evaluator.parse(&using_expr)?;
9067
9068 let scan_plan = sql::LogicalPlan::Scan {
9070 table_name,
9071 alias: alias.clone(),
9072 schema,
9073 projection,
9074 as_of,
9075 };
9076
9077 return Ok(sql::LogicalPlan::Filter {
9078 input: Box::new(scan_plan),
9079 predicate: filter_expr,
9080 });
9081 }
9082 }
9083
9084 Ok(sql::LogicalPlan::Scan { table_name, alias, schema, projection, as_of })
9086 }
9087
9088 sql::LogicalPlan::Filter { input, predicate } => {
9089 Ok(sql::LogicalPlan::Filter {
9090 input: Box::new(self.apply_rls_to_plan_recursive(*input)?),
9091 predicate,
9092 })
9093 }
9094
9095 sql::LogicalPlan::Project { input, exprs, aliases, distinct, distinct_on } => {
9096 Ok(sql::LogicalPlan::Project {
9097 input: Box::new(self.apply_rls_to_plan_recursive(*input)?),
9098 exprs,
9099 aliases,
9100 distinct,
9101 distinct_on,
9102 })
9103 }
9104
9105 sql::LogicalPlan::Aggregate { input, group_by, aggr_exprs, having } => {
9106 Ok(sql::LogicalPlan::Aggregate {
9107 input: Box::new(self.apply_rls_to_plan_recursive(*input)?),
9108 group_by,
9109 aggr_exprs,
9110 having,
9111 })
9112 }
9113
9114 sql::LogicalPlan::Join { left, right, join_type, on, lateral } => {
9115 Ok(sql::LogicalPlan::Join {
9116 left: Box::new(self.apply_rls_to_plan_recursive(*left)?),
9117 right: Box::new(self.apply_rls_to_plan_recursive(*right)?),
9118 join_type,
9119 on,
9120 lateral,
9121 })
9122 }
9123
9124 sql::LogicalPlan::Sort { input, exprs, asc } => {
9125 Ok(sql::LogicalPlan::Sort {
9126 input: Box::new(self.apply_rls_to_plan_recursive(*input)?),
9127 exprs,
9128 asc,
9129 })
9130 }
9131
9132 sql::LogicalPlan::Limit { input, limit, offset, limit_param, offset_param } => {
9133 Ok(sql::LogicalPlan::Limit {
9134 input: Box::new(self.apply_rls_to_plan_recursive(*input)?),
9135 limit,
9136 offset,
9137 limit_param,
9138 offset_param,
9139 })
9140 }
9141
9142 sql::LogicalPlan::FilteredScan { table_name, alias, schema, projection, predicate, as_of } => {
9144 if self.tenant_manager.should_apply_rls(&table_name, "SELECT") {
9146 if let Some((using_expr, _)) = self.tenant_manager.get_rls_conditions(&table_name, "SELECT") {
9147 let tenant_context = self.tenant_manager.get_current_context();
9149 let rls_evaluator = tenant::RLSExpressionEvaluator::new(
9150 schema.clone(),
9151 tenant_context
9152 );
9153 let rls_predicate = rls_evaluator.parse(&using_expr)?;
9154
9155 let combined_predicate = if let Some(existing) = predicate {
9157 Some(sql::LogicalExpr::BinaryExpr {
9158 left: Box::new(existing),
9159 op: sql::BinaryOperator::And,
9160 right: Box::new(rls_predicate),
9161 })
9162 } else {
9163 Some(rls_predicate)
9164 };
9165
9166 return Ok(sql::LogicalPlan::FilteredScan {
9167 table_name,
9168 alias,
9169 schema,
9170 projection,
9171 predicate: combined_predicate,
9172 as_of,
9173 });
9174 }
9175 }
9176
9177 Ok(sql::LogicalPlan::FilteredScan { table_name, alias, schema, projection, predicate, as_of })
9179 }
9180
9181 other => Ok(other),
9183 }
9184 }
9185
9186 pub async fn start_auto_refresh(
9219 &self,
9220 config: Option<storage::AutoRefreshConfig>,
9221 ) -> Result<()> {
9222 let worker_config = config.unwrap_or_else(|| {
9223 storage::AutoRefreshConfig::default()
9224 .with_enabled(true)
9225 .with_interval_seconds(self.config.materialized_views.refresh_check_interval_secs)
9226 .with_staleness_threshold(300) .with_max_cpu_percent(self.config.materialized_views.default_max_cpu_percent as f64)
9228 .with_max_concurrent(self.config.materialized_views.max_concurrent_refreshes)
9229 });
9230
9231 let mut worker = storage::AutoRefreshWorker::new(
9232 worker_config,
9233 std::sync::Arc::clone(&self.storage),
9234 std::sync::Arc::clone(&self.mv_scheduler),
9235 );
9236
9237 worker.start().await?;
9238
9239 *self.auto_refresh_worker.write() = Some(worker);
9241
9242 tracing::info!("Materialized view auto-refresh worker started");
9243 Ok(())
9244 }
9245
9246 pub async fn stop_auto_refresh(&self) -> Result<()> {
9250 let worker = {
9251 let mut worker_guard = self.auto_refresh_worker.write();
9252 worker_guard.take()
9253 };
9254 if let Some(mut worker) = worker {
9255 worker.stop().await?;
9256 tracing::info!("Materialized view auto-refresh worker stopped");
9257 }
9258 Ok(())
9259 }
9260
9261 pub fn is_auto_refresh_running(&self) -> bool {
9263 self.auto_refresh_worker.read().as_ref()
9264 .map(|w| w.is_running())
9265 .unwrap_or(false)
9266 }
9267
9268 pub fn mv_scheduler(&self) -> &std::sync::Arc<storage::MVScheduler> {
9270 &self.mv_scheduler
9271 }
9272
9273 pub fn check_mv_staleness_now(&self) -> Result<()> {
9278 let worker_guard = self.auto_refresh_worker.read();
9279 if let Some(ref worker) = *worker_guard {
9280 worker.check_now()?;
9281 Ok(())
9282 } else {
9283 Err(Error::query_execution("Auto-refresh worker is not running"))
9284 }
9285 }
9286}
9287
9288pub struct Transaction<'a> {
9295 tx: storage::Transaction,
9296 db: &'a EmbeddedDatabase,
9298}
9299
9300impl Transaction<'_> {
9301 pub fn commit(self) -> Result<()> {
9306 self.tx.commit()
9307 }
9308
9309 pub fn rollback(self) -> Result<()> {
9314 self.tx.rollback()
9315 }
9316
9317 pub fn execute(&self, sql: &str) -> Result<u64> {
9348 self.db.execute_in_transaction_no_fast_path(sql, &self.tx)
9352 }
9353
9354 pub fn query(&self, sql: &str, _params: &[&dyn std::fmt::Display]) -> Result<Vec<Tuple>> {
9391 let (statement, _) = self.db.parse_cached(sql)?;
9393
9394 let catalog = self.db.storage.catalog();
9396 let planner = sql::Planner::with_catalog(&catalog)
9397 .with_sql(sql.to_string());
9398 let plan = planner.statement_to_plan(statement)?;
9399
9400 let mut executor = sql::Executor::with_storage(&self.db.storage)
9404 .with_timeout(self.db.config.storage.query_timeout_ms)
9405 .with_transaction(&self.tx);
9406
9407 executor.execute(&plan)
9408 }
9409}
9410
9411#[cfg(test)]
9412#[allow(clippy::unwrap_used, clippy::expect_used)]
9413#[allow(
9415 clippy::unwrap_used,
9416 clippy::expect_used,
9417 clippy::panic,
9418 clippy::indexing_slicing,
9419)]
9420mod tests {
9421 use super::*;
9422
9423 #[test]
9424 fn test_embedded_database_creation() {
9425 let db = EmbeddedDatabase::new_in_memory();
9426 assert!(db.is_ok());
9427 }
9428
9429 #[test]
9434 fn test_savepoint_basic_via_execute_works_in_transaction() {
9435 let db = EmbeddedDatabase::new_in_memory().unwrap();
9437 db.execute("CREATE TABLE sp_basic (id INT, val TEXT)").unwrap();
9438
9439 db.execute("BEGIN").unwrap();
9440 let result = db.execute("SAVEPOINT s1");
9441 assert!(result.is_ok(),
9442 "SAVEPOINT via execute() in BEGIN block should succeed, got: {:?}", result.err());
9443 db.execute("ROLLBACK").unwrap();
9444 }
9445
9446 #[test]
9447 fn test_savepoint_outside_transaction_succeeds_in_implicit_txn() {
9448 let db = EmbeddedDatabase::new_in_memory().unwrap();
9453 let result = db.execute("SAVEPOINT s1");
9454 assert!(result.is_ok(),
9455 "SAVEPOINT in implicit transaction should succeed, got: {:?}", result.err());
9456 }
9457
9458 #[test]
9459 fn test_savepoint_via_execute_returning_path() {
9460 let db = EmbeddedDatabase::new_in_memory().unwrap();
9463 db.execute("CREATE TABLE sp_ret (id INT, val TEXT)").unwrap();
9464
9465 db.execute_returning("BEGIN").unwrap();
9467 let result = db.execute_returning("SAVEPOINT s1");
9468 if result.is_ok() {
9470 let _ = db.execute_returning("INSERT INTO sp_ret VALUES (1, 'test')");
9471 let release_result = db.execute_returning("RELEASE SAVEPOINT s1");
9472 assert!(release_result.is_ok(), "RELEASE SAVEPOINT should work via returning path");
9473 let _ = db.execute_returning("COMMIT");
9474 } else {
9475 let err = result.unwrap_err().to_string();
9477 assert!(err.contains("not yet implemented") || err.contains("SAVEPOINT"),
9478 "Unexpected error: {}", err);
9479 let _ = db.execute_returning("ROLLBACK");
9480 }
9481 }
9482
9483 #[test]
9484 fn test_savepoint_nonexistent_rollback_errors() {
9485 let db = EmbeddedDatabase::new_in_memory().unwrap();
9487 db.execute("CREATE TABLE sp_noexist (id INT)").unwrap();
9488
9489 db.execute("BEGIN").unwrap();
9490 let result = db.execute("ROLLBACK TO SAVEPOINT nonexistent");
9491 assert!(result.is_err(), "ROLLBACK TO nonexistent savepoint should fail");
9492 db.execute("ROLLBACK").unwrap();
9493 }
9494
9495 #[test]
9496 fn test_savepoint_nonexistent_release_errors() {
9497 let db = EmbeddedDatabase::new_in_memory().unwrap();
9499
9500 db.execute("BEGIN").unwrap();
9501 let result = db.execute("RELEASE SAVEPOINT nonexistent");
9502 assert!(result.is_err(), "RELEASE nonexistent savepoint should fail");
9503 db.execute("ROLLBACK").unwrap();
9504 }
9505
9506 #[test]
9507 fn test_savepoint_nested_release_via_returning() {
9508 let db = EmbeddedDatabase::new_in_memory().unwrap();
9512 db.execute("CREATE TABLE sp_nested_rel (id INT, val TEXT)").unwrap();
9513
9514 db.execute("BEGIN").unwrap();
9516 let sp1 = db.execute_returning("SAVEPOINT s1");
9517 if sp1.is_err() {
9518 db.execute("ROLLBACK").unwrap();
9520 return;
9521 }
9522 db.execute("INSERT INTO sp_nested_rel VALUES (1, 'A')").unwrap();
9523 db.execute_returning("SAVEPOINT s2").unwrap();
9524 db.execute("INSERT INTO sp_nested_rel VALUES (2, 'B')").unwrap();
9525 db.execute_returning("RELEASE SAVEPOINT s2").unwrap();
9526 db.execute_returning("RELEASE SAVEPOINT s1").unwrap();
9527 db.execute("COMMIT").unwrap();
9528
9529 let rows = db.query("SELECT * FROM sp_nested_rel", &[]).unwrap();
9530 assert_eq!(rows.len(), 2, "Both A and B should be preserved after nested RELEASE + COMMIT");
9531 }
9532
9533 #[test]
9534 fn test_savepoint_rollback_to_undoes_inserts() {
9535 let db = EmbeddedDatabase::new_in_memory().unwrap();
9537 db.execute("CREATE TABLE sp_stub (id INT, val TEXT)").unwrap();
9538
9539 db.execute("BEGIN").unwrap();
9540 let sp1 = db.execute_returning("SAVEPOINT s1");
9541 if sp1.is_err() {
9542 db.execute("ROLLBACK").unwrap();
9543 return;
9544 }
9545 db.execute("INSERT INTO sp_stub VALUES (1, 'should_vanish')").unwrap();
9546 let rb = db.execute_returning("ROLLBACK TO SAVEPOINT s1");
9547 if rb.is_err() {
9548 db.execute("ROLLBACK").unwrap();
9549 return;
9550 }
9551 db.execute("COMMIT").unwrap();
9552
9553 let rows = db.query("SELECT * FROM sp_stub", &[]).unwrap();
9554 assert_eq!(rows.len(), 0,
9558 "ROLLBACK TO SAVEPOINT should undo INSERTs via transaction write set");
9559 }
9560
9561 #[test]
9562 fn test_savepoint_reuse_name_after_release_via_returning() {
9563 let db = EmbeddedDatabase::new_in_memory().unwrap();
9565 db.execute("CREATE TABLE sp_reuse (id INT)").unwrap();
9566
9567 db.execute("BEGIN").unwrap();
9568 let sp1 = db.execute_returning("SAVEPOINT s1");
9569 if sp1.is_err() {
9570 db.execute("ROLLBACK").unwrap();
9571 return;
9572 }
9573 db.execute("INSERT INTO sp_reuse VALUES (1)").unwrap();
9574 db.execute_returning("RELEASE SAVEPOINT s1").unwrap();
9575
9576 db.execute_returning("SAVEPOINT s1").unwrap();
9578 db.execute("INSERT INTO sp_reuse VALUES (2)").unwrap();
9579 db.execute_returning("RELEASE SAVEPOINT s1").unwrap();
9580
9581 db.execute("COMMIT").unwrap();
9582
9583 let rows = db.query("SELECT * FROM sp_reuse", &[]).unwrap();
9584 assert_eq!(rows.len(), 2, "Both inserts should persist after reuse of savepoint name");
9585 }
9586
9587 #[test]
9592 fn test_transaction_read_committed() {
9593 let db = EmbeddedDatabase::new_in_memory().unwrap();
9595 db.execute("CREATE TABLE iso_rc (id INT, val TEXT)").unwrap();
9596
9597 let s1 = db.create_session("user1", crate::session::IsolationLevel::ReadCommitted).unwrap();
9598 let s2 = db.create_session("user2", crate::session::IsolationLevel::ReadCommitted).unwrap();
9599
9600 db.begin_transaction_for_session(s1).unwrap();
9602 db.execute_in_session(s1, "INSERT INTO iso_rc VALUES (1, 'uncommitted')").unwrap();
9603
9604 let rows = db.query_in_session(s2, "SELECT * FROM iso_rc", &[]).unwrap();
9606 assert_eq!(rows.len(), 0,
9607 "Uncommitted writes from S1 should be invisible to S2 (read committed)");
9608
9609 db.commit_transaction_for_session(s1).unwrap();
9611
9612 let rows = db.query_in_session(s2, "SELECT * FROM iso_rc WHERE 1=1", &[]).unwrap();
9616 assert_eq!(rows.len(), 1, "After S1 commits, S2 should see the row");
9617
9618 db.destroy_session(s1).unwrap();
9619 db.destroy_session(s2).unwrap();
9620 }
9621
9622 #[test]
9623 fn test_transaction_dirty_read_prevented() {
9624 let db = EmbeddedDatabase::new_in_memory().unwrap();
9626 db.execute("CREATE TABLE iso_dirty (id INT, val TEXT)").unwrap();
9627 db.execute("INSERT INTO iso_dirty VALUES (1, 'visible')").unwrap();
9628
9629 let s1 = db.create_session("writer", crate::session::IsolationLevel::ReadCommitted).unwrap();
9630 let s2 = db.create_session("reader", crate::session::IsolationLevel::ReadCommitted).unwrap();
9631
9632 db.begin_transaction_for_session(s1).unwrap();
9634 db.execute_in_session(s1, "INSERT INTO iso_dirty VALUES (2, 'dirty')").unwrap();
9635
9636 let rows = db.query_in_session(s2, "SELECT * FROM iso_dirty", &[]).unwrap();
9638 assert_eq!(rows.len(), 1, "S2 should only see committed data, not dirty writes");
9639 assert_eq!(rows[0].get(1), Some(&Value::String("visible".to_string())));
9640
9641 db.rollback_transaction_for_session(s1).unwrap();
9642 db.destroy_session(s1).unwrap();
9643 db.destroy_session(s2).unwrap();
9644 }
9645
9646 #[test]
9647 fn test_transaction_rollback_visibility() {
9648 let db = EmbeddedDatabase::new_in_memory().unwrap();
9650 db.execute("CREATE TABLE iso_rb_vis (id INT, val TEXT)").unwrap();
9651
9652 let s1 = db.create_session("writer", crate::session::IsolationLevel::ReadCommitted).unwrap();
9653 let s2 = db.create_session("reader", crate::session::IsolationLevel::ReadCommitted).unwrap();
9654
9655 db.begin_transaction_for_session(s1).unwrap();
9657 db.execute_in_session(s1, "INSERT INTO iso_rb_vis VALUES (1, 'rolled_back')").unwrap();
9658 db.rollback_transaction_for_session(s1).unwrap();
9659
9660 let rows = db.query_in_session(s2, "SELECT * FROM iso_rb_vis", &[]).unwrap();
9662 assert_eq!(rows.len(), 0, "Rolled-back data should never be visible");
9663
9664 let rows = db.query("SELECT * FROM iso_rb_vis", &[]).unwrap();
9666 assert_eq!(rows.len(), 0, "Rolled-back data should be invisible via default query path too");
9667
9668 db.destroy_session(s1).unwrap();
9669 db.destroy_session(s2).unwrap();
9670 }
9671
9672 #[test]
9677 fn test_concurrent_inserts_different_rows() {
9678 use std::sync::Arc;
9680
9681 let db = Arc::new(EmbeddedDatabase::new_in_memory().unwrap());
9682 db.execute("CREATE TABLE conc_ins (id INT, thread_id INT)").unwrap();
9683
9684 let num_threads = 4;
9685 let rows_per_thread = 25;
9686 let mut handles = Vec::new();
9687
9688 for t in 0..num_threads {
9689 let db_clone = Arc::clone(&db);
9690 let handle = std::thread::spawn(move || {
9691 for i in 0..rows_per_thread {
9692 let id = t * rows_per_thread + i;
9693 db_clone.execute(
9694 &format!("INSERT INTO conc_ins VALUES ({}, {})", id, t)
9695 ).unwrap();
9696 }
9697 });
9698 handles.push(handle);
9699 }
9700
9701 for h in handles {
9702 h.join().expect("Thread panicked");
9703 }
9704
9705 let rows = db.query("SELECT * FROM conc_ins", &[]).unwrap();
9706 assert_eq!(rows.len(), (num_threads * rows_per_thread) as usize,
9707 "All inserts from all threads should be visible");
9708 }
9709
9710 #[test]
9711 fn test_concurrent_reads_during_write() {
9712 use std::sync::Arc;
9715 use std::sync::atomic::{AtomicBool, Ordering};
9716
9717 let db = Arc::new(EmbeddedDatabase::new_in_memory().unwrap());
9718 db.execute("CREATE TABLE conc_rw (id INT, val TEXT)").unwrap();
9719
9720 let done = Arc::new(AtomicBool::new(false));
9721
9722 let db_w = Arc::clone(&db);
9724 let done_w = Arc::clone(&done);
9725 let writer = std::thread::spawn(move || {
9726 for i in 0..50 {
9727 db_w.execute(&format!("INSERT INTO conc_rw VALUES ({}, 'row_{}')", i, i)).unwrap();
9728 }
9729 done_w.store(true, Ordering::Release);
9730 });
9731
9732 let mut readers = Vec::new();
9734 for t in 0..3_usize {
9735 let db_r = Arc::clone(&db);
9736 let done_r = Arc::clone(&done);
9737 let reader = std::thread::spawn(move || {
9738 let mut query_count = 0_usize;
9739 while !done_r.load(Ordering::Acquire) {
9740 let sql = format!(
9743 "SELECT * FROM conc_rw WHERE 1=1 /* t{}q{} */", t, query_count
9744 );
9745 let rows = db_r.query(&sql, &[]).unwrap();
9746 assert!(rows.len() <= 50, "Should never exceed 50 rows");
9748 query_count += 1;
9749 std::thread::yield_now();
9750 }
9751 assert!(query_count > 0, "Reader should have executed at least one query");
9752 });
9753 readers.push(reader);
9754 }
9755
9756 writer.join().expect("Writer panicked");
9757 for r in readers {
9758 r.join().expect("Reader panicked");
9759 }
9760
9761 let final_rows = db.query("SELECT * FROM conc_rw WHERE 1=1 /* final */", &[]).unwrap();
9763 assert_eq!(final_rows.len(), 50, "All 50 rows should be visible after writer completes");
9764 }
9765
9766 #[test]
9767 fn test_concurrent_counter_increment() {
9768 use std::sync::Arc;
9772
9773 let db = Arc::new(EmbeddedDatabase::new_in_memory().unwrap());
9774 db.execute("CREATE TABLE conc_counter (id INT, cnt INT)").unwrap();
9775 db.execute("INSERT INTO conc_counter VALUES (1, 0)").unwrap();
9776
9777 let num_threads = 4;
9778 let increments_per_thread = 10;
9779 let mut handles = Vec::new();
9780
9781 for _ in 0..num_threads {
9782 let db_clone = Arc::clone(&db);
9783 let handle = std::thread::spawn(move || {
9784 for _ in 0..increments_per_thread {
9785 let rows = db_clone.query("SELECT cnt FROM conc_counter WHERE id = 1", &[]).unwrap();
9787 if let Some(row) = rows.first() {
9788 if let Some(Value::Int4(current)) = row.get(0) {
9789 let new_val = current + 1;
9790 db_clone.execute(
9791 &format!("UPDATE conc_counter SET cnt = {} WHERE id = 1", new_val)
9792 ).unwrap();
9793 }
9794 }
9795 }
9796 });
9797 handles.push(handle);
9798 }
9799
9800 for h in handles {
9801 h.join().expect("Thread panicked");
9802 }
9803
9804 let rows = db.query("SELECT cnt FROM conc_counter WHERE id = 1", &[]).unwrap();
9805 assert_eq!(rows.len(), 1, "Counter row should still exist");
9806 if let Some(Value::Int4(final_val)) = rows[0].get(0) {
9807 let max_expected = (num_threads * increments_per_thread) as i32;
9811 assert!(*final_val > 0, "Counter should have been incremented at least once");
9812 assert!(*final_val <= max_expected,
9813 "Counter {} should not exceed {}", final_val, max_expected);
9814 if *final_val < max_expected {
9816 }
9818 } else {
9819 panic!("Counter value should be Int4");
9820 }
9821 }
9822
9823 #[test]
9824 fn test_concurrent_transactions_different_tables() {
9825 use std::sync::Arc;
9827
9828 let db = Arc::new(EmbeddedDatabase::new_in_memory().unwrap());
9829
9830 let num_threads = 4;
9831 let mut handles = Vec::new();
9832
9833 for t in 0..num_threads {
9835 db.execute(&format!("CREATE TABLE conc_tbl_{} (id INT, val TEXT)", t)).unwrap();
9836 }
9837
9838 for t in 0..num_threads {
9839 let db_clone = Arc::clone(&db);
9840 let handle = std::thread::spawn(move || {
9841 let session = db_clone.create_session(
9842 &format!("user{}", t),
9843 crate::session::IsolationLevel::ReadCommitted,
9844 ).unwrap();
9845
9846 db_clone.begin_transaction_for_session(session).unwrap();
9847 for i in 0..10 {
9848 db_clone.execute_in_session(session,
9849 &format!("INSERT INTO conc_tbl_{} VALUES ({}, 'val_{}')", t, i, i)
9850 ).unwrap();
9851 }
9852 db_clone.commit_transaction_for_session(session).unwrap();
9853 db_clone.destroy_session(session).unwrap();
9854 });
9855 handles.push(handle);
9856 }
9857
9858 for h in handles {
9859 h.join().expect("Thread panicked");
9860 }
9861
9862 for t in 0..num_threads {
9864 let rows = db.query(&format!("SELECT * FROM conc_tbl_{}", t), &[]).unwrap();
9865 assert_eq!(rows.len(), 10,
9866 "Table conc_tbl_{} should have 10 rows, got {}", t, rows.len());
9867 }
9868 }
9869
9870 #[test]
9875 fn test_transaction_double_commit() {
9876 let db = EmbeddedDatabase::new_in_memory().unwrap();
9878 db.execute("CREATE TABLE dbl_commit (id INT)").unwrap();
9879
9880 db.execute("BEGIN").unwrap();
9881 db.execute("INSERT INTO dbl_commit VALUES (1)").unwrap();
9882 db.execute("COMMIT").unwrap();
9883
9884 let result = db.execute("COMMIT");
9885 assert!(result.is_err(), "Second COMMIT without active transaction should fail");
9886 }
9887
9888 #[test]
9889 fn test_transaction_double_rollback() {
9890 let db = EmbeddedDatabase::new_in_memory().unwrap();
9892
9893 db.execute("BEGIN").unwrap();
9894 db.execute("ROLLBACK").unwrap();
9895
9896 let result = db.execute("ROLLBACK");
9897 assert!(result.is_err(), "Second ROLLBACK without active transaction should fail");
9898 }
9899
9900 #[test]
9901 fn test_autocommit_mode() {
9902 let db = EmbeddedDatabase::new_in_memory().unwrap();
9904 db.execute("CREATE TABLE autocommit (id INT, val TEXT)").unwrap();
9905
9906 db.execute("INSERT INTO autocommit VALUES (1, 'a')").unwrap();
9908 db.execute("INSERT INTO autocommit VALUES (2, 'b')").unwrap();
9909 db.execute("INSERT INTO autocommit VALUES (3, 'c')").unwrap();
9910
9911 let rows = db.query("SELECT * FROM autocommit", &[]).unwrap();
9912 assert_eq!(rows.len(), 3, "All auto-committed inserts should be visible");
9913
9914 db.execute("UPDATE autocommit SET val = 'updated' WHERE id = 2").unwrap();
9916 let rows = db.query("SELECT val FROM autocommit WHERE id = 2", &[]).unwrap();
9917 assert_eq!(rows.len(), 1);
9918 assert_eq!(rows[0].get(0), Some(&Value::String("updated".to_string())));
9919
9920 db.execute("DELETE FROM autocommit WHERE id = 3").unwrap();
9922 let rows = db.query("SELECT * FROM autocommit", &[]).unwrap();
9923 assert_eq!(rows.len(), 2, "Delete should be auto-committed");
9924 }
9925
9926 #[test]
9927 fn test_ddl_in_transaction_commit() {
9928 let db = EmbeddedDatabase::new_in_memory().unwrap();
9930
9931 db.execute("BEGIN").unwrap();
9932 db.execute("CREATE TABLE ddl_txn (id INT, val TEXT)").unwrap();
9933 db.execute("INSERT INTO ddl_txn VALUES (1, 'hello')").unwrap();
9934 db.execute("COMMIT").unwrap();
9935
9936 let rows = db.query("SELECT * FROM ddl_txn", &[]).unwrap();
9937 assert_eq!(rows.len(), 1, "DDL + DML in committed transaction should persist");
9938 }
9939
9940 #[test]
9941 fn test_ddl_in_transaction_rollback() {
9942 let db = EmbeddedDatabase::new_in_memory().unwrap();
9945
9946 db.execute("BEGIN").unwrap();
9947 db.execute("CREATE TABLE ddl_rb (id INT, val TEXT)").unwrap();
9948 db.execute("INSERT INTO ddl_rb VALUES (1, 'hello')").unwrap();
9949 db.execute("ROLLBACK").unwrap();
9950
9951 let query_result = db.query("SELECT * FROM ddl_rb", &[]);
9954 if let Ok(rows) = query_result {
9959 assert!(rows.is_empty() || rows.len() == 1,
9961 "DDL rollback behavior: table exists with {} rows", rows.len());
9962 }
9963 }
9965
9966 #[test]
9967 fn test_empty_transaction_commit() {
9968 let db = EmbeddedDatabase::new_in_memory().unwrap();
9970
9971 db.execute("BEGIN").unwrap();
9972 assert!(db.in_transaction());
9973 db.execute("COMMIT").unwrap();
9974 assert!(!db.in_transaction());
9975 }
9976
9977 #[test]
9978 fn test_empty_transaction_rollback() {
9979 let db = EmbeddedDatabase::new_in_memory().unwrap();
9981
9982 db.execute("BEGIN").unwrap();
9983 assert!(db.in_transaction());
9984 db.execute("ROLLBACK").unwrap();
9985 assert!(!db.in_transaction());
9986 }
9987
9988 #[test]
9989 fn test_transaction_after_error() {
9990 let db = EmbeddedDatabase::new_in_memory().unwrap();
9993 db.execute("CREATE TABLE txn_err (id INT, val TEXT)").unwrap();
9994
9995 db.execute("BEGIN").unwrap();
9996
9997 let result = db.execute("INSERT INTO nonexistent_table VALUES (1)");
9999 assert!(result.is_err(), "Insert into nonexistent table should fail");
10000
10001 assert!(db.in_transaction(), "Transaction should still be active after statement error");
10003
10004 db.execute("INSERT INTO txn_err VALUES (1, 'after_error')").unwrap();
10006 db.execute("COMMIT").unwrap();
10007
10008 let rows = db.query("SELECT * FROM txn_err", &[]).unwrap();
10009 assert_eq!(rows.len(), 1, "Valid insert after error should be committed");
10010 assert_eq!(rows[0].get(1), Some(&Value::String("after_error".to_string())));
10011 }
10012
10013 #[test]
10014 fn test_begin_while_in_transaction_errors() {
10015 let db = EmbeddedDatabase::new_in_memory().unwrap();
10017
10018 db.execute("BEGIN").unwrap();
10019 let result = db.execute("BEGIN");
10020 assert!(result.is_err(), "Nested BEGIN should fail");
10021 assert!(result.unwrap_err().to_string().contains("already active"),
10022 "Error should mention transaction already active");
10023
10024 db.execute("ROLLBACK").unwrap();
10025 }
10026
10027 #[test]
10028 fn test_transaction_commit_then_new_transaction() {
10029 let db = EmbeddedDatabase::new_in_memory().unwrap();
10031 db.execute("CREATE TABLE txn_seq (id INT)").unwrap();
10032
10033 db.execute("BEGIN").unwrap();
10035 db.execute("INSERT INTO txn_seq VALUES (1)").unwrap();
10036 db.execute("COMMIT").unwrap();
10037
10038 db.execute("BEGIN").unwrap();
10040 db.execute("INSERT INTO txn_seq VALUES (2)").unwrap();
10041 db.execute("COMMIT").unwrap();
10042
10043 db.execute("BEGIN").unwrap();
10045 db.execute("INSERT INTO txn_seq VALUES (3)").unwrap();
10046 db.execute("ROLLBACK").unwrap();
10047
10048 db.execute("BEGIN").unwrap();
10050 db.execute("INSERT INTO txn_seq VALUES (4)").unwrap();
10051 db.execute("COMMIT").unwrap();
10052
10053 let rows = db.query("SELECT * FROM txn_seq", &[]).unwrap();
10054 assert_eq!(rows.len(), 3, "Rows from txn 1, 2, 4 should exist (txn 3 rolled back)");
10055 }
10056
10057 #[test]
10062 fn test_insert_rollback_pk_reuse() {
10063 let db = EmbeddedDatabase::new_in_memory().unwrap();
10065 db.execute("CREATE TABLE pk_reuse (id INT PRIMARY KEY, val TEXT)").unwrap();
10066
10067 db.execute("BEGIN").unwrap();
10069 db.execute("INSERT INTO pk_reuse VALUES (1, 'rolled_back')").unwrap();
10070 db.execute("ROLLBACK").unwrap();
10071
10072 db.execute("INSERT INTO pk_reuse VALUES (1, 'final')").unwrap();
10074 let rows = db.query("SELECT val FROM pk_reuse WHERE id = 1", &[]).unwrap();
10075 assert_eq!(rows.len(), 1);
10076 assert_eq!(rows[0].get(0), Some(&Value::String("final".to_string())));
10077 }
10078
10079 #[test]
10080 fn test_update_rollback_preserves_original() {
10081 let db = EmbeddedDatabase::new_in_memory().unwrap();
10083 db.execute("CREATE TABLE upd_rb (id INT, val TEXT)").unwrap();
10084 db.execute("INSERT INTO upd_rb VALUES (1, 'original')").unwrap();
10085
10086 db.execute("BEGIN").unwrap();
10087 db.execute("UPDATE upd_rb SET val = 'changed' WHERE id = 1").unwrap();
10088 db.execute("ROLLBACK").unwrap();
10089
10090 let rows = db.query("SELECT val FROM upd_rb WHERE id = 1", &[]).unwrap();
10091 assert_eq!(rows.len(), 1);
10092 let val = rows[0].get(0);
10093 assert_eq!(val, Some(&Value::String("original".to_string())),
10094 "ROLLBACK should undo the UPDATE");
10095 if true {
10096 }
10098 }
10099
10100 #[test]
10101 fn test_delete_rollback_preserves_row() {
10102 let db = EmbeddedDatabase::new_in_memory().unwrap();
10104 db.execute("CREATE TABLE del_rb (id INT, val TEXT)").unwrap();
10105 db.execute("INSERT INTO del_rb VALUES (1, 'keep_me')").unwrap();
10106
10107 db.execute("BEGIN").unwrap();
10108 db.execute("DELETE FROM del_rb WHERE id = 1").unwrap();
10109 db.execute("ROLLBACK").unwrap();
10110
10111 let rows = db.query("SELECT * FROM del_rb", &[]).unwrap();
10112 assert_eq!(rows.len(), 1, "ROLLBACK should undo the DELETE");
10113 assert_eq!(rows[0].get(1), Some(&Value::String("keep_me".to_string())));
10114 }
10115
10116 #[test]
10117 fn test_insert_commit_data_integrity() {
10118 let db = EmbeddedDatabase::new_in_memory().unwrap();
10120 db.execute("CREATE TABLE integrity (id INT, name TEXT, score FLOAT, active BOOLEAN)").unwrap();
10121
10122 db.execute("BEGIN").unwrap();
10123 db.execute("INSERT INTO integrity VALUES (42, 'test_name', 3.14, true)").unwrap();
10124 db.execute("COMMIT").unwrap();
10125
10126 let rows = db.query("SELECT * FROM integrity WHERE id = 42", &[]).unwrap();
10127 assert_eq!(rows.len(), 1);
10128 assert_eq!(rows[0].get(0), Some(&Value::Int4(42)));
10129 assert_eq!(rows[0].get(1), Some(&Value::String("test_name".to_string())));
10130 if let Some(Value::Float8(f)) = rows[0].get(2) {
10132 assert!((f - 3.14).abs() < 0.001, "Float should be ~3.14, got {}", f);
10133 } else if let Some(Value::Float4(f)) = rows[0].get(2) {
10134 assert!((f - 3.14_f32).abs() < 0.01, "Float should be ~3.14, got {}", f);
10135 } else {
10136 panic!("Score should be a float type, got {:?}", rows[0].get(2));
10137 }
10138 assert_eq!(rows[0].get(3), Some(&Value::Boolean(true)));
10139 }
10140
10141 #[test]
10142 fn test_multiple_inserts_rollback_clears_all() {
10143 let db = EmbeddedDatabase::new_in_memory().unwrap();
10145 db.execute("CREATE TABLE multi_rb (id INT, val TEXT)").unwrap();
10146
10147 db.execute("BEGIN").unwrap();
10148 for i in 1..=10 {
10149 db.execute(&format!("INSERT INTO multi_rb VALUES ({}, 'row_{}')", i, i)).unwrap();
10150 }
10151 db.execute("ROLLBACK").unwrap();
10152
10153 let rows = db.query("SELECT * FROM multi_rb", &[]).unwrap();
10154 assert_eq!(rows.len(), 0, "All 10 inserts should be rolled back");
10155 }
10156
10157 #[test]
10158 fn test_transaction_with_multiple_tables() {
10159 let db = EmbeddedDatabase::new_in_memory().unwrap();
10161 db.execute("CREATE TABLE multi_a (id INT, val TEXT)").unwrap();
10162 db.execute("CREATE TABLE multi_b (id INT, ref_id INT)").unwrap();
10163
10164 db.execute("BEGIN").unwrap();
10165 db.execute("INSERT INTO multi_a VALUES (1, 'parent')").unwrap();
10166 db.execute("INSERT INTO multi_b VALUES (100, 1)").unwrap();
10167 db.execute("INSERT INTO multi_b VALUES (101, 1)").unwrap();
10168 db.execute("COMMIT").unwrap();
10169
10170 let rows_a = db.query("SELECT * FROM multi_a", &[]).unwrap();
10171 let rows_b = db.query("SELECT * FROM multi_b", &[]).unwrap();
10172 assert_eq!(rows_a.len(), 1, "Parent table should have 1 row");
10173 assert_eq!(rows_b.len(), 2, "Child table should have 2 rows");
10174 }
10175
10176 #[test]
10177 fn test_transaction_with_multiple_tables_rollback() {
10178 let db = EmbeddedDatabase::new_in_memory().unwrap();
10180 db.execute("CREATE TABLE multi_rb_a (id INT, val TEXT)").unwrap();
10181 db.execute("CREATE TABLE multi_rb_b (id INT, ref_id INT)").unwrap();
10182
10183 db.execute("BEGIN").unwrap();
10184 db.execute("INSERT INTO multi_rb_a VALUES (1, 'parent')").unwrap();
10185 db.execute("INSERT INTO multi_rb_b VALUES (100, 1)").unwrap();
10186 db.execute("INSERT INTO multi_rb_b VALUES (101, 1)").unwrap();
10187 db.execute("ROLLBACK").unwrap();
10188
10189 let rows_a = db.query("SELECT * FROM multi_rb_a", &[]).unwrap();
10190 let rows_b = db.query("SELECT * FROM multi_rb_b", &[]).unwrap();
10191 assert_eq!(rows_a.len(), 0, "Parent table should be empty after rollback");
10192 assert_eq!(rows_b.len(), 0, "Child table should be empty after rollback");
10193 }
10194
10195 #[test]
10200 fn test_transaction_handle_commit() {
10201 let db = EmbeddedDatabase::new_in_memory().unwrap();
10203 db.execute("CREATE TABLE txn_handle (id INT, val TEXT)").unwrap();
10204
10205 let tx = db.begin_transaction().unwrap();
10206 tx.execute("INSERT INTO txn_handle VALUES (1, 'via_handle')").unwrap();
10207 tx.execute("INSERT INTO txn_handle VALUES (2, 'via_handle')").unwrap();
10208 tx.commit().unwrap();
10209
10210 let rows = db.query("SELECT * FROM txn_handle", &[]).unwrap();
10211 assert_eq!(rows.len(), 2, "Both inserts via Transaction handle should be committed");
10212 }
10213
10214 #[test]
10215 fn test_transaction_handle_rollback() {
10216 let db = EmbeddedDatabase::new_in_memory().unwrap();
10218 db.execute("CREATE TABLE txn_h_rb (id INT, val TEXT)").unwrap();
10219 db.execute("INSERT INTO txn_h_rb VALUES (0, 'pre_existing')").unwrap();
10220
10221 let tx = db.begin_transaction().unwrap();
10222 tx.execute("INSERT INTO txn_h_rb VALUES (1, 'will_rollback')").unwrap();
10223 tx.rollback().unwrap();
10224
10225 let rows = db.query("SELECT * FROM txn_h_rb", &[]).unwrap();
10226 assert_eq!(rows.len(), 1, "Only pre-existing row should remain after rollback");
10227 assert_eq!(rows[0].get(1), Some(&Value::String("pre_existing".to_string())));
10228 }
10229
10230 #[test]
10231 fn test_transaction_handle_query() {
10232 let db = EmbeddedDatabase::new_in_memory().unwrap();
10234 db.execute("CREATE TABLE txn_h_q (id INT, val TEXT)").unwrap();
10235 db.execute("INSERT INTO txn_h_q VALUES (1, 'committed')").unwrap();
10236
10237 let tx = db.begin_transaction().unwrap();
10238 tx.execute("INSERT INTO txn_h_q VALUES (2, 'in_txn')").unwrap();
10239 let rows = tx.query("SELECT * FROM txn_h_q", &[]).unwrap();
10240 assert!(rows.len() >= 1, "Should see at least the committed row");
10243 tx.commit().unwrap();
10244
10245 let rows = db.query("SELECT * FROM txn_h_q", &[]).unwrap();
10246 assert_eq!(rows.len(), 2, "Both rows should be visible after commit");
10247 }
10248
10249 #[test]
10254 fn test_session_sequential_transactions() {
10255 let db = EmbeddedDatabase::new_in_memory().unwrap();
10257 db.execute("CREATE TABLE sess_seq (id INT, val TEXT)").unwrap();
10258
10259 let s1 = db.create_session("user1", crate::session::IsolationLevel::ReadCommitted).unwrap();
10260
10261 db.begin_transaction_for_session(s1).unwrap();
10263 db.execute_in_session(s1, "INSERT INTO sess_seq VALUES (1, 'first')").unwrap();
10264 db.commit_transaction_for_session(s1).unwrap();
10265
10266 db.begin_transaction_for_session(s1).unwrap();
10268 db.execute_in_session(s1, "INSERT INTO sess_seq VALUES (2, 'second')").unwrap();
10269 db.commit_transaction_for_session(s1).unwrap();
10270
10271 let rows = db.query("SELECT * FROM sess_seq", &[]).unwrap();
10272 assert_eq!(rows.len(), 2, "Both sequential transactions should have committed");
10273
10274 db.destroy_session(s1).unwrap();
10275 }
10276
10277 #[test]
10278 fn test_session_rollback_then_new_transaction() {
10279 let db = EmbeddedDatabase::new_in_memory().unwrap();
10281 db.execute("CREATE TABLE sess_rb_new (id INT, val TEXT)").unwrap();
10282
10283 let s1 = db.create_session("user1", crate::session::IsolationLevel::ReadCommitted).unwrap();
10284
10285 db.begin_transaction_for_session(s1).unwrap();
10287 db.execute_in_session(s1, "INSERT INTO sess_rb_new VALUES (1, 'rolled_back')").unwrap();
10288 db.rollback_transaction_for_session(s1).unwrap();
10289
10290 db.begin_transaction_for_session(s1).unwrap();
10292 db.execute_in_session(s1, "INSERT INTO sess_rb_new VALUES (2, 'committed')").unwrap();
10293 db.commit_transaction_for_session(s1).unwrap();
10294
10295 let rows = db.query("SELECT * FROM sess_rb_new", &[]).unwrap();
10296 assert_eq!(rows.len(), 1, "Only the committed transaction's data should exist");
10297 assert_eq!(rows[0].get(1), Some(&Value::String("committed".to_string())));
10298
10299 db.destroy_session(s1).unwrap();
10300 }
10301
10302 #[test]
10303 fn test_session_double_begin_errors() {
10304 let db = EmbeddedDatabase::new_in_memory().unwrap();
10306
10307 let s1 = db.create_session("user1", crate::session::IsolationLevel::ReadCommitted).unwrap();
10308 db.begin_transaction_for_session(s1).unwrap();
10309
10310 let result = db.begin_transaction_for_session(s1);
10311 assert!(result.is_err(), "Double BEGIN on same session should fail");
10312
10313 db.rollback_transaction_for_session(s1).unwrap();
10314 db.destroy_session(s1).unwrap();
10315 }
10316
10317 #[test]
10318 fn test_session_commit_without_transaction_errors() {
10319 let db = EmbeddedDatabase::new_in_memory().unwrap();
10321
10322 let s1 = db.create_session("user1", crate::session::IsolationLevel::ReadCommitted).unwrap();
10323 let result = db.commit_transaction_for_session(s1);
10324 assert!(result.is_err(), "COMMIT without active transaction should fail");
10325
10326 db.destroy_session(s1).unwrap();
10327 }
10328
10329 #[test]
10330 fn test_session_rollback_without_transaction_errors() {
10331 let db = EmbeddedDatabase::new_in_memory().unwrap();
10333
10334 let s1 = db.create_session("user1", crate::session::IsolationLevel::ReadCommitted).unwrap();
10335 let result = db.rollback_transaction_for_session(s1);
10336 assert!(result.is_err(), "ROLLBACK without active transaction should fail");
10337
10338 db.destroy_session(s1).unwrap();
10339 }
10340
10341 #[test]
10346 fn test_insert_visible_in_same_transaction() {
10347 let db = EmbeddedDatabase::new_in_memory().unwrap();
10349 db.execute("CREATE TABLE t_ryow (id INT PRIMARY KEY, v TEXT)").unwrap();
10350 db.begin().unwrap();
10351 db.execute("INSERT INTO t_ryow VALUES (1, 'hello')").unwrap();
10352 let rows = db.query("SELECT * FROM t_ryow", &[]).unwrap();
10353 assert_eq!(rows.len(), 1, "INSERT must be visible to SELECT within the same transaction");
10354 db.commit().unwrap();
10355 }
10356
10357 #[test]
10358 fn test_update_visible_in_same_transaction() {
10359 let db = EmbeddedDatabase::new_in_memory().unwrap();
10361 db.execute("CREATE TABLE t_ryow2 (id INT PRIMARY KEY, v TEXT)").unwrap();
10362 db.execute("INSERT INTO t_ryow2 VALUES (1, 'before')").unwrap();
10363 db.begin().unwrap();
10364 db.execute("UPDATE t_ryow2 SET v = 'after' WHERE id = 1").unwrap();
10365 let rows = db.query("SELECT * FROM t_ryow2 WHERE id = 1", &[]).unwrap();
10366 assert_eq!(rows.len(), 1);
10367 let val = &rows[0].values[1];
10368 assert_eq!(val, &Value::String("after".to_string()),
10369 "UPDATE must be visible to SELECT within the same transaction");
10370 db.commit().unwrap();
10371 }
10372
10373 #[test]
10374 fn test_delete_visible_in_same_transaction() {
10375 let db = EmbeddedDatabase::new_in_memory().unwrap();
10377 db.execute("CREATE TABLE t_ryow3 (id INT PRIMARY KEY, v TEXT)").unwrap();
10378 db.execute("INSERT INTO t_ryow3 VALUES (1, 'gone')").unwrap();
10379 db.begin().unwrap();
10380 db.execute("DELETE FROM t_ryow3 WHERE id = 1").unwrap();
10381 let rows = db.query("SELECT * FROM t_ryow3", &[]).unwrap();
10382 assert_eq!(rows.len(), 0,
10383 "DELETE must be reflected in SELECT within the same transaction");
10384 db.commit().unwrap();
10385 }
10386
10387 #[test]
10388 fn test_multiple_inserts_visible_in_same_transaction() {
10389 let db = EmbeddedDatabase::new_in_memory().unwrap();
10391 db.execute("CREATE TABLE t_ryow4 (id INT PRIMARY KEY, v TEXT)").unwrap();
10392 db.begin().unwrap();
10393 db.execute("INSERT INTO t_ryow4 VALUES (1, 'a')").unwrap();
10394 db.execute("INSERT INTO t_ryow4 VALUES (2, 'b')").unwrap();
10395 db.execute("INSERT INTO t_ryow4 VALUES (3, 'c')").unwrap();
10396 let rows = db.query("SELECT * FROM t_ryow4", &[]).unwrap();
10397 assert_eq!(rows.len(), 3,
10398 "All INSERTs must be visible to SELECT within the same transaction");
10399 db.commit().unwrap();
10400 }
10401
10402 #[test]
10403 fn test_rollback_hides_inserts() {
10404 let db = EmbeddedDatabase::new_in_memory().unwrap();
10406 db.execute("CREATE TABLE t_ryow5 (id INT PRIMARY KEY, v TEXT)").unwrap();
10407 db.begin().unwrap();
10408 db.execute("INSERT INTO t_ryow5 VALUES (1, 'temp')").unwrap();
10409 let rows = db.query("SELECT * FROM t_ryow5", &[]).unwrap();
10411 assert_eq!(rows.len(), 1);
10412 db.rollback().unwrap();
10413 let rows = db.query("SELECT * FROM t_ryow5", &[]).unwrap();
10415 assert_eq!(rows.len(), 0,
10416 "After ROLLBACK, inserted data must not be visible");
10417 }
10418
10419 fn setup_window_test_db() -> EmbeddedDatabase {
10427 let db = EmbeddedDatabase::new_in_memory().unwrap();
10428 db.execute(
10429 "CREATE TABLE employees (id INT PRIMARY KEY, name TEXT, dept TEXT, salary INT, age INT)",
10430 )
10431 .unwrap();
10432 db.execute("INSERT INTO employees (id, name, dept, salary, age) VALUES (1, 'Alice', 'Engineering', 120000, 35)").unwrap();
10433 db.execute("INSERT INTO employees (id, name, dept, salary, age) VALUES (2, 'Bob', 'Engineering', 110000, 28)").unwrap();
10434 db.execute("INSERT INTO employees (id, name, dept, salary, age) VALUES (3, 'Charlie', 'Engineering', 110000, 32)").unwrap();
10435 db.execute("INSERT INTO employees (id, name, dept, salary, age) VALUES (4, 'Dave', 'Sales', 90000, 40)").unwrap();
10436 db.execute("INSERT INTO employees (id, name, dept, salary, age) VALUES (5, 'Eve', 'Sales', 95000, 25)").unwrap();
10437 db.execute("INSERT INTO employees (id, name, dept, salary, age) VALUES (6, 'Frank', 'Marketing', 80000, 45)").unwrap();
10438 db
10439 }
10440
10441 #[test]
10446 fn test_window_row_number_basic() {
10447 let db = setup_window_test_db();
10448 let results = db
10449 .query(
10450 "SELECT name, salary, ROW_NUMBER() OVER (ORDER BY salary DESC) FROM employees",
10451 &[],
10452 )
10453 .unwrap();
10454 assert_eq!(results.len(), 6, "Should return all 6 employees");
10455 let row_nums: std::collections::HashSet<i64> = results
10458 .iter()
10459 .map(|r| match r.get(2).unwrap() {
10460 Value::Int8(v) => *v,
10461 _ => panic!("expected Int8"),
10462 })
10463 .collect();
10464 assert_eq!(row_nums.len(), 6);
10465 for n in 1..=6 {
10466 assert!(row_nums.contains(&n), "Should contain row_number {}", n);
10467 }
10468 for row in &results {
10470 let sal = match row.get(1).unwrap() {
10471 Value::Int4(v) => *v as i64,
10472 Value::Int8(v) => *v,
10473 _ => panic!("unexpected type"),
10474 };
10475 let rn = match row.get(2).unwrap() {
10476 Value::Int8(v) => *v,
10477 _ => panic!("expected Int8"),
10478 };
10479 if sal == 120000 {
10480 assert_eq!(rn, 1, "Highest salary should have row_number 1");
10481 }
10482 }
10483 }
10484
10485 #[test]
10486 fn test_window_row_number_partitioned() {
10487 let db = setup_window_test_db();
10488 let results = db
10489 .query(
10490 "SELECT name, dept, ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary DESC) FROM employees",
10491 &[],
10492 )
10493 .unwrap();
10494 assert_eq!(results.len(), 6);
10495 let mut dept_row_nums: std::collections::HashMap<String, Vec<i64>> =
10497 std::collections::HashMap::new();
10498 for row in &results {
10499 if let (Some(Value::String(dept)), Some(Value::Int8(rn))) = (row.get(1), row.get(2)) {
10500 dept_row_nums
10501 .entry(dept.clone())
10502 .or_default()
10503 .push(*rn);
10504 }
10505 }
10506 if let Some(eng) = dept_row_nums.get("Engineering") {
10508 let mut sorted = eng.clone();
10509 sorted.sort();
10510 assert_eq!(sorted, vec![1, 2, 3]);
10511 }
10512 if let Some(mkt) = dept_row_nums.get("Marketing") {
10514 assert_eq!(mkt, &vec![1]);
10515 }
10516 }
10517
10518 #[test]
10519 fn test_window_rank_basic() {
10520 let db = setup_window_test_db();
10521 let results = db
10522 .query(
10523 "SELECT name, salary, RANK() OVER (ORDER BY salary DESC) FROM employees",
10524 &[],
10525 )
10526 .unwrap();
10527 assert_eq!(results.len(), 6);
10528 let ranks: Vec<(i64, i64)> = results
10531 .iter()
10532 .map(|r| {
10533 let sal = match r.get(1).unwrap() {
10534 Value::Int4(v) => *v as i64,
10535 Value::Int8(v) => *v,
10536 _ => panic!("unexpected salary type"),
10537 };
10538 let rank = match r.get(2).unwrap() {
10539 Value::Int8(v) => *v,
10540 _ => panic!("unexpected rank type"),
10541 };
10542 (sal, rank)
10543 })
10544 .collect();
10545 let rank_120k: Vec<i64> = ranks.iter().filter(|(s, _)| *s == 120000).map(|(_, r)| *r).collect();
10547 assert!(rank_120k.iter().all(|r| *r == 1), "120000 should have rank 1");
10548 let all_ranks: std::collections::HashSet<i64> = ranks.iter().map(|(_, r)| *r).collect();
10550 assert_eq!(all_ranks.len(), 5, "RANK correctly detects ties on ORDER BY columns");
10551 }
10552
10553 #[test]
10554 fn test_window_rank_with_ties() {
10555 let db = EmbeddedDatabase::new_in_memory().unwrap();
10556 db.execute("CREATE TABLE scores (id INT PRIMARY KEY, score INT)").unwrap();
10557 db.execute("INSERT INTO scores (id, score) VALUES (1, 100)").unwrap();
10558 db.execute("INSERT INTO scores (id, score) VALUES (2, 90)").unwrap();
10559 db.execute("INSERT INTO scores (id, score) VALUES (3, 90)").unwrap();
10560 db.execute("INSERT INTO scores (id, score) VALUES (4, 80)").unwrap();
10561 let results = db
10562 .query(
10563 "SELECT id, score, RANK() OVER (ORDER BY score DESC) FROM scores",
10564 &[],
10565 )
10566 .unwrap();
10567 assert_eq!(results.len(), 4);
10568 let ranks: Vec<i64> = results
10569 .iter()
10570 .map(|r| match r.get(2).unwrap() {
10571 Value::Int8(v) => *v,
10572 _ => panic!("expected Int8"),
10573 })
10574 .collect();
10575 let mut sorted_ranks = ranks.clone();
10578 sorted_ranks.sort();
10579 assert_eq!(
10580 sorted_ranks,
10581 vec![1, 2, 2, 4],
10582 "RANK correctly detects ties per SQL standard"
10583 );
10584 }
10585
10586 #[test]
10587 fn test_window_dense_rank_basic() {
10588 let db = EmbeddedDatabase::new_in_memory().unwrap();
10589 db.execute("CREATE TABLE scores (id INT PRIMARY KEY, score INT)").unwrap();
10590 db.execute("INSERT INTO scores (id, score) VALUES (1, 100)").unwrap();
10591 db.execute("INSERT INTO scores (id, score) VALUES (2, 90)").unwrap();
10592 db.execute("INSERT INTO scores (id, score) VALUES (3, 90)").unwrap();
10593 db.execute("INSERT INTO scores (id, score) VALUES (4, 80)").unwrap();
10594 let results = db
10595 .query(
10596 "SELECT id, score, DENSE_RANK() OVER (ORDER BY score DESC) FROM scores",
10597 &[],
10598 )
10599 .unwrap();
10600 assert_eq!(results.len(), 4);
10601 let ranks: Vec<i64> = results
10602 .iter()
10603 .map(|r| match r.get(2).unwrap() {
10604 Value::Int8(v) => *v,
10605 _ => panic!("expected Int8"),
10606 })
10607 .collect();
10608 let mut sorted_ranks = ranks.clone();
10611 sorted_ranks.sort();
10612 assert_eq!(
10613 sorted_ranks,
10614 vec![1, 2, 2, 3],
10615 "DENSE_RANK correctly detects ties per SQL standard (no gaps)"
10616 );
10617 }
10618
10619 #[test]
10620 fn test_window_ntile_basic() {
10621 let db = setup_window_test_db();
10622 let results = db
10623 .query(
10624 "SELECT name, NTILE(3) OVER (ORDER BY salary) FROM employees",
10625 &[],
10626 )
10627 .unwrap();
10628 assert_eq!(results.len(), 6);
10629 let buckets: Vec<i64> = results
10631 .iter()
10632 .map(|r| match r.get(1).unwrap() {
10633 Value::Int8(v) => *v,
10634 _ => panic!("expected Int8"),
10635 })
10636 .collect();
10637 assert!(
10638 buckets.iter().all(|b| *b >= 1 && *b <= 3),
10639 "All buckets should be 1..=3"
10640 );
10641 for bucket in 1..=3 {
10642 let count = buckets.iter().filter(|&&b| b == bucket).count();
10643 assert_eq!(count, 2, "Bucket {} should have 2 rows", bucket);
10644 }
10645 }
10646
10647 #[test]
10648 fn test_window_ntile_uneven() {
10649 let db = EmbeddedDatabase::new_in_memory().unwrap();
10650 db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
10651 for i in 1..=7 {
10652 db.execute(&format!(
10653 "INSERT INTO nums (id, val) VALUES ({}, {})",
10654 i,
10655 i * 10
10656 ))
10657 .unwrap();
10658 }
10659 let results = db
10660 .query("SELECT val, NTILE(3) OVER (ORDER BY val) FROM nums", &[])
10661 .unwrap();
10662 assert_eq!(results.len(), 7);
10663 let buckets: Vec<i64> = results
10664 .iter()
10665 .map(|r| match r.get(1).unwrap() {
10666 Value::Int8(v) => *v,
10667 _ => panic!("expected Int8"),
10668 })
10669 .collect();
10670 assert!(
10671 buckets.iter().all(|b| *b >= 1 && *b <= 3),
10672 "All buckets should be 1..=3"
10673 );
10674 }
10675
10676 #[test]
10681 fn test_window_lag_basic() {
10682 let db = EmbeddedDatabase::new_in_memory().unwrap();
10683 db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
10684 db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
10685 db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
10686 db.execute("INSERT INTO nums (id, val) VALUES (3, 30)").unwrap();
10687 db.execute("INSERT INTO nums (id, val) VALUES (4, 40)").unwrap();
10688 let results = db
10689 .query(
10690 "SELECT val, LAG(val, 1) OVER (ORDER BY val) FROM nums",
10691 &[],
10692 )
10693 .unwrap();
10694 assert_eq!(results.len(), 4);
10695 assert_eq!(results[0].get(1).unwrap(), &Value::Null);
10697 let lag_val = results[1].get(1).unwrap();
10699 assert!(
10700 matches!(lag_val, Value::Int4(10) | Value::Int8(10)),
10701 "LAG of second row should be 10, got {:?}",
10702 lag_val
10703 );
10704 }
10705
10706 #[test]
10707 fn test_window_lag_offset_2() {
10708 let db = EmbeddedDatabase::new_in_memory().unwrap();
10709 db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
10710 for i in 1..=5 {
10711 db.execute(&format!(
10712 "INSERT INTO nums (id, val) VALUES ({}, {})",
10713 i,
10714 i * 10
10715 ))
10716 .unwrap();
10717 }
10718 let results = db
10719 .query(
10720 "SELECT val, LAG(val, 2) OVER (ORDER BY val) FROM nums",
10721 &[],
10722 )
10723 .unwrap();
10724 assert_eq!(results.len(), 5);
10725 assert_eq!(results[0].get(1).unwrap(), &Value::Null);
10726 assert_eq!(results[1].get(1).unwrap(), &Value::Null);
10727 let lag_val = results[2].get(1).unwrap();
10728 assert!(
10729 matches!(lag_val, Value::Int4(10) | Value::Int8(10)),
10730 "LAG(val,2) of third row should be 10, got {:?}",
10731 lag_val
10732 );
10733 }
10734
10735 #[test]
10736 fn test_window_lag_default_offset() {
10737 let db = EmbeddedDatabase::new_in_memory().unwrap();
10739 db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
10740 db.execute("INSERT INTO nums (id, val) VALUES (1, 100)").unwrap();
10741 db.execute("INSERT INTO nums (id, val) VALUES (2, 200)").unwrap();
10742 db.execute("INSERT INTO nums (id, val) VALUES (3, 300)").unwrap();
10743 let results = db
10744 .query(
10745 "SELECT val, LAG(val) OVER (ORDER BY val) FROM nums",
10746 &[],
10747 )
10748 .unwrap();
10749 assert_eq!(results.len(), 3);
10750 assert_eq!(results[0].get(1).unwrap(), &Value::Null);
10751 let lag_val = results[1].get(1).unwrap();
10752 assert!(
10753 matches!(lag_val, Value::Int4(100) | Value::Int8(100)),
10754 "LAG with default offset should be 100 for second row, got {:?}",
10755 lag_val
10756 );
10757 }
10758
10759 #[test]
10760 fn test_window_lead_basic() {
10761 let db = EmbeddedDatabase::new_in_memory().unwrap();
10762 db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
10763 db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
10764 db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
10765 db.execute("INSERT INTO nums (id, val) VALUES (3, 30)").unwrap();
10766 db.execute("INSERT INTO nums (id, val) VALUES (4, 40)").unwrap();
10767 let results = db
10768 .query(
10769 "SELECT val, LEAD(val, 1) OVER (ORDER BY val) FROM nums",
10770 &[],
10771 )
10772 .unwrap();
10773 assert_eq!(results.len(), 4);
10774 assert_eq!(results[3].get(1).unwrap(), &Value::Null);
10776 let lead_val = results[0].get(1).unwrap();
10778 assert!(
10779 matches!(lead_val, Value::Int4(20) | Value::Int8(20)),
10780 "LEAD of first row should be 20, got {:?}",
10781 lead_val
10782 );
10783 }
10784
10785 #[test]
10786 fn test_window_lead_offset_2() {
10787 let db = EmbeddedDatabase::new_in_memory().unwrap();
10788 db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
10789 for i in 1..=5 {
10790 db.execute(&format!(
10791 "INSERT INTO nums (id, val) VALUES ({}, {})",
10792 i,
10793 i * 10
10794 ))
10795 .unwrap();
10796 }
10797 let results = db
10798 .query(
10799 "SELECT val, LEAD(val, 2) OVER (ORDER BY val) FROM nums",
10800 &[],
10801 )
10802 .unwrap();
10803 assert_eq!(results.len(), 5);
10804 assert_eq!(results[3].get(1).unwrap(), &Value::Null);
10805 assert_eq!(results[4].get(1).unwrap(), &Value::Null);
10806 let lead_val = results[0].get(1).unwrap();
10807 assert!(
10808 matches!(lead_val, Value::Int4(30) | Value::Int8(30)),
10809 "LEAD(val,2) of first row should be 30, got {:?}",
10810 lead_val
10811 );
10812 }
10813
10814 #[test]
10815 fn test_window_lead_default_offset() {
10816 let db = EmbeddedDatabase::new_in_memory().unwrap();
10817 db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
10818 db.execute("INSERT INTO nums (id, val) VALUES (1, 100)").unwrap();
10819 db.execute("INSERT INTO nums (id, val) VALUES (2, 200)").unwrap();
10820 db.execute("INSERT INTO nums (id, val) VALUES (3, 300)").unwrap();
10821 let results = db
10822 .query(
10823 "SELECT val, LEAD(val) OVER (ORDER BY val) FROM nums",
10824 &[],
10825 )
10826 .unwrap();
10827 assert_eq!(results.len(), 3);
10828 assert_eq!(results[2].get(1).unwrap(), &Value::Null);
10829 let lead_val = results[0].get(1).unwrap();
10830 assert!(
10831 matches!(lead_val, Value::Int4(200) | Value::Int8(200)),
10832 "LEAD with default offset should be 200 for first row, got {:?}",
10833 lead_val
10834 );
10835 }
10836
10837 #[test]
10838 fn test_window_first_value_basic() {
10839 let db = EmbeddedDatabase::new_in_memory().unwrap();
10840 db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
10841 db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
10842 db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
10843 db.execute("INSERT INTO nums (id, val) VALUES (3, 30)").unwrap();
10844 let results = db
10845 .query(
10846 "SELECT val, FIRST_VALUE(val) OVER (ORDER BY val) FROM nums",
10847 &[],
10848 )
10849 .unwrap();
10850 assert_eq!(results.len(), 3);
10851 for row in &results {
10852 let fv = row.get(1).unwrap();
10853 assert!(
10854 matches!(fv, Value::Int4(10) | Value::Int8(10)),
10855 "FIRST_VALUE should be 10, got {:?}",
10856 fv
10857 );
10858 }
10859 }
10860
10861 #[test]
10862 fn test_window_first_value_partitioned() {
10863 let db = setup_window_test_db();
10864 let results = db.query(
10865 "SELECT name, dept, salary, FIRST_VALUE(salary) OVER (PARTITION BY dept ORDER BY salary DESC) FROM employees",
10866 &[],
10867 ).unwrap();
10868 assert_eq!(results.len(), 6);
10869 for row in &results {
10871 if let Some(Value::String(dept)) = row.get(1) {
10872 let fv = row.get(3).unwrap();
10873 let expected = match dept.as_str() {
10874 "Engineering" => 120000,
10875 "Sales" => 95000,
10876 "Marketing" => 80000,
10877 _ => panic!("unexpected dept"),
10878 };
10879 assert!(
10880 matches!(fv, Value::Int4(v) if *v == expected)
10881 || matches!(fv, Value::Int8(v) if *v == expected as i64),
10882 "FIRST_VALUE for {} should be {}, got {:?}",
10883 dept,
10884 expected,
10885 fv
10886 );
10887 }
10888 }
10889 }
10890
10891 #[test]
10892 fn test_window_first_value_with_nulls() {
10893 let db = EmbeddedDatabase::new_in_memory().unwrap();
10894 db.execute("CREATE TABLE null_first (id INT PRIMARY KEY, val INT)").unwrap();
10895 db.execute("INSERT INTO null_first (id) VALUES (1)").unwrap(); db.execute("INSERT INTO null_first (id, val) VALUES (2, 20)").unwrap();
10897 db.execute("INSERT INTO null_first (id, val) VALUES (3, 30)").unwrap();
10898 let results = db
10899 .query(
10900 "SELECT id, val, FIRST_VALUE(val) OVER (ORDER BY id) FROM null_first",
10901 &[],
10902 )
10903 .unwrap();
10904 assert_eq!(results.len(), 3);
10905 for row in &results {
10907 assert_eq!(
10908 row.get(2).unwrap(),
10909 &Value::Null,
10910 "FIRST_VALUE should be NULL when first row has NULL val"
10911 );
10912 }
10913 }
10914
10915 #[test]
10916 fn test_window_last_value_with_order_by() {
10917 let db = EmbeddedDatabase::new_in_memory().unwrap();
10918 db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
10919 db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
10920 db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
10921 db.execute("INSERT INTO nums (id, val) VALUES (3, 30)").unwrap();
10922 let results = db
10925 .query(
10926 "SELECT val, LAST_VALUE(val) OVER (ORDER BY val) FROM nums",
10927 &[],
10928 )
10929 .unwrap();
10930 assert_eq!(results.len(), 3);
10931 for row in &results {
10932 let val = row.get(0).unwrap();
10933 let lv = row.get(1).unwrap();
10934 assert_eq!(
10935 val, lv,
10936 "LAST_VALUE with default frame (ORDER BY) should equal current row value"
10937 );
10938 }
10939 }
10940
10941 #[test]
10942 fn test_window_last_value_no_order_by() {
10943 let db = EmbeddedDatabase::new_in_memory().unwrap();
10944 db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
10945 db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
10946 db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
10947 db.execute("INSERT INTO nums (id, val) VALUES (3, 30)").unwrap();
10948 let results = db
10951 .query("SELECT val, LAST_VALUE(val) OVER () FROM nums", &[])
10952 .unwrap();
10953 assert_eq!(results.len(), 3);
10954 let last_vals: Vec<&Value> = results.iter().map(|r| r.get(1).unwrap()).collect();
10955 assert!(
10956 last_vals.windows(2).all(|w| w[0] == w[1]),
10957 "All LAST_VALUE results without ORDER BY should be equal"
10958 );
10959 }
10960
10961 #[test]
10966 fn test_window_sum_partitioned() {
10967 let db = setup_window_test_db();
10968 let results = db
10969 .query(
10970 "SELECT name, dept, salary, SUM(salary) OVER (PARTITION BY dept) FROM employees",
10971 &[],
10972 )
10973 .unwrap();
10974 assert_eq!(results.len(), 6);
10975 for row in &results {
10976 if let Some(Value::String(dept)) = row.get(1) {
10977 let sum_val = row.get(3).unwrap();
10978 let expected: f64 = match dept.as_str() {
10979 "Engineering" => 340_000.0,
10980 "Sales" => 185_000.0,
10981 "Marketing" => 80_000.0,
10982 _ => panic!("unexpected dept"),
10983 };
10984 assert!(
10985 matches!(sum_val, Value::Float8(v) if (*v - expected).abs() < 0.01),
10986 "SUM for {} should be {}, got {:?}",
10987 dept,
10988 expected,
10989 sum_val
10990 );
10991 }
10992 }
10993 }
10994
10995 #[test]
10996 fn test_window_sum_running_total() {
10997 let db = EmbeddedDatabase::new_in_memory().unwrap();
10998 db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
10999 db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
11000 db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
11001 db.execute("INSERT INTO nums (id, val) VALUES (3, 30)").unwrap();
11002 let results = db
11004 .query(
11005 "SELECT val, SUM(val) OVER (ORDER BY val) FROM nums",
11006 &[],
11007 )
11008 .unwrap();
11009 assert_eq!(results.len(), 3);
11010 let sums: Vec<f64> = results
11011 .iter()
11012 .map(|r| match r.get(1).unwrap() {
11013 Value::Float8(v) => *v,
11014 other => panic!("expected Float8, got {:?}", other),
11015 })
11016 .collect();
11017 assert!((sums[0] - 10.0).abs() < 0.01, "Running sum row 1 = 10");
11018 assert!((sums[1] - 30.0).abs() < 0.01, "Running sum row 2 = 30");
11019 assert!((sums[2] - 60.0).abs() < 0.01, "Running sum row 3 = 60");
11020 }
11021
11022 #[test]
11023 fn test_window_count_partitioned() {
11024 let db = setup_window_test_db();
11025 let results = db
11027 .query(
11028 "SELECT name, dept, COUNT(salary) OVER (PARTITION BY dept) FROM employees",
11029 &[],
11030 )
11031 .unwrap();
11032 assert_eq!(results.len(), 6);
11033 for row in &results {
11034 if let Some(Value::String(dept)) = row.get(1) {
11035 let count = row.get(2).unwrap();
11036 let expected = match dept.as_str() {
11037 "Engineering" => 3,
11038 "Sales" => 2,
11039 "Marketing" => 1,
11040 _ => panic!("unexpected dept"),
11041 };
11042 assert_eq!(
11043 count,
11044 &Value::Int8(expected),
11045 "COUNT for {} should be {}",
11046 dept,
11047 expected
11048 );
11049 }
11050 }
11051 }
11052
11053 #[test]
11054 fn test_window_count_running() {
11055 let db = EmbeddedDatabase::new_in_memory().unwrap();
11056 db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
11057 for i in 1..=5 {
11058 db.execute(&format!(
11059 "INSERT INTO nums (id, val) VALUES ({}, {})",
11060 i,
11061 i * 10
11062 ))
11063 .unwrap();
11064 }
11065 let results = db
11067 .query(
11068 "SELECT val, COUNT(val) OVER (ORDER BY val) FROM nums",
11069 &[],
11070 )
11071 .unwrap();
11072 assert_eq!(results.len(), 5);
11073 let counts: Vec<i64> = results
11074 .iter()
11075 .map(|r| match r.get(1).unwrap() {
11076 Value::Int8(v) => *v,
11077 _ => panic!("expected Int8"),
11078 })
11079 .collect();
11080 assert_eq!(counts, vec![1, 2, 3, 4, 5], "Running count should be 1..=5");
11081 }
11082
11083 #[test]
11084 fn test_window_avg_partitioned() {
11085 let db = setup_window_test_db();
11086 let results = db
11087 .query(
11088 "SELECT name, dept, salary, AVG(salary) OVER (PARTITION BY dept) FROM employees",
11089 &[],
11090 )
11091 .unwrap();
11092 assert_eq!(results.len(), 6);
11093 for row in &results {
11094 if let Some(Value::String(dept)) = row.get(1) {
11095 let avg_val = row.get(3).unwrap();
11096 let expected: f64 = match dept.as_str() {
11097 "Engineering" => 340_000.0 / 3.0,
11098 "Sales" => 92_500.0,
11099 "Marketing" => 80_000.0,
11100 _ => panic!("unexpected dept"),
11101 };
11102 assert!(
11103 matches!(avg_val, Value::Float8(v) if (*v - expected).abs() < 1.0),
11104 "AVG for {} should be ~{}, got {:?}",
11105 dept,
11106 expected,
11107 avg_val
11108 );
11109 }
11110 }
11111 }
11112
11113 #[test]
11114 fn test_window_avg_running() {
11115 let db = EmbeddedDatabase::new_in_memory().unwrap();
11116 db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
11117 db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
11118 db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
11119 db.execute("INSERT INTO nums (id, val) VALUES (3, 30)").unwrap();
11120 let results = db
11121 .query(
11122 "SELECT val, AVG(val) OVER (ORDER BY val) FROM nums",
11123 &[],
11124 )
11125 .unwrap();
11126 assert_eq!(results.len(), 3);
11127 let avgs: Vec<f64> = results
11128 .iter()
11129 .map(|r| match r.get(1).unwrap() {
11130 Value::Float8(v) => *v,
11131 other => panic!("expected Float8, got {:?}", other),
11132 })
11133 .collect();
11134 assert!((avgs[0] - 10.0).abs() < 0.01, "Running avg row 1");
11135 assert!((avgs[1] - 15.0).abs() < 0.01, "Running avg row 2");
11136 assert!((avgs[2] - 20.0).abs() < 0.01, "Running avg row 3");
11137 }
11138
11139 #[test]
11140 fn test_window_min_max_partitioned() {
11141 let db = setup_window_test_db();
11142 let results = db.query(
11143 "SELECT name, dept, MIN(salary) OVER (PARTITION BY dept), MAX(salary) OVER (PARTITION BY dept) FROM employees",
11144 &[],
11145 ).unwrap();
11146 assert_eq!(results.len(), 6);
11147 for row in &results {
11148 if let Some(Value::String(dept)) = row.get(1) {
11149 let min_val = row.get(2).unwrap();
11150 let max_val = row.get(3).unwrap();
11151 match dept.as_str() {
11152 "Engineering" => {
11153 assert!(
11154 matches!(min_val, Value::Int4(110000) | Value::Int8(110000)),
11155 "MIN for Engineering = 110000, got {:?}",
11156 min_val
11157 );
11158 assert!(
11159 matches!(max_val, Value::Int4(120000) | Value::Int8(120000)),
11160 "MAX for Engineering = 120000, got {:?}",
11161 max_val
11162 );
11163 }
11164 "Sales" => {
11165 assert!(
11166 matches!(min_val, Value::Int4(90000) | Value::Int8(90000)),
11167 "MIN for Sales = 90000, got {:?}",
11168 min_val
11169 );
11170 assert!(
11171 matches!(max_val, Value::Int4(95000) | Value::Int8(95000)),
11172 "MAX for Sales = 95000, got {:?}",
11173 max_val
11174 );
11175 }
11176 "Marketing" => {
11177 assert!(
11178 matches!(min_val, Value::Int4(80000) | Value::Int8(80000)),
11179 "MIN for Marketing = 80000, got {:?}",
11180 min_val
11181 );
11182 assert!(
11183 matches!(max_val, Value::Int4(80000) | Value::Int8(80000)),
11184 "MAX for Marketing = 80000, got {:?}",
11185 max_val
11186 );
11187 }
11188 _ => panic!("unexpected dept"),
11189 }
11190 }
11191 }
11192 }
11193
11194 #[test]
11199 fn test_window_empty_result_set() {
11200 let db = EmbeddedDatabase::new_in_memory().unwrap();
11201 db.execute("CREATE TABLE empty_t (id INT PRIMARY KEY, val INT)").unwrap();
11202 let results = db
11203 .query(
11204 "SELECT val, ROW_NUMBER() OVER (ORDER BY val) FROM empty_t",
11205 &[],
11206 )
11207 .unwrap();
11208 assert_eq!(results.len(), 0, "Window on empty table => 0 rows");
11209 }
11210
11211 #[test]
11212 fn test_window_empty_result_set_via_where() {
11213 let db = setup_window_test_db();
11214 let results = db
11215 .query(
11216 "SELECT name, ROW_NUMBER() OVER (ORDER BY salary) FROM employees WHERE dept = 'NonExistent'",
11217 &[],
11218 )
11219 .unwrap();
11220 assert_eq!(results.len(), 0, "No matching WHERE => 0 rows");
11221 }
11222
11223 #[test]
11224 fn test_window_single_row_partition() {
11225 let db = setup_window_test_db();
11226 let results = db.query(
11227 "SELECT name, dept, \
11228 ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary), \
11229 RANK() OVER (PARTITION BY dept ORDER BY salary), \
11230 SUM(salary) OVER (PARTITION BY dept) \
11231 FROM employees WHERE dept = 'Marketing'",
11232 &[],
11233 ).unwrap();
11234 assert_eq!(results.len(), 1);
11235 assert_eq!(results[0].get(2).unwrap(), &Value::Int8(1), "ROW_NUMBER = 1");
11236 assert_eq!(results[0].get(3).unwrap(), &Value::Int8(1), "RANK = 1");
11237 assert!(
11238 matches!(results[0].get(4).unwrap(), Value::Float8(v) if (*v - 80000.0).abs() < 0.01),
11239 "SUM = 80000"
11240 );
11241 }
11242
11243 #[test]
11244 fn test_window_all_null_values_in_windowed_column() {
11245 let db = EmbeddedDatabase::new_in_memory().unwrap();
11246 db.execute("CREATE TABLE null_t (id INT PRIMARY KEY, val INT)").unwrap();
11247 db.execute("INSERT INTO null_t (id) VALUES (1)").unwrap();
11248 db.execute("INSERT INTO null_t (id) VALUES (2)").unwrap();
11249 db.execute("INSERT INTO null_t (id) VALUES (3)").unwrap();
11250 let results = db
11251 .query(
11252 "SELECT id, val, SUM(val) OVER (ORDER BY id) FROM null_t",
11253 &[],
11254 )
11255 .unwrap();
11256 assert_eq!(results.len(), 3);
11257 for row in &results {
11259 let sum_val = row.get(2).unwrap();
11260 assert!(
11261 matches!(sum_val, Value::Null),
11262 "SUM of all NULLs should be NULL (SQL standard), got {:?}",
11263 sum_val
11264 );
11265 }
11266 }
11267
11268 #[test]
11269 fn test_window_null_in_windowed_column() {
11270 let db = EmbeddedDatabase::new_in_memory().unwrap();
11271 db.execute("CREATE TABLE mixed_nulls (id INT PRIMARY KEY, val INT)").unwrap();
11272 db.execute("INSERT INTO mixed_nulls (id, val) VALUES (1, 10)").unwrap();
11273 db.execute("INSERT INTO mixed_nulls (id) VALUES (2)").unwrap(); db.execute("INSERT INTO mixed_nulls (id, val) VALUES (3, 30)").unwrap();
11275 let results = db
11276 .query(
11277 "SELECT id, val, LAG(val, 1) OVER (ORDER BY id) FROM mixed_nulls",
11278 &[],
11279 )
11280 .unwrap();
11281 assert_eq!(results.len(), 3);
11282 assert_eq!(results[0].get(2).unwrap(), &Value::Null, "LAG for first row = NULL");
11283 let lag2 = results[1].get(2).unwrap();
11284 assert!(
11285 matches!(lag2, Value::Int4(10) | Value::Int8(10)),
11286 "LAG for id=2 should be 10, got {:?}",
11287 lag2
11288 );
11289 assert_eq!(
11290 results[2].get(2).unwrap(),
11291 &Value::Null,
11292 "LAG for id=3 should be NULL (previous row has NULL val)"
11293 );
11294 }
11295
11296 #[test]
11297 fn test_window_multiple_functions_same_select() {
11298 let db = setup_window_test_db();
11299 let results = db.query(
11301 "SELECT name, salary, \
11302 ROW_NUMBER() OVER (ORDER BY salary DESC), \
11303 SUM(salary) OVER (), \
11304 COUNT(salary) OVER () \
11305 FROM employees",
11306 &[],
11307 ).unwrap();
11308 assert_eq!(results.len(), 6);
11309 let total: f64 =
11310 120_000.0 + 110_000.0 + 110_000.0 + 90_000.0 + 95_000.0 + 80_000.0;
11311 for row in &results {
11312 let sum_val = row.get(3).unwrap();
11313 assert!(
11314 matches!(sum_val, Value::Float8(v) if (*v - total).abs() < 0.01),
11315 "Total SUM should be {}, got {:?}",
11316 total,
11317 sum_val
11318 );
11319 let count_val = row.get(4).unwrap();
11320 assert_eq!(count_val, &Value::Int8(6), "Total COUNT should be 6");
11321 }
11322 let row_nums: Vec<i64> = results
11324 .iter()
11325 .map(|r| match r.get(2).unwrap() {
11326 Value::Int8(v) => *v,
11327 _ => panic!("expected Int8"),
11328 })
11329 .collect();
11330 let mut sorted = row_nums.clone();
11331 sorted.sort();
11332 assert_eq!(sorted, vec![1, 2, 3, 4, 5, 6]);
11333 }
11334
11335 #[test]
11336 fn test_window_no_partition_by_entire_table() {
11337 let db = setup_window_test_db();
11338 let results = db
11340 .query(
11341 "SELECT name, salary, COUNT(salary) OVER () FROM employees",
11342 &[],
11343 )
11344 .unwrap();
11345 assert_eq!(results.len(), 6);
11346 for row in &results {
11347 assert_eq!(row.get(2).unwrap(), &Value::Int8(6));
11348 }
11349 }
11350
11351 #[test]
11352 fn test_window_partition_with_many_groups() {
11353 let db = EmbeddedDatabase::new_in_memory().unwrap();
11354 db.execute("CREATE TABLE big_t (id INT PRIMARY KEY, grp INT, val INT)").unwrap();
11355 for i in 1..=200 {
11356 let grp = (i - 1) / 2 + 1; db.execute(&format!(
11358 "INSERT INTO big_t (id, grp, val) VALUES ({}, {}, {})",
11359 i, grp, i * 10
11360 ))
11361 .unwrap();
11362 }
11363 let results = db
11365 .query(
11366 "SELECT grp, COUNT(val) OVER (PARTITION BY grp) FROM big_t",
11367 &[],
11368 )
11369 .unwrap();
11370 assert_eq!(results.len(), 200);
11371 for row in &results {
11372 assert_eq!(
11373 row.get(1).unwrap(),
11374 &Value::Int8(2),
11375 "Each of 100 groups should have COUNT = 2"
11376 );
11377 }
11378 }
11379
11380 #[test]
11381 fn test_window_identical_values_all_rows() {
11382 let db = EmbeddedDatabase::new_in_memory().unwrap();
11383 db.execute("CREATE TABLE same_vals (id INT PRIMARY KEY, val INT)").unwrap();
11384 db.execute("INSERT INTO same_vals (id, val) VALUES (1, 42)").unwrap();
11385 db.execute("INSERT INTO same_vals (id, val) VALUES (2, 42)").unwrap();
11386 db.execute("INSERT INTO same_vals (id, val) VALUES (3, 42)").unwrap();
11387 let results = db.query(
11388 "SELECT id, val, ROW_NUMBER() OVER (ORDER BY val), SUM(val) OVER () FROM same_vals",
11389 &[],
11390 ).unwrap();
11391 assert_eq!(results.len(), 3);
11392 let row_nums: Vec<i64> = results
11393 .iter()
11394 .map(|r| match r.get(2).unwrap() {
11395 Value::Int8(v) => *v,
11396 _ => panic!("expected Int8"),
11397 })
11398 .collect();
11399 let mut sorted = row_nums.clone();
11400 sorted.sort();
11401 assert_eq!(sorted, vec![1, 2, 3]);
11402 for row in &results {
11403 assert!(
11404 matches!(row.get(3).unwrap(), Value::Float8(v) if (*v - 126.0).abs() < 0.01),
11405 "SUM should be 126"
11406 );
11407 }
11408 }
11409
11410 #[test]
11411 fn test_window_single_row_table() {
11412 let db = EmbeddedDatabase::new_in_memory().unwrap();
11413 db.execute("CREATE TABLE single (id INT PRIMARY KEY, val INT)").unwrap();
11414 db.execute("INSERT INTO single (id, val) VALUES (1, 42)").unwrap();
11415 let results = db.query(
11417 "SELECT val, \
11418 ROW_NUMBER() OVER (ORDER BY val), \
11419 RANK() OVER (ORDER BY val), \
11420 LAG(val, 1) OVER (ORDER BY val), \
11421 LEAD(val, 1) OVER (ORDER BY val), \
11422 SUM(val) OVER (), \
11423 COUNT(val) OVER () \
11424 FROM single",
11425 &[],
11426 ).unwrap();
11427 assert_eq!(results.len(), 1);
11428 let row = &results[0];
11429 assert_eq!(row.get(1).unwrap(), &Value::Int8(1), "ROW_NUMBER = 1");
11430 assert_eq!(row.get(2).unwrap(), &Value::Int8(1), "RANK = 1");
11431 assert_eq!(row.get(3).unwrap(), &Value::Null, "LAG = NULL");
11432 assert_eq!(row.get(4).unwrap(), &Value::Null, "LEAD = NULL");
11433 assert!(
11434 matches!(row.get(5).unwrap(), Value::Float8(v) if (*v - 42.0).abs() < 0.01),
11435 "SUM = 42"
11436 );
11437 assert_eq!(row.get(6).unwrap(), &Value::Int8(1), "COUNT = 1");
11438 }
11439
11440 #[test]
11445 fn test_window_frame_unbounded_preceding_to_current_row() {
11446 let db = EmbeddedDatabase::new_in_memory().unwrap();
11447 db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
11448 db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
11449 db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
11450 db.execute("INSERT INTO nums (id, val) VALUES (3, 30)").unwrap();
11451 db.execute("INSERT INTO nums (id, val) VALUES (4, 40)").unwrap();
11452 let results = db.query(
11453 "SELECT val, SUM(val) OVER (ORDER BY val ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) FROM nums",
11454 &[],
11455 ).unwrap();
11456 assert_eq!(results.len(), 4);
11457 let sums: Vec<f64> = results
11458 .iter()
11459 .map(|r| match r.get(1).unwrap() {
11460 Value::Float8(v) => *v,
11461 other => panic!("expected Float8, got {:?}", other),
11462 })
11463 .collect();
11464 assert!((sums[0] - 10.0).abs() < 0.01);
11465 assert!((sums[1] - 30.0).abs() < 0.01);
11466 assert!((sums[2] - 60.0).abs() < 0.01);
11467 assert!((sums[3] - 100.0).abs() < 0.01);
11468 }
11469
11470 #[test]
11471 fn test_window_frame_1_preceding_to_1_following() {
11472 let db = EmbeddedDatabase::new_in_memory().unwrap();
11473 db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
11474 db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
11475 db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
11476 db.execute("INSERT INTO nums (id, val) VALUES (3, 30)").unwrap();
11477 db.execute("INSERT INTO nums (id, val) VALUES (4, 40)").unwrap();
11478 let results = db.query(
11479 "SELECT val, SUM(val) OVER (ORDER BY val ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM nums",
11480 &[],
11481 ).unwrap();
11482 assert_eq!(results.len(), 4);
11483 let sums: Vec<f64> = results
11484 .iter()
11485 .map(|r| match r.get(1).unwrap() {
11486 Value::Float8(v) => *v,
11487 other => panic!("expected Float8, got {:?}", other),
11488 })
11489 .collect();
11490 assert!((sums[0] - 30.0).abs() < 0.01, "Row 1: got {}", sums[0]);
11492 assert!((sums[1] - 60.0).abs() < 0.01, "Row 2: got {}", sums[1]);
11493 assert!((sums[2] - 90.0).abs() < 0.01, "Row 3: got {}", sums[2]);
11494 assert!((sums[3] - 70.0).abs() < 0.01, "Row 4: got {}", sums[3]);
11495 }
11496
11497 #[test]
11498 fn test_window_frame_unbounded_preceding_to_unbounded_following() {
11499 let db = EmbeddedDatabase::new_in_memory().unwrap();
11500 db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
11501 db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
11502 db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
11503 db.execute("INSERT INTO nums (id, val) VALUES (3, 30)").unwrap();
11504 let results = db.query(
11505 "SELECT val, SUM(val) OVER (ORDER BY val ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM nums",
11506 &[],
11507 ).unwrap();
11508 assert_eq!(results.len(), 3);
11509 for row in &results {
11510 let sum_val = row.get(1).unwrap();
11511 assert!(
11512 matches!(sum_val, Value::Float8(v) if (*v - 60.0).abs() < 0.01),
11513 "Full frame SUM should be 60, got {:?}",
11514 sum_val
11515 );
11516 }
11517 }
11518
11519 #[test]
11520 fn test_window_frame_current_row_to_current_row() {
11521 let db = EmbeddedDatabase::new_in_memory().unwrap();
11522 db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
11523 db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
11524 db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
11525 db.execute("INSERT INTO nums (id, val) VALUES (3, 30)").unwrap();
11526 let results = db.query(
11527 "SELECT val, SUM(val) OVER (ORDER BY val ROWS BETWEEN CURRENT ROW AND CURRENT ROW) FROM nums",
11528 &[],
11529 ).unwrap();
11530 assert_eq!(results.len(), 3);
11531 for row in &results {
11532 let val = match row.get(0).unwrap() {
11533 Value::Int4(v) => *v as f64,
11534 Value::Int8(v) => *v as f64,
11535 _ => panic!("unexpected type"),
11536 };
11537 let sum = match row.get(1).unwrap() {
11538 Value::Float8(v) => *v,
11539 other => panic!("expected Float8, got {:?}", other),
11540 };
11541 assert!(
11542 (sum - val).abs() < 0.01,
11543 "CURRENT ROW frame SUM should equal own value"
11544 );
11545 }
11546 }
11547
11548 #[test]
11549 fn test_window_frame_2_preceding_to_current_row() {
11550 let db = EmbeddedDatabase::new_in_memory().unwrap();
11551 db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
11552 for i in 1..=5 {
11553 db.execute(&format!(
11554 "INSERT INTO nums (id, val) VALUES ({}, {})",
11555 i,
11556 i * 10
11557 ))
11558 .unwrap();
11559 }
11560 let results = db.query(
11561 "SELECT val, SUM(val) OVER (ORDER BY val ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) FROM nums",
11562 &[],
11563 ).unwrap();
11564 assert_eq!(results.len(), 5);
11565 let sums: Vec<f64> = results
11566 .iter()
11567 .map(|r| match r.get(1).unwrap() {
11568 Value::Float8(v) => *v,
11569 other => panic!("expected Float8, got {:?}", other),
11570 })
11571 .collect();
11572 assert!((sums[0] - 10.0).abs() < 0.01);
11574 assert!((sums[1] - 30.0).abs() < 0.01);
11575 assert!((sums[2] - 60.0).abs() < 0.01);
11576 assert!((sums[3] - 90.0).abs() < 0.01);
11577 assert!((sums[4] - 120.0).abs() < 0.01);
11578 }
11579
11580 #[test]
11585 fn test_window_row_number_no_order_by() {
11586 let db = EmbeddedDatabase::new_in_memory().unwrap();
11587 db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
11588 db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
11589 db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
11590 db.execute("INSERT INTO nums (id, val) VALUES (3, 30)").unwrap();
11591 let results = db
11592 .query("SELECT val, ROW_NUMBER() OVER () FROM nums", &[])
11593 .unwrap();
11594 assert_eq!(results.len(), 3);
11595 let row_nums: Vec<i64> = results
11596 .iter()
11597 .map(|r| match r.get(1).unwrap() {
11598 Value::Int8(v) => *v,
11599 _ => panic!("expected Int8"),
11600 })
11601 .collect();
11602 let mut sorted = row_nums.clone();
11603 sorted.sort();
11604 assert_eq!(sorted, vec![1, 2, 3]);
11605 }
11606
11607 #[test]
11608 fn test_window_descending_order() {
11609 let db = EmbeddedDatabase::new_in_memory().unwrap();
11610 db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
11611 db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
11612 db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
11613 db.execute("INSERT INTO nums (id, val) VALUES (3, 30)").unwrap();
11614 let results = db
11615 .query(
11616 "SELECT val, ROW_NUMBER() OVER (ORDER BY val DESC) FROM nums",
11617 &[],
11618 )
11619 .unwrap();
11620 assert_eq!(results.len(), 3);
11621 for row in &results {
11622 let val = match row.get(0).unwrap() {
11623 Value::Int4(v) => *v as i64,
11624 Value::Int8(v) => *v,
11625 _ => panic!("unexpected type"),
11626 };
11627 let rn = match row.get(1).unwrap() {
11628 Value::Int8(v) => *v,
11629 _ => panic!("expected Int8"),
11630 };
11631 match val {
11632 30 => assert_eq!(rn, 1),
11633 20 => assert_eq!(rn, 2),
11634 10 => assert_eq!(rn, 3),
11635 _ => panic!("unexpected val {}", val),
11636 }
11637 }
11638 }
11639
11640 #[test]
11641 fn test_window_sum_with_where_clause() {
11642 let db = setup_window_test_db();
11643 let results = db.query(
11644 "SELECT name, salary, SUM(salary) OVER (ORDER BY salary) FROM employees WHERE dept = 'Engineering'",
11645 &[],
11646 ).unwrap();
11647 assert_eq!(results.len(), 3, "Only Engineering employees");
11648 let sums: Vec<f64> = results
11649 .iter()
11650 .map(|r| match r.get(2).unwrap() {
11651 Value::Float8(v) => *v,
11652 other => panic!("expected Float8, got {:?}", other),
11653 })
11654 .collect();
11655 let max_sum = sums.iter().cloned().fold(0.0_f64, f64::max);
11658 assert!(
11659 (max_sum - 340_000.0).abs() < 0.01,
11660 "Max running sum should be 340000, got {}",
11661 max_sum
11662 );
11663 let min_sum = sums.iter().cloned().fold(f64::MAX, f64::min);
11665 assert!(
11666 min_sum >= 109_999.0,
11667 "Min running sum should be >= 110000, got {}",
11668 min_sum
11669 );
11670 }
11671
11672 #[test]
11673 fn test_window_count_star_over() {
11674 let db = EmbeddedDatabase::new_in_memory().unwrap();
11676 db.execute("CREATE TABLE t (id INT PRIMARY KEY, val INT)").unwrap();
11677 db.execute("INSERT INTO t (id, val) VALUES (1, 10)").unwrap();
11678 db.execute("INSERT INTO t (id) VALUES (2)").unwrap(); db.execute("INSERT INTO t (id, val) VALUES (3, 30)").unwrap();
11680 let results = db
11681 .query("SELECT id, COUNT(*) OVER () FROM t", &[])
11682 .unwrap();
11683 assert_eq!(results.len(), 3);
11684 for row in &results {
11686 assert_eq!(
11687 row.get(1).unwrap(),
11688 &Value::Int8(3),
11689 "COUNT(*) OVER() should count all rows"
11690 );
11691 }
11692 }
11693
11694 #[test]
11695 fn test_window_count_column_excludes_nulls() {
11696 let db = EmbeddedDatabase::new_in_memory().unwrap();
11698 db.execute("CREATE TABLE t (id INT PRIMARY KEY, val INT)").unwrap();
11699 db.execute("INSERT INTO t (id, val) VALUES (1, 10)").unwrap();
11700 db.execute("INSERT INTO t (id) VALUES (2)").unwrap(); db.execute("INSERT INTO t (id, val) VALUES (3, 30)").unwrap();
11702 let results = db
11703 .query("SELECT id, COUNT(val) OVER () FROM t", &[])
11704 .unwrap();
11705 assert_eq!(results.len(), 3);
11706 for row in &results {
11707 assert_eq!(
11709 row.get(1).unwrap(),
11710 &Value::Int8(2),
11711 "COUNT(col) should exclude NULLs per SQL standard"
11712 );
11713 }
11714 }
11715
11716 #[test]
11717 fn test_window_multiple_partitions_multiple_functions() {
11718 let db = setup_window_test_db();
11719 let results = db.query(
11720 "SELECT name, dept, salary, \
11721 ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary DESC), \
11722 SUM(salary) OVER (PARTITION BY dept), \
11723 AVG(salary) OVER (PARTITION BY dept) \
11724 FROM employees",
11725 &[],
11726 ).unwrap();
11727 assert_eq!(results.len(), 6);
11728 for row in &results {
11729 assert_eq!(row.len(), 6, "3 original + 3 window columns");
11730 }
11731 }
11732
11733 #[test]
11734 fn test_window_preserves_original_columns() {
11735 let db = setup_window_test_db();
11736 let results = db
11737 .query(
11738 "SELECT id, name, dept, salary, ROW_NUMBER() OVER (ORDER BY salary) FROM employees",
11739 &[],
11740 )
11741 .unwrap();
11742 assert_eq!(results.len(), 6);
11743 for row in &results {
11744 assert_eq!(row.len(), 5, "4 original + 1 window column");
11745 assert!(
11746 matches!(row.get(0).unwrap(), Value::Int4(_) | Value::Int8(_)),
11747 "id should be integer"
11748 );
11749 assert!(matches!(row.get(1).unwrap(), Value::String(_)), "name should be string");
11750 }
11751 }
11752
11753 #[test]
11754 fn test_window_lag_partitioned() {
11755 let db = setup_window_test_db();
11756 let results = db.query(
11757 "SELECT name, dept, salary, LAG(salary, 1) OVER (PARTITION BY dept ORDER BY salary) FROM employees",
11758 &[],
11759 ).unwrap();
11760 assert_eq!(results.len(), 6);
11761 let mut dept_null_count: std::collections::HashMap<String, usize> =
11764 std::collections::HashMap::new();
11765 for row in &results {
11766 if let Some(Value::String(dept)) = row.get(1) {
11767 if row.get(3).unwrap() == &Value::Null {
11768 *dept_null_count.entry(dept.clone()).or_insert(0) += 1;
11769 }
11770 }
11771 }
11772 for (dept, count) in &dept_null_count {
11774 assert_eq!(
11775 *count, 1,
11776 "Partition {} should have exactly 1 NULL LAG, got {}",
11777 dept, count
11778 );
11779 }
11780 }
11781
11782 #[test]
11783 fn test_window_lead_partitioned() {
11784 let db = setup_window_test_db();
11785 let results = db.query(
11786 "SELECT name, dept, salary, LEAD(salary, 1) OVER (PARTITION BY dept ORDER BY salary) FROM employees",
11787 &[],
11788 ).unwrap();
11789 assert_eq!(results.len(), 6);
11790 let mut dept_null_count: std::collections::HashMap<String, usize> =
11793 std::collections::HashMap::new();
11794 for row in &results {
11795 if let Some(Value::String(dept)) = row.get(1) {
11796 if row.get(3).unwrap() == &Value::Null {
11797 *dept_null_count.entry(dept.clone()).or_insert(0) += 1;
11798 }
11799 }
11800 }
11801 for (dept, count) in &dept_null_count {
11802 assert_eq!(
11803 *count, 1,
11804 "Partition {} should have exactly 1 NULL LEAD, got {}",
11805 dept, count
11806 );
11807 }
11808 }
11809
11810 #[test]
11811 fn test_window_large_dataset_row_number() {
11812 let db = EmbeddedDatabase::new_in_memory().unwrap();
11813 db.execute("CREATE TABLE large_t (id INT PRIMARY KEY, val INT)").unwrap();
11814 for i in 1..=500 {
11815 db.execute(&format!(
11816 "INSERT INTO large_t (id, val) VALUES ({}, {})",
11817 i, i
11818 ))
11819 .unwrap();
11820 }
11821 let results = db
11822 .query(
11823 "SELECT id, ROW_NUMBER() OVER (ORDER BY val) FROM large_t",
11824 &[],
11825 )
11826 .unwrap();
11827 assert_eq!(results.len(), 500);
11828 let row_nums: std::collections::HashSet<i64> = results
11829 .iter()
11830 .map(|r| match r.get(1).unwrap() {
11831 Value::Int8(v) => *v,
11832 _ => panic!("expected Int8"),
11833 })
11834 .collect();
11835 assert_eq!(row_nums.len(), 500);
11836 assert!(row_nums.contains(&1));
11837 assert!(row_nums.contains(&500));
11838 }
11839
11840 #[test]
11841 fn test_window_percent_rank() {
11842 let db = EmbeddedDatabase::new_in_memory().unwrap();
11843 db.execute("CREATE TABLE scores (id INT PRIMARY KEY, score INT)").unwrap();
11844 db.execute("INSERT INTO scores (id, score) VALUES (1, 100)").unwrap();
11845 db.execute("INSERT INTO scores (id, score) VALUES (2, 200)").unwrap();
11846 db.execute("INSERT INTO scores (id, score) VALUES (3, 300)").unwrap();
11847 db.execute("INSERT INTO scores (id, score) VALUES (4, 400)").unwrap();
11848 let results = db
11849 .query(
11850 "SELECT score, PERCENT_RANK() OVER (ORDER BY score) FROM scores",
11851 &[],
11852 )
11853 .unwrap();
11854 assert_eq!(results.len(), 4);
11855 let pct_ranks: Vec<f64> = results
11856 .iter()
11857 .map(|r| match r.get(1).unwrap() {
11858 Value::Float8(v) => *v,
11859 other => panic!("expected Float8, got {:?}", other),
11860 })
11861 .collect();
11862 assert!((pct_ranks[0] - 0.0).abs() < 0.01);
11864 assert!((pct_ranks[1] - 1.0 / 3.0).abs() < 0.01);
11865 assert!((pct_ranks[2] - 2.0 / 3.0).abs() < 0.01);
11866 assert!((pct_ranks[3] - 1.0).abs() < 0.01);
11867 }
11868
11869 #[test]
11870 fn test_window_percent_rank_single_row() {
11871 let db = EmbeddedDatabase::new_in_memory().unwrap();
11872 db.execute("CREATE TABLE one_row (id INT PRIMARY KEY, val INT)").unwrap();
11873 db.execute("INSERT INTO one_row (id, val) VALUES (1, 42)").unwrap();
11874 let results = db
11875 .query(
11876 "SELECT val, PERCENT_RANK() OVER (ORDER BY val) FROM one_row",
11877 &[],
11878 )
11879 .unwrap();
11880 assert_eq!(results.len(), 1);
11881 assert!(
11882 matches!(results[0].get(1).unwrap(), Value::Float8(v) if v.abs() < 0.01),
11883 "PERCENT_RANK with single row should be 0.0"
11884 );
11885 }
11886
11887 #[test]
11888 fn test_window_ntile_single_bucket() {
11889 let db = EmbeddedDatabase::new_in_memory().unwrap();
11890 db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
11891 db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
11892 db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
11893 db.execute("INSERT INTO nums (id, val) VALUES (3, 30)").unwrap();
11894 let results = db
11895 .query(
11896 "SELECT val, NTILE(1) OVER (ORDER BY val) FROM nums",
11897 &[],
11898 )
11899 .unwrap();
11900 assert_eq!(results.len(), 3);
11901 for row in &results {
11902 assert_eq!(row.get(1).unwrap(), &Value::Int8(1));
11903 }
11904 }
11905
11906 #[test]
11907 fn test_window_ntile_more_buckets_than_rows() {
11908 let db = EmbeddedDatabase::new_in_memory().unwrap();
11909 db.execute("CREATE TABLE nums (id INT PRIMARY KEY, val INT)").unwrap();
11910 db.execute("INSERT INTO nums (id, val) VALUES (1, 10)").unwrap();
11911 db.execute("INSERT INTO nums (id, val) VALUES (2, 20)").unwrap();
11912 let results = db
11913 .query(
11914 "SELECT val, NTILE(5) OVER (ORDER BY val) FROM nums",
11915 &[],
11916 )
11917 .unwrap();
11918 assert_eq!(results.len(), 2);
11919 let buckets: Vec<i64> = results
11920 .iter()
11921 .map(|r| match r.get(1).unwrap() {
11922 Value::Int8(v) => *v,
11923 _ => panic!("expected Int8"),
11924 })
11925 .collect();
11926 assert!(
11927 buckets.iter().all(|b| *b >= 1 && *b <= 5),
11928 "Buckets should be in range 1..=5"
11929 );
11930 }
11931
11932 #[test]
11937 fn test_returning_insert_star() {
11938 let db = EmbeddedDatabase::new_in_memory().unwrap();
11939 db.execute("CREATE TABLE ret_test (a INT, b TEXT)").unwrap();
11940 let (count, rows) = db.execute_returning(
11941 "INSERT INTO ret_test (a, b) VALUES (1, 'hello') RETURNING *"
11942 ).unwrap();
11943 assert_eq!(count, 1);
11944 assert_eq!(rows.len(), 1);
11945 assert_eq!(rows[0].values.len(), 2);
11946 assert_eq!(rows[0].values[0], Value::Int4(1));
11947 assert_eq!(rows[0].values[1], Value::String("hello".to_string()));
11948 }
11949
11950 #[test]
11951 fn test_returning_insert_specific_columns() {
11952 let db = EmbeddedDatabase::new_in_memory().unwrap();
11953 db.execute("CREATE TABLE ret_cols (a INT, b TEXT)").unwrap();
11954 let (count, rows) = db.execute_returning(
11955 "INSERT INTO ret_cols (a, b) VALUES (1, 'world') RETURNING a, b"
11956 ).unwrap();
11957 assert_eq!(count, 1);
11958 assert_eq!(rows.len(), 1);
11959 assert_eq!(rows[0].values.len(), 2);
11960 assert_eq!(rows[0].values[0], Value::Int4(1));
11961 assert_eq!(rows[0].values[1], Value::String("world".to_string()));
11962 }
11963
11964 #[test]
11965 fn test_returning_insert_single_column() {
11966 let db = EmbeddedDatabase::new_in_memory().unwrap();
11967 db.execute("CREATE TABLE ret_single (a INT, b TEXT)").unwrap();
11968 let (count, rows) = db.execute_returning(
11969 "INSERT INTO ret_single (a, b) VALUES (42, 'test') RETURNING a"
11970 ).unwrap();
11971 assert_eq!(count, 1);
11972 assert_eq!(rows.len(), 1);
11973 assert_eq!(rows[0].values.len(), 1);
11974 assert_eq!(rows[0].values[0], Value::Int4(42));
11975 }
11976
11977 #[test]
11978 fn test_returning_insert_expression_with_alias() {
11979 let db = EmbeddedDatabase::new_in_memory().unwrap();
11980 db.execute("CREATE TABLE ret_expr (a INT, b INT)").unwrap();
11981 let (count, rows) = db.execute_returning(
11982 "INSERT INTO ret_expr (a, b) VALUES (1, 2) RETURNING a + 1 AS incremented"
11983 ).unwrap();
11984 assert_eq!(count, 1);
11985 assert_eq!(rows.len(), 1);
11986 assert_eq!(rows[0].values.len(), 1);
11987 assert_eq!(rows[0].values[0], Value::Int4(2));
11989 }
11990
11991 #[test]
11992 fn test_returning_update_star() {
11993 let db = EmbeddedDatabase::new_in_memory().unwrap();
11994 db.execute("CREATE TABLE ret_upd (a INT, b INT)").unwrap();
11995 db.execute("INSERT INTO ret_upd (a, b) VALUES (1, 5)").unwrap();
11996 let (count, rows) = db.execute_returning(
11997 "UPDATE ret_upd SET b = 10 WHERE a = 1 RETURNING *"
11998 ).unwrap();
11999 assert_eq!(count, 1);
12000 assert_eq!(rows.len(), 1);
12001 assert_eq!(rows[0].values[0], Value::Int4(1));
12002 assert_eq!(rows[0].values[1], Value::Int4(10));
12003 }
12004
12005 #[test]
12006 fn test_returning_update_specific_column() {
12007 let db = EmbeddedDatabase::new_in_memory().unwrap();
12008 db.execute("CREATE TABLE ret_upd2 (a INT, b INT)").unwrap();
12009 db.execute("INSERT INTO ret_upd2 (a, b) VALUES (1, 5)").unwrap();
12010 db.execute("INSERT INTO ret_upd2 (a, b) VALUES (2, 6)").unwrap();
12011 let (count, rows) = db.execute_returning(
12012 "UPDATE ret_upd2 SET b = 99 WHERE a = 2 RETURNING b"
12013 ).unwrap();
12014 assert_eq!(count, 1);
12015 assert_eq!(rows.len(), 1);
12016 assert_eq!(rows[0].values.len(), 1);
12017 assert_eq!(rows[0].values[0], Value::Int4(99));
12018 }
12019
12020 #[test]
12021 fn test_returning_delete() {
12022 let db = EmbeddedDatabase::new_in_memory().unwrap();
12023 db.execute("CREATE TABLE ret_del (a INT, b TEXT)").unwrap();
12024 db.execute("INSERT INTO ret_del (a, b) VALUES (1, 'one')").unwrap();
12025 db.execute("INSERT INTO ret_del (a, b) VALUES (2, 'two')").unwrap();
12026 let (count, rows) = db.execute_returning(
12027 "DELETE FROM ret_del WHERE a = 1 RETURNING a"
12028 ).unwrap();
12029 assert_eq!(count, 1);
12030 assert_eq!(rows.len(), 1);
12031 assert_eq!(rows[0].values.len(), 1);
12032 assert_eq!(rows[0].values[0], Value::Int4(1));
12033 }
12034
12035 #[test]
12036 fn test_returning_delete_star() {
12037 let db = EmbeddedDatabase::new_in_memory().unwrap();
12038 db.execute("CREATE TABLE ret_del2 (a INT, b TEXT)").unwrap();
12039 db.execute("INSERT INTO ret_del2 (a, b) VALUES (1, 'one')").unwrap();
12040 db.execute("INSERT INTO ret_del2 (a, b) VALUES (2, 'two')").unwrap();
12041 let (count, rows) = db.execute_returning(
12042 "DELETE FROM ret_del2 WHERE a = 2 RETURNING *"
12043 ).unwrap();
12044 assert_eq!(count, 1);
12045 assert_eq!(rows.len(), 1);
12046 assert_eq!(rows[0].values.len(), 2);
12047 assert_eq!(rows[0].values[0], Value::Int4(2));
12048 assert_eq!(rows[0].values[1], Value::String("two".to_string()));
12049 }
12050
12051 #[test]
12052 fn test_returning_multi_row_insert() {
12053 let db = EmbeddedDatabase::new_in_memory().unwrap();
12054 db.execute("CREATE TABLE ret_multi (a INT, b INT)").unwrap();
12055 let (count, rows) = db.execute_returning(
12056 "INSERT INTO ret_multi (a, b) VALUES (1, 10), (2, 20), (3, 30) RETURNING *"
12057 ).unwrap();
12058 assert_eq!(count, 3);
12059 assert_eq!(rows.len(), 3);
12060 assert_eq!(rows[0].values[0], Value::Int4(1));
12061 assert_eq!(rows[1].values[0], Value::Int4(2));
12062 assert_eq!(rows[2].values[0], Value::Int4(3));
12063 }
12064
12065 #[test]
12066 fn test_returning_no_matching_rows() {
12067 let db = EmbeddedDatabase::new_in_memory().unwrap();
12068 db.execute("CREATE TABLE ret_empty (a INT)").unwrap();
12069 let (count, rows) = db.execute_returning(
12070 "DELETE FROM ret_empty WHERE a = 999 RETURNING *"
12071 ).unwrap();
12072 assert_eq!(count, 0);
12073 assert_eq!(rows.len(), 0);
12074 }
12075
12076 #[test]
12077 fn test_returning_via_query() {
12078 let db = EmbeddedDatabase::new_in_memory().unwrap();
12080 db.execute("CREATE TABLE ret_query (a INT, b TEXT)").unwrap();
12081 let rows = db.query(
12082 "INSERT INTO ret_query (a, b) VALUES (7, 'seven') RETURNING *",
12083 &[]
12084 ).unwrap();
12085 assert_eq!(rows.len(), 1);
12086 assert_eq!(rows[0].values[0], Value::Int4(7));
12087 assert_eq!(rows[0].values[1], Value::String("seven".to_string()));
12088 }
12089
12090 #[test]
12091 fn test_returning_update_no_clause() {
12092 let db = EmbeddedDatabase::new_in_memory().unwrap();
12094 db.execute("CREATE TABLE ret_none (a INT)").unwrap();
12095 db.execute("INSERT INTO ret_none (a) VALUES (1)").unwrap();
12096 let (count, rows) = db.execute_returning(
12097 "UPDATE ret_none SET a = 2 WHERE a = 1"
12098 ).unwrap();
12099 assert_eq!(count, 1);
12100 assert_eq!(rows.len(), 0);
12101 }
12102
12103 fn parse_json_value(val: &Value) -> serde_json::Value {
12117 match val {
12118 Value::Json(j) => serde_json::from_str(j).unwrap(),
12119 Value::String(s) => serde_json::from_str(s).unwrap_or_else(|_| serde_json::json!(s)),
12120 other => panic!("Expected Json or String, got {:?}", other),
12121 }
12122 }
12123
12124 #[test]
12125 fn test_json_column_create_insert_select() {
12126 let db = EmbeddedDatabase::new_in_memory().unwrap();
12128 db.execute("CREATE TABLE json_basic (id INT PRIMARY KEY, data JSONB)").unwrap();
12129 db.execute(r#"INSERT INTO json_basic (id, data) VALUES (1, '{"name":"Alice","age":30}')"#).unwrap();
12130 db.execute(r#"INSERT INTO json_basic (id, data) VALUES (2, '{"name":"Bob","age":25}')"#).unwrap();
12131
12132 let rows = db.query("SELECT id, data FROM json_basic ORDER BY id", &[]).unwrap();
12133 assert_eq!(rows.len(), 2);
12134 assert_eq!(rows[0].values[0], Value::Int4(1));
12135 let parsed = parse_json_value(&rows[0].values[1]);
12136 assert_eq!(parsed["name"], "Alice");
12137 assert_eq!(parsed["age"], 30);
12138 }
12139
12140 #[test]
12141 fn test_json_column_type_json_vs_jsonb() {
12142 let db = EmbeddedDatabase::new_in_memory().unwrap();
12144 db.execute("CREATE TABLE json_types (id INT PRIMARY KEY, j JSON, jb JSONB)").unwrap();
12145 db.execute(r#"INSERT INTO json_types (id, j, jb) VALUES (1, '{"a":1}', '{"b":2}')"#).unwrap();
12146
12147 let rows = db.query("SELECT j, jb FROM json_types WHERE id = 1", &[]).unwrap();
12148 assert_eq!(rows.len(), 1);
12149 let j_parsed = parse_json_value(&rows[0].values[0]);
12150 assert_eq!(j_parsed["a"], 1);
12151 let jb_parsed = parse_json_value(&rows[0].values[1]);
12152 assert_eq!(jb_parsed["b"], 2);
12153 }
12154
12155 #[test]
12156 fn test_json_null_column() {
12157 let db = EmbeddedDatabase::new_in_memory().unwrap();
12159 db.execute("CREATE TABLE json_nulls (id INT PRIMARY KEY, data JSONB)").unwrap();
12160 db.execute("INSERT INTO json_nulls (id, data) VALUES (1, NULL)").unwrap();
12161
12162 let rows = db.query("SELECT data FROM json_nulls WHERE id = 1", &[]).unwrap();
12163 assert_eq!(rows.len(), 1);
12164 assert_eq!(rows[0].values[0], Value::Null);
12165 }
12166
12167 #[test]
12168 fn test_json_cast_string_to_jsonb() {
12169 let db = EmbeddedDatabase::new_in_memory().unwrap();
12171 let rows = db.query(r#"SELECT CAST('{"hello":"world"}' AS JSONB)"#, &[]).unwrap();
12172 assert_eq!(rows.len(), 1);
12173 match &rows[0].values[0] {
12174 Value::Json(j) => {
12175 let parsed: serde_json::Value = serde_json::from_str(j).unwrap();
12176 assert_eq!(parsed["hello"], "world");
12177 }
12178 other => panic!("Expected Json from CAST, got {:?}", other),
12179 }
12180 }
12181
12182 #[test]
12183 fn test_json_cast_to_json_type() {
12184 let db = EmbeddedDatabase::new_in_memory().unwrap();
12186 let rows = db.query(r#"SELECT CAST('{"k":"v"}' AS JSON)"#, &[]).unwrap();
12187 assert_eq!(rows.len(), 1);
12188 match &rows[0].values[0] {
12189 Value::Json(j) => {
12190 let parsed: serde_json::Value = serde_json::from_str(j).unwrap();
12191 assert_eq!(parsed["k"], "v");
12192 }
12193 other => panic!("Expected Json from CAST to JSON, got {:?}", other),
12194 }
12195 }
12196
12197 #[test]
12198 fn test_json_arrow_get_object_field_via_cast() {
12199 let db = EmbeddedDatabase::new_in_memory().unwrap();
12201
12202 let rows = db.query(
12203 r#"SELECT CAST('{"name":"Alice","age":30}' AS JSONB)->'name'"#, &[]
12204 ).unwrap();
12205 assert_eq!(rows.len(), 1);
12206 match &rows[0].values[0] {
12207 Value::Json(j) => assert_eq!(j, "\"Alice\""),
12208 other => panic!("Expected Json from ->, got {:?}", other),
12209 }
12210 }
12211
12212 #[test]
12213 fn test_json_double_arrow_get_text_via_cast() {
12214 let db = EmbeddedDatabase::new_in_memory().unwrap();
12216
12217 let rows = db.query(
12218 r#"SELECT CAST('{"name":"Alice","age":30}' AS JSONB)->>'name'"#, &[]
12219 ).unwrap();
12220 assert_eq!(rows.len(), 1);
12221 assert_eq!(rows[0].values[0], Value::String("Alice".to_string()));
12222 }
12223
12224 #[test]
12225 fn test_json_arrow_get_numeric_as_text_via_cast() {
12226 let db = EmbeddedDatabase::new_in_memory().unwrap();
12228
12229 let rows = db.query(
12230 r#"SELECT CAST('{"age":25}' AS JSONB)->>'age'"#, &[]
12231 ).unwrap();
12232 assert_eq!(rows.len(), 1);
12233 assert_eq!(rows[0].values[0], Value::String("25".to_string()));
12234 }
12235
12236 #[test]
12237 fn test_json_arrow_array_index_via_cast() {
12238 let db = EmbeddedDatabase::new_in_memory().unwrap();
12240
12241 let rows = db.query(
12242 r#"SELECT CAST('["apple","banana","cherry"]' AS JSONB)->1"#, &[]
12243 ).unwrap();
12244 assert_eq!(rows.len(), 1);
12245 match &rows[0].values[0] {
12246 Value::Json(j) => assert_eq!(j, "\"banana\""),
12247 other => panic!("Expected Json for array index, got {:?}", other),
12248 }
12249 }
12250
12251 #[test]
12252 fn test_json_arrow_missing_key_via_cast() {
12253 let db = EmbeddedDatabase::new_in_memory().unwrap();
12255
12256 let rows = db.query(
12257 r#"SELECT CAST('{"name":"Alice"}' AS JSONB)->'nonexistent'"#, &[]
12258 ).unwrap();
12259 assert_eq!(rows.len(), 1);
12260 assert_eq!(rows[0].values[0], Value::Null);
12261 }
12262
12263 #[test]
12264 fn test_json_arrow_on_null_column() {
12265 let db = EmbeddedDatabase::new_in_memory().unwrap();
12267 db.execute("CREATE TABLE json_null_op (id INT PRIMARY KEY, data JSONB)").unwrap();
12268 db.execute("INSERT INTO json_null_op (id, data) VALUES (1, NULL)").unwrap();
12269
12270 let rows = db.query("SELECT CAST(data AS JSONB)->'key' FROM json_null_op WHERE id = 1", &[]).unwrap();
12272 assert_eq!(rows.len(), 1);
12273 assert_eq!(rows[0].values[0], Value::Null);
12274 }
12275
12276 #[test]
12277 fn test_json_nested_arrow_chaining_via_cast() {
12278 let db = EmbeddedDatabase::new_in_memory().unwrap();
12280
12281 let rows = db.query(
12282 r#"SELECT CAST('{"user":{"address":{"city":"NYC"}}}' AS JSONB)->'user'->'address'->'city'"#,
12283 &[]
12284 ).unwrap();
12285 assert_eq!(rows.len(), 1);
12286 match &rows[0].values[0] {
12287 Value::Json(j) => assert_eq!(j, "\"NYC\""),
12288 other => panic!("Expected nested Json, got {:?}", other),
12289 }
12290 }
12291
12292 #[test]
12293 fn test_json_nested_arrow_then_double_arrow() {
12294 let db = EmbeddedDatabase::new_in_memory().unwrap();
12296
12297 let rows = db.query(
12298 r#"SELECT CAST('{"user":{"name":"Alice"}}' AS JSONB)->'user'->>'name'"#, &[]
12299 ).unwrap();
12300 assert_eq!(rows.len(), 1);
12301 assert_eq!(rows[0].values[0], Value::String("Alice".to_string()));
12302 }
12303
12304 #[test]
12305 fn test_json_contains_operator_via_cast() {
12306 let db = EmbeddedDatabase::new_in_memory().unwrap();
12308
12309 let rows = db.query(
12310 r#"SELECT CAST('{"name":"Alice","city":"NYC"}' AS JSONB) @> CAST('{"city":"NYC"}' AS JSONB)"#,
12311 &[]
12312 ).unwrap();
12313 assert_eq!(rows.len(), 1);
12314 assert_eq!(rows[0].values[0], Value::Boolean(true));
12315
12316 let rows = db.query(
12317 r#"SELECT CAST('{"name":"Alice","city":"NYC"}' AS JSONB) @> CAST('{"city":"LA"}' AS JSONB)"#,
12318 &[]
12319 ).unwrap();
12320 assert_eq!(rows[0].values[0], Value::Boolean(false));
12321 }
12322
12323 #[test]
12324 fn test_json_contained_by_operator_via_cast() {
12325 let db = EmbeddedDatabase::new_in_memory().unwrap();
12327
12328 let rows = db.query(
12329 r#"SELECT CAST('{"a":1}' AS JSONB) <@ CAST('{"a":1,"b":2}' AS JSONB)"#, &[]
12330 ).unwrap();
12331 assert_eq!(rows[0].values[0], Value::Boolean(true));
12332
12333 let rows = db.query(
12334 r#"SELECT CAST('{"a":1,"c":3}' AS JSONB) <@ CAST('{"a":1,"b":2}' AS JSONB)"#, &[]
12335 ).unwrap();
12336 assert_eq!(rows[0].values[0], Value::Boolean(false));
12337 }
12338
12339 #[test]
12340 fn test_json_contains_nested_via_cast() {
12341 let db = EmbeddedDatabase::new_in_memory().unwrap();
12343
12344 let rows = db.query(
12345 r#"SELECT CAST('{"user":{"address":{"city":"NYC"}}}' AS JSONB) @> CAST('{"user":{"address":{"city":"NYC"}}}' AS JSONB)"#,
12346 &[]
12347 ).unwrap();
12348 assert_eq!(rows[0].values[0], Value::Boolean(true));
12349
12350 let rows = db.query(
12351 r#"SELECT CAST('{"user":{"address":{"city":"NYC"}}}' AS JSONB) @> CAST('{"user":{"address":{"city":"LA"}}}' AS JSONB)"#,
12352 &[]
12353 ).unwrap();
12354 assert_eq!(rows[0].values[0], Value::Boolean(false));
12355 }
12356
12357 #[test]
12358 fn test_json_contains_array_values_via_cast() {
12359 let db = EmbeddedDatabase::new_in_memory().unwrap();
12361
12362 let rows = db.query(
12363 r#"SELECT CAST('{"tags":["rust","db","json"]}' AS JSONB) @> CAST('{"tags":["rust"]}' AS JSONB)"#,
12364 &[]
12365 ).unwrap();
12366 assert_eq!(rows[0].values[0], Value::Boolean(true));
12367
12368 let rows = db.query(
12369 r#"SELECT CAST('{"tags":["rust","db"]}' AS JSONB) @> CAST('{"tags":["python"]}' AS JSONB)"#,
12370 &[]
12371 ).unwrap();
12372 assert_eq!(rows[0].values[0], Value::Boolean(false));
12373 }
12374
12375 #[test]
12376 fn test_json_complex_data_types_via_cast() {
12377 let db = EmbeddedDatabase::new_in_memory().unwrap();
12379 let json_str = r#"{"str":"hello","num":42,"flag":true,"arr":[1,2,3],"obj":{"x":1}}"#;
12380
12381 let rows = db.query(
12382 &format!("SELECT CAST('{}' AS JSONB)->>'str'", json_str), &[]
12383 ).unwrap();
12384 assert_eq!(rows[0].values[0], Value::String("hello".to_string()));
12385
12386 let rows = db.query(
12387 &format!("SELECT CAST('{}' AS JSONB)->>'num'", json_str), &[]
12388 ).unwrap();
12389 assert_eq!(rows[0].values[0], Value::String("42".to_string()));
12390
12391 let rows = db.query(
12392 &format!("SELECT CAST('{}' AS JSONB)->>'flag'", json_str), &[]
12393 ).unwrap();
12394 assert_eq!(rows[0].values[0], Value::String("true".to_string()));
12395
12396 let rows = db.query(
12397 &format!("SELECT CAST('{}' AS JSONB)->'arr'", json_str), &[]
12398 ).unwrap();
12399 match &rows[0].values[0] {
12400 Value::Json(j) => {
12401 let parsed: serde_json::Value = serde_json::from_str(j).unwrap();
12402 assert!(parsed.is_array());
12403 assert_eq!(parsed.as_array().unwrap().len(), 3);
12404 }
12405 other => panic!("Expected Json for nested array, got {:?}", other),
12406 }
12407 }
12408
12409 #[test]
12410 fn test_json_update_column() {
12411 let db = EmbeddedDatabase::new_in_memory().unwrap();
12413 db.execute("CREATE TABLE json_update (id INT PRIMARY KEY, data JSONB)").unwrap();
12414 db.execute(r#"INSERT INTO json_update (id, data) VALUES (1, '{"v":1}')"#).unwrap();
12415 db.execute(r#"UPDATE json_update SET data = '{"v":2,"extra":"added"}' WHERE id = 1"#).unwrap();
12416
12417 let rows = db.query("SELECT data FROM json_update WHERE id = 1", &[]).unwrap();
12418 let parsed = parse_json_value(&rows[0].values[0]);
12419 assert_eq!(parsed["v"], 2);
12420 assert_eq!(parsed["extra"], "added");
12421 }
12422
12423 #[test]
12424 fn test_json_func_jsonb_typeof() {
12425 let db = EmbeddedDatabase::new_in_memory().unwrap();
12427
12428 let cases = vec![
12429 (r#"'{"k":"v"}'"#, "object"),
12430 (r#"'[1,2,3]'"#, "array"),
12431 (r#"'"hello"'"#, "string"),
12432 (r#"'42'"#, "number"),
12433 (r#"'true'"#, "boolean"),
12434 (r#"'null'"#, "null"),
12435 ];
12436
12437 for (json_literal, expected_type) in cases {
12438 let query = format!("SELECT jsonb_typeof(CAST({} AS JSONB))", json_literal);
12439 let rows = db.query(&query, &[]).unwrap();
12440 assert_eq!(
12441 rows[0].values[0],
12442 Value::String(expected_type.to_string()),
12443 "jsonb_typeof failed for {}",
12444 json_literal
12445 );
12446 }
12447 }
12448
12449 #[test]
12450 fn test_json_func_jsonb_array_length() {
12451 let db = EmbeddedDatabase::new_in_memory().unwrap();
12453
12454 let rows = db.query("SELECT jsonb_array_length(CAST('[10,20,30,40]' AS JSONB))", &[]).unwrap();
12455 assert_eq!(rows[0].values[0], Value::Int4(4));
12456
12457 let rows = db.query("SELECT jsonb_array_length(CAST('[]' AS JSONB))", &[]).unwrap();
12458 assert_eq!(rows[0].values[0], Value::Int4(0));
12459 }
12460
12461 #[test]
12462 fn test_json_func_jsonb_extract_path() {
12463 let db = EmbeddedDatabase::new_in_memory().unwrap();
12465
12466 let rows = db.query(
12467 r#"SELECT jsonb_extract_path(CAST('{"user":{"address":{"city":"NYC"}}}' AS JSONB), 'user', 'address', 'city')"#,
12468 &[]
12469 ).unwrap();
12470 assert_eq!(rows.len(), 1);
12471 match &rows[0].values[0] {
12472 Value::Json(j) => assert_eq!(j, "\"NYC\""),
12473 other => panic!("Expected Json from jsonb_extract_path, got {:?}", other),
12474 }
12475 }
12476
12477 #[test]
12478 fn test_json_func_jsonb_extract_path_text() {
12479 let db = EmbeddedDatabase::new_in_memory().unwrap();
12481
12482 let rows = db.query(
12483 r#"SELECT jsonb_extract_path_text(CAST('{"user":{"name":"Alice"}}' AS JSONB), 'user', 'name')"#,
12484 &[]
12485 ).unwrap();
12486 assert_eq!(rows[0].values[0], Value::String("Alice".to_string()));
12487 }
12488
12489 #[test]
12490 fn test_json_func_jsonb_extract_path_missing() {
12491 let db = EmbeddedDatabase::new_in_memory().unwrap();
12493
12494 let rows = db.query(
12495 r#"SELECT jsonb_extract_path(CAST('{"a":1}' AS JSONB), 'nonexistent', 'path')"#, &[]
12496 ).unwrap();
12497 assert_eq!(rows[0].values[0], Value::Null);
12498 }
12499
12500 #[test]
12501 fn test_json_func_jsonb_object_keys() {
12502 let db = EmbeddedDatabase::new_in_memory().unwrap();
12503
12504 let rows = db.query(
12505 r#"SELECT jsonb_object_keys(CAST('{"name":"Alice","age":30,"city":"NYC"}' AS JSONB))"#, &[]
12506 ).unwrap();
12507 assert_eq!(rows.len(), 1);
12508 match &rows[0].values[0] {
12509 Value::Array(keys) => {
12510 let key_strings: Vec<String> = keys.iter().filter_map(|v| {
12511 if let Value::String(s) = v { Some(s.clone()) } else { None }
12512 }).collect();
12513 assert!(key_strings.contains(&"name".to_string()));
12514 assert!(key_strings.contains(&"age".to_string()));
12515 assert!(key_strings.contains(&"city".to_string()));
12516 assert_eq!(key_strings.len(), 3);
12517 }
12518 other => panic!("Expected Array, got {:?}", other),
12519 }
12520 }
12521
12522 #[test]
12523 fn test_json_func_jsonb_build_object() {
12524 let db = EmbeddedDatabase::new_in_memory().unwrap();
12525 let rows = db.query("SELECT jsonb_build_object('name', 'Alice', 'age', 30)", &[]).unwrap();
12526 match &rows[0].values[0] {
12527 Value::Json(j) => {
12528 let parsed: serde_json::Value = serde_json::from_str(j).unwrap();
12529 assert_eq!(parsed["name"], "Alice");
12530 assert_eq!(parsed["age"], 30);
12531 }
12532 other => panic!("Expected Json from jsonb_build_object, got {:?}", other),
12533 }
12534 }
12535
12536 #[test]
12537 fn test_json_func_jsonb_build_array() {
12538 let db = EmbeddedDatabase::new_in_memory().unwrap();
12539 let rows = db.query("SELECT jsonb_build_array(1, 'two', 3, true)", &[]).unwrap();
12540 match &rows[0].values[0] {
12541 Value::Json(j) => {
12542 let parsed: serde_json::Value = serde_json::from_str(j).unwrap();
12543 let arr = parsed.as_array().unwrap();
12544 assert_eq!(arr.len(), 4);
12545 assert_eq!(arr[0], 1);
12546 assert_eq!(arr[1], "two");
12547 assert_eq!(arr[2], 3);
12548 assert_eq!(arr[3], true);
12549 }
12550 other => panic!("Expected Json from jsonb_build_array, got {:?}", other),
12551 }
12552 }
12553
12554 #[test]
12555 fn test_json_func_jsonb_strip_nulls() {
12556 let db = EmbeddedDatabase::new_in_memory().unwrap();
12557
12558 let rows = db.query(
12559 r#"SELECT jsonb_strip_nulls(CAST('{"a":1,"b":null,"c":"hello","d":null}' AS JSONB))"#, &[]
12560 ).unwrap();
12561 match &rows[0].values[0] {
12562 Value::Json(j) => {
12563 let parsed: serde_json::Value = serde_json::from_str(j).unwrap();
12564 assert_eq!(parsed["a"], 1);
12565 assert_eq!(parsed["c"], "hello");
12566 assert!(parsed.get("b").is_none());
12567 assert!(parsed.get("d").is_none());
12568 }
12569 other => panic!("Expected Json from jsonb_strip_nulls, got {:?}", other),
12570 }
12571 }
12572
12573 #[test]
12574 fn test_json_func_jsonb_strip_nulls_nested() {
12575 let db = EmbeddedDatabase::new_in_memory().unwrap();
12576
12577 let rows = db.query(
12578 r#"SELECT jsonb_strip_nulls(CAST('{"a":1,"b":{"c":null,"d":2},"e":null}' AS JSONB))"#, &[]
12579 ).unwrap();
12580 match &rows[0].values[0] {
12581 Value::Json(j) => {
12582 let parsed: serde_json::Value = serde_json::from_str(j).unwrap();
12583 assert_eq!(parsed["a"], 1);
12584 assert!(parsed.get("e").is_none());
12585 assert!(parsed["b"].get("c").is_none());
12586 assert_eq!(parsed["b"]["d"], 2);
12587 }
12588 other => panic!("Expected Json, got {:?}", other),
12589 }
12590 }
12591
12592 #[test]
12593 fn test_json_func_jsonb_pretty() {
12594 let db = EmbeddedDatabase::new_in_memory().unwrap();
12595
12596 let rows = db.query(
12597 r#"SELECT jsonb_pretty(CAST('{"a":1,"b":2}' AS JSONB))"#, &[]
12598 ).unwrap();
12599 match &rows[0].values[0] {
12600 Value::String(s) => {
12601 assert!(s.contains('\n'));
12602 let parsed: serde_json::Value = serde_json::from_str(s).unwrap();
12603 assert_eq!(parsed["a"], 1);
12604 assert_eq!(parsed["b"], 2);
12605 }
12606 other => panic!("Expected String from jsonb_pretty, got {:?}", other),
12607 }
12608 }
12609
12610 #[test]
12611 fn test_json_func_jsonb_path_query() {
12612 let db = EmbeddedDatabase::new_in_memory().unwrap();
12613
12614 let rows = db.query(
12615 r#"SELECT jsonb_path_query(CAST('{"user":{"name":"Alice"}}' AS JSONB), 'user.name')"#, &[]
12616 ).unwrap();
12617 match &rows[0].values[0] {
12618 Value::Json(j) => assert_eq!(j, "\"Alice\""),
12619 other => panic!("Expected Json from jsonb_path_query, got {:?}", other),
12620 }
12621 }
12622
12623 #[test]
12624 fn test_json_func_jsonb_path_query_nested() {
12625 let db = EmbeddedDatabase::new_in_memory().unwrap();
12626
12627 let rows = db.query(
12628 r#"SELECT jsonb_path_query(CAST('{"config":{"db":{"host":"localhost","port":5432}}}' AS JSONB), 'config.db.host')"#,
12629 &[]
12630 ).unwrap();
12631 match &rows[0].values[0] {
12632 Value::Json(j) => assert_eq!(j, "\"localhost\""),
12633 other => panic!("Expected Json, got {:?}", other),
12634 }
12635
12636 let rows = db.query(
12637 r#"SELECT jsonb_path_query(CAST('{"config":{"db":{"host":"localhost","port":5432}}}' AS JSONB), 'config.db.port')"#,
12638 &[]
12639 ).unwrap();
12640 match &rows[0].values[0] {
12641 Value::Json(j) => assert_eq!(j, "5432"),
12642 other => panic!("Expected Json, got {:?}", other),
12643 }
12644 }
12645
12646 #[test]
12647 fn test_json_func_jsonb_path_exists() {
12648 let db = EmbeddedDatabase::new_in_memory().unwrap();
12649
12650 let rows = db.query(
12651 r#"SELECT jsonb_path_exists(CAST('{"user":{"name":"Alice"}}' AS JSONB), 'user.name')"#, &[]
12652 ).unwrap();
12653 assert_eq!(rows[0].values[0], Value::Boolean(true));
12654
12655 let rows = db.query(
12656 r#"SELECT jsonb_path_exists(CAST('{"user":{"name":"Alice"}}' AS JSONB), 'user.email')"#, &[]
12657 ).unwrap();
12658 assert_eq!(rows[0].values[0], Value::Boolean(false));
12659 }
12660
12661 #[test]
12662 fn test_json_func_jsonb_path_query_array() {
12663 let db = EmbeddedDatabase::new_in_memory().unwrap();
12664
12665 let rows = db.query(
12666 r#"SELECT jsonb_path_query_array(CAST('{"user":{"name":"Alice"}}' AS JSONB), 'user.name')"#, &[]
12667 ).unwrap();
12668 match &rows[0].values[0] {
12669 Value::Array(arr) => {
12670 assert_eq!(arr.len(), 1);
12671 match &arr[0] {
12672 Value::Json(j) => assert_eq!(j, "\"Alice\""),
12673 other => panic!("Expected Json inside array, got {:?}", other),
12674 }
12675 }
12676 other => panic!("Expected Array, got {:?}", other),
12677 }
12678 }
12679
12680 #[test]
12681 fn test_json_func_jsonb_path_query_first() {
12682 let db = EmbeddedDatabase::new_in_memory().unwrap();
12683
12684 let rows = db.query(
12685 r#"SELECT jsonb_path_query_first(CAST('{"x":{"y":42}}' AS JSONB), 'x.y')"#, &[]
12686 ).unwrap();
12687 match &rows[0].values[0] {
12688 Value::Json(j) => assert_eq!(j, "42"),
12689 other => panic!("Expected Json, got {:?}", other),
12690 }
12691 }
12692
12693 #[test]
12694 fn test_json_func_jsonb_set() {
12695 let db = EmbeddedDatabase::new_in_memory().unwrap();
12696
12697 let rows = db.query(
12698 r#"SELECT jsonb_set(CAST('{"name":"Alice","age":30}' AS JSONB), ARRAY['age'], '31')"#, &[]
12699 ).unwrap();
12700 match &rows[0].values[0] {
12701 Value::Json(j) => {
12702 let parsed: serde_json::Value = serde_json::from_str(j).unwrap();
12703 assert_eq!(parsed["name"], "Alice");
12704 assert_eq!(parsed["age"], "31");
12705 }
12706 other => panic!("Expected Json from jsonb_set, got {:?}", other),
12707 }
12708 }
12709
12710 #[test]
12711 fn test_json_func_jsonb_set_nested() {
12712 let db = EmbeddedDatabase::new_in_memory().unwrap();
12713
12714 let rows = db.query(
12715 r#"SELECT jsonb_set(CAST('{"user":{"name":"Alice","age":30}}' AS JSONB), ARRAY['user','name'], '"Bob"')"#,
12716 &[]
12717 ).unwrap();
12718 match &rows[0].values[0] {
12719 Value::Json(j) => {
12720 let parsed: serde_json::Value = serde_json::from_str(j).unwrap();
12721 assert_eq!(parsed["user"]["age"], 30);
12722 }
12723 other => panic!("Expected Json from jsonb_set nested, got {:?}", other),
12724 }
12725 }
12726
12727 #[test]
12728 fn test_json_func_jsonb_concat() {
12729 let db = EmbeddedDatabase::new_in_memory().unwrap();
12730
12731 let rows = db.query(
12732 r#"SELECT jsonb_concat(CAST('{"x":1}' AS JSONB), CAST('{"y":2}' AS JSONB))"#, &[]
12733 ).unwrap();
12734 match &rows[0].values[0] {
12735 Value::Json(j) => {
12736 let parsed: serde_json::Value = serde_json::from_str(j).unwrap();
12737 assert_eq!(parsed["x"], 1);
12738 assert_eq!(parsed["y"], 2);
12739 }
12740 other => panic!("Expected Json from jsonb_concat, got {:?}", other),
12741 }
12742 }
12743
12744 #[test]
12745 fn test_json_func_jsonb_concat_overwrites() {
12746 let db = EmbeddedDatabase::new_in_memory().unwrap();
12748
12749 let rows = db.query(
12750 r#"SELECT jsonb_concat(CAST('{"x":1,"y":2}' AS JSONB), CAST('{"y":99,"z":3}' AS JSONB))"#, &[]
12751 ).unwrap();
12752 match &rows[0].values[0] {
12753 Value::Json(j) => {
12754 let parsed: serde_json::Value = serde_json::from_str(j).unwrap();
12755 assert_eq!(parsed["x"], 1);
12756 assert_eq!(parsed["y"], 99);
12757 assert_eq!(parsed["z"], 3);
12758 }
12759 other => panic!("Expected Json from jsonb_concat, got {:?}", other),
12760 }
12761 }
12762
12763 #[test]
12764 fn test_json_func_jsonb_delete() {
12765 let db = EmbeddedDatabase::new_in_memory().unwrap();
12766
12767 let rows = db.query(
12768 r#"SELECT jsonb_delete(CAST('{"a":1,"b":2,"c":3}' AS JSONB), ARRAY['b'])"#, &[]
12769 ).unwrap();
12770 match &rows[0].values[0] {
12771 Value::Json(j) => {
12772 let parsed: serde_json::Value = serde_json::from_str(j).unwrap();
12773 assert_eq!(parsed["a"], 1);
12774 assert_eq!(parsed["c"], 3);
12775 assert!(parsed.get("b").is_none());
12776 }
12777 other => panic!("Expected Json from jsonb_delete, got {:?}", other),
12778 }
12779 }
12780
12781 #[test]
12782 fn test_json_func_jsonb_each() {
12783 let db = EmbeddedDatabase::new_in_memory().unwrap();
12784
12785 let rows = db.query(
12786 r#"SELECT jsonb_each(CAST('{"x":10,"y":20}' AS JSONB))"#, &[]
12787 ).unwrap();
12788 match &rows[0].values[0] {
12789 Value::Array(pairs) => {
12790 assert_eq!(pairs.len(), 4);
12791 let has_x = pairs.iter().any(|v| matches!(v, Value::String(s) if s == "x"));
12792 let has_y = pairs.iter().any(|v| matches!(v, Value::String(s) if s == "y"));
12793 assert!(has_x);
12794 assert!(has_y);
12795 }
12796 other => panic!("Expected Array from jsonb_each, got {:?}", other),
12797 }
12798 }
12799
12800 #[test]
12801 fn test_json_func_jsonb_each_text() {
12802 let db = EmbeddedDatabase::new_in_memory().unwrap();
12803
12804 let rows = db.query(
12805 r#"SELECT jsonb_each_text(CAST('{"name":"Alice","age":30}' AS JSONB))"#, &[]
12806 ).unwrap();
12807 match &rows[0].values[0] {
12808 Value::Array(pairs) => {
12809 for v in pairs {
12810 assert!(matches!(v, Value::String(_)), "Expected text, got {:?}", v);
12811 }
12812 }
12813 other => panic!("Expected Array from jsonb_each_text, got {:?}", other),
12814 }
12815 }
12816
12817 #[test]
12818 fn test_json_func_jsonb_array_elements() {
12819 let db = EmbeddedDatabase::new_in_memory().unwrap();
12821
12822 let rows = db.query(
12823 r#"SELECT jsonb_array_elements(CAST('["first","second","third"]' AS JSONB))"#, &[]
12824 ).unwrap();
12825 match &rows[0].values[0] {
12826 Value::Json(j) => assert_eq!(j, "\"first\""),
12827 other => panic!("Expected Json from jsonb_array_elements, got {:?}", other),
12828 }
12829 }
12830
12831 #[test]
12832 fn test_json_func_jsonb_array_elements_text() {
12833 let db = EmbeddedDatabase::new_in_memory().unwrap();
12834
12835 let rows = db.query(
12836 r#"SELECT jsonb_array_elements_text(CAST('["hello","world"]' AS JSONB))"#, &[]
12837 ).unwrap();
12838 assert_eq!(rows[0].values[0], Value::String("hello".to_string()));
12839 }
12840
12841 #[test]
12842 fn test_json_agg_function() {
12843 let db = EmbeddedDatabase::new_in_memory().unwrap();
12844 db.execute("CREATE TABLE json_agg_t (id INT PRIMARY KEY, name TEXT)").unwrap();
12845 db.execute("INSERT INTO json_agg_t (id, name) VALUES (1, 'Alice')").unwrap();
12846 db.execute("INSERT INTO json_agg_t (id, name) VALUES (2, 'Bob')").unwrap();
12847 db.execute("INSERT INTO json_agg_t (id, name) VALUES (3, 'Charlie')").unwrap();
12848
12849 let rows = db.query("SELECT json_agg(name) FROM json_agg_t", &[]).unwrap();
12850 match &rows[0].values[0] {
12851 Value::Json(j) => {
12852 let parsed: serde_json::Value = serde_json::from_str(j).unwrap();
12853 let arr = parsed.as_array().unwrap();
12854 assert_eq!(arr.len(), 3);
12855 let strings: Vec<&str> = arr.iter().filter_map(|v| v.as_str()).collect();
12856 assert!(strings.contains(&"Alice"));
12857 assert!(strings.contains(&"Bob"));
12858 assert!(strings.contains(&"Charlie"));
12859 }
12860 other => panic!("Expected Json from json_agg, got {:?}", other),
12861 }
12862 }
12863
12864 #[test]
12865 fn test_json_func_null_handling() {
12866 let db = EmbeddedDatabase::new_in_memory().unwrap();
12868
12869 let rows = db.query("SELECT jsonb_extract_path(NULL, 'key')", &[]).unwrap();
12871 assert_eq!(rows[0].values[0], Value::Null);
12872
12873 let rows = db.query("SELECT jsonb_pretty(NULL)", &[]).unwrap();
12875 assert_eq!(rows[0].values[0], Value::Null);
12876 }
12877
12878 #[test]
12879 fn test_json_empty_object_and_array() {
12880 let db = EmbeddedDatabase::new_in_memory().unwrap();
12881
12882 let rows = db.query("SELECT jsonb_typeof(CAST('{}' AS JSONB))", &[]).unwrap();
12883 assert_eq!(rows[0].values[0], Value::String("object".to_string()));
12884
12885 let rows = db.query("SELECT jsonb_typeof(CAST('[]' AS JSONB))", &[]).unwrap();
12886 assert_eq!(rows[0].values[0], Value::String("array".to_string()));
12887
12888 let rows = db.query("SELECT jsonb_array_length(CAST('[]' AS JSONB))", &[]).unwrap();
12889 assert_eq!(rows[0].values[0], Value::Int4(0));
12890
12891 let rows = db.query("SELECT jsonb_object_keys(CAST('{}' AS JSONB))", &[]).unwrap();
12892 match &rows[0].values[0] {
12893 Value::Array(keys) => assert_eq!(keys.len(), 0),
12894 other => panic!("Expected empty Array, got {:?}", other),
12895 }
12896 }
12897
12898 #[test]
12899 fn test_json_deeply_nested_via_cast() {
12900 let db = EmbeddedDatabase::new_in_memory().unwrap();
12901
12902 let rows = db.query(
12903 r#"SELECT CAST('{"a":{"b":{"c":{"d":{"e":"deep"}}}}}' AS JSONB)->'a'->'b'->'c'->'d'->>'e'"#,
12904 &[]
12905 ).unwrap();
12906 assert_eq!(rows[0].values[0], Value::String("deep".to_string()));
12907 }
12908
12909 #[test]
12910 fn test_json_double_arrow_on_null_json_field() {
12911 let db = EmbeddedDatabase::new_in_memory().unwrap();
12913
12914 let rows = db.query(
12915 r#"SELECT CAST('{"a":null}' AS JSONB)->>'a'"#, &[]
12916 ).unwrap();
12917 assert_eq!(rows[0].values[0], Value::String("null".to_string()));
12918 }
12919
12920 #[test]
12921 fn test_json_contains_false_cases() {
12922 let db = EmbeddedDatabase::new_in_memory().unwrap();
12923
12924 let rows = db.query(
12926 r#"SELECT CAST('{"a":1,"b":2}' AS JSONB) @> CAST('{"a":99}' AS JSONB)"#, &[]
12927 ).unwrap();
12928 assert_eq!(rows[0].values[0], Value::Boolean(false));
12929
12930 let rows = db.query(
12932 r#"SELECT CAST('{"a":1,"b":2}' AS JSONB) @> CAST('{"z":1}' AS JSONB)"#, &[]
12933 ).unwrap();
12934 assert_eq!(rows[0].values[0], Value::Boolean(false));
12935 }
12936
12937 #[test]
12938 fn test_json_storage_roundtrip_preserves_data() {
12939 let db = EmbeddedDatabase::new_in_memory().unwrap();
12941 db.execute("CREATE TABLE json_rt (id INT PRIMARY KEY, data JSONB)").unwrap();
12942
12943 let test_cases = vec![
12944 (1, r#"{"nested":{"a":1,"b":[2,3]}}"#),
12945 (2, r#"[1,"two",true,null]"#),
12946 (3, r#""just a string""#),
12947 (4, r#"42"#),
12948 (5, r#"true"#),
12949 ];
12950
12951 for (id, json) in &test_cases {
12952 db.execute(&format!("INSERT INTO json_rt (id, data) VALUES ({}, '{}')", id, json)).unwrap();
12953 }
12954
12955 let rows = db.query("SELECT id, data FROM json_rt ORDER BY id", &[]).unwrap();
12956 assert_eq!(rows.len(), 5);
12957
12958 for (i, (_, expected_json)) in test_cases.iter().enumerate() {
12959 let parsed = parse_json_value(&rows[i].values[1]);
12960 let expected: serde_json::Value = serde_json::from_str(expected_json).unwrap();
12961 assert_eq!(parsed, expected, "Round-trip failed for row {}", i + 1);
12962 }
12963 }
12964
12965 #[test]
12966 fn test_json_delete_rows_from_json_table() {
12967 let db = EmbeddedDatabase::new_in_memory().unwrap();
12969 db.execute("CREATE TABLE json_del (id INT PRIMARY KEY, data JSONB)").unwrap();
12970 db.execute(r#"INSERT INTO json_del (id, data) VALUES (1, '{"x":1}')"#).unwrap();
12971 db.execute(r#"INSERT INTO json_del (id, data) VALUES (2, '{"x":2}')"#).unwrap();
12972 db.execute(r#"INSERT INTO json_del (id, data) VALUES (3, '{"x":3}')"#).unwrap();
12973
12974 db.execute("DELETE FROM json_del WHERE id = 2").unwrap();
12975
12976 let rows = db.query("SELECT id FROM json_del ORDER BY id", &[]).unwrap();
12977 assert_eq!(rows.len(), 2);
12978 assert_eq!(rows[0].values[0], Value::Int4(1));
12979 assert_eq!(rows[1].values[0], Value::Int4(3));
12980 }
12981
12982 #[test]
12983 fn test_json_build_object_then_arrow() {
12984 let db = EmbeddedDatabase::new_in_memory().unwrap();
12986
12987 let rows = db.query(
12988 "SELECT jsonb_build_object('name', 'Alice', 'age', 30)->>'name'", &[]
12989 ).unwrap();
12990 assert_eq!(rows[0].values[0], Value::String("Alice".to_string()));
12991 }
12992
12993 #[test]
12994 fn test_json_build_array_then_index() {
12995 let db = EmbeddedDatabase::new_in_memory().unwrap();
12997
12998 let rows = db.query("SELECT jsonb_build_array(10, 20, 30)->1", &[]).unwrap();
12999 match &rows[0].values[0] {
13000 Value::Json(j) => assert_eq!(j, "20"),
13001 other => panic!("Expected Json, got {:?}", other),
13002 }
13003 }
13004
13005 #[test]
13006 fn test_json_typeof_on_null() {
13007 let db = EmbeddedDatabase::new_in_memory().unwrap();
13008 let rows = db.query("SELECT jsonb_typeof(NULL)", &[]).unwrap();
13009 assert_eq!(rows[0].values[0], Value::String("null".to_string()));
13010 }
13011
13012 #[test]
13013 fn test_json_mixed_with_regular_columns() {
13014 let db = EmbeddedDatabase::new_in_memory().unwrap();
13016 db.execute("CREATE TABLE json_mixed (id INT PRIMARY KEY, name TEXT, meta JSONB)").unwrap();
13017 db.execute(r#"INSERT INTO json_mixed (id, name, meta) VALUES (1, 'Alice', '{"role":"admin"}')"#).unwrap();
13018 db.execute(r#"INSERT INTO json_mixed (id, name, meta) VALUES (2, 'Bob', '{"role":"user"}')"#).unwrap();
13019
13020 let rows = db.query("SELECT name, meta FROM json_mixed ORDER BY id", &[]).unwrap();
13021 assert_eq!(rows.len(), 2);
13022 assert_eq!(rows[0].values[0], Value::String("Alice".to_string()));
13023 let meta0 = parse_json_value(&rows[0].values[1]);
13024 assert_eq!(meta0["role"], "admin");
13025 assert_eq!(rows[1].values[0], Value::String("Bob".to_string()));
13026 let meta1 = parse_json_value(&rows[1].values[1]);
13027 assert_eq!(meta1["role"], "user");
13028 }
13029
13030 #[test]
13031 fn test_json_large_document() {
13032 let db = EmbeddedDatabase::new_in_memory().unwrap();
13034 db.execute("CREATE TABLE json_large (id INT PRIMARY KEY, data JSONB)").unwrap();
13035
13036 let mut json_obj = serde_json::Map::new();
13037 for i in 0..50 {
13038 json_obj.insert(format!("key_{}", i), serde_json::json!(i));
13039 }
13040 let json_str = serde_json::Value::Object(json_obj).to_string();
13041 db.execute(&format!("INSERT INTO json_large (id, data) VALUES (1, '{}')", json_str)).unwrap();
13042
13043 let rows = db.query("SELECT data FROM json_large WHERE id = 1", &[]).unwrap();
13044 let parsed = parse_json_value(&rows[0].values[0]);
13045 assert_eq!(parsed["key_25"], 25);
13046 assert_eq!(parsed["key_0"], 0);
13047 assert_eq!(parsed["key_49"], 49);
13048 }
13049
13050 #[test]
13051 fn test_json_unicode_content() {
13052 let db = EmbeddedDatabase::new_in_memory().unwrap();
13053 db.execute("CREATE TABLE json_uni (id INT PRIMARY KEY, data JSONB)").unwrap();
13054 db.execute(r#"INSERT INTO json_uni (id, data) VALUES (1, '{"greeting":"Bonjour"}')"#).unwrap();
13055
13056 let rows = db.query("SELECT data FROM json_uni WHERE id = 1", &[]).unwrap();
13057 let parsed = parse_json_value(&rows[0].values[0]);
13058 assert_eq!(parsed["greeting"], "Bonjour");
13059 }
13060
13061 #[test]
13082 fn test_recursive_cte_simple_counting() {
13083 let db = EmbeddedDatabase::new_in_memory().unwrap();
13093
13094 let sql = "\
13095 WITH RECURSIVE cnt(n) AS ( \
13096 SELECT 1 \
13097 UNION ALL \
13098 SELECT n + 1 FROM cnt WHERE n < 10 \
13099 ) \
13100 SELECT n FROM cnt";
13101
13102 match db.query(sql, &[]) {
13103 Ok(rows) => {
13104 assert_eq!(rows.len(), 10, "Expected 10 rows for counting 1..10, got {}", rows.len());
13105 for (i, row) in rows.iter().enumerate() {
13106 let val = row.get(0).unwrap();
13107 let expected = (i as i32) + 1;
13108 assert_eq!(
13109 val, &Value::Int4(expected),
13110 "Row {} should be {}, got {:?}", i, expected, val
13111 );
13112 }
13113 }
13114 Err(e) => {
13115 panic!(
13116 "Recursive CTE simple counting failed with error: {}. \
13117 This indicates recursive CTEs may not be supported.",
13118 e
13119 );
13120 }
13121 }
13122 }
13123
13124 #[test]
13125 fn test_recursive_cte_tree_traversal() {
13126 let db = EmbeddedDatabase::new_in_memory().unwrap();
13129
13130 db.execute("CREATE TABLE rc_employees (id INT PRIMARY KEY, name TEXT, manager_id INT)").unwrap();
13131 db.execute("INSERT INTO rc_employees VALUES (1, 'CEO', NULL)").unwrap();
13132 db.execute("INSERT INTO rc_employees VALUES (2, 'VP', 1)").unwrap();
13133 db.execute("INSERT INTO rc_employees VALUES (3, 'Director', 2)").unwrap();
13134 db.execute("INSERT INTO rc_employees VALUES (4, 'Manager', 3)").unwrap();
13135 db.execute("INSERT INTO rc_employees VALUES (5, 'Staff', 4)").unwrap();
13136
13137 let sql = "\
13139 WITH RECURSIVE reports(id, name, manager_id) AS ( \
13140 SELECT id, name, manager_id FROM rc_employees WHERE id = 2 \
13141 UNION ALL \
13142 SELECT e.id, e.name, e.manager_id \
13143 FROM rc_employees e \
13144 JOIN reports r ON e.manager_id = r.id \
13145 ) \
13146 SELECT id, name FROM reports ORDER BY id";
13147
13148 match db.query(sql, &[]) {
13149 Ok(rows) => {
13150 assert_eq!(rows.len(), 4, "Expected 4 reports under VP, got {}", rows.len());
13152 let ids: Vec<&Value> = rows.iter().map(|r| r.get(0).unwrap()).collect();
13153 assert_eq!(ids[0], &Value::Int4(2), "First should be VP (id=2)");
13154 assert_eq!(ids[1], &Value::Int4(3), "Second should be Director (id=3)");
13155 assert_eq!(ids[2], &Value::Int4(4), "Third should be Manager (id=4)");
13156 assert_eq!(ids[3], &Value::Int4(5), "Fourth should be Staff (id=5)");
13157 }
13158 Err(e) => {
13159 let err_msg = e.to_string();
13162 assert!(
13163 err_msg.contains("not found") ||
13164 err_msg.contains("not implemented") ||
13165 err_msg.contains("not yet") ||
13166 err_msg.contains("recursive") ||
13167 err_msg.contains("ambiguous") ||
13168 err_msg.contains("column"),
13169 "Unexpected error in recursive CTE tree traversal: {}", err_msg
13170 );
13171 }
13172 }
13173 }
13174
13175 #[test]
13176 fn test_recursive_cte_fibonacci() {
13177 let db = EmbeddedDatabase::new_in_memory().unwrap();
13182
13183 let sql = "\
13184 WITH RECURSIVE fib(a, b) AS ( \
13185 SELECT 0, 1 \
13186 UNION ALL \
13187 SELECT b, a + b FROM fib WHERE b < 100 \
13188 ) \
13189 SELECT a FROM fib";
13190
13191 match db.query(sql, &[]) {
13192 Ok(rows) => {
13193 let expected_deduped: Vec<i32> = vec![0, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89];
13194 let expected_full: Vec<i32> = vec![0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89];
13195
13196 if rows.len() == expected_deduped.len() {
13197 for (i, (row, exp)) in rows.iter().zip(expected_deduped.iter()).enumerate() {
13198 let val = row.get(0).unwrap();
13199 assert_eq!(val, &Value::Int4(*exp),
13200 "Fibonacci (deduped) row {} should be {}, got {:?}", i, exp, val);
13201 }
13202 } else if rows.len() == expected_full.len() {
13203 for (i, (row, exp)) in rows.iter().zip(expected_full.iter()).enumerate() {
13204 let val = row.get(0).unwrap();
13205 assert_eq!(val, &Value::Int4(*exp),
13206 "Fibonacci row {} should be {}, got {:?}", i, exp, val);
13207 }
13208 } else {
13209 panic!("Expected 11 (deduped) or 12 (full) Fibonacci numbers, got {}", rows.len());
13210 }
13211 }
13212 Err(e) => {
13213 let err_msg = e.to_string();
13214 assert!(
13215 err_msg.contains("not found") ||
13216 err_msg.contains("not implemented") ||
13217 err_msg.contains("not yet") ||
13218 err_msg.contains("column") ||
13219 err_msg.contains("type"),
13220 "Unexpected error in recursive CTE Fibonacci: {}", err_msg
13221 );
13222 }
13223 }
13224 }
13225
13226 #[test]
13227 fn test_recursive_cte_depth_limit_via_where() {
13228 let db = EmbeddedDatabase::new_in_memory().unwrap();
13231
13232 let sql = "\
13233 WITH RECURSIVE nums(n) AS ( \
13234 SELECT 1 \
13235 UNION ALL \
13236 SELECT n + 1 FROM nums WHERE n < 5 \
13237 ) \
13238 SELECT n FROM nums";
13239
13240 match db.query(sql, &[]) {
13241 Ok(rows) => {
13242 assert_eq!(rows.len(), 5, "Expected 5 rows for counting 1..5, got {}", rows.len());
13243 for (i, row) in rows.iter().enumerate() {
13244 let val = row.get(0).unwrap();
13245 let expected = (i as i32) + 1;
13246 assert_eq!(
13247 val, &Value::Int4(expected),
13248 "Row {} should be {}, got {:?}", i, expected, val
13249 );
13250 }
13251 }
13252 Err(e) => {
13253 panic!(
13254 "Recursive CTE with WHERE depth limit failed: {}",
13255 e
13256 );
13257 }
13258 }
13259 }
13260
13261 #[test]
13262 fn test_recursive_cte_non_recursive_basic() {
13263 let db = EmbeddedDatabase::new_in_memory().unwrap();
13266
13267 let sql = "WITH summary AS (SELECT 42 AS answer) SELECT answer FROM summary";
13268
13269 match db.query(sql, &[]) {
13270 Ok(rows) => {
13271 assert_eq!(rows.len(), 1, "Non-recursive CTE should return 1 row");
13272 let val = rows[0].get(0).unwrap();
13273 assert_eq!(
13274 val, &Value::Int4(42),
13275 "Non-recursive CTE should return 42, got {:?}", val
13276 );
13277 }
13278 Err(e) => {
13279 panic!(
13280 "Non-recursive CTE failed: {}. Basic WITH support should work.",
13281 e
13282 );
13283 }
13284 }
13285 }
13286
13287 #[test]
13288 fn test_recursive_cte_non_recursive_table_data() {
13289 let db = EmbeddedDatabase::new_in_memory().unwrap();
13291
13292 db.execute("CREATE TABLE rc_products (id INT, name TEXT, price INT)").unwrap();
13293 db.execute("INSERT INTO rc_products VALUES (1, 'Widget', 10)").unwrap();
13294 db.execute("INSERT INTO rc_products VALUES (2, 'Gadget', 25)").unwrap();
13295 db.execute("INSERT INTO rc_products VALUES (3, 'Doohickey', 5)").unwrap();
13296
13297 let sql = "\
13298 WITH expensive AS ( \
13299 SELECT id, name, price FROM rc_products WHERE price > 8 \
13300 ) \
13301 SELECT name FROM expensive ORDER BY name";
13302
13303 match db.query(sql, &[]) {
13304 Ok(rows) => {
13305 assert_eq!(rows.len(), 2, "Expected 2 expensive products, got {}", rows.len());
13306 assert_eq!(rows[0].get(0).unwrap(), &Value::String("Gadget".to_string()));
13307 assert_eq!(rows[1].get(0).unwrap(), &Value::String("Widget".to_string()));
13308 }
13309 Err(e) => {
13310 panic!("Non-recursive CTE with table data failed: {}", e);
13311 }
13312 }
13313 }
13314
13315 #[test]
13316 fn test_recursive_cte_join_with_table() {
13317 let db = EmbeddedDatabase::new_in_memory().unwrap();
13320
13321 db.execute("CREATE TABLE rc_items (id INT PRIMARY KEY, label TEXT)").unwrap();
13322 db.execute("INSERT INTO rc_items VALUES (1, 'alpha')").unwrap();
13323 db.execute("INSERT INTO rc_items VALUES (2, 'beta')").unwrap();
13324 db.execute("INSERT INTO rc_items VALUES (3, 'gamma')").unwrap();
13325
13326 let sql = "\
13327 WITH RECURSIVE nums(n) AS ( \
13328 SELECT 1 \
13329 UNION ALL \
13330 SELECT n + 1 FROM nums WHERE n < 5 \
13331 ) \
13332 SELECT nums.n, rc_items.label \
13333 FROM nums \
13334 JOIN rc_items ON nums.n = rc_items.id \
13335 ORDER BY nums.n";
13336
13337 match db.query(sql, &[]) {
13338 Ok(rows) => {
13339 assert_eq!(rows.len(), 3, "Expected 3 matched rows, got {}", rows.len());
13342 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
13343 assert_eq!(rows[0].get(1).unwrap(), &Value::String("alpha".to_string()));
13344 assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(2));
13345 assert_eq!(rows[1].get(1).unwrap(), &Value::String("beta".to_string()));
13346 assert_eq!(rows[2].get(0).unwrap(), &Value::Int4(3));
13347 assert_eq!(rows[2].get(1).unwrap(), &Value::String("gamma".to_string()));
13348 }
13349 Err(e) => {
13350 let err_msg = e.to_string();
13352 assert!(
13353 err_msg.contains("not found") ||
13354 err_msg.contains("not implemented") ||
13355 err_msg.contains("ambiguous") ||
13356 err_msg.contains("table") ||
13357 err_msg.contains("column") ||
13358 err_msg.contains("type"),
13359 "Unexpected error in recursive CTE JOIN: {}", err_msg
13360 );
13361 }
13362 }
13363 }
13364
13365 #[test]
13366 fn test_recursive_cte_empty_base_case() {
13367 let db = EmbeddedDatabase::new_in_memory().unwrap();
13370
13371 let sql = "\
13372 WITH RECURSIVE empty(n) AS ( \
13373 SELECT 1 WHERE 1 = 0 \
13374 UNION ALL \
13375 SELECT n + 1 FROM empty WHERE n < 10 \
13376 ) \
13377 SELECT n FROM empty";
13378
13379 match db.query(sql, &[]) {
13380 Ok(rows) => {
13381 assert_eq!(
13382 rows.len(), 0,
13383 "Empty base case should produce 0 rows, got {}", rows.len()
13384 );
13385 }
13386 Err(e) => {
13387 let err_msg = e.to_string();
13388 assert!(
13389 err_msg.contains("not found") ||
13390 err_msg.contains("not implemented") ||
13391 err_msg.contains("column") ||
13392 err_msg.contains("type") ||
13393 err_msg.contains("empty"),
13394 "Unexpected error in empty base case CTE: {}", err_msg
13395 );
13396 }
13397 }
13398 }
13399
13400 #[test]
13401 fn test_recursive_cte_union_vs_union_all() {
13402 let db = EmbeddedDatabase::new_in_memory().unwrap();
13405
13406 let sql_union_all = "\
13408 WITH RECURSIVE cnt(n) AS ( \
13409 SELECT 1 \
13410 UNION ALL \
13411 SELECT n + 1 FROM cnt WHERE n < 5 \
13412 ) \
13413 SELECT n FROM cnt";
13414
13415 let result_all = db.query(sql_union_all, &[]);
13416
13417 let sql_union = "\
13419 WITH RECURSIVE cnt2(n) AS ( \
13420 SELECT 1 \
13421 UNION \
13422 SELECT n + 1 FROM cnt2 WHERE n < 5 \
13423 ) \
13424 SELECT n FROM cnt2";
13425
13426 let result_distinct = db.query(sql_union, &[]);
13427
13428 match (result_all, result_distinct) {
13429 (Ok(rows_all), Ok(rows_distinct)) => {
13430 assert_eq!(
13433 rows_all.len(), 5,
13434 "UNION ALL counting 1..5 should produce 5 rows, got {}", rows_all.len()
13435 );
13436 assert!(
13437 rows_distinct.len() <= rows_all.len(),
13438 "UNION should produce <= rows than UNION ALL ({} vs {})",
13439 rows_distinct.len(), rows_all.len()
13440 );
13441 for (i, row) in rows_all.iter().enumerate() {
13443 let val = row.get(0).unwrap();
13444 let expected = (i as i32) + 1;
13445 assert_eq!(
13446 val, &Value::Int4(expected),
13447 "UNION ALL row {} should be {}, got {:?}", i, expected, val
13448 );
13449 }
13450 }
13451 (Ok(_), Err(e)) => {
13452 let err_msg = e.to_string();
13454 assert!(
13455 err_msg.contains("not implemented") ||
13456 err_msg.contains("not found") ||
13457 err_msg.contains("UNION") ||
13458 err_msg.contains("recursive"),
13459 "Unexpected error in UNION recursive CTE: {}", err_msg
13460 );
13461 }
13462 (Err(e_all), _) => {
13463 panic!(
13464 "UNION ALL recursive CTE failed: {}. Basic recursive CTE should work.",
13465 e_all
13466 );
13467 }
13468 }
13469 }
13470
13471 #[test]
13472 fn test_recursive_cte_with_sum_aggregate() {
13473 let db = EmbeddedDatabase::new_in_memory().unwrap();
13476
13477 let sql = "\
13478 WITH RECURSIVE nums(n) AS ( \
13479 SELECT 1 \
13480 UNION ALL \
13481 SELECT n + 1 FROM nums WHERE n < 10 \
13482 ) \
13483 SELECT SUM(n) FROM nums";
13484
13485 match db.query(sql, &[]) {
13486 Ok(rows) => {
13487 assert_eq!(rows.len(), 1, "Aggregate should return 1 row");
13488 let val = rows[0].get(0).unwrap();
13489 match val {
13491 Value::Int8(v) => assert_eq!(*v, 55, "SUM(1..10) should be 55, got {}", v),
13492 Value::Int4(v) => assert_eq!(*v, 55, "SUM(1..10) should be 55, got {}", v),
13493 Value::Numeric(v) => assert_eq!(v, "55", "SUM(1..10) should be 55, got {}", v),
13494 Value::Float8(v) => {
13495 assert!((*v - 55.0).abs() < 0.001,
13496 "SUM(1..10) should be 55.0, got {}", v);
13497 }
13498 other => panic!("SUM returned unexpected type: {:?}", other),
13499 }
13500 }
13501 Err(e) => {
13502 let err_msg = e.to_string();
13503 assert!(
13504 err_msg.contains("not found") ||
13505 err_msg.contains("not implemented") ||
13506 err_msg.contains("aggregate") ||
13507 err_msg.contains("column"),
13508 "Unexpected error in recursive CTE with aggregate: {}", err_msg
13509 );
13510 }
13511 }
13512 }
13513
13514 #[test]
13515 fn test_recursive_cte_with_limit() {
13516 let db = EmbeddedDatabase::new_in_memory().unwrap();
13519
13520 let sql = "\
13521 WITH RECURSIVE nums(n) AS ( \
13522 SELECT 1 \
13523 UNION ALL \
13524 SELECT n + 1 FROM nums WHERE n < 100 \
13525 ) \
13526 SELECT n FROM nums LIMIT 5";
13527
13528 let rows = db.query(sql, &[]).unwrap();
13529 assert_eq!(rows.len(), 5, "LIMIT 5 should return 5 rows, got {}", rows.len());
13530 for (i, row) in rows.iter().enumerate() {
13531 let val = row.get(0).unwrap();
13532 let expected = (i as i32) + 1;
13533 assert_eq!(
13534 val, &Value::Int4(expected),
13535 "LIMIT row {} should be {}, got {:?}", i, expected, val
13536 );
13537 }
13538 }
13539
13540 #[test]
13541 fn test_recursive_cte_single_row_termination() {
13542 let db = EmbeddedDatabase::new_in_memory().unwrap();
13545
13546 let sql = "\
13547 WITH RECURSIVE one(n) AS ( \
13548 SELECT 100 \
13549 UNION ALL \
13550 SELECT n + 1 FROM one WHERE n < 100 \
13551 ) \
13552 SELECT n FROM one";
13553
13554 match db.query(sql, &[]) {
13555 Ok(rows) => {
13556 assert_eq!(
13557 rows.len(), 1,
13558 "Should produce exactly 1 row (base case only), got {}", rows.len()
13559 );
13560 assert_eq!(
13561 rows[0].get(0).unwrap(), &Value::Int4(100),
13562 "Single row should be 100"
13563 );
13564 }
13565 Err(e) => {
13566 panic!("Recursive CTE single-row termination failed: {}", e);
13567 }
13568 }
13569 }
13570
13571 #[test]
13572 fn test_recursive_cte_non_recursive_multiple_ctes() {
13573 let db = EmbeddedDatabase::new_in_memory().unwrap();
13575
13576 db.execute("CREATE TABLE rc_multi (id INT, category TEXT, amount INT)").unwrap();
13577 db.execute("INSERT INTO rc_multi VALUES (1, 'A', 10)").unwrap();
13578 db.execute("INSERT INTO rc_multi VALUES (2, 'B', 20)").unwrap();
13579 db.execute("INSERT INTO rc_multi VALUES (3, 'A', 30)").unwrap();
13580
13581 let sql = "\
13582 WITH \
13583 cat_a AS (SELECT id, amount FROM rc_multi WHERE category = 'A'), \
13584 cat_b AS (SELECT id, amount FROM rc_multi WHERE category = 'B') \
13585 SELECT cat_a.id, cat_a.amount FROM cat_a ORDER BY cat_a.id";
13586
13587 match db.query(sql, &[]) {
13588 Ok(rows) => {
13589 assert_eq!(rows.len(), 2, "Category A has 2 items, got {}", rows.len());
13590 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
13591 assert_eq!(rows[0].get(1).unwrap(), &Value::Int4(10));
13592 assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(3));
13593 assert_eq!(rows[1].get(1).unwrap(), &Value::Int4(30));
13594 }
13595 Err(e) => {
13596 let err_msg = e.to_string();
13597 assert!(
13598 err_msg.contains("not found") ||
13599 err_msg.contains("not implemented") ||
13600 err_msg.contains("table") ||
13601 err_msg.contains("CTE"),
13602 "Unexpected error in multiple CTEs: {}", err_msg
13603 );
13604 }
13605 }
13606 }
13607
13608 #[test]
13609 fn test_recursive_cte_graph_path() {
13610 let db = EmbeddedDatabase::new_in_memory().unwrap();
13613
13614 db.execute("CREATE TABLE rc_edges (src INT, dst INT)").unwrap();
13615 db.execute("INSERT INTO rc_edges VALUES (1, 2)").unwrap();
13616 db.execute("INSERT INTO rc_edges VALUES (2, 3)").unwrap();
13617 db.execute("INSERT INTO rc_edges VALUES (3, 4)").unwrap();
13618 db.execute("INSERT INTO rc_edges VALUES (1, 5)").unwrap();
13619
13620 let sql = "\
13621 WITH RECURSIVE reachable(node) AS ( \
13622 SELECT 1 \
13623 UNION ALL \
13624 SELECT e.dst FROM rc_edges e JOIN reachable r ON e.src = r.node \
13625 ) \
13626 SELECT node FROM reachable ORDER BY node";
13627
13628 match db.query(sql, &[]) {
13629 Ok(rows) => {
13630 let nodes: Vec<i64> = rows.iter().map(|r| {
13633 match r.get(0).unwrap() {
13634 Value::Int8(v) => *v,
13635 Value::Int4(v) => i64::from(*v),
13636 other => panic!("Unexpected node type: {:?}", other),
13637 }
13638 }).collect();
13639
13640 assert!(nodes.contains(&1), "Should contain starting node 1");
13641 assert!(nodes.contains(&2), "Should contain node 2");
13642 assert!(nodes.contains(&3), "Should contain node 3");
13643 assert!(nodes.contains(&4), "Should contain node 4");
13644 assert!(nodes.contains(&5), "Should contain node 5");
13645 assert_eq!(nodes.len(), 5,
13646 "Should have exactly 5 distinct reachable nodes, got {:?}", nodes);
13647 }
13648 Err(e) => {
13649 let err_msg = e.to_string();
13650 assert!(
13651 err_msg.contains("not found") ||
13652 err_msg.contains("not implemented") ||
13653 err_msg.contains("ambiguous") ||
13654 err_msg.contains("column") ||
13655 err_msg.contains("table"),
13656 "Unexpected error in recursive CTE graph traversal: {}", err_msg
13657 );
13658 }
13659 }
13660 }
13661
13662 #[test]
13663 fn test_recursive_cte_string_concatenation() {
13664 let db = EmbeddedDatabase::new_in_memory().unwrap();
13666
13667 let sql = "\
13668 WITH RECURSIVE strs(s, len) AS ( \
13669 SELECT 'a', 1 \
13670 UNION ALL \
13671 SELECT s || 'a', len + 1 FROM strs WHERE len < 5 \
13672 ) \
13673 SELECT s, len FROM strs ORDER BY len";
13674
13675 let rows = db.query(sql, &[]).unwrap();
13676 assert_eq!(rows.len(), 5, "Expected 5 rows, got {}", rows.len());
13677 let expected = ["a", "aa", "aaa", "aaaa", "aaaaa"];
13678 for (i, row) in rows.iter().enumerate() {
13679 let s = row.get(0).unwrap();
13680 assert_eq!(
13681 s, &Value::String(expected[i].to_string()),
13682 "Row {} should be '{}', got {:?}", i, expected[i], s
13683 );
13684 }
13685 }
13686
13687 #[test]
13688 fn test_recursive_cte_powers_of_two() {
13689 let db = EmbeddedDatabase::new_in_memory().unwrap();
13691
13692 let sql = "\
13693 WITH RECURSIVE powers(n) AS ( \
13694 SELECT 1 \
13695 UNION ALL \
13696 SELECT n * 2 FROM powers WHERE n < 512 \
13697 ) \
13698 SELECT n FROM powers";
13699
13700 match db.query(sql, &[]) {
13701 Ok(rows) => {
13702 let expected: Vec<i32> = vec![1, 2, 4, 8, 16, 32, 64, 128, 256, 512];
13703 assert_eq!(
13704 rows.len(), expected.len(),
13705 "Expected {} powers of 2, got {}", expected.len(), rows.len()
13706 );
13707 for (i, (row, exp)) in rows.iter().zip(expected.iter()).enumerate() {
13708 let val = row.get(0).unwrap();
13709 assert_eq!(
13710 val, &Value::Int4(*exp),
13711 "Power of 2 row {} should be {}, got {:?}", i, exp, val
13712 );
13713 }
13714 }
13715 Err(e) => {
13716 panic!("Recursive CTE powers of two failed: {}", e);
13717 }
13718 }
13719 }
13720
13721 #[test]
13722 fn test_recursive_cte_with_count_aggregate() {
13723 let db = EmbeddedDatabase::new_in_memory().unwrap();
13726
13727 let sql = "\
13728 WITH RECURSIVE nums(n) AS ( \
13729 SELECT 1 \
13730 UNION ALL \
13731 SELECT n + 1 FROM nums WHERE n < 20 \
13732 ) \
13733 SELECT COUNT(*) FROM nums";
13734
13735 let rows = db.query(sql, &[]).unwrap();
13736 assert_eq!(rows.len(), 1, "COUNT should return 1 row");
13737 let val = rows[0].get(0).unwrap();
13738 match val {
13739 Value::Int8(v) => assert_eq!(*v, 20, "COUNT(*) should be 20, got {}", v),
13740 Value::Int4(v) => assert_eq!(*v, 20, "COUNT(*) should be 20, got {}", v),
13741 other => panic!("COUNT returned unexpected type: {:?}", other),
13742 }
13743 }
13744
13745 #[test]
13746 fn test_recursive_cte_non_recursive_column_alias() {
13747 let db = EmbeddedDatabase::new_in_memory().unwrap();
13750
13751 let sql = "WITH t(x, y) AS (SELECT 10, 20) SELECT x, y FROM t";
13752
13753 match db.query(sql, &[]) {
13754 Ok(rows) => {
13755 assert_eq!(rows.len(), 1, "Should return 1 row");
13756 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(10));
13757 assert_eq!(rows[0].get(1).unwrap(), &Value::Int4(20));
13758 }
13759 Err(e) => {
13760 let err_msg = e.to_string();
13761 assert!(
13762 err_msg.contains("not found") ||
13763 err_msg.contains("not implemented") ||
13764 err_msg.contains("column") ||
13765 err_msg.contains("alias"),
13766 "Unexpected error in CTE column alias: {}", err_msg
13767 );
13768 }
13769 }
13770 }
13771
13772 #[test]
13773 fn test_recursive_cte_descending_countdown() {
13774 let db = EmbeddedDatabase::new_in_memory().unwrap();
13776
13777 let sql = "\
13778 WITH RECURSIVE countdown(n) AS ( \
13779 SELECT 10 \
13780 UNION ALL \
13781 SELECT n - 1 FROM countdown WHERE n > 1 \
13782 ) \
13783 SELECT n FROM countdown";
13784
13785 match db.query(sql, &[]) {
13786 Ok(rows) => {
13787 assert_eq!(rows.len(), 10, "Expected 10 rows for 10..1, got {}", rows.len());
13788 for (i, row) in rows.iter().enumerate() {
13789 let val = row.get(0).unwrap();
13790 let expected = 10 - i as i32;
13791 assert_eq!(
13792 val, &Value::Int4(expected),
13793 "Countdown row {} should be {}, got {:?}", i, expected, val
13794 );
13795 }
13796 }
13797 Err(e) => {
13798 panic!("Recursive CTE countdown failed: {}", e);
13799 }
13800 }
13801 }
13802
13803 #[test]
13812 fn test_set_op_union_all_basic() {
13813 let db = EmbeddedDatabase::new_in_memory().unwrap();
13815
13816 let rows = db.query(
13817 "SELECT 1 AS id, 'alice' AS name \
13818 UNION ALL \
13819 SELECT 2, 'bob'",
13820 &[],
13821 ).unwrap();
13822
13823 assert_eq!(rows.len(), 2, "UNION ALL of two single-row SELECTs should produce 2 rows");
13824 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
13825 assert_eq!(rows[0].get(1).unwrap(), &Value::String("alice".to_string()));
13826 assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(2));
13827 assert_eq!(rows[1].get(1).unwrap(), &Value::String("bob".to_string()));
13828 }
13829
13830 #[test]
13831 fn test_set_op_union_all_preserves_duplicates() {
13832 let db = EmbeddedDatabase::new_in_memory().unwrap();
13834
13835 let rows = db.query(
13836 "SELECT 1 AS v UNION ALL SELECT 1 UNION ALL SELECT 1",
13837 &[],
13838 ).unwrap();
13839
13840 assert_eq!(rows.len(), 3, "UNION ALL of three identical rows should produce 3 rows");
13841 for row in &rows {
13842 assert_eq!(row.get(0).unwrap(), &Value::Int4(1));
13843 }
13844 }
13845
13846 #[test]
13847 fn test_set_op_union_distinct_removes_duplicates() {
13848 let db = EmbeddedDatabase::new_in_memory().unwrap();
13850
13851 let rows = db.query(
13852 "SELECT 1 AS v UNION SELECT 1 UNION SELECT 2",
13853 &[],
13854 ).unwrap();
13855
13856 assert_eq!(rows.len(), 2, "UNION of (1, 1, 2) should produce 2 distinct rows");
13857 let mut vals: Vec<i32> = rows.iter()
13858 .map(|r| match r.get(0).unwrap() {
13859 Value::Int4(n) => *n,
13860 other => panic!("Expected Int4, got {:?}", other),
13861 })
13862 .collect();
13863 vals.sort();
13864 assert_eq!(vals, vec![1, 2]);
13865 }
13866
13867 #[test]
13868 fn test_set_op_union_vs_union_all_difference() {
13869 let db = EmbeddedDatabase::new_in_memory().unwrap();
13871
13872 let rows_all = db.query(
13874 "SELECT 1 AS v UNION ALL SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 2",
13875 &[],
13876 ).unwrap();
13877
13878 let rows_distinct = db.query(
13880 "SELECT 1 AS v UNION SELECT 1 UNION SELECT 2 UNION SELECT 2",
13881 &[],
13882 ).unwrap();
13883
13884 assert_eq!(rows_all.len(), 4, "UNION ALL should produce 4 rows");
13885 assert_eq!(rows_distinct.len(), 2, "UNION (distinct) should produce 2 rows");
13886 }
13887
13888 #[test]
13889 fn test_set_op_intersect_basic() {
13890 let db = EmbeddedDatabase::new_in_memory().unwrap();
13892
13893 let rows = db.query(
13894 "SELECT 1 AS v INTERSECT SELECT 1",
13895 &[],
13896 ).unwrap();
13897
13898 assert_eq!(rows.len(), 1, "INTERSECT of (1) and (1) should produce 1 row");
13899 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
13900 }
13901
13902 #[test]
13903 fn test_set_op_intersect_no_overlap() {
13904 let db = EmbeddedDatabase::new_in_memory().unwrap();
13906
13907 let rows = db.query(
13908 "SELECT 1 AS v INTERSECT SELECT 2",
13909 &[],
13910 ).unwrap();
13911
13912 assert_eq!(rows.len(), 0, "INTERSECT of (1) and (2) should produce 0 rows");
13913 }
13914
13915 #[test]
13916 fn test_set_op_intersect_with_multiple_values() {
13917 let db = EmbeddedDatabase::new_in_memory().unwrap();
13919
13920 let sql = "\
13922 SELECT * FROM (SELECT 1 AS v UNION ALL SELECT 2 UNION ALL SELECT 3) AS a \
13923 INTERSECT \
13924 SELECT * FROM (SELECT 2 AS v UNION ALL SELECT 3 UNION ALL SELECT 4) AS b";
13925
13926 match db.query(sql, &[]) {
13927 Ok(rows) => {
13928 assert_eq!(rows.len(), 2, "INTERSECT of (1,2,3) and (2,3,4) should produce 2 rows");
13929 let mut vals: Vec<i32> = rows.iter()
13930 .map(|r| match r.get(0).unwrap() {
13931 Value::Int4(n) => *n,
13932 other => panic!("Expected Int4, got {:?}", other),
13933 })
13934 .collect();
13935 vals.sort();
13936 assert_eq!(vals, vec![2, 3]);
13937 }
13938 Err(e) => {
13939 println!("Subquery-based INTERSECT not supported: {}", e);
13941 db.execute("CREATE TABLE int_left (v INT)").unwrap();
13943 db.execute("INSERT INTO int_left VALUES (1), (2), (3)").unwrap();
13944 db.execute("CREATE TABLE int_right (v INT)").unwrap();
13945 db.execute("INSERT INTO int_right VALUES (2), (3), (4)").unwrap();
13946
13947 let rows = db.query(
13948 "SELECT v FROM int_left INTERSECT SELECT v FROM int_right",
13949 &[],
13950 ).unwrap();
13951
13952 assert_eq!(rows.len(), 2, "INTERSECT of (1,2,3) and (2,3,4) should produce 2 rows");
13953 let mut vals: Vec<i32> = rows.iter()
13954 .map(|r| match r.get(0).unwrap() {
13955 Value::Int4(n) => *n,
13956 other => panic!("Expected Int4, got {:?}", other),
13957 })
13958 .collect();
13959 vals.sort();
13960 assert_eq!(vals, vec![2, 3]);
13961 }
13962 }
13963 }
13964
13965 #[test]
13966 fn test_set_op_except_basic() {
13967 let db = EmbeddedDatabase::new_in_memory().unwrap();
13969
13970 db.execute("CREATE TABLE exc_left (v INT)").unwrap();
13972 db.execute("INSERT INTO exc_left VALUES (1), (2), (3)").unwrap();
13973 db.execute("CREATE TABLE exc_right (v INT)").unwrap();
13974 db.execute("INSERT INTO exc_right VALUES (2), (3)").unwrap();
13975
13976 let rows = db.query(
13977 "SELECT v FROM exc_left EXCEPT SELECT v FROM exc_right",
13978 &[],
13979 ).unwrap();
13980
13981 assert_eq!(rows.len(), 1, "EXCEPT of (1,2,3) minus (2,3) should produce 1 row");
13982 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
13983 }
13984
13985 #[test]
13986 fn test_set_op_except_all_rows_removed() {
13987 let db = EmbeddedDatabase::new_in_memory().unwrap();
13989
13990 let rows = db.query(
13991 "SELECT 1 AS v EXCEPT SELECT 1",
13992 &[],
13993 ).unwrap();
13994
13995 assert_eq!(rows.len(), 0, "EXCEPT of (1) minus (1) should produce 0 rows");
13996 }
13997
13998 #[test]
13999 fn test_set_op_except_no_overlap() {
14000 let db = EmbeddedDatabase::new_in_memory().unwrap();
14002
14003 let rows = db.query(
14004 "SELECT 1 AS v EXCEPT SELECT 2",
14005 &[],
14006 ).unwrap();
14007
14008 assert_eq!(rows.len(), 1, "EXCEPT of (1) minus (2) should produce 1 row");
14009 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
14010 }
14011
14012 #[test]
14013 fn test_set_op_except_all_with_duplicates() {
14014 let db = EmbeddedDatabase::new_in_memory().unwrap();
14017
14018 db.execute("CREATE TABLE ea_left (v INT)").unwrap();
14019 db.execute("INSERT INTO ea_left VALUES (1), (1), (1), (2)").unwrap();
14020 db.execute("CREATE TABLE ea_right (v INT)").unwrap();
14021 db.execute("INSERT INTO ea_right VALUES (1)").unwrap();
14022
14023 match db.query(
14024 "SELECT v FROM ea_left EXCEPT ALL SELECT v FROM ea_right",
14025 &[],
14026 ) {
14027 Ok(rows) => {
14028 assert_eq!(rows.len(), 3, "EXCEPT ALL should produce 3 rows (two 1s and one 2)");
14030 let mut vals: Vec<i32> = rows.iter()
14031 .map(|r| match r.get(0).unwrap() {
14032 Value::Int4(n) => *n,
14033 other => panic!("Expected Int4, got {:?}", other),
14034 })
14035 .collect();
14036 vals.sort();
14037 assert_eq!(vals, vec![1, 1, 2], "Should have two 1s and one 2");
14038 }
14039 Err(e) => {
14040 println!("EXCEPT ALL not supported: {}", e);
14041 }
14042 }
14043 }
14044
14045 #[test]
14046 fn test_set_op_intersect_all_with_duplicates() {
14047 let db = EmbeddedDatabase::new_in_memory().unwrap();
14050
14051 db.execute("CREATE TABLE ia_left (v INT)").unwrap();
14052 db.execute("INSERT INTO ia_left VALUES (1), (1), (1), (2)").unwrap();
14053 db.execute("CREATE TABLE ia_right (v INT)").unwrap();
14054 db.execute("INSERT INTO ia_right VALUES (1), (1)").unwrap();
14055
14056 match db.query(
14057 "SELECT v FROM ia_left INTERSECT ALL SELECT v FROM ia_right",
14058 &[],
14059 ) {
14060 Ok(rows) => {
14061 assert_eq!(rows.len(), 2, "INTERSECT ALL should produce 2 rows (min of 3,2 = 2)");
14062 for row in &rows {
14063 assert_eq!(row.get(0).unwrap(), &Value::Int4(1),
14064 "All INTERSECT ALL results should be 1");
14065 }
14066 }
14067 Err(e) => {
14068 println!("INTERSECT ALL not supported: {}", e);
14069 }
14070 }
14071 }
14072
14073 #[test]
14074 fn test_set_op_multiple_unions_chained() {
14075 let db = EmbeddedDatabase::new_in_memory().unwrap();
14077
14078 let rows = db.query(
14079 "SELECT 1 AS v UNION SELECT 2 UNION SELECT 3",
14080 &[],
14081 ).unwrap();
14082
14083 assert_eq!(rows.len(), 3, "Three-way UNION of distinct values should produce 3 rows");
14084 let mut vals: Vec<i32> = rows.iter()
14085 .map(|r| match r.get(0).unwrap() {
14086 Value::Int4(n) => *n,
14087 other => panic!("Expected Int4, got {:?}", other),
14088 })
14089 .collect();
14090 vals.sort();
14091 assert_eq!(vals, vec![1, 2, 3]);
14092 }
14093
14094 #[test]
14095 fn test_set_op_multiple_union_all_chained() {
14096 let db = EmbeddedDatabase::new_in_memory().unwrap();
14098
14099 let rows = db.query(
14100 "SELECT 10 AS v UNION ALL SELECT 20 UNION ALL SELECT 30 UNION ALL SELECT 10",
14101 &[],
14102 ).unwrap();
14103
14104 assert_eq!(rows.len(), 4, "Four-way UNION ALL should produce 4 rows");
14105 let vals: Vec<i32> = rows.iter()
14106 .map(|r| match r.get(0).unwrap() {
14107 Value::Int4(n) => *n,
14108 other => panic!("Expected Int4, got {:?}", other),
14109 })
14110 .collect();
14111 assert_eq!(vals, vec![10, 20, 30, 10]);
14112 }
14113
14114 #[test]
14115 fn test_set_op_union_uses_first_select_column_names() {
14116 let db = EmbeddedDatabase::new_in_memory().unwrap();
14118
14119 let rows = db.query(
14122 "SELECT 1 AS first_col UNION ALL SELECT 2 AS second_col",
14123 &[],
14124 ).unwrap();
14125
14126 assert_eq!(rows.len(), 2);
14127 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
14129 assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(2));
14130 }
14131
14132 #[test]
14133 fn test_set_op_union_with_order_by() {
14134 let db = EmbeddedDatabase::new_in_memory().unwrap();
14136
14137 let rows = db.query(
14138 "SELECT 3 AS v UNION ALL SELECT 1 UNION ALL SELECT 2 ORDER BY v",
14139 &[],
14140 ).unwrap();
14141
14142 assert_eq!(rows.len(), 3, "UNION ALL with ORDER BY should produce 3 rows");
14143 let vals: Vec<i32> = rows.iter()
14144 .map(|r| match r.get(0).unwrap() {
14145 Value::Int4(n) => *n,
14146 other => panic!("Expected Int4, got {:?}", other),
14147 })
14148 .collect();
14149 assert_eq!(vals, vec![1, 2, 3], "ORDER BY v should sort ascending");
14150 }
14151
14152 #[test]
14153 fn test_set_op_union_with_order_by_desc() {
14154 let db = EmbeddedDatabase::new_in_memory().unwrap();
14156
14157 let rows = db.query(
14158 "SELECT 3 AS v UNION ALL SELECT 1 UNION ALL SELECT 2 ORDER BY v DESC",
14159 &[],
14160 ).unwrap();
14161
14162 assert_eq!(rows.len(), 3);
14163 let vals: Vec<i32> = rows.iter()
14164 .map(|r| match r.get(0).unwrap() {
14165 Value::Int4(n) => *n,
14166 other => panic!("Expected Int4, got {:?}", other),
14167 })
14168 .collect();
14169 assert_eq!(vals, vec![3, 2, 1], "ORDER BY v DESC should sort descending");
14170 }
14171
14172 #[test]
14173 fn test_set_op_union_with_limit() {
14174 let db = EmbeddedDatabase::new_in_memory().unwrap();
14176
14177 let rows = db.query(
14178 "SELECT 1 AS v UNION ALL SELECT 2 UNION ALL SELECT 3 LIMIT 2",
14179 &[],
14180 ).unwrap();
14181
14182 assert_eq!(rows.len(), 2, "UNION ALL with LIMIT 2 should produce 2 rows");
14183 }
14184
14185 #[test]
14186 fn test_set_op_union_with_order_by_and_limit() {
14187 let db = EmbeddedDatabase::new_in_memory().unwrap();
14189
14190 let rows = db.query(
14191 "SELECT 5 AS v UNION ALL SELECT 3 UNION ALL SELECT 1 \
14192 UNION ALL SELECT 4 UNION ALL SELECT 2 \
14193 ORDER BY v LIMIT 3",
14194 &[],
14195 ).unwrap();
14196
14197 assert_eq!(rows.len(), 3, "UNION ALL with ORDER BY + LIMIT 3 should produce 3 rows");
14198 let vals: Vec<i32> = rows.iter()
14199 .map(|r| match r.get(0).unwrap() {
14200 Value::Int4(n) => *n,
14201 other => panic!("Expected Int4, got {:?}", other),
14202 })
14203 .collect();
14204 assert_eq!(vals, vec![1, 2, 3], "Should return smallest 3 values sorted");
14205 }
14206
14207 #[test]
14208 fn test_set_op_intersect_empty_result() {
14209 let db = EmbeddedDatabase::new_in_memory().unwrap();
14211
14212 db.execute("CREATE TABLE ie_left (v INT)").unwrap();
14213 db.execute("INSERT INTO ie_left VALUES (1), (2)").unwrap();
14214 db.execute("CREATE TABLE ie_right (v INT)").unwrap();
14215 db.execute("INSERT INTO ie_right VALUES (3), (4)").unwrap();
14216
14217 let rows = db.query(
14218 "SELECT v FROM ie_left INTERSECT SELECT v FROM ie_right",
14219 &[],
14220 ).unwrap();
14221
14222 assert_eq!(rows.len(), 0, "INTERSECT with no common rows should produce 0 rows");
14223 }
14224
14225 #[test]
14226 fn test_set_op_union_with_null_values() {
14227 let db = EmbeddedDatabase::new_in_memory().unwrap();
14229
14230 let rows_all = db.query(
14232 "SELECT NULL AS v UNION ALL SELECT NULL",
14233 &[],
14234 ).unwrap();
14235 assert_eq!(rows_all.len(), 2, "UNION ALL of (NULL, NULL) should produce 2 rows");
14236 assert_eq!(rows_all[0].get(0).unwrap(), &Value::Null);
14237 assert_eq!(rows_all[1].get(0).unwrap(), &Value::Null);
14238
14239 let rows_distinct = db.query(
14241 "SELECT NULL AS v UNION SELECT NULL",
14242 &[],
14243 ).unwrap();
14244 assert_eq!(rows_distinct.len(), 1,
14245 "UNION of (NULL, NULL) should dedup to 1 row (SQL standard: NULL = NULL for UNION)");
14246 assert_eq!(rows_distinct[0].get(0).unwrap(), &Value::Null);
14247 }
14248
14249 #[test]
14250 fn test_set_op_intersect_with_null_values() {
14251 let db = EmbeddedDatabase::new_in_memory().unwrap();
14253
14254 let rows = db.query(
14255 "SELECT NULL AS v INTERSECT SELECT NULL",
14256 &[],
14257 ).unwrap();
14258
14259 assert_eq!(rows.len(), 1,
14260 "INTERSECT of (NULL) and (NULL) should produce 1 row");
14261 assert_eq!(rows[0].get(0).unwrap(), &Value::Null);
14262 }
14263
14264 #[test]
14265 fn test_set_op_except_with_null_values() {
14266 let db = EmbeddedDatabase::new_in_memory().unwrap();
14268
14269 let rows = db.query(
14270 "SELECT NULL AS v EXCEPT SELECT NULL",
14271 &[],
14272 ).unwrap();
14273
14274 assert_eq!(rows.len(), 0,
14275 "EXCEPT of (NULL) minus (NULL) should produce 0 rows");
14276 }
14277
14278 #[test]
14279 fn test_set_op_union_with_table_data() {
14280 let db = EmbeddedDatabase::new_in_memory().unwrap();
14282
14283 db.execute("CREATE TABLE employees (id INT, name TEXT, dept TEXT)").unwrap();
14284 db.execute("INSERT INTO employees VALUES (1, 'Alice', 'Eng')").unwrap();
14285 db.execute("INSERT INTO employees VALUES (2, 'Bob', 'Eng')").unwrap();
14286
14287 db.execute("CREATE TABLE contractors (id INT, name TEXT, dept TEXT)").unwrap();
14288 db.execute("INSERT INTO contractors VALUES (3, 'Charlie', 'Eng')").unwrap();
14289 db.execute("INSERT INTO contractors VALUES (4, 'Diana', 'Sales')").unwrap();
14290
14291 let rows = db.query(
14292 "SELECT id, name FROM employees UNION ALL SELECT id, name FROM contractors",
14293 &[],
14294 ).unwrap();
14295
14296 assert_eq!(rows.len(), 4, "UNION ALL of 2+2 rows should produce 4 rows");
14297
14298 let names: Vec<String> = rows.iter()
14299 .map(|r| match r.get(1).unwrap() {
14300 Value::String(s) => s.clone(),
14301 other => panic!("Expected String, got {:?}", other),
14302 })
14303 .collect();
14304 assert!(names.contains(&"Alice".to_string()));
14305 assert!(names.contains(&"Bob".to_string()));
14306 assert!(names.contains(&"Charlie".to_string()));
14307 assert!(names.contains(&"Diana".to_string()));
14308 }
14309
14310 #[test]
14311 fn test_set_op_union_distinct_with_table_data() {
14312 let db = EmbeddedDatabase::new_in_memory().unwrap();
14314
14315 db.execute("CREATE TABLE colors_a (name TEXT)").unwrap();
14316 db.execute("INSERT INTO colors_a VALUES ('red'), ('green'), ('blue')").unwrap();
14317
14318 db.execute("CREATE TABLE colors_b (name TEXT)").unwrap();
14319 db.execute("INSERT INTO colors_b VALUES ('blue'), ('green'), ('yellow')").unwrap();
14320
14321 let rows = db.query(
14322 "SELECT name FROM colors_a UNION SELECT name FROM colors_b",
14323 &[],
14324 ).unwrap();
14325
14326 assert_eq!(rows.len(), 4,
14328 "UNION of (red,green,blue) and (blue,green,yellow) should produce 4 unique rows");
14329
14330 let mut names: Vec<String> = rows.iter()
14331 .map(|r| match r.get(0).unwrap() {
14332 Value::String(s) => s.clone(),
14333 other => panic!("Expected String, got {:?}", other),
14334 })
14335 .collect();
14336 names.sort();
14337 assert_eq!(names, vec!["blue", "green", "red", "yellow"]);
14338 }
14339
14340 #[test]
14341 fn test_set_op_intersect_with_table_data() {
14342 let db = EmbeddedDatabase::new_in_memory().unwrap();
14344
14345 db.execute("CREATE TABLE skills_a (skill TEXT)").unwrap();
14346 db.execute("INSERT INTO skills_a VALUES ('rust'), ('python'), ('go')").unwrap();
14347
14348 db.execute("CREATE TABLE skills_b (skill TEXT)").unwrap();
14349 db.execute("INSERT INTO skills_b VALUES ('python'), ('go'), ('java')").unwrap();
14350
14351 let rows = db.query(
14352 "SELECT skill FROM skills_a INTERSECT SELECT skill FROM skills_b",
14353 &[],
14354 ).unwrap();
14355
14356 assert_eq!(rows.len(), 2,
14357 "INTERSECT of (rust,python,go) and (python,go,java) should produce 2 rows");
14358
14359 let mut names: Vec<String> = rows.iter()
14360 .map(|r| match r.get(0).unwrap() {
14361 Value::String(s) => s.clone(),
14362 other => panic!("Expected String, got {:?}", other),
14363 })
14364 .collect();
14365 names.sort();
14366 assert_eq!(names, vec!["go", "python"]);
14367 }
14368
14369 #[test]
14370 fn test_set_op_except_with_table_data() {
14371 let db = EmbeddedDatabase::new_in_memory().unwrap();
14373
14374 db.execute("CREATE TABLE all_items (item TEXT)").unwrap();
14375 db.execute("INSERT INTO all_items VALUES ('a'), ('b'), ('c'), ('d')").unwrap();
14376
14377 db.execute("CREATE TABLE sold_items (item TEXT)").unwrap();
14378 db.execute("INSERT INTO sold_items VALUES ('b'), ('d')").unwrap();
14379
14380 let rows = db.query(
14381 "SELECT item FROM all_items EXCEPT SELECT item FROM sold_items",
14382 &[],
14383 ).unwrap();
14384
14385 assert_eq!(rows.len(), 2,
14386 "EXCEPT of (a,b,c,d) minus (b,d) should produce 2 rows");
14387
14388 let mut names: Vec<String> = rows.iter()
14389 .map(|r| match r.get(0).unwrap() {
14390 Value::String(s) => s.clone(),
14391 other => panic!("Expected String, got {:?}", other),
14392 })
14393 .collect();
14394 names.sort();
14395 assert_eq!(names, vec!["a", "c"]);
14396 }
14397
14398 #[test]
14399 fn test_set_op_union_multi_column() {
14400 let db = EmbeddedDatabase::new_in_memory().unwrap();
14402
14403 let rows = db.query(
14404 "SELECT 1 AS a, 'x' AS b UNION ALL SELECT 2, 'y'",
14405 &[],
14406 ).unwrap();
14407
14408 assert_eq!(rows.len(), 2);
14409 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
14410 assert_eq!(rows[0].get(1).unwrap(), &Value::String("x".to_string()));
14411 assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(2));
14412 assert_eq!(rows[1].get(1).unwrap(), &Value::String("y".to_string()));
14413 }
14414
14415 #[test]
14416 fn test_set_op_union_distinct_multi_column() {
14417 let db = EmbeddedDatabase::new_in_memory().unwrap();
14419
14420 let rows = db.query(
14422 "SELECT 1 AS a, 'a' AS b \
14423 UNION SELECT 1, 'a' \
14424 UNION SELECT 1, 'b'",
14425 &[],
14426 ).unwrap();
14427
14428 assert_eq!(rows.len(), 2,
14429 "UNION should dedup (1,'a') but keep (1,'b') as distinct");
14430 }
14431
14432 #[test]
14433 fn test_set_op_union_empty_left() {
14434 let db = EmbeddedDatabase::new_in_memory().unwrap();
14436
14437 db.execute("CREATE TABLE empty_tbl (v INT)").unwrap();
14438 db.execute("CREATE TABLE full_tbl (v INT)").unwrap();
14439 db.execute("INSERT INTO full_tbl VALUES (1), (2)").unwrap();
14440
14441 let rows = db.query(
14442 "SELECT v FROM empty_tbl UNION ALL SELECT v FROM full_tbl",
14443 &[],
14444 ).unwrap();
14445
14446 assert_eq!(rows.len(), 2, "UNION ALL of empty + 2 rows should produce 2 rows");
14447 }
14448
14449 #[test]
14450 fn test_set_op_union_empty_right() {
14451 let db = EmbeddedDatabase::new_in_memory().unwrap();
14453
14454 db.execute("CREATE TABLE full_tbl2 (v INT)").unwrap();
14455 db.execute("INSERT INTO full_tbl2 VALUES (1), (2)").unwrap();
14456 db.execute("CREATE TABLE empty_tbl2 (v INT)").unwrap();
14457
14458 let rows = db.query(
14459 "SELECT v FROM full_tbl2 UNION ALL SELECT v FROM empty_tbl2",
14460 &[],
14461 ).unwrap();
14462
14463 assert_eq!(rows.len(), 2, "UNION ALL of 2 rows + empty should produce 2 rows");
14464 }
14465
14466 #[test]
14467 fn test_set_op_union_both_empty() {
14468 let db = EmbeddedDatabase::new_in_memory().unwrap();
14470
14471 db.execute("CREATE TABLE empty_a (v INT)").unwrap();
14472 db.execute("CREATE TABLE empty_b (v INT)").unwrap();
14473
14474 let rows = db.query(
14475 "SELECT v FROM empty_a UNION ALL SELECT v FROM empty_b",
14476 &[],
14477 ).unwrap();
14478
14479 assert_eq!(rows.len(), 0, "UNION ALL of two empty tables should produce 0 rows");
14480 }
14481
14482 #[test]
14483 fn test_set_op_except_empty_right_preserves_left() {
14484 let db = EmbeddedDatabase::new_in_memory().unwrap();
14486
14487 db.execute("CREATE TABLE exc_full (v INT)").unwrap();
14488 db.execute("INSERT INTO exc_full VALUES (10), (20), (30)").unwrap();
14489 db.execute("CREATE TABLE exc_empty (v INT)").unwrap();
14490
14491 let rows = db.query(
14492 "SELECT v FROM exc_full EXCEPT SELECT v FROM exc_empty",
14493 &[],
14494 ).unwrap();
14495
14496 assert_eq!(rows.len(), 3,
14497 "EXCEPT with empty right should return all 3 left rows");
14498 }
14499
14500 #[test]
14501 fn test_set_op_intersect_empty_right_returns_empty() {
14502 let db = EmbeddedDatabase::new_in_memory().unwrap();
14504
14505 db.execute("CREATE TABLE isec_full (v INT)").unwrap();
14506 db.execute("INSERT INTO isec_full VALUES (10), (20)").unwrap();
14507 db.execute("CREATE TABLE isec_empty (v INT)").unwrap();
14508
14509 let rows = db.query(
14510 "SELECT v FROM isec_full INTERSECT SELECT v FROM isec_empty",
14511 &[],
14512 ).unwrap();
14513
14514 assert_eq!(rows.len(), 0,
14515 "INTERSECT with empty right should produce 0 rows");
14516 }
14517
14518 #[test]
14519 fn test_set_op_union_with_where_clause() {
14520 let db = EmbeddedDatabase::new_in_memory().unwrap();
14522
14523 db.execute("CREATE TABLE nums (v INT)").unwrap();
14524 db.execute("INSERT INTO nums VALUES (1), (2), (3), (4), (5)").unwrap();
14525
14526 let rows = db.query(
14527 "SELECT v FROM nums WHERE v <= 2 UNION ALL SELECT v FROM nums WHERE v >= 4",
14528 &[],
14529 ).unwrap();
14530
14531 assert_eq!(rows.len(), 4, "UNION ALL of (1,2) and (4,5) should produce 4 rows");
14532 let mut vals: Vec<i32> = rows.iter()
14533 .map(|r| match r.get(0).unwrap() {
14534 Value::Int4(n) => *n,
14535 other => panic!("Expected Int4, got {:?}", other),
14536 })
14537 .collect();
14538 vals.sort();
14539 assert_eq!(vals, vec![1, 2, 4, 5]);
14540 }
14541
14542 #[test]
14543 fn test_set_op_union_null_mixed_with_values() {
14544 let db = EmbeddedDatabase::new_in_memory().unwrap();
14546
14547 let rows = db.query(
14548 "SELECT 1 AS v UNION SELECT NULL UNION SELECT 2 UNION SELECT NULL",
14549 &[],
14550 ).unwrap();
14551
14552 assert_eq!(rows.len(), 3,
14554 "UNION of (1, NULL, 2, NULL) should produce 3 rows (dedup NULLs)");
14555 }
14556
14557 #[test]
14558 fn test_set_op_union_all_large_dataset() {
14559 let db = EmbeddedDatabase::new_in_memory().unwrap();
14561
14562 db.execute("CREATE TABLE big_a (id INT, val TEXT)").unwrap();
14563 db.execute("CREATE TABLE big_b (id INT, val TEXT)").unwrap();
14564
14565 for i in 0..50 {
14566 db.execute(&format!("INSERT INTO big_a VALUES ({}, 'a{}')", i, i)).unwrap();
14567 }
14568 for i in 25..75 {
14569 db.execute(&format!("INSERT INTO big_b VALUES ({}, 'b{}')", i, i)).unwrap();
14570 }
14571
14572 let rows = db.query(
14573 "SELECT id, val FROM big_a UNION ALL SELECT id, val FROM big_b",
14574 &[],
14575 ).unwrap();
14576
14577 assert_eq!(rows.len(), 100,
14579 "UNION ALL of 50+50 rows should produce 100 rows");
14580 }
14581
14582 #[test]
14583 fn test_set_op_intersect_single_common_row() {
14584 let db = EmbeddedDatabase::new_in_memory().unwrap();
14586
14587 let rows = db.query(
14588 "SELECT 1 AS v INTERSECT SELECT 1",
14589 &[],
14590 ).unwrap();
14591
14592 assert_eq!(rows.len(), 1);
14593 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
14594 }
14595
14596 #[test]
14597 fn test_set_op_except_is_not_symmetric() {
14598 let db = EmbeddedDatabase::new_in_memory().unwrap();
14600
14601 db.execute("CREATE TABLE set_a (v INT)").unwrap();
14602 db.execute("INSERT INTO set_a VALUES (1), (2), (3)").unwrap();
14603 db.execute("CREATE TABLE set_b (v INT)").unwrap();
14604 db.execute("INSERT INTO set_b VALUES (2), (3), (4)").unwrap();
14605
14606 let a_except_b = db.query(
14607 "SELECT v FROM set_a EXCEPT SELECT v FROM set_b",
14608 &[],
14609 ).unwrap();
14610
14611 let b_except_a = db.query(
14612 "SELECT v FROM set_b EXCEPT SELECT v FROM set_a",
14613 &[],
14614 ).unwrap();
14615
14616 assert_eq!(a_except_b.len(), 1);
14618 assert_eq!(b_except_a.len(), 1);
14619 assert_eq!(a_except_b[0].get(0).unwrap(), &Value::Int4(1),
14620 "A EXCEPT B should yield 1");
14621 assert_eq!(b_except_a[0].get(0).unwrap(), &Value::Int4(4),
14622 "B EXCEPT A should yield 4");
14623 }
14624
14625 #[test]
14626 fn test_set_op_union_with_string_values() {
14627 let db = EmbeddedDatabase::new_in_memory().unwrap();
14629
14630 let rows = db.query(
14631 "SELECT 'hello' AS greeting UNION ALL SELECT 'world'",
14632 &[],
14633 ).unwrap();
14634
14635 assert_eq!(rows.len(), 2);
14636 assert_eq!(rows[0].get(0).unwrap(), &Value::String("hello".to_string()));
14637 assert_eq!(rows[1].get(0).unwrap(), &Value::String("world".to_string()));
14638 }
14639
14640 #[test]
14641 fn test_set_op_union_distinct_string_dedup() {
14642 let db = EmbeddedDatabase::new_in_memory().unwrap();
14644
14645 let rows = db.query(
14646 "SELECT 'same' AS v UNION SELECT 'same' UNION SELECT 'different'",
14647 &[],
14648 ).unwrap();
14649
14650 assert_eq!(rows.len(), 2, "UNION should dedup 'same' into one row");
14651 let mut vals: Vec<String> = rows.iter()
14652 .map(|r| match r.get(0).unwrap() {
14653 Value::String(s) => s.clone(),
14654 other => panic!("Expected String, got {:?}", other),
14655 })
14656 .collect();
14657 vals.sort();
14658 assert_eq!(vals, vec!["different", "same"]);
14659 }
14660
14661 #[test]
14662 fn test_set_op_union_with_boolean_values() {
14663 let db = EmbeddedDatabase::new_in_memory().unwrap();
14665
14666 let rows = db.query(
14667 "SELECT TRUE AS flag UNION SELECT FALSE UNION SELECT TRUE",
14668 &[],
14669 ).unwrap();
14670
14671 assert_eq!(rows.len(), 2, "UNION of (TRUE, FALSE, TRUE) should produce 2 rows");
14672 }
14673
14674 #[test]
14675 fn test_set_op_union_all_preserves_order() {
14676 let db = EmbeddedDatabase::new_in_memory().unwrap();
14678
14679 let rows = db.query(
14680 "SELECT 100 AS v UNION ALL SELECT 200",
14681 &[],
14682 ).unwrap();
14683
14684 assert_eq!(rows.len(), 2);
14685 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(100));
14687 assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(200));
14688 }
14689
14690 #[test]
14691 fn test_set_op_except_self_yields_empty() {
14692 let db = EmbeddedDatabase::new_in_memory().unwrap();
14694
14695 db.execute("CREATE TABLE self_exc (v INT)").unwrap();
14696 db.execute("INSERT INTO self_exc VALUES (1), (2), (3)").unwrap();
14697
14698 let rows = db.query(
14699 "SELECT v FROM self_exc EXCEPT SELECT v FROM self_exc",
14700 &[],
14701 ).unwrap();
14702
14703 assert_eq!(rows.len(), 0, "Table EXCEPT itself should produce 0 rows");
14704 }
14705
14706 #[test]
14707 fn test_set_op_intersect_self_yields_all() {
14708 let db = EmbeddedDatabase::new_in_memory().unwrap();
14710
14711 db.execute("CREATE TABLE self_int (v INT)").unwrap();
14712 db.execute("INSERT INTO self_int VALUES (1), (2), (3)").unwrap();
14713
14714 let rows = db.query(
14715 "SELECT v FROM self_int INTERSECT SELECT v FROM self_int",
14716 &[],
14717 ).unwrap();
14718
14719 assert_eq!(rows.len(), 3, "Table INTERSECT itself should return all 3 unique rows");
14720 let mut vals: Vec<i32> = rows.iter()
14721 .map(|r| match r.get(0).unwrap() {
14722 Value::Int4(n) => *n,
14723 other => panic!("Expected Int4, got {:?}", other),
14724 })
14725 .collect();
14726 vals.sort();
14727 assert_eq!(vals, vec![1, 2, 3]);
14728 }
14729
14730 #[test]
14731 fn test_set_op_union_with_expressions() {
14732 let db = EmbeddedDatabase::new_in_memory().unwrap();
14734
14735 let rows = db.query(
14736 "SELECT 1 + 1 AS result UNION ALL SELECT 2 * 3",
14737 &[],
14738 ).unwrap();
14739
14740 assert_eq!(rows.len(), 2);
14741 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(2));
14742 assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(6));
14743 }
14744
14745 #[test]
14746 fn test_set_op_union_single_row_each() {
14747 let db = EmbeddedDatabase::new_in_memory().unwrap();
14749
14750 let rows = db.query("SELECT 42 AS v UNION ALL SELECT 99", &[]).unwrap();
14751
14752 assert_eq!(rows.len(), 2);
14753 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(42));
14754 assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(99));
14755 }
14756
14757 fn setup_subquery_tables() -> EmbeddedDatabase {
14779 let db = EmbeddedDatabase::new_in_memory().unwrap();
14780
14781 db.execute("CREATE TABLE customers (id INT, name TEXT, category TEXT)").unwrap();
14782 db.execute("CREATE TABLE orders (id INT, customer_id INT, amount INT, product_id INT)").unwrap();
14783 db.execute("CREATE TABLE products (id INT, name TEXT, price INT)").unwrap();
14784
14785 db.execute("INSERT INTO customers VALUES (1, 'Alice', 'premium')").unwrap();
14787 db.execute("INSERT INTO customers VALUES (2, 'Bob', 'standard')").unwrap();
14788 db.execute("INSERT INTO customers VALUES (3, 'Charlie', 'premium')").unwrap();
14789 db.execute("INSERT INTO customers VALUES (4, 'Diana', 'standard')").unwrap();
14790
14791 db.execute("INSERT INTO orders VALUES (10, 1, 100, 1)").unwrap();
14793 db.execute("INSERT INTO orders VALUES (11, 1, 200, 2)").unwrap();
14794 db.execute("INSERT INTO orders VALUES (12, 2, 50, 1)").unwrap();
14795 db.execute("INSERT INTO orders VALUES (13, 3, 300, 3)").unwrap();
14796
14797 db.execute("INSERT INTO products VALUES (1, 'Widget', 10)").unwrap();
14799 db.execute("INSERT INTO products VALUES (2, 'Gadget', 25)").unwrap();
14800 db.execute("INSERT INTO products VALUES (3, 'Gizmo', 50)").unwrap();
14801 db.execute("INSERT INTO products VALUES (4, 'Doohickey', 5)").unwrap();
14802
14803 db
14804 }
14805
14806 #[test]
14809 fn test_subquery_in_basic() {
14810 let db = setup_subquery_tables();
14813
14814 let sql = "SELECT id, name FROM customers WHERE id IN (SELECT customer_id FROM orders) ORDER BY id";
14815 match db.query(sql, &[]) {
14816 Ok(rows) => {
14817 assert_eq!(rows.len(), 3, "3 customers have orders, got {}", rows.len());
14818 assert_eq!(rows[0].get(0), Some(&Value::Int4(1)));
14819 assert_eq!(rows[0].get(1), Some(&Value::String("Alice".to_string())));
14820 assert_eq!(rows[1].get(0), Some(&Value::Int4(2)));
14821 assert_eq!(rows[1].get(1), Some(&Value::String("Bob".to_string())));
14822 assert_eq!(rows[2].get(0), Some(&Value::Int4(3)));
14823 assert_eq!(rows[2].get(1), Some(&Value::String("Charlie".to_string())));
14824 }
14825 Err(e) => {
14826 println!("IN subquery not supported: {}", e);
14828 }
14829 }
14830 }
14831
14832 #[test]
14833 fn test_subquery_not_in_basic() {
14834 let db = setup_subquery_tables();
14837
14838 let sql = "SELECT id, name FROM customers WHERE id NOT IN (SELECT customer_id FROM orders) ORDER BY id";
14839 match db.query(sql, &[]) {
14840 Ok(rows) => {
14841 assert_eq!(rows.len(), 1, "Only Diana has no orders, got {}", rows.len());
14842 assert_eq!(rows[0].get(0), Some(&Value::Int4(4)));
14843 assert_eq!(rows[0].get(1), Some(&Value::String("Diana".to_string())));
14844 }
14845 Err(e) => {
14846 println!("NOT IN subquery not supported: {}", e);
14848 }
14849 }
14850 }
14851
14852 #[test]
14853 fn test_subquery_in_with_empty_result() {
14854 let db = setup_subquery_tables();
14858
14859 let sql = "SELECT id, name FROM customers WHERE id IN (SELECT customer_id FROM orders WHERE amount > 9999)";
14860 match db.query(sql, &[]) {
14861 Ok(rows) => {
14862 assert_eq!(rows.len(), 0, "No orders match amount > 9999, so IN list is empty");
14863 }
14864 Err(e) => {
14865 println!("IN subquery with empty result not supported: {}", e);
14866 }
14867 }
14868 }
14869
14870 #[test]
14871 fn test_subquery_in_all_match() {
14872 let db = setup_subquery_tables();
14875
14876 let sql = "SELECT id FROM customers WHERE id IN (SELECT id FROM customers) ORDER BY id";
14877 match db.query(sql, &[]) {
14878 Ok(rows) => {
14879 assert_eq!(rows.len(), 4, "All 4 customers should match, got {}", rows.len());
14880 for (i, row) in rows.iter().enumerate() {
14881 let expected_id = (i as i32) + 1;
14882 assert_eq!(row.get(0), Some(&Value::Int4(expected_id)));
14883 }
14884 }
14885 Err(e) => {
14886 println!("IN subquery self-reference not supported: {}", e);
14887 }
14888 }
14889 }
14890
14891 #[test]
14892 fn test_subquery_not_in_all_match() {
14893 let db = setup_subquery_tables();
14896
14897 let sql = "SELECT id FROM customers WHERE id NOT IN (SELECT id FROM customers)";
14898 match db.query(sql, &[]) {
14899 Ok(rows) => {
14900 assert_eq!(rows.len(), 0, "NOT IN with all ids should return nothing");
14901 }
14902 Err(e) => {
14903 println!("NOT IN subquery self-reference not supported: {}", e);
14904 }
14905 }
14906 }
14907
14908 #[test]
14909 fn test_subquery_in_products_not_ordered() {
14910 let db = setup_subquery_tables();
14913
14914 let sql = "SELECT id, name FROM products WHERE id NOT IN (SELECT product_id FROM orders) ORDER BY id";
14915 match db.query(sql, &[]) {
14916 Ok(rows) => {
14917 assert_eq!(rows.len(), 1, "Only Doohickey has no orders, got {}", rows.len());
14918 assert_eq!(rows[0].get(0), Some(&Value::Int4(4)));
14919 assert_eq!(rows[0].get(1), Some(&Value::String("Doohickey".to_string())));
14920 }
14921 Err(e) => {
14922 println!("NOT IN subquery for products not supported: {}", e);
14923 }
14924 }
14925 }
14926
14927 #[test]
14930 fn test_exists_basic_uncorrelated() {
14931 let db = setup_subquery_tables();
14934
14935 let sql = "SELECT id, name FROM customers WHERE EXISTS (SELECT 1 FROM orders) ORDER BY id";
14936 match db.query(sql, &[]) {
14937 Ok(rows) => {
14938 assert_eq!(rows.len(), 4, "EXISTS(non-empty) should return all 4 customers, got {}", rows.len());
14939 assert_eq!(rows[0].get(1), Some(&Value::String("Alice".to_string())));
14940 assert_eq!(rows[3].get(1), Some(&Value::String("Diana".to_string())));
14941 }
14942 Err(e) => {
14943 println!("EXISTS uncorrelated not supported: {}", e);
14944 }
14945 }
14946 }
14947
14948 #[test]
14949 fn test_exists_with_empty_subquery_result() {
14950 let db = setup_subquery_tables();
14953
14954 let sql = "SELECT id, name FROM customers WHERE EXISTS (SELECT 1 FROM orders WHERE amount > 9999)";
14955 match db.query(sql, &[]) {
14956 Ok(rows) => {
14957 assert_eq!(rows.len(), 0, "EXISTS on empty subquery should return 0 rows, got {}", rows.len());
14958 }
14959 Err(e) => {
14960 println!("EXISTS with empty subquery not supported: {}", e);
14961 }
14962 }
14963 }
14964
14965 #[test]
14966 fn test_not_exists_uncorrelated() {
14967 let db = setup_subquery_tables();
14970
14971 let sql = "SELECT id, name FROM customers WHERE NOT EXISTS (SELECT 1 FROM orders)";
14972 match db.query(sql, &[]) {
14973 Ok(rows) => {
14974 assert_eq!(rows.len(), 0, "NOT EXISTS(non-empty) should return 0 rows, got {}", rows.len());
14975 }
14976 Err(e) => {
14977 println!("NOT EXISTS uncorrelated not supported: {}", e);
14978 }
14979 }
14980 }
14981
14982 #[test]
14983 fn test_not_exists_with_empty_subquery() {
14984 let db = setup_subquery_tables();
14987
14988 let sql = "SELECT id, name FROM customers WHERE NOT EXISTS (SELECT 1 FROM orders WHERE amount > 9999) ORDER BY id";
14989 match db.query(sql, &[]) {
14990 Ok(rows) => {
14991 assert_eq!(rows.len(), 4, "NOT EXISTS(empty) should return all 4 customers, got {}", rows.len());
14992 }
14993 Err(e) => {
14994 println!("NOT EXISTS with empty subquery not supported: {}", e);
14995 }
14996 }
14997 }
14998
14999 #[test]
15000 fn test_exists_with_specific_filter() {
15001 let db = setup_subquery_tables();
15004
15005 let sql = "SELECT id, name FROM customers WHERE EXISTS (SELECT 1 FROM orders WHERE amount >= 200) ORDER BY id";
15006 match db.query(sql, &[]) {
15007 Ok(rows) => {
15008 assert_eq!(rows.len(), 4, "EXISTS with matching filter should return all customers, got {}", rows.len());
15010 }
15011 Err(e) => {
15012 println!("EXISTS with filter not supported: {}", e);
15013 }
15014 }
15015 }
15016
15017 #[test]
15018 fn test_exists_against_empty_table() {
15019 let db = EmbeddedDatabase::new_in_memory().unwrap();
15021 db.execute("CREATE TABLE parent (id INT, name TEXT)").unwrap();
15022 db.execute("CREATE TABLE child (id INT, parent_id INT)").unwrap();
15023 db.execute("INSERT INTO parent VALUES (1, 'Alice')").unwrap();
15024
15025 let sql = "SELECT id, name FROM parent WHERE EXISTS (SELECT 1 FROM child)";
15027 match db.query(sql, &[]) {
15028 Ok(rows) => {
15029 assert_eq!(rows.len(), 0, "EXISTS on empty table should return 0 rows");
15030 }
15031 Err(e) => {
15032 println!("EXISTS against empty table not supported: {}", e);
15033 }
15034 }
15035 }
15036
15037 #[test]
15038 fn test_not_exists_against_empty_table() {
15039 let db = EmbeddedDatabase::new_in_memory().unwrap();
15041 db.execute("CREATE TABLE parent_ne (id INT, name TEXT)").unwrap();
15042 db.execute("CREATE TABLE child_ne (id INT, parent_id INT)").unwrap();
15043 db.execute("INSERT INTO parent_ne VALUES (1, 'Alice')").unwrap();
15044 db.execute("INSERT INTO parent_ne VALUES (2, 'Bob')").unwrap();
15045
15046 let sql = "SELECT id, name FROM parent_ne WHERE NOT EXISTS (SELECT 1 FROM child_ne) ORDER BY id";
15048 match db.query(sql, &[]) {
15049 Ok(rows) => {
15050 assert_eq!(rows.len(), 2, "NOT EXISTS on empty table should return all parent rows");
15051 assert_eq!(rows[0].get(1), Some(&Value::String("Alice".to_string())));
15052 assert_eq!(rows[1].get(1), Some(&Value::String("Bob".to_string())));
15053 }
15054 Err(e) => {
15055 println!("NOT EXISTS against empty table not supported: {}", e);
15056 }
15057 }
15058 }
15059
15060 #[test]
15063 fn test_exists_correlated_subquery() {
15064 let db = setup_subquery_tables();
15069
15070 let sql = "SELECT id, name FROM customers WHERE EXISTS (SELECT 1 FROM orders WHERE orders.customer_id = customers.id) ORDER BY id";
15071 match db.query(sql, &[]) {
15072 Ok(rows) => {
15073 assert_eq!(rows.len(), 3, "Correlated EXISTS should find 3 customers with orders, got {}", rows.len());
15075 assert_eq!(rows[0].get(0), Some(&Value::Int4(1)));
15076 assert_eq!(rows[1].get(0), Some(&Value::Int4(2)));
15077 assert_eq!(rows[2].get(0), Some(&Value::Int4(3)));
15078 }
15079 Err(e) => {
15080 println!("Correlated EXISTS not supported: {}", e);
15083 }
15084 }
15085 }
15086
15087 #[test]
15088 fn test_not_exists_correlated_subquery() {
15089 let db = setup_subquery_tables();
15092
15093 let sql = "SELECT id, name FROM customers WHERE NOT EXISTS (SELECT 1 FROM orders WHERE orders.customer_id = customers.id) ORDER BY id";
15094 match db.query(sql, &[]) {
15095 Ok(rows) => {
15096 assert_eq!(rows.len(), 1, "Correlated NOT EXISTS should find 1 customer without orders, got {}", rows.len());
15098 assert_eq!(rows[0].get(0), Some(&Value::Int4(4)));
15099 assert_eq!(rows[0].get(1), Some(&Value::String("Diana".to_string())));
15100 }
15101 Err(e) => {
15102 println!("Correlated NOT EXISTS not supported: {}", e);
15104 }
15105 }
15106 }
15107
15108 #[test]
15111 fn test_subquery_scalar_in_select() {
15112 let db = setup_subquery_tables();
15115
15116 let sql = "SELECT id, (SELECT COUNT(*) FROM orders WHERE orders.customer_id = customers.id) FROM customers ORDER BY id";
15117 match db.query(sql, &[]) {
15118 Ok(rows) => {
15119 assert_eq!(rows.len(), 4, "Should return all 4 customers");
15121 println!("Scalar subquery returned {} rows - values: {:?}", rows.len(),
15123 rows.iter().map(|r| (r.get(0), r.get(1))).collect::<Vec<_>>());
15124 }
15125 Err(e) => {
15126 println!("Scalar subquery in SELECT not supported: {}", e);
15129 }
15130 }
15131 }
15132
15133 #[test]
15134 fn test_subquery_scalar_in_where() {
15135 let db = setup_subquery_tables();
15138
15139 let sql = "SELECT id, name FROM customers WHERE id > (SELECT MIN(customer_id) FROM orders) ORDER BY id";
15140 match db.query(sql, &[]) {
15141 Ok(rows) => {
15142 assert_eq!(rows.len(), 3, "Customers with id > 1, got {}", rows.len());
15144 assert_eq!(rows[0].get(0), Some(&Value::Int4(2)));
15145 assert_eq!(rows[1].get(0), Some(&Value::Int4(3)));
15146 assert_eq!(rows[2].get(0), Some(&Value::Int4(4)));
15147 }
15148 Err(e) => {
15149 println!("Scalar subquery in WHERE not supported: {}", e);
15151 }
15152 }
15153 }
15154
15155 #[test]
15158 fn test_subquery_in_from_clause() {
15159 let db = setup_subquery_tables();
15161
15162 let sql = "SELECT * FROM (SELECT id, name FROM customers WHERE category = 'premium') AS sub ORDER BY id";
15163 match db.query(sql, &[]) {
15164 Ok(rows) => {
15165 assert_eq!(rows.len(), 2, "2 premium customers, got {}", rows.len());
15167 assert_eq!(rows[0].get(0), Some(&Value::Int4(1)));
15168 assert_eq!(rows[0].get(1), Some(&Value::String("Alice".to_string())));
15169 assert_eq!(rows[1].get(0), Some(&Value::Int4(3)));
15170 assert_eq!(rows[1].get(1), Some(&Value::String("Charlie".to_string())));
15171 }
15172 Err(e) => {
15173 println!("Subquery in FROM clause not supported: {}", e);
15174 }
15175 }
15176 }
15177
15178 #[test]
15179 fn test_subquery_in_from_with_aggregation() {
15180 let db = setup_subquery_tables();
15182
15183 let sql = "SELECT * FROM (SELECT customer_id, SUM(amount) AS total FROM orders GROUP BY customer_id) AS sub ORDER BY customer_id";
15184 match db.query(sql, &[]) {
15185 Ok(rows) => {
15186 assert_eq!(rows.len(), 3, "3 customers have orders, got {}", rows.len());
15188 println!("FROM subquery with aggregation returned: {:?}",
15189 rows.iter().map(|r| (r.get(0), r.get(1))).collect::<Vec<_>>());
15190 }
15191 Err(e) => {
15192 println!("Subquery in FROM with aggregation not supported: {}", e);
15193 }
15194 }
15195 }
15196
15197 #[test]
15198 fn test_subquery_in_from_empty_result() {
15199 let db = setup_subquery_tables();
15201
15202 let sql = "SELECT * FROM (SELECT id, name FROM customers WHERE id > 999) AS sub";
15203 match db.query(sql, &[]) {
15204 Ok(rows) => {
15205 assert_eq!(rows.len(), 0, "No customers with id > 999");
15206 }
15207 Err(e) => {
15208 println!("FROM subquery with empty result not supported: {}", e);
15209 }
15210 }
15211 }
15212
15213 #[test]
15216 fn test_subquery_nested_in() {
15217 let db = setup_subquery_tables();
15223
15224 let sql = "SELECT id, name FROM customers WHERE id IN \
15225 (SELECT customer_id FROM orders WHERE product_id IN \
15226 (SELECT id FROM products WHERE price > 20)) ORDER BY id";
15227 match db.query(sql, &[]) {
15228 Ok(rows) => {
15229 assert_eq!(rows.len(), 2, "2 customers ordered expensive products, got {}", rows.len());
15230 assert_eq!(rows[0].get(0), Some(&Value::Int4(1)));
15231 assert_eq!(rows[0].get(1), Some(&Value::String("Alice".to_string())));
15232 assert_eq!(rows[1].get(0), Some(&Value::Int4(3)));
15233 assert_eq!(rows[1].get(1), Some(&Value::String("Charlie".to_string())));
15234 }
15235 Err(e) => {
15236 println!("Nested IN subquery not supported: {}", e);
15237 }
15238 }
15239 }
15240
15241 #[test]
15242 fn test_subquery_nested_in_three_levels() {
15243 let db = setup_subquery_tables();
15251
15252 let sql = "SELECT id, name FROM products WHERE id IN \
15253 (SELECT product_id FROM orders WHERE customer_id IN \
15254 (SELECT id FROM customers WHERE category = 'premium')) ORDER BY id";
15255 match db.query(sql, &[]) {
15256 Ok(rows) => {
15257 assert_eq!(rows.len(), 3, "3 products ordered by premium customers, got {}", rows.len());
15258 assert_eq!(rows[0].get(1), Some(&Value::String("Widget".to_string())));
15259 assert_eq!(rows[1].get(1), Some(&Value::String("Gadget".to_string())));
15260 assert_eq!(rows[2].get(1), Some(&Value::String("Gizmo".to_string())));
15261 }
15262 Err(e) => {
15263 println!("3-level nested IN subquery not supported: {}", e);
15264 }
15265 }
15266 }
15267
15268 #[test]
15271 fn test_exists_and_in_combined() {
15272 let db = setup_subquery_tables();
15277
15278 let sql = "SELECT id, name FROM customers \
15279 WHERE category = 'premium' \
15280 AND EXISTS (SELECT 1 FROM orders) \
15281 ORDER BY id";
15282 match db.query(sql, &[]) {
15283 Ok(rows) => {
15284 assert_eq!(rows.len(), 2, "2 premium customers when orders exist, got {}", rows.len());
15285 assert_eq!(rows[0].get(1), Some(&Value::String("Alice".to_string())));
15286 assert_eq!(rows[1].get(1), Some(&Value::String("Charlie".to_string())));
15287 }
15288 Err(e) => {
15289 println!("Combined EXISTS and filter not supported: {}", e);
15290 }
15291 }
15292 }
15293
15294 #[test]
15295 fn test_in_subquery_with_distinct() {
15296 let db = setup_subquery_tables();
15300
15301 let sql = "SELECT id, name FROM products WHERE id IN (SELECT DISTINCT product_id FROM orders) ORDER BY id";
15302 match db.query(sql, &[]) {
15303 Ok(rows) => {
15304 assert_eq!(rows.len(), 3, "3 distinct products in orders, got {}", rows.len());
15305 assert_eq!(rows[0].get(1), Some(&Value::String("Widget".to_string())));
15306 assert_eq!(rows[1].get(1), Some(&Value::String("Gadget".to_string())));
15307 assert_eq!(rows[2].get(1), Some(&Value::String("Gizmo".to_string())));
15308 }
15309 Err(e) => {
15310 println!("IN subquery with DISTINCT not supported: {}", e);
15311 }
15312 }
15313 }
15314
15315 #[test]
15316 fn test_subquery_in_with_expression() {
15317 let db = setup_subquery_tables();
15320
15321 let sql = "SELECT id, name FROM customers WHERE id + 0 IN (SELECT customer_id FROM orders) ORDER BY id";
15322 match db.query(sql, &[]) {
15323 Ok(rows) => {
15324 assert_eq!(rows.len(), 3, "3 customers with orders (via expression), got {}", rows.len());
15325 assert_eq!(rows[0].get(0), Some(&Value::Int4(1)));
15326 assert_eq!(rows[1].get(0), Some(&Value::Int4(2)));
15327 assert_eq!(rows[2].get(0), Some(&Value::Int4(3)));
15328 }
15329 Err(e) => {
15330 println!("IN subquery with expression not supported: {}", e);
15331 }
15332 }
15333 }
15334
15335 #[test]
15336 fn test_subquery_in_single_value() {
15337 let db = setup_subquery_tables();
15340
15341 let sql = "SELECT id, name FROM customers WHERE id IN (SELECT customer_id FROM orders WHERE amount = 300)";
15342 match db.query(sql, &[]) {
15343 Ok(rows) => {
15344 assert_eq!(rows.len(), 1, "Only 1 customer has the 300-amount order, got {}", rows.len());
15345 assert_eq!(rows[0].get(0), Some(&Value::Int4(3)));
15346 assert_eq!(rows[0].get(1), Some(&Value::String("Charlie".to_string())));
15347 }
15348 Err(e) => {
15349 println!("IN subquery with single result not supported: {}", e);
15350 }
15351 }
15352 }
15353
15354 #[test]
15355 fn test_exists_with_select_star_subquery() {
15356 let db = setup_subquery_tables();
15358
15359 let sql = "SELECT id FROM customers WHERE EXISTS (SELECT * FROM orders WHERE amount > 100) ORDER BY id";
15360 match db.query(sql, &[]) {
15361 Ok(rows) => {
15362 assert_eq!(rows.len(), 4, "EXISTS(SELECT *) with matches returns all outer rows, got {}", rows.len());
15364 }
15365 Err(e) => {
15366 println!("EXISTS with SELECT * not supported: {}", e);
15367 }
15368 }
15369 }
15370
15371 #[test]
15372 fn test_subquery_in_from_with_where() {
15373 let db = setup_subquery_tables();
15375
15376 let sql = "SELECT * FROM (SELECT id, name, category FROM customers) AS sub WHERE category = 'standard' ORDER BY id";
15377 match db.query(sql, &[]) {
15378 Ok(rows) => {
15379 assert_eq!(rows.len(), 2, "2 standard customers via derived table, got {}", rows.len());
15381 assert_eq!(rows[0].get(1), Some(&Value::String("Bob".to_string())));
15382 assert_eq!(rows[1].get(1), Some(&Value::String("Diana".to_string())));
15383 }
15384 Err(e) => {
15385 println!("Derived table with outer WHERE not supported: {}", e);
15386 }
15387 }
15388 }
15389
15390 #[test]
15391 fn test_subquery_in_from_select_subset() {
15392 let db = setup_subquery_tables();
15394
15395 let sql = "SELECT name FROM (SELECT id, name FROM customers WHERE id <= 2) AS sub ORDER BY name";
15396 match db.query(sql, &[]) {
15397 Ok(rows) => {
15398 assert_eq!(rows.len(), 2, "2 customers with id <= 2, got {}", rows.len());
15399 assert_eq!(rows[0].get(0), Some(&Value::String("Alice".to_string())));
15401 assert_eq!(rows[1].get(0), Some(&Value::String("Bob".to_string())));
15402 }
15403 Err(e) => {
15404 println!("Derived table with column subset not supported: {}", e);
15405 }
15406 }
15407 }
15408
15409 #[test]
15412 fn test_subquery_in_single_column_single_row() {
15413 let db = EmbeddedDatabase::new_in_memory().unwrap();
15415 db.execute("CREATE TABLE vals (v INT)").unwrap();
15416 db.execute("INSERT INTO vals VALUES (10)").unwrap();
15417 db.execute("INSERT INTO vals VALUES (20)").unwrap();
15418 db.execute("INSERT INTO vals VALUES (30)").unwrap();
15419
15420 let sql = "SELECT v FROM vals WHERE v IN (SELECT MAX(v) FROM vals)";
15421 match db.query(sql, &[]) {
15422 Ok(rows) => {
15423 assert_eq!(rows.len(), 1, "Only MAX(v)=30 should match, got {}", rows.len());
15424 assert_eq!(rows[0].get(0), Some(&Value::Int4(30)));
15425 }
15426 Err(e) => {
15427 println!("IN subquery with aggregate not supported: {}", e);
15428 }
15429 }
15430 }
15431
15432 #[test]
15433 fn test_exists_on_single_row_table() {
15434 let db = EmbeddedDatabase::new_in_memory().unwrap();
15436 db.execute("CREATE TABLE singleton (v INT)").unwrap();
15437 db.execute("INSERT INTO singleton VALUES (42)").unwrap();
15438 db.execute("CREATE TABLE checker (id INT)").unwrap();
15439 db.execute("INSERT INTO checker VALUES (1)").unwrap();
15440 db.execute("INSERT INTO checker VALUES (2)").unwrap();
15441
15442 let sql = "SELECT id FROM checker WHERE EXISTS (SELECT 1 FROM singleton) ORDER BY id";
15443 match db.query(sql, &[]) {
15444 Ok(rows) => {
15445 assert_eq!(rows.len(), 2, "EXISTS on single-row table should return all checker rows");
15446 assert_eq!(rows[0].get(0), Some(&Value::Int4(1)));
15447 assert_eq!(rows[1].get(0), Some(&Value::Int4(2)));
15448 }
15449 Err(e) => {
15450 println!("EXISTS on single-row table not supported: {}", e);
15451 }
15452 }
15453 }
15454
15455 #[test]
15456 fn test_subquery_in_with_string_column() {
15457 let db = setup_subquery_tables();
15459
15460 let sql = "SELECT id, name FROM customers WHERE name IN (SELECT name FROM products) ORDER BY id";
15461 match db.query(sql, &[]) {
15462 Ok(rows) => {
15463 assert_eq!(rows.len(), 0, "No customer names match product names, got {}", rows.len());
15467 }
15468 Err(e) => {
15469 println!("IN subquery with string column not supported: {}", e);
15470 }
15471 }
15472 }
15473
15474 #[test]
15475 fn test_subquery_not_in_with_string_column() {
15476 let db = setup_subquery_tables();
15478
15479 let sql = "SELECT id, name FROM customers WHERE name NOT IN (SELECT name FROM products) ORDER BY id";
15480 match db.query(sql, &[]) {
15481 Ok(rows) => {
15482 assert_eq!(rows.len(), 4, "All 4 customers have non-product names, got {}", rows.len());
15483 }
15484 Err(e) => {
15485 println!("NOT IN subquery with string column not supported: {}", e);
15486 }
15487 }
15488 }
15489
15490 #[test]
15491 fn test_subquery_in_with_filter_on_subquery() {
15492 let db = setup_subquery_tables();
15495
15496 let sql = "SELECT id, name FROM customers WHERE id IN (SELECT customer_id FROM orders WHERE amount >= 100) ORDER BY id";
15497 match db.query(sql, &[]) {
15498 Ok(rows) => {
15499 assert_eq!(rows.len(), 2, "2 customers with orders >= 100, got {}", rows.len());
15502 assert_eq!(rows[0].get(0), Some(&Value::Int4(1)));
15503 assert_eq!(rows[0].get(1), Some(&Value::String("Alice".to_string())));
15504 assert_eq!(rows[1].get(0), Some(&Value::Int4(3)));
15505 assert_eq!(rows[1].get(1), Some(&Value::String("Charlie".to_string())));
15506 }
15507 Err(e) => {
15508 println!("IN subquery with WHERE filter not supported: {}", e);
15509 }
15510 }
15511 }
15512
15513 #[test]
15514 fn test_exists_combined_with_or() {
15515 let db = setup_subquery_tables();
15518
15519 let sql = "SELECT id, name FROM customers \
15520 WHERE category = 'premium' \
15521 OR EXISTS (SELECT 1 FROM orders WHERE amount > 9999) \
15522 ORDER BY id";
15523 match db.query(sql, &[]) {
15524 Ok(rows) => {
15525 assert_eq!(rows.len(), 2, "Only premium customers when EXISTS is false, got {}", rows.len());
15528 assert_eq!(rows[0].get(1), Some(&Value::String("Alice".to_string())));
15529 assert_eq!(rows[1].get(1), Some(&Value::String("Charlie".to_string())));
15530 }
15531 Err(e) => {
15532 println!("EXISTS combined with OR not supported: {}", e);
15533 }
15534 }
15535 }
15536
15537 #[test]
15538 fn test_not_exists_combined_with_and() {
15539 let db = setup_subquery_tables();
15542
15543 let sql = "SELECT id, name FROM customers \
15544 WHERE category = 'standard' \
15545 AND NOT EXISTS (SELECT 1 FROM orders WHERE amount > 9999) \
15546 ORDER BY id";
15547 match db.query(sql, &[]) {
15548 Ok(rows) => {
15549 assert_eq!(rows.len(), 2, "Standard customers when NOT EXISTS is true, got {}", rows.len());
15552 assert_eq!(rows[0].get(1), Some(&Value::String("Bob".to_string())));
15553 assert_eq!(rows[1].get(1), Some(&Value::String("Diana".to_string())));
15554 }
15555 Err(e) => {
15556 println!("NOT EXISTS combined with AND not supported: {}", e);
15557 }
15558 }
15559 }
15560
15561 fn setup_multi_table_join_db() -> EmbeddedDatabase {
15572 let db = EmbeddedDatabase::new_in_memory().unwrap();
15573
15574 db.execute("CREATE TABLE jt_customers (id INT PRIMARY KEY, name TEXT, city TEXT)")
15576 .unwrap();
15577 db.execute("INSERT INTO jt_customers VALUES (1, 'Alice', 'NYC')").unwrap();
15578 db.execute("INSERT INTO jt_customers VALUES (2, 'Bob', 'LA')").unwrap();
15579 db.execute("INSERT INTO jt_customers VALUES (3, 'Carol', 'NYC')").unwrap();
15580 db.execute("INSERT INTO jt_customers VALUES (4, 'Diana', 'Chicago')").unwrap();
15581
15582 db.execute("CREATE TABLE jt_products (id INT PRIMARY KEY, name TEXT, price INT)")
15584 .unwrap();
15585 db.execute("INSERT INTO jt_products VALUES (10, 'Widget', 100)").unwrap();
15586 db.execute("INSERT INTO jt_products VALUES (20, 'Gadget', 250)").unwrap();
15587 db.execute("INSERT INTO jt_products VALUES (30, 'Doohickey', 50)").unwrap();
15588
15589 db.execute(
15591 "CREATE TABLE jt_orders (id INT PRIMARY KEY, customer_id INT, product_id INT, qty INT)",
15592 )
15593 .unwrap();
15594 db.execute("INSERT INTO jt_orders VALUES (100, 1, 10, 2)").unwrap(); db.execute("INSERT INTO jt_orders VALUES (101, 1, 20, 1)").unwrap(); db.execute("INSERT INTO jt_orders VALUES (102, 2, 10, 5)").unwrap(); db.execute("INSERT INTO jt_orders VALUES (103, 3, 30, 3)").unwrap(); db
15602 }
15603
15604 #[test]
15605 fn test_join_three_table_inner() {
15606 let db = setup_multi_table_join_db();
15608
15609 let sql = "\
15610 SELECT jt_customers.name, jt_products.name, jt_orders.qty \
15611 FROM jt_orders \
15612 JOIN jt_customers ON jt_orders.customer_id = jt_customers.id \
15613 JOIN jt_products ON jt_orders.product_id = jt_products.id \
15614 ORDER BY jt_orders.id";
15615
15616 match db.query(sql, &[]) {
15617 Ok(rows) => {
15618 assert_eq!(rows.len(), 4, "Expected 4 order rows, got {}", rows.len());
15619 assert_eq!(rows[0].get(0).unwrap(), &Value::String("Alice".to_string()));
15621 assert_eq!(rows[0].get(1).unwrap(), &Value::String("Widget".to_string()));
15622 assert_eq!(rows[0].get(2).unwrap(), &Value::Int4(2));
15623 assert_eq!(rows[1].get(0).unwrap(), &Value::String("Alice".to_string()));
15625 assert_eq!(rows[1].get(1).unwrap(), &Value::String("Gadget".to_string()));
15626 assert_eq!(rows[1].get(2).unwrap(), &Value::Int4(1));
15627 assert_eq!(rows[2].get(0).unwrap(), &Value::String("Bob".to_string()));
15629 assert_eq!(rows[2].get(1).unwrap(), &Value::String("Widget".to_string()));
15630 assert_eq!(rows[2].get(2).unwrap(), &Value::Int4(5));
15631 assert_eq!(rows[3].get(0).unwrap(), &Value::String("Carol".to_string()));
15633 assert_eq!(rows[3].get(1).unwrap(), &Value::String("Doohickey".to_string()));
15634 assert_eq!(rows[3].get(2).unwrap(), &Value::Int4(3));
15635 }
15636 Err(e) => {
15637 panic!("3-table INNER JOIN failed: {}", e);
15638 }
15639 }
15640 }
15641
15642 #[test]
15643 fn test_join_four_table_chain() {
15644 let db = EmbeddedDatabase::new_in_memory().unwrap();
15646
15647 db.execute("CREATE TABLE jt4_cities (id INT PRIMARY KEY, city_name TEXT)").unwrap();
15648 db.execute("INSERT INTO jt4_cities VALUES (1, 'New York')").unwrap();
15649 db.execute("INSERT INTO jt4_cities VALUES (2, 'Los Angeles')").unwrap();
15650
15651 db.execute("CREATE TABLE jt4_addresses (id INT PRIMARY KEY, street TEXT, city_id INT)")
15652 .unwrap();
15653 db.execute("INSERT INTO jt4_addresses VALUES (10, '123 Main St', 1)").unwrap();
15654 db.execute("INSERT INTO jt4_addresses VALUES (20, '456 Oak Ave', 2)").unwrap();
15655
15656 db.execute(
15657 "CREATE TABLE jt4_customers (id INT PRIMARY KEY, name TEXT, address_id INT)",
15658 )
15659 .unwrap();
15660 db.execute("INSERT INTO jt4_customers VALUES (100, 'Alice', 10)").unwrap();
15661 db.execute("INSERT INTO jt4_customers VALUES (200, 'Bob', 20)").unwrap();
15662
15663 db.execute(
15664 "CREATE TABLE jt4_orders (id INT PRIMARY KEY, customer_id INT, amount INT)",
15665 )
15666 .unwrap();
15667 db.execute("INSERT INTO jt4_orders VALUES (1000, 100, 500)").unwrap();
15668 db.execute("INSERT INTO jt4_orders VALUES (1001, 200, 300)").unwrap();
15669
15670 let sql = "\
15671 SELECT jt4_orders.id, jt4_customers.name, jt4_addresses.street, jt4_cities.city_name \
15672 FROM jt4_orders \
15673 JOIN jt4_customers ON jt4_orders.customer_id = jt4_customers.id \
15674 JOIN jt4_addresses ON jt4_customers.address_id = jt4_addresses.id \
15675 JOIN jt4_cities ON jt4_addresses.city_id = jt4_cities.id \
15676 ORDER BY jt4_orders.id";
15677
15678 match db.query(sql, &[]) {
15679 Ok(rows) => {
15680 assert_eq!(rows.len(), 2, "Expected 2 rows from 4-table chain, got {}", rows.len());
15681 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1000));
15683 assert_eq!(rows[0].get(1).unwrap(), &Value::String("Alice".to_string()));
15684 assert_eq!(rows[0].get(2).unwrap(), &Value::String("123 Main St".to_string()));
15685 assert_eq!(rows[0].get(3).unwrap(), &Value::String("New York".to_string()));
15686 assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(1001));
15688 assert_eq!(rows[1].get(1).unwrap(), &Value::String("Bob".to_string()));
15689 assert_eq!(rows[1].get(2).unwrap(), &Value::String("456 Oak Ave".to_string()));
15690 assert_eq!(rows[1].get(3).unwrap(), &Value::String("Los Angeles".to_string()));
15691 }
15692 Err(e) => {
15693 panic!("4-table JOIN chain failed: {}", e);
15694 }
15695 }
15696 }
15697
15698 fn setup_employee_db() -> EmbeddedDatabase {
15700 let db = EmbeddedDatabase::new_in_memory().unwrap();
15701
15702 db.execute(
15703 "CREATE TABLE jt_employees (id INT PRIMARY KEY, name TEXT, manager_id INT, dept TEXT)",
15704 )
15705 .unwrap();
15706 db.execute("INSERT INTO jt_employees VALUES (1, 'Eve', NULL, 'Exec')").unwrap();
15708 db.execute("INSERT INTO jt_employees VALUES (2, 'Frank', 1, 'Engineering')").unwrap();
15710 db.execute("INSERT INTO jt_employees VALUES (3, 'Grace', 1, 'Sales')").unwrap();
15711 db.execute("INSERT INTO jt_employees VALUES (4, 'Hank', 2, 'Engineering')").unwrap();
15713 db.execute("INSERT INTO jt_employees VALUES (5, 'Iris', 2, 'Engineering')").unwrap();
15714 db.execute("INSERT INTO jt_employees VALUES (6, 'Jack', 3, 'Sales')").unwrap();
15716
15717 db
15718 }
15719
15720 #[test]
15721 fn test_join_self_join_employees() {
15722 let db = setup_employee_db();
15725
15726 let sql = "\
15727 SELECT e.name, m.name \
15728 FROM jt_employees e \
15729 JOIN jt_employees m ON e.manager_id = m.id \
15730 ORDER BY e.id";
15731
15732 match db.query(sql, &[]) {
15733 Ok(rows) => {
15734 assert_eq!(rows.len(), 5, "5 employees have managers, got {}", rows.len());
15736 assert_eq!(rows[0].get(0).unwrap(), &Value::String("Frank".to_string()));
15738 assert_eq!(rows[0].get(1).unwrap(), &Value::String("Eve".to_string()));
15739 assert_eq!(rows[1].get(0).unwrap(), &Value::String("Grace".to_string()));
15741 assert_eq!(rows[1].get(1).unwrap(), &Value::String("Eve".to_string()));
15742 assert_eq!(rows[2].get(0).unwrap(), &Value::String("Hank".to_string()));
15744 assert_eq!(rows[2].get(1).unwrap(), &Value::String("Frank".to_string()));
15745 assert_eq!(rows[3].get(0).unwrap(), &Value::String("Iris".to_string()));
15747 assert_eq!(rows[3].get(1).unwrap(), &Value::String("Frank".to_string()));
15748 assert_eq!(rows[4].get(0).unwrap(), &Value::String("Jack".to_string()));
15750 assert_eq!(rows[4].get(1).unwrap(), &Value::String("Grace".to_string()));
15751 }
15752 Err(e) => {
15753 panic!("Self-join (employees->managers) failed: {}", e);
15754 }
15755 }
15756 }
15757
15758 #[test]
15759 fn test_join_self_join_left_with_null_manager() {
15760 let db = setup_employee_db();
15762
15763 let sql = "\
15764 SELECT e.name, m.name \
15765 FROM jt_employees e \
15766 LEFT JOIN jt_employees m ON e.manager_id = m.id \
15767 ORDER BY e.id";
15768
15769 match db.query(sql, &[]) {
15770 Ok(rows) => {
15771 assert_eq!(rows.len(), 6, "All 6 employees should appear, got {}", rows.len());
15773 assert_eq!(rows[0].get(0).unwrap(), &Value::String("Eve".to_string()));
15775 assert_eq!(rows[0].get(1).unwrap(), &Value::Null);
15776 assert_eq!(rows[1].get(0).unwrap(), &Value::String("Frank".to_string()));
15778 assert_eq!(rows[1].get(1).unwrap(), &Value::String("Eve".to_string()));
15779 }
15780 Err(e) => {
15781 panic!("LEFT JOIN self-join failed: {}", e);
15782 }
15783 }
15784 }
15785
15786 #[test]
15787 fn test_join_left_join_three_tables() {
15788 let db = setup_multi_table_join_db();
15791
15792 let sql = "\
15793 SELECT jt_customers.name, jt_orders.id, jt_products.name \
15794 FROM jt_customers \
15795 LEFT JOIN jt_orders ON jt_customers.id = jt_orders.customer_id \
15796 LEFT JOIN jt_products ON jt_orders.product_id = jt_products.id \
15797 ORDER BY jt_customers.id, jt_orders.id";
15798
15799 match db.query(sql, &[]) {
15800 Ok(rows) => {
15801 assert_eq!(rows.len(), 5, "Expected 5 rows (4 orders + 1 NULL), got {}", rows.len());
15803
15804 let diana_row = &rows[4];
15806 assert_eq!(diana_row.get(0).unwrap(), &Value::String("Diana".to_string()));
15807 assert_eq!(diana_row.get(1).unwrap(), &Value::Null);
15808 assert_eq!(diana_row.get(2).unwrap(), &Value::Null);
15809 }
15810 Err(e) => {
15811 panic!("3-table LEFT JOIN failed: {}", e);
15812 }
15813 }
15814 }
15815
15816 #[test]
15817 fn test_join_right_join() {
15818 let db = setup_multi_table_join_db();
15820
15821 db.execute("INSERT INTO jt_products VALUES (40, 'Thingamajig', 75)").unwrap();
15823
15824 let sql = "\
15825 SELECT jt_orders.id, jt_products.name \
15826 FROM jt_orders \
15827 RIGHT JOIN jt_products ON jt_orders.product_id = jt_products.id \
15828 ORDER BY jt_products.id";
15829
15830 match db.query(sql, &[]) {
15831 Ok(rows) => {
15832 assert_eq!(rows.len(), 5, "Expected 5 rows from RIGHT JOIN, got {}", rows.len());
15836
15837 let last = &rows[rows.len() - 1];
15839 assert_eq!(last.get(0).unwrap(), &Value::Null);
15840 assert_eq!(last.get(1).unwrap(), &Value::String("Thingamajig".to_string()));
15841 }
15842 Err(e) => {
15843 panic!("RIGHT JOIN failed: {}", e);
15844 }
15845 }
15846 }
15847
15848 #[test]
15849 fn test_join_cross_join() {
15850 let db = EmbeddedDatabase::new_in_memory().unwrap();
15852
15853 db.execute("CREATE TABLE jt_colors (id INT PRIMARY KEY, color TEXT)").unwrap();
15854 db.execute("INSERT INTO jt_colors VALUES (1, 'Red')").unwrap();
15855 db.execute("INSERT INTO jt_colors VALUES (2, 'Blue')").unwrap();
15856
15857 db.execute("CREATE TABLE jt_sizes (id INT PRIMARY KEY, size TEXT)").unwrap();
15858 db.execute("INSERT INTO jt_sizes VALUES (1, 'Small')").unwrap();
15859 db.execute("INSERT INTO jt_sizes VALUES (2, 'Medium')").unwrap();
15860 db.execute("INSERT INTO jt_sizes VALUES (3, 'Large')").unwrap();
15861
15862 let sql = "\
15863 SELECT jt_colors.color, jt_sizes.size \
15864 FROM jt_colors \
15865 CROSS JOIN jt_sizes \
15866 ORDER BY jt_colors.id, jt_sizes.id";
15867
15868 match db.query(sql, &[]) {
15869 Ok(rows) => {
15870 assert_eq!(rows.len(), 6, "CROSS JOIN should produce 6 rows, got {}", rows.len());
15872 assert_eq!(rows[0].get(0).unwrap(), &Value::String("Red".to_string()));
15874 assert_eq!(rows[0].get(1).unwrap(), &Value::String("Small".to_string()));
15875 assert_eq!(rows[5].get(0).unwrap(), &Value::String("Blue".to_string()));
15877 assert_eq!(rows[5].get(1).unwrap(), &Value::String("Large".to_string()));
15878 }
15879 Err(e) => {
15880 panic!("CROSS JOIN failed: {}", e);
15881 }
15882 }
15883 }
15884
15885 #[test]
15886 fn test_join_multiple_conditions() {
15887 let db = EmbeddedDatabase::new_in_memory().unwrap();
15889
15890 db.execute("CREATE TABLE jt_left (a INT, b INT, val TEXT)").unwrap();
15891 db.execute("INSERT INTO jt_left VALUES (1, 10, 'x')").unwrap();
15892 db.execute("INSERT INTO jt_left VALUES (1, 20, 'y')").unwrap();
15893 db.execute("INSERT INTO jt_left VALUES (2, 10, 'z')").unwrap();
15894
15895 db.execute("CREATE TABLE jt_right (a INT, b INT, info TEXT)").unwrap();
15896 db.execute("INSERT INTO jt_right VALUES (1, 10, 'match1')").unwrap();
15897 db.execute("INSERT INTO jt_right VALUES (1, 20, 'match2')").unwrap();
15898 db.execute("INSERT INTO jt_right VALUES (2, 20, 'no_match')").unwrap();
15899
15900 let sql = "\
15901 SELECT jt_left.val, jt_right.info \
15902 FROM jt_left \
15903 JOIN jt_right ON jt_left.a = jt_right.a AND jt_left.b = jt_right.b \
15904 ORDER BY jt_left.val";
15905
15906 match db.query(sql, &[]) {
15907 Ok(rows) => {
15908 assert_eq!(rows.len(), 2, "Expected 2 rows matching both conditions, got {}", rows.len());
15910 assert_eq!(rows[0].get(0).unwrap(), &Value::String("x".to_string()));
15911 assert_eq!(rows[0].get(1).unwrap(), &Value::String("match1".to_string()));
15912 assert_eq!(rows[1].get(0).unwrap(), &Value::String("y".to_string()));
15913 assert_eq!(rows[1].get(1).unwrap(), &Value::String("match2".to_string()));
15914 }
15915 Err(e) => {
15916 panic!("JOIN with multiple conditions failed: {}", e);
15917 }
15918 }
15919 }
15920
15921 #[test]
15922 fn test_join_with_aggregate_count() {
15923 let db = setup_multi_table_join_db();
15925
15926 let sql = "\
15927 SELECT jt_customers.name, COUNT(jt_orders.id) \
15928 FROM jt_customers \
15929 JOIN jt_orders ON jt_customers.id = jt_orders.customer_id \
15930 GROUP BY jt_customers.name \
15931 ORDER BY jt_customers.name";
15932
15933 match db.query(sql, &[]) {
15934 Ok(rows) => {
15935 assert_eq!(rows.len(), 3, "3 customers have orders, got {}", rows.len());
15937 assert_eq!(rows[0].get(0).unwrap(), &Value::String("Alice".to_string()));
15938 assert_eq!(rows[0].get(1).unwrap(), &Value::Int8(2));
15939 assert_eq!(rows[1].get(0).unwrap(), &Value::String("Bob".to_string()));
15940 assert_eq!(rows[1].get(1).unwrap(), &Value::Int8(1));
15941 assert_eq!(rows[2].get(0).unwrap(), &Value::String("Carol".to_string()));
15942 assert_eq!(rows[2].get(1).unwrap(), &Value::Int8(1));
15943 }
15944 Err(e) => {
15945 panic!("JOIN with aggregate COUNT failed: {}", e);
15946 }
15947 }
15948 }
15949
15950 #[test]
15951 fn test_join_with_where_filter() {
15952 let db = setup_multi_table_join_db();
15954
15955 let sql = "\
15956 SELECT jt_customers.name, jt_products.name, jt_orders.qty \
15957 FROM jt_orders \
15958 JOIN jt_customers ON jt_orders.customer_id = jt_customers.id \
15959 JOIN jt_products ON jt_orders.product_id = jt_products.id \
15960 WHERE jt_orders.qty > 2 \
15961 ORDER BY jt_customers.name";
15962
15963 match db.query(sql, &[]) {
15964 Ok(rows) => {
15965 assert_eq!(rows.len(), 2, "Expected 2 orders with qty > 2, got {}", rows.len());
15967 assert_eq!(rows[0].get(0).unwrap(), &Value::String("Bob".to_string()));
15968 assert_eq!(rows[0].get(2).unwrap(), &Value::Int4(5));
15969 assert_eq!(rows[1].get(0).unwrap(), &Value::String("Carol".to_string()));
15970 assert_eq!(rows[1].get(2).unwrap(), &Value::Int4(3));
15971 }
15972 Err(e) => {
15973 panic!("JOIN with WHERE filter failed: {}", e);
15974 }
15975 }
15976 }
15977
15978 #[test]
15979 fn test_join_empty_result() {
15980 let db = EmbeddedDatabase::new_in_memory().unwrap();
15982
15983 db.execute("CREATE TABLE jt_empty_a (id INT PRIMARY KEY, val TEXT)").unwrap();
15984 db.execute("INSERT INTO jt_empty_a VALUES (1, 'one')").unwrap();
15985 db.execute("INSERT INTO jt_empty_a VALUES (2, 'two')").unwrap();
15986
15987 db.execute("CREATE TABLE jt_empty_b (id INT PRIMARY KEY, ref_id INT, info TEXT)").unwrap();
15988 db.execute("INSERT INTO jt_empty_b VALUES (10, 99, 'orphan1')").unwrap();
15989 db.execute("INSERT INTO jt_empty_b VALUES (20, 98, 'orphan2')").unwrap();
15990
15991 let sql = "\
15993 SELECT jt_empty_a.val, jt_empty_b.info \
15994 FROM jt_empty_a \
15995 JOIN jt_empty_b ON jt_empty_a.id = jt_empty_b.ref_id";
15996
15997 match db.query(sql, &[]) {
15998 Ok(rows) => {
15999 assert_eq!(rows.len(), 0, "No matching rows, expected 0, got {}", rows.len());
16000 }
16001 Err(e) => {
16002 panic!("Empty JOIN result test failed: {}", e);
16003 }
16004 }
16005 }
16006
16007 #[test]
16008 fn test_join_with_null_fk_left_join() {
16009 let db = EmbeddedDatabase::new_in_memory().unwrap();
16011
16012 db.execute("CREATE TABLE jt_depts (id INT PRIMARY KEY, dept_name TEXT)").unwrap();
16013 db.execute("INSERT INTO jt_depts VALUES (1, 'Engineering')").unwrap();
16014 db.execute("INSERT INTO jt_depts VALUES (2, 'Marketing')").unwrap();
16015
16016 db.execute("CREATE TABLE jt_staff (id INT PRIMARY KEY, name TEXT, dept_id INT)").unwrap();
16017 db.execute("INSERT INTO jt_staff VALUES (1, 'Alice', 1)").unwrap();
16018 db.execute("INSERT INTO jt_staff VALUES (2, 'Bob', NULL)").unwrap(); db.execute("INSERT INTO jt_staff VALUES (3, 'Carol', 2)").unwrap();
16020 db.execute("INSERT INTO jt_staff VALUES (4, 'Dave', NULL)").unwrap(); let sql = "\
16023 SELECT jt_staff.name, jt_depts.dept_name \
16024 FROM jt_staff \
16025 LEFT JOIN jt_depts ON jt_staff.dept_id = jt_depts.id \
16026 ORDER BY jt_staff.id";
16027
16028 match db.query(sql, &[]) {
16029 Ok(rows) => {
16030 assert_eq!(rows.len(), 4, "All 4 staff should appear, got {}", rows.len());
16031 assert_eq!(rows[0].get(0).unwrap(), &Value::String("Alice".to_string()));
16033 assert_eq!(rows[0].get(1).unwrap(), &Value::String("Engineering".to_string()));
16034 assert_eq!(rows[1].get(0).unwrap(), &Value::String("Bob".to_string()));
16036 assert_eq!(rows[1].get(1).unwrap(), &Value::Null);
16037 assert_eq!(rows[2].get(0).unwrap(), &Value::String("Carol".to_string()));
16039 assert_eq!(rows[2].get(1).unwrap(), &Value::String("Marketing".to_string()));
16040 assert_eq!(rows[3].get(0).unwrap(), &Value::String("Dave".to_string()));
16042 assert_eq!(rows[3].get(1).unwrap(), &Value::Null);
16043 }
16044 Err(e) => {
16045 panic!("LEFT JOIN with NULL FK failed: {}", e);
16046 }
16047 }
16048 }
16049
16050 #[test]
16051 fn test_join_three_table_with_aggregate_sum() {
16052 let db = setup_multi_table_join_db();
16054
16055 let sql = "\
16056 SELECT jt_customers.name, SUM(jt_orders.qty * jt_products.price) \
16057 FROM jt_orders \
16058 JOIN jt_customers ON jt_orders.customer_id = jt_customers.id \
16059 JOIN jt_products ON jt_orders.product_id = jt_products.id \
16060 GROUP BY jt_customers.name \
16061 ORDER BY jt_customers.name";
16062
16063 match db.query(sql, &[]) {
16064 Ok(rows) => {
16065 assert_eq!(rows.len(), 3, "3 customers with orders, got {}", rows.len());
16069 assert_eq!(rows[0].get(0).unwrap(), &Value::String("Alice".to_string()));
16070 let alice_total = rows[0].get(1).unwrap();
16072 let alice_val = match alice_total {
16073 Value::Int4(v) => *v as i64,
16074 Value::Int8(v) => *v,
16075 Value::Float8(v) => *v as i64,
16076 other => panic!("Unexpected type for SUM: {:?}", other),
16077 };
16078 assert_eq!(alice_val, 450, "Alice total revenue should be 450, got {}", alice_val);
16079
16080 let bob_total = rows[1].get(1).unwrap();
16081 let bob_val = match bob_total {
16082 Value::Int4(v) => *v as i64,
16083 Value::Int8(v) => *v,
16084 Value::Float8(v) => *v as i64,
16085 other => panic!("Unexpected type for SUM: {:?}", other),
16086 };
16087 assert_eq!(bob_val, 500, "Bob total revenue should be 500, got {}", bob_val);
16088
16089 let carol_total = rows[2].get(1).unwrap();
16090 let carol_val = match carol_total {
16091 Value::Int4(v) => *v as i64,
16092 Value::Int8(v) => *v,
16093 Value::Float8(v) => *v as i64,
16094 other => panic!("Unexpected type for SUM: {:?}", other),
16095 };
16096 assert_eq!(carol_val, 150, "Carol total revenue should be 150, got {}", carol_val);
16097 }
16098 Err(e) => {
16099 panic!("3-table JOIN with SUM aggregate failed: {}", e);
16100 }
16101 }
16102 }
16103
16104 #[test]
16105 fn test_join_self_join_same_department() {
16106 let db = setup_employee_db();
16109
16110 let sql = "\
16111 SELECT e1.name, e2.name, e1.dept \
16112 FROM jt_employees e1 \
16113 JOIN jt_employees e2 ON e1.dept = e2.dept AND e1.id < e2.id \
16114 ORDER BY e1.dept, e1.id, e2.id";
16115
16116 match db.query(sql, &[]) {
16117 Ok(rows) => {
16118 assert_eq!(rows.len(), 4, "Expected 4 same-dept pairs, got {}", rows.len());
16123
16124 let pairs: Vec<(String, String)> = rows
16126 .iter()
16127 .map(|r| {
16128 let n1 = match r.get(0).unwrap() {
16129 Value::String(s) => s.clone(),
16130 other => panic!("Expected String, got {:?}", other),
16131 };
16132 let n2 = match r.get(1).unwrap() {
16133 Value::String(s) => s.clone(),
16134 other => panic!("Expected String, got {:?}", other),
16135 };
16136 (n1, n2)
16137 })
16138 .collect();
16139
16140 assert!(
16141 pairs.contains(&("Frank".to_string(), "Hank".to_string())),
16142 "Should contain (Frank, Hank)"
16143 );
16144 assert!(
16145 pairs.contains(&("Grace".to_string(), "Jack".to_string())),
16146 "Should contain (Grace, Jack)"
16147 );
16148 }
16149 Err(e) => {
16150 panic!("Self-join same department failed: {}", e);
16151 }
16152 }
16153 }
16154
16155 #[test]
16156 fn test_join_cross_join_with_where() {
16157 let db = EmbeddedDatabase::new_in_memory().unwrap();
16159
16160 db.execute("CREATE TABLE jt_t1 (id INT PRIMARY KEY, val TEXT)").unwrap();
16161 db.execute("INSERT INTO jt_t1 VALUES (1, 'a')").unwrap();
16162 db.execute("INSERT INTO jt_t1 VALUES (2, 'b')").unwrap();
16163
16164 db.execute("CREATE TABLE jt_t2 (id INT PRIMARY KEY, t1_id INT, info TEXT)").unwrap();
16165 db.execute("INSERT INTO jt_t2 VALUES (10, 1, 'info1')").unwrap();
16166 db.execute("INSERT INTO jt_t2 VALUES (20, 2, 'info2')").unwrap();
16167 db.execute("INSERT INTO jt_t2 VALUES (30, 1, 'info3')").unwrap();
16168
16169 let sql = "\
16171 SELECT jt_t1.val, jt_t2.info \
16172 FROM jt_t1 \
16173 CROSS JOIN jt_t2 \
16174 WHERE jt_t1.id = jt_t2.t1_id";
16175
16176 match db.query(sql, &[]) {
16177 Ok(rows) => {
16178 assert_eq!(rows.len(), 3, "3 matching rows expected, got {}", rows.len());
16180
16181 let mut pairs: Vec<(String, String)> = rows
16183 .iter()
16184 .map(|r| {
16185 let val = match r.get(0).unwrap() {
16186 Value::String(s) => s.clone(),
16187 other => panic!("Expected String, got {:?}", other),
16188 };
16189 let info = match r.get(1).unwrap() {
16190 Value::String(s) => s.clone(),
16191 other => panic!("Expected String, got {:?}", other),
16192 };
16193 (val, info)
16194 })
16195 .collect();
16196 pairs.sort();
16197
16198 assert_eq!(
16199 pairs,
16200 vec![
16201 ("a".to_string(), "info1".to_string()),
16202 ("a".to_string(), "info3".to_string()),
16203 ("b".to_string(), "info2".to_string()),
16204 ],
16205 "CROSS JOIN + WHERE should produce the correct matching pairs"
16206 );
16207 }
16208 Err(e) => {
16209 panic!("CROSS JOIN with WHERE failed: {}", e);
16210 }
16211 }
16212 }
16213
16214 #[test]
16215 fn test_join_left_join_count_including_zero() {
16216 let db = setup_multi_table_join_db();
16218
16219 let sql = "\
16220 SELECT jt_customers.name, COUNT(jt_orders.id) \
16221 FROM jt_customers \
16222 LEFT JOIN jt_orders ON jt_customers.id = jt_orders.customer_id \
16223 GROUP BY jt_customers.name \
16224 ORDER BY jt_customers.name";
16225
16226 match db.query(sql, &[]) {
16227 Ok(rows) => {
16228 assert_eq!(rows.len(), 4, "All 4 customers should appear, got {}", rows.len());
16230
16231 assert_eq!(rows[0].get(0).unwrap(), &Value::String("Alice".to_string()));
16233 assert_eq!(rows[0].get(1).unwrap(), &Value::Int8(2));
16234 assert_eq!(rows[1].get(0).unwrap(), &Value::String("Bob".to_string()));
16235 assert_eq!(rows[1].get(1).unwrap(), &Value::Int8(1));
16236 assert_eq!(rows[2].get(0).unwrap(), &Value::String("Carol".to_string()));
16237 assert_eq!(rows[2].get(1).unwrap(), &Value::Int8(1));
16238 assert_eq!(rows[3].get(0).unwrap(), &Value::String("Diana".to_string()));
16239 let diana_count = rows[3].get(1).unwrap();
16243 match diana_count {
16244 Value::Int8(n) => {
16245 assert!(
16246 *n == 0 || *n == 1,
16247 "Diana order count should be 0 (or 1 if NULL counted), got {}",
16248 n
16249 );
16250 }
16251 other => panic!("Expected Int8 for COUNT, got {:?}", other),
16252 }
16253 }
16254 Err(e) => {
16255 panic!("LEFT JOIN + COUNT with zero orders failed: {}", e);
16256 }
16257 }
16258 }
16259
16260 #[test]
16261 fn test_join_three_table_mixed_join_types() {
16262 let db = EmbeddedDatabase::new_in_memory().unwrap();
16265
16266 db.execute("CREATE TABLE jt_mix_cust (id INT PRIMARY KEY, name TEXT)").unwrap();
16267 db.execute("INSERT INTO jt_mix_cust VALUES (1, 'Alice')").unwrap();
16268 db.execute("INSERT INTO jt_mix_cust VALUES (2, 'Bob')").unwrap();
16269
16270 db.execute("CREATE TABLE jt_mix_orders (id INT PRIMARY KEY, cust_id INT, product TEXT)")
16271 .unwrap();
16272 db.execute("INSERT INTO jt_mix_orders VALUES (10, 1, 'Widget')").unwrap();
16273 db.execute("INSERT INTO jt_mix_orders VALUES (20, 2, 'Gadget')").unwrap();
16274
16275 db.execute(
16276 "CREATE TABLE jt_mix_reviews (id INT PRIMARY KEY, order_id INT, rating INT)",
16277 )
16278 .unwrap();
16279 db.execute("INSERT INTO jt_mix_reviews VALUES (100, 10, 5)").unwrap();
16281
16282 let sql = "\
16283 SELECT jt_mix_cust.name, jt_mix_orders.product, jt_mix_reviews.rating \
16284 FROM jt_mix_cust \
16285 JOIN jt_mix_orders ON jt_mix_cust.id = jt_mix_orders.cust_id \
16286 LEFT JOIN jt_mix_reviews ON jt_mix_orders.id = jt_mix_reviews.order_id \
16287 ORDER BY jt_mix_cust.id";
16288
16289 match db.query(sql, &[]) {
16290 Ok(rows) => {
16291 assert_eq!(rows.len(), 2, "2 orders expected, got {}", rows.len());
16292 assert_eq!(rows[0].get(0).unwrap(), &Value::String("Alice".to_string()));
16294 assert_eq!(rows[0].get(1).unwrap(), &Value::String("Widget".to_string()));
16295 assert_eq!(rows[0].get(2).unwrap(), &Value::Int4(5));
16296 assert_eq!(rows[1].get(0).unwrap(), &Value::String("Bob".to_string()));
16298 assert_eq!(rows[1].get(1).unwrap(), &Value::String("Gadget".to_string()));
16299 assert_eq!(rows[1].get(2).unwrap(), &Value::Null);
16300 }
16301 Err(e) => {
16302 panic!("Mixed INNER+LEFT JOIN failed: {}", e);
16303 }
16304 }
16305 }
16306
16307 #[test]
16308 fn test_join_with_where_and_order_by() {
16309 let db = setup_multi_table_join_db();
16311
16312 let sql = "\
16313 SELECT jt_customers.name, jt_products.name, jt_orders.qty \
16314 FROM jt_orders \
16315 JOIN jt_customers ON jt_orders.customer_id = jt_customers.id \
16316 JOIN jt_products ON jt_orders.product_id = jt_products.id \
16317 WHERE jt_customers.city = 'NYC' \
16318 ORDER BY jt_customers.name, jt_orders.qty";
16319
16320 match db.query(sql, &[]) {
16321 Ok(rows) => {
16322 assert_eq!(rows.len(), 3, "3 NYC orders expected, got {}", rows.len());
16326 let mut qtys: Vec<i32> = rows
16328 .iter()
16329 .map(|r| match r.get(2).unwrap() {
16330 Value::Int4(v) => *v,
16331 other => panic!("Expected Int4 for qty, got {:?}", other),
16332 })
16333 .collect();
16334 qtys.sort();
16335 assert_eq!(qtys, vec![1, 2, 3], "NYC orders should have qty 1, 2, 3");
16336
16337 for row in &rows {
16339 let name = match row.get(0).unwrap() {
16340 Value::String(s) => s.as_str(),
16341 other => panic!("Expected String name, got {:?}", other),
16342 };
16343 assert!(
16344 name == "Alice" || name == "Carol",
16345 "All results should be NYC customers, got '{}'",
16346 name
16347 );
16348 }
16349 }
16350 Err(e) => {
16351 panic!("JOIN + WHERE + ORDER BY failed: {}", e);
16352 }
16353 }
16354 }
16355
16356 #[test]
16357 fn test_join_empty_table() {
16358 let db = EmbeddedDatabase::new_in_memory().unwrap();
16360
16361 db.execute("CREATE TABLE jt_full (id INT PRIMARY KEY, val TEXT)").unwrap();
16362 db.execute("INSERT INTO jt_full VALUES (1, 'one')").unwrap();
16363 db.execute("INSERT INTO jt_full VALUES (2, 'two')").unwrap();
16364
16365 db.execute("CREATE TABLE jt_empty (id INT PRIMARY KEY, ref_id INT)").unwrap();
16366 let sql = "\
16369 SELECT jt_full.val, jt_empty.ref_id \
16370 FROM jt_full \
16371 JOIN jt_empty ON jt_full.id = jt_empty.ref_id";
16372
16373 match db.query(sql, &[]) {
16374 Ok(rows) => {
16375 assert_eq!(rows.len(), 0, "JOIN with empty table should produce 0 rows, got {}", rows.len());
16376 }
16377 Err(e) => {
16378 panic!("JOIN with empty table failed: {}", e);
16379 }
16380 }
16381 }
16382
16383 #[test]
16384 fn test_join_left_join_empty_right() {
16385 let db = EmbeddedDatabase::new_in_memory().unwrap();
16387
16388 db.execute("CREATE TABLE jt_main (id INT PRIMARY KEY, name TEXT)").unwrap();
16389 db.execute("INSERT INTO jt_main VALUES (1, 'alpha')").unwrap();
16390 db.execute("INSERT INTO jt_main VALUES (2, 'beta')").unwrap();
16391
16392 db.execute("CREATE TABLE jt_detail (id INT PRIMARY KEY, main_id INT, note TEXT)")
16393 .unwrap();
16394 let sql = "\
16397 SELECT jt_main.name, jt_detail.note \
16398 FROM jt_main \
16399 LEFT JOIN jt_detail ON jt_main.id = jt_detail.main_id \
16400 ORDER BY jt_main.id";
16401
16402 match db.query(sql, &[]) {
16403 Ok(rows) => {
16404 assert_eq!(rows.len(), 2, "All left rows should appear, got {}", rows.len());
16405 assert_eq!(rows[0].get(0).unwrap(), &Value::String("alpha".to_string()));
16406 assert_eq!(rows[0].get(1).unwrap(), &Value::Null);
16407 assert_eq!(rows[1].get(0).unwrap(), &Value::String("beta".to_string()));
16408 assert_eq!(rows[1].get(1).unwrap(), &Value::Null);
16409 }
16410 Err(e) => {
16411 panic!("LEFT JOIN with empty right table failed: {}", e);
16412 }
16413 }
16414 }
16415
16416 #[test]
16417 fn test_join_cross_join_single_row_tables() {
16418 let db = EmbeddedDatabase::new_in_memory().unwrap();
16420
16421 db.execute("CREATE TABLE jt_single_a (val TEXT)").unwrap();
16422 db.execute("INSERT INTO jt_single_a VALUES ('hello')").unwrap();
16423
16424 db.execute("CREATE TABLE jt_single_b (val TEXT)").unwrap();
16425 db.execute("INSERT INTO jt_single_b VALUES ('world')").unwrap();
16426
16427 let sql = "\
16428 SELECT jt_single_a.val, jt_single_b.val \
16429 FROM jt_single_a \
16430 CROSS JOIN jt_single_b";
16431
16432 match db.query(sql, &[]) {
16433 Ok(rows) => {
16434 assert_eq!(rows.len(), 1, "1x1 CROSS JOIN should produce 1 row, got {}", rows.len());
16435 assert_eq!(rows[0].get(0).unwrap(), &Value::String("hello".to_string()));
16436 assert_eq!(rows[0].get(1).unwrap(), &Value::String("world".to_string()));
16437 }
16438 Err(e) => {
16439 panic!("CROSS JOIN single-row tables failed: {}", e);
16440 }
16441 }
16442 }
16443
16444 #[test]
16445 fn test_join_duplicate_column_values() {
16446 let db = EmbeddedDatabase::new_in_memory().unwrap();
16448
16449 db.execute("CREATE TABLE jt_parent (id INT PRIMARY KEY, label TEXT)").unwrap();
16450 db.execute("INSERT INTO jt_parent VALUES (1, 'group_a')").unwrap();
16451
16452 db.execute("CREATE TABLE jt_child (id INT PRIMARY KEY, parent_id INT, name TEXT)")
16453 .unwrap();
16454 db.execute("INSERT INTO jt_child VALUES (10, 1, 'child1')").unwrap();
16455 db.execute("INSERT INTO jt_child VALUES (20, 1, 'child2')").unwrap();
16456 db.execute("INSERT INTO jt_child VALUES (30, 1, 'child3')").unwrap();
16457
16458 let sql = "\
16459 SELECT jt_parent.label, jt_child.name \
16460 FROM jt_parent \
16461 JOIN jt_child ON jt_parent.id = jt_child.parent_id \
16462 ORDER BY jt_child.id";
16463
16464 match db.query(sql, &[]) {
16465 Ok(rows) => {
16466 assert_eq!(rows.len(), 3, "1 parent x 3 children = 3 rows, got {}", rows.len());
16468 for row in &rows {
16469 assert_eq!(row.get(0).unwrap(), &Value::String("group_a".to_string()));
16470 }
16471 assert_eq!(rows[0].get(1).unwrap(), &Value::String("child1".to_string()));
16472 assert_eq!(rows[1].get(1).unwrap(), &Value::String("child2".to_string()));
16473 assert_eq!(rows[2].get(1).unwrap(), &Value::String("child3".to_string()));
16474 }
16475 Err(e) => {
16476 panic!("One-to-many JOIN failed: {}", e);
16477 }
16478 }
16479 }
16480
16481 #[test]
16482 fn test_join_five_table_chain() {
16483 let db = EmbeddedDatabase::new_in_memory().unwrap();
16485
16486 db.execute("CREATE TABLE jt5_a (id INT PRIMARY KEY, name TEXT)").unwrap();
16487 db.execute("INSERT INTO jt5_a VALUES (1, 'a1')").unwrap();
16488
16489 db.execute("CREATE TABLE jt5_b (id INT PRIMARY KEY, a_id INT, name TEXT)").unwrap();
16490 db.execute("INSERT INTO jt5_b VALUES (10, 1, 'b1')").unwrap();
16491
16492 db.execute("CREATE TABLE jt5_c (id INT PRIMARY KEY, b_id INT, name TEXT)").unwrap();
16493 db.execute("INSERT INTO jt5_c VALUES (100, 10, 'c1')").unwrap();
16494
16495 db.execute("CREATE TABLE jt5_d (id INT PRIMARY KEY, c_id INT, name TEXT)").unwrap();
16496 db.execute("INSERT INTO jt5_d VALUES (1000, 100, 'd1')").unwrap();
16497
16498 db.execute("CREATE TABLE jt5_e (id INT PRIMARY KEY, d_id INT, name TEXT)").unwrap();
16499 db.execute("INSERT INTO jt5_e VALUES (10000, 1000, 'e1')").unwrap();
16500
16501 let sql = "\
16502 SELECT jt5_a.name, jt5_b.name, jt5_c.name, jt5_d.name, jt5_e.name \
16503 FROM jt5_a \
16504 JOIN jt5_b ON jt5_a.id = jt5_b.a_id \
16505 JOIN jt5_c ON jt5_b.id = jt5_c.b_id \
16506 JOIN jt5_d ON jt5_c.id = jt5_d.c_id \
16507 JOIN jt5_e ON jt5_d.id = jt5_e.d_id";
16508
16509 match db.query(sql, &[]) {
16510 Ok(rows) => {
16511 assert_eq!(rows.len(), 1, "5-table chain should produce 1 row, got {}", rows.len());
16512 assert_eq!(rows[0].get(0).unwrap(), &Value::String("a1".to_string()));
16513 assert_eq!(rows[0].get(1).unwrap(), &Value::String("b1".to_string()));
16514 assert_eq!(rows[0].get(2).unwrap(), &Value::String("c1".to_string()));
16515 assert_eq!(rows[0].get(3).unwrap(), &Value::String("d1".to_string()));
16516 assert_eq!(rows[0].get(4).unwrap(), &Value::String("e1".to_string()));
16517 }
16518 Err(e) => {
16519 panic!("5-table JOIN chain failed: {}", e);
16520 }
16521 }
16522 }
16523
16524 #[test]
16525 fn test_join_self_join_two_levels() {
16526 let db = setup_employee_db();
16528
16529 let sql = "\
16530 SELECT e.name, m.name, gm.name \
16531 FROM jt_employees e \
16532 JOIN jt_employees m ON e.manager_id = m.id \
16533 JOIN jt_employees gm ON m.manager_id = gm.id \
16534 ORDER BY e.id";
16535
16536 match db.query(sql, &[]) {
16537 Ok(rows) => {
16538 assert_eq!(rows.len(), 3, "3 employees have grandmanagers, got {}", rows.len());
16543 assert_eq!(rows[0].get(0).unwrap(), &Value::String("Hank".to_string()));
16544 assert_eq!(rows[0].get(1).unwrap(), &Value::String("Frank".to_string()));
16545 assert_eq!(rows[0].get(2).unwrap(), &Value::String("Eve".to_string()));
16546 assert_eq!(rows[1].get(0).unwrap(), &Value::String("Iris".to_string()));
16547 assert_eq!(rows[1].get(1).unwrap(), &Value::String("Frank".to_string()));
16548 assert_eq!(rows[1].get(2).unwrap(), &Value::String("Eve".to_string()));
16549 assert_eq!(rows[2].get(0).unwrap(), &Value::String("Jack".to_string()));
16550 assert_eq!(rows[2].get(1).unwrap(), &Value::String("Grace".to_string()));
16551 assert_eq!(rows[2].get(2).unwrap(), &Value::String("Eve".to_string()));
16552 }
16553 Err(e) => {
16554 panic!("Two-level self-join (grandmanager) failed: {}", e);
16555 }
16556 }
16557 }
16558
16559 #[test]
16560 fn test_join_alias_column_resolution_in_where() {
16561 let db = EmbeddedDatabase::new_in_memory().unwrap();
16564
16565 db.execute("CREATE TABLE wp_term_taxonomy (term_taxonomy_id INT PRIMARY KEY, term_id INT, taxonomy TEXT)").unwrap();
16567 db.execute("CREATE TABLE wp_terms (term_id INT PRIMARY KEY, name TEXT)").unwrap();
16568 db.execute("INSERT INTO wp_terms VALUES (1, 'Uncategorized')").unwrap();
16569 db.execute("INSERT INTO wp_terms VALUES (2, 'News')").unwrap();
16570 db.execute("INSERT INTO wp_term_taxonomy VALUES (1, 1, 'category')").unwrap();
16571 db.execute("INSERT INTO wp_term_taxonomy VALUES (2, 2, 'category')").unwrap();
16572 db.execute("INSERT INTO wp_term_taxonomy VALUES (3, 2, 'post_tag')").unwrap();
16573
16574 let rows = db.query(
16576 "SELECT tt.term_taxonomy_id FROM wp_term_taxonomy AS tt \
16577 INNER JOIN wp_terms AS t ON t.term_id = tt.term_id \
16578 WHERE tt.taxonomy = 'category'",
16579 &[],
16580 ).expect("WordPress-style JOIN with aliased WHERE column should work");
16581 assert_eq!(rows.len(), 2, "Should find 2 category rows");
16582
16583 let rows = db.query(
16585 "SELECT t.name, tt.taxonomy FROM wp_term_taxonomy AS tt \
16586 INNER JOIN wp_terms AS t ON t.term_id = tt.term_id \
16587 WHERE tt.taxonomy = 'category' ORDER BY t.name",
16588 &[],
16589 ).expect("Multi-column aliased JOIN should work");
16590 assert_eq!(rows.len(), 2);
16591
16592 db.execute("CREATE TABLE t1 (id INT PRIMARY KEY, name TEXT)").unwrap();
16594 db.execute("CREATE TABLE t2 (id INT PRIMARY KEY, t1_id INT, value TEXT)").unwrap();
16595 db.execute("INSERT INTO t1 VALUES (1, 'Alice')").unwrap();
16596 db.execute("INSERT INTO t2 VALUES (1, 1, 'hello')").unwrap();
16597
16598 let rows = db.query(
16599 "SELECT a.name, b.value FROM t1 AS a INNER JOIN t2 AS b ON a.id = b.t1_id WHERE a.name = 'Alice'",
16600 &[],
16601 ).expect("JOIN with aliased WHERE column should work");
16602 assert_eq!(rows.len(), 1);
16603 assert_eq!(rows[0].get(0).unwrap(), &Value::String("Alice".to_string()));
16604 assert_eq!(rows[0].get(1).unwrap(), &Value::String("hello".to_string()));
16605
16606 db.execute("CREATE TABLE wp_term_relationships (object_id INT, term_taxonomy_id INT)").unwrap();
16608 db.execute("INSERT INTO wp_term_relationships VALUES (10, 1)").unwrap();
16609 db.execute("INSERT INTO wp_term_relationships VALUES (20, 2)").unwrap();
16610 db.execute("INSERT INTO wp_term_relationships VALUES (30, 3)").unwrap();
16611
16612 let rows = db.query(
16613 "SELECT tr.object_id, tt.taxonomy, t.name \
16614 FROM wp_term_relationships AS tr \
16615 INNER JOIN wp_term_taxonomy AS tt ON tr.term_taxonomy_id = tt.term_taxonomy_id \
16616 INNER JOIN wp_terms AS t ON t.term_id = tt.term_id \
16617 WHERE tt.taxonomy = 'category'",
16618 &[],
16619 ).expect("Three-table WordPress-style JOIN should work");
16620 assert_eq!(rows.len(), 2, "Should find 2 relationships with category taxonomy");
16621
16622 let rows = db.query(
16625 "SELECT tt.term_taxonomy_id FROM wp_term_taxonomy AS tt \
16626 INNER JOIN wp_terms AS t ON tt.term_id = t.term_id \
16627 WHERE tt.taxonomy = 'category'",
16628 &[],
16629 ).expect("JOIN with swapped ON column order should work");
16630 assert_eq!(rows.len(), 2, "Should still find 2 category rows with swapped ON order");
16631 }
16632
16633 #[test]
16638 fn test_truncate_basic() {
16639 let db = EmbeddedDatabase::new_in_memory().unwrap();
16641 db.execute("CREATE TABLE trunc_basic (id INT PRIMARY KEY, name TEXT)").unwrap();
16642 db.execute("INSERT INTO trunc_basic VALUES (1, 'Alice')").unwrap();
16643 db.execute("INSERT INTO trunc_basic VALUES (2, 'Bob')").unwrap();
16644 db.execute("INSERT INTO trunc_basic VALUES (3, 'Charlie')").unwrap();
16645
16646 let rows = db.query("SELECT * FROM trunc_basic", &[]).unwrap();
16647 assert_eq!(rows.len(), 3, "Should have 3 rows before TRUNCATE");
16648
16649 db.execute("TRUNCATE TABLE trunc_basic").unwrap();
16650
16651 let rows = db.query("SELECT * FROM trunc_basic", &[]).unwrap();
16652 assert_eq!(rows.len(), 0, "Should have 0 rows after TRUNCATE");
16653 }
16654
16655 #[test]
16656 fn test_truncate_preserves_schema() {
16657 let db = EmbeddedDatabase::new_in_memory().unwrap();
16659 db.execute("CREATE TABLE trunc_schema (id INT PRIMARY KEY, name TEXT, score FLOAT)").unwrap();
16660 db.execute("INSERT INTO trunc_schema VALUES (1, 'Alice', 95.5)").unwrap();
16661
16662 db.execute("TRUNCATE TABLE trunc_schema").unwrap();
16663
16664 db.execute("INSERT INTO trunc_schema VALUES (10, 'David', 88.0)").unwrap();
16666 let rows = db.query("SELECT id, name, score FROM trunc_schema", &[]).unwrap();
16667 assert_eq!(rows.len(), 1, "Should have 1 row after re-insert");
16668 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(10));
16669 assert_eq!(rows[0].get(1).unwrap(), &Value::String("David".to_string()));
16670 }
16671
16672 #[test]
16673 fn test_truncate_empty_table() {
16674 let db = EmbeddedDatabase::new_in_memory().unwrap();
16676 db.execute("CREATE TABLE trunc_empty (id INT PRIMARY KEY, val TEXT)").unwrap();
16677
16678 let result = db.execute("TRUNCATE TABLE trunc_empty");
16680 assert!(result.is_ok(), "TRUNCATE on empty table should succeed, got: {:?}", result.err());
16681
16682 let rows = db.query("SELECT * FROM trunc_empty", &[]).unwrap();
16683 assert_eq!(rows.len(), 0, "Empty table should remain empty after TRUNCATE");
16684 }
16685
16686 #[test]
16687 fn test_truncate_reinsert_after() {
16688 let db = EmbeddedDatabase::new_in_memory().unwrap();
16690 db.execute("CREATE TABLE trunc_reinsert (id INT PRIMARY KEY, label TEXT)").unwrap();
16691 db.execute("INSERT INTO trunc_reinsert VALUES (1, 'first')").unwrap();
16692 db.execute("INSERT INTO trunc_reinsert VALUES (2, 'second')").unwrap();
16693
16694 db.execute("TRUNCATE TABLE trunc_reinsert").unwrap();
16695
16696 db.execute("INSERT INTO trunc_reinsert VALUES (1, 'new_first')").unwrap();
16698 db.execute("INSERT INTO trunc_reinsert VALUES (3, 'third')").unwrap();
16699
16700 let rows = db.query("SELECT * FROM trunc_reinsert ORDER BY id", &[]).unwrap();
16701 assert_eq!(rows.len(), 2, "Should have 2 rows after re-insert");
16702 assert_eq!(rows[0].get(1).unwrap(), &Value::String("new_first".to_string()));
16703 assert_eq!(rows[1].get(1).unwrap(), &Value::String("third".to_string()));
16704 }
16705
16706 #[test]
16707 fn test_truncate_multiple_tables() {
16708 let db = EmbeddedDatabase::new_in_memory().unwrap();
16710 db.execute("CREATE TABLE trunc_a (id INT PRIMARY KEY, val TEXT)").unwrap();
16711 db.execute("CREATE TABLE trunc_b (id INT PRIMARY KEY, val TEXT)").unwrap();
16712
16713 db.execute("INSERT INTO trunc_a VALUES (1, 'a1')").unwrap();
16714 db.execute("INSERT INTO trunc_a VALUES (2, 'a2')").unwrap();
16715 db.execute("INSERT INTO trunc_b VALUES (10, 'b1')").unwrap();
16716 db.execute("INSERT INTO trunc_b VALUES (20, 'b2')").unwrap();
16717 db.execute("INSERT INTO trunc_b VALUES (30, 'b3')").unwrap();
16718
16719 db.execute("TRUNCATE TABLE trunc_a").unwrap();
16721
16722 let rows_a = db.query("SELECT * FROM trunc_a", &[]).unwrap();
16723 let rows_b = db.query("SELECT * FROM trunc_b", &[]).unwrap();
16724 assert_eq!(rows_a.len(), 0, "Table A should be empty after TRUNCATE");
16725 assert_eq!(rows_b.len(), 3, "Table B should be unaffected by TRUNCATE of A");
16726
16727 db.execute("TRUNCATE TABLE trunc_b").unwrap();
16729 let rows_b = db.query("SELECT * FROM trunc_b", &[]).unwrap();
16730 assert_eq!(rows_b.len(), 0, "Table B should be empty after TRUNCATE");
16731 }
16732
16733 #[test]
16734 fn test_truncate_with_many_rows() {
16735 let db = EmbeddedDatabase::new_in_memory().unwrap();
16737 db.execute("CREATE TABLE trunc_many (id INT PRIMARY KEY, val INT)").unwrap();
16738
16739 for i in 1..=150 {
16740 db.execute(&format!("INSERT INTO trunc_many VALUES ({}, {})", i, i * 10)).unwrap();
16741 }
16742
16743 let rows = db.query("SELECT COUNT(*) FROM trunc_many", &[]).unwrap();
16744 match rows[0].get(0).unwrap() {
16746 Value::Int8(n) => assert_eq!(*n, 150, "Should have 150 rows before TRUNCATE"),
16747 other => panic!("Expected Int8 count, got {:?}", other),
16748 }
16749
16750 db.execute("TRUNCATE TABLE trunc_many").unwrap();
16751
16752 let rows = db.query("SELECT COUNT(*) FROM trunc_many", &[]).unwrap();
16753 match rows[0].get(0).unwrap() {
16754 Value::Int8(n) => assert_eq!(*n, 0, "Should have 0 rows after TRUNCATE"),
16755 other => panic!("Expected Int8 count of 0, got {:?}", other),
16756 }
16757 }
16758
16759 #[test]
16760 fn test_truncate_preserves_indexes() {
16761 let db = EmbeddedDatabase::new_in_memory().unwrap();
16763 db.execute("CREATE TABLE trunc_idx (id INT PRIMARY KEY, name TEXT, score INT)").unwrap();
16764 db.execute("INSERT INTO trunc_idx VALUES (1, 'Alice', 90)").unwrap();
16765 db.execute("INSERT INTO trunc_idx VALUES (2, 'Bob', 85)").unwrap();
16766
16767 db.execute("TRUNCATE TABLE trunc_idx").unwrap();
16768
16769 db.execute("INSERT INTO trunc_idx VALUES (5, 'Eve', 95)").unwrap();
16771 db.execute("INSERT INTO trunc_idx VALUES (6, 'Frank', 80)").unwrap();
16772
16773 let rows = db.query("SELECT * FROM trunc_idx WHERE id = 5", &[]).unwrap();
16774 assert_eq!(rows.len(), 1, "PK lookup should work after TRUNCATE + re-insert");
16775 assert_eq!(rows[0].get(1).unwrap(), &Value::String("Eve".to_string()));
16776
16777 let rows = db.query("SELECT name FROM trunc_idx ORDER BY score DESC", &[]).unwrap();
16779 assert_eq!(rows.len(), 2);
16780 assert_eq!(rows[0].get(0).unwrap(), &Value::String("Eve".to_string()));
16781 assert_eq!(rows[1].get(0).unwrap(), &Value::String("Frank".to_string()));
16782 }
16783
16784 #[test]
16785 fn test_truncate_nonexistent_table() {
16786 let db = EmbeddedDatabase::new_in_memory().unwrap();
16788
16789 let result = db.execute("TRUNCATE TABLE no_such_table");
16790 assert!(result.is_err(), "TRUNCATE on non-existent table should error");
16791 let err_msg = result.unwrap_err().to_string();
16792 assert!(
16794 err_msg.to_lowercase().contains("no_such_table") || err_msg.to_lowercase().contains("not exist") || err_msg.to_lowercase().contains("not found"),
16795 "Error should mention missing table, got: {}",
16796 err_msg
16797 );
16798 }
16799
16800 #[test]
16801 fn test_truncate_returns_zero() {
16802 let db = EmbeddedDatabase::new_in_memory().unwrap();
16807 db.execute("CREATE TABLE trunc_count (id INT PRIMARY KEY)").unwrap();
16808 db.execute("INSERT INTO trunc_count VALUES (1)").unwrap();
16809 db.execute("INSERT INTO trunc_count VALUES (2)").unwrap();
16810 db.execute("INSERT INTO trunc_count VALUES (3)").unwrap();
16811
16812 let count = db.execute("TRUNCATE TABLE trunc_count").unwrap();
16813 assert_eq!(count, 3, "TRUNCATE returns actual row count");
16814
16815 let rows = db.query("SELECT * FROM trunc_count", &[]).unwrap();
16817 assert_eq!(rows.len(), 0, "All rows should be removed");
16818 }
16819
16820 #[test]
16821 fn test_truncate_then_count() {
16822 let db = EmbeddedDatabase::new_in_memory().unwrap();
16824 db.execute("CREATE TABLE trunc_cnt (id INT PRIMARY KEY, x INT)").unwrap();
16825 for i in 1..=5 {
16826 db.execute(&format!("INSERT INTO trunc_cnt VALUES ({}, {})", i, i)).unwrap();
16827 }
16828
16829 db.execute("TRUNCATE TABLE trunc_cnt").unwrap();
16830
16831 let rows = db.query("SELECT COUNT(*) FROM trunc_cnt", &[]).unwrap();
16832 match rows[0].get(0).unwrap() {
16833 Value::Int8(n) => assert_eq!(*n, 0),
16834 other => panic!("Expected Int8(0), got {:?}", other),
16835 }
16836
16837 db.execute("INSERT INTO trunc_cnt VALUES (10, 100)").unwrap();
16839 db.execute("INSERT INTO trunc_cnt VALUES (20, 200)").unwrap();
16840
16841 let rows = db.query("SELECT COUNT(*) FROM trunc_cnt", &[]).unwrap();
16842 match rows[0].get(0).unwrap() {
16843 Value::Int8(n) => assert_eq!(*n, 2, "COUNT should be 2 after re-inserting 2 rows"),
16844 other => panic!("Expected Int8(2), got {:?}", other),
16845 }
16846 }
16847
16848 #[test]
16853 fn test_fk_basic_creation() {
16854 let db = EmbeddedDatabase::new_in_memory().unwrap();
16856 db.execute("CREATE TABLE fk_parent (id INT PRIMARY KEY, name TEXT)").unwrap();
16857
16858 let result = db.execute(
16859 "CREATE TABLE fk_child (
16860 id INT PRIMARY KEY,
16861 parent_id INT,
16862 FOREIGN KEY (parent_id) REFERENCES fk_parent(id)
16863 )"
16864 );
16865 assert!(result.is_ok(), "Creating table with FK constraint should succeed, got: {:?}", result.err());
16866 }
16867
16868 #[test]
16869 fn test_fk_insert_valid() {
16870 let db = EmbeddedDatabase::new_in_memory().unwrap();
16872 db.execute("CREATE TABLE fk_iv_parent (id INT PRIMARY KEY, name TEXT)").unwrap();
16873 db.execute(
16874 "CREATE TABLE fk_iv_child (
16875 id INT PRIMARY KEY,
16876 parent_id INT,
16877 FOREIGN KEY (parent_id) REFERENCES fk_iv_parent(id)
16878 )"
16879 ).unwrap();
16880
16881 db.execute("INSERT INTO fk_iv_parent VALUES (1, 'Alice')").unwrap();
16882 let result = db.execute("INSERT INTO fk_iv_child VALUES (100, 1)");
16883 assert!(result.is_ok(), "Insert with valid FK reference should succeed, got: {:?}", result.err());
16884
16885 let rows = db.query("SELECT * FROM fk_iv_child WHERE parent_id = 1", &[]).unwrap();
16886 assert_eq!(rows.len(), 1, "Child row should be inserted");
16887 }
16888
16889 #[test]
16890 fn test_fk_insert_invalid() {
16891 let db = EmbeddedDatabase::new_in_memory().unwrap();
16893 db.execute("CREATE TABLE fk_ii_parent (id INT PRIMARY KEY, name TEXT)").unwrap();
16894 db.execute(
16895 "CREATE TABLE fk_ii_child (
16896 id INT PRIMARY KEY,
16897 parent_id INT,
16898 FOREIGN KEY (parent_id) REFERENCES fk_ii_parent(id)
16899 )"
16900 ).unwrap();
16901
16902 let result = db.execute("INSERT INTO fk_ii_child VALUES (1, 999)");
16904 assert!(result.is_err(), "Insert with invalid FK reference should fail");
16905 let err_msg = result.unwrap_err().to_string();
16906 assert!(
16907 err_msg.to_lowercase().contains("foreign key") || err_msg.to_lowercase().contains("constraint"),
16908 "Error should mention foreign key constraint, got: {}",
16909 err_msg
16910 );
16911 }
16912
16913 #[test]
16914 fn test_fk_insert_null_fk_value() {
16915 let db = EmbeddedDatabase::new_in_memory().unwrap();
16917 db.execute("CREATE TABLE fk_null_parent (id INT PRIMARY KEY, name TEXT)").unwrap();
16918 db.execute(
16919 "CREATE TABLE fk_null_child (
16920 id INT PRIMARY KEY,
16921 parent_id INT,
16922 FOREIGN KEY (parent_id) REFERENCES fk_null_parent(id)
16923 )"
16924 ).unwrap();
16925
16926 let result = db.execute("INSERT INTO fk_null_child VALUES (1, NULL)");
16928 assert!(result.is_ok(), "Insert with NULL FK value should succeed, got: {:?}", result.err());
16929
16930 let rows = db.query("SELECT * FROM fk_null_child", &[]).unwrap();
16931 assert_eq!(rows.len(), 1);
16932 assert_eq!(rows[0].get(1).unwrap(), &Value::Null);
16933 }
16934
16935 #[test]
16936 fn test_fk_delete_parent_default_action() {
16937 let db = EmbeddedDatabase::new_in_memory().unwrap();
16939 db.execute("CREATE TABLE fk_dp_parent (id INT PRIMARY KEY, name TEXT)").unwrap();
16940 db.execute(
16941 "CREATE TABLE fk_dp_child (
16942 id INT PRIMARY KEY,
16943 parent_id INT,
16944 FOREIGN KEY (parent_id) REFERENCES fk_dp_parent(id)
16945 )"
16946 ).unwrap();
16947
16948 db.execute("INSERT INTO fk_dp_parent VALUES (1, 'Alice')").unwrap();
16949 db.execute("INSERT INTO fk_dp_child VALUES (100, 1)").unwrap();
16950
16951 let result = db.execute("DELETE FROM fk_dp_parent WHERE id = 1");
16953 assert!(result.is_err(), "Deleting parent with referencing children should fail with default action");
16954 let err_msg = result.unwrap_err().to_string();
16955 assert!(
16956 err_msg.to_lowercase().contains("foreign key") || err_msg.to_lowercase().contains("constraint") || err_msg.to_lowercase().contains("referenced"),
16957 "Error should mention FK constraint, got: {}",
16958 err_msg
16959 );
16960
16961 let parent_rows = db.query("SELECT * FROM fk_dp_parent", &[]).unwrap();
16963 let child_rows = db.query("SELECT * FROM fk_dp_child", &[]).unwrap();
16964 assert_eq!(parent_rows.len(), 1, "Parent should still exist");
16965 assert_eq!(child_rows.len(), 1, "Child should still exist");
16966 }
16967
16968 #[test]
16969 fn test_fk_cascade_delete() {
16970 let db = EmbeddedDatabase::new_in_memory().unwrap();
16972 db.execute("CREATE TABLE fk_cd_parent (id INT PRIMARY KEY, name TEXT)").unwrap();
16973 db.execute(
16974 "CREATE TABLE fk_cd_child (
16975 id INT PRIMARY KEY,
16976 parent_id INT,
16977 FOREIGN KEY (parent_id) REFERENCES fk_cd_parent(id) ON DELETE CASCADE
16978 )"
16979 ).unwrap();
16980
16981 db.execute("INSERT INTO fk_cd_parent VALUES (1, 'Alice')").unwrap();
16982 db.execute("INSERT INTO fk_cd_parent VALUES (2, 'Bob')").unwrap();
16983 db.execute("INSERT INTO fk_cd_child VALUES (100, 1)").unwrap();
16984 db.execute("INSERT INTO fk_cd_child VALUES (101, 1)").unwrap();
16985 db.execute("INSERT INTO fk_cd_child VALUES (102, 2)").unwrap();
16986
16987 db.execute("DELETE FROM fk_cd_parent WHERE id = 1").unwrap();
16989
16990 let parent_rows = db.query("SELECT * FROM fk_cd_parent", &[]).unwrap();
16991 assert_eq!(parent_rows.len(), 1, "Only parent id=2 should remain");
16992 assert_eq!(parent_rows[0].get(0).unwrap(), &Value::Int4(2));
16993
16994 let child_rows = db.query("SELECT * FROM fk_cd_child", &[]).unwrap();
16995 assert_eq!(child_rows.len(), 1, "Only child 102 (referencing parent 2) should remain");
16996 assert_eq!(child_rows[0].get(0).unwrap(), &Value::Int4(102));
16997 }
16998
16999 #[test]
17000 fn test_fk_set_null_delete() {
17001 let db = EmbeddedDatabase::new_in_memory().unwrap();
17003 db.execute("CREATE TABLE fk_sn_parent (id INT PRIMARY KEY, name TEXT)").unwrap();
17004 db.execute(
17005 "CREATE TABLE fk_sn_child (
17006 id INT PRIMARY KEY,
17007 parent_id INT,
17008 FOREIGN KEY (parent_id) REFERENCES fk_sn_parent(id) ON DELETE SET NULL
17009 )"
17010 ).unwrap();
17011
17012 db.execute("INSERT INTO fk_sn_parent VALUES (1, 'Alice')").unwrap();
17013 db.execute("INSERT INTO fk_sn_child VALUES (100, 1)").unwrap();
17014 db.execute("INSERT INTO fk_sn_child VALUES (101, 1)").unwrap();
17015
17016 db.execute("DELETE FROM fk_sn_parent WHERE id = 1").unwrap();
17018
17019 let child_rows = db.query("SELECT id, parent_id FROM fk_sn_child ORDER BY id", &[]).unwrap();
17020 assert_eq!(child_rows.len(), 2, "Child rows should still exist");
17021 assert_eq!(child_rows[0].get(1).unwrap(), &Value::Null, "parent_id should be NULL after SET NULL");
17023 assert_eq!(child_rows[1].get(1).unwrap(), &Value::Null, "parent_id should be NULL after SET NULL");
17024 }
17025
17026 #[test]
17027 fn test_fk_restrict_delete() {
17028 let db = EmbeddedDatabase::new_in_memory().unwrap();
17030 db.execute("CREATE TABLE fk_rd_parent (id INT PRIMARY KEY, name TEXT)").unwrap();
17031 db.execute(
17032 "CREATE TABLE fk_rd_child (
17033 id INT PRIMARY KEY,
17034 parent_id INT,
17035 FOREIGN KEY (parent_id) REFERENCES fk_rd_parent(id) ON DELETE RESTRICT
17036 )"
17037 ).unwrap();
17038
17039 db.execute("INSERT INTO fk_rd_parent VALUES (1, 'Alice')").unwrap();
17040 db.execute("INSERT INTO fk_rd_child VALUES (100, 1)").unwrap();
17041
17042 let result = db.execute("DELETE FROM fk_rd_parent WHERE id = 1");
17043 assert!(result.is_err(), "RESTRICT should prevent parent deletion");
17044 let err_msg = result.unwrap_err().to_string();
17045 assert!(
17046 err_msg.to_lowercase().contains("foreign key") || err_msg.to_lowercase().contains("constraint"),
17047 "Error should mention FK constraint, got: {}",
17048 err_msg
17049 );
17050
17051 let parent_rows = db.query("SELECT * FROM fk_rd_parent", &[]).unwrap();
17053 let child_rows = db.query("SELECT * FROM fk_rd_child", &[]).unwrap();
17054 assert_eq!(parent_rows.len(), 1);
17055 assert_eq!(child_rows.len(), 1);
17056 }
17057
17058 #[test]
17059 fn test_fk_restrict_allows_delete_when_no_children() {
17060 let db = EmbeddedDatabase::new_in_memory().unwrap();
17062 db.execute("CREATE TABLE fk_ra_parent (id INT PRIMARY KEY, name TEXT)").unwrap();
17063 db.execute(
17064 "CREATE TABLE fk_ra_child (
17065 id INT PRIMARY KEY,
17066 parent_id INT,
17067 FOREIGN KEY (parent_id) REFERENCES fk_ra_parent(id) ON DELETE RESTRICT
17068 )"
17069 ).unwrap();
17070
17071 db.execute("INSERT INTO fk_ra_parent VALUES (1, 'Alice')").unwrap();
17072 db.execute("INSERT INTO fk_ra_parent VALUES (2, 'Bob')").unwrap();
17073 db.execute("INSERT INTO fk_ra_child VALUES (100, 1)").unwrap();
17075
17076 let result = db.execute("DELETE FROM fk_ra_parent WHERE id = 2");
17078 assert!(result.is_ok(), "Should allow deletion of unreferenced parent, got: {:?}", result.err());
17079
17080 let parent_rows = db.query("SELECT * FROM fk_ra_parent", &[]).unwrap();
17081 assert_eq!(parent_rows.len(), 1, "Only parent 1 should remain");
17082 }
17083
17084 #[test]
17085 fn test_fk_no_action_delete() {
17086 let db = EmbeddedDatabase::new_in_memory().unwrap();
17088 db.execute("CREATE TABLE fk_na_parent (id INT PRIMARY KEY, name TEXT)").unwrap();
17089 db.execute(
17090 "CREATE TABLE fk_na_child (
17091 id INT PRIMARY KEY,
17092 parent_id INT,
17093 FOREIGN KEY (parent_id) REFERENCES fk_na_parent(id) ON DELETE NO ACTION
17094 )"
17095 ).unwrap();
17096
17097 db.execute("INSERT INTO fk_na_parent VALUES (1, 'Alice')").unwrap();
17098 db.execute("INSERT INTO fk_na_child VALUES (100, 1)").unwrap();
17099
17100 let result = db.execute("DELETE FROM fk_na_parent WHERE id = 1");
17101 assert!(result.is_err(), "NO ACTION should prevent parent deletion when children exist");
17102 }
17103
17104 #[test]
17105 fn test_fk_self_referencing() {
17106 let db = EmbeddedDatabase::new_in_memory().unwrap();
17108 db.execute(
17109 "CREATE TABLE fk_self_emp (
17110 id INT PRIMARY KEY,
17111 name TEXT,
17112 manager_id INT,
17113 FOREIGN KEY (manager_id) REFERENCES fk_self_emp(id)
17114 )"
17115 ).unwrap();
17116
17117 db.execute("INSERT INTO fk_self_emp VALUES (1, 'CEO', NULL)").unwrap();
17119
17120 db.execute("INSERT INTO fk_self_emp VALUES (2, 'VP', 1)").unwrap();
17122
17123 let rows = db.query("SELECT * FROM fk_self_emp ORDER BY id", &[]).unwrap();
17124 assert_eq!(rows.len(), 2, "Should have 2 employees");
17125
17126 let result = db.execute("INSERT INTO fk_self_emp VALUES (3, 'Ghost', 999)");
17128 assert!(result.is_err(), "Self-referencing FK with invalid ID should fail");
17129 }
17130
17131 #[test]
17132 fn test_fk_multiple_fks_on_one_table() {
17133 let db = EmbeddedDatabase::new_in_memory().unwrap();
17135 db.execute("CREATE TABLE fk_m_departments (id INT PRIMARY KEY, name TEXT)").unwrap();
17136 db.execute("CREATE TABLE fk_m_managers (id INT PRIMARY KEY, name TEXT)").unwrap();
17137 db.execute(
17138 "CREATE TABLE fk_m_employees (
17139 id INT PRIMARY KEY,
17140 name TEXT,
17141 dept_id INT,
17142 manager_id INT,
17143 FOREIGN KEY (dept_id) REFERENCES fk_m_departments(id),
17144 FOREIGN KEY (manager_id) REFERENCES fk_m_managers(id)
17145 )"
17146 ).unwrap();
17147
17148 db.execute("INSERT INTO fk_m_departments VALUES (1, 'Engineering')").unwrap();
17149 db.execute("INSERT INTO fk_m_managers VALUES (10, 'Alice')").unwrap();
17150
17151 let result = db.execute("INSERT INTO fk_m_employees VALUES (100, 'Bob', 1, 10)");
17153 assert!(result.is_ok(), "Insert with valid references to both FK parents should succeed, got: {:?}", result.err());
17154
17155 let result = db.execute("INSERT INTO fk_m_employees VALUES (101, 'Carol', 999, 10)");
17157 assert!(result.is_err(), "Insert with invalid dept FK should fail");
17158
17159 let result = db.execute("INSERT INTO fk_m_employees VALUES (102, 'Dave', 1, 999)");
17161 assert!(result.is_err(), "Insert with invalid manager FK should fail");
17162 }
17163
17164 #[test]
17165 fn test_fk_cascade_delete_multiple_children() {
17166 let db = EmbeddedDatabase::new_in_memory().unwrap();
17168 db.execute("CREATE TABLE fk_cm_parent (id INT PRIMARY KEY, name TEXT)").unwrap();
17169 db.execute(
17170 "CREATE TABLE fk_cm_child (
17171 id INT PRIMARY KEY,
17172 parent_id INT,
17173 label TEXT,
17174 FOREIGN KEY (parent_id) REFERENCES fk_cm_parent(id) ON DELETE CASCADE
17175 )"
17176 ).unwrap();
17177
17178 db.execute("INSERT INTO fk_cm_parent VALUES (1, 'Alpha')").unwrap();
17179 for i in 1..=5 {
17180 db.execute(&format!("INSERT INTO fk_cm_child VALUES ({}, 1, 'child_{}')", i, i)).unwrap();
17181 }
17182
17183 let child_count = db.query("SELECT COUNT(*) FROM fk_cm_child", &[]).unwrap();
17184 match child_count[0].get(0).unwrap() {
17185 Value::Int8(n) => assert_eq!(*n, 5),
17186 other => panic!("Expected 5 children, got {:?}", other),
17187 }
17188
17189 db.execute("DELETE FROM fk_cm_parent WHERE id = 1").unwrap();
17191
17192 let child_count = db.query("SELECT COUNT(*) FROM fk_cm_child", &[]).unwrap();
17193 match child_count[0].get(0).unwrap() {
17194 Value::Int8(n) => assert_eq!(*n, 0, "All children should be cascade-deleted"),
17195 other => panic!("Expected 0 children after cascade, got {:?}", other),
17196 }
17197 }
17198
17199 #[test]
17200 fn test_fk_drop_parent_table() {
17201 let db = EmbeddedDatabase::new_in_memory().unwrap();
17205 db.execute("CREATE TABLE fk_drop_parent (id INT PRIMARY KEY, name TEXT)").unwrap();
17206 db.execute(
17207 "CREATE TABLE fk_drop_child (
17208 id INT PRIMARY KEY,
17209 parent_id INT,
17210 FOREIGN KEY (parent_id) REFERENCES fk_drop_parent(id)
17211 )"
17212 ).unwrap();
17213
17214 db.execute("INSERT INTO fk_drop_parent VALUES (1, 'Alice')").unwrap();
17215 db.execute("INSERT INTO fk_drop_child VALUES (100, 1)").unwrap();
17216
17217 match db.execute("DROP TABLE fk_drop_parent") {
17219 Ok(_) => {
17220 let child_rows = db.query("SELECT * FROM fk_drop_child", &[]).unwrap();
17223 assert_eq!(child_rows.len(), 1, "Child table data should still exist after parent drop");
17224 }
17225 Err(e) => {
17226 let err_msg = e.to_string();
17228 assert!(
17229 err_msg.to_lowercase().contains("foreign key") || err_msg.to_lowercase().contains("referenced") || err_msg.to_lowercase().contains("depends"),
17230 "Error should mention FK dependency, got: {}",
17231 err_msg
17232 );
17233 let parent_rows = db.query("SELECT * FROM fk_drop_parent", &[]).unwrap();
17235 assert_eq!(parent_rows.len(), 1, "Parent should still exist after failed drop");
17236 }
17237 }
17238 }
17239
17240 #[test]
17241 fn test_fk_cascade_update() {
17242 let db = EmbeddedDatabase::new_in_memory().unwrap();
17246 db.execute("CREATE TABLE fk_cu_parent (id INT PRIMARY KEY, name TEXT)").unwrap();
17247 db.execute(
17248 "CREATE TABLE fk_cu_child (
17249 id INT PRIMARY KEY,
17250 parent_id INT,
17251 FOREIGN KEY (parent_id) REFERENCES fk_cu_parent(id) ON UPDATE CASCADE
17252 )"
17253 ).unwrap();
17254
17255 db.execute("INSERT INTO fk_cu_parent VALUES (1, 'Alice')").unwrap();
17256 db.execute("INSERT INTO fk_cu_child VALUES (100, 1)").unwrap();
17257
17258 match db.execute("UPDATE fk_cu_parent SET id = 10 WHERE id = 1") {
17260 Ok(_) => {
17261 let child_rows = db.query("SELECT parent_id FROM fk_cu_child WHERE id = 100", &[]).unwrap();
17263 assert_eq!(child_rows.len(), 1, "Child should still exist");
17264 match child_rows[0].get(0).unwrap() {
17265 Value::Int4(v) => {
17266 if *v == 10 {
17267 } else {
17269 assert_eq!(*v, 1, "Without cascade enforcement, child should retain old FK value");
17272 }
17273 }
17274 other => panic!("Expected Int4 for parent_id, got {:?}", other),
17275 }
17276 }
17277 Err(e) => {
17278 println!("UPDATE parent PK with ON UPDATE CASCADE result: {}", e);
17280 }
17281 }
17282 }
17283
17284 #[test]
17285 fn test_fk_insert_then_delete_child_then_delete_parent() {
17286 let db = EmbeddedDatabase::new_in_memory().unwrap();
17288 db.execute("CREATE TABLE fk_idc_parent (id INT PRIMARY KEY, name TEXT)").unwrap();
17289 db.execute(
17290 "CREATE TABLE fk_idc_child (
17291 id INT PRIMARY KEY,
17292 parent_id INT,
17293 FOREIGN KEY (parent_id) REFERENCES fk_idc_parent(id)
17294 )"
17295 ).unwrap();
17296
17297 db.execute("INSERT INTO fk_idc_parent VALUES (1, 'Alice')").unwrap();
17298 db.execute("INSERT INTO fk_idc_child VALUES (100, 1)").unwrap();
17299
17300 db.execute("DELETE FROM fk_idc_child WHERE id = 100").unwrap();
17302
17303 let result = db.execute("DELETE FROM fk_idc_parent WHERE id = 1");
17305 assert!(result.is_ok(), "Should be able to delete parent after all children removed, got: {:?}", result.err());
17306
17307 let parent_rows = db.query("SELECT * FROM fk_idc_parent", &[]).unwrap();
17308 assert_eq!(parent_rows.len(), 0, "Parent should be deleted");
17309 }
17310
17311 fn setup_group_by_db() -> EmbeddedDatabase {
17319 let db = EmbeddedDatabase::new_in_memory().unwrap();
17320 db.execute("CREATE TABLE gb_sales (id INT, department TEXT, amount INT, rating FLOAT8)").unwrap();
17321 db.execute("INSERT INTO gb_sales VALUES (1, 'Engineering', 100, 4.5)").unwrap();
17322 db.execute("INSERT INTO gb_sales VALUES (2, 'Engineering', 200, 3.8)").unwrap();
17323 db.execute("INSERT INTO gb_sales VALUES (3, 'Engineering', 150, 4.2)").unwrap();
17324 db.execute("INSERT INTO gb_sales VALUES (4, 'Sales', 80, 3.0)").unwrap();
17325 db.execute("INSERT INTO gb_sales VALUES (5, 'Sales', 120, 4.1)").unwrap();
17326 db.execute("INSERT INTO gb_sales VALUES (6, 'Marketing', 90, 3.5)").unwrap();
17327 db.execute("INSERT INTO gb_sales VALUES (7, 'HR', 60, 2.8)").unwrap();
17328 db
17329 }
17330
17331 #[test]
17332 fn test_group_by_having_count() {
17333 let db = setup_group_by_db();
17335 let rows = db.query(
17336 "SELECT department, COUNT(*) AS cnt FROM gb_sales GROUP BY department HAVING COUNT(*) > 1 ORDER BY department",
17337 &[],
17338 ).unwrap();
17339 assert_eq!(rows.len(), 2, "Expected 2 departments with count > 1, got {}", rows.len());
17341 assert_eq!(rows[0].get(0).unwrap(), &Value::String("Engineering".to_string()));
17342 assert_eq!(rows[0].get(1).unwrap(), &Value::Int8(3));
17343 assert_eq!(rows[1].get(0).unwrap(), &Value::String("Sales".to_string()));
17344 assert_eq!(rows[1].get(1).unwrap(), &Value::Int8(2));
17345 }
17346
17347 #[test]
17348 fn test_group_by_having_sum() {
17349 let db = setup_group_by_db();
17351 let rows = db.query(
17352 "SELECT department, SUM(amount) AS total FROM gb_sales GROUP BY department HAVING SUM(amount) > 100 ORDER BY department",
17353 &[],
17354 ).unwrap();
17355 assert_eq!(rows.len(), 2, "Expected 2 departments with sum > 100, got {}", rows.len());
17358 assert_eq!(rows[0].get(0).unwrap(), &Value::String("Engineering".to_string()));
17359 assert_eq!(rows[0].get(1).unwrap(), &Value::Int8(450));
17360 assert_eq!(rows[1].get(0).unwrap(), &Value::String("Sales".to_string()));
17361 assert_eq!(rows[1].get(1).unwrap(), &Value::Int8(200));
17362 }
17363
17364 #[test]
17365 fn test_group_by_having_avg() {
17366 let db = setup_group_by_db();
17368 let rows = db.query(
17369 "SELECT department, AVG(rating) FROM gb_sales GROUP BY department HAVING AVG(rating) > 3.5 ORDER BY department",
17370 &[],
17371 ).unwrap();
17372 assert_eq!(rows.len(), 2, "Expected 2 departments with avg > 3.5, got {}", rows.len());
17375 assert_eq!(rows[0].get(0).unwrap(), &Value::String("Engineering".to_string()));
17376 assert_eq!(rows[1].get(0).unwrap(), &Value::String("Sales".to_string()));
17377 if let Value::Float8(avg) = rows[0].get(1).unwrap() {
17379 assert!((avg - 4.1666).abs() < 0.01, "Engineering avg should be ~4.167, got {}", avg);
17380 } else {
17381 panic!("AVG should return Float8, got {:?}", rows[0].get(1));
17382 }
17383 }
17384
17385 #[test]
17386 fn test_group_by_having_multiple_conditions() {
17387 let db = setup_group_by_db();
17389 let rows = db.query(
17390 "SELECT department, COUNT(*), SUM(amount) FROM gb_sales GROUP BY department HAVING COUNT(*) > 1 AND SUM(amount) > 150 ORDER BY department",
17391 &[],
17392 ).unwrap();
17393 assert_eq!(rows.len(), 2, "Expected 2 departments matching both conditions, got {}", rows.len());
17398 assert_eq!(rows[0].get(0).unwrap(), &Value::String("Engineering".to_string()));
17399 assert_eq!(rows[1].get(0).unwrap(), &Value::String("Sales".to_string()));
17400 }
17401
17402 #[test]
17403 fn test_group_by_having_no_match() {
17404 let db = setup_group_by_db();
17406 let rows = db.query(
17407 "SELECT department, COUNT(*) FROM gb_sales GROUP BY department HAVING COUNT(*) > 100",
17408 &[],
17409 ).unwrap();
17410 assert_eq!(rows.len(), 0, "No group should have count > 100");
17411 }
17412
17413 #[test]
17414 fn test_group_by_having_all_match() {
17415 let db = setup_group_by_db();
17417 let rows = db.query(
17418 "SELECT department, COUNT(*) FROM gb_sales GROUP BY department HAVING COUNT(*) >= 1 ORDER BY department",
17419 &[],
17420 ).unwrap();
17421 assert_eq!(rows.len(), 4, "All 4 departments should match count >= 1, got {}", rows.len());
17422 }
17423
17424 #[test]
17425 fn test_group_by_multiple_columns() {
17426 let db = EmbeddedDatabase::new_in_memory().unwrap();
17428 db.execute("CREATE TABLE gb_multi (region TEXT, category TEXT, amount INT)").unwrap();
17429 db.execute("INSERT INTO gb_multi VALUES ('East', 'A', 10)").unwrap();
17430 db.execute("INSERT INTO gb_multi VALUES ('East', 'A', 20)").unwrap();
17431 db.execute("INSERT INTO gb_multi VALUES ('East', 'B', 30)").unwrap();
17432 db.execute("INSERT INTO gb_multi VALUES ('West', 'A', 40)").unwrap();
17433 db.execute("INSERT INTO gb_multi VALUES ('West', 'B', 50)").unwrap();
17434
17435 let rows = db.query(
17436 "SELECT region, category, SUM(amount) FROM gb_multi GROUP BY region, category ORDER BY region, category",
17437 &[],
17438 ).unwrap();
17439 assert_eq!(rows.len(), 4, "Expected 4 groups from 2-column GROUP BY, got {}", rows.len());
17441 assert_eq!(rows[0].get(0).unwrap(), &Value::String("East".to_string()));
17442 assert_eq!(rows[0].get(1).unwrap(), &Value::String("A".to_string()));
17443 assert_eq!(rows[0].get(2).unwrap(), &Value::Int8(30));
17444 assert_eq!(rows[1].get(0).unwrap(), &Value::String("East".to_string()));
17445 assert_eq!(rows[1].get(1).unwrap(), &Value::String("B".to_string()));
17446 assert_eq!(rows[1].get(2).unwrap(), &Value::Int8(30));
17447 assert_eq!(rows[2].get(0).unwrap(), &Value::String("West".to_string()));
17448 assert_eq!(rows[2].get(1).unwrap(), &Value::String("A".to_string()));
17449 assert_eq!(rows[2].get(2).unwrap(), &Value::Int8(40));
17450 assert_eq!(rows[3].get(0).unwrap(), &Value::String("West".to_string()));
17451 assert_eq!(rows[3].get(1).unwrap(), &Value::String("B".to_string()));
17452 assert_eq!(rows[3].get(2).unwrap(), &Value::Int8(50));
17453 }
17454
17455 #[test]
17456 fn test_group_by_with_order_by() {
17457 let db = setup_group_by_db();
17459 let rows = db.query(
17460 "SELECT department, SUM(amount) AS total FROM gb_sales GROUP BY department ORDER BY total ASC",
17461 &[],
17462 ).unwrap();
17463 assert_eq!(rows.len(), 4, "Expected 4 departments, got {}", rows.len());
17464 assert_eq!(rows[0].get(0).unwrap(), &Value::String("HR".to_string()));
17466 assert_eq!(rows[0].get(1).unwrap(), &Value::Int8(60));
17467 assert_eq!(rows[1].get(0).unwrap(), &Value::String("Marketing".to_string()));
17468 assert_eq!(rows[1].get(1).unwrap(), &Value::Int8(90));
17469 assert_eq!(rows[2].get(0).unwrap(), &Value::String("Sales".to_string()));
17470 assert_eq!(rows[2].get(1).unwrap(), &Value::Int8(200));
17471 assert_eq!(rows[3].get(0).unwrap(), &Value::String("Engineering".to_string()));
17472 assert_eq!(rows[3].get(1).unwrap(), &Value::Int8(450));
17473 }
17474
17475 #[test]
17476 fn test_group_by_null_values() {
17477 let db = EmbeddedDatabase::new_in_memory().unwrap();
17479 db.execute("CREATE TABLE gb_nulls (category TEXT, val INT)").unwrap();
17480 db.execute("INSERT INTO gb_nulls VALUES ('A', 10)").unwrap();
17481 db.execute("INSERT INTO gb_nulls VALUES ('A', 20)").unwrap();
17482 db.execute("INSERT INTO gb_nulls VALUES (NULL, 30)").unwrap();
17483 db.execute("INSERT INTO gb_nulls VALUES (NULL, 40)").unwrap();
17484
17485 let rows = db.query(
17486 "SELECT category, SUM(val) FROM gb_nulls GROUP BY category ORDER BY category",
17487 &[],
17488 ).unwrap();
17489 assert_eq!(rows.len(), 2, "Expected 2 groups (A and NULL), got {}", rows.len());
17491
17492 let a_group = rows.iter().find(|r| r.get(0).unwrap() == &Value::String("A".to_string()));
17494 let null_group = rows.iter().find(|r| r.get(0).unwrap() == &Value::Null);
17495
17496 assert!(a_group.is_some(), "Should have an 'A' group");
17497 assert_eq!(a_group.unwrap().get(1).unwrap(), &Value::Int8(30));
17498
17499 assert!(null_group.is_some(), "NULL values should form their own group");
17500 assert_eq!(null_group.unwrap().get(1).unwrap(), &Value::Int8(70));
17501 }
17502
17503 #[test]
17504 fn test_group_by_count_distinct() {
17505 let db = EmbeddedDatabase::new_in_memory().unwrap();
17507 db.execute("CREATE TABLE gb_cd (grp TEXT, val INT)").unwrap();
17508 db.execute("INSERT INTO gb_cd VALUES ('X', 1)").unwrap();
17509 db.execute("INSERT INTO gb_cd VALUES ('X', 2)").unwrap();
17510 db.execute("INSERT INTO gb_cd VALUES ('X', 2)").unwrap();
17511 db.execute("INSERT INTO gb_cd VALUES ('X', 3)").unwrap();
17512 db.execute("INSERT INTO gb_cd VALUES ('Y', 10)").unwrap();
17513 db.execute("INSERT INTO gb_cd VALUES ('Y', 10)").unwrap();
17514
17515 let result = db.query(
17516 "SELECT grp, COUNT(DISTINCT val) FROM gb_cd GROUP BY grp ORDER BY grp",
17517 &[],
17518 );
17519 match result {
17520 Ok(rows) => {
17521 assert_eq!(rows.len(), 2, "Expected 2 groups, got {}", rows.len());
17522 assert_eq!(rows[0].get(0).unwrap(), &Value::String("X".to_string()));
17524 assert_eq!(rows[0].get(1).unwrap(), &Value::Int8(3));
17525 assert_eq!(rows[1].get(0).unwrap(), &Value::String("Y".to_string()));
17527 assert_eq!(rows[1].get(1).unwrap(), &Value::Int8(1));
17528 }
17529 Err(e) => {
17530 let err_msg = e.to_string();
17532 assert!(
17533 err_msg.contains("DISTINCT") || err_msg.contains("distinct") || err_msg.contains("not") || err_msg.contains("syntax"),
17534 "COUNT(DISTINCT) unsupported or syntax error: {}", err_msg
17535 );
17536 }
17537 }
17538 }
17539
17540 #[test]
17547 fn test_cast_int_to_text() {
17548 let db = EmbeddedDatabase::new_in_memory().unwrap();
17549 let rows = db.query("SELECT CAST(42 AS TEXT)", &[]).unwrap();
17550 assert_eq!(rows.len(), 1);
17551 assert_eq!(rows[0].get(0).unwrap(), &Value::String("42".to_string()));
17552 }
17553
17554 #[test]
17555 fn test_cast_text_to_int() {
17556 let db = EmbeddedDatabase::new_in_memory().unwrap();
17557 let rows = db.query("SELECT CAST('42' AS INT)", &[]).unwrap();
17558 assert_eq!(rows.len(), 1);
17559 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(42));
17560 }
17561
17562 #[test]
17563 fn test_cast_int_to_float() {
17564 let db = EmbeddedDatabase::new_in_memory().unwrap();
17565 let rows = db.query("SELECT CAST(42 AS FLOAT8)", &[]).unwrap();
17566 assert_eq!(rows.len(), 1);
17567 assert_eq!(rows[0].get(0).unwrap(), &Value::Float8(42.0));
17568 }
17569
17570 #[test]
17571 fn test_cast_float_to_int() {
17572 let db = EmbeddedDatabase::new_in_memory().unwrap();
17574 let rows = db.query("SELECT CAST(3.7 AS INT)", &[]).unwrap();
17575 assert_eq!(rows.len(), 1);
17576 let val = rows[0].get(0).unwrap();
17578 assert!(
17579 val == &Value::Int4(3) || val == &Value::Int4(4),
17580 "CAST(3.7 AS INT) should truncate to 3 (or possibly round to 4), got {:?}", val
17581 );
17582 }
17583
17584 #[test]
17585 fn test_cast_text_to_boolean() {
17586 let db = EmbeddedDatabase::new_in_memory().unwrap();
17587 let rows = db.query("SELECT CAST('true' AS BOOLEAN)", &[]).unwrap();
17588 assert_eq!(rows.len(), 1);
17589 assert_eq!(rows[0].get(0).unwrap(), &Value::Boolean(true));
17590
17591 let rows2 = db.query("SELECT CAST('false' AS BOOLEAN)", &[]).unwrap();
17592 assert_eq!(rows2.len(), 1);
17593 assert_eq!(rows2[0].get(0).unwrap(), &Value::Boolean(false));
17594 }
17595
17596 #[test]
17597 fn test_cast_boolean_to_text() {
17598 let db = EmbeddedDatabase::new_in_memory().unwrap();
17599 let rows = db.query("SELECT CAST(TRUE AS TEXT)", &[]).unwrap();
17600 assert_eq!(rows.len(), 1);
17601 assert_eq!(rows[0].get(0).unwrap(), &Value::String("true".to_string()));
17602
17603 let rows2 = db.query("SELECT CAST(FALSE AS TEXT)", &[]).unwrap();
17604 assert_eq!(rows2.len(), 1);
17605 assert_eq!(rows2[0].get(0).unwrap(), &Value::String("false".to_string()));
17606 }
17607
17608 #[test]
17609 fn test_cast_null_cast() {
17610 let db = EmbeddedDatabase::new_in_memory().unwrap();
17612 let rows = db.query("SELECT CAST(NULL AS INT)", &[]).unwrap();
17613 assert_eq!(rows.len(), 1);
17614 assert_eq!(rows[0].get(0).unwrap(), &Value::Null);
17615 }
17616
17617 #[test]
17618 fn test_cast_invalid_text_to_int() {
17619 let db = EmbeddedDatabase::new_in_memory().unwrap();
17621 let result = db.query("SELECT CAST('abc' AS INT)", &[]);
17622 assert!(result.is_err(), "CAST('abc' AS INT) should fail, but got: {:?}", result);
17623 let err_msg = result.unwrap_err().to_string();
17624 assert!(
17625 err_msg.contains("Cannot cast") || err_msg.contains("cast") || err_msg.contains("invalid"),
17626 "Error should mention cast failure, got: {}", err_msg
17627 );
17628 }
17629
17630 #[test]
17631 fn test_cast_int_to_bigint() {
17632 let db = EmbeddedDatabase::new_in_memory().unwrap();
17633 let rows = db.query("SELECT CAST(42 AS BIGINT)", &[]).unwrap();
17634 assert_eq!(rows.len(), 1);
17635 assert_eq!(rows[0].get(0).unwrap(), &Value::Int8(42));
17636 }
17637
17638 #[test]
17639 fn test_cast_in_where() {
17640 let db = EmbeddedDatabase::new_in_memory().unwrap();
17642 db.execute("CREATE TABLE cast_where (id INT, code INT)").unwrap();
17643 db.execute("INSERT INTO cast_where VALUES (1, 42)").unwrap();
17644 db.execute("INSERT INTO cast_where VALUES (2, 99)").unwrap();
17645 db.execute("INSERT INTO cast_where VALUES (3, 42)").unwrap();
17646
17647 let rows = db.query(
17648 "SELECT id FROM cast_where WHERE CAST(code AS TEXT) = '42' ORDER BY id",
17649 &[],
17650 ).unwrap();
17651 assert_eq!(rows.len(), 2, "Expected 2 rows with code=42, got {}", rows.len());
17652 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
17653 assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(3));
17654 }
17655
17656 #[test]
17670 fn test_alter_add_column_basic() {
17671 let db = EmbeddedDatabase::new_in_memory().unwrap();
17673 db.execute("CREATE TABLE alt_add_basic (id INT, name TEXT)").unwrap();
17674 db.execute("INSERT INTO alt_add_basic VALUES (1, 'Alice')").unwrap();
17675
17676 db.execute("ALTER TABLE alt_add_basic ADD COLUMN age INT").unwrap();
17677
17678 let rows = db.query("SELECT id, name, age FROM alt_add_basic", &[]).unwrap();
17680 assert_eq!(rows.len(), 1, "Should still have 1 row");
17681 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
17682 assert_eq!(rows[0].get(1).unwrap(), &Value::String("Alice".to_string()));
17683 assert_eq!(rows[0].get(2).unwrap(), &Value::Null,
17684 "New column should be NULL for existing rows");
17685 }
17686
17687 #[test]
17688 fn test_alter_add_column_with_default() {
17689 let db = EmbeddedDatabase::new_in_memory().unwrap();
17691 db.execute("CREATE TABLE alt_add_def (id INT, name TEXT)").unwrap();
17692 db.execute("INSERT INTO alt_add_def VALUES (1, 'Alice')").unwrap();
17693 db.execute("INSERT INTO alt_add_def VALUES (2, 'Bob')").unwrap();
17694
17695 db.execute("ALTER TABLE alt_add_def ADD COLUMN status TEXT DEFAULT 'active'").unwrap();
17696
17697 let rows = db.query("SELECT id, name, status FROM alt_add_def ORDER BY id", &[]).unwrap();
17699 assert_eq!(rows.len(), 2, "Should still have 2 rows");
17700
17701 let status_0 = rows[0].get(2).unwrap();
17704 let status_1 = rows[1].get(2).unwrap();
17705
17706 assert_eq!(status_0, status_1,
17708 "Both existing rows should get same value for new column with DEFAULT");
17709
17710 assert!(
17712 *status_0 == Value::Null || *status_0 == Value::String("active".to_string()),
17713 "New column should be NULL or 'active', got: {:?}", status_0
17714 );
17715 }
17716
17717 #[test]
17718 fn test_alter_add_column_nullable() {
17719 let db = EmbeddedDatabase::new_in_memory().unwrap();
17721 db.execute("CREATE TABLE alt_add_null (id INT)").unwrap();
17722 db.execute("INSERT INTO alt_add_null VALUES (1)").unwrap();
17723 db.execute("INSERT INTO alt_add_null VALUES (2)").unwrap();
17724 db.execute("INSERT INTO alt_add_null VALUES (3)").unwrap();
17725
17726 db.execute("ALTER TABLE alt_add_null ADD COLUMN note TEXT").unwrap();
17727
17728 let rows = db.query("SELECT id, note FROM alt_add_null ORDER BY id", &[]).unwrap();
17729 assert_eq!(rows.len(), 3, "Should still have 3 rows");
17730
17731 for (i, row) in rows.iter().enumerate() {
17732 assert_eq!(row.get(1).unwrap(), &Value::Null,
17733 "Row {} new column should be NULL", i);
17734 }
17735 }
17736
17737 #[test]
17738 fn test_alter_add_column_text_type() {
17739 let db = EmbeddedDatabase::new_in_memory().unwrap();
17741 db.execute("CREATE TABLE alt_add_text (id INT)").unwrap();
17742 db.execute("INSERT INTO alt_add_text VALUES (1)").unwrap();
17743
17744 db.execute("ALTER TABLE alt_add_text ADD COLUMN description TEXT").unwrap();
17745
17746 db.execute("UPDATE alt_add_text SET description = 'hello world' WHERE id = 1").unwrap();
17748
17749 let rows = db.query("SELECT id, description FROM alt_add_text", &[]).unwrap();
17750 assert_eq!(rows.len(), 1);
17751 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
17752 assert_eq!(rows[0].get(1).unwrap(), &Value::String("hello world".to_string()));
17753 }
17754
17755 #[test]
17756 fn test_alter_add_column_then_insert() {
17757 let db = EmbeddedDatabase::new_in_memory().unwrap();
17759 db.execute("CREATE TABLE alt_add_ins (id INT, name TEXT)").unwrap();
17760 db.execute("INSERT INTO alt_add_ins VALUES (1, 'Alice')").unwrap();
17761
17762 db.execute("ALTER TABLE alt_add_ins ADD COLUMN score INT").unwrap();
17763
17764 db.execute("INSERT INTO alt_add_ins VALUES (2, 'Bob', 95)").unwrap();
17766
17767 let rows = db.query("SELECT id, name, score FROM alt_add_ins ORDER BY id", &[]).unwrap();
17768 assert_eq!(rows.len(), 2, "Should have 2 rows total");
17769
17770 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
17772 assert_eq!(rows[0].get(2).unwrap(), &Value::Null);
17773
17774 assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(2));
17776 assert_eq!(rows[1].get(2).unwrap(), &Value::Int4(95));
17777 }
17778
17779 #[test]
17780 fn test_alter_add_column_duplicate() {
17781 let db = EmbeddedDatabase::new_in_memory().unwrap();
17783 db.execute("CREATE TABLE alt_add_dup (id INT, name TEXT)").unwrap();
17784
17785 let result = db.execute("ALTER TABLE alt_add_dup ADD COLUMN name TEXT");
17786 assert!(result.is_err(),
17787 "Adding a duplicate column should fail");
17788 let err_msg = result.unwrap_err().to_string();
17789 assert!(err_msg.contains("already exists"),
17790 "Error should mention 'already exists', got: {}", err_msg);
17791 }
17792
17793 #[test]
17794 fn test_alter_add_column_if_not_exists() {
17795 let db = EmbeddedDatabase::new_in_memory().unwrap();
17797 db.execute("CREATE TABLE alt_add_ine (id INT, name TEXT)").unwrap();
17798
17799 let result = db.execute("ALTER TABLE alt_add_ine ADD COLUMN IF NOT EXISTS name TEXT");
17801 assert!(result.is_ok(),
17802 "ADD COLUMN IF NOT EXISTS for existing column should succeed silently, got: {:?}",
17803 result.err());
17804 }
17805
17806 #[test]
17809 fn test_alter_drop_column_basic() {
17810 let db = EmbeddedDatabase::new_in_memory().unwrap();
17812 db.execute("CREATE TABLE alt_drop_basic (id INT, name TEXT, age INT)").unwrap();
17813 db.execute("INSERT INTO alt_drop_basic VALUES (1, 'Alice', 30)").unwrap();
17814
17815 db.execute("ALTER TABLE alt_drop_basic DROP COLUMN age").unwrap();
17816
17817 let result = db.query("SELECT age FROM alt_drop_basic", &[]);
17819 assert!(result.is_err(),
17820 "Selecting a dropped column should fail");
17821
17822 let rows = db.query("SELECT id, name FROM alt_drop_basic", &[]).unwrap();
17824 assert_eq!(rows.len(), 1);
17825 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
17826 assert_eq!(rows[0].get(1).unwrap(), &Value::String("Alice".to_string()));
17827 }
17828
17829 #[test]
17830 fn test_alter_drop_column_with_data() {
17831 let db = EmbeddedDatabase::new_in_memory().unwrap();
17833 db.execute("CREATE TABLE alt_drop_data (id INT, name TEXT, score INT, grade TEXT)").unwrap();
17834 db.execute("INSERT INTO alt_drop_data VALUES (1, 'Alice', 90, 'A')").unwrap();
17835 db.execute("INSERT INTO alt_drop_data VALUES (2, 'Bob', 80, 'B')").unwrap();
17836 db.execute("INSERT INTO alt_drop_data VALUES (3, 'Carol', 70, 'C')").unwrap();
17837
17838 db.execute("ALTER TABLE alt_drop_data DROP COLUMN score").unwrap();
17839
17840 let rows = db.query("SELECT id, name, grade FROM alt_drop_data ORDER BY id", &[]).unwrap();
17842 assert_eq!(rows.len(), 3, "All rows should still exist");
17843 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
17844 assert_eq!(rows[0].get(1).unwrap(), &Value::String("Alice".to_string()));
17845 assert_eq!(rows[0].get(2).unwrap(), &Value::String("A".to_string()));
17846 assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(2));
17847 assert_eq!(rows[1].get(1).unwrap(), &Value::String("Bob".to_string()));
17848 assert_eq!(rows[1].get(2).unwrap(), &Value::String("B".to_string()));
17849 assert_eq!(rows[2].get(0).unwrap(), &Value::Int4(3));
17850 assert_eq!(rows[2].get(1).unwrap(), &Value::String("Carol".to_string()));
17851 assert_eq!(rows[2].get(2).unwrap(), &Value::String("C".to_string()));
17852 }
17853
17854 #[test]
17855 fn test_alter_drop_column_nonexistent() {
17856 let db = EmbeddedDatabase::new_in_memory().unwrap();
17858 db.execute("CREATE TABLE alt_drop_ne (id INT, name TEXT)").unwrap();
17859
17860 let result = db.execute("ALTER TABLE alt_drop_ne DROP COLUMN nonexistent");
17861 assert!(result.is_err(),
17862 "Dropping a nonexistent column should fail");
17863 let err_msg = result.unwrap_err().to_string();
17864 assert!(err_msg.contains("does not exist"),
17865 "Error should mention 'does not exist', got: {}", err_msg);
17866 }
17867
17868 #[test]
17869 fn test_alter_drop_column_if_exists() {
17870 let db = EmbeddedDatabase::new_in_memory().unwrap();
17872 db.execute("CREATE TABLE alt_drop_ie (id INT, name TEXT)").unwrap();
17873
17874 let result = db.execute("ALTER TABLE alt_drop_ie DROP COLUMN IF EXISTS nonexistent");
17875 assert!(result.is_ok(),
17876 "DROP COLUMN IF EXISTS for nonexistent column should succeed silently, got: {:?}",
17877 result.err());
17878 }
17879
17880 #[test]
17881 fn test_alter_drop_column_last_column() {
17882 let db = EmbeddedDatabase::new_in_memory().unwrap();
17885 db.execute("CREATE TABLE alt_drop_last (only_col INT)").unwrap();
17886 db.execute("INSERT INTO alt_drop_last VALUES (42)").unwrap();
17887
17888 let result = db.execute("ALTER TABLE alt_drop_last DROP COLUMN only_col");
17889 if result.is_ok() {
17891 let query_result = db.query("SELECT * FROM alt_drop_last", &[]);
17893 match query_result {
17895 Ok(rows) => {
17896 assert!(rows.is_empty() || rows[0].values.is_empty(),
17898 "After dropping last column, rows should be empty or have no values");
17899 }
17900 Err(_) => {
17901 }
17903 }
17904 }
17905 }
17907
17908 #[test]
17909 fn test_alter_drop_primary_key_column_without_cascade() {
17910 let db = EmbeddedDatabase::new_in_memory().unwrap();
17912 db.execute("CREATE TABLE alt_drop_pk (id INT PRIMARY KEY, name TEXT)").unwrap();
17913 db.execute("INSERT INTO alt_drop_pk VALUES (1, 'Alice')").unwrap();
17914
17915 let result = db.execute("ALTER TABLE alt_drop_pk DROP COLUMN id");
17916 assert!(result.is_err(),
17917 "Dropping a primary key column without CASCADE should fail");
17918 let err_msg = result.unwrap_err().to_string();
17919 assert!(err_msg.contains("CASCADE") || err_msg.contains("primary key"),
17920 "Error should mention CASCADE or primary key, got: {}", err_msg);
17921 }
17922
17923 #[test]
17926 fn test_alter_rename_column_basic() {
17927 let db = EmbeddedDatabase::new_in_memory().unwrap();
17929 db.execute("CREATE TABLE alt_ren_col (id INT, old_name TEXT)").unwrap();
17930 db.execute("INSERT INTO alt_ren_col VALUES (1, 'Alice')").unwrap();
17931
17932 db.execute("ALTER TABLE alt_ren_col RENAME COLUMN old_name TO new_name").unwrap();
17933
17934 let rows = db.query("SELECT id, new_name FROM alt_ren_col", &[]).unwrap();
17936 assert_eq!(rows.len(), 1);
17937 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
17938 assert_eq!(rows[0].get(1).unwrap(), &Value::String("Alice".to_string()));
17939 }
17940
17941 #[test]
17942 fn test_alter_rename_column_preserves_data() {
17943 let db = EmbeddedDatabase::new_in_memory().unwrap();
17945 db.execute("CREATE TABLE alt_ren_data (id INT, val TEXT)").unwrap();
17946 db.execute("INSERT INTO alt_ren_data VALUES (1, 'one')").unwrap();
17947 db.execute("INSERT INTO alt_ren_data VALUES (2, 'two')").unwrap();
17948 db.execute("INSERT INTO alt_ren_data VALUES (3, 'three')").unwrap();
17949
17950 db.execute("ALTER TABLE alt_ren_data RENAME COLUMN val TO value").unwrap();
17951
17952 let rows = db.query("SELECT id, value FROM alt_ren_data ORDER BY id", &[]).unwrap();
17953 assert_eq!(rows.len(), 3, "All rows should still exist after rename");
17954 assert_eq!(rows[0].get(1).unwrap(), &Value::String("one".to_string()));
17955 assert_eq!(rows[1].get(1).unwrap(), &Value::String("two".to_string()));
17956 assert_eq!(rows[2].get(1).unwrap(), &Value::String("three".to_string()));
17957
17958 let result = db.query("SELECT val FROM alt_ren_data", &[]);
17960 assert!(result.is_err(),
17961 "Old column name should no longer be valid after rename");
17962 }
17963
17964 #[test]
17965 fn test_alter_rename_column_nonexistent() {
17966 let db = EmbeddedDatabase::new_in_memory().unwrap();
17968 db.execute("CREATE TABLE alt_ren_ne (id INT, name TEXT)").unwrap();
17969
17970 let result = db.execute("ALTER TABLE alt_ren_ne RENAME COLUMN ghost TO phantom");
17971 assert!(result.is_err(),
17972 "Renaming a nonexistent column should fail");
17973 let err_msg = result.unwrap_err().to_string();
17974 assert!(err_msg.contains("does not exist"),
17975 "Error should mention 'does not exist', got: {}", err_msg);
17976 }
17977
17978 #[test]
17979 fn test_alter_rename_column_to_existing_name() {
17980 let db = EmbeddedDatabase::new_in_memory().unwrap();
17982 db.execute("CREATE TABLE alt_ren_dup (id INT, name TEXT)").unwrap();
17983
17984 let result = db.execute("ALTER TABLE alt_ren_dup RENAME COLUMN id TO name");
17985 assert!(result.is_err(),
17986 "Renaming to an already-existing column name should fail");
17987 let err_msg = result.unwrap_err().to_string();
17988 assert!(err_msg.contains("already exists"),
17989 "Error should mention 'already exists', got: {}", err_msg);
17990 }
17991
17992 #[test]
17995 fn test_alter_rename_table_basic() {
17996 let db = EmbeddedDatabase::new_in_memory().unwrap();
17998 db.execute("CREATE TABLE alt_old_tbl (id INT, name TEXT)").unwrap();
17999 db.execute("INSERT INTO alt_old_tbl VALUES (1, 'Alice')").unwrap();
18000
18001 db.execute("ALTER TABLE alt_old_tbl RENAME TO alt_new_tbl").unwrap();
18002
18003 let rows = db.query("SELECT id, name FROM alt_new_tbl", &[]).unwrap();
18004 assert_eq!(rows.len(), 1, "Data should be accessible via new table name");
18005 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
18006 assert_eq!(rows[0].get(1).unwrap(), &Value::String("Alice".to_string()));
18007 }
18008
18009 #[test]
18010 fn test_alter_rename_table_old_name_fails() {
18011 let db = EmbeddedDatabase::new_in_memory().unwrap();
18013 db.execute("CREATE TABLE alt_orig (id INT)").unwrap();
18014 db.execute("INSERT INTO alt_orig VALUES (1)").unwrap();
18015
18016 db.execute("ALTER TABLE alt_orig RENAME TO alt_renamed").unwrap();
18017
18018 let result = db.query("SELECT * FROM alt_orig", &[]);
18019 assert!(result.is_err(),
18020 "Querying the old table name after rename should fail");
18021 }
18022
18023 #[test]
18024 fn test_alter_rename_table_to_existing() {
18025 let db = EmbeddedDatabase::new_in_memory().unwrap();
18027 db.execute("CREATE TABLE alt_src (id INT)").unwrap();
18028 db.execute("CREATE TABLE alt_dst (id INT)").unwrap();
18029
18030 let result = db.execute("ALTER TABLE alt_src RENAME TO alt_dst");
18031 assert!(result.is_err(),
18032 "Renaming to an existing table name should fail");
18033 let err_msg = result.unwrap_err().to_string();
18034 assert!(err_msg.contains("already exists"),
18035 "Error should mention 'already exists', got: {}", err_msg);
18036 }
18037
18038 #[test]
18041 fn test_alter_add_then_drop_column() {
18042 let db = EmbeddedDatabase::new_in_memory().unwrap();
18044 db.execute("CREATE TABLE alt_add_drop (id INT, name TEXT)").unwrap();
18045 db.execute("INSERT INTO alt_add_drop VALUES (1, 'Alice')").unwrap();
18046
18047 db.execute("ALTER TABLE alt_add_drop ADD COLUMN temp INT").unwrap();
18049 let rows = db.query("SELECT id, name, temp FROM alt_add_drop", &[]).unwrap();
18050 assert_eq!(rows.len(), 1);
18051 assert_eq!(rows[0].values.len(), 3, "Should have 3 columns after ADD");
18052
18053 db.execute("ALTER TABLE alt_add_drop DROP COLUMN temp").unwrap();
18055 let rows = db.query("SELECT id, name FROM alt_add_drop", &[]).unwrap();
18056 assert_eq!(rows.len(), 1);
18057 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
18058 assert_eq!(rows[0].get(1).unwrap(), &Value::String("Alice".to_string()));
18059
18060 let result = db.query("SELECT temp FROM alt_add_drop", &[]);
18062 assert!(result.is_err(), "Dropped column should not be queryable");
18063 }
18064
18065 #[test]
18066 fn test_alter_multiple_sequential_operations() {
18067 let db = EmbeddedDatabase::new_in_memory().unwrap();
18069 db.execute("CREATE TABLE alt_seq (id INT, a TEXT)").unwrap();
18070 db.execute("INSERT INTO alt_seq VALUES (1, 'original')").unwrap();
18071
18072 db.execute("ALTER TABLE alt_seq ADD COLUMN b INT").unwrap();
18074 db.execute("ALTER TABLE alt_seq RENAME COLUMN a TO alpha").unwrap();
18076 db.execute("ALTER TABLE alt_seq ADD COLUMN c TEXT").unwrap();
18078
18079 db.execute("INSERT INTO alt_seq VALUES (2, 'new', 42, 'hello')").unwrap();
18081
18082 let rows = db.query("SELECT id, alpha, b, c FROM alt_seq ORDER BY id", &[]).unwrap();
18083 assert_eq!(rows.len(), 2, "Should have 2 rows");
18084
18085 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
18087 assert_eq!(rows[0].get(1).unwrap(), &Value::String("original".to_string()));
18088 assert_eq!(rows[0].get(2).unwrap(), &Value::Null);
18089 assert_eq!(rows[0].get(3).unwrap(), &Value::Null);
18090
18091 assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(2));
18093 assert_eq!(rows[1].get(1).unwrap(), &Value::String("new".to_string()));
18094 assert_eq!(rows[1].get(2).unwrap(), &Value::Int4(42));
18095 assert_eq!(rows[1].get(3).unwrap(), &Value::String("hello".to_string()));
18096 }
18097
18098 #[test]
18099 fn test_alter_table_nonexistent_table() {
18100 let db = EmbeddedDatabase::new_in_memory().unwrap();
18102
18103 let result = db.execute("ALTER TABLE no_such_table ADD COLUMN x INT");
18104 assert!(result.is_err(),
18105 "ALTER TABLE on nonexistent table should fail");
18106 }
18107
18108 fn setup_pagination_db() -> EmbeddedDatabase {
18119 let db = EmbeddedDatabase::new_in_memory().unwrap();
18120 db.execute(
18121 "CREATE TABLE pg_products (id INT, name TEXT, price INT, category TEXT)"
18122 ).unwrap();
18123 db.execute("INSERT INTO pg_products VALUES (1, 'Product_01', 50, 'Electronics')").unwrap();
18125 db.execute("INSERT INTO pg_products VALUES (2, 'Product_02', 30, 'Books')").unwrap();
18126 db.execute("INSERT INTO pg_products VALUES (3, 'Product_03', 75, 'Electronics')").unwrap();
18127 db.execute("INSERT INTO pg_products VALUES (4, 'Product_04', 20, 'Clothing')").unwrap();
18128 db.execute("INSERT INTO pg_products VALUES (5, 'Product_05', 90, 'Electronics')").unwrap();
18129 db.execute("INSERT INTO pg_products VALUES (6, 'Product_06', 15, 'Books')").unwrap();
18130 db.execute("INSERT INTO pg_products VALUES (7, 'Product_07', 60, 'Clothing')").unwrap();
18131 db.execute("INSERT INTO pg_products VALUES (8, 'Product_08', 45, 'Books')").unwrap();
18132 db.execute("INSERT INTO pg_products VALUES (9, 'Product_09', 80, 'Clothing')").unwrap();
18133 db.execute("INSERT INTO pg_products VALUES (10, 'Product_10', 35, 'Electronics')").unwrap();
18134 db
18135 }
18136
18137 #[test]
18140 fn test_limit_basic() {
18141 let db = setup_pagination_db();
18143 let rows = db.query(
18144 "SELECT id FROM pg_products ORDER BY id LIMIT 3",
18145 &[],
18146 ).unwrap();
18147 assert_eq!(rows.len(), 3, "LIMIT 3 should return 3 rows, got {}", rows.len());
18148 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
18149 assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(2));
18150 assert_eq!(rows[2].get(0).unwrap(), &Value::Int4(3));
18151 }
18152
18153 #[test]
18154 fn test_limit_zero() {
18155 let db = setup_pagination_db();
18157 let rows = db.query(
18158 "SELECT id FROM pg_products LIMIT 0",
18159 &[],
18160 ).unwrap();
18161 assert_eq!(rows.len(), 0, "LIMIT 0 should return 0 rows, got {}", rows.len());
18162 }
18163
18164 #[test]
18165 fn test_limit_exceeds_rows() {
18166 let db = setup_pagination_db();
18168 let rows = db.query(
18169 "SELECT id FROM pg_products ORDER BY id LIMIT 100",
18170 &[],
18171 ).unwrap();
18172 assert_eq!(rows.len(), 10, "LIMIT 100 on 10 rows should return 10, got {}", rows.len());
18173 }
18174
18175 #[test]
18176 fn test_limit_one() {
18177 let db = setup_pagination_db();
18179 let rows = db.query(
18180 "SELECT id FROM pg_products ORDER BY id LIMIT 1",
18181 &[],
18182 ).unwrap();
18183 assert_eq!(rows.len(), 1, "LIMIT 1 should return 1 row, got {}", rows.len());
18184 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
18185 }
18186
18187 #[test]
18188 fn test_limit_with_order_by() {
18189 let db = setup_pagination_db();
18192 let rows = db.query(
18193 "SELECT id, price FROM pg_products ORDER BY price DESC LIMIT 3",
18194 &[],
18195 ).unwrap();
18196 assert_eq!(rows.len(), 3, "Top 3 by price DESC should return 3 rows");
18197 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(5), "Most expensive is id=5 (price 90)");
18198 assert_eq!(rows[0].get(1).unwrap(), &Value::Int4(90));
18199 assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(9), "Second most expensive is id=9 (price 80)");
18200 assert_eq!(rows[1].get(1).unwrap(), &Value::Int4(80));
18201 assert_eq!(rows[2].get(0).unwrap(), &Value::Int4(3), "Third most expensive is id=3 (price 75)");
18202 assert_eq!(rows[2].get(1).unwrap(), &Value::Int4(75));
18203 }
18204
18205 #[test]
18208 fn test_offset_basic() {
18209 let db = setup_pagination_db();
18211 let rows = db.query(
18212 "SELECT id FROM pg_products ORDER BY id OFFSET 2",
18213 &[],
18214 ).unwrap();
18215 assert_eq!(rows.len(), 8, "OFFSET 2 on 10 rows should return 8, got {}", rows.len());
18216 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(3));
18218 assert_eq!(rows[7].get(0).unwrap(), &Value::Int4(10));
18219 }
18220
18221 #[test]
18222 fn test_offset_exceeds_rows() {
18223 let db = setup_pagination_db();
18227 let rows = db.query(
18228 "SELECT id FROM pg_products ORDER BY id LIMIT 100 OFFSET 20",
18229 &[],
18230 ).unwrap();
18231 assert_eq!(rows.len(), 0, "OFFSET beyond row count should return 0 rows, got {}", rows.len());
18232 }
18233
18234 #[test]
18235 fn test_offset_zero() {
18236 let db = setup_pagination_db();
18238 let rows = db.query(
18239 "SELECT id FROM pg_products ORDER BY id OFFSET 0",
18240 &[],
18241 ).unwrap();
18242 assert_eq!(rows.len(), 10, "OFFSET 0 should return all 10 rows, got {}", rows.len());
18243 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(1));
18244 }
18245
18246 #[test]
18247 fn test_limit_offset_combined() {
18248 let db = setup_pagination_db();
18250 let rows = db.query(
18251 "SELECT id FROM pg_products ORDER BY id LIMIT 3 OFFSET 2",
18252 &[],
18253 ).unwrap();
18254 assert_eq!(rows.len(), 3, "LIMIT 3 OFFSET 2 should return 3 rows, got {}", rows.len());
18255 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(3));
18256 assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(4));
18257 assert_eq!(rows[2].get(0).unwrap(), &Value::Int4(5));
18258 }
18259
18260 #[test]
18261 fn test_limit_offset_page_2() {
18262 let db = setup_pagination_db();
18264 let rows = db.query(
18265 "SELECT id FROM pg_products ORDER BY id LIMIT 3 OFFSET 3",
18266 &[],
18267 ).unwrap();
18268 assert_eq!(rows.len(), 3, "Page 2 (LIMIT 3 OFFSET 3) should return 3 rows, got {}", rows.len());
18269 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(4));
18270 assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(5));
18271 assert_eq!(rows[2].get(0).unwrap(), &Value::Int4(6));
18272 }
18273
18274 #[test]
18275 fn test_limit_offset_last_page() {
18276 let db = setup_pagination_db();
18279 let rows = db.query(
18280 "SELECT id FROM pg_products ORDER BY id LIMIT 3 OFFSET 9",
18281 &[],
18282 ).unwrap();
18283 assert_eq!(rows.len(), 1, "Last page should return 1 remaining row, got {}", rows.len());
18284 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(10));
18285 }
18286
18287 #[test]
18290 fn test_order_by_asc() {
18291 let db = setup_pagination_db();
18294 let rows = db.query(
18295 "SELECT id, price FROM pg_products ORDER BY price ASC",
18296 &[],
18297 ).unwrap();
18298 assert_eq!(rows.len(), 10);
18299 assert_eq!(rows[0].get(1).unwrap(), &Value::Int4(15), "Cheapest should be 15");
18300 assert_eq!(rows[9].get(1).unwrap(), &Value::Int4(90), "Most expensive should be 90");
18301 for i in 0..9 {
18303 let cur = match rows[i].get(1).unwrap() { Value::Int4(v) => *v, _ => panic!("expected Int4") };
18304 let nxt = match rows[i + 1].get(1).unwrap() { Value::Int4(v) => *v, _ => panic!("expected Int4") };
18305 assert!(cur <= nxt, "Row {} price {} should be <= row {} price {}", i, cur, i + 1, nxt);
18306 }
18307 }
18308
18309 #[test]
18310 fn test_order_by_desc() {
18311 let db = setup_pagination_db();
18313 let rows = db.query(
18314 "SELECT id, price FROM pg_products ORDER BY price DESC",
18315 &[],
18316 ).unwrap();
18317 assert_eq!(rows.len(), 10);
18318 assert_eq!(rows[0].get(1).unwrap(), &Value::Int4(90), "Most expensive should be first");
18319 assert_eq!(rows[9].get(1).unwrap(), &Value::Int4(15), "Cheapest should be last");
18320 for i in 0..9 {
18322 let cur = match rows[i].get(1).unwrap() { Value::Int4(v) => *v, _ => panic!("expected Int4") };
18323 let nxt = match rows[i + 1].get(1).unwrap() { Value::Int4(v) => *v, _ => panic!("expected Int4") };
18324 assert!(cur >= nxt, "Row {} price {} should be >= row {} price {}", i, cur, i + 1, nxt);
18325 }
18326 }
18327
18328 #[test]
18329 fn test_order_by_multiple_columns() {
18330 let db = setup_pagination_db();
18333 let rows = db.query(
18334 "SELECT id, category, price FROM pg_products ORDER BY category, price",
18335 &[],
18336 ).unwrap();
18337 assert_eq!(rows.len(), 10);
18338 assert_eq!(rows[0].get(1).unwrap(), &Value::String("Books".to_string()));
18340 assert_eq!(rows[0].get(2).unwrap(), &Value::Int4(15));
18341 assert_eq!(rows[1].get(1).unwrap(), &Value::String("Books".to_string()));
18342 assert_eq!(rows[1].get(2).unwrap(), &Value::Int4(30));
18343 assert_eq!(rows[2].get(1).unwrap(), &Value::String("Books".to_string()));
18344 assert_eq!(rows[2].get(2).unwrap(), &Value::Int4(45));
18345 assert_eq!(rows[3].get(1).unwrap(), &Value::String("Clothing".to_string()));
18347 assert_eq!(rows[3].get(2).unwrap(), &Value::Int4(20));
18348 assert_eq!(rows[4].get(1).unwrap(), &Value::String("Clothing".to_string()));
18349 assert_eq!(rows[4].get(2).unwrap(), &Value::Int4(60));
18350 assert_eq!(rows[5].get(1).unwrap(), &Value::String("Clothing".to_string()));
18351 assert_eq!(rows[5].get(2).unwrap(), &Value::Int4(80));
18352 assert_eq!(rows[6].get(1).unwrap(), &Value::String("Electronics".to_string()));
18354 assert_eq!(rows[6].get(2).unwrap(), &Value::Int4(35));
18355 }
18356
18357 #[test]
18358 fn test_order_by_mixed_directions() {
18359 let db = setup_pagination_db();
18361 let rows = db.query(
18362 "SELECT id, category, price FROM pg_products ORDER BY category ASC, price DESC",
18363 &[],
18364 ).unwrap();
18365 assert_eq!(rows.len(), 10);
18366 assert_eq!(rows[0].get(1).unwrap(), &Value::String("Books".to_string()));
18368 assert_eq!(rows[0].get(2).unwrap(), &Value::Int4(45));
18369 assert_eq!(rows[1].get(1).unwrap(), &Value::String("Books".to_string()));
18370 assert_eq!(rows[1].get(2).unwrap(), &Value::Int4(30));
18371 assert_eq!(rows[2].get(1).unwrap(), &Value::String("Books".to_string()));
18372 assert_eq!(rows[2].get(2).unwrap(), &Value::Int4(15));
18373 assert_eq!(rows[3].get(1).unwrap(), &Value::String("Clothing".to_string()));
18375 assert_eq!(rows[3].get(2).unwrap(), &Value::Int4(80));
18376 assert_eq!(rows[4].get(1).unwrap(), &Value::String("Clothing".to_string()));
18377 assert_eq!(rows[4].get(2).unwrap(), &Value::Int4(60));
18378 assert_eq!(rows[5].get(1).unwrap(), &Value::String("Clothing".to_string()));
18379 assert_eq!(rows[5].get(2).unwrap(), &Value::Int4(20));
18380 assert_eq!(rows[6].get(1).unwrap(), &Value::String("Electronics".to_string()));
18382 assert_eq!(rows[6].get(2).unwrap(), &Value::Int4(90));
18383 assert_eq!(rows[7].get(2).unwrap(), &Value::Int4(75));
18384 assert_eq!(rows[8].get(2).unwrap(), &Value::Int4(50));
18385 assert_eq!(rows[9].get(2).unwrap(), &Value::Int4(35));
18386 }
18387
18388 #[test]
18389 fn test_order_by_with_nulls() {
18390 let db = EmbeddedDatabase::new_in_memory().unwrap();
18393 db.execute("CREATE TABLE pg_nullsort (id INT, score INT)").unwrap();
18394 db.execute("INSERT INTO pg_nullsort VALUES (1, 50)").unwrap();
18395 db.execute("INSERT INTO pg_nullsort VALUES (2, NULL)").unwrap();
18396 db.execute("INSERT INTO pg_nullsort VALUES (3, 30)").unwrap();
18397 db.execute("INSERT INTO pg_nullsort VALUES (4, NULL)").unwrap();
18398 db.execute("INSERT INTO pg_nullsort VALUES (5, 70)").unwrap();
18399
18400 let rows = db.query(
18402 "SELECT id, score FROM pg_nullsort ORDER BY score ASC, id ASC",
18403 &[],
18404 ).unwrap();
18405 assert_eq!(rows.len(), 5);
18406 assert_eq!(rows[0].get(1).unwrap(), &Value::Null);
18408 assert_eq!(rows[1].get(1).unwrap(), &Value::Null);
18409 assert_eq!(rows[2].get(1).unwrap(), &Value::Int4(30));
18411 assert_eq!(rows[3].get(1).unwrap(), &Value::Int4(50));
18412 assert_eq!(rows[4].get(1).unwrap(), &Value::Int4(70));
18413
18414 let rows_desc = db.query(
18416 "SELECT id, score FROM pg_nullsort ORDER BY score DESC, id ASC",
18417 &[],
18418 ).unwrap();
18419 assert_eq!(rows_desc.len(), 5);
18420 assert_eq!(rows_desc[0].get(1).unwrap(), &Value::Int4(70));
18421 assert_eq!(rows_desc[1].get(1).unwrap(), &Value::Int4(50));
18422 assert_eq!(rows_desc[2].get(1).unwrap(), &Value::Int4(30));
18423 assert_eq!(rows_desc[3].get(1).unwrap(), &Value::Null);
18425 assert_eq!(rows_desc[4].get(1).unwrap(), &Value::Null);
18426 }
18427
18428 #[test]
18431 fn test_pagination_full_scan() {
18432 let db = setup_pagination_db();
18435 let page_size = 3;
18436 let mut all_ids: Vec<i32> = Vec::new();
18437
18438 for page in 0..4 {
18439 let offset = page * page_size;
18440 let sql = format!(
18441 "SELECT id FROM pg_products ORDER BY id LIMIT {} OFFSET {}",
18442 page_size, offset
18443 );
18444 let rows = db.query(&sql, &[]).unwrap();
18445
18446 if page < 3 {
18447 assert_eq!(rows.len(), 3, "Page {} should have 3 rows, got {}", page, rows.len());
18448 } else {
18449 assert_eq!(rows.len(), 1, "Last page should have 1 row, got {}", rows.len());
18450 }
18451
18452 for row in &rows {
18453 if let Value::Int4(id) = row.get(0).unwrap() {
18454 all_ids.push(*id);
18455 }
18456 }
18457 }
18458
18459 assert_eq!(all_ids, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
18461 "Full pagination should yield all IDs 1..=10 in order");
18462 }
18463
18464 #[test]
18465 fn test_limit_with_where() {
18466 let db = setup_pagination_db();
18470 let rows = db.query(
18471 "SELECT id, price FROM pg_products WHERE category = 'Electronics' ORDER BY price DESC LIMIT 2",
18472 &[],
18473 ).unwrap();
18474 assert_eq!(rows.len(), 2, "LIMIT 2 on 4 Electronics rows should return 2, got {}", rows.len());
18475 assert_eq!(rows[0].get(0).unwrap(), &Value::Int4(5), "Most expensive electronics is id=5 (90)");
18476 assert_eq!(rows[0].get(1).unwrap(), &Value::Int4(90));
18477 assert_eq!(rows[1].get(0).unwrap(), &Value::Int4(3), "Second most expensive is id=3 (75)");
18478 assert_eq!(rows[1].get(1).unwrap(), &Value::Int4(75));
18479 }
18480
18481 #[test]
18482 fn test_limit_with_group_by() {
18483 let db = setup_pagination_db();
18487 let rows = db.query(
18488 "SELECT category, COUNT(*) AS cnt FROM pg_products GROUP BY category ORDER BY category LIMIT 2",
18489 &[],
18490 ).unwrap();
18491 assert_eq!(rows.len(), 2, "LIMIT 2 on 3 groups should return 2, got {}", rows.len());
18492 assert_eq!(rows[0].get(0).unwrap(), &Value::String("Books".to_string()));
18494 assert_eq!(rows[0].get(1).unwrap(), &Value::Int8(3));
18495 assert_eq!(rows[1].get(0).unwrap(), &Value::String("Clothing".to_string()));
18496 assert_eq!(rows[1].get(1).unwrap(), &Value::Int8(3));
18497 }
18498
18499 #[test]
18504 fn test_pg_compat_version() {
18505 let db = EmbeddedDatabase::new_in_memory().unwrap();
18506 let rows = db.query("SELECT version()", &[]).unwrap();
18507 assert_eq!(rows.len(), 1);
18508 let val = rows[0].get(0).unwrap();
18509 match val {
18510 Value::String(s) => {
18511 assert!(s.contains("PostgreSQL"), "version() should mention PostgreSQL, got: {}", s);
18512 assert!(s.contains("HeliosDB"), "version() should mention HeliosDB, got: {}", s);
18513 }
18514 other => panic!("Expected String, got: {:?}", other),
18515 }
18516 }
18517
18518 #[test]
18519 fn test_pg_compat_pg_catalog_version() {
18520 let db = EmbeddedDatabase::new_in_memory().unwrap();
18521 let rows = db.query("SELECT pg_catalog.version()", &[]).unwrap();
18522 assert_eq!(rows.len(), 1);
18523 let val = rows[0].get(0).unwrap();
18524 match val {
18525 Value::String(s) => {
18526 assert!(s.contains("PostgreSQL"), "pg_catalog.version() should mention PostgreSQL");
18527 }
18528 other => panic!("Expected String, got: {:?}", other),
18529 }
18530 }
18531
18532 #[test]
18533 fn test_pg_compat_current_schema() {
18534 let db = EmbeddedDatabase::new_in_memory().unwrap();
18535 let rows = db.query("SELECT current_schema()", &[]).unwrap();
18536 assert_eq!(rows.len(), 1);
18537 assert_eq!(rows[0].get(0).unwrap(), &Value::String("public".to_string()));
18538 }
18539
18540 #[test]
18541 fn test_pg_compat_current_database() {
18542 let db = EmbeddedDatabase::new_in_memory().unwrap();
18543 let rows = db.query("SELECT current_database()", &[]).unwrap();
18544 assert_eq!(rows.len(), 1);
18545 assert_eq!(rows[0].get(0).unwrap(), &Value::String("heliosdb".to_string()));
18546 }
18547
18548 #[test]
18553 fn test_wp_bigint_eq_where_clause() {
18554 let db = EmbeddedDatabase::new_in_memory().unwrap();
18557 db.execute("CREATE TABLE wp_posts (ID BIGSERIAL PRIMARY KEY, title TEXT)").unwrap();
18558 db.execute("INSERT INTO wp_posts (title) VALUES ('hello')").unwrap();
18559
18560 let rows_in = db.query("SELECT * FROM wp_posts WHERE ID IN (1)", &[]).unwrap();
18562 assert_eq!(rows_in.len(), 1, "IN (1) should find the row");
18563
18564 let rows_eq = db.query("SELECT * FROM wp_posts WHERE ID = 1", &[]).unwrap();
18566 assert_eq!(rows_eq.len(), 1, "fast-path WHERE ID = 1 should find the row");
18567
18568 let rows_order = db.query("SELECT * FROM wp_posts WHERE ID = 1 ORDER BY ID", &[]).unwrap();
18570 assert_eq!(rows_order.len(), 1, "executor-path WHERE ID = 1 ORDER BY should find the row");
18571
18572 let rows_col = db.query("SELECT ID, title FROM wp_posts WHERE ID = 1", &[]).unwrap();
18574 assert_eq!(rows_col.len(), 1, "SELECT cols WHERE ID = 1 should find the row");
18575
18576 db.execute("CREATE TABLE t_small (id SMALLSERIAL PRIMARY KEY, val TEXT)").unwrap();
18578 db.execute("INSERT INTO t_small (val) VALUES ('x')").unwrap();
18579 let rows_small = db.query("SELECT * FROM t_small WHERE id = 1", &[]).unwrap();
18580 assert_eq!(rows_small.len(), 1, "SMALLSERIAL PK with int4 literal should work");
18581 }
18582
18583 #[test]
18584 fn test_wp_last_insert_id_serial() {
18585 let db = EmbeddedDatabase::new_in_memory().unwrap();
18587 db.execute("CREATE TABLE t_serial (id BIGSERIAL PRIMARY KEY, name TEXT)").unwrap();
18588 db.execute("INSERT INTO t_serial (name) VALUES ('hello')").unwrap();
18589 let rows = db.query("SELECT MAX(id) FROM t_serial", &[]).unwrap();
18590 let max_id = rows[0].get(0).unwrap();
18591 match max_id {
18592 Value::Int8(n) => assert!(*n > 0, "SERIAL should auto-generate: got {}", n),
18593 Value::Int4(n) => assert!(*n > 0, "SERIAL should auto-generate: got {}", n),
18594 other => panic!("Unexpected type for MAX(id): {:?}", other),
18595 }
18596 }
18597
18598 #[test]
18599 fn test_wp_duplicate_pk_error_message() {
18600 let db = EmbeddedDatabase::new_in_memory().unwrap();
18603 db.execute("CREATE TABLE t_dup (id INT PRIMARY KEY, name TEXT)").unwrap();
18604 db.execute("INSERT INTO t_dup VALUES (1, 'a')").unwrap();
18605 let result = db.execute("INSERT INTO t_dup VALUES (1, 'b')");
18606 assert!(result.is_err(), "Duplicate PK insert must fail, but got Ok");
18607 let msg = result.unwrap_err().to_string();
18608 let lower = msg.to_lowercase();
18610 assert!(
18611 lower.contains("duplicate") || lower.contains("unique") || lower.contains("primary key"),
18612 "Duplicate PK error should contain recognizable keywords, got: {}", msg
18613 );
18614 }
18615
18616 #[test]
18617 fn test_wp_duplicate_pk_no_data_corruption() {
18618 let db = EmbeddedDatabase::new_in_memory().unwrap();
18620 db.execute("CREATE TABLE t_dup2 (id INT PRIMARY KEY, name TEXT)").unwrap();
18621 db.execute("INSERT INTO t_dup2 VALUES (1, 'original')").unwrap();
18622 let _ = db.execute("INSERT INTO t_dup2 VALUES (1, 'duplicate')");
18623 let rows = db.query("SELECT * FROM t_dup2", &[]).unwrap();
18624 assert_eq!(rows.len(), 1, "Only one row should exist after rejected duplicate");
18625 assert_eq!(rows[0].get(1).unwrap(), &Value::String("original".to_string()),
18626 "Original row must be preserved");
18627 }
18628
18629 #[test]
18630 fn test_wp_duplicate_unique_constraint() {
18631 let db = EmbeddedDatabase::new_in_memory().unwrap();
18633 db.execute("CREATE TABLE t_uq (id INT PRIMARY KEY, email TEXT UNIQUE)").unwrap();
18634 db.execute("INSERT INTO t_uq VALUES (1, 'a@b.com')").unwrap();
18635 let result = db.execute("INSERT INTO t_uq VALUES (2, 'a@b.com')");
18636 assert!(result.is_err(), "Duplicate UNIQUE insert must fail");
18637 }
18638
18639 #[test]
18644 fn test_on_conflict_do_nothing() {
18645 let db = EmbeddedDatabase::new_in_memory().unwrap();
18646 db.execute("CREATE TABLE t_oc1 (id INT PRIMARY KEY, name TEXT)").unwrap();
18647 db.execute("INSERT INTO t_oc1 VALUES (1, 'a')").unwrap();
18648 db.execute("INSERT INTO t_oc1 VALUES (1, 'b') ON CONFLICT DO NOTHING").unwrap();
18650 let rows = db.query("SELECT name FROM t_oc1 WHERE id = 1", &[]).unwrap();
18651 assert_eq!(rows.len(), 1);
18652 assert_eq!(rows[0].values[0], Value::String("a".to_string()));
18653 }
18654
18655 #[test]
18656 fn test_on_conflict_do_update() {
18657 let db = EmbeddedDatabase::new_in_memory().unwrap();
18658 db.execute("CREATE TABLE t_oc2 (id INT PRIMARY KEY, name TEXT)").unwrap();
18659 db.execute("INSERT INTO t_oc2 VALUES (1, 'a')").unwrap();
18660 db.execute("INSERT INTO t_oc2 VALUES (1, 'b') ON CONFLICT DO UPDATE SET name = EXCLUDED.name").unwrap();
18662 let rows = db.query("SELECT name FROM t_oc2 WHERE id = 1", &[]).unwrap();
18663 assert_eq!(rows.len(), 1);
18664 assert_eq!(rows[0].values[0], Value::String("b".to_string()));
18665 }
18666
18667 #[test]
18668 fn test_on_conflict_do_update_multiple_columns() {
18669 let db = EmbeddedDatabase::new_in_memory().unwrap();
18670 db.execute("CREATE TABLE t_oc3 (id INT PRIMARY KEY, name TEXT, score INT)").unwrap();
18671 db.execute("INSERT INTO t_oc3 VALUES (1, 'alice', 10)").unwrap();
18672 db.execute("INSERT INTO t_oc3 VALUES (1, 'bob', 20) ON CONFLICT DO UPDATE SET name = EXCLUDED.name, score = EXCLUDED.score").unwrap();
18673 let rows = db.query("SELECT name, score FROM t_oc3 WHERE id = 1", &[]).unwrap();
18674 assert_eq!(rows.len(), 1);
18675 assert_eq!(rows[0].values[0], Value::String("bob".to_string()));
18676 assert_eq!(rows[0].values[1], Value::Int4(20));
18677 }
18678
18679 #[test]
18680 fn test_on_conflict_do_nothing_no_conflict() {
18681 let db = EmbeddedDatabase::new_in_memory().unwrap();
18683 db.execute("CREATE TABLE t_oc4 (id INT PRIMARY KEY, name TEXT)").unwrap();
18684 db.execute("INSERT INTO t_oc4 VALUES (1, 'a') ON CONFLICT DO NOTHING").unwrap();
18685 let rows = db.query("SELECT name FROM t_oc4 WHERE id = 1", &[]).unwrap();
18686 assert_eq!(rows.len(), 1);
18687 assert_eq!(rows[0].values[0], Value::String("a".to_string()));
18688 }
18689
18690 #[test]
18691 fn test_on_conflict_do_update_no_conflict() {
18692 let db = EmbeddedDatabase::new_in_memory().unwrap();
18694 db.execute("CREATE TABLE t_oc5 (id INT PRIMARY KEY, name TEXT)").unwrap();
18695 db.execute("INSERT INTO t_oc5 VALUES (1, 'a') ON CONFLICT DO UPDATE SET name = EXCLUDED.name").unwrap();
18696 let rows = db.query("SELECT name FROM t_oc5 WHERE id = 1", &[]).unwrap();
18697 assert_eq!(rows.len(), 1);
18698 assert_eq!(rows[0].values[0], Value::String("a".to_string()));
18699 }
18700
18701 #[test]
18702 fn test_on_conflict_do_nothing_returns_zero() {
18703 let db = EmbeddedDatabase::new_in_memory().unwrap();
18705 db.execute("CREATE TABLE t_oc6 (id INT PRIMARY KEY, name TEXT)").unwrap();
18706 db.execute("INSERT INTO t_oc6 VALUES (1, 'a')").unwrap();
18707 let affected = db.execute("INSERT INTO t_oc6 VALUES (1, 'b') ON CONFLICT DO NOTHING").unwrap();
18708 assert_eq!(affected, 0, "DO NOTHING should report 0 affected rows");
18709 }
18710
18711 #[test]
18712 fn test_on_conflict_do_update_returns_one() {
18713 let db = EmbeddedDatabase::new_in_memory().unwrap();
18715 db.execute("CREATE TABLE t_oc7 (id INT PRIMARY KEY, name TEXT)").unwrap();
18716 db.execute("INSERT INTO t_oc7 VALUES (1, 'a')").unwrap();
18717 let affected = db.execute("INSERT INTO t_oc7 VALUES (1, 'b') ON CONFLICT DO UPDATE SET name = EXCLUDED.name").unwrap();
18718 assert_eq!(affected, 1, "DO UPDATE should report 1 affected row");
18719 }
18720
18721 #[test]
18722 fn test_on_conflict_with_column_list() {
18723 let db = EmbeddedDatabase::new_in_memory().unwrap();
18725 db.execute("CREATE TABLE t_oc8 (id INT PRIMARY KEY, name TEXT, val INT)").unwrap();
18726 db.execute("INSERT INTO t_oc8 (id, name, val) VALUES (1, 'a', 10)").unwrap();
18727 db.execute("INSERT INTO t_oc8 (id, name, val) VALUES (1, 'b', 20) ON CONFLICT DO UPDATE SET name = EXCLUDED.name, val = EXCLUDED.val").unwrap();
18728 let rows = db.query("SELECT name, val FROM t_oc8 WHERE id = 1", &[]).unwrap();
18729 assert_eq!(rows.len(), 1);
18730 assert_eq!(rows[0].values[0], Value::String("b".to_string()));
18731 assert_eq!(rows[0].values[1], Value::Int4(20));
18732 }
18733
18734 #[test]
18735 fn test_on_conflict_do_update_partial() {
18736 let db = EmbeddedDatabase::new_in_memory().unwrap();
18738 db.execute("CREATE TABLE t_oc9 (id INT PRIMARY KEY, name TEXT, val INT)").unwrap();
18739 db.execute("INSERT INTO t_oc9 VALUES (1, 'alice', 10)").unwrap();
18740 db.execute("INSERT INTO t_oc9 VALUES (1, 'bob', 99) ON CONFLICT DO UPDATE SET val = EXCLUDED.val").unwrap();
18742 let rows = db.query("SELECT name, val FROM t_oc9 WHERE id = 1", &[]).unwrap();
18743 assert_eq!(rows.len(), 1);
18744 assert_eq!(rows[0].values[0], Value::String("alice".to_string()), "name should be unchanged");
18745 assert_eq!(rows[0].values[1], Value::Int4(99), "val should be updated");
18746 }
18747
18748 #[test]
18749 fn test_on_conflict_do_update_with_literal() {
18750 let db = EmbeddedDatabase::new_in_memory().unwrap();
18752 db.execute("CREATE TABLE t_oc10 (id INT PRIMARY KEY, name TEXT)").unwrap();
18753 db.execute("INSERT INTO t_oc10 VALUES (1, 'a')").unwrap();
18754 db.execute("INSERT INTO t_oc10 VALUES (1, 'b') ON CONFLICT DO UPDATE SET name = 'replaced'").unwrap();
18755 let rows = db.query("SELECT name FROM t_oc10 WHERE id = 1", &[]).unwrap();
18756 assert_eq!(rows.len(), 1);
18757 assert_eq!(rows[0].values[0], Value::String("replaced".to_string()));
18758 }
18759}