1use std::cell::RefCell;
4use std::num::NonZeroUsize;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7
8fn generate_temp_id() -> u64 {
9 static COUNTER: AtomicU64 = AtomicU64::new(0);
10 let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
11 let nanos = (crate::datetime::now_micros() as u64) & 0xFFFF_FFFF;
12 (nanos << 32) | (counter & 0xFFFF_FFFF)
13}
14
15fn temp_storage_name(temp_id: u64, user_name: &str) -> String {
16 format!("__temp_{temp_id}_{}", user_name.to_ascii_lowercase())
17}
18
19use lru::LruCache;
20
21use citadel::Database;
22use citadel_txn::write_txn::{WriteTxn, WriteTxnSnapshot};
23
24use crate::error::{Result, SqlError};
25use crate::executor;
26use crate::parser;
27use crate::parser::{BeginAccessMode, QueryBody, SelectQuery, Statement};
28use crate::prepared::PreparedStatement;
29use crate::schema::{SchemaManager, SchemaSnapshot};
30use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
31
32const DEFAULT_CACHE_CAPACITY: usize = 64;
33
34#[derive(Debug)]
35pub struct ScriptExecution {
36 pub completed: Vec<ExecutionResult>,
37 pub error: Option<SqlError>,
38}
39
40fn parse_fixed_offset(s: &str) -> Option<()> {
41 let s = s.trim();
42 if s.eq_ignore_ascii_case("z") || s.eq_ignore_ascii_case("utc") {
43 return Some(());
44 }
45 let bytes = s.as_bytes();
46 if bytes.is_empty() {
47 return None;
48 }
49 if !matches!(bytes[0], b'+' | b'-') {
50 return None;
51 }
52 let rest = &s[1..];
53 let (hh, mm) = if let Some((h, m)) = rest.split_once(':') {
54 (h, m)
55 } else if rest.len() == 4 {
56 (&rest[..2], &rest[2..])
57 } else if rest.len() == 2 {
58 (rest, "00")
59 } else {
60 return None;
61 };
62 let h: u32 = hh.parse().ok()?;
63 let m: u32 = mm.parse().ok()?;
64 if h > 23 || m > 59 {
65 return None;
66 }
67 Some(())
68}
69
70fn rewrite_show_triggers(sql: &str) -> Option<String> {
71 let trimmed = sql.trim();
72 let trimmed = trimmed.trim_end_matches(';').trim();
73 let lower = trimmed.to_ascii_lowercase();
74 if !lower.starts_with("show triggers") {
75 return None;
76 }
77 let after = lower["show triggers".len()..].trim_start();
78 let base = "SELECT trigger_name, event_object_table AS table_name, action_timing, \
79 event_manipulation, action_orientation, action_statement \
80 FROM information_schema.triggers";
81 if after.is_empty() {
82 return Some(format!("{base} ORDER BY trigger_name"));
83 }
84 if let Some(rest) = after.strip_prefix("on ") {
85 let table = rest.trim().trim_end_matches(';').trim();
86 if table.is_empty() {
87 return None;
88 }
89 let escaped = table.replace('\'', "''");
90 return Some(format!(
91 "{base} WHERE LOWER(event_object_table) = LOWER('{escaped}') ORDER BY trigger_name"
92 ));
93 }
94 None
95}
96
97fn rewrite_show_matviews(sql: &str) -> Option<String> {
98 let trimmed = sql.trim();
99 let trimmed = trimmed.trim_end_matches(';').trim();
100 let lower = trimmed.to_ascii_lowercase();
101 if lower != "show materialized views" {
102 return None;
103 }
104 Some(
105 "SELECT matviewname, ispopulated, hasindexes, definition \
106 FROM pg_matviews ORDER BY matviewname"
107 .to_string(),
108 )
109}
110
111fn stmt_mutates(stmt: &Statement) -> bool {
112 if matches!(
113 stmt,
114 Statement::Insert(_)
115 | Statement::Update(_)
116 | Statement::Delete(_)
117 | Statement::Truncate(_)
118 | Statement::CreateTable(_)
119 | Statement::DropTable(_)
120 | Statement::AlterTable(_)
121 | Statement::CreateIndex(_)
122 | Statement::DropIndex(_)
123 | Statement::CreateView(_)
124 | Statement::DropView(_)
125 | Statement::CreateTrigger(_)
126 | Statement::DropTrigger(_)
127 | Statement::CreateMaterializedView(_)
128 | Statement::RefreshMaterializedView(_)
129 | Statement::DropMaterializedView(_)
130 ) {
131 return true;
132 }
133 if let Statement::Select(sq) = stmt {
134 if select_query_has_dml(sq) {
135 return true;
136 }
137 }
138 false
139}
140
141fn select_query_has_dml(sq: &SelectQuery) -> bool {
142 sq.ctes.iter().any(|cte| query_body_has_dml(&cte.body)) || query_body_has_dml(&sq.body)
143}
144
145fn query_body_has_dml(body: &QueryBody) -> bool {
146 match body {
147 QueryBody::Insert(_) | QueryBody::Update(_) | QueryBody::Delete(_) => true,
148 QueryBody::Compound(c) => query_body_has_dml(&c.left) || query_body_has_dml(&c.right),
149 QueryBody::Select(_) => false,
150 }
151}
152
153fn try_normalize_insert(sql: &str) -> Option<(String, Vec<Value>)> {
154 let bytes = sql.as_bytes();
155 let len = bytes.len();
156 let mut i = 0;
157
158 while i < len && bytes[i].is_ascii_whitespace() {
159 i += 1;
160 }
161 if i + 6 > len || !bytes[i..i + 6].eq_ignore_ascii_case(b"INSERT") {
162 return None;
163 }
164 i += 6;
165 if i >= len || !bytes[i].is_ascii_whitespace() {
166 return None;
167 }
168 while i < len && bytes[i].is_ascii_whitespace() {
169 i += 1;
170 }
171
172 if i + 4 > len || !bytes[i..i + 4].eq_ignore_ascii_case(b"INTO") {
173 return None;
174 }
175 i += 4;
176 if i >= len || !bytes[i].is_ascii_whitespace() {
177 return None;
178 }
179
180 let prefix_start = 0;
181 let mut values_pos = None;
182 let mut j = i;
183 while j + 6 <= len {
184 if bytes[j..j + 6].eq_ignore_ascii_case(b"VALUES")
185 && (j == 0 || !bytes[j - 1].is_ascii_alphanumeric() && bytes[j - 1] != b'_')
186 && (j + 6 >= len || !bytes[j + 6].is_ascii_alphanumeric() && bytes[j + 6] != b'_')
187 {
188 values_pos = Some(j);
189 break;
190 }
191 j += 1;
192 }
193 let values_pos = values_pos?;
194
195 let prefix = &sql[prefix_start..values_pos + 6];
196 let mut pos = values_pos + 6;
197
198 while pos < len && bytes[pos].is_ascii_whitespace() {
199 pos += 1;
200 }
201 if pos >= len || bytes[pos] != b'(' {
202 return None;
203 }
204 pos += 1;
205
206 let mut values = Vec::new();
207 let mut normalized = String::with_capacity(sql.len());
208 normalized.push_str(prefix);
209 normalized.push_str(" (");
210
211 loop {
212 while pos < len && bytes[pos].is_ascii_whitespace() {
213 pos += 1;
214 }
215 if pos >= len {
216 return None;
217 }
218
219 let param_idx = values.len() + 1;
220 if param_idx > 1 {
221 normalized.push_str(", ");
222 }
223
224 if bytes[pos] == b'\'' {
225 pos += 1;
226 let mut seg_start = pos;
227 let mut s = String::new();
228 loop {
229 if pos >= len {
230 return None;
231 }
232 if bytes[pos] == b'\'' {
233 s.push_str(std::str::from_utf8(&bytes[seg_start..pos]).ok()?);
234 pos += 1;
235 if pos < len && bytes[pos] == b'\'' {
236 s.push('\'');
237 pos += 1;
238 seg_start = pos;
239 } else {
240 break;
241 }
242 } else {
243 pos += 1;
244 }
245 }
246 values.push(Value::Text(s.into()));
247 } else if bytes[pos] == b'-' || bytes[pos].is_ascii_digit() {
248 let start = pos;
249 if bytes[pos] == b'-' {
250 pos += 1;
251 }
252 while pos < len && bytes[pos].is_ascii_digit() {
253 pos += 1;
254 }
255 if pos < len && bytes[pos] == b'.' {
256 pos += 1;
257 while pos < len && bytes[pos].is_ascii_digit() {
258 pos += 1;
259 }
260 let num: f64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
261 values.push(Value::Real(num));
262 } else {
263 let num: i64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
264 values.push(Value::Integer(num));
265 }
266 } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"NULL") {
267 let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
268 if !after.is_ascii_alphanumeric() && after != b'_' {
269 pos += 4;
270 values.push(Value::Null);
271 } else {
272 return None;
273 }
274 } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"TRUE") {
275 let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
276 if !after.is_ascii_alphanumeric() && after != b'_' {
277 pos += 4;
278 values.push(Value::Boolean(true));
279 } else {
280 return None;
281 }
282 } else if pos + 5 <= len && bytes[pos..pos + 5].eq_ignore_ascii_case(b"FALSE") {
283 let after = if pos + 5 < len { bytes[pos + 5] } else { b')' };
284 if !after.is_ascii_alphanumeric() && after != b'_' {
285 pos += 5;
286 values.push(Value::Boolean(false));
287 } else {
288 return None;
289 }
290 } else {
291 return None;
292 }
293
294 normalized.push('$');
295 normalized.push_str(¶m_idx.to_string());
296
297 while pos < len && bytes[pos].is_ascii_whitespace() {
298 pos += 1;
299 }
300 if pos >= len {
301 return None;
302 }
303
304 if bytes[pos] == b',' {
305 pos += 1;
306 } else if bytes[pos] == b')' {
307 pos += 1;
308 break;
309 } else {
310 return None;
311 }
312 }
313
314 normalized.push(')');
315
316 while pos < len && (bytes[pos].is_ascii_whitespace() || bytes[pos] == b';') {
317 pos += 1;
318 }
319 if pos != len {
320 return None;
321 }
322
323 if values.is_empty() {
324 return None;
325 }
326
327 Some((normalized, values))
328}
329
330pub(crate) struct CacheEntry {
331 pub(crate) stmt: Arc<Statement>,
332 pub(crate) schema_gen: u64,
333 pub(crate) param_count: usize,
334 pub(crate) compiled: Option<Arc<dyn executor::CompiledPlan>>,
335}
336
337struct SavepointEntry {
338 name: String,
339 snapshot: Option<SavepointSnapshot>,
340}
341
342struct SavepointSnapshot {
343 wtx_snap: WriteTxnSnapshot,
344 schema_snap: SchemaSnapshot,
345}
346
347#[allow(clippy::large_enum_variant)]
350pub(crate) enum ActiveTxn<'a> {
351 None,
352 Write(WriteTxn<'a>),
353 Read(citadel_txn::read_txn::ReadTxn<'a>),
354}
355
356impl<'a> ActiveTxn<'a> {
357 fn is_none(&self) -> bool {
358 matches!(self, ActiveTxn::None)
359 }
360 fn is_active(&self) -> bool {
361 !self.is_none()
362 }
363 fn is_read_only(&self) -> bool {
364 matches!(self, ActiveTxn::Read(_))
365 }
366 fn as_write_mut(&mut self) -> Option<&mut WriteTxn<'a>> {
367 match self {
368 ActiveTxn::Write(w) => Some(w),
369 _ => None,
370 }
371 }
372 fn take(&mut self) -> ActiveTxn<'a> {
373 std::mem::replace(self, ActiveTxn::None)
374 }
375}
376
377pub(crate) struct ConnectionInner<'a> {
378 pub(crate) schema: SchemaManager,
379 active_txn: ActiveTxn<'a>,
380 savepoint_stack: Vec<SavepointEntry>,
381 in_place_saved: Option<bool>,
382 pub(crate) stmt_cache: LruCache<String, CacheEntry>,
383 txn_start_ts: Option<i64>,
384 session_timezone: String,
385 temp_id: u64,
387 temp_table_names: Vec<String>,
388}
389
390pub struct Connection<'a> {
391 pub(crate) db: &'a Database,
392 pub(crate) inner: RefCell<ConnectionInner<'a>>,
393}
394
395impl<'a> Connection<'a> {
396 pub fn open(db: &'a Database) -> Result<Self> {
397 let schema = SchemaManager::load(db)?;
398 let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
399 let temp_id = generate_temp_id();
400 Ok(Self {
401 db,
402 inner: RefCell::new(ConnectionInner {
403 schema,
404 active_txn: ActiveTxn::None,
405 savepoint_stack: Vec::new(),
406 in_place_saved: None,
407 stmt_cache,
408 txn_start_ts: None,
409 session_timezone: "UTC".to_string(),
410 temp_id,
411 temp_table_names: Vec::new(),
412 }),
413 })
414 }
415
416 pub fn txn_start_ts(&self) -> Option<i64> {
418 self.inner.borrow().txn_start_ts
419 }
420
421 pub fn session_timezone(&self) -> String {
423 self.inner.borrow().session_timezone.clone()
424 }
425
426 pub fn set_session_timezone(&self, tz: &str) -> Result<()> {
428 self.inner.borrow_mut().set_session_timezone_impl(tz)
429 }
430
431 pub fn execute(&self, sql: &str) -> Result<ExecutionResult> {
432 self.inner.borrow_mut().execute_impl(self.db, sql)
433 }
434
435 pub fn execute_params(&self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
436 self.inner
437 .borrow_mut()
438 .execute_params_impl(self.db, sql, params)
439 }
440
441 pub fn execute_script(&self, sql: &str) -> ScriptExecution {
443 let stmts = match parser::parse_sql_multi(sql) {
444 Ok(s) => s,
445 Err(e) => {
446 return ScriptExecution {
447 completed: vec![],
448 error: Some(e),
449 }
450 }
451 };
452 let mut completed = Vec::with_capacity(stmts.len());
453 for stmt in stmts {
454 match self.inner.borrow_mut().dispatch(self.db, &stmt, &[]) {
455 Ok(r) => completed.push(r),
456 Err(e) => {
457 return ScriptExecution {
458 completed,
459 error: Some(e),
460 }
461 }
462 }
463 }
464 ScriptExecution {
465 completed,
466 error: None,
467 }
468 }
469
470 pub fn query(&self, sql: &str) -> Result<QueryResult> {
471 self.query_params(sql, &[])
472 }
473
474 pub fn query_params(&self, sql: &str, params: &[Value]) -> Result<QueryResult> {
475 match self.execute_params(sql, params)? {
476 ExecutionResult::Query(qr) => Ok(qr),
477 ExecutionResult::RowsAffected(n) => Ok(QueryResult {
478 columns: vec!["rows_affected".into()],
479 rows: vec![vec![Value::Integer(n as i64)]],
480 }),
481 ExecutionResult::Ok => Ok(QueryResult {
482 columns: vec![],
483 rows: vec![],
484 }),
485 }
486 }
487
488 pub fn prepare(&self, sql: &str) -> Result<PreparedStatement<'_, 'a>> {
489 if let Some(rewritten) = rewrite_show_triggers(sql) {
490 return PreparedStatement::new(self, &rewritten);
491 }
492 if let Some(rewritten) = rewrite_show_matviews(sql) {
493 return PreparedStatement::new(self, &rewritten);
494 }
495 PreparedStatement::new(self, sql)
496 }
497
498 pub fn tables(&self) -> Vec<String> {
499 self.inner
500 .borrow()
501 .schema
502 .table_names()
503 .into_iter()
504 .map(String::from)
505 .collect()
506 }
507
508 pub fn in_transaction(&self) -> bool {
510 self.inner.borrow().active_txn.is_active()
511 }
512
513 pub fn table_schema(&self, name: &str) -> Option<TableSchema> {
514 self.inner.borrow().schema.get(name).cloned()
515 }
516
517 pub fn refresh_schema(&self) -> Result<()> {
518 let new_schema = SchemaManager::load(self.db)?;
519 self.inner.borrow_mut().schema = new_schema;
520 Ok(())
521 }
522}
523
524impl<'a> ConnectionInner<'a> {
525 pub(crate) fn active_txn_is_some(&self) -> bool {
526 self.active_txn.is_active()
527 }
528
529 fn set_session_timezone_impl(&mut self, tz: &str) -> Result<()> {
530 let upper = tz.to_ascii_uppercase();
531 if (upper.starts_with("UTC+") || upper.starts_with("UTC-")) && tz.len() > 3 {
532 return Err(SqlError::InvalidTimezone(format!(
533 "'{tz}' is ambiguous; use ISO-8601 offset (e.g. '+05:00') or named zone (e.g. 'Etc/GMT-5')"
534 )));
535 }
536 if jiff::tz::TimeZone::get(tz).is_err() && parse_fixed_offset(tz).is_none() {
537 return Err(SqlError::InvalidTimezone(format!(
538 "{tz}: not a known IANA zone or ISO-8601 offset (e.g. '+05:00', 'UTC', 'America/New_York')"
539 )));
540 }
541 self.session_timezone = tz.to_string();
542 Ok(())
543 }
544
545 fn execute_impl(&mut self, db: &'a Database, sql: &str) -> Result<ExecutionResult> {
546 if let Some(rewritten) = rewrite_show_triggers(sql) {
547 return self.execute_params_impl(db, &rewritten, &[]);
548 }
549 if let Some(rewritten) = rewrite_show_matviews(sql) {
550 return self.execute_params_impl(db, &rewritten, &[]);
551 }
552 if matches!(sql.as_bytes().first(), Some(b'I' | b'i')) {
553 if let Some((normalized_key, extracted)) = try_normalize_insert(sql) {
554 let gen = self.schema.generation();
555 let stmt = if let Some(entry) = self.stmt_cache.get(&normalized_key) {
556 if entry.schema_gen == gen {
557 Arc::clone(&entry.stmt)
558 } else {
559 self.parse_and_cache(normalized_key, gen)?
560 }
561 } else {
562 self.parse_and_cache(normalized_key, gen)?
563 };
564 return self.dispatch(db, &stmt, &extracted);
565 }
566 }
567 self.execute_params_impl(db, sql, &[])
568 }
569
570 fn execute_params_impl(
571 &mut self,
572 db: &'a Database,
573 sql: &str,
574 params: &[Value],
575 ) -> Result<ExecutionResult> {
576 let gen = self.schema.generation();
577 if self.active_txn.is_none() {
578 if let Some(entry) = self.stmt_cache.get(sql) {
579 if entry.schema_gen == gen && entry.param_count == params.len() {
580 if let Some(plan) = entry.compiled.as_ref().map(Arc::clone) {
581 let stmt = Arc::clone(&entry.stmt);
582 return self.run_compiled(db, &plan, &stmt, params);
583 }
584 }
585 }
586 }
587
588 let (stmt, param_count) = self.get_or_parse(sql)?;
589
590 if param_count != params.len() {
591 return Err(SqlError::ParameterCountMismatch {
592 expected: param_count,
593 got: params.len(),
594 });
595 }
596
597 if self.active_txn.is_none() {
598 if let Some(plan) = executor::compile(&self.schema, &stmt) {
599 if let Some(e) = self.stmt_cache.get_mut(sql) {
600 e.compiled = Some(Arc::clone(&plan));
601 }
602 let stmt_owned = Arc::clone(&stmt);
603 return self.run_compiled(db, &plan, &stmt_owned, params);
604 }
605 }
606
607 self.dispatch(db, &stmt, params)
608 }
609
610 fn run_compiled(
611 &mut self,
612 db: &'a Database,
613 plan: &Arc<dyn executor::CompiledPlan>,
614 stmt: &Statement,
615 params: &[Value],
616 ) -> Result<ExecutionResult> {
617 use executor::compile::ActiveTxnRef;
618 let schema = &self.schema;
619 let exec = || {
620 if params.is_empty() {
621 plan.execute(db, schema, stmt, params, ActiveTxnRef::None)
622 } else {
623 crate::eval::with_scoped_params(params, || {
624 plan.execute(db, schema, stmt, params, ActiveTxnRef::None)
625 })
626 }
627 };
628 if plan.needs_txn_clock() {
629 let cached_ts = self
630 .txn_start_ts
631 .or_else(|| Some(crate::datetime::now_micros()));
632 crate::datetime::with_txn_clock(cached_ts, exec)
633 } else {
634 exec()
635 }
636 }
637
638 pub(crate) fn parse_and_cache(
639 &mut self,
640 normalized_key: String,
641 gen: u64,
642 ) -> Result<Arc<Statement>> {
643 let stmt = Arc::new(parser::parse_sql(&normalized_key)?);
644 let param_count = parser::count_params(&stmt);
645 self.stmt_cache.put(
646 normalized_key,
647 CacheEntry {
648 stmt: Arc::clone(&stmt),
649 schema_gen: gen,
650 param_count,
651 compiled: None,
652 },
653 );
654 Ok(stmt)
655 }
656
657 pub(crate) fn get_or_parse(&mut self, sql: &str) -> Result<(Arc<Statement>, usize)> {
658 let gen = self.schema.generation();
659
660 if let Some(entry) = self.stmt_cache.get(sql) {
661 if entry.schema_gen == gen {
662 return Ok((Arc::clone(&entry.stmt), entry.param_count));
663 }
664 }
665
666 let stmt = Arc::new(parser::parse_sql(sql)?);
667 let param_count = parser::count_params(&stmt);
668
669 let cacheable = !matches!(
670 *stmt,
671 Statement::CreateTable(_)
672 | Statement::DropTable(_)
673 | Statement::CreateIndex(_)
674 | Statement::DropIndex(_)
675 | Statement::CreateView(_)
676 | Statement::DropView(_)
677 | Statement::AlterTable(_)
678 );
679
680 if cacheable {
681 self.stmt_cache.put(
682 sql.to_string(),
683 CacheEntry {
684 stmt: Arc::clone(&stmt),
685 schema_gen: gen,
686 param_count,
687 compiled: None,
688 },
689 );
690 }
691
692 Ok((stmt, param_count))
693 }
694
695 pub(crate) fn execute_prepared(
696 &mut self,
697 db: &'a Database,
698 stmt: &Statement,
699 compiled: Option<&Arc<dyn executor::CompiledPlan>>,
700 params: &[Value],
701 ) -> Result<ExecutionResult> {
702 if let Some(plan) = compiled {
703 if self.active_txn.is_none() {
704 return self.run_compiled(db, plan, stmt, params);
705 }
706 if !self.savepoint_stack.is_empty() && stmt_mutates(stmt) {
707 self.capture_pending_snapshots();
708 }
709 return self.run_compiled_in_txn(db, plan, stmt, params);
710 }
711 self.dispatch(db, stmt, params)
712 }
713
714 fn run_compiled_in_txn(
715 &mut self,
716 db: &'a Database,
717 plan: &Arc<dyn executor::CompiledPlan>,
718 stmt: &Statement,
719 params: &[Value],
720 ) -> Result<ExecutionResult> {
721 use executor::compile::ActiveTxnRef;
722 let schema = &self.schema;
723 let txn = match &mut self.active_txn {
724 ActiveTxn::Write(wtx) => ActiveTxnRef::Write(wtx),
725 ActiveTxn::Read(rtx) => ActiveTxnRef::Read(rtx),
726 ActiveTxn::None => ActiveTxnRef::None,
727 };
728 if params.is_empty() || !plan.uses_scoped_params() {
729 plan.execute(db, schema, stmt, params, txn)
730 } else {
731 crate::eval::with_scoped_params(params, || plan.execute(db, schema, stmt, params, txn))
732 }
733 }
734
735 pub(crate) fn dispatch(
736 &mut self,
737 db: &'a Database,
738 stmt: &Statement,
739 params: &[Value],
740 ) -> Result<ExecutionResult> {
741 let cached_ts = self
742 .txn_start_ts
743 .or_else(|| Some(crate::datetime::now_micros()));
744 crate::datetime::with_txn_clock(cached_ts, || {
745 if params.is_empty() {
746 self.dispatch_inner(db, stmt, params)
747 } else {
748 crate::eval::with_scoped_params(params, || self.dispatch_inner(db, stmt, params))
749 }
750 })
751 }
752
753 fn dispatch_inner(
754 &mut self,
755 db: &'a Database,
756 stmt: &Statement,
757 params: &[Value],
758 ) -> Result<ExecutionResult> {
759 match stmt {
760 Statement::Begin { access_mode } => {
761 if self.active_txn.is_active() {
762 return Err(SqlError::TransactionAlreadyActive);
763 }
764 let ts = crate::datetime::txn_or_clock_micros();
765 match access_mode {
766 BeginAccessMode::ReadOnly => {
767 let rtx = db.begin_read();
768 self.active_txn = ActiveTxn::Read(rtx);
769 }
770 BeginAccessMode::ReadWrite | BeginAccessMode::Default => {
771 let wtx = db.begin_write().map_err(SqlError::Storage)?;
772 self.active_txn = ActiveTxn::Write(wtx);
773 }
774 }
775 self.txn_start_ts = Some(ts);
776 crate::datetime::set_txn_clock(Some(ts));
777 Ok(ExecutionResult::Ok)
778 }
779 Statement::Commit => {
780 match self.active_txn.take() {
781 ActiveTxn::None => return Err(SqlError::NoActiveTransaction),
782 ActiveTxn::Write(mut wtx) => {
783 crate::executor::helpers::drain_deferred_fk_checks(&mut wtx)?;
784 wtx.commit().map_err(SqlError::Storage)?;
785 }
786 ActiveTxn::Read(_rtx) => {}
787 }
788 self.clear_savepoint_state();
789 self.txn_start_ts = None;
790 crate::datetime::set_txn_clock(None);
791 Ok(ExecutionResult::Ok)
792 }
793 Statement::Rollback => {
794 match self.active_txn.take() {
795 ActiveTxn::None => return Err(SqlError::NoActiveTransaction),
796 ActiveTxn::Write(wtx) => {
797 wtx.abort();
798 self.schema = SchemaManager::load(db)?;
799 }
800 ActiveTxn::Read(_rtx) => {}
801 }
802 self.clear_savepoint_state();
803 self.txn_start_ts = None;
804 crate::datetime::set_txn_clock(None);
805 Ok(ExecutionResult::Ok)
806 }
807 Statement::Savepoint(name) => self.do_savepoint(name),
808 Statement::ReleaseSavepoint(name) => self.do_release(name),
809 Statement::RollbackTo(name) => self.do_rollback_to(name),
810 Statement::SetTimezone(zone) => {
811 self.set_session_timezone_impl(zone)?;
812 Ok(ExecutionResult::Ok)
813 }
814 Statement::CreateTable(ct) if ct.temporary => {
815 if self.active_txn.is_read_only() {
816 return Err(SqlError::Unsupported(
817 "cannot execute mutating statement inside a read-only transaction".into(),
818 ));
819 }
820 let user_name = ct.name.clone();
821 let prefixed = temp_storage_name(self.temp_id, &user_name);
822 if self.schema.contains(&user_name) {
823 if ct.if_not_exists {
824 return Ok(ExecutionResult::Ok);
825 }
826 return Err(SqlError::TableAlreadyExists(user_name));
827 }
828 let mut clone = ct.clone();
829 clone.name = prefixed.clone();
830 clone.temporary = false;
831 let stmt_concrete = Statement::CreateTable(clone);
832 let outcome = if let Some(wtx) = self.active_txn.as_write_mut() {
833 executor::execute_in_txn(wtx, &mut self.schema, &stmt_concrete, params)?
834 } else {
835 executor::execute(db, &mut self.schema, &stmt_concrete, params)?
836 };
837 self.schema
838 .register_temp_alias(&user_name, prefixed.clone());
839 self.temp_table_names.push(prefixed);
840 Ok(outcome)
841 }
842 Statement::Insert(ins) if self.active_txn.as_write_mut().is_some() => {
843 self.capture_pending_snapshots();
844 let wtx = self.active_txn.as_write_mut().unwrap();
845 executor::exec_insert_in_txn(wtx, &self.schema, ins, params)
846 }
847 _ => {
848 if self.active_txn.is_read_only() && stmt_mutates(stmt) {
849 return Err(SqlError::Unsupported(
850 "cannot execute mutating statement inside a read-only transaction".into(),
851 ));
852 }
853 if self.active_txn.as_write_mut().is_some() && stmt_mutates(stmt) {
854 self.capture_pending_snapshots();
855 }
856 let outcome = match &mut self.active_txn {
857 ActiveTxn::Write(wtx) => {
858 executor::execute_in_txn(wtx, &mut self.schema, stmt, params)?
859 }
860 ActiveTxn::Read(rtx) => {
861 executor::execute_with_read(rtx, &self.schema, stmt, params)?
862 }
863 ActiveTxn::None => executor::execute(db, &mut self.schema, stmt, params)?,
864 };
865 if let Statement::DropTable(dt) = stmt {
866 self.schema.unregister_temp_alias(&dt.name);
867 }
868 Ok(outcome)
869 }
870 }
871 }
872
873 fn clear_savepoint_state(&mut self) {
874 self.savepoint_stack.clear();
875 self.in_place_saved = None;
876 }
877
878 fn do_savepoint(&mut self, name: &str) -> Result<ExecutionResult> {
879 let wtx = self
880 .active_txn
881 .as_write_mut()
882 .ok_or(SqlError::NoActiveTransaction)?;
883
884 if self.savepoint_stack.is_empty() {
885 self.in_place_saved = Some(wtx.in_place());
886 wtx.set_in_place(false);
887 }
888
889 self.savepoint_stack.push(SavepointEntry {
890 name: name.to_string(),
891 snapshot: None,
892 });
893
894 Ok(ExecutionResult::Ok)
895 }
896
897 fn capture_pending_snapshots(&mut self) {
898 let last_pending = match self
899 .savepoint_stack
900 .iter()
901 .rposition(|e| e.snapshot.is_none())
902 {
903 Some(i) => i,
904 None => return,
905 };
906 let wtx = match self.active_txn.as_write_mut() {
907 Some(w) => w,
908 None => return,
909 };
910 let wtx_snap = wtx.begin_savepoint();
911 let schema_snap = self.schema.save_snapshot();
912
913 for i in 0..last_pending {
914 if self.savepoint_stack[i].snapshot.is_none() {
915 self.savepoint_stack[i].snapshot = Some(SavepointSnapshot {
916 wtx_snap: wtx_snap.clone(),
917 schema_snap: schema_snap.clone(),
918 });
919 }
920 }
921 self.savepoint_stack[last_pending].snapshot = Some(SavepointSnapshot {
922 wtx_snap,
923 schema_snap,
924 });
925 }
926
927 fn do_release(&mut self, name: &str) -> Result<ExecutionResult> {
928 if !self.active_txn.is_active() {
929 return Err(SqlError::NoActiveTransaction);
930 }
931
932 let idx = self
933 .savepoint_stack
934 .iter()
935 .rposition(|e| e.name == name)
936 .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
937 self.savepoint_stack.truncate(idx);
938
939 if self.savepoint_stack.is_empty() {
940 if let (Some(wtx), Some(original)) =
941 (self.active_txn.as_write_mut(), self.in_place_saved.take())
942 {
943 wtx.set_in_place(original);
944 }
945 }
946
947 Ok(ExecutionResult::Ok)
948 }
949
950 fn do_rollback_to(&mut self, name: &str) -> Result<ExecutionResult> {
951 if !self.active_txn.is_active() {
952 return Err(SqlError::NoActiveTransaction);
953 }
954
955 let idx = self
956 .savepoint_stack
957 .iter()
958 .rposition(|e| e.name == name)
959 .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
960
961 self.savepoint_stack.truncate(idx + 1);
962 let entry = self.savepoint_stack.last_mut().unwrap();
963 let snapshot = match entry.snapshot.take() {
964 Some(s) => s,
965 None => return Ok(ExecutionResult::Ok),
966 };
967
968 let wtx = match self.active_txn.as_write_mut() {
969 Some(w) => w,
970 None => return Err(SqlError::NoActiveTransaction),
971 };
972 wtx.restore_snapshot(snapshot.wtx_snap);
973 self.schema.restore_snapshot(snapshot.schema_snap);
974
975 Ok(ExecutionResult::Ok)
976 }
977}
978
979impl<'a> Drop for Connection<'a> {
980 fn drop(&mut self) {
981 let temp_names = std::mem::take(&mut self.inner.borrow_mut().temp_table_names);
982 if temp_names.is_empty() {
983 return;
984 }
985 if let Ok(mut wtx) = self.db.begin_write() {
986 for prefixed in &temp_names {
987 let _ = wtx.drop_table(prefixed.as_bytes());
988 }
989 let _ = wtx.commit();
990 }
991 }
992}