1use std::cell::RefCell;
4use std::num::NonZeroUsize;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7
8fn generate_temp_id() -> u64 {
9 static COUNTER: AtomicU64 = AtomicU64::new(0);
10 let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
11 let nanos = (crate::datetime::now_micros() as u64) & 0xFFFF_FFFF;
12 (nanos << 32) | (counter & 0xFFFF_FFFF)
13}
14
15fn temp_storage_name(temp_id: u64, user_name: &str) -> String {
16 format!("__temp_{temp_id}_{}", user_name.to_ascii_lowercase())
17}
18
19use lru::LruCache;
20
21use citadel::Database;
22use citadel_txn::write_txn::{WriteTxn, WriteTxnSnapshot};
23
24use crate::error::{Result, SqlError};
25use crate::executor;
26use crate::parser;
27use crate::parser::{BeginAccessMode, QueryBody, SelectQuery, Statement};
28use crate::prepared::PreparedStatement;
29use crate::schema::{SchemaManager, SchemaSnapshot};
30use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
31
32const DEFAULT_CACHE_CAPACITY: usize = 64;
33
34fn invalidate_dml_caches(schema: &SchemaManager, db: &Database) {
36 for table in schema.drain_dml_dirty() {
37 db.sql_cache_invalidate_prefix(&format!("ann:{table}:"));
38 }
39}
40
41#[derive(Debug)]
42pub struct ScriptExecution {
43 pub completed: Vec<ExecutionResult>,
44 pub error: Option<SqlError>,
45}
46
47fn parse_fixed_offset(s: &str) -> Option<()> {
48 let s = s.trim();
49 if s.eq_ignore_ascii_case("z") || s.eq_ignore_ascii_case("utc") {
50 return Some(());
51 }
52 let bytes = s.as_bytes();
53 if bytes.is_empty() {
54 return None;
55 }
56 if !matches!(bytes[0], b'+' | b'-') {
57 return None;
58 }
59 let rest = &s[1..];
60 let (hh, mm) = if let Some((h, m)) = rest.split_once(':') {
61 (h, m)
62 } else if rest.len() == 4 {
63 (&rest[..2], &rest[2..])
64 } else if rest.len() == 2 {
65 (rest, "00")
66 } else {
67 return None;
68 };
69 let h: u32 = hh.parse().ok()?;
70 let m: u32 = mm.parse().ok()?;
71 if h > 23 || m > 59 {
72 return None;
73 }
74 Some(())
75}
76
77fn rewrite_show_triggers(sql: &str) -> Option<String> {
78 let trimmed = sql.trim();
79 let trimmed = trimmed.trim_end_matches(';').trim();
80 let lower = trimmed.to_ascii_lowercase();
81 if !lower.starts_with("show triggers") {
82 return None;
83 }
84 let after = lower["show triggers".len()..].trim_start();
85 let base = "SELECT trigger_name, event_object_table AS table_name, action_timing, \
86 event_manipulation, action_orientation, action_statement \
87 FROM information_schema.triggers";
88 if after.is_empty() {
89 return Some(format!("{base} ORDER BY trigger_name"));
90 }
91 if let Some(rest) = after.strip_prefix("on ") {
92 let table = rest.trim().trim_end_matches(';').trim();
93 if table.is_empty() {
94 return None;
95 }
96 let escaped = table.replace('\'', "''");
97 return Some(format!(
98 "{base} WHERE LOWER(event_object_table) = LOWER('{escaped}') ORDER BY trigger_name"
99 ));
100 }
101 None
102}
103
104fn rewrite_show_matviews(sql: &str) -> Option<String> {
105 let trimmed = sql.trim();
106 let trimmed = trimmed.trim_end_matches(';').trim();
107 let lower = trimmed.to_ascii_lowercase();
108 if lower != "show materialized views" {
109 return None;
110 }
111 Some(
112 "SELECT matviewname, ispopulated, hasindexes, definition \
113 FROM pg_matviews ORDER BY matviewname"
114 .to_string(),
115 )
116}
117
118fn stmt_mutates(stmt: &Statement) -> bool {
119 if matches!(
120 stmt,
121 Statement::Insert(_)
122 | Statement::Update(_)
123 | Statement::Delete(_)
124 | Statement::Truncate(_)
125 | Statement::CreateTable(_)
126 | Statement::DropTable(_)
127 | Statement::AlterTable(_)
128 | Statement::CreateIndex(_)
129 | Statement::DropIndex(_)
130 | Statement::CreateView(_)
131 | Statement::DropView(_)
132 | Statement::CreateTrigger(_)
133 | Statement::DropTrigger(_)
134 | Statement::CreateMaterializedView(_)
135 | Statement::RefreshMaterializedView(_)
136 | Statement::DropMaterializedView(_)
137 ) {
138 return true;
139 }
140 if let Statement::Select(sq) = stmt {
141 if select_query_has_dml(sq) {
142 return true;
143 }
144 }
145 false
146}
147
148fn select_query_has_dml(sq: &SelectQuery) -> bool {
149 sq.ctes.iter().any(|cte| query_body_has_dml(&cte.body)) || query_body_has_dml(&sq.body)
150}
151
152fn query_body_has_dml(body: &QueryBody) -> bool {
153 match body {
154 QueryBody::Insert(_) | QueryBody::Update(_) | QueryBody::Delete(_) => true,
155 QueryBody::Compound(c) => query_body_has_dml(&c.left) || query_body_has_dml(&c.right),
156 QueryBody::Select(_) => false,
157 }
158}
159
160fn try_normalize_insert(sql: &str) -> Option<(String, Vec<Value>)> {
161 let bytes = sql.as_bytes();
162 let len = bytes.len();
163 let mut i = 0;
164
165 while i < len && bytes[i].is_ascii_whitespace() {
166 i += 1;
167 }
168 if i + 6 > len || !bytes[i..i + 6].eq_ignore_ascii_case(b"INSERT") {
169 return None;
170 }
171 i += 6;
172 if i >= len || !bytes[i].is_ascii_whitespace() {
173 return None;
174 }
175 while i < len && bytes[i].is_ascii_whitespace() {
176 i += 1;
177 }
178
179 if i + 4 > len || !bytes[i..i + 4].eq_ignore_ascii_case(b"INTO") {
180 return None;
181 }
182 i += 4;
183 if i >= len || !bytes[i].is_ascii_whitespace() {
184 return None;
185 }
186
187 let prefix_start = 0;
188 let mut values_pos = None;
189 let mut j = i;
190 while j + 6 <= len {
191 if bytes[j..j + 6].eq_ignore_ascii_case(b"VALUES")
192 && (j == 0 || !bytes[j - 1].is_ascii_alphanumeric() && bytes[j - 1] != b'_')
193 && (j + 6 >= len || !bytes[j + 6].is_ascii_alphanumeric() && bytes[j + 6] != b'_')
194 {
195 values_pos = Some(j);
196 break;
197 }
198 j += 1;
199 }
200 let values_pos = values_pos?;
201
202 let prefix = &sql[prefix_start..values_pos + 6];
203 let mut pos = values_pos + 6;
204
205 while pos < len && bytes[pos].is_ascii_whitespace() {
206 pos += 1;
207 }
208 if pos >= len || bytes[pos] != b'(' {
209 return None;
210 }
211 pos += 1;
212
213 let mut values = Vec::new();
214 let mut normalized = String::with_capacity(sql.len());
215 normalized.push_str(prefix);
216 normalized.push_str(" (");
217
218 loop {
219 while pos < len && bytes[pos].is_ascii_whitespace() {
220 pos += 1;
221 }
222 if pos >= len {
223 return None;
224 }
225
226 let param_idx = values.len() + 1;
227 if param_idx > 1 {
228 normalized.push_str(", ");
229 }
230
231 if bytes[pos] == b'\'' {
232 pos += 1;
233 let mut seg_start = pos;
234 let mut s = String::new();
235 loop {
236 if pos >= len {
237 return None;
238 }
239 if bytes[pos] == b'\'' {
240 s.push_str(std::str::from_utf8(&bytes[seg_start..pos]).ok()?);
241 pos += 1;
242 if pos < len && bytes[pos] == b'\'' {
243 s.push('\'');
244 pos += 1;
245 seg_start = pos;
246 } else {
247 break;
248 }
249 } else {
250 pos += 1;
251 }
252 }
253 values.push(Value::Text(s.into()));
254 } else if bytes[pos] == b'-' || bytes[pos].is_ascii_digit() {
255 let start = pos;
256 if bytes[pos] == b'-' {
257 pos += 1;
258 }
259 while pos < len && bytes[pos].is_ascii_digit() {
260 pos += 1;
261 }
262 if pos < len && bytes[pos] == b'.' {
263 pos += 1;
264 while pos < len && bytes[pos].is_ascii_digit() {
265 pos += 1;
266 }
267 let num: f64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
268 values.push(Value::Real(num));
269 } else {
270 let num: i64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
271 values.push(Value::Integer(num));
272 }
273 } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"NULL") {
274 let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
275 if !after.is_ascii_alphanumeric() && after != b'_' {
276 pos += 4;
277 values.push(Value::Null);
278 } else {
279 return None;
280 }
281 } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"TRUE") {
282 let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
283 if !after.is_ascii_alphanumeric() && after != b'_' {
284 pos += 4;
285 values.push(Value::Boolean(true));
286 } else {
287 return None;
288 }
289 } else if pos + 5 <= len && bytes[pos..pos + 5].eq_ignore_ascii_case(b"FALSE") {
290 let after = if pos + 5 < len { bytes[pos + 5] } else { b')' };
291 if !after.is_ascii_alphanumeric() && after != b'_' {
292 pos += 5;
293 values.push(Value::Boolean(false));
294 } else {
295 return None;
296 }
297 } else {
298 return None;
299 }
300
301 normalized.push('$');
302 normalized.push_str(¶m_idx.to_string());
303
304 while pos < len && bytes[pos].is_ascii_whitespace() {
305 pos += 1;
306 }
307 if pos >= len {
308 return None;
309 }
310
311 if bytes[pos] == b',' {
312 pos += 1;
313 } else if bytes[pos] == b')' {
314 pos += 1;
315 break;
316 } else {
317 return None;
318 }
319 }
320
321 normalized.push(')');
322
323 while pos < len && (bytes[pos].is_ascii_whitespace() || bytes[pos] == b';') {
324 pos += 1;
325 }
326 if pos != len {
327 return None;
328 }
329
330 if values.is_empty() {
331 return None;
332 }
333
334 Some((normalized, values))
335}
336
337pub(crate) struct CacheEntry {
338 pub(crate) stmt: Arc<Statement>,
339 pub(crate) schema_gen: u64,
340 pub(crate) param_count: usize,
341 pub(crate) compiled: Option<Arc<dyn executor::CompiledPlan>>,
342}
343
344struct SavepointEntry {
345 name: String,
346 snapshot: Option<SavepointSnapshot>,
347}
348
349struct SavepointSnapshot {
350 wtx_snap: WriteTxnSnapshot,
351 schema_snap: SchemaSnapshot,
352}
353
354#[allow(clippy::large_enum_variant)]
357pub(crate) enum ActiveTxn<'a> {
358 None,
359 Write(WriteTxn<'a>),
360 Read(citadel_txn::read_txn::ReadTxn<'a>),
361}
362
363impl<'a> ActiveTxn<'a> {
364 fn is_none(&self) -> bool {
365 matches!(self, ActiveTxn::None)
366 }
367 fn is_active(&self) -> bool {
368 !self.is_none()
369 }
370 fn is_read_only(&self) -> bool {
371 matches!(self, ActiveTxn::Read(_))
372 }
373 fn as_write_mut(&mut self) -> Option<&mut WriteTxn<'a>> {
374 match self {
375 ActiveTxn::Write(w) => Some(w),
376 _ => None,
377 }
378 }
379 fn take(&mut self) -> ActiveTxn<'a> {
380 std::mem::replace(self, ActiveTxn::None)
381 }
382}
383
384pub(crate) struct ConnectionInner<'a> {
385 pub(crate) schema: SchemaManager,
386 active_txn: ActiveTxn<'a>,
387 savepoint_stack: Vec<SavepointEntry>,
388 in_place_saved: Option<bool>,
389 pub(crate) stmt_cache: LruCache<String, CacheEntry>,
390 txn_start_ts: Option<i64>,
391 session_timezone: String,
392 temp_id: u64,
394 temp_table_names: Vec<String>,
395}
396
397pub struct Connection<'a> {
398 pub(crate) db: &'a Database,
399 pub(crate) inner: RefCell<ConnectionInner<'a>>,
400}
401
402impl<'a> Connection<'a> {
403 pub fn open(db: &'a Database) -> Result<Self> {
404 let schema = SchemaManager::load(db)?;
405 let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
406 let temp_id = generate_temp_id();
407 Ok(Self {
408 db,
409 inner: RefCell::new(ConnectionInner {
410 schema,
411 active_txn: ActiveTxn::None,
412 savepoint_stack: Vec::new(),
413 in_place_saved: None,
414 stmt_cache,
415 txn_start_ts: None,
416 session_timezone: "UTC".to_string(),
417 temp_id,
418 temp_table_names: Vec::new(),
419 }),
420 })
421 }
422
423 pub fn txn_start_ts(&self) -> Option<i64> {
425 self.inner.borrow().txn_start_ts
426 }
427
428 pub fn session_timezone(&self) -> String {
430 self.inner.borrow().session_timezone.clone()
431 }
432
433 pub fn set_session_timezone(&self, tz: &str) -> Result<()> {
435 self.inner.borrow_mut().set_session_timezone_impl(tz)
436 }
437
438 pub fn execute(&self, sql: &str) -> Result<ExecutionResult> {
439 self.inner.borrow_mut().execute_impl(self.db, sql)
440 }
441
442 pub fn execute_params(&self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
443 self.inner
444 .borrow_mut()
445 .execute_params_impl(self.db, sql, params)
446 }
447
448 pub fn execute_script(&self, sql: &str) -> ScriptExecution {
450 let stmts = match parser::parse_sql_multi(sql) {
451 Ok(s) => s,
452 Err(e) => {
453 return ScriptExecution {
454 completed: vec![],
455 error: Some(e),
456 }
457 }
458 };
459 let mut completed = Vec::with_capacity(stmts.len());
460 for stmt in stmts {
461 match self.inner.borrow_mut().dispatch(self.db, &stmt, &[]) {
462 Ok(r) => completed.push(r),
463 Err(e) => {
464 return ScriptExecution {
465 completed,
466 error: Some(e),
467 }
468 }
469 }
470 }
471 ScriptExecution {
472 completed,
473 error: None,
474 }
475 }
476
477 pub fn query(&self, sql: &str) -> Result<QueryResult> {
478 self.query_params(sql, &[])
479 }
480
481 pub fn query_params(&self, sql: &str, params: &[Value]) -> Result<QueryResult> {
482 match self.execute_params(sql, params)? {
483 ExecutionResult::Query(qr) => Ok(qr),
484 ExecutionResult::RowsAffected(n) => Ok(QueryResult {
485 columns: vec!["rows_affected".into()],
486 rows: vec![vec![Value::Integer(n as i64)]],
487 }),
488 ExecutionResult::Ok => Ok(QueryResult {
489 columns: vec![],
490 rows: vec![],
491 }),
492 }
493 }
494
495 pub fn prepare(&self, sql: &str) -> Result<PreparedStatement<'_, 'a>> {
496 if let Some(rewritten) = rewrite_show_triggers(sql) {
497 return PreparedStatement::new(self, &rewritten);
498 }
499 if let Some(rewritten) = rewrite_show_matviews(sql) {
500 return PreparedStatement::new(self, &rewritten);
501 }
502 PreparedStatement::new(self, sql)
503 }
504
505 pub fn tables(&self) -> Vec<String> {
506 self.inner
507 .borrow()
508 .schema
509 .table_names()
510 .into_iter()
511 .map(String::from)
512 .collect()
513 }
514
515 pub fn in_transaction(&self) -> bool {
517 self.inner.borrow().active_txn.is_active()
518 }
519
520 pub fn table_schema(&self, name: &str) -> Option<TableSchema> {
521 self.inner.borrow().schema.get(name).cloned()
522 }
523
524 pub fn refresh_schema(&self) -> Result<()> {
525 let new_schema = SchemaManager::load(self.db)?;
526 self.inner.borrow_mut().schema = new_schema;
527 Ok(())
528 }
529}
530
531impl<'a> ConnectionInner<'a> {
532 pub(crate) fn active_txn_is_some(&self) -> bool {
533 self.active_txn.is_active()
534 }
535
536 fn set_session_timezone_impl(&mut self, tz: &str) -> Result<()> {
537 let upper = tz.to_ascii_uppercase();
538 if (upper.starts_with("UTC+") || upper.starts_with("UTC-")) && tz.len() > 3 {
539 return Err(SqlError::InvalidTimezone(format!(
540 "'{tz}' is ambiguous; use ISO-8601 offset (e.g. '+05:00') or named zone (e.g. 'Etc/GMT-5')"
541 )));
542 }
543 if jiff::tz::TimeZone::get(tz).is_err() && parse_fixed_offset(tz).is_none() {
544 return Err(SqlError::InvalidTimezone(format!(
545 "{tz}: not a known IANA zone or ISO-8601 offset (e.g. '+05:00', 'UTC', 'America/New_York')"
546 )));
547 }
548 self.session_timezone = tz.to_string();
549 Ok(())
550 }
551
552 fn execute_impl(&mut self, db: &'a Database, sql: &str) -> Result<ExecutionResult> {
553 if let Some(rewritten) = rewrite_show_triggers(sql) {
554 return self.execute_params_impl(db, &rewritten, &[]);
555 }
556 if let Some(rewritten) = rewrite_show_matviews(sql) {
557 return self.execute_params_impl(db, &rewritten, &[]);
558 }
559 if matches!(sql.as_bytes().first(), Some(b'I' | b'i')) {
560 if let Some((normalized_key, extracted)) = try_normalize_insert(sql) {
561 let gen = self.schema.generation();
562 let stmt = if let Some(entry) = self.stmt_cache.get(&normalized_key) {
563 if entry.schema_gen == gen {
564 Arc::clone(&entry.stmt)
565 } else {
566 self.parse_and_cache(normalized_key, gen)?
567 }
568 } else {
569 self.parse_and_cache(normalized_key, gen)?
570 };
571 return self.dispatch(db, &stmt, &extracted);
572 }
573 }
574 self.execute_params_impl(db, sql, &[])
575 }
576
577 fn execute_params_impl(
578 &mut self,
579 db: &'a Database,
580 sql: &str,
581 params: &[Value],
582 ) -> Result<ExecutionResult> {
583 let gen = self.schema.generation();
584 if self.active_txn.is_none() {
585 if let Some(entry) = self.stmt_cache.get(sql) {
586 if entry.schema_gen == gen && entry.param_count == params.len() {
587 if let Some(plan) = entry.compiled.as_ref().map(Arc::clone) {
588 let stmt = Arc::clone(&entry.stmt);
589 return self.run_compiled(db, &plan, &stmt, params);
590 }
591 }
592 }
593 }
594
595 let (stmt, param_count) = self.get_or_parse(sql)?;
596
597 if param_count != params.len() {
598 return Err(SqlError::ParameterCountMismatch {
599 expected: param_count,
600 got: params.len(),
601 });
602 }
603
604 if self.active_txn.is_none() {
605 if let Some(plan) = executor::compile(&self.schema, &stmt) {
606 if let Some(e) = self.stmt_cache.get_mut(sql) {
607 e.compiled = Some(Arc::clone(&plan));
608 }
609 let stmt_owned = Arc::clone(&stmt);
610 return self.run_compiled(db, &plan, &stmt_owned, params);
611 }
612 }
613
614 self.dispatch(db, &stmt, params)
615 }
616
617 fn run_compiled(
618 &mut self,
619 db: &'a Database,
620 plan: &Arc<dyn executor::CompiledPlan>,
621 stmt: &Statement,
622 params: &[Value],
623 ) -> Result<ExecutionResult> {
624 use executor::compile::ActiveTxnRef;
625 let schema = &self.schema;
626 let exec = || {
627 if params.is_empty() {
628 plan.execute(db, schema, stmt, params, ActiveTxnRef::None)
629 } else {
630 crate::eval::with_scoped_params(params, || {
631 plan.execute(db, schema, stmt, params, ActiveTxnRef::None)
632 })
633 }
634 };
635 let outcome = if plan.needs_txn_clock() {
636 let cached_ts = self
637 .txn_start_ts
638 .or_else(|| Some(crate::datetime::now_micros()));
639 crate::datetime::with_txn_clock(cached_ts, exec)
640 } else {
641 exec()
642 }?;
643 invalidate_dml_caches(&self.schema, db);
644 Ok(outcome)
645 }
646
647 pub(crate) fn parse_and_cache(
648 &mut self,
649 normalized_key: String,
650 gen: u64,
651 ) -> Result<Arc<Statement>> {
652 let stmt = Arc::new(parser::parse_sql(&normalized_key)?);
653 let param_count = parser::count_params(&stmt);
654 self.stmt_cache.put(
655 normalized_key,
656 CacheEntry {
657 stmt: Arc::clone(&stmt),
658 schema_gen: gen,
659 param_count,
660 compiled: None,
661 },
662 );
663 Ok(stmt)
664 }
665
666 pub(crate) fn get_or_parse(&mut self, sql: &str) -> Result<(Arc<Statement>, usize)> {
667 let gen = self.schema.generation();
668
669 if let Some(entry) = self.stmt_cache.get(sql) {
670 if entry.schema_gen == gen {
671 return Ok((Arc::clone(&entry.stmt), entry.param_count));
672 }
673 }
674
675 let stmt = Arc::new(parser::parse_sql(sql)?);
676 let param_count = parser::count_params(&stmt);
677
678 let cacheable = !matches!(
679 *stmt,
680 Statement::CreateTable(_)
681 | Statement::DropTable(_)
682 | Statement::CreateIndex(_)
683 | Statement::DropIndex(_)
684 | Statement::CreateView(_)
685 | Statement::DropView(_)
686 | Statement::AlterTable(_)
687 );
688
689 if cacheable {
690 self.stmt_cache.put(
691 sql.to_string(),
692 CacheEntry {
693 stmt: Arc::clone(&stmt),
694 schema_gen: gen,
695 param_count,
696 compiled: None,
697 },
698 );
699 }
700
701 Ok((stmt, param_count))
702 }
703
704 pub(crate) fn execute_prepared(
705 &mut self,
706 db: &'a Database,
707 stmt: &Statement,
708 compiled: Option<&Arc<dyn executor::CompiledPlan>>,
709 params: &[Value],
710 ) -> Result<ExecutionResult> {
711 if let Some(plan) = compiled {
712 if self.active_txn.is_none() {
713 return self.run_compiled(db, plan, stmt, params);
714 }
715 if !self.savepoint_stack.is_empty() && stmt_mutates(stmt) {
716 self.capture_pending_snapshots();
717 }
718 return self.run_compiled_in_txn(db, plan, stmt, params);
719 }
720 self.dispatch(db, stmt, params)
721 }
722
723 fn run_compiled_in_txn(
724 &mut self,
725 db: &'a Database,
726 plan: &Arc<dyn executor::CompiledPlan>,
727 stmt: &Statement,
728 params: &[Value],
729 ) -> Result<ExecutionResult> {
730 use executor::compile::ActiveTxnRef;
731 let schema = &self.schema;
732 let txn = match &mut self.active_txn {
733 ActiveTxn::Write(wtx) => ActiveTxnRef::Write(wtx),
734 ActiveTxn::Read(rtx) => ActiveTxnRef::Read(rtx),
735 ActiveTxn::None => ActiveTxnRef::None,
736 };
737 if params.is_empty() || !plan.uses_scoped_params() {
738 plan.execute(db, schema, stmt, params, txn)
739 } else {
740 crate::eval::with_scoped_params(params, || plan.execute(db, schema, stmt, params, txn))
741 }
742 }
743
744 pub(crate) fn dispatch(
745 &mut self,
746 db: &'a Database,
747 stmt: &Statement,
748 params: &[Value],
749 ) -> Result<ExecutionResult> {
750 let cached_ts = self
751 .txn_start_ts
752 .or_else(|| Some(crate::datetime::now_micros()));
753 crate::datetime::with_txn_clock(cached_ts, || {
754 if params.is_empty() {
755 self.dispatch_inner(db, stmt, params)
756 } else {
757 crate::eval::with_scoped_params(params, || self.dispatch_inner(db, stmt, params))
758 }
759 })
760 }
761
762 fn dispatch_inner(
763 &mut self,
764 db: &'a Database,
765 stmt: &Statement,
766 params: &[Value],
767 ) -> Result<ExecutionResult> {
768 match stmt {
769 Statement::Begin { access_mode } => {
770 if self.active_txn.is_active() {
771 return Err(SqlError::TransactionAlreadyActive);
772 }
773 let ts = crate::datetime::txn_or_clock_micros();
774 match access_mode {
775 BeginAccessMode::ReadOnly => {
776 let rtx = db.begin_read();
777 self.active_txn = ActiveTxn::Read(rtx);
778 }
779 BeginAccessMode::ReadWrite | BeginAccessMode::Default => {
780 let wtx = db.begin_write().map_err(SqlError::Storage)?;
781 self.active_txn = ActiveTxn::Write(wtx);
782 }
783 }
784 self.txn_start_ts = Some(ts);
785 crate::datetime::set_txn_clock(Some(ts));
786 Ok(ExecutionResult::Ok)
787 }
788 Statement::Commit => {
789 match self.active_txn.take() {
790 ActiveTxn::None => return Err(SqlError::NoActiveTransaction),
791 ActiveTxn::Write(mut wtx) => {
792 crate::executor::helpers::drain_deferred_fk_checks(&mut wtx)?;
793 wtx.commit().map_err(SqlError::Storage)?;
794 invalidate_dml_caches(&self.schema, db);
795 }
796 ActiveTxn::Read(_rtx) => {}
797 }
798 self.clear_savepoint_state();
799 self.txn_start_ts = None;
800 crate::datetime::set_txn_clock(None);
801 Ok(ExecutionResult::Ok)
802 }
803 Statement::Rollback => {
804 match self.active_txn.take() {
805 ActiveTxn::None => return Err(SqlError::NoActiveTransaction),
806 ActiveTxn::Write(wtx) => {
807 wtx.abort();
808 self.schema = SchemaManager::load(db)?;
809 }
810 ActiveTxn::Read(_rtx) => {}
811 }
812 self.clear_savepoint_state();
813 self.txn_start_ts = None;
814 crate::datetime::set_txn_clock(None);
815 Ok(ExecutionResult::Ok)
816 }
817 Statement::Savepoint(name) => self.do_savepoint(name),
818 Statement::ReleaseSavepoint(name) => self.do_release(name),
819 Statement::RollbackTo(name) => self.do_rollback_to(name),
820 Statement::SetTimezone(zone) => {
821 self.set_session_timezone_impl(zone)?;
822 Ok(ExecutionResult::Ok)
823 }
824 Statement::CreateTable(ct) if ct.temporary => {
825 if self.active_txn.is_read_only() {
826 return Err(SqlError::Unsupported(
827 "cannot execute mutating statement inside a read-only transaction".into(),
828 ));
829 }
830 let user_name = ct.name.clone();
831 let prefixed = temp_storage_name(self.temp_id, &user_name);
832 if self.schema.contains(&user_name) {
833 if ct.if_not_exists {
834 return Ok(ExecutionResult::Ok);
835 }
836 return Err(SqlError::TableAlreadyExists(user_name));
837 }
838 let mut clone = ct.clone();
839 clone.name = prefixed.clone();
840 clone.temporary = false;
841 let stmt_concrete = Statement::CreateTable(clone);
842 let outcome = if let Some(wtx) = self.active_txn.as_write_mut() {
843 executor::execute_in_txn(wtx, &mut self.schema, &stmt_concrete, params)?
844 } else {
845 executor::execute(db, &mut self.schema, &stmt_concrete, params)?
846 };
847 self.schema
848 .register_temp_alias(&user_name, prefixed.clone());
849 self.temp_table_names.push(prefixed);
850 Ok(outcome)
851 }
852 Statement::Insert(ins) if self.active_txn.as_write_mut().is_some() => {
853 self.capture_pending_snapshots();
854 let wtx = self.active_txn.as_write_mut().unwrap();
855 executor::exec_insert_in_txn(wtx, &self.schema, ins, params)
856 }
857 _ => {
858 if self.active_txn.is_read_only() && stmt_mutates(stmt) {
859 return Err(SqlError::Unsupported(
860 "cannot execute mutating statement inside a read-only transaction".into(),
861 ));
862 }
863 if self.active_txn.as_write_mut().is_some() && stmt_mutates(stmt) {
864 self.capture_pending_snapshots();
865 }
866 let was_auto_commit = matches!(self.active_txn, ActiveTxn::None);
867 let outcome = match &mut self.active_txn {
868 ActiveTxn::Write(wtx) => {
869 executor::execute_in_txn(wtx, &mut self.schema, stmt, params)?
870 }
871 ActiveTxn::Read(rtx) => {
872 executor::execute_with_read(rtx, &self.schema, stmt, params)?
873 }
874 ActiveTxn::None => executor::execute(db, &mut self.schema, stmt, params)?,
875 };
876 if was_auto_commit {
877 invalidate_dml_caches(&self.schema, db);
878 }
879 if let Statement::DropTable(dt) = stmt {
880 self.schema.unregister_temp_alias(&dt.name);
881 }
882 Ok(outcome)
883 }
884 }
885 }
886
887 fn clear_savepoint_state(&mut self) {
888 self.savepoint_stack.clear();
889 self.in_place_saved = None;
890 }
891
892 fn do_savepoint(&mut self, name: &str) -> Result<ExecutionResult> {
893 let wtx = self
894 .active_txn
895 .as_write_mut()
896 .ok_or(SqlError::NoActiveTransaction)?;
897
898 if self.savepoint_stack.is_empty() {
899 self.in_place_saved = Some(wtx.in_place());
900 wtx.set_in_place(false);
901 }
902
903 self.savepoint_stack.push(SavepointEntry {
904 name: name.to_string(),
905 snapshot: None,
906 });
907
908 Ok(ExecutionResult::Ok)
909 }
910
911 fn capture_pending_snapshots(&mut self) {
912 let last_pending = match self
913 .savepoint_stack
914 .iter()
915 .rposition(|e| e.snapshot.is_none())
916 {
917 Some(i) => i,
918 None => return,
919 };
920 let wtx = match self.active_txn.as_write_mut() {
921 Some(w) => w,
922 None => return,
923 };
924 let wtx_snap = wtx.begin_savepoint();
925 let schema_snap = self.schema.save_snapshot();
926
927 for i in 0..last_pending {
928 if self.savepoint_stack[i].snapshot.is_none() {
929 self.savepoint_stack[i].snapshot = Some(SavepointSnapshot {
930 wtx_snap: wtx_snap.clone(),
931 schema_snap: schema_snap.clone(),
932 });
933 }
934 }
935 self.savepoint_stack[last_pending].snapshot = Some(SavepointSnapshot {
936 wtx_snap,
937 schema_snap,
938 });
939 }
940
941 fn do_release(&mut self, name: &str) -> Result<ExecutionResult> {
942 if !self.active_txn.is_active() {
943 return Err(SqlError::NoActiveTransaction);
944 }
945
946 let idx = self
947 .savepoint_stack
948 .iter()
949 .rposition(|e| e.name == name)
950 .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
951 self.savepoint_stack.truncate(idx);
952
953 if self.savepoint_stack.is_empty() {
954 if let (Some(wtx), Some(original)) =
955 (self.active_txn.as_write_mut(), self.in_place_saved.take())
956 {
957 wtx.set_in_place(original);
958 }
959 }
960
961 Ok(ExecutionResult::Ok)
962 }
963
964 fn do_rollback_to(&mut self, name: &str) -> Result<ExecutionResult> {
965 if !self.active_txn.is_active() {
966 return Err(SqlError::NoActiveTransaction);
967 }
968
969 let idx = self
970 .savepoint_stack
971 .iter()
972 .rposition(|e| e.name == name)
973 .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
974
975 self.savepoint_stack.truncate(idx + 1);
976 let entry = self.savepoint_stack.last_mut().unwrap();
977 let snapshot = match entry.snapshot.take() {
978 Some(s) => s,
979 None => return Ok(ExecutionResult::Ok),
980 };
981
982 let wtx = match self.active_txn.as_write_mut() {
983 Some(w) => w,
984 None => return Err(SqlError::NoActiveTransaction),
985 };
986 wtx.restore_snapshot(snapshot.wtx_snap);
987 self.schema.restore_snapshot(snapshot.schema_snap);
988
989 Ok(ExecutionResult::Ok)
990 }
991}
992
993impl<'a> Drop for Connection<'a> {
994 fn drop(&mut self) {
995 let temp_names = std::mem::take(&mut self.inner.borrow_mut().temp_table_names);
996 if temp_names.is_empty() {
997 return;
998 }
999 if let Ok(mut wtx) = self.db.begin_write() {
1000 for prefixed in &temp_names {
1001 let _ = wtx.drop_table(prefixed.as_bytes());
1002 }
1003 let _ = wtx.commit();
1004 }
1005 }
1006}