1use std::cell::RefCell;
4use std::num::NonZeroUsize;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7
8fn generate_temp_id() -> u64 {
9 static COUNTER: AtomicU64 = AtomicU64::new(0);
10 let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
11 let nanos = (crate::datetime::now_micros() as u64) & 0xFFFF_FFFF;
12 (nanos << 32) | (counter & 0xFFFF_FFFF)
13}
14
15fn temp_storage_name(temp_id: u64, user_name: &str) -> String {
16 format!("__temp_{temp_id}_{}", user_name.to_ascii_lowercase())
17}
18
19use lru::LruCache;
20
21use citadel::Database;
22use citadel_txn::write_txn::{WriteTxn, WriteTxnSnapshot};
23
24use crate::error::{Result, SqlError};
25use crate::executor;
26use crate::parser;
27use crate::parser::{BeginAccessMode, QueryBody, SelectQuery, Statement};
28use crate::prepared::PreparedStatement;
29use crate::schema::{SchemaManager, SchemaSnapshot};
30use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
31
32const DEFAULT_CACHE_CAPACITY: usize = 64;
33
34fn invalidate_dml_caches(schema: &SchemaManager, db: &Database) {
39 if !schema.has_dml_dirty() {
40 return;
41 }
42 let gen = db.manager().commit_generation();
43 let hard_invalidate = |table: &str| {
44 db.sql_cache_invalidate_prefix(&format!("ann:{table}:"));
45 let marker: std::sync::Arc<dyn std::any::Any + Send + Sync> = std::sync::Arc::new(gen);
46 schema
47 .sql_caches
48 .lock()
49 .insert(crate::executor::ann_dml_gen_key(table), marker);
50 };
51
52 let dirty = schema.drain_dml_dirty();
53 for table in &dirty.mutating {
54 hard_invalidate(table);
55 }
56 for (table, min_pk) in &dirty.appends {
59 if crate::executor::ann_appends_safe(schema, table, *min_pk) {
60 continue;
61 }
62 hard_invalidate(table);
63 }
64}
65
66#[derive(Debug)]
67pub struct ScriptExecution {
68 pub completed: Vec<ExecutionResult>,
69 pub error: Option<SqlError>,
70}
71
72fn parse_fixed_offset(s: &str) -> Option<()> {
73 let s = s.trim();
74 if s.eq_ignore_ascii_case("z") || s.eq_ignore_ascii_case("utc") {
75 return Some(());
76 }
77 let bytes = s.as_bytes();
78 if bytes.is_empty() {
79 return None;
80 }
81 if !matches!(bytes[0], b'+' | b'-') {
82 return None;
83 }
84 let rest = &s[1..];
85 let (hh, mm) = if let Some((h, m)) = rest.split_once(':') {
86 (h, m)
87 } else if rest.len() == 4 {
88 (&rest[..2], &rest[2..])
89 } else if rest.len() == 2 {
90 (rest, "00")
91 } else {
92 return None;
93 };
94 let h: u32 = hh.parse().ok()?;
95 let m: u32 = mm.parse().ok()?;
96 if h > 23 || m > 59 {
97 return None;
98 }
99 Some(())
100}
101
102fn rewrite_show_triggers(sql: &str) -> Option<String> {
103 let trimmed = sql.trim();
104 let trimmed = trimmed.trim_end_matches(';').trim();
105 let lower = trimmed.to_ascii_lowercase();
106 if !lower.starts_with("show triggers") {
107 return None;
108 }
109 let after = lower["show triggers".len()..].trim_start();
110 let base = "SELECT trigger_name, event_object_table AS table_name, action_timing, \
111 event_manipulation, action_orientation, action_statement \
112 FROM information_schema.triggers";
113 if after.is_empty() {
114 return Some(format!("{base} ORDER BY trigger_name"));
115 }
116 if let Some(rest) = after.strip_prefix("on ") {
117 let table = rest.trim().trim_end_matches(';').trim();
118 if table.is_empty() {
119 return None;
120 }
121 let escaped = table.replace('\'', "''");
122 return Some(format!(
123 "{base} WHERE LOWER(event_object_table) = LOWER('{escaped}') ORDER BY trigger_name"
124 ));
125 }
126 None
127}
128
129fn rewrite_show_matviews(sql: &str) -> Option<String> {
130 let trimmed = sql.trim();
131 let trimmed = trimmed.trim_end_matches(';').trim();
132 let lower = trimmed.to_ascii_lowercase();
133 if lower != "show materialized views" {
134 return None;
135 }
136 Some(
137 "SELECT matviewname, ispopulated, hasindexes, definition \
138 FROM pg_matviews ORDER BY matviewname"
139 .to_string(),
140 )
141}
142
143fn stmt_mutates(stmt: &Statement) -> bool {
144 if matches!(
145 stmt,
146 Statement::Insert(_)
147 | Statement::Update(_)
148 | Statement::Delete(_)
149 | Statement::Truncate(_)
150 | Statement::CreateTable(_)
151 | Statement::DropTable(_)
152 | Statement::AlterTable(_)
153 | Statement::CreateIndex(_)
154 | Statement::DropIndex(_)
155 | Statement::CreateView(_)
156 | Statement::DropView(_)
157 | Statement::CreateTrigger(_)
158 | Statement::DropTrigger(_)
159 | Statement::CreateMaterializedView(_)
160 | Statement::RefreshMaterializedView(_)
161 | Statement::DropMaterializedView(_)
162 ) {
163 return true;
164 }
165 if let Statement::Select(sq) = stmt {
166 if select_query_has_dml(sq) {
167 return true;
168 }
169 }
170 false
171}
172
173fn is_txn_control(stmt: &Statement) -> bool {
174 matches!(
175 stmt,
176 Statement::Begin { .. }
177 | Statement::Commit
178 | Statement::Rollback
179 | Statement::Savepoint(_)
180 | Statement::ReleaseSavepoint(_)
181 | Statement::RollbackTo(_)
182 )
183}
184
185fn select_query_has_dml(sq: &SelectQuery) -> bool {
186 sq.ctes.iter().any(|cte| query_body_has_dml(&cte.body)) || query_body_has_dml(&sq.body)
187}
188
189fn query_body_has_dml(body: &QueryBody) -> bool {
190 match body {
191 QueryBody::Insert(_) | QueryBody::Update(_) | QueryBody::Delete(_) => true,
192 QueryBody::Compound(c) => query_body_has_dml(&c.left) || query_body_has_dml(&c.right),
193 QueryBody::Select(_) => false,
194 }
195}
196
197fn try_normalize_insert(sql: &str) -> Option<(String, Vec<Value>)> {
198 let bytes = sql.as_bytes();
199 let len = bytes.len();
200 let mut i = 0;
201
202 while i < len && bytes[i].is_ascii_whitespace() {
203 i += 1;
204 }
205 if i + 6 > len || !bytes[i..i + 6].eq_ignore_ascii_case(b"INSERT") {
206 return None;
207 }
208 i += 6;
209 if i >= len || !bytes[i].is_ascii_whitespace() {
210 return None;
211 }
212 while i < len && bytes[i].is_ascii_whitespace() {
213 i += 1;
214 }
215
216 if i + 4 > len || !bytes[i..i + 4].eq_ignore_ascii_case(b"INTO") {
217 return None;
218 }
219 i += 4;
220 if i >= len || !bytes[i].is_ascii_whitespace() {
221 return None;
222 }
223
224 let prefix_start = 0;
225 let mut values_pos = None;
226 let mut j = i;
227 while j + 6 <= len {
228 if bytes[j..j + 6].eq_ignore_ascii_case(b"VALUES")
229 && (j == 0 || !bytes[j - 1].is_ascii_alphanumeric() && bytes[j - 1] != b'_')
230 && (j + 6 >= len || !bytes[j + 6].is_ascii_alphanumeric() && bytes[j + 6] != b'_')
231 {
232 values_pos = Some(j);
233 break;
234 }
235 j += 1;
236 }
237 let values_pos = values_pos?;
238
239 let prefix = &sql[prefix_start..values_pos + 6];
240 let mut pos = values_pos + 6;
241
242 while pos < len && bytes[pos].is_ascii_whitespace() {
243 pos += 1;
244 }
245 if pos >= len || bytes[pos] != b'(' {
246 return None;
247 }
248 pos += 1;
249
250 let mut values = Vec::new();
251 let mut normalized = String::with_capacity(sql.len());
252 normalized.push_str(prefix);
253 normalized.push_str(" (");
254
255 loop {
256 while pos < len && bytes[pos].is_ascii_whitespace() {
257 pos += 1;
258 }
259 if pos >= len {
260 return None;
261 }
262
263 let param_idx = values.len() + 1;
264 if param_idx > 1 {
265 normalized.push_str(", ");
266 }
267
268 if bytes[pos] == b'\'' {
269 pos += 1;
270 let mut seg_start = pos;
271 let mut s = String::new();
272 loop {
273 if pos >= len {
274 return None;
275 }
276 if bytes[pos] == b'\'' {
277 s.push_str(std::str::from_utf8(&bytes[seg_start..pos]).ok()?);
278 pos += 1;
279 if pos < len && bytes[pos] == b'\'' {
280 s.push('\'');
281 pos += 1;
282 seg_start = pos;
283 } else {
284 break;
285 }
286 } else {
287 pos += 1;
288 }
289 }
290 values.push(Value::Text(s.into()));
291 } else if bytes[pos] == b'-' || bytes[pos].is_ascii_digit() {
292 let start = pos;
293 if bytes[pos] == b'-' {
294 pos += 1;
295 }
296 while pos < len && bytes[pos].is_ascii_digit() {
297 pos += 1;
298 }
299 if pos < len && bytes[pos] == b'.' {
300 pos += 1;
301 while pos < len && bytes[pos].is_ascii_digit() {
302 pos += 1;
303 }
304 let num: f64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
305 values.push(Value::Real(num));
306 } else {
307 let num: i64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
308 values.push(Value::Integer(num));
309 }
310 } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"NULL") {
311 let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
312 if !after.is_ascii_alphanumeric() && after != b'_' {
313 pos += 4;
314 values.push(Value::Null);
315 } else {
316 return None;
317 }
318 } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"TRUE") {
319 let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
320 if !after.is_ascii_alphanumeric() && after != b'_' {
321 pos += 4;
322 values.push(Value::Boolean(true));
323 } else {
324 return None;
325 }
326 } else if pos + 5 <= len && bytes[pos..pos + 5].eq_ignore_ascii_case(b"FALSE") {
327 let after = if pos + 5 < len { bytes[pos + 5] } else { b')' };
328 if !after.is_ascii_alphanumeric() && after != b'_' {
329 pos += 5;
330 values.push(Value::Boolean(false));
331 } else {
332 return None;
333 }
334 } else {
335 return None;
336 }
337
338 normalized.push('$');
339 normalized.push_str(¶m_idx.to_string());
340
341 while pos < len && bytes[pos].is_ascii_whitespace() {
342 pos += 1;
343 }
344 if pos >= len {
345 return None;
346 }
347
348 if bytes[pos] == b',' {
349 pos += 1;
350 } else if bytes[pos] == b')' {
351 pos += 1;
352 break;
353 } else {
354 return None;
355 }
356 }
357
358 normalized.push(')');
359
360 while pos < len && (bytes[pos].is_ascii_whitespace() || bytes[pos] == b';') {
361 pos += 1;
362 }
363 if pos != len {
364 return None;
365 }
366
367 if values.is_empty() {
368 return None;
369 }
370
371 Some((normalized, values))
372}
373
374pub(crate) struct CacheEntry {
375 pub(crate) stmt: Arc<Statement>,
376 pub(crate) schema_gen: u64,
377 pub(crate) param_count: usize,
378 pub(crate) compiled: Option<Arc<dyn executor::CompiledPlan>>,
379}
380
381struct SavepointEntry {
382 name: String,
383 snapshot: Option<SavepointSnapshot>,
384}
385
386struct SavepointSnapshot {
387 wtx_snap: WriteTxnSnapshot,
388 schema_snap: SchemaSnapshot,
389}
390
391#[allow(clippy::large_enum_variant)]
394pub(crate) enum ActiveTxn<'a> {
395 None,
396 Write(WriteTxn<'a>),
397 Read(citadel_txn::read_txn::ReadTxn<'a>),
398}
399
400impl<'a> ActiveTxn<'a> {
401 fn is_none(&self) -> bool {
402 matches!(self, ActiveTxn::None)
403 }
404 fn is_active(&self) -> bool {
405 !self.is_none()
406 }
407 fn is_read_only(&self) -> bool {
408 matches!(self, ActiveTxn::Read(_))
409 }
410 fn as_write_mut(&mut self) -> Option<&mut WriteTxn<'a>> {
411 match self {
412 ActiveTxn::Write(w) => Some(w),
413 _ => None,
414 }
415 }
416 fn take(&mut self) -> ActiveTxn<'a> {
417 std::mem::replace(self, ActiveTxn::None)
418 }
419}
420
421pub(crate) struct ConnectionInner<'a> {
422 pub(crate) schema: SchemaManager,
423 active_txn: ActiveTxn<'a>,
424 savepoint_stack: Vec<SavepointEntry>,
425 in_place_saved: Option<bool>,
426 pub(crate) stmt_cache: LruCache<String, CacheEntry>,
427 txn_start_ts: Option<i64>,
428 session_timezone: String,
429 temp_id: u64,
431 temp_table_names: Vec<String>,
432}
433
434pub struct Connection<'a> {
435 pub(crate) db: &'a Database,
436 pub(crate) inner: RefCell<ConnectionInner<'a>>,
437}
438
439impl<'a> Connection<'a> {
440 pub fn open(db: &'a Database) -> Result<Self> {
441 let schema = SchemaManager::load(db)?;
442 let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
443 let temp_id = generate_temp_id();
444 Ok(Self {
445 db,
446 inner: RefCell::new(ConnectionInner {
447 schema,
448 active_txn: ActiveTxn::None,
449 savepoint_stack: Vec::new(),
450 in_place_saved: None,
451 stmt_cache,
452 txn_start_ts: None,
453 session_timezone: "UTC".to_string(),
454 temp_id,
455 temp_table_names: Vec::new(),
456 }),
457 })
458 }
459
460 pub fn txn_start_ts(&self) -> Option<i64> {
462 self.inner.borrow().txn_start_ts
463 }
464
465 pub fn session_timezone(&self) -> String {
467 self.inner.borrow().session_timezone.clone()
468 }
469
470 pub fn set_session_timezone(&self, tz: &str) -> Result<()> {
472 self.inner.borrow_mut().set_session_timezone_impl(tz)
473 }
474
475 pub fn execute(&self, sql: &str) -> Result<ExecutionResult> {
476 self.inner.borrow_mut().execute_impl(self.db, sql)
477 }
478
479 pub fn execute_params(&self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
480 self.inner
481 .borrow_mut()
482 .execute_params_impl(self.db, sql, params)
483 }
484
485 pub fn execute_script(&self, sql: &str) -> ScriptExecution {
487 let stmts = match parser::parse_sql_multi(sql) {
488 Ok(s) => s,
489 Err(e) => {
490 return ScriptExecution {
491 completed: vec![],
492 error: Some(e),
493 }
494 }
495 };
496 let mut completed = Vec::with_capacity(stmts.len());
497 for stmt in stmts {
498 match self.inner.borrow_mut().dispatch(self.db, &stmt, &[]) {
499 Ok(r) => completed.push(r),
500 Err(e) => {
501 return ScriptExecution {
502 completed,
503 error: Some(e),
504 }
505 }
506 }
507 }
508 ScriptExecution {
509 completed,
510 error: None,
511 }
512 }
513
514 pub fn execute_batch(&self, sql: &str) -> Result<Vec<ExecutionResult>> {
515 self.inner.borrow_mut().execute_batch_impl(self.db, sql)
516 }
517
518 pub fn query(&self, sql: &str) -> Result<QueryResult> {
519 self.query_params(sql, &[])
520 }
521
522 pub fn query_params(&self, sql: &str, params: &[Value]) -> Result<QueryResult> {
523 match self.execute_params(sql, params)? {
524 ExecutionResult::Query(qr) => Ok(qr),
525 ExecutionResult::RowsAffected(n) => Ok(QueryResult {
526 columns: vec!["rows_affected".into()],
527 rows: vec![vec![Value::Integer(n as i64)]],
528 }),
529 ExecutionResult::Ok => Ok(QueryResult {
530 columns: vec![],
531 rows: vec![],
532 }),
533 }
534 }
535
536 pub fn prepare(&self, sql: &str) -> Result<PreparedStatement<'_, 'a>> {
537 if let Some(rewritten) = rewrite_show_triggers(sql) {
538 return PreparedStatement::new(self, &rewritten);
539 }
540 if let Some(rewritten) = rewrite_show_matviews(sql) {
541 return PreparedStatement::new(self, &rewritten);
542 }
543 PreparedStatement::new(self, sql)
544 }
545
546 pub fn tables(&self) -> Vec<String> {
547 self.inner
548 .borrow()
549 .schema
550 .table_names()
551 .into_iter()
552 .map(String::from)
553 .collect()
554 }
555
556 pub fn in_transaction(&self) -> bool {
558 self.inner.borrow().active_txn.is_active()
559 }
560
561 pub fn table_schema(&self, name: &str) -> Option<TableSchema> {
562 self.inner.borrow().schema.get(name).cloned()
563 }
564
565 pub fn refresh_schema(&self) -> Result<()> {
566 let new_schema = SchemaManager::load(self.db)?;
567 self.inner.borrow_mut().schema = new_schema;
568 Ok(())
569 }
570
571 pub fn persist_ann_index(
579 &self,
580 table: &str,
581 column: &str,
582 ) -> Result<crate::executor::AnnSegmentInfo> {
583 if self.in_transaction() {
584 return Err(SqlError::InvalidValue(
585 "persist_ann_index: not allowed inside an explicit transaction".into(),
586 ));
587 }
588 let inner = self.inner.borrow();
589 let lower = table.to_ascii_lowercase();
590 if inner.schema.resolve_temp(&lower) != lower {
591 return Err(SqlError::InvalidValue(
592 "persist_ann_index: TEMP tables are not persistable".into(),
593 ));
594 }
595 let table_schema = inner
596 .schema
597 .get(&lower)
598 .ok_or_else(|| SqlError::TableNotFound(table.to_string()))?;
599 crate::executor::persist_ann_index(self.db, &inner.schema, table_schema, column)
600 }
601
602 pub fn ann_cache_status(
607 &self,
608 table: &str,
609 column: &str,
610 ) -> Result<Option<(crate::executor::AnnIndexSource, u64)>> {
611 let inner = self.inner.borrow();
612 let table_schema = inner
613 .schema
614 .get(&table.to_ascii_lowercase())
615 .ok_or_else(|| SqlError::TableNotFound(table.to_string()))?;
616 crate::executor::ann_cache_status(&inner.schema, table_schema, column)
617 }
618}
619
620impl<'a> ConnectionInner<'a> {
621 pub(crate) fn active_txn_is_some(&self) -> bool {
622 self.active_txn.is_active()
623 }
624
625 fn set_session_timezone_impl(&mut self, tz: &str) -> Result<()> {
626 let upper = tz.to_ascii_uppercase();
627 if (upper.starts_with("UTC+") || upper.starts_with("UTC-")) && tz.len() > 3 {
628 return Err(SqlError::InvalidTimezone(format!(
629 "'{tz}' is ambiguous; use ISO-8601 offset (e.g. '+05:00') or named zone (e.g. 'Etc/GMT-5')"
630 )));
631 }
632 if jiff::tz::TimeZone::get(tz).is_err() && parse_fixed_offset(tz).is_none() {
633 return Err(SqlError::InvalidTimezone(format!(
634 "{tz}: not a known IANA zone or ISO-8601 offset (e.g. '+05:00', 'UTC', 'America/New_York')"
635 )));
636 }
637 self.session_timezone = tz.to_string();
638 Ok(())
639 }
640
641 fn execute_impl(&mut self, db: &'a Database, sql: &str) -> Result<ExecutionResult> {
642 if let Some(rewritten) = rewrite_show_triggers(sql) {
643 return self.execute_params_impl(db, &rewritten, &[]);
644 }
645 if let Some(rewritten) = rewrite_show_matviews(sql) {
646 return self.execute_params_impl(db, &rewritten, &[]);
647 }
648 if matches!(sql.as_bytes().first(), Some(b'I' | b'i')) {
649 if let Some((normalized_key, extracted)) = try_normalize_insert(sql) {
650 let gen = self.schema.generation();
651 let stmt = if let Some(entry) = self.stmt_cache.get(&normalized_key) {
652 if entry.schema_gen == gen {
653 Arc::clone(&entry.stmt)
654 } else {
655 self.parse_and_cache(normalized_key, gen)?
656 }
657 } else {
658 self.parse_and_cache(normalized_key, gen)?
659 };
660 return self.dispatch(db, &stmt, &extracted);
661 }
662 }
663 self.execute_params_impl(db, sql, &[])
664 }
665
666 fn execute_batch_impl(&mut self, db: &'a Database, sql: &str) -> Result<Vec<ExecutionResult>> {
667 if self.active_txn.is_active() {
668 return Err(SqlError::TransactionAlreadyActive);
669 }
670 let stmts = parser::parse_sql_multi(sql)?;
671 if stmts.iter().any(is_txn_control) {
672 return Err(SqlError::Unsupported(
673 "transaction-control statements are not allowed in execute_batch".into(),
674 ));
675 }
676
677 let wtx = db.begin_write().map_err(SqlError::Storage)?;
678 let ts = crate::datetime::txn_or_clock_micros();
679 self.active_txn = ActiveTxn::Write(wtx);
680 self.txn_start_ts = Some(ts);
681 crate::datetime::set_txn_clock(Some(ts));
682
683 let mut results = Vec::with_capacity(stmts.len());
684 for stmt in &stmts {
685 match self.dispatch(db, stmt, &[]) {
686 Ok(r) => results.push(r),
687 Err(e) => {
688 self.abort_active_txn(db);
689 return Err(e);
690 }
691 }
692 }
693
694 let commit = match self.active_txn.take() {
695 ActiveTxn::Write(mut wtx) => {
696 match crate::executor::helpers::drain_deferred_fk_checks(&mut wtx) {
697 Ok(()) => wtx.commit().map_err(SqlError::Storage),
698 Err(e) => {
699 wtx.abort();
700 Err(e)
701 }
702 }
703 }
704 _ => Err(SqlError::NoActiveTransaction),
705 };
706 self.reset_txn_state();
707 match commit {
708 Ok(()) => {
709 invalidate_dml_caches(&self.schema, db);
710 Ok(results)
711 }
712 Err(e) => {
713 self.schema = SchemaManager::load(db)?;
714 Err(e)
715 }
716 }
717 }
718
719 fn reset_txn_state(&mut self) {
720 self.clear_savepoint_state();
721 self.txn_start_ts = None;
722 crate::datetime::set_txn_clock(None);
723 }
724
725 fn abort_active_txn(&mut self, db: &'a Database) {
726 if let ActiveTxn::Write(wtx) = self.active_txn.take() {
727 wtx.abort();
728 }
729 if let Ok(mut fresh) = SchemaManager::load(db) {
730 fresh.bump_generation_past(self.schema.generation());
731 self.schema = fresh;
732 }
733 self.reset_txn_state();
734 }
735
736 fn execute_params_impl(
737 &mut self,
738 db: &'a Database,
739 sql: &str,
740 params: &[Value],
741 ) -> Result<ExecutionResult> {
742 let gen = self.schema.generation();
743 if self.active_txn.is_none() {
744 if let Some(entry) = self.stmt_cache.get(sql) {
745 if entry.schema_gen == gen && entry.param_count == params.len() {
746 if let Some(plan) = entry.compiled.as_ref().map(Arc::clone) {
747 let stmt = Arc::clone(&entry.stmt);
748 return self.run_compiled(db, &plan, &stmt, params);
749 }
750 }
751 }
752 }
753
754 let (stmt, param_count) = self.get_or_parse(sql)?;
755
756 if param_count != params.len() {
757 return Err(SqlError::ParameterCountMismatch {
758 expected: param_count,
759 got: params.len(),
760 });
761 }
762
763 if self.active_txn.is_none() {
764 if let Some(plan) = executor::compile(&self.schema, &stmt) {
765 if let Some(e) = self.stmt_cache.get_mut(sql) {
766 e.compiled = Some(Arc::clone(&plan));
767 }
768 let stmt_owned = Arc::clone(&stmt);
769 return self.run_compiled(db, &plan, &stmt_owned, params);
770 }
771 }
772
773 self.dispatch(db, &stmt, params)
774 }
775
776 fn run_compiled(
777 &mut self,
778 db: &'a Database,
779 plan: &Arc<dyn executor::CompiledPlan>,
780 stmt: &Statement,
781 params: &[Value],
782 ) -> Result<ExecutionResult> {
783 use executor::compile::ActiveTxnRef;
784 let schema = &self.schema;
785 let exec = || {
786 if params.is_empty() {
787 plan.execute(db, schema, stmt, params, ActiveTxnRef::None)
788 } else {
789 crate::eval::with_scoped_params(params, || {
790 plan.execute(db, schema, stmt, params, ActiveTxnRef::None)
791 })
792 }
793 };
794 let outcome = if plan.needs_txn_clock() {
795 let cached_ts = self
796 .txn_start_ts
797 .or_else(|| Some(crate::datetime::now_micros()));
798 crate::datetime::with_txn_clock(cached_ts, exec)
799 } else {
800 exec()
801 }?;
802 invalidate_dml_caches(&self.schema, db);
803 Ok(outcome)
804 }
805
806 pub(crate) fn parse_and_cache(
807 &mut self,
808 normalized_key: String,
809 gen: u64,
810 ) -> Result<Arc<Statement>> {
811 let stmt = Arc::new(parser::parse_sql(&normalized_key)?);
812 let param_count = parser::count_params(&stmt);
813 self.stmt_cache.put(
814 normalized_key,
815 CacheEntry {
816 stmt: Arc::clone(&stmt),
817 schema_gen: gen,
818 param_count,
819 compiled: None,
820 },
821 );
822 Ok(stmt)
823 }
824
825 pub(crate) fn get_or_parse(&mut self, sql: &str) -> Result<(Arc<Statement>, usize)> {
826 let gen = self.schema.generation();
827
828 if let Some(entry) = self.stmt_cache.get(sql) {
829 if entry.schema_gen == gen {
830 return Ok((Arc::clone(&entry.stmt), entry.param_count));
831 }
832 }
833
834 let stmt = Arc::new(parser::parse_sql(sql)?);
835 let param_count = parser::count_params(&stmt);
836
837 let cacheable = !matches!(
838 *stmt,
839 Statement::CreateTable(_)
840 | Statement::DropTable(_)
841 | Statement::CreateIndex(_)
842 | Statement::DropIndex(_)
843 | Statement::CreateView(_)
844 | Statement::DropView(_)
845 | Statement::AlterTable(_)
846 );
847
848 if cacheable {
849 self.stmt_cache.put(
850 sql.to_string(),
851 CacheEntry {
852 stmt: Arc::clone(&stmt),
853 schema_gen: gen,
854 param_count,
855 compiled: None,
856 },
857 );
858 }
859
860 Ok((stmt, param_count))
861 }
862
863 pub(crate) fn execute_prepared(
864 &mut self,
865 db: &'a Database,
866 stmt: &Statement,
867 compiled: Option<&Arc<dyn executor::CompiledPlan>>,
868 params: &[Value],
869 ) -> Result<ExecutionResult> {
870 if let Some(plan) = compiled {
871 if self.active_txn.is_none() {
872 return self.run_compiled(db, plan, stmt, params);
873 }
874 if !self.savepoint_stack.is_empty() && stmt_mutates(stmt) {
875 self.capture_pending_snapshots();
876 }
877 return self.run_compiled_in_txn(db, plan, stmt, params);
878 }
879 self.dispatch(db, stmt, params)
880 }
881
882 fn run_compiled_in_txn(
883 &mut self,
884 db: &'a Database,
885 plan: &Arc<dyn executor::CompiledPlan>,
886 stmt: &Statement,
887 params: &[Value],
888 ) -> Result<ExecutionResult> {
889 use executor::compile::ActiveTxnRef;
890 let schema = &self.schema;
891 let txn = match &mut self.active_txn {
892 ActiveTxn::Write(wtx) => ActiveTxnRef::Write(wtx),
893 ActiveTxn::Read(rtx) => ActiveTxnRef::Read(rtx),
894 ActiveTxn::None => ActiveTxnRef::None,
895 };
896 if params.is_empty() || !plan.uses_scoped_params() {
897 plan.execute(db, schema, stmt, params, txn)
898 } else {
899 crate::eval::with_scoped_params(params, || plan.execute(db, schema, stmt, params, txn))
900 }
901 }
902
903 pub(crate) fn dispatch(
904 &mut self,
905 db: &'a Database,
906 stmt: &Statement,
907 params: &[Value],
908 ) -> Result<ExecutionResult> {
909 let cached_ts = self
910 .txn_start_ts
911 .or_else(|| Some(crate::datetime::now_micros()));
912 crate::datetime::with_txn_clock(cached_ts, || {
913 if params.is_empty() {
914 self.dispatch_inner(db, stmt, params)
915 } else {
916 crate::eval::with_scoped_params(params, || self.dispatch_inner(db, stmt, params))
917 }
918 })
919 }
920
921 fn dispatch_inner(
922 &mut self,
923 db: &'a Database,
924 stmt: &Statement,
925 params: &[Value],
926 ) -> Result<ExecutionResult> {
927 match stmt {
928 Statement::Begin { access_mode } => {
929 if self.active_txn.is_active() {
930 return Err(SqlError::TransactionAlreadyActive);
931 }
932 let ts = crate::datetime::txn_or_clock_micros();
933 match access_mode {
934 BeginAccessMode::ReadOnly => {
935 let rtx = db.begin_read();
936 self.active_txn = ActiveTxn::Read(rtx);
937 }
938 BeginAccessMode::ReadWrite | BeginAccessMode::Default => {
939 let wtx = db.begin_write().map_err(SqlError::Storage)?;
940 self.active_txn = ActiveTxn::Write(wtx);
941 }
942 }
943 self.txn_start_ts = Some(ts);
944 crate::datetime::set_txn_clock(Some(ts));
945 Ok(ExecutionResult::Ok)
946 }
947 Statement::Commit => {
948 match self.active_txn.take() {
949 ActiveTxn::None => return Err(SqlError::NoActiveTransaction),
950 ActiveTxn::Write(mut wtx) => {
951 crate::executor::helpers::drain_deferred_fk_checks(&mut wtx)?;
952 wtx.commit().map_err(SqlError::Storage)?;
953 invalidate_dml_caches(&self.schema, db);
954 }
955 ActiveTxn::Read(_rtx) => {}
956 }
957 self.reset_txn_state();
958 Ok(ExecutionResult::Ok)
959 }
960 Statement::Rollback => {
961 match self.active_txn.take() {
962 ActiveTxn::None => return Err(SqlError::NoActiveTransaction),
963 ActiveTxn::Write(wtx) => {
964 wtx.abort();
965 let mut fresh = SchemaManager::load(db)?;
966 fresh.bump_generation_past(self.schema.generation());
967 self.schema = fresh;
968 }
969 ActiveTxn::Read(_rtx) => {}
970 }
971 self.reset_txn_state();
972 Ok(ExecutionResult::Ok)
973 }
974 Statement::Savepoint(name) => self.do_savepoint(name),
975 Statement::ReleaseSavepoint(name) => self.do_release(name),
976 Statement::RollbackTo(name) => self.do_rollback_to(name),
977 Statement::SetTimezone(zone) => {
978 self.set_session_timezone_impl(zone)?;
979 Ok(ExecutionResult::Ok)
980 }
981 Statement::CreateTable(ct) if ct.temporary => {
982 if self.active_txn.is_read_only() {
983 return Err(SqlError::Unsupported(
984 "cannot execute mutating statement inside a read-only transaction".into(),
985 ));
986 }
987 let user_name = ct.name.clone();
988 let prefixed = temp_storage_name(self.temp_id, &user_name);
989 if self.schema.contains(&user_name) {
990 if ct.if_not_exists {
991 return Ok(ExecutionResult::Ok);
992 }
993 return Err(SqlError::TableAlreadyExists(user_name));
994 }
995 let mut clone = ct.clone();
996 clone.name = prefixed.clone();
997 clone.temporary = false;
998 let stmt_concrete = Statement::CreateTable(clone);
999 let outcome = if let Some(wtx) = self.active_txn.as_write_mut() {
1000 executor::execute_in_txn(wtx, &mut self.schema, &stmt_concrete, params)?
1001 } else {
1002 executor::execute(db, &mut self.schema, &stmt_concrete, params)?
1003 };
1004 self.schema
1005 .register_temp_alias(&user_name, prefixed.clone());
1006 self.temp_table_names.push(prefixed);
1007 Ok(outcome)
1008 }
1009 Statement::Insert(ins) if self.active_txn.as_write_mut().is_some() => {
1010 self.capture_pending_snapshots();
1011 let wtx = self.active_txn.as_write_mut().unwrap();
1012 executor::exec_insert_in_txn(wtx, &self.schema, ins, params)
1013 }
1014 _ => {
1015 if self.active_txn.is_read_only() && stmt_mutates(stmt) {
1016 return Err(SqlError::Unsupported(
1017 "cannot execute mutating statement inside a read-only transaction".into(),
1018 ));
1019 }
1020 if self.active_txn.as_write_mut().is_some() && stmt_mutates(stmt) {
1021 self.capture_pending_snapshots();
1022 }
1023 let was_auto_commit = matches!(self.active_txn, ActiveTxn::None);
1024 let outcome = match &mut self.active_txn {
1025 ActiveTxn::Write(wtx) => {
1026 executor::execute_in_txn(wtx, &mut self.schema, stmt, params)?
1027 }
1028 ActiveTxn::Read(rtx) => {
1029 executor::execute_with_read(rtx, &self.schema, stmt, params)?
1030 }
1031 ActiveTxn::None => executor::execute(db, &mut self.schema, stmt, params)?,
1032 };
1033 if was_auto_commit {
1034 invalidate_dml_caches(&self.schema, db);
1035 }
1036 if let Statement::DropTable(dt) = stmt {
1037 self.schema.unregister_temp_alias(&dt.name);
1038 }
1039 Ok(outcome)
1040 }
1041 }
1042 }
1043
1044 fn clear_savepoint_state(&mut self) {
1045 self.savepoint_stack.clear();
1046 self.in_place_saved = None;
1047 }
1048
1049 fn do_savepoint(&mut self, name: &str) -> Result<ExecutionResult> {
1050 let wtx = self
1051 .active_txn
1052 .as_write_mut()
1053 .ok_or(SqlError::NoActiveTransaction)?;
1054
1055 if self.savepoint_stack.is_empty() {
1056 self.in_place_saved = Some(wtx.in_place());
1057 wtx.set_in_place(false);
1058 }
1059
1060 self.savepoint_stack.push(SavepointEntry {
1061 name: name.to_string(),
1062 snapshot: None,
1063 });
1064
1065 Ok(ExecutionResult::Ok)
1066 }
1067
1068 fn capture_pending_snapshots(&mut self) {
1069 let last_pending = match self
1070 .savepoint_stack
1071 .iter()
1072 .rposition(|e| e.snapshot.is_none())
1073 {
1074 Some(i) => i,
1075 None => return,
1076 };
1077 let wtx = match self.active_txn.as_write_mut() {
1078 Some(w) => w,
1079 None => return,
1080 };
1081 let wtx_snap = wtx.begin_savepoint();
1082 let schema_snap = self.schema.save_snapshot();
1083
1084 for i in 0..last_pending {
1085 if self.savepoint_stack[i].snapshot.is_none() {
1086 self.savepoint_stack[i].snapshot = Some(SavepointSnapshot {
1087 wtx_snap: wtx_snap.clone(),
1088 schema_snap: schema_snap.clone(),
1089 });
1090 }
1091 }
1092 self.savepoint_stack[last_pending].snapshot = Some(SavepointSnapshot {
1093 wtx_snap,
1094 schema_snap,
1095 });
1096 }
1097
1098 fn do_release(&mut self, name: &str) -> Result<ExecutionResult> {
1099 if !self.active_txn.is_active() {
1100 return Err(SqlError::NoActiveTransaction);
1101 }
1102
1103 let idx = self
1104 .savepoint_stack
1105 .iter()
1106 .rposition(|e| e.name == name)
1107 .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
1108 self.savepoint_stack.truncate(idx);
1109
1110 if self.savepoint_stack.is_empty() {
1111 if let (Some(wtx), Some(original)) =
1112 (self.active_txn.as_write_mut(), self.in_place_saved.take())
1113 {
1114 wtx.set_in_place(original);
1115 }
1116 }
1117
1118 Ok(ExecutionResult::Ok)
1119 }
1120
1121 fn do_rollback_to(&mut self, name: &str) -> Result<ExecutionResult> {
1122 if !self.active_txn.is_active() {
1123 return Err(SqlError::NoActiveTransaction);
1124 }
1125
1126 let idx = self
1127 .savepoint_stack
1128 .iter()
1129 .rposition(|e| e.name == name)
1130 .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
1131
1132 self.savepoint_stack.truncate(idx + 1);
1133 let entry = self.savepoint_stack.last_mut().unwrap();
1134 let snapshot = match entry.snapshot.take() {
1135 Some(s) => s,
1136 None => return Ok(ExecutionResult::Ok),
1137 };
1138
1139 let wtx = match self.active_txn.as_write_mut() {
1140 Some(w) => w,
1141 None => return Err(SqlError::NoActiveTransaction),
1142 };
1143 wtx.restore_snapshot(snapshot.wtx_snap);
1144 self.schema.restore_snapshot(snapshot.schema_snap);
1145
1146 Ok(ExecutionResult::Ok)
1147 }
1148}
1149
1150impl<'a> Drop for Connection<'a> {
1151 fn drop(&mut self) {
1152 let temp_names = std::mem::take(&mut self.inner.borrow_mut().temp_table_names);
1153 if temp_names.is_empty() {
1154 return;
1155 }
1156 if let Ok(mut wtx) = self.db.begin_write() {
1157 for prefixed in &temp_names {
1158 let _ = wtx.drop_table(prefixed.as_bytes());
1159 }
1160 let _ = wtx.commit();
1161 }
1162 }
1163}
1164
1165#[cfg(test)]
1166mod tests {
1167 use super::*;
1168 use citadel::{Argon2Profile, DatabaseBuilder};
1169
1170 fn fresh_db(dir: &std::path::Path) -> citadel::Database {
1171 DatabaseBuilder::new(dir.join("t.db"))
1172 .passphrase(b"test-passphrase")
1173 .argon2_profile(Argon2Profile::Iot)
1174 .create()
1175 .unwrap()
1176 }
1177
1178 #[test]
1179 fn execute_batch_commits_all_statements() {
1180 let dir = tempfile::tempdir().unwrap();
1181 let db = fresh_db(dir.path());
1182 let conn = Connection::open(&db).unwrap();
1183 conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER)")
1184 .unwrap();
1185
1186 let results = conn
1187 .execute_batch(
1188 "INSERT INTO t VALUES (1, 10); \
1189 INSERT INTO t VALUES (2, 20); \
1190 UPDATE t SET n = n + 1 WHERE id = 1;",
1191 )
1192 .unwrap();
1193 assert_eq!(results.len(), 3);
1194
1195 let qr = conn.query("SELECT id, n FROM t ORDER BY id").unwrap();
1196 assert_eq!(qr.rows.len(), 2);
1197 assert_eq!(qr.rows[0], vec![Value::Integer(1), Value::Integer(11)]);
1198 assert_eq!(qr.rows[1], vec![Value::Integer(2), Value::Integer(20)]);
1199 }
1200
1201 #[test]
1202 fn execute_batch_rolls_back_whole_batch_on_error() {
1203 let dir = tempfile::tempdir().unwrap();
1204 let db = fresh_db(dir.path());
1205 let conn = Connection::open(&db).unwrap();
1206 conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER)")
1207 .unwrap();
1208 conn.execute("INSERT INTO t VALUES (1, 100)").unwrap();
1209
1210 let res =
1211 conn.execute_batch("INSERT INTO t VALUES (2, 20); INSERT INTO t VALUES (1, 999);");
1212 assert!(res.is_err());
1213
1214 let qr = conn.query("SELECT id, n FROM t ORDER BY id").unwrap();
1215 assert_eq!(qr.rows.len(), 1);
1216 assert_eq!(qr.rows[0], vec![Value::Integer(1), Value::Integer(100)]);
1217
1218 conn.execute("INSERT INTO t VALUES (3, 30)").unwrap();
1219 assert!(!conn.in_transaction());
1220 }
1221
1222 #[test]
1223 fn execute_batch_rejects_txn_control() {
1224 let dir = tempfile::tempdir().unwrap();
1225 let db = fresh_db(dir.path());
1226 let conn = Connection::open(&db).unwrap();
1227 conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY)")
1228 .unwrap();
1229 assert!(conn
1230 .execute_batch("INSERT INTO t VALUES (1); COMMIT;")
1231 .is_err());
1232 assert!(!conn.in_transaction());
1233 }
1234
1235 #[test]
1236 fn execute_batch_rejected_inside_transaction() {
1237 let dir = tempfile::tempdir().unwrap();
1238 let db = fresh_db(dir.path());
1239 let conn = Connection::open(&db).unwrap();
1240 conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY)")
1241 .unwrap();
1242 conn.execute("BEGIN").unwrap();
1243 assert!(conn.execute_batch("INSERT INTO t VALUES (1);").is_err());
1244 conn.execute("ROLLBACK").unwrap();
1245 }
1246
1247 #[test]
1248 fn streamed_expression_projection_correct() {
1249 let dir = tempfile::tempdir().unwrap();
1250 let db = fresh_db(dir.path());
1251 let conn = Connection::open(&db).unwrap();
1252 conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER)")
1253 .unwrap();
1254 conn.execute_batch(
1255 "INSERT INTO t VALUES (1, 10); INSERT INTO t VALUES (2, 20); INSERT INTO t VALUES (3, 30);",
1256 )
1257 .unwrap();
1258 let stmt = conn.prepare("SELECT id + 1, n * 2 FROM t").unwrap();
1259 let qr = stmt.query_collect(&[]).unwrap();
1260 assert_eq!(qr.rows.len(), 3);
1261 assert_eq!(qr.rows[0], vec![Value::Integer(2), Value::Integer(20)]);
1262 assert_eq!(qr.rows[1], vec![Value::Integer(3), Value::Integer(40)]);
1263 assert_eq!(qr.rows[2], vec![Value::Integer(4), Value::Integer(60)]);
1264 }
1265
1266 #[test]
1267 fn jsonb_contains_raw_predicate_correct() {
1268 let dir = tempfile::tempdir().unwrap();
1269 let db = fresh_db(dir.path());
1270 let conn = Connection::open(&db).unwrap();
1271 conn.execute("CREATE TABLE u (id INTEGER PRIMARY KEY, data JSONB)")
1272 .unwrap();
1273 conn.execute_batch(
1274 "INSERT INTO u VALUES (1, '{\"role\":\"admin\",\"x\":1}'); \
1275 INSERT INTO u VALUES (2, '{\"role\":\"user\"}'); \
1276 INSERT INTO u VALUES (3, NULL); \
1277 INSERT INTO u VALUES (4, '{\"role\":\"admin\"}');",
1278 )
1279 .unwrap();
1280 let stmt = conn
1281 .prepare("SELECT id FROM u WHERE data @> '{\"role\":\"admin\"}'::jsonb")
1282 .unwrap();
1283 let qr = stmt.query_collect(&[]).unwrap();
1284 let mut ids: Vec<i64> = qr
1285 .rows
1286 .iter()
1287 .map(|r| match r[0] {
1288 Value::Integer(i) => i,
1289 _ => -1,
1290 })
1291 .collect();
1292 ids.sort_unstable();
1293 assert_eq!(ids, vec![1, 4]);
1294 }
1295}