1use std::cell::RefCell;
4use std::num::NonZeroUsize;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7
8fn generate_temp_id() -> u64 {
9 static COUNTER: AtomicU64 = AtomicU64::new(0);
10 let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
11 let nanos = (crate::datetime::now_micros() as u64) & 0xFFFF_FFFF;
12 (nanos << 32) | (counter & 0xFFFF_FFFF)
13}
14
15fn temp_storage_name(temp_id: u64, user_name: &str) -> String {
16 format!("__temp_{temp_id}_{}", user_name.to_ascii_lowercase())
17}
18
19use lru::LruCache;
20
21use citadel::Database;
22use citadel_txn::write_txn::{WriteTxn, WriteTxnSnapshot};
23
24use crate::error::{Result, SqlError};
25use crate::executor;
26use crate::parser;
27use crate::parser::{BeginAccessMode, QueryBody, SelectQuery, Statement};
28use crate::prepared::PreparedStatement;
29use crate::schema::{SchemaManager, SchemaSnapshot};
30use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
31
32const DEFAULT_CACHE_CAPACITY: usize = 64;
33
34fn invalidate_dml_caches(schema: &SchemaManager, db: &Database) {
39 let gen = db.manager().commit_generation();
40 for table in schema.drain_dml_dirty() {
41 db.sql_cache_invalidate_prefix(&format!("ann:{table}:"));
42 let marker: std::sync::Arc<dyn std::any::Any + Send + Sync> = std::sync::Arc::new(gen);
43 schema
44 .sql_caches
45 .lock()
46 .insert(crate::executor::ann_dml_gen_key(&table), marker);
47 }
48}
49
50#[derive(Debug)]
51pub struct ScriptExecution {
52 pub completed: Vec<ExecutionResult>,
53 pub error: Option<SqlError>,
54}
55
56fn parse_fixed_offset(s: &str) -> Option<()> {
57 let s = s.trim();
58 if s.eq_ignore_ascii_case("z") || s.eq_ignore_ascii_case("utc") {
59 return Some(());
60 }
61 let bytes = s.as_bytes();
62 if bytes.is_empty() {
63 return None;
64 }
65 if !matches!(bytes[0], b'+' | b'-') {
66 return None;
67 }
68 let rest = &s[1..];
69 let (hh, mm) = if let Some((h, m)) = rest.split_once(':') {
70 (h, m)
71 } else if rest.len() == 4 {
72 (&rest[..2], &rest[2..])
73 } else if rest.len() == 2 {
74 (rest, "00")
75 } else {
76 return None;
77 };
78 let h: u32 = hh.parse().ok()?;
79 let m: u32 = mm.parse().ok()?;
80 if h > 23 || m > 59 {
81 return None;
82 }
83 Some(())
84}
85
86fn rewrite_show_triggers(sql: &str) -> Option<String> {
87 let trimmed = sql.trim();
88 let trimmed = trimmed.trim_end_matches(';').trim();
89 let lower = trimmed.to_ascii_lowercase();
90 if !lower.starts_with("show triggers") {
91 return None;
92 }
93 let after = lower["show triggers".len()..].trim_start();
94 let base = "SELECT trigger_name, event_object_table AS table_name, action_timing, \
95 event_manipulation, action_orientation, action_statement \
96 FROM information_schema.triggers";
97 if after.is_empty() {
98 return Some(format!("{base} ORDER BY trigger_name"));
99 }
100 if let Some(rest) = after.strip_prefix("on ") {
101 let table = rest.trim().trim_end_matches(';').trim();
102 if table.is_empty() {
103 return None;
104 }
105 let escaped = table.replace('\'', "''");
106 return Some(format!(
107 "{base} WHERE LOWER(event_object_table) = LOWER('{escaped}') ORDER BY trigger_name"
108 ));
109 }
110 None
111}
112
113fn rewrite_show_matviews(sql: &str) -> Option<String> {
114 let trimmed = sql.trim();
115 let trimmed = trimmed.trim_end_matches(';').trim();
116 let lower = trimmed.to_ascii_lowercase();
117 if lower != "show materialized views" {
118 return None;
119 }
120 Some(
121 "SELECT matviewname, ispopulated, hasindexes, definition \
122 FROM pg_matviews ORDER BY matviewname"
123 .to_string(),
124 )
125}
126
127fn stmt_mutates(stmt: &Statement) -> bool {
128 if matches!(
129 stmt,
130 Statement::Insert(_)
131 | Statement::Update(_)
132 | Statement::Delete(_)
133 | Statement::Truncate(_)
134 | Statement::CreateTable(_)
135 | Statement::DropTable(_)
136 | Statement::AlterTable(_)
137 | Statement::CreateIndex(_)
138 | Statement::DropIndex(_)
139 | Statement::CreateView(_)
140 | Statement::DropView(_)
141 | Statement::CreateTrigger(_)
142 | Statement::DropTrigger(_)
143 | Statement::CreateMaterializedView(_)
144 | Statement::RefreshMaterializedView(_)
145 | Statement::DropMaterializedView(_)
146 ) {
147 return true;
148 }
149 if let Statement::Select(sq) = stmt {
150 if select_query_has_dml(sq) {
151 return true;
152 }
153 }
154 false
155}
156
157fn select_query_has_dml(sq: &SelectQuery) -> bool {
158 sq.ctes.iter().any(|cte| query_body_has_dml(&cte.body)) || query_body_has_dml(&sq.body)
159}
160
161fn query_body_has_dml(body: &QueryBody) -> bool {
162 match body {
163 QueryBody::Insert(_) | QueryBody::Update(_) | QueryBody::Delete(_) => true,
164 QueryBody::Compound(c) => query_body_has_dml(&c.left) || query_body_has_dml(&c.right),
165 QueryBody::Select(_) => false,
166 }
167}
168
169fn try_normalize_insert(sql: &str) -> Option<(String, Vec<Value>)> {
170 let bytes = sql.as_bytes();
171 let len = bytes.len();
172 let mut i = 0;
173
174 while i < len && bytes[i].is_ascii_whitespace() {
175 i += 1;
176 }
177 if i + 6 > len || !bytes[i..i + 6].eq_ignore_ascii_case(b"INSERT") {
178 return None;
179 }
180 i += 6;
181 if i >= len || !bytes[i].is_ascii_whitespace() {
182 return None;
183 }
184 while i < len && bytes[i].is_ascii_whitespace() {
185 i += 1;
186 }
187
188 if i + 4 > len || !bytes[i..i + 4].eq_ignore_ascii_case(b"INTO") {
189 return None;
190 }
191 i += 4;
192 if i >= len || !bytes[i].is_ascii_whitespace() {
193 return None;
194 }
195
196 let prefix_start = 0;
197 let mut values_pos = None;
198 let mut j = i;
199 while j + 6 <= len {
200 if bytes[j..j + 6].eq_ignore_ascii_case(b"VALUES")
201 && (j == 0 || !bytes[j - 1].is_ascii_alphanumeric() && bytes[j - 1] != b'_')
202 && (j + 6 >= len || !bytes[j + 6].is_ascii_alphanumeric() && bytes[j + 6] != b'_')
203 {
204 values_pos = Some(j);
205 break;
206 }
207 j += 1;
208 }
209 let values_pos = values_pos?;
210
211 let prefix = &sql[prefix_start..values_pos + 6];
212 let mut pos = values_pos + 6;
213
214 while pos < len && bytes[pos].is_ascii_whitespace() {
215 pos += 1;
216 }
217 if pos >= len || bytes[pos] != b'(' {
218 return None;
219 }
220 pos += 1;
221
222 let mut values = Vec::new();
223 let mut normalized = String::with_capacity(sql.len());
224 normalized.push_str(prefix);
225 normalized.push_str(" (");
226
227 loop {
228 while pos < len && bytes[pos].is_ascii_whitespace() {
229 pos += 1;
230 }
231 if pos >= len {
232 return None;
233 }
234
235 let param_idx = values.len() + 1;
236 if param_idx > 1 {
237 normalized.push_str(", ");
238 }
239
240 if bytes[pos] == b'\'' {
241 pos += 1;
242 let mut seg_start = pos;
243 let mut s = String::new();
244 loop {
245 if pos >= len {
246 return None;
247 }
248 if bytes[pos] == b'\'' {
249 s.push_str(std::str::from_utf8(&bytes[seg_start..pos]).ok()?);
250 pos += 1;
251 if pos < len && bytes[pos] == b'\'' {
252 s.push('\'');
253 pos += 1;
254 seg_start = pos;
255 } else {
256 break;
257 }
258 } else {
259 pos += 1;
260 }
261 }
262 values.push(Value::Text(s.into()));
263 } else if bytes[pos] == b'-' || bytes[pos].is_ascii_digit() {
264 let start = pos;
265 if bytes[pos] == b'-' {
266 pos += 1;
267 }
268 while pos < len && bytes[pos].is_ascii_digit() {
269 pos += 1;
270 }
271 if pos < len && bytes[pos] == b'.' {
272 pos += 1;
273 while pos < len && bytes[pos].is_ascii_digit() {
274 pos += 1;
275 }
276 let num: f64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
277 values.push(Value::Real(num));
278 } else {
279 let num: i64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
280 values.push(Value::Integer(num));
281 }
282 } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"NULL") {
283 let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
284 if !after.is_ascii_alphanumeric() && after != b'_' {
285 pos += 4;
286 values.push(Value::Null);
287 } else {
288 return None;
289 }
290 } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"TRUE") {
291 let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
292 if !after.is_ascii_alphanumeric() && after != b'_' {
293 pos += 4;
294 values.push(Value::Boolean(true));
295 } else {
296 return None;
297 }
298 } else if pos + 5 <= len && bytes[pos..pos + 5].eq_ignore_ascii_case(b"FALSE") {
299 let after = if pos + 5 < len { bytes[pos + 5] } else { b')' };
300 if !after.is_ascii_alphanumeric() && after != b'_' {
301 pos += 5;
302 values.push(Value::Boolean(false));
303 } else {
304 return None;
305 }
306 } else {
307 return None;
308 }
309
310 normalized.push('$');
311 normalized.push_str(¶m_idx.to_string());
312
313 while pos < len && bytes[pos].is_ascii_whitespace() {
314 pos += 1;
315 }
316 if pos >= len {
317 return None;
318 }
319
320 if bytes[pos] == b',' {
321 pos += 1;
322 } else if bytes[pos] == b')' {
323 pos += 1;
324 break;
325 } else {
326 return None;
327 }
328 }
329
330 normalized.push(')');
331
332 while pos < len && (bytes[pos].is_ascii_whitespace() || bytes[pos] == b';') {
333 pos += 1;
334 }
335 if pos != len {
336 return None;
337 }
338
339 if values.is_empty() {
340 return None;
341 }
342
343 Some((normalized, values))
344}
345
346pub(crate) struct CacheEntry {
347 pub(crate) stmt: Arc<Statement>,
348 pub(crate) schema_gen: u64,
349 pub(crate) param_count: usize,
350 pub(crate) compiled: Option<Arc<dyn executor::CompiledPlan>>,
351}
352
353struct SavepointEntry {
354 name: String,
355 snapshot: Option<SavepointSnapshot>,
356}
357
358struct SavepointSnapshot {
359 wtx_snap: WriteTxnSnapshot,
360 schema_snap: SchemaSnapshot,
361}
362
363#[allow(clippy::large_enum_variant)]
366pub(crate) enum ActiveTxn<'a> {
367 None,
368 Write(WriteTxn<'a>),
369 Read(citadel_txn::read_txn::ReadTxn<'a>),
370}
371
372impl<'a> ActiveTxn<'a> {
373 fn is_none(&self) -> bool {
374 matches!(self, ActiveTxn::None)
375 }
376 fn is_active(&self) -> bool {
377 !self.is_none()
378 }
379 fn is_read_only(&self) -> bool {
380 matches!(self, ActiveTxn::Read(_))
381 }
382 fn as_write_mut(&mut self) -> Option<&mut WriteTxn<'a>> {
383 match self {
384 ActiveTxn::Write(w) => Some(w),
385 _ => None,
386 }
387 }
388 fn take(&mut self) -> ActiveTxn<'a> {
389 std::mem::replace(self, ActiveTxn::None)
390 }
391}
392
393pub(crate) struct ConnectionInner<'a> {
394 pub(crate) schema: SchemaManager,
395 active_txn: ActiveTxn<'a>,
396 savepoint_stack: Vec<SavepointEntry>,
397 in_place_saved: Option<bool>,
398 pub(crate) stmt_cache: LruCache<String, CacheEntry>,
399 txn_start_ts: Option<i64>,
400 session_timezone: String,
401 temp_id: u64,
403 temp_table_names: Vec<String>,
404}
405
406pub struct Connection<'a> {
407 pub(crate) db: &'a Database,
408 pub(crate) inner: RefCell<ConnectionInner<'a>>,
409}
410
411impl<'a> Connection<'a> {
412 pub fn open(db: &'a Database) -> Result<Self> {
413 let schema = SchemaManager::load(db)?;
414 let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
415 let temp_id = generate_temp_id();
416 Ok(Self {
417 db,
418 inner: RefCell::new(ConnectionInner {
419 schema,
420 active_txn: ActiveTxn::None,
421 savepoint_stack: Vec::new(),
422 in_place_saved: None,
423 stmt_cache,
424 txn_start_ts: None,
425 session_timezone: "UTC".to_string(),
426 temp_id,
427 temp_table_names: Vec::new(),
428 }),
429 })
430 }
431
432 pub fn txn_start_ts(&self) -> Option<i64> {
434 self.inner.borrow().txn_start_ts
435 }
436
437 pub fn session_timezone(&self) -> String {
439 self.inner.borrow().session_timezone.clone()
440 }
441
442 pub fn set_session_timezone(&self, tz: &str) -> Result<()> {
444 self.inner.borrow_mut().set_session_timezone_impl(tz)
445 }
446
447 pub fn execute(&self, sql: &str) -> Result<ExecutionResult> {
448 self.inner.borrow_mut().execute_impl(self.db, sql)
449 }
450
451 pub fn execute_params(&self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
452 self.inner
453 .borrow_mut()
454 .execute_params_impl(self.db, sql, params)
455 }
456
457 pub fn execute_script(&self, sql: &str) -> ScriptExecution {
459 let stmts = match parser::parse_sql_multi(sql) {
460 Ok(s) => s,
461 Err(e) => {
462 return ScriptExecution {
463 completed: vec![],
464 error: Some(e),
465 }
466 }
467 };
468 let mut completed = Vec::with_capacity(stmts.len());
469 for stmt in stmts {
470 match self.inner.borrow_mut().dispatch(self.db, &stmt, &[]) {
471 Ok(r) => completed.push(r),
472 Err(e) => {
473 return ScriptExecution {
474 completed,
475 error: Some(e),
476 }
477 }
478 }
479 }
480 ScriptExecution {
481 completed,
482 error: None,
483 }
484 }
485
486 pub fn query(&self, sql: &str) -> Result<QueryResult> {
487 self.query_params(sql, &[])
488 }
489
490 pub fn query_params(&self, sql: &str, params: &[Value]) -> Result<QueryResult> {
491 match self.execute_params(sql, params)? {
492 ExecutionResult::Query(qr) => Ok(qr),
493 ExecutionResult::RowsAffected(n) => Ok(QueryResult {
494 columns: vec!["rows_affected".into()],
495 rows: vec![vec![Value::Integer(n as i64)]],
496 }),
497 ExecutionResult::Ok => Ok(QueryResult {
498 columns: vec![],
499 rows: vec![],
500 }),
501 }
502 }
503
504 pub fn prepare(&self, sql: &str) -> Result<PreparedStatement<'_, 'a>> {
505 if let Some(rewritten) = rewrite_show_triggers(sql) {
506 return PreparedStatement::new(self, &rewritten);
507 }
508 if let Some(rewritten) = rewrite_show_matviews(sql) {
509 return PreparedStatement::new(self, &rewritten);
510 }
511 PreparedStatement::new(self, sql)
512 }
513
514 pub fn tables(&self) -> Vec<String> {
515 self.inner
516 .borrow()
517 .schema
518 .table_names()
519 .into_iter()
520 .map(String::from)
521 .collect()
522 }
523
524 pub fn in_transaction(&self) -> bool {
526 self.inner.borrow().active_txn.is_active()
527 }
528
529 pub fn table_schema(&self, name: &str) -> Option<TableSchema> {
530 self.inner.borrow().schema.get(name).cloned()
531 }
532
533 pub fn refresh_schema(&self) -> Result<()> {
534 let new_schema = SchemaManager::load(self.db)?;
535 self.inner.borrow_mut().schema = new_schema;
536 Ok(())
537 }
538
539 pub fn persist_ann_index(
547 &self,
548 table: &str,
549 column: &str,
550 ) -> Result<crate::executor::AnnSegmentInfo> {
551 if self.in_transaction() {
552 return Err(SqlError::InvalidValue(
553 "persist_ann_index: not allowed inside an explicit transaction".into(),
554 ));
555 }
556 let inner = self.inner.borrow();
557 let lower = table.to_ascii_lowercase();
558 if inner.schema.resolve_temp(&lower) != lower {
559 return Err(SqlError::InvalidValue(
560 "persist_ann_index: TEMP tables are not persistable".into(),
561 ));
562 }
563 let table_schema = inner
564 .schema
565 .get(&lower)
566 .ok_or_else(|| SqlError::TableNotFound(table.to_string()))?;
567 crate::executor::persist_ann_index(self.db, &inner.schema, table_schema, column)
568 }
569
570 pub fn ann_cache_status(
575 &self,
576 table: &str,
577 column: &str,
578 ) -> Result<Option<(crate::executor::AnnIndexSource, u64)>> {
579 let inner = self.inner.borrow();
580 let table_schema = inner
581 .schema
582 .get(&table.to_ascii_lowercase())
583 .ok_or_else(|| SqlError::TableNotFound(table.to_string()))?;
584 crate::executor::ann_cache_status(&inner.schema, table_schema, column)
585 }
586}
587
588impl<'a> ConnectionInner<'a> {
589 pub(crate) fn active_txn_is_some(&self) -> bool {
590 self.active_txn.is_active()
591 }
592
593 fn set_session_timezone_impl(&mut self, tz: &str) -> Result<()> {
594 let upper = tz.to_ascii_uppercase();
595 if (upper.starts_with("UTC+") || upper.starts_with("UTC-")) && tz.len() > 3 {
596 return Err(SqlError::InvalidTimezone(format!(
597 "'{tz}' is ambiguous; use ISO-8601 offset (e.g. '+05:00') or named zone (e.g. 'Etc/GMT-5')"
598 )));
599 }
600 if jiff::tz::TimeZone::get(tz).is_err() && parse_fixed_offset(tz).is_none() {
601 return Err(SqlError::InvalidTimezone(format!(
602 "{tz}: not a known IANA zone or ISO-8601 offset (e.g. '+05:00', 'UTC', 'America/New_York')"
603 )));
604 }
605 self.session_timezone = tz.to_string();
606 Ok(())
607 }
608
609 fn execute_impl(&mut self, db: &'a Database, sql: &str) -> Result<ExecutionResult> {
610 if let Some(rewritten) = rewrite_show_triggers(sql) {
611 return self.execute_params_impl(db, &rewritten, &[]);
612 }
613 if let Some(rewritten) = rewrite_show_matviews(sql) {
614 return self.execute_params_impl(db, &rewritten, &[]);
615 }
616 if matches!(sql.as_bytes().first(), Some(b'I' | b'i')) {
617 if let Some((normalized_key, extracted)) = try_normalize_insert(sql) {
618 let gen = self.schema.generation();
619 let stmt = if let Some(entry) = self.stmt_cache.get(&normalized_key) {
620 if entry.schema_gen == gen {
621 Arc::clone(&entry.stmt)
622 } else {
623 self.parse_and_cache(normalized_key, gen)?
624 }
625 } else {
626 self.parse_and_cache(normalized_key, gen)?
627 };
628 return self.dispatch(db, &stmt, &extracted);
629 }
630 }
631 self.execute_params_impl(db, sql, &[])
632 }
633
634 fn execute_params_impl(
635 &mut self,
636 db: &'a Database,
637 sql: &str,
638 params: &[Value],
639 ) -> Result<ExecutionResult> {
640 let gen = self.schema.generation();
641 if self.active_txn.is_none() {
642 if let Some(entry) = self.stmt_cache.get(sql) {
643 if entry.schema_gen == gen && entry.param_count == params.len() {
644 if let Some(plan) = entry.compiled.as_ref().map(Arc::clone) {
645 let stmt = Arc::clone(&entry.stmt);
646 return self.run_compiled(db, &plan, &stmt, params);
647 }
648 }
649 }
650 }
651
652 let (stmt, param_count) = self.get_or_parse(sql)?;
653
654 if param_count != params.len() {
655 return Err(SqlError::ParameterCountMismatch {
656 expected: param_count,
657 got: params.len(),
658 });
659 }
660
661 if self.active_txn.is_none() {
662 if let Some(plan) = executor::compile(&self.schema, &stmt) {
663 if let Some(e) = self.stmt_cache.get_mut(sql) {
664 e.compiled = Some(Arc::clone(&plan));
665 }
666 let stmt_owned = Arc::clone(&stmt);
667 return self.run_compiled(db, &plan, &stmt_owned, params);
668 }
669 }
670
671 self.dispatch(db, &stmt, params)
672 }
673
674 fn run_compiled(
675 &mut self,
676 db: &'a Database,
677 plan: &Arc<dyn executor::CompiledPlan>,
678 stmt: &Statement,
679 params: &[Value],
680 ) -> Result<ExecutionResult> {
681 use executor::compile::ActiveTxnRef;
682 let schema = &self.schema;
683 let exec = || {
684 if params.is_empty() {
685 plan.execute(db, schema, stmt, params, ActiveTxnRef::None)
686 } else {
687 crate::eval::with_scoped_params(params, || {
688 plan.execute(db, schema, stmt, params, ActiveTxnRef::None)
689 })
690 }
691 };
692 let outcome = if plan.needs_txn_clock() {
693 let cached_ts = self
694 .txn_start_ts
695 .or_else(|| Some(crate::datetime::now_micros()));
696 crate::datetime::with_txn_clock(cached_ts, exec)
697 } else {
698 exec()
699 }?;
700 invalidate_dml_caches(&self.schema, db);
701 Ok(outcome)
702 }
703
704 pub(crate) fn parse_and_cache(
705 &mut self,
706 normalized_key: String,
707 gen: u64,
708 ) -> Result<Arc<Statement>> {
709 let stmt = Arc::new(parser::parse_sql(&normalized_key)?);
710 let param_count = parser::count_params(&stmt);
711 self.stmt_cache.put(
712 normalized_key,
713 CacheEntry {
714 stmt: Arc::clone(&stmt),
715 schema_gen: gen,
716 param_count,
717 compiled: None,
718 },
719 );
720 Ok(stmt)
721 }
722
723 pub(crate) fn get_or_parse(&mut self, sql: &str) -> Result<(Arc<Statement>, usize)> {
724 let gen = self.schema.generation();
725
726 if let Some(entry) = self.stmt_cache.get(sql) {
727 if entry.schema_gen == gen {
728 return Ok((Arc::clone(&entry.stmt), entry.param_count));
729 }
730 }
731
732 let stmt = Arc::new(parser::parse_sql(sql)?);
733 let param_count = parser::count_params(&stmt);
734
735 let cacheable = !matches!(
736 *stmt,
737 Statement::CreateTable(_)
738 | Statement::DropTable(_)
739 | Statement::CreateIndex(_)
740 | Statement::DropIndex(_)
741 | Statement::CreateView(_)
742 | Statement::DropView(_)
743 | Statement::AlterTable(_)
744 );
745
746 if cacheable {
747 self.stmt_cache.put(
748 sql.to_string(),
749 CacheEntry {
750 stmt: Arc::clone(&stmt),
751 schema_gen: gen,
752 param_count,
753 compiled: None,
754 },
755 );
756 }
757
758 Ok((stmt, param_count))
759 }
760
761 pub(crate) fn execute_prepared(
762 &mut self,
763 db: &'a Database,
764 stmt: &Statement,
765 compiled: Option<&Arc<dyn executor::CompiledPlan>>,
766 params: &[Value],
767 ) -> Result<ExecutionResult> {
768 if let Some(plan) = compiled {
769 if self.active_txn.is_none() {
770 return self.run_compiled(db, plan, stmt, params);
771 }
772 if !self.savepoint_stack.is_empty() && stmt_mutates(stmt) {
773 self.capture_pending_snapshots();
774 }
775 return self.run_compiled_in_txn(db, plan, stmt, params);
776 }
777 self.dispatch(db, stmt, params)
778 }
779
780 fn run_compiled_in_txn(
781 &mut self,
782 db: &'a Database,
783 plan: &Arc<dyn executor::CompiledPlan>,
784 stmt: &Statement,
785 params: &[Value],
786 ) -> Result<ExecutionResult> {
787 use executor::compile::ActiveTxnRef;
788 let schema = &self.schema;
789 let txn = match &mut self.active_txn {
790 ActiveTxn::Write(wtx) => ActiveTxnRef::Write(wtx),
791 ActiveTxn::Read(rtx) => ActiveTxnRef::Read(rtx),
792 ActiveTxn::None => ActiveTxnRef::None,
793 };
794 if params.is_empty() || !plan.uses_scoped_params() {
795 plan.execute(db, schema, stmt, params, txn)
796 } else {
797 crate::eval::with_scoped_params(params, || plan.execute(db, schema, stmt, params, txn))
798 }
799 }
800
801 pub(crate) fn dispatch(
802 &mut self,
803 db: &'a Database,
804 stmt: &Statement,
805 params: &[Value],
806 ) -> Result<ExecutionResult> {
807 let cached_ts = self
808 .txn_start_ts
809 .or_else(|| Some(crate::datetime::now_micros()));
810 crate::datetime::with_txn_clock(cached_ts, || {
811 if params.is_empty() {
812 self.dispatch_inner(db, stmt, params)
813 } else {
814 crate::eval::with_scoped_params(params, || self.dispatch_inner(db, stmt, params))
815 }
816 })
817 }
818
819 fn dispatch_inner(
820 &mut self,
821 db: &'a Database,
822 stmt: &Statement,
823 params: &[Value],
824 ) -> Result<ExecutionResult> {
825 match stmt {
826 Statement::Begin { access_mode } => {
827 if self.active_txn.is_active() {
828 return Err(SqlError::TransactionAlreadyActive);
829 }
830 let ts = crate::datetime::txn_or_clock_micros();
831 match access_mode {
832 BeginAccessMode::ReadOnly => {
833 let rtx = db.begin_read();
834 self.active_txn = ActiveTxn::Read(rtx);
835 }
836 BeginAccessMode::ReadWrite | BeginAccessMode::Default => {
837 let wtx = db.begin_write().map_err(SqlError::Storage)?;
838 self.active_txn = ActiveTxn::Write(wtx);
839 }
840 }
841 self.txn_start_ts = Some(ts);
842 crate::datetime::set_txn_clock(Some(ts));
843 Ok(ExecutionResult::Ok)
844 }
845 Statement::Commit => {
846 match self.active_txn.take() {
847 ActiveTxn::None => return Err(SqlError::NoActiveTransaction),
848 ActiveTxn::Write(mut wtx) => {
849 crate::executor::helpers::drain_deferred_fk_checks(&mut wtx)?;
850 wtx.commit().map_err(SqlError::Storage)?;
851 invalidate_dml_caches(&self.schema, db);
852 }
853 ActiveTxn::Read(_rtx) => {}
854 }
855 self.clear_savepoint_state();
856 self.txn_start_ts = None;
857 crate::datetime::set_txn_clock(None);
858 Ok(ExecutionResult::Ok)
859 }
860 Statement::Rollback => {
861 match self.active_txn.take() {
862 ActiveTxn::None => return Err(SqlError::NoActiveTransaction),
863 ActiveTxn::Write(wtx) => {
864 wtx.abort();
865 self.schema = SchemaManager::load(db)?;
866 }
867 ActiveTxn::Read(_rtx) => {}
868 }
869 self.clear_savepoint_state();
870 self.txn_start_ts = None;
871 crate::datetime::set_txn_clock(None);
872 Ok(ExecutionResult::Ok)
873 }
874 Statement::Savepoint(name) => self.do_savepoint(name),
875 Statement::ReleaseSavepoint(name) => self.do_release(name),
876 Statement::RollbackTo(name) => self.do_rollback_to(name),
877 Statement::SetTimezone(zone) => {
878 self.set_session_timezone_impl(zone)?;
879 Ok(ExecutionResult::Ok)
880 }
881 Statement::CreateTable(ct) if ct.temporary => {
882 if self.active_txn.is_read_only() {
883 return Err(SqlError::Unsupported(
884 "cannot execute mutating statement inside a read-only transaction".into(),
885 ));
886 }
887 let user_name = ct.name.clone();
888 let prefixed = temp_storage_name(self.temp_id, &user_name);
889 if self.schema.contains(&user_name) {
890 if ct.if_not_exists {
891 return Ok(ExecutionResult::Ok);
892 }
893 return Err(SqlError::TableAlreadyExists(user_name));
894 }
895 let mut clone = ct.clone();
896 clone.name = prefixed.clone();
897 clone.temporary = false;
898 let stmt_concrete = Statement::CreateTable(clone);
899 let outcome = if let Some(wtx) = self.active_txn.as_write_mut() {
900 executor::execute_in_txn(wtx, &mut self.schema, &stmt_concrete, params)?
901 } else {
902 executor::execute(db, &mut self.schema, &stmt_concrete, params)?
903 };
904 self.schema
905 .register_temp_alias(&user_name, prefixed.clone());
906 self.temp_table_names.push(prefixed);
907 Ok(outcome)
908 }
909 Statement::Insert(ins) if self.active_txn.as_write_mut().is_some() => {
910 self.capture_pending_snapshots();
911 let wtx = self.active_txn.as_write_mut().unwrap();
912 executor::exec_insert_in_txn(wtx, &self.schema, ins, params)
913 }
914 _ => {
915 if self.active_txn.is_read_only() && stmt_mutates(stmt) {
916 return Err(SqlError::Unsupported(
917 "cannot execute mutating statement inside a read-only transaction".into(),
918 ));
919 }
920 if self.active_txn.as_write_mut().is_some() && stmt_mutates(stmt) {
921 self.capture_pending_snapshots();
922 }
923 let was_auto_commit = matches!(self.active_txn, ActiveTxn::None);
924 let outcome = match &mut self.active_txn {
925 ActiveTxn::Write(wtx) => {
926 executor::execute_in_txn(wtx, &mut self.schema, stmt, params)?
927 }
928 ActiveTxn::Read(rtx) => {
929 executor::execute_with_read(rtx, &self.schema, stmt, params)?
930 }
931 ActiveTxn::None => executor::execute(db, &mut self.schema, stmt, params)?,
932 };
933 if was_auto_commit {
934 invalidate_dml_caches(&self.schema, db);
935 }
936 if let Statement::DropTable(dt) = stmt {
937 self.schema.unregister_temp_alias(&dt.name);
938 }
939 Ok(outcome)
940 }
941 }
942 }
943
944 fn clear_savepoint_state(&mut self) {
945 self.savepoint_stack.clear();
946 self.in_place_saved = None;
947 }
948
949 fn do_savepoint(&mut self, name: &str) -> Result<ExecutionResult> {
950 let wtx = self
951 .active_txn
952 .as_write_mut()
953 .ok_or(SqlError::NoActiveTransaction)?;
954
955 if self.savepoint_stack.is_empty() {
956 self.in_place_saved = Some(wtx.in_place());
957 wtx.set_in_place(false);
958 }
959
960 self.savepoint_stack.push(SavepointEntry {
961 name: name.to_string(),
962 snapshot: None,
963 });
964
965 Ok(ExecutionResult::Ok)
966 }
967
968 fn capture_pending_snapshots(&mut self) {
969 let last_pending = match self
970 .savepoint_stack
971 .iter()
972 .rposition(|e| e.snapshot.is_none())
973 {
974 Some(i) => i,
975 None => return,
976 };
977 let wtx = match self.active_txn.as_write_mut() {
978 Some(w) => w,
979 None => return,
980 };
981 let wtx_snap = wtx.begin_savepoint();
982 let schema_snap = self.schema.save_snapshot();
983
984 for i in 0..last_pending {
985 if self.savepoint_stack[i].snapshot.is_none() {
986 self.savepoint_stack[i].snapshot = Some(SavepointSnapshot {
987 wtx_snap: wtx_snap.clone(),
988 schema_snap: schema_snap.clone(),
989 });
990 }
991 }
992 self.savepoint_stack[last_pending].snapshot = Some(SavepointSnapshot {
993 wtx_snap,
994 schema_snap,
995 });
996 }
997
998 fn do_release(&mut self, name: &str) -> Result<ExecutionResult> {
999 if !self.active_txn.is_active() {
1000 return Err(SqlError::NoActiveTransaction);
1001 }
1002
1003 let idx = self
1004 .savepoint_stack
1005 .iter()
1006 .rposition(|e| e.name == name)
1007 .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
1008 self.savepoint_stack.truncate(idx);
1009
1010 if self.savepoint_stack.is_empty() {
1011 if let (Some(wtx), Some(original)) =
1012 (self.active_txn.as_write_mut(), self.in_place_saved.take())
1013 {
1014 wtx.set_in_place(original);
1015 }
1016 }
1017
1018 Ok(ExecutionResult::Ok)
1019 }
1020
1021 fn do_rollback_to(&mut self, name: &str) -> Result<ExecutionResult> {
1022 if !self.active_txn.is_active() {
1023 return Err(SqlError::NoActiveTransaction);
1024 }
1025
1026 let idx = self
1027 .savepoint_stack
1028 .iter()
1029 .rposition(|e| e.name == name)
1030 .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
1031
1032 self.savepoint_stack.truncate(idx + 1);
1033 let entry = self.savepoint_stack.last_mut().unwrap();
1034 let snapshot = match entry.snapshot.take() {
1035 Some(s) => s,
1036 None => return Ok(ExecutionResult::Ok),
1037 };
1038
1039 let wtx = match self.active_txn.as_write_mut() {
1040 Some(w) => w,
1041 None => return Err(SqlError::NoActiveTransaction),
1042 };
1043 wtx.restore_snapshot(snapshot.wtx_snap);
1044 self.schema.restore_snapshot(snapshot.schema_snap);
1045
1046 Ok(ExecutionResult::Ok)
1047 }
1048}
1049
1050impl<'a> Drop for Connection<'a> {
1051 fn drop(&mut self) {
1052 let temp_names = std::mem::take(&mut self.inner.borrow_mut().temp_table_names);
1053 if temp_names.is_empty() {
1054 return;
1055 }
1056 if let Ok(mut wtx) = self.db.begin_write() {
1057 for prefixed in &temp_names {
1058 let _ = wtx.drop_table(prefixed.as_bytes());
1059 }
1060 let _ = wtx.commit();
1061 }
1062 }
1063}