1use std::cell::RefCell;
4use std::num::NonZeroUsize;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7
8fn generate_temp_id() -> u64 {
9 static COUNTER: AtomicU64 = AtomicU64::new(0);
10 let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
11 let nanos = (crate::datetime::now_micros() as u64) & 0xFFFF_FFFF;
12 (nanos << 32) | (counter & 0xFFFF_FFFF)
13}
14
15fn temp_storage_name(temp_id: u64, user_name: &str) -> String {
16 format!("__temp_{temp_id}_{}", user_name.to_ascii_lowercase())
17}
18
19use lru::LruCache;
20
21use citadel::Database;
22use citadel_txn::write_txn::{WriteTxn, WriteTxnSnapshot};
23
24use crate::error::{Result, SqlError};
25use crate::executor;
26use crate::parser;
27use crate::parser::{BeginAccessMode, QueryBody, SelectQuery, Statement};
28use crate::prepared::PreparedStatement;
29use crate::schema::{SchemaManager, SchemaSnapshot};
30use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
31
32const DEFAULT_CACHE_CAPACITY: usize = 64;
33
34fn invalidate_dml_caches(schema: &SchemaManager, db: &Database) {
39 let gen = db.manager().commit_generation();
40 let hard_invalidate = |table: &str| {
41 db.sql_cache_invalidate_prefix(&format!("ann:{table}:"));
42 let marker: std::sync::Arc<dyn std::any::Any + Send + Sync> = std::sync::Arc::new(gen);
43 schema
44 .sql_caches
45 .lock()
46 .insert(crate::executor::ann_dml_gen_key(table), marker);
47 };
48
49 let dirty = schema.drain_dml_dirty();
50 for table in &dirty.mutating {
51 hard_invalidate(table);
52 }
53 for (table, min_pk) in &dirty.appends {
56 if crate::executor::ann_appends_safe(schema, table, *min_pk) {
57 continue;
58 }
59 hard_invalidate(table);
60 }
61}
62
63#[derive(Debug)]
64pub struct ScriptExecution {
65 pub completed: Vec<ExecutionResult>,
66 pub error: Option<SqlError>,
67}
68
69fn parse_fixed_offset(s: &str) -> Option<()> {
70 let s = s.trim();
71 if s.eq_ignore_ascii_case("z") || s.eq_ignore_ascii_case("utc") {
72 return Some(());
73 }
74 let bytes = s.as_bytes();
75 if bytes.is_empty() {
76 return None;
77 }
78 if !matches!(bytes[0], b'+' | b'-') {
79 return None;
80 }
81 let rest = &s[1..];
82 let (hh, mm) = if let Some((h, m)) = rest.split_once(':') {
83 (h, m)
84 } else if rest.len() == 4 {
85 (&rest[..2], &rest[2..])
86 } else if rest.len() == 2 {
87 (rest, "00")
88 } else {
89 return None;
90 };
91 let h: u32 = hh.parse().ok()?;
92 let m: u32 = mm.parse().ok()?;
93 if h > 23 || m > 59 {
94 return None;
95 }
96 Some(())
97}
98
99fn rewrite_show_triggers(sql: &str) -> Option<String> {
100 let trimmed = sql.trim();
101 let trimmed = trimmed.trim_end_matches(';').trim();
102 let lower = trimmed.to_ascii_lowercase();
103 if !lower.starts_with("show triggers") {
104 return None;
105 }
106 let after = lower["show triggers".len()..].trim_start();
107 let base = "SELECT trigger_name, event_object_table AS table_name, action_timing, \
108 event_manipulation, action_orientation, action_statement \
109 FROM information_schema.triggers";
110 if after.is_empty() {
111 return Some(format!("{base} ORDER BY trigger_name"));
112 }
113 if let Some(rest) = after.strip_prefix("on ") {
114 let table = rest.trim().trim_end_matches(';').trim();
115 if table.is_empty() {
116 return None;
117 }
118 let escaped = table.replace('\'', "''");
119 return Some(format!(
120 "{base} WHERE LOWER(event_object_table) = LOWER('{escaped}') ORDER BY trigger_name"
121 ));
122 }
123 None
124}
125
126fn rewrite_show_matviews(sql: &str) -> Option<String> {
127 let trimmed = sql.trim();
128 let trimmed = trimmed.trim_end_matches(';').trim();
129 let lower = trimmed.to_ascii_lowercase();
130 if lower != "show materialized views" {
131 return None;
132 }
133 Some(
134 "SELECT matviewname, ispopulated, hasindexes, definition \
135 FROM pg_matviews ORDER BY matviewname"
136 .to_string(),
137 )
138}
139
140fn stmt_mutates(stmt: &Statement) -> bool {
141 if matches!(
142 stmt,
143 Statement::Insert(_)
144 | Statement::Update(_)
145 | Statement::Delete(_)
146 | Statement::Truncate(_)
147 | Statement::CreateTable(_)
148 | Statement::DropTable(_)
149 | Statement::AlterTable(_)
150 | Statement::CreateIndex(_)
151 | Statement::DropIndex(_)
152 | Statement::CreateView(_)
153 | Statement::DropView(_)
154 | Statement::CreateTrigger(_)
155 | Statement::DropTrigger(_)
156 | Statement::CreateMaterializedView(_)
157 | Statement::RefreshMaterializedView(_)
158 | Statement::DropMaterializedView(_)
159 ) {
160 return true;
161 }
162 if let Statement::Select(sq) = stmt {
163 if select_query_has_dml(sq) {
164 return true;
165 }
166 }
167 false
168}
169
170fn select_query_has_dml(sq: &SelectQuery) -> bool {
171 sq.ctes.iter().any(|cte| query_body_has_dml(&cte.body)) || query_body_has_dml(&sq.body)
172}
173
174fn query_body_has_dml(body: &QueryBody) -> bool {
175 match body {
176 QueryBody::Insert(_) | QueryBody::Update(_) | QueryBody::Delete(_) => true,
177 QueryBody::Compound(c) => query_body_has_dml(&c.left) || query_body_has_dml(&c.right),
178 QueryBody::Select(_) => false,
179 }
180}
181
182fn try_normalize_insert(sql: &str) -> Option<(String, Vec<Value>)> {
183 let bytes = sql.as_bytes();
184 let len = bytes.len();
185 let mut i = 0;
186
187 while i < len && bytes[i].is_ascii_whitespace() {
188 i += 1;
189 }
190 if i + 6 > len || !bytes[i..i + 6].eq_ignore_ascii_case(b"INSERT") {
191 return None;
192 }
193 i += 6;
194 if i >= len || !bytes[i].is_ascii_whitespace() {
195 return None;
196 }
197 while i < len && bytes[i].is_ascii_whitespace() {
198 i += 1;
199 }
200
201 if i + 4 > len || !bytes[i..i + 4].eq_ignore_ascii_case(b"INTO") {
202 return None;
203 }
204 i += 4;
205 if i >= len || !bytes[i].is_ascii_whitespace() {
206 return None;
207 }
208
209 let prefix_start = 0;
210 let mut values_pos = None;
211 let mut j = i;
212 while j + 6 <= len {
213 if bytes[j..j + 6].eq_ignore_ascii_case(b"VALUES")
214 && (j == 0 || !bytes[j - 1].is_ascii_alphanumeric() && bytes[j - 1] != b'_')
215 && (j + 6 >= len || !bytes[j + 6].is_ascii_alphanumeric() && bytes[j + 6] != b'_')
216 {
217 values_pos = Some(j);
218 break;
219 }
220 j += 1;
221 }
222 let values_pos = values_pos?;
223
224 let prefix = &sql[prefix_start..values_pos + 6];
225 let mut pos = values_pos + 6;
226
227 while pos < len && bytes[pos].is_ascii_whitespace() {
228 pos += 1;
229 }
230 if pos >= len || bytes[pos] != b'(' {
231 return None;
232 }
233 pos += 1;
234
235 let mut values = Vec::new();
236 let mut normalized = String::with_capacity(sql.len());
237 normalized.push_str(prefix);
238 normalized.push_str(" (");
239
240 loop {
241 while pos < len && bytes[pos].is_ascii_whitespace() {
242 pos += 1;
243 }
244 if pos >= len {
245 return None;
246 }
247
248 let param_idx = values.len() + 1;
249 if param_idx > 1 {
250 normalized.push_str(", ");
251 }
252
253 if bytes[pos] == b'\'' {
254 pos += 1;
255 let mut seg_start = pos;
256 let mut s = String::new();
257 loop {
258 if pos >= len {
259 return None;
260 }
261 if bytes[pos] == b'\'' {
262 s.push_str(std::str::from_utf8(&bytes[seg_start..pos]).ok()?);
263 pos += 1;
264 if pos < len && bytes[pos] == b'\'' {
265 s.push('\'');
266 pos += 1;
267 seg_start = pos;
268 } else {
269 break;
270 }
271 } else {
272 pos += 1;
273 }
274 }
275 values.push(Value::Text(s.into()));
276 } else if bytes[pos] == b'-' || bytes[pos].is_ascii_digit() {
277 let start = pos;
278 if bytes[pos] == b'-' {
279 pos += 1;
280 }
281 while pos < len && bytes[pos].is_ascii_digit() {
282 pos += 1;
283 }
284 if pos < len && bytes[pos] == b'.' {
285 pos += 1;
286 while pos < len && bytes[pos].is_ascii_digit() {
287 pos += 1;
288 }
289 let num: f64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
290 values.push(Value::Real(num));
291 } else {
292 let num: i64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
293 values.push(Value::Integer(num));
294 }
295 } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"NULL") {
296 let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
297 if !after.is_ascii_alphanumeric() && after != b'_' {
298 pos += 4;
299 values.push(Value::Null);
300 } else {
301 return None;
302 }
303 } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"TRUE") {
304 let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
305 if !after.is_ascii_alphanumeric() && after != b'_' {
306 pos += 4;
307 values.push(Value::Boolean(true));
308 } else {
309 return None;
310 }
311 } else if pos + 5 <= len && bytes[pos..pos + 5].eq_ignore_ascii_case(b"FALSE") {
312 let after = if pos + 5 < len { bytes[pos + 5] } else { b')' };
313 if !after.is_ascii_alphanumeric() && after != b'_' {
314 pos += 5;
315 values.push(Value::Boolean(false));
316 } else {
317 return None;
318 }
319 } else {
320 return None;
321 }
322
323 normalized.push('$');
324 normalized.push_str(¶m_idx.to_string());
325
326 while pos < len && bytes[pos].is_ascii_whitespace() {
327 pos += 1;
328 }
329 if pos >= len {
330 return None;
331 }
332
333 if bytes[pos] == b',' {
334 pos += 1;
335 } else if bytes[pos] == b')' {
336 pos += 1;
337 break;
338 } else {
339 return None;
340 }
341 }
342
343 normalized.push(')');
344
345 while pos < len && (bytes[pos].is_ascii_whitespace() || bytes[pos] == b';') {
346 pos += 1;
347 }
348 if pos != len {
349 return None;
350 }
351
352 if values.is_empty() {
353 return None;
354 }
355
356 Some((normalized, values))
357}
358
359pub(crate) struct CacheEntry {
360 pub(crate) stmt: Arc<Statement>,
361 pub(crate) schema_gen: u64,
362 pub(crate) param_count: usize,
363 pub(crate) compiled: Option<Arc<dyn executor::CompiledPlan>>,
364}
365
366struct SavepointEntry {
367 name: String,
368 snapshot: Option<SavepointSnapshot>,
369}
370
371struct SavepointSnapshot {
372 wtx_snap: WriteTxnSnapshot,
373 schema_snap: SchemaSnapshot,
374}
375
376#[allow(clippy::large_enum_variant)]
379pub(crate) enum ActiveTxn<'a> {
380 None,
381 Write(WriteTxn<'a>),
382 Read(citadel_txn::read_txn::ReadTxn<'a>),
383}
384
385impl<'a> ActiveTxn<'a> {
386 fn is_none(&self) -> bool {
387 matches!(self, ActiveTxn::None)
388 }
389 fn is_active(&self) -> bool {
390 !self.is_none()
391 }
392 fn is_read_only(&self) -> bool {
393 matches!(self, ActiveTxn::Read(_))
394 }
395 fn as_write_mut(&mut self) -> Option<&mut WriteTxn<'a>> {
396 match self {
397 ActiveTxn::Write(w) => Some(w),
398 _ => None,
399 }
400 }
401 fn take(&mut self) -> ActiveTxn<'a> {
402 std::mem::replace(self, ActiveTxn::None)
403 }
404}
405
406pub(crate) struct ConnectionInner<'a> {
407 pub(crate) schema: SchemaManager,
408 active_txn: ActiveTxn<'a>,
409 savepoint_stack: Vec<SavepointEntry>,
410 in_place_saved: Option<bool>,
411 pub(crate) stmt_cache: LruCache<String, CacheEntry>,
412 txn_start_ts: Option<i64>,
413 session_timezone: String,
414 temp_id: u64,
416 temp_table_names: Vec<String>,
417}
418
419pub struct Connection<'a> {
420 pub(crate) db: &'a Database,
421 pub(crate) inner: RefCell<ConnectionInner<'a>>,
422}
423
424impl<'a> Connection<'a> {
425 pub fn open(db: &'a Database) -> Result<Self> {
426 let schema = SchemaManager::load(db)?;
427 let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
428 let temp_id = generate_temp_id();
429 Ok(Self {
430 db,
431 inner: RefCell::new(ConnectionInner {
432 schema,
433 active_txn: ActiveTxn::None,
434 savepoint_stack: Vec::new(),
435 in_place_saved: None,
436 stmt_cache,
437 txn_start_ts: None,
438 session_timezone: "UTC".to_string(),
439 temp_id,
440 temp_table_names: Vec::new(),
441 }),
442 })
443 }
444
445 pub fn txn_start_ts(&self) -> Option<i64> {
447 self.inner.borrow().txn_start_ts
448 }
449
450 pub fn session_timezone(&self) -> String {
452 self.inner.borrow().session_timezone.clone()
453 }
454
455 pub fn set_session_timezone(&self, tz: &str) -> Result<()> {
457 self.inner.borrow_mut().set_session_timezone_impl(tz)
458 }
459
460 pub fn execute(&self, sql: &str) -> Result<ExecutionResult> {
461 self.inner.borrow_mut().execute_impl(self.db, sql)
462 }
463
464 pub fn execute_params(&self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
465 self.inner
466 .borrow_mut()
467 .execute_params_impl(self.db, sql, params)
468 }
469
470 pub fn execute_script(&self, sql: &str) -> ScriptExecution {
472 let stmts = match parser::parse_sql_multi(sql) {
473 Ok(s) => s,
474 Err(e) => {
475 return ScriptExecution {
476 completed: vec![],
477 error: Some(e),
478 }
479 }
480 };
481 let mut completed = Vec::with_capacity(stmts.len());
482 for stmt in stmts {
483 match self.inner.borrow_mut().dispatch(self.db, &stmt, &[]) {
484 Ok(r) => completed.push(r),
485 Err(e) => {
486 return ScriptExecution {
487 completed,
488 error: Some(e),
489 }
490 }
491 }
492 }
493 ScriptExecution {
494 completed,
495 error: None,
496 }
497 }
498
499 pub fn query(&self, sql: &str) -> Result<QueryResult> {
500 self.query_params(sql, &[])
501 }
502
503 pub fn query_params(&self, sql: &str, params: &[Value]) -> Result<QueryResult> {
504 match self.execute_params(sql, params)? {
505 ExecutionResult::Query(qr) => Ok(qr),
506 ExecutionResult::RowsAffected(n) => Ok(QueryResult {
507 columns: vec!["rows_affected".into()],
508 rows: vec![vec![Value::Integer(n as i64)]],
509 }),
510 ExecutionResult::Ok => Ok(QueryResult {
511 columns: vec![],
512 rows: vec![],
513 }),
514 }
515 }
516
517 pub fn prepare(&self, sql: &str) -> Result<PreparedStatement<'_, 'a>> {
518 if let Some(rewritten) = rewrite_show_triggers(sql) {
519 return PreparedStatement::new(self, &rewritten);
520 }
521 if let Some(rewritten) = rewrite_show_matviews(sql) {
522 return PreparedStatement::new(self, &rewritten);
523 }
524 PreparedStatement::new(self, sql)
525 }
526
527 pub fn tables(&self) -> Vec<String> {
528 self.inner
529 .borrow()
530 .schema
531 .table_names()
532 .into_iter()
533 .map(String::from)
534 .collect()
535 }
536
537 pub fn in_transaction(&self) -> bool {
539 self.inner.borrow().active_txn.is_active()
540 }
541
542 pub fn table_schema(&self, name: &str) -> Option<TableSchema> {
543 self.inner.borrow().schema.get(name).cloned()
544 }
545
546 pub fn refresh_schema(&self) -> Result<()> {
547 let new_schema = SchemaManager::load(self.db)?;
548 self.inner.borrow_mut().schema = new_schema;
549 Ok(())
550 }
551
552 pub fn persist_ann_index(
560 &self,
561 table: &str,
562 column: &str,
563 ) -> Result<crate::executor::AnnSegmentInfo> {
564 if self.in_transaction() {
565 return Err(SqlError::InvalidValue(
566 "persist_ann_index: not allowed inside an explicit transaction".into(),
567 ));
568 }
569 let inner = self.inner.borrow();
570 let lower = table.to_ascii_lowercase();
571 if inner.schema.resolve_temp(&lower) != lower {
572 return Err(SqlError::InvalidValue(
573 "persist_ann_index: TEMP tables are not persistable".into(),
574 ));
575 }
576 let table_schema = inner
577 .schema
578 .get(&lower)
579 .ok_or_else(|| SqlError::TableNotFound(table.to_string()))?;
580 crate::executor::persist_ann_index(self.db, &inner.schema, table_schema, column)
581 }
582
583 pub fn ann_cache_status(
588 &self,
589 table: &str,
590 column: &str,
591 ) -> Result<Option<(crate::executor::AnnIndexSource, u64)>> {
592 let inner = self.inner.borrow();
593 let table_schema = inner
594 .schema
595 .get(&table.to_ascii_lowercase())
596 .ok_or_else(|| SqlError::TableNotFound(table.to_string()))?;
597 crate::executor::ann_cache_status(&inner.schema, table_schema, column)
598 }
599}
600
601impl<'a> ConnectionInner<'a> {
602 pub(crate) fn active_txn_is_some(&self) -> bool {
603 self.active_txn.is_active()
604 }
605
606 fn set_session_timezone_impl(&mut self, tz: &str) -> Result<()> {
607 let upper = tz.to_ascii_uppercase();
608 if (upper.starts_with("UTC+") || upper.starts_with("UTC-")) && tz.len() > 3 {
609 return Err(SqlError::InvalidTimezone(format!(
610 "'{tz}' is ambiguous; use ISO-8601 offset (e.g. '+05:00') or named zone (e.g. 'Etc/GMT-5')"
611 )));
612 }
613 if jiff::tz::TimeZone::get(tz).is_err() && parse_fixed_offset(tz).is_none() {
614 return Err(SqlError::InvalidTimezone(format!(
615 "{tz}: not a known IANA zone or ISO-8601 offset (e.g. '+05:00', 'UTC', 'America/New_York')"
616 )));
617 }
618 self.session_timezone = tz.to_string();
619 Ok(())
620 }
621
622 fn execute_impl(&mut self, db: &'a Database, sql: &str) -> Result<ExecutionResult> {
623 if let Some(rewritten) = rewrite_show_triggers(sql) {
624 return self.execute_params_impl(db, &rewritten, &[]);
625 }
626 if let Some(rewritten) = rewrite_show_matviews(sql) {
627 return self.execute_params_impl(db, &rewritten, &[]);
628 }
629 if matches!(sql.as_bytes().first(), Some(b'I' | b'i')) {
630 if let Some((normalized_key, extracted)) = try_normalize_insert(sql) {
631 let gen = self.schema.generation();
632 let stmt = if let Some(entry) = self.stmt_cache.get(&normalized_key) {
633 if entry.schema_gen == gen {
634 Arc::clone(&entry.stmt)
635 } else {
636 self.parse_and_cache(normalized_key, gen)?
637 }
638 } else {
639 self.parse_and_cache(normalized_key, gen)?
640 };
641 return self.dispatch(db, &stmt, &extracted);
642 }
643 }
644 self.execute_params_impl(db, sql, &[])
645 }
646
647 fn execute_params_impl(
648 &mut self,
649 db: &'a Database,
650 sql: &str,
651 params: &[Value],
652 ) -> Result<ExecutionResult> {
653 let gen = self.schema.generation();
654 if self.active_txn.is_none() {
655 if let Some(entry) = self.stmt_cache.get(sql) {
656 if entry.schema_gen == gen && entry.param_count == params.len() {
657 if let Some(plan) = entry.compiled.as_ref().map(Arc::clone) {
658 let stmt = Arc::clone(&entry.stmt);
659 return self.run_compiled(db, &plan, &stmt, params);
660 }
661 }
662 }
663 }
664
665 let (stmt, param_count) = self.get_or_parse(sql)?;
666
667 if param_count != params.len() {
668 return Err(SqlError::ParameterCountMismatch {
669 expected: param_count,
670 got: params.len(),
671 });
672 }
673
674 if self.active_txn.is_none() {
675 if let Some(plan) = executor::compile(&self.schema, &stmt) {
676 if let Some(e) = self.stmt_cache.get_mut(sql) {
677 e.compiled = Some(Arc::clone(&plan));
678 }
679 let stmt_owned = Arc::clone(&stmt);
680 return self.run_compiled(db, &plan, &stmt_owned, params);
681 }
682 }
683
684 self.dispatch(db, &stmt, params)
685 }
686
687 fn run_compiled(
688 &mut self,
689 db: &'a Database,
690 plan: &Arc<dyn executor::CompiledPlan>,
691 stmt: &Statement,
692 params: &[Value],
693 ) -> Result<ExecutionResult> {
694 use executor::compile::ActiveTxnRef;
695 let schema = &self.schema;
696 let exec = || {
697 if params.is_empty() {
698 plan.execute(db, schema, stmt, params, ActiveTxnRef::None)
699 } else {
700 crate::eval::with_scoped_params(params, || {
701 plan.execute(db, schema, stmt, params, ActiveTxnRef::None)
702 })
703 }
704 };
705 let outcome = if plan.needs_txn_clock() {
706 let cached_ts = self
707 .txn_start_ts
708 .or_else(|| Some(crate::datetime::now_micros()));
709 crate::datetime::with_txn_clock(cached_ts, exec)
710 } else {
711 exec()
712 }?;
713 invalidate_dml_caches(&self.schema, db);
714 Ok(outcome)
715 }
716
717 pub(crate) fn parse_and_cache(
718 &mut self,
719 normalized_key: String,
720 gen: u64,
721 ) -> Result<Arc<Statement>> {
722 let stmt = Arc::new(parser::parse_sql(&normalized_key)?);
723 let param_count = parser::count_params(&stmt);
724 self.stmt_cache.put(
725 normalized_key,
726 CacheEntry {
727 stmt: Arc::clone(&stmt),
728 schema_gen: gen,
729 param_count,
730 compiled: None,
731 },
732 );
733 Ok(stmt)
734 }
735
736 pub(crate) fn get_or_parse(&mut self, sql: &str) -> Result<(Arc<Statement>, usize)> {
737 let gen = self.schema.generation();
738
739 if let Some(entry) = self.stmt_cache.get(sql) {
740 if entry.schema_gen == gen {
741 return Ok((Arc::clone(&entry.stmt), entry.param_count));
742 }
743 }
744
745 let stmt = Arc::new(parser::parse_sql(sql)?);
746 let param_count = parser::count_params(&stmt);
747
748 let cacheable = !matches!(
749 *stmt,
750 Statement::CreateTable(_)
751 | Statement::DropTable(_)
752 | Statement::CreateIndex(_)
753 | Statement::DropIndex(_)
754 | Statement::CreateView(_)
755 | Statement::DropView(_)
756 | Statement::AlterTable(_)
757 );
758
759 if cacheable {
760 self.stmt_cache.put(
761 sql.to_string(),
762 CacheEntry {
763 stmt: Arc::clone(&stmt),
764 schema_gen: gen,
765 param_count,
766 compiled: None,
767 },
768 );
769 }
770
771 Ok((stmt, param_count))
772 }
773
774 pub(crate) fn execute_prepared(
775 &mut self,
776 db: &'a Database,
777 stmt: &Statement,
778 compiled: Option<&Arc<dyn executor::CompiledPlan>>,
779 params: &[Value],
780 ) -> Result<ExecutionResult> {
781 if let Some(plan) = compiled {
782 if self.active_txn.is_none() {
783 return self.run_compiled(db, plan, stmt, params);
784 }
785 if !self.savepoint_stack.is_empty() && stmt_mutates(stmt) {
786 self.capture_pending_snapshots();
787 }
788 return self.run_compiled_in_txn(db, plan, stmt, params);
789 }
790 self.dispatch(db, stmt, params)
791 }
792
793 fn run_compiled_in_txn(
794 &mut self,
795 db: &'a Database,
796 plan: &Arc<dyn executor::CompiledPlan>,
797 stmt: &Statement,
798 params: &[Value],
799 ) -> Result<ExecutionResult> {
800 use executor::compile::ActiveTxnRef;
801 let schema = &self.schema;
802 let txn = match &mut self.active_txn {
803 ActiveTxn::Write(wtx) => ActiveTxnRef::Write(wtx),
804 ActiveTxn::Read(rtx) => ActiveTxnRef::Read(rtx),
805 ActiveTxn::None => ActiveTxnRef::None,
806 };
807 if params.is_empty() || !plan.uses_scoped_params() {
808 plan.execute(db, schema, stmt, params, txn)
809 } else {
810 crate::eval::with_scoped_params(params, || plan.execute(db, schema, stmt, params, txn))
811 }
812 }
813
814 pub(crate) fn dispatch(
815 &mut self,
816 db: &'a Database,
817 stmt: &Statement,
818 params: &[Value],
819 ) -> Result<ExecutionResult> {
820 let cached_ts = self
821 .txn_start_ts
822 .or_else(|| Some(crate::datetime::now_micros()));
823 crate::datetime::with_txn_clock(cached_ts, || {
824 if params.is_empty() {
825 self.dispatch_inner(db, stmt, params)
826 } else {
827 crate::eval::with_scoped_params(params, || self.dispatch_inner(db, stmt, params))
828 }
829 })
830 }
831
832 fn dispatch_inner(
833 &mut self,
834 db: &'a Database,
835 stmt: &Statement,
836 params: &[Value],
837 ) -> Result<ExecutionResult> {
838 match stmt {
839 Statement::Begin { access_mode } => {
840 if self.active_txn.is_active() {
841 return Err(SqlError::TransactionAlreadyActive);
842 }
843 let ts = crate::datetime::txn_or_clock_micros();
844 match access_mode {
845 BeginAccessMode::ReadOnly => {
846 let rtx = db.begin_read();
847 self.active_txn = ActiveTxn::Read(rtx);
848 }
849 BeginAccessMode::ReadWrite | BeginAccessMode::Default => {
850 let wtx = db.begin_write().map_err(SqlError::Storage)?;
851 self.active_txn = ActiveTxn::Write(wtx);
852 }
853 }
854 self.txn_start_ts = Some(ts);
855 crate::datetime::set_txn_clock(Some(ts));
856 Ok(ExecutionResult::Ok)
857 }
858 Statement::Commit => {
859 match self.active_txn.take() {
860 ActiveTxn::None => return Err(SqlError::NoActiveTransaction),
861 ActiveTxn::Write(mut wtx) => {
862 crate::executor::helpers::drain_deferred_fk_checks(&mut wtx)?;
863 wtx.commit().map_err(SqlError::Storage)?;
864 invalidate_dml_caches(&self.schema, db);
865 }
866 ActiveTxn::Read(_rtx) => {}
867 }
868 self.clear_savepoint_state();
869 self.txn_start_ts = None;
870 crate::datetime::set_txn_clock(None);
871 Ok(ExecutionResult::Ok)
872 }
873 Statement::Rollback => {
874 match self.active_txn.take() {
875 ActiveTxn::None => return Err(SqlError::NoActiveTransaction),
876 ActiveTxn::Write(wtx) => {
877 wtx.abort();
878 self.schema = SchemaManager::load(db)?;
879 }
880 ActiveTxn::Read(_rtx) => {}
881 }
882 self.clear_savepoint_state();
883 self.txn_start_ts = None;
884 crate::datetime::set_txn_clock(None);
885 Ok(ExecutionResult::Ok)
886 }
887 Statement::Savepoint(name) => self.do_savepoint(name),
888 Statement::ReleaseSavepoint(name) => self.do_release(name),
889 Statement::RollbackTo(name) => self.do_rollback_to(name),
890 Statement::SetTimezone(zone) => {
891 self.set_session_timezone_impl(zone)?;
892 Ok(ExecutionResult::Ok)
893 }
894 Statement::CreateTable(ct) if ct.temporary => {
895 if self.active_txn.is_read_only() {
896 return Err(SqlError::Unsupported(
897 "cannot execute mutating statement inside a read-only transaction".into(),
898 ));
899 }
900 let user_name = ct.name.clone();
901 let prefixed = temp_storage_name(self.temp_id, &user_name);
902 if self.schema.contains(&user_name) {
903 if ct.if_not_exists {
904 return Ok(ExecutionResult::Ok);
905 }
906 return Err(SqlError::TableAlreadyExists(user_name));
907 }
908 let mut clone = ct.clone();
909 clone.name = prefixed.clone();
910 clone.temporary = false;
911 let stmt_concrete = Statement::CreateTable(clone);
912 let outcome = if let Some(wtx) = self.active_txn.as_write_mut() {
913 executor::execute_in_txn(wtx, &mut self.schema, &stmt_concrete, params)?
914 } else {
915 executor::execute(db, &mut self.schema, &stmt_concrete, params)?
916 };
917 self.schema
918 .register_temp_alias(&user_name, prefixed.clone());
919 self.temp_table_names.push(prefixed);
920 Ok(outcome)
921 }
922 Statement::Insert(ins) if self.active_txn.as_write_mut().is_some() => {
923 self.capture_pending_snapshots();
924 let wtx = self.active_txn.as_write_mut().unwrap();
925 executor::exec_insert_in_txn(wtx, &self.schema, ins, params)
926 }
927 _ => {
928 if self.active_txn.is_read_only() && stmt_mutates(stmt) {
929 return Err(SqlError::Unsupported(
930 "cannot execute mutating statement inside a read-only transaction".into(),
931 ));
932 }
933 if self.active_txn.as_write_mut().is_some() && stmt_mutates(stmt) {
934 self.capture_pending_snapshots();
935 }
936 let was_auto_commit = matches!(self.active_txn, ActiveTxn::None);
937 let outcome = match &mut self.active_txn {
938 ActiveTxn::Write(wtx) => {
939 executor::execute_in_txn(wtx, &mut self.schema, stmt, params)?
940 }
941 ActiveTxn::Read(rtx) => {
942 executor::execute_with_read(rtx, &self.schema, stmt, params)?
943 }
944 ActiveTxn::None => executor::execute(db, &mut self.schema, stmt, params)?,
945 };
946 if was_auto_commit {
947 invalidate_dml_caches(&self.schema, db);
948 }
949 if let Statement::DropTable(dt) = stmt {
950 self.schema.unregister_temp_alias(&dt.name);
951 }
952 Ok(outcome)
953 }
954 }
955 }
956
957 fn clear_savepoint_state(&mut self) {
958 self.savepoint_stack.clear();
959 self.in_place_saved = None;
960 }
961
962 fn do_savepoint(&mut self, name: &str) -> Result<ExecutionResult> {
963 let wtx = self
964 .active_txn
965 .as_write_mut()
966 .ok_or(SqlError::NoActiveTransaction)?;
967
968 if self.savepoint_stack.is_empty() {
969 self.in_place_saved = Some(wtx.in_place());
970 wtx.set_in_place(false);
971 }
972
973 self.savepoint_stack.push(SavepointEntry {
974 name: name.to_string(),
975 snapshot: None,
976 });
977
978 Ok(ExecutionResult::Ok)
979 }
980
981 fn capture_pending_snapshots(&mut self) {
982 let last_pending = match self
983 .savepoint_stack
984 .iter()
985 .rposition(|e| e.snapshot.is_none())
986 {
987 Some(i) => i,
988 None => return,
989 };
990 let wtx = match self.active_txn.as_write_mut() {
991 Some(w) => w,
992 None => return,
993 };
994 let wtx_snap = wtx.begin_savepoint();
995 let schema_snap = self.schema.save_snapshot();
996
997 for i in 0..last_pending {
998 if self.savepoint_stack[i].snapshot.is_none() {
999 self.savepoint_stack[i].snapshot = Some(SavepointSnapshot {
1000 wtx_snap: wtx_snap.clone(),
1001 schema_snap: schema_snap.clone(),
1002 });
1003 }
1004 }
1005 self.savepoint_stack[last_pending].snapshot = Some(SavepointSnapshot {
1006 wtx_snap,
1007 schema_snap,
1008 });
1009 }
1010
1011 fn do_release(&mut self, name: &str) -> Result<ExecutionResult> {
1012 if !self.active_txn.is_active() {
1013 return Err(SqlError::NoActiveTransaction);
1014 }
1015
1016 let idx = self
1017 .savepoint_stack
1018 .iter()
1019 .rposition(|e| e.name == name)
1020 .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
1021 self.savepoint_stack.truncate(idx);
1022
1023 if self.savepoint_stack.is_empty() {
1024 if let (Some(wtx), Some(original)) =
1025 (self.active_txn.as_write_mut(), self.in_place_saved.take())
1026 {
1027 wtx.set_in_place(original);
1028 }
1029 }
1030
1031 Ok(ExecutionResult::Ok)
1032 }
1033
1034 fn do_rollback_to(&mut self, name: &str) -> Result<ExecutionResult> {
1035 if !self.active_txn.is_active() {
1036 return Err(SqlError::NoActiveTransaction);
1037 }
1038
1039 let idx = self
1040 .savepoint_stack
1041 .iter()
1042 .rposition(|e| e.name == name)
1043 .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
1044
1045 self.savepoint_stack.truncate(idx + 1);
1046 let entry = self.savepoint_stack.last_mut().unwrap();
1047 let snapshot = match entry.snapshot.take() {
1048 Some(s) => s,
1049 None => return Ok(ExecutionResult::Ok),
1050 };
1051
1052 let wtx = match self.active_txn.as_write_mut() {
1053 Some(w) => w,
1054 None => return Err(SqlError::NoActiveTransaction),
1055 };
1056 wtx.restore_snapshot(snapshot.wtx_snap);
1057 self.schema.restore_snapshot(snapshot.schema_snap);
1058
1059 Ok(ExecutionResult::Ok)
1060 }
1061}
1062
1063impl<'a> Drop for Connection<'a> {
1064 fn drop(&mut self) {
1065 let temp_names = std::mem::take(&mut self.inner.borrow_mut().temp_table_names);
1066 if temp_names.is_empty() {
1067 return;
1068 }
1069 if let Ok(mut wtx) = self.db.begin_write() {
1070 for prefixed in &temp_names {
1071 let _ = wtx.drop_table(prefixed.as_bytes());
1072 }
1073 let _ = wtx.commit();
1074 }
1075 }
1076}