1use std::cell::RefCell;
4use std::num::NonZeroUsize;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7
8fn generate_temp_id() -> u64 {
9 static COUNTER: AtomicU64 = AtomicU64::new(0);
10 let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
11 let nanos = (crate::datetime::now_micros() as u64) & 0xFFFF_FFFF;
12 (nanos << 32) | (counter & 0xFFFF_FFFF)
13}
14
15fn temp_storage_name(temp_id: u64, user_name: &str) -> String {
16 format!("__temp_{temp_id}_{}", user_name.to_ascii_lowercase())
17}
18
19use lru::LruCache;
20
21use citadel::Database;
22use citadel_txn::write_txn::{WriteTxn, WriteTxnSnapshot};
23
24use crate::error::{Result, SqlError};
25use crate::executor;
26use crate::parser;
27use crate::parser::{BeginAccessMode, QueryBody, SelectQuery, Statement};
28use crate::prepared::PreparedStatement;
29use crate::schema::{SchemaManager, SchemaSnapshot};
30use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
31
32const DEFAULT_CACHE_CAPACITY: usize = 64;
33
34fn invalidate_dml_caches(schema: &SchemaManager, db: &Database) {
39 if !schema.has_dml_dirty() {
40 return;
41 }
42 let gen = db.manager().commit_generation();
43 let hard_invalidate = |table: &str| {
44 db.sql_cache_invalidate_prefix(&format!("ann:{table}:"));
45 let marker: std::sync::Arc<dyn std::any::Any + Send + Sync> = std::sync::Arc::new(gen);
46 schema
47 .sql_caches
48 .lock()
49 .insert(crate::executor::ann_dml_gen_key(table), marker);
50 };
51
52 let dirty = schema.drain_dml_dirty();
53 for table in &dirty.mutating {
54 hard_invalidate(table);
55 }
56 for (table, min_pk) in &dirty.appends {
59 if crate::executor::ann_appends_safe(schema, table, *min_pk) {
60 continue;
61 }
62 hard_invalidate(table);
63 }
64}
65
66#[derive(Debug)]
67pub struct ScriptExecution {
68 pub completed: Vec<ExecutionResult>,
69 pub error: Option<SqlError>,
70}
71
72fn parse_fixed_offset(s: &str) -> Option<()> {
73 let s = s.trim();
74 if s.eq_ignore_ascii_case("z") || s.eq_ignore_ascii_case("utc") {
75 return Some(());
76 }
77 let bytes = s.as_bytes();
78 if bytes.is_empty() {
79 return None;
80 }
81 if !matches!(bytes[0], b'+' | b'-') {
82 return None;
83 }
84 let rest = &s[1..];
85 let (hh, mm) = if let Some((h, m)) = rest.split_once(':') {
86 (h, m)
87 } else if rest.len() == 4 {
88 (&rest[..2], &rest[2..])
89 } else if rest.len() == 2 {
90 (rest, "00")
91 } else {
92 return None;
93 };
94 let h: u32 = hh.parse().ok()?;
95 let m: u32 = mm.parse().ok()?;
96 if h > 23 || m > 59 {
97 return None;
98 }
99 Some(())
100}
101
102fn rewrite_show_triggers(sql: &str) -> Option<String> {
103 let trimmed = sql.trim();
104 let trimmed = trimmed.trim_end_matches(';').trim();
105 let lower = trimmed.to_ascii_lowercase();
106 if !lower.starts_with("show triggers") {
107 return None;
108 }
109 let after = lower["show triggers".len()..].trim_start();
110 let base = "SELECT trigger_name, event_object_table AS table_name, action_timing, \
111 event_manipulation, action_orientation, action_statement \
112 FROM information_schema.triggers";
113 if after.is_empty() {
114 return Some(format!("{base} ORDER BY trigger_name"));
115 }
116 if let Some(rest) = after.strip_prefix("on ") {
117 let table = rest.trim().trim_end_matches(';').trim();
118 if table.is_empty() {
119 return None;
120 }
121 let escaped = table.replace('\'', "''");
122 return Some(format!(
123 "{base} WHERE LOWER(event_object_table) = LOWER('{escaped}') ORDER BY trigger_name"
124 ));
125 }
126 None
127}
128
129fn rewrite_show_matviews(sql: &str) -> Option<String> {
130 let trimmed = sql.trim();
131 let trimmed = trimmed.trim_end_matches(';').trim();
132 let lower = trimmed.to_ascii_lowercase();
133 if lower != "show materialized views" {
134 return None;
135 }
136 Some(
137 "SELECT matviewname, ispopulated, hasindexes, definition \
138 FROM pg_matviews ORDER BY matviewname"
139 .to_string(),
140 )
141}
142
143fn stmt_mutates(stmt: &Statement) -> bool {
144 if matches!(
145 stmt,
146 Statement::Insert(_)
147 | Statement::Update(_)
148 | Statement::Delete(_)
149 | Statement::Truncate(_)
150 | Statement::CreateTable(_)
151 | Statement::DropTable(_)
152 | Statement::AlterTable(_)
153 | Statement::CreateIndex(_)
154 | Statement::DropIndex(_)
155 | Statement::CreateView(_)
156 | Statement::DropView(_)
157 | Statement::CreateTrigger(_)
158 | Statement::DropTrigger(_)
159 | Statement::CreateMaterializedView(_)
160 | Statement::RefreshMaterializedView(_)
161 | Statement::DropMaterializedView(_)
162 ) {
163 return true;
164 }
165 if let Statement::Select(sq) = stmt {
166 if select_query_has_dml(sq) {
167 return true;
168 }
169 }
170 false
171}
172
173fn is_txn_control(stmt: &Statement) -> bool {
174 matches!(
175 stmt,
176 Statement::Begin { .. }
177 | Statement::Commit
178 | Statement::Rollback
179 | Statement::Savepoint(_)
180 | Statement::ReleaseSavepoint(_)
181 | Statement::RollbackTo(_)
182 )
183}
184
185fn select_query_has_dml(sq: &SelectQuery) -> bool {
186 sq.ctes.iter().any(|cte| query_body_has_dml(&cte.body)) || query_body_has_dml(&sq.body)
187}
188
189fn query_body_has_dml(body: &QueryBody) -> bool {
190 match body {
191 QueryBody::Insert(_) | QueryBody::Update(_) | QueryBody::Delete(_) => true,
192 QueryBody::Compound(c) => query_body_has_dml(&c.left) || query_body_has_dml(&c.right),
193 QueryBody::Select(_) => false,
194 }
195}
196
197fn try_normalize_insert(sql: &str) -> Option<(String, Vec<Value>)> {
198 let bytes = sql.as_bytes();
199 let len = bytes.len();
200 let mut i = 0;
201
202 while i < len && bytes[i].is_ascii_whitespace() {
203 i += 1;
204 }
205 if i + 6 > len || !bytes[i..i + 6].eq_ignore_ascii_case(b"INSERT") {
206 return None;
207 }
208 i += 6;
209 if i >= len || !bytes[i].is_ascii_whitespace() {
210 return None;
211 }
212 while i < len && bytes[i].is_ascii_whitespace() {
213 i += 1;
214 }
215
216 if i + 4 > len || !bytes[i..i + 4].eq_ignore_ascii_case(b"INTO") {
217 return None;
218 }
219 i += 4;
220 if i >= len || !bytes[i].is_ascii_whitespace() {
221 return None;
222 }
223
224 let prefix_start = 0;
225 let mut values_pos = None;
226 let mut j = i;
227 while j + 6 <= len {
228 if bytes[j..j + 6].eq_ignore_ascii_case(b"VALUES")
229 && (j == 0 || !bytes[j - 1].is_ascii_alphanumeric() && bytes[j - 1] != b'_')
230 && (j + 6 >= len || !bytes[j + 6].is_ascii_alphanumeric() && bytes[j + 6] != b'_')
231 {
232 values_pos = Some(j);
233 break;
234 }
235 j += 1;
236 }
237 let values_pos = values_pos?;
238
239 let prefix = &sql[prefix_start..values_pos + 6];
240 let mut pos = values_pos + 6;
241
242 while pos < len && bytes[pos].is_ascii_whitespace() {
243 pos += 1;
244 }
245 if pos >= len || bytes[pos] != b'(' {
246 return None;
247 }
248 pos += 1;
249
250 let mut values = Vec::new();
251 let mut normalized = String::with_capacity(sql.len());
252 normalized.push_str(prefix);
253 normalized.push_str(" (");
254
255 loop {
256 while pos < len && bytes[pos].is_ascii_whitespace() {
257 pos += 1;
258 }
259 if pos >= len {
260 return None;
261 }
262
263 let param_idx = values.len() + 1;
264 if param_idx > 1 {
265 normalized.push_str(", ");
266 }
267
268 if bytes[pos] == b'\'' {
269 pos += 1;
270 let mut seg_start = pos;
271 let mut s = String::new();
272 loop {
273 if pos >= len {
274 return None;
275 }
276 if bytes[pos] == b'\'' {
277 s.push_str(std::str::from_utf8(&bytes[seg_start..pos]).ok()?);
278 pos += 1;
279 if pos < len && bytes[pos] == b'\'' {
280 s.push('\'');
281 pos += 1;
282 seg_start = pos;
283 } else {
284 break;
285 }
286 } else {
287 pos += 1;
288 }
289 }
290 values.push(Value::Text(s.into()));
291 } else if bytes[pos] == b'-' || bytes[pos].is_ascii_digit() {
292 let start = pos;
293 if bytes[pos] == b'-' {
294 pos += 1;
295 }
296 while pos < len && bytes[pos].is_ascii_digit() {
297 pos += 1;
298 }
299 if pos < len && bytes[pos] == b'.' {
300 pos += 1;
301 while pos < len && bytes[pos].is_ascii_digit() {
302 pos += 1;
303 }
304 let num: f64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
305 values.push(Value::Real(num));
306 } else {
307 let num: i64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
308 values.push(Value::Integer(num));
309 }
310 } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"NULL") {
311 let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
312 if !after.is_ascii_alphanumeric() && after != b'_' {
313 pos += 4;
314 values.push(Value::Null);
315 } else {
316 return None;
317 }
318 } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"TRUE") {
319 let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
320 if !after.is_ascii_alphanumeric() && after != b'_' {
321 pos += 4;
322 values.push(Value::Boolean(true));
323 } else {
324 return None;
325 }
326 } else if pos + 5 <= len && bytes[pos..pos + 5].eq_ignore_ascii_case(b"FALSE") {
327 let after = if pos + 5 < len { bytes[pos + 5] } else { b')' };
328 if !after.is_ascii_alphanumeric() && after != b'_' {
329 pos += 5;
330 values.push(Value::Boolean(false));
331 } else {
332 return None;
333 }
334 } else {
335 return None;
336 }
337
338 normalized.push('$');
339 normalized.push_str(¶m_idx.to_string());
340
341 while pos < len && bytes[pos].is_ascii_whitespace() {
342 pos += 1;
343 }
344 if pos >= len {
345 return None;
346 }
347
348 if bytes[pos] == b',' {
349 pos += 1;
350 } else if bytes[pos] == b')' {
351 pos += 1;
352 break;
353 } else {
354 return None;
355 }
356 }
357
358 normalized.push(')');
359
360 while pos < len && (bytes[pos].is_ascii_whitespace() || bytes[pos] == b';') {
361 pos += 1;
362 }
363 if pos != len {
364 return None;
365 }
366
367 if values.is_empty() {
368 return None;
369 }
370
371 Some((normalized, values))
372}
373
374pub(crate) struct CacheEntry {
375 pub(crate) stmt: Arc<Statement>,
376 pub(crate) schema_gen: u64,
377 pub(crate) param_count: usize,
378 pub(crate) compiled: Option<Arc<dyn executor::CompiledPlan>>,
379}
380
381struct SavepointEntry {
382 name: String,
383 snapshot: Option<SavepointSnapshot>,
384}
385
386struct SavepointSnapshot {
387 wtx_snap: WriteTxnSnapshot,
388 schema_snap: SchemaSnapshot,
389}
390
391#[allow(clippy::large_enum_variant)]
394pub(crate) enum ActiveTxn<'a> {
395 None,
396 Write(WriteTxn<'a>),
397 Read(citadel_txn::read_txn::ReadTxn<'a>),
398}
399
400impl<'a> ActiveTxn<'a> {
401 fn is_none(&self) -> bool {
402 matches!(self, ActiveTxn::None)
403 }
404 fn is_active(&self) -> bool {
405 !self.is_none()
406 }
407 fn is_read_only(&self) -> bool {
408 matches!(self, ActiveTxn::Read(_))
409 }
410 fn as_write_mut(&mut self) -> Option<&mut WriteTxn<'a>> {
411 match self {
412 ActiveTxn::Write(w) => Some(w),
413 _ => None,
414 }
415 }
416 fn take(&mut self) -> ActiveTxn<'a> {
417 std::mem::replace(self, ActiveTxn::None)
418 }
419}
420
421pub(crate) struct ConnectionInner<'a> {
422 pub(crate) schema: SchemaManager,
423 active_txn: ActiveTxn<'a>,
424 savepoint_stack: Vec<SavepointEntry>,
425 in_place_saved: Option<bool>,
426 pub(crate) stmt_cache: LruCache<String, CacheEntry>,
427 txn_start_ts: Option<i64>,
428 session_timezone: String,
429 temp_id: u64,
431 temp_table_names: Vec<String>,
432}
433
434pub struct Connection<'a> {
435 pub(crate) db: &'a Database,
436 pub(crate) inner: RefCell<ConnectionInner<'a>>,
437}
438
439impl<'a> Connection<'a> {
440 pub fn open(db: &'a Database) -> Result<Self> {
441 let schema = SchemaManager::load(db)?;
442 let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
443 let temp_id = generate_temp_id();
444 Ok(Self {
445 db,
446 inner: RefCell::new(ConnectionInner {
447 schema,
448 active_txn: ActiveTxn::None,
449 savepoint_stack: Vec::new(),
450 in_place_saved: None,
451 stmt_cache,
452 txn_start_ts: None,
453 session_timezone: "UTC".to_string(),
454 temp_id,
455 temp_table_names: Vec::new(),
456 }),
457 })
458 }
459
460 pub fn txn_start_ts(&self) -> Option<i64> {
462 self.inner.borrow().txn_start_ts
463 }
464
465 pub fn session_timezone(&self) -> String {
467 self.inner.borrow().session_timezone.clone()
468 }
469
470 pub fn set_session_timezone(&self, tz: &str) -> Result<()> {
472 self.inner.borrow_mut().set_session_timezone_impl(tz)
473 }
474
475 pub fn execute(&self, sql: &str) -> Result<ExecutionResult> {
476 self.inner.borrow_mut().execute_impl(self.db, sql)
477 }
478
479 pub fn execute_params(&self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
480 self.inner
481 .borrow_mut()
482 .execute_params_impl(self.db, sql, params)
483 }
484
485 pub fn execute_script(&self, sql: &str) -> ScriptExecution {
487 let stmts = match parser::parse_sql_multi(sql) {
488 Ok(s) => s,
489 Err(e) => {
490 return ScriptExecution {
491 completed: vec![],
492 error: Some(e),
493 }
494 }
495 };
496 let mut completed = Vec::with_capacity(stmts.len());
497 for stmt in stmts {
498 match self.inner.borrow_mut().dispatch(self.db, &stmt, &[]) {
499 Ok(r) => completed.push(r),
500 Err(e) => {
501 return ScriptExecution {
502 completed,
503 error: Some(e),
504 }
505 }
506 }
507 }
508 ScriptExecution {
509 completed,
510 error: None,
511 }
512 }
513
514 pub fn execute_batch(&self, sql: &str) -> Result<Vec<ExecutionResult>> {
515 self.inner.borrow_mut().execute_batch_impl(self.db, sql)
516 }
517
518 pub fn query(&self, sql: &str) -> Result<QueryResult> {
519 self.query_params(sql, &[])
520 }
521
522 pub fn query_params(&self, sql: &str, params: &[Value]) -> Result<QueryResult> {
523 match self.execute_params(sql, params)? {
524 ExecutionResult::Query(qr) => Ok(qr),
525 ExecutionResult::RowsAffected(n) => Ok(QueryResult {
526 columns: vec!["rows_affected".into()],
527 rows: vec![vec![Value::Integer(n as i64)]],
528 }),
529 ExecutionResult::Ok => Ok(QueryResult {
530 columns: vec![],
531 rows: vec![],
532 }),
533 }
534 }
535
536 pub fn prepare(&self, sql: &str) -> Result<PreparedStatement<'_, 'a>> {
537 if let Some(rewritten) = rewrite_show_triggers(sql) {
538 return PreparedStatement::new(self, &rewritten);
539 }
540 if let Some(rewritten) = rewrite_show_matviews(sql) {
541 return PreparedStatement::new(self, &rewritten);
542 }
543 PreparedStatement::new(self, sql)
544 }
545
546 pub fn tables(&self) -> Vec<String> {
547 self.inner
548 .borrow()
549 .schema
550 .table_names()
551 .into_iter()
552 .map(String::from)
553 .collect()
554 }
555
556 pub fn in_transaction(&self) -> bool {
558 self.inner.borrow().active_txn.is_active()
559 }
560
561 pub fn table_schema(&self, name: &str) -> Option<TableSchema> {
562 self.inner.borrow().schema.get(name).cloned()
563 }
564
565 pub fn refresh_schema(&self) -> Result<()> {
566 let new_schema = SchemaManager::load(self.db)?;
567 self.inner.borrow_mut().schema = new_schema;
568 Ok(())
569 }
570
571 pub fn persist_ann_index(
579 &self,
580 table: &str,
581 column: &str,
582 ) -> Result<crate::executor::AnnSegmentInfo> {
583 if self.in_transaction() {
584 return Err(SqlError::InvalidValue(
585 "persist_ann_index: not allowed inside an explicit transaction".into(),
586 ));
587 }
588 let inner = self.inner.borrow();
589 let lower = table.to_ascii_lowercase();
590 if inner.schema.resolve_temp(&lower) != lower {
591 return Err(SqlError::InvalidValue(
592 "persist_ann_index: TEMP tables are not persistable".into(),
593 ));
594 }
595 let table_schema = inner
596 .schema
597 .get(&lower)
598 .ok_or_else(|| SqlError::TableNotFound(table.to_string()))?;
599 crate::executor::persist_ann_index(self.db, &inner.schema, table_schema, column)
600 }
601
602 pub fn ann_cache_status(
607 &self,
608 table: &str,
609 column: &str,
610 ) -> Result<Option<(crate::executor::AnnIndexSource, u64)>> {
611 let inner = self.inner.borrow();
612 let table_schema = inner
613 .schema
614 .get(&table.to_ascii_lowercase())
615 .ok_or_else(|| SqlError::TableNotFound(table.to_string()))?;
616 crate::executor::ann_cache_status(&inner.schema, table_schema, column)
617 }
618}
619
620impl<'a> ConnectionInner<'a> {
621 pub(crate) fn active_txn_is_some(&self) -> bool {
622 self.active_txn.is_active()
623 }
624
625 fn set_session_timezone_impl(&mut self, tz: &str) -> Result<()> {
626 let upper = tz.to_ascii_uppercase();
627 if (upper.starts_with("UTC+") || upper.starts_with("UTC-")) && tz.len() > 3 {
628 return Err(SqlError::InvalidTimezone(format!(
629 "'{tz}' is ambiguous; use ISO-8601 offset (e.g. '+05:00') or named zone (e.g. 'Etc/GMT-5')"
630 )));
631 }
632 if jiff::tz::TimeZone::get(tz).is_err() && parse_fixed_offset(tz).is_none() {
633 return Err(SqlError::InvalidTimezone(format!(
634 "{tz}: not a known IANA zone or ISO-8601 offset (e.g. '+05:00', 'UTC', 'America/New_York')"
635 )));
636 }
637 self.session_timezone = tz.to_string();
638 Ok(())
639 }
640
641 fn execute_impl(&mut self, db: &'a Database, sql: &str) -> Result<ExecutionResult> {
642 if let Some(rewritten) = rewrite_show_triggers(sql) {
643 return self.execute_params_impl(db, &rewritten, &[]);
644 }
645 if let Some(rewritten) = rewrite_show_matviews(sql) {
646 return self.execute_params_impl(db, &rewritten, &[]);
647 }
648 if matches!(sql.as_bytes().first(), Some(b'I' | b'i')) {
649 if let Some((normalized_key, extracted)) = try_normalize_insert(sql) {
650 let gen = self.schema.generation();
651 let stmt = if let Some(entry) = self.stmt_cache.get(&normalized_key) {
652 if entry.schema_gen == gen {
653 Arc::clone(&entry.stmt)
654 } else {
655 self.parse_and_cache(normalized_key, gen)?
656 }
657 } else {
658 self.parse_and_cache(normalized_key, gen)?
659 };
660 return self.dispatch(db, &stmt, &extracted);
661 }
662 }
663 self.execute_params_impl(db, sql, &[])
664 }
665
666 fn execute_batch_impl(&mut self, db: &'a Database, sql: &str) -> Result<Vec<ExecutionResult>> {
667 if self.active_txn.is_active() {
668 return Err(SqlError::TransactionAlreadyActive);
669 }
670 let stmts = parser::parse_sql_multi(sql)?;
671 if stmts.iter().any(is_txn_control) {
672 return Err(SqlError::Unsupported(
673 "transaction-control statements are not allowed in execute_batch".into(),
674 ));
675 }
676
677 let wtx = db.begin_write().map_err(SqlError::Storage)?;
678 let ts = crate::datetime::txn_or_clock_micros();
679 self.active_txn = ActiveTxn::Write(wtx);
680 self.txn_start_ts = Some(ts);
681 crate::datetime::set_txn_clock(Some(ts));
682
683 let mut results = Vec::with_capacity(stmts.len());
684 for stmt in &stmts {
685 match self.dispatch(db, stmt, &[]) {
686 Ok(r) => results.push(r),
687 Err(e) => {
688 self.abort_active_txn(db);
689 return Err(e);
690 }
691 }
692 }
693
694 let commit = match self.active_txn.take() {
695 ActiveTxn::Write(mut wtx) => {
696 match crate::executor::helpers::drain_deferred_fk_checks(&mut wtx) {
697 Ok(()) => wtx.commit().map_err(SqlError::Storage),
698 Err(e) => {
699 wtx.abort();
700 Err(e)
701 }
702 }
703 }
704 _ => Err(SqlError::NoActiveTransaction),
705 };
706 self.reset_txn_state();
707 match commit {
708 Ok(()) => {
709 invalidate_dml_caches(&self.schema, db);
710 Ok(results)
711 }
712 Err(e) => {
713 self.schema = SchemaManager::load(db)?;
714 Err(e)
715 }
716 }
717 }
718
719 fn reset_txn_state(&mut self) {
720 self.clear_savepoint_state();
721 self.txn_start_ts = None;
722 crate::datetime::set_txn_clock(None);
723 }
724
725 fn abort_active_txn(&mut self, db: &'a Database) {
726 if let ActiveTxn::Write(wtx) = self.active_txn.take() {
727 wtx.abort();
728 }
729 if let Ok(fresh) = SchemaManager::load(db) {
730 self.schema = fresh;
731 }
732 self.reset_txn_state();
733 }
734
735 fn execute_params_impl(
736 &mut self,
737 db: &'a Database,
738 sql: &str,
739 params: &[Value],
740 ) -> Result<ExecutionResult> {
741 let gen = self.schema.generation();
742 if self.active_txn.is_none() {
743 if let Some(entry) = self.stmt_cache.get(sql) {
744 if entry.schema_gen == gen && entry.param_count == params.len() {
745 if let Some(plan) = entry.compiled.as_ref().map(Arc::clone) {
746 let stmt = Arc::clone(&entry.stmt);
747 return self.run_compiled(db, &plan, &stmt, params);
748 }
749 }
750 }
751 }
752
753 let (stmt, param_count) = self.get_or_parse(sql)?;
754
755 if param_count != params.len() {
756 return Err(SqlError::ParameterCountMismatch {
757 expected: param_count,
758 got: params.len(),
759 });
760 }
761
762 if self.active_txn.is_none() {
763 if let Some(plan) = executor::compile(&self.schema, &stmt) {
764 if let Some(e) = self.stmt_cache.get_mut(sql) {
765 e.compiled = Some(Arc::clone(&plan));
766 }
767 let stmt_owned = Arc::clone(&stmt);
768 return self.run_compiled(db, &plan, &stmt_owned, params);
769 }
770 }
771
772 self.dispatch(db, &stmt, params)
773 }
774
775 fn run_compiled(
776 &mut self,
777 db: &'a Database,
778 plan: &Arc<dyn executor::CompiledPlan>,
779 stmt: &Statement,
780 params: &[Value],
781 ) -> Result<ExecutionResult> {
782 use executor::compile::ActiveTxnRef;
783 let schema = &self.schema;
784 let exec = || {
785 if params.is_empty() {
786 plan.execute(db, schema, stmt, params, ActiveTxnRef::None)
787 } else {
788 crate::eval::with_scoped_params(params, || {
789 plan.execute(db, schema, stmt, params, ActiveTxnRef::None)
790 })
791 }
792 };
793 let outcome = if plan.needs_txn_clock() {
794 let cached_ts = self
795 .txn_start_ts
796 .or_else(|| Some(crate::datetime::now_micros()));
797 crate::datetime::with_txn_clock(cached_ts, exec)
798 } else {
799 exec()
800 }?;
801 invalidate_dml_caches(&self.schema, db);
802 Ok(outcome)
803 }
804
805 pub(crate) fn parse_and_cache(
806 &mut self,
807 normalized_key: String,
808 gen: u64,
809 ) -> Result<Arc<Statement>> {
810 let stmt = Arc::new(parser::parse_sql(&normalized_key)?);
811 let param_count = parser::count_params(&stmt);
812 self.stmt_cache.put(
813 normalized_key,
814 CacheEntry {
815 stmt: Arc::clone(&stmt),
816 schema_gen: gen,
817 param_count,
818 compiled: None,
819 },
820 );
821 Ok(stmt)
822 }
823
824 pub(crate) fn get_or_parse(&mut self, sql: &str) -> Result<(Arc<Statement>, usize)> {
825 let gen = self.schema.generation();
826
827 if let Some(entry) = self.stmt_cache.get(sql) {
828 if entry.schema_gen == gen {
829 return Ok((Arc::clone(&entry.stmt), entry.param_count));
830 }
831 }
832
833 let stmt = Arc::new(parser::parse_sql(sql)?);
834 let param_count = parser::count_params(&stmt);
835
836 let cacheable = !matches!(
837 *stmt,
838 Statement::CreateTable(_)
839 | Statement::DropTable(_)
840 | Statement::CreateIndex(_)
841 | Statement::DropIndex(_)
842 | Statement::CreateView(_)
843 | Statement::DropView(_)
844 | Statement::AlterTable(_)
845 );
846
847 if cacheable {
848 self.stmt_cache.put(
849 sql.to_string(),
850 CacheEntry {
851 stmt: Arc::clone(&stmt),
852 schema_gen: gen,
853 param_count,
854 compiled: None,
855 },
856 );
857 }
858
859 Ok((stmt, param_count))
860 }
861
862 pub(crate) fn execute_prepared(
863 &mut self,
864 db: &'a Database,
865 stmt: &Statement,
866 compiled: Option<&Arc<dyn executor::CompiledPlan>>,
867 params: &[Value],
868 ) -> Result<ExecutionResult> {
869 if let Some(plan) = compiled {
870 if self.active_txn.is_none() {
871 return self.run_compiled(db, plan, stmt, params);
872 }
873 if !self.savepoint_stack.is_empty() && stmt_mutates(stmt) {
874 self.capture_pending_snapshots();
875 }
876 return self.run_compiled_in_txn(db, plan, stmt, params);
877 }
878 self.dispatch(db, stmt, params)
879 }
880
881 fn run_compiled_in_txn(
882 &mut self,
883 db: &'a Database,
884 plan: &Arc<dyn executor::CompiledPlan>,
885 stmt: &Statement,
886 params: &[Value],
887 ) -> Result<ExecutionResult> {
888 use executor::compile::ActiveTxnRef;
889 let schema = &self.schema;
890 let txn = match &mut self.active_txn {
891 ActiveTxn::Write(wtx) => ActiveTxnRef::Write(wtx),
892 ActiveTxn::Read(rtx) => ActiveTxnRef::Read(rtx),
893 ActiveTxn::None => ActiveTxnRef::None,
894 };
895 if params.is_empty() || !plan.uses_scoped_params() {
896 plan.execute(db, schema, stmt, params, txn)
897 } else {
898 crate::eval::with_scoped_params(params, || plan.execute(db, schema, stmt, params, txn))
899 }
900 }
901
902 pub(crate) fn dispatch(
903 &mut self,
904 db: &'a Database,
905 stmt: &Statement,
906 params: &[Value],
907 ) -> Result<ExecutionResult> {
908 let cached_ts = self
909 .txn_start_ts
910 .or_else(|| Some(crate::datetime::now_micros()));
911 crate::datetime::with_txn_clock(cached_ts, || {
912 if params.is_empty() {
913 self.dispatch_inner(db, stmt, params)
914 } else {
915 crate::eval::with_scoped_params(params, || self.dispatch_inner(db, stmt, params))
916 }
917 })
918 }
919
920 fn dispatch_inner(
921 &mut self,
922 db: &'a Database,
923 stmt: &Statement,
924 params: &[Value],
925 ) -> Result<ExecutionResult> {
926 match stmt {
927 Statement::Begin { access_mode } => {
928 if self.active_txn.is_active() {
929 return Err(SqlError::TransactionAlreadyActive);
930 }
931 let ts = crate::datetime::txn_or_clock_micros();
932 match access_mode {
933 BeginAccessMode::ReadOnly => {
934 let rtx = db.begin_read();
935 self.active_txn = ActiveTxn::Read(rtx);
936 }
937 BeginAccessMode::ReadWrite | BeginAccessMode::Default => {
938 let wtx = db.begin_write().map_err(SqlError::Storage)?;
939 self.active_txn = ActiveTxn::Write(wtx);
940 }
941 }
942 self.txn_start_ts = Some(ts);
943 crate::datetime::set_txn_clock(Some(ts));
944 Ok(ExecutionResult::Ok)
945 }
946 Statement::Commit => {
947 match self.active_txn.take() {
948 ActiveTxn::None => return Err(SqlError::NoActiveTransaction),
949 ActiveTxn::Write(mut wtx) => {
950 crate::executor::helpers::drain_deferred_fk_checks(&mut wtx)?;
951 wtx.commit().map_err(SqlError::Storage)?;
952 invalidate_dml_caches(&self.schema, db);
953 }
954 ActiveTxn::Read(_rtx) => {}
955 }
956 self.reset_txn_state();
957 Ok(ExecutionResult::Ok)
958 }
959 Statement::Rollback => {
960 match self.active_txn.take() {
961 ActiveTxn::None => return Err(SqlError::NoActiveTransaction),
962 ActiveTxn::Write(wtx) => {
963 wtx.abort();
964 self.schema = SchemaManager::load(db)?;
965 }
966 ActiveTxn::Read(_rtx) => {}
967 }
968 self.reset_txn_state();
969 Ok(ExecutionResult::Ok)
970 }
971 Statement::Savepoint(name) => self.do_savepoint(name),
972 Statement::ReleaseSavepoint(name) => self.do_release(name),
973 Statement::RollbackTo(name) => self.do_rollback_to(name),
974 Statement::SetTimezone(zone) => {
975 self.set_session_timezone_impl(zone)?;
976 Ok(ExecutionResult::Ok)
977 }
978 Statement::CreateTable(ct) if ct.temporary => {
979 if self.active_txn.is_read_only() {
980 return Err(SqlError::Unsupported(
981 "cannot execute mutating statement inside a read-only transaction".into(),
982 ));
983 }
984 let user_name = ct.name.clone();
985 let prefixed = temp_storage_name(self.temp_id, &user_name);
986 if self.schema.contains(&user_name) {
987 if ct.if_not_exists {
988 return Ok(ExecutionResult::Ok);
989 }
990 return Err(SqlError::TableAlreadyExists(user_name));
991 }
992 let mut clone = ct.clone();
993 clone.name = prefixed.clone();
994 clone.temporary = false;
995 let stmt_concrete = Statement::CreateTable(clone);
996 let outcome = if let Some(wtx) = self.active_txn.as_write_mut() {
997 executor::execute_in_txn(wtx, &mut self.schema, &stmt_concrete, params)?
998 } else {
999 executor::execute(db, &mut self.schema, &stmt_concrete, params)?
1000 };
1001 self.schema
1002 .register_temp_alias(&user_name, prefixed.clone());
1003 self.temp_table_names.push(prefixed);
1004 Ok(outcome)
1005 }
1006 Statement::Insert(ins) if self.active_txn.as_write_mut().is_some() => {
1007 self.capture_pending_snapshots();
1008 let wtx = self.active_txn.as_write_mut().unwrap();
1009 executor::exec_insert_in_txn(wtx, &self.schema, ins, params)
1010 }
1011 _ => {
1012 if self.active_txn.is_read_only() && stmt_mutates(stmt) {
1013 return Err(SqlError::Unsupported(
1014 "cannot execute mutating statement inside a read-only transaction".into(),
1015 ));
1016 }
1017 if self.active_txn.as_write_mut().is_some() && stmt_mutates(stmt) {
1018 self.capture_pending_snapshots();
1019 }
1020 let was_auto_commit = matches!(self.active_txn, ActiveTxn::None);
1021 let outcome = match &mut self.active_txn {
1022 ActiveTxn::Write(wtx) => {
1023 executor::execute_in_txn(wtx, &mut self.schema, stmt, params)?
1024 }
1025 ActiveTxn::Read(rtx) => {
1026 executor::execute_with_read(rtx, &self.schema, stmt, params)?
1027 }
1028 ActiveTxn::None => executor::execute(db, &mut self.schema, stmt, params)?,
1029 };
1030 if was_auto_commit {
1031 invalidate_dml_caches(&self.schema, db);
1032 }
1033 if let Statement::DropTable(dt) = stmt {
1034 self.schema.unregister_temp_alias(&dt.name);
1035 }
1036 Ok(outcome)
1037 }
1038 }
1039 }
1040
1041 fn clear_savepoint_state(&mut self) {
1042 self.savepoint_stack.clear();
1043 self.in_place_saved = None;
1044 }
1045
1046 fn do_savepoint(&mut self, name: &str) -> Result<ExecutionResult> {
1047 let wtx = self
1048 .active_txn
1049 .as_write_mut()
1050 .ok_or(SqlError::NoActiveTransaction)?;
1051
1052 if self.savepoint_stack.is_empty() {
1053 self.in_place_saved = Some(wtx.in_place());
1054 wtx.set_in_place(false);
1055 }
1056
1057 self.savepoint_stack.push(SavepointEntry {
1058 name: name.to_string(),
1059 snapshot: None,
1060 });
1061
1062 Ok(ExecutionResult::Ok)
1063 }
1064
1065 fn capture_pending_snapshots(&mut self) {
1066 let last_pending = match self
1067 .savepoint_stack
1068 .iter()
1069 .rposition(|e| e.snapshot.is_none())
1070 {
1071 Some(i) => i,
1072 None => return,
1073 };
1074 let wtx = match self.active_txn.as_write_mut() {
1075 Some(w) => w,
1076 None => return,
1077 };
1078 let wtx_snap = wtx.begin_savepoint();
1079 let schema_snap = self.schema.save_snapshot();
1080
1081 for i in 0..last_pending {
1082 if self.savepoint_stack[i].snapshot.is_none() {
1083 self.savepoint_stack[i].snapshot = Some(SavepointSnapshot {
1084 wtx_snap: wtx_snap.clone(),
1085 schema_snap: schema_snap.clone(),
1086 });
1087 }
1088 }
1089 self.savepoint_stack[last_pending].snapshot = Some(SavepointSnapshot {
1090 wtx_snap,
1091 schema_snap,
1092 });
1093 }
1094
1095 fn do_release(&mut self, name: &str) -> Result<ExecutionResult> {
1096 if !self.active_txn.is_active() {
1097 return Err(SqlError::NoActiveTransaction);
1098 }
1099
1100 let idx = self
1101 .savepoint_stack
1102 .iter()
1103 .rposition(|e| e.name == name)
1104 .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
1105 self.savepoint_stack.truncate(idx);
1106
1107 if self.savepoint_stack.is_empty() {
1108 if let (Some(wtx), Some(original)) =
1109 (self.active_txn.as_write_mut(), self.in_place_saved.take())
1110 {
1111 wtx.set_in_place(original);
1112 }
1113 }
1114
1115 Ok(ExecutionResult::Ok)
1116 }
1117
1118 fn do_rollback_to(&mut self, name: &str) -> Result<ExecutionResult> {
1119 if !self.active_txn.is_active() {
1120 return Err(SqlError::NoActiveTransaction);
1121 }
1122
1123 let idx = self
1124 .savepoint_stack
1125 .iter()
1126 .rposition(|e| e.name == name)
1127 .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
1128
1129 self.savepoint_stack.truncate(idx + 1);
1130 let entry = self.savepoint_stack.last_mut().unwrap();
1131 let snapshot = match entry.snapshot.take() {
1132 Some(s) => s,
1133 None => return Ok(ExecutionResult::Ok),
1134 };
1135
1136 let wtx = match self.active_txn.as_write_mut() {
1137 Some(w) => w,
1138 None => return Err(SqlError::NoActiveTransaction),
1139 };
1140 wtx.restore_snapshot(snapshot.wtx_snap);
1141 self.schema.restore_snapshot(snapshot.schema_snap);
1142
1143 Ok(ExecutionResult::Ok)
1144 }
1145}
1146
1147impl<'a> Drop for Connection<'a> {
1148 fn drop(&mut self) {
1149 let temp_names = std::mem::take(&mut self.inner.borrow_mut().temp_table_names);
1150 if temp_names.is_empty() {
1151 return;
1152 }
1153 if let Ok(mut wtx) = self.db.begin_write() {
1154 for prefixed in &temp_names {
1155 let _ = wtx.drop_table(prefixed.as_bytes());
1156 }
1157 let _ = wtx.commit();
1158 }
1159 }
1160}
1161
1162#[cfg(test)]
1163mod tests {
1164 use super::*;
1165 use citadel::{Argon2Profile, DatabaseBuilder};
1166
1167 fn fresh_db(dir: &std::path::Path) -> citadel::Database {
1168 DatabaseBuilder::new(dir.join("t.db"))
1169 .passphrase(b"test-passphrase")
1170 .argon2_profile(Argon2Profile::Iot)
1171 .create()
1172 .unwrap()
1173 }
1174
1175 #[test]
1176 fn execute_batch_commits_all_statements() {
1177 let dir = tempfile::tempdir().unwrap();
1178 let db = fresh_db(dir.path());
1179 let conn = Connection::open(&db).unwrap();
1180 conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER)")
1181 .unwrap();
1182
1183 let results = conn
1184 .execute_batch(
1185 "INSERT INTO t VALUES (1, 10); \
1186 INSERT INTO t VALUES (2, 20); \
1187 UPDATE t SET n = n + 1 WHERE id = 1;",
1188 )
1189 .unwrap();
1190 assert_eq!(results.len(), 3);
1191
1192 let qr = conn.query("SELECT id, n FROM t ORDER BY id").unwrap();
1193 assert_eq!(qr.rows.len(), 2);
1194 assert_eq!(qr.rows[0], vec![Value::Integer(1), Value::Integer(11)]);
1195 assert_eq!(qr.rows[1], vec![Value::Integer(2), Value::Integer(20)]);
1196 }
1197
1198 #[test]
1199 fn execute_batch_rolls_back_whole_batch_on_error() {
1200 let dir = tempfile::tempdir().unwrap();
1201 let db = fresh_db(dir.path());
1202 let conn = Connection::open(&db).unwrap();
1203 conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER)")
1204 .unwrap();
1205 conn.execute("INSERT INTO t VALUES (1, 100)").unwrap();
1206
1207 let res =
1208 conn.execute_batch("INSERT INTO t VALUES (2, 20); INSERT INTO t VALUES (1, 999);");
1209 assert!(res.is_err());
1210
1211 let qr = conn.query("SELECT id, n FROM t ORDER BY id").unwrap();
1212 assert_eq!(qr.rows.len(), 1);
1213 assert_eq!(qr.rows[0], vec![Value::Integer(1), Value::Integer(100)]);
1214
1215 conn.execute("INSERT INTO t VALUES (3, 30)").unwrap();
1216 assert!(!conn.in_transaction());
1217 }
1218
1219 #[test]
1220 fn execute_batch_rejects_txn_control() {
1221 let dir = tempfile::tempdir().unwrap();
1222 let db = fresh_db(dir.path());
1223 let conn = Connection::open(&db).unwrap();
1224 conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY)")
1225 .unwrap();
1226 assert!(conn
1227 .execute_batch("INSERT INTO t VALUES (1); COMMIT;")
1228 .is_err());
1229 assert!(!conn.in_transaction());
1230 }
1231
1232 #[test]
1233 fn execute_batch_rejected_inside_transaction() {
1234 let dir = tempfile::tempdir().unwrap();
1235 let db = fresh_db(dir.path());
1236 let conn = Connection::open(&db).unwrap();
1237 conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY)")
1238 .unwrap();
1239 conn.execute("BEGIN").unwrap();
1240 assert!(conn.execute_batch("INSERT INTO t VALUES (1);").is_err());
1241 conn.execute("ROLLBACK").unwrap();
1242 }
1243
1244 #[test]
1245 fn streamed_expression_projection_correct() {
1246 let dir = tempfile::tempdir().unwrap();
1247 let db = fresh_db(dir.path());
1248 let conn = Connection::open(&db).unwrap();
1249 conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER)")
1250 .unwrap();
1251 conn.execute_batch(
1252 "INSERT INTO t VALUES (1, 10); INSERT INTO t VALUES (2, 20); INSERT INTO t VALUES (3, 30);",
1253 )
1254 .unwrap();
1255 let stmt = conn.prepare("SELECT id + 1, n * 2 FROM t").unwrap();
1256 let qr = stmt.query_collect(&[]).unwrap();
1257 assert_eq!(qr.rows.len(), 3);
1258 assert_eq!(qr.rows[0], vec![Value::Integer(2), Value::Integer(20)]);
1259 assert_eq!(qr.rows[1], vec![Value::Integer(3), Value::Integer(40)]);
1260 assert_eq!(qr.rows[2], vec![Value::Integer(4), Value::Integer(60)]);
1261 }
1262
1263 #[test]
1264 fn jsonb_contains_raw_predicate_correct() {
1265 let dir = tempfile::tempdir().unwrap();
1266 let db = fresh_db(dir.path());
1267 let conn = Connection::open(&db).unwrap();
1268 conn.execute("CREATE TABLE u (id INTEGER PRIMARY KEY, data JSONB)")
1269 .unwrap();
1270 conn.execute_batch(
1271 "INSERT INTO u VALUES (1, '{\"role\":\"admin\",\"x\":1}'); \
1272 INSERT INTO u VALUES (2, '{\"role\":\"user\"}'); \
1273 INSERT INTO u VALUES (3, NULL); \
1274 INSERT INTO u VALUES (4, '{\"role\":\"admin\"}');",
1275 )
1276 .unwrap();
1277 let stmt = conn
1278 .prepare("SELECT id FROM u WHERE data @> '{\"role\":\"admin\"}'::jsonb")
1279 .unwrap();
1280 let qr = stmt.query_collect(&[]).unwrap();
1281 let mut ids: Vec<i64> = qr
1282 .rows
1283 .iter()
1284 .map(|r| match r[0] {
1285 Value::Integer(i) => i,
1286 _ => -1,
1287 })
1288 .collect();
1289 ids.sort_unstable();
1290 assert_eq!(ids, vec![1, 4]);
1291 }
1292}