1use std::cell::RefCell;
4use std::num::NonZeroUsize;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7
8fn generate_temp_id() -> u64 {
9 static COUNTER: AtomicU64 = AtomicU64::new(0);
10 let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
11 let nanos = std::time::SystemTime::now()
12 .duration_since(std::time::UNIX_EPOCH)
13 .map(|d| d.subsec_nanos() as u64)
14 .unwrap_or(0);
15 (nanos << 32) | (counter & 0xFFFF_FFFF)
16}
17
18fn temp_storage_name(temp_id: u64, user_name: &str) -> String {
19 format!("__temp_{temp_id}_{}", user_name.to_ascii_lowercase())
20}
21
22use lru::LruCache;
23
24use citadel::Database;
25use citadel_txn::write_txn::{WriteTxn, WriteTxnSnapshot};
26
27use crate::error::{Result, SqlError};
28use crate::executor;
29use crate::parser;
30use crate::parser::{BeginAccessMode, QueryBody, SelectQuery, Statement};
31use crate::prepared::PreparedStatement;
32use crate::schema::{SchemaManager, SchemaSnapshot};
33use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
34
35const DEFAULT_CACHE_CAPACITY: usize = 64;
36
37#[derive(Debug)]
38pub struct ScriptExecution {
39 pub completed: Vec<ExecutionResult>,
40 pub error: Option<SqlError>,
41}
42
43fn parse_fixed_offset(s: &str) -> Option<()> {
44 let s = s.trim();
45 if s.eq_ignore_ascii_case("z") || s.eq_ignore_ascii_case("utc") {
46 return Some(());
47 }
48 let bytes = s.as_bytes();
49 if bytes.is_empty() {
50 return None;
51 }
52 if !matches!(bytes[0], b'+' | b'-') {
53 return None;
54 }
55 let rest = &s[1..];
56 let (hh, mm) = if let Some((h, m)) = rest.split_once(':') {
57 (h, m)
58 } else if rest.len() == 4 {
59 (&rest[..2], &rest[2..])
60 } else if rest.len() == 2 {
61 (rest, "00")
62 } else {
63 return None;
64 };
65 let h: u32 = hh.parse().ok()?;
66 let m: u32 = mm.parse().ok()?;
67 if h > 23 || m > 59 {
68 return None;
69 }
70 Some(())
71}
72
73fn rewrite_show_triggers(sql: &str) -> Option<String> {
74 let trimmed = sql.trim();
75 let trimmed = trimmed.trim_end_matches(';').trim();
76 let lower = trimmed.to_ascii_lowercase();
77 if !lower.starts_with("show triggers") {
78 return None;
79 }
80 let after = lower["show triggers".len()..].trim_start();
81 let base = "SELECT trigger_name, event_object_table AS table_name, action_timing, \
82 event_manipulation, action_orientation, action_statement \
83 FROM information_schema.triggers";
84 if after.is_empty() {
85 return Some(format!("{base} ORDER BY trigger_name"));
86 }
87 if let Some(rest) = after.strip_prefix("on ") {
88 let table = rest.trim().trim_end_matches(';').trim();
89 if table.is_empty() {
90 return None;
91 }
92 let escaped = table.replace('\'', "''");
93 return Some(format!(
94 "{base} WHERE LOWER(event_object_table) = LOWER('{escaped}') ORDER BY trigger_name"
95 ));
96 }
97 None
98}
99
100fn rewrite_show_matviews(sql: &str) -> Option<String> {
101 let trimmed = sql.trim();
102 let trimmed = trimmed.trim_end_matches(';').trim();
103 let lower = trimmed.to_ascii_lowercase();
104 if lower != "show materialized views" {
105 return None;
106 }
107 Some(
108 "SELECT matviewname, ispopulated, hasindexes, definition \
109 FROM pg_matviews ORDER BY matviewname"
110 .to_string(),
111 )
112}
113
114fn stmt_mutates(stmt: &Statement) -> bool {
115 if matches!(
116 stmt,
117 Statement::Insert(_)
118 | Statement::Update(_)
119 | Statement::Delete(_)
120 | Statement::Truncate(_)
121 | Statement::CreateTable(_)
122 | Statement::DropTable(_)
123 | Statement::AlterTable(_)
124 | Statement::CreateIndex(_)
125 | Statement::DropIndex(_)
126 | Statement::CreateView(_)
127 | Statement::DropView(_)
128 | Statement::CreateTrigger(_)
129 | Statement::DropTrigger(_)
130 | Statement::CreateMaterializedView(_)
131 | Statement::RefreshMaterializedView(_)
132 | Statement::DropMaterializedView(_)
133 ) {
134 return true;
135 }
136 if let Statement::Select(sq) = stmt {
137 if select_query_has_dml(sq) {
138 return true;
139 }
140 }
141 false
142}
143
144fn select_query_has_dml(sq: &SelectQuery) -> bool {
145 sq.ctes.iter().any(|cte| query_body_has_dml(&cte.body)) || query_body_has_dml(&sq.body)
146}
147
148fn query_body_has_dml(body: &QueryBody) -> bool {
149 match body {
150 QueryBody::Insert(_) | QueryBody::Update(_) | QueryBody::Delete(_) => true,
151 QueryBody::Compound(c) => query_body_has_dml(&c.left) || query_body_has_dml(&c.right),
152 QueryBody::Select(_) => false,
153 }
154}
155
156fn try_normalize_insert(sql: &str) -> Option<(String, Vec<Value>)> {
157 let bytes = sql.as_bytes();
158 let len = bytes.len();
159 let mut i = 0;
160
161 while i < len && bytes[i].is_ascii_whitespace() {
162 i += 1;
163 }
164 if i + 6 > len || !bytes[i..i + 6].eq_ignore_ascii_case(b"INSERT") {
165 return None;
166 }
167 i += 6;
168 if i >= len || !bytes[i].is_ascii_whitespace() {
169 return None;
170 }
171 while i < len && bytes[i].is_ascii_whitespace() {
172 i += 1;
173 }
174
175 if i + 4 > len || !bytes[i..i + 4].eq_ignore_ascii_case(b"INTO") {
176 return None;
177 }
178 i += 4;
179 if i >= len || !bytes[i].is_ascii_whitespace() {
180 return None;
181 }
182
183 let prefix_start = 0;
184 let mut values_pos = None;
185 let mut j = i;
186 while j + 6 <= len {
187 if bytes[j..j + 6].eq_ignore_ascii_case(b"VALUES")
188 && (j == 0 || !bytes[j - 1].is_ascii_alphanumeric() && bytes[j - 1] != b'_')
189 && (j + 6 >= len || !bytes[j + 6].is_ascii_alphanumeric() && bytes[j + 6] != b'_')
190 {
191 values_pos = Some(j);
192 break;
193 }
194 j += 1;
195 }
196 let values_pos = values_pos?;
197
198 let prefix = &sql[prefix_start..values_pos + 6];
199 let mut pos = values_pos + 6;
200
201 while pos < len && bytes[pos].is_ascii_whitespace() {
202 pos += 1;
203 }
204 if pos >= len || bytes[pos] != b'(' {
205 return None;
206 }
207 pos += 1;
208
209 let mut values = Vec::new();
210 let mut normalized = String::with_capacity(sql.len());
211 normalized.push_str(prefix);
212 normalized.push_str(" (");
213
214 loop {
215 while pos < len && bytes[pos].is_ascii_whitespace() {
216 pos += 1;
217 }
218 if pos >= len {
219 return None;
220 }
221
222 let param_idx = values.len() + 1;
223 if param_idx > 1 {
224 normalized.push_str(", ");
225 }
226
227 if bytes[pos] == b'\'' {
228 pos += 1;
229 let mut seg_start = pos;
230 let mut s = String::new();
231 loop {
232 if pos >= len {
233 return None;
234 }
235 if bytes[pos] == b'\'' {
236 s.push_str(std::str::from_utf8(&bytes[seg_start..pos]).ok()?);
237 pos += 1;
238 if pos < len && bytes[pos] == b'\'' {
239 s.push('\'');
240 pos += 1;
241 seg_start = pos;
242 } else {
243 break;
244 }
245 } else {
246 pos += 1;
247 }
248 }
249 values.push(Value::Text(s.into()));
250 } else if bytes[pos] == b'-' || bytes[pos].is_ascii_digit() {
251 let start = pos;
252 if bytes[pos] == b'-' {
253 pos += 1;
254 }
255 while pos < len && bytes[pos].is_ascii_digit() {
256 pos += 1;
257 }
258 if pos < len && bytes[pos] == b'.' {
259 pos += 1;
260 while pos < len && bytes[pos].is_ascii_digit() {
261 pos += 1;
262 }
263 let num: f64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
264 values.push(Value::Real(num));
265 } else {
266 let num: i64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
267 values.push(Value::Integer(num));
268 }
269 } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"NULL") {
270 let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
271 if !after.is_ascii_alphanumeric() && after != b'_' {
272 pos += 4;
273 values.push(Value::Null);
274 } else {
275 return None;
276 }
277 } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"TRUE") {
278 let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
279 if !after.is_ascii_alphanumeric() && after != b'_' {
280 pos += 4;
281 values.push(Value::Boolean(true));
282 } else {
283 return None;
284 }
285 } else if pos + 5 <= len && bytes[pos..pos + 5].eq_ignore_ascii_case(b"FALSE") {
286 let after = if pos + 5 < len { bytes[pos + 5] } else { b')' };
287 if !after.is_ascii_alphanumeric() && after != b'_' {
288 pos += 5;
289 values.push(Value::Boolean(false));
290 } else {
291 return None;
292 }
293 } else {
294 return None;
295 }
296
297 normalized.push('$');
298 normalized.push_str(¶m_idx.to_string());
299
300 while pos < len && bytes[pos].is_ascii_whitespace() {
301 pos += 1;
302 }
303 if pos >= len {
304 return None;
305 }
306
307 if bytes[pos] == b',' {
308 pos += 1;
309 } else if bytes[pos] == b')' {
310 pos += 1;
311 break;
312 } else {
313 return None;
314 }
315 }
316
317 normalized.push(')');
318
319 while pos < len && (bytes[pos].is_ascii_whitespace() || bytes[pos] == b';') {
320 pos += 1;
321 }
322 if pos != len {
323 return None;
324 }
325
326 if values.is_empty() {
327 return None;
328 }
329
330 Some((normalized, values))
331}
332
333pub(crate) struct CacheEntry {
334 pub(crate) stmt: Arc<Statement>,
335 pub(crate) schema_gen: u64,
336 pub(crate) param_count: usize,
337 pub(crate) compiled: Option<Arc<dyn executor::CompiledPlan>>,
338}
339
340struct SavepointEntry {
341 name: String,
342 snapshot: Option<SavepointSnapshot>,
343}
344
345struct SavepointSnapshot {
346 wtx_snap: WriteTxnSnapshot,
347 schema_snap: SchemaSnapshot,
348}
349
350#[allow(clippy::large_enum_variant)]
353pub(crate) enum ActiveTxn<'a> {
354 None,
355 Write(WriteTxn<'a>),
356 Read(citadel_txn::read_txn::ReadTxn<'a>),
357}
358
359impl<'a> ActiveTxn<'a> {
360 fn is_none(&self) -> bool {
361 matches!(self, ActiveTxn::None)
362 }
363 fn is_active(&self) -> bool {
364 !self.is_none()
365 }
366 fn is_read_only(&self) -> bool {
367 matches!(self, ActiveTxn::Read(_))
368 }
369 fn as_write_mut(&mut self) -> Option<&mut WriteTxn<'a>> {
370 match self {
371 ActiveTxn::Write(w) => Some(w),
372 _ => None,
373 }
374 }
375 fn take(&mut self) -> ActiveTxn<'a> {
376 std::mem::replace(self, ActiveTxn::None)
377 }
378}
379
380pub(crate) struct ConnectionInner<'a> {
381 pub(crate) schema: SchemaManager,
382 active_txn: ActiveTxn<'a>,
383 savepoint_stack: Vec<SavepointEntry>,
384 in_place_saved: Option<bool>,
385 pub(crate) stmt_cache: LruCache<String, CacheEntry>,
386 txn_start_ts: Option<i64>,
387 session_timezone: String,
388 temp_id: u64,
390 temp_table_names: Vec<String>,
391}
392
393pub struct Connection<'a> {
394 pub(crate) db: &'a Database,
395 pub(crate) inner: RefCell<ConnectionInner<'a>>,
396}
397
398impl<'a> Connection<'a> {
399 pub fn open(db: &'a Database) -> Result<Self> {
400 let schema = SchemaManager::load(db)?;
401 let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
402 let temp_id = generate_temp_id();
403 Ok(Self {
404 db,
405 inner: RefCell::new(ConnectionInner {
406 schema,
407 active_txn: ActiveTxn::None,
408 savepoint_stack: Vec::new(),
409 in_place_saved: None,
410 stmt_cache,
411 txn_start_ts: None,
412 session_timezone: "UTC".to_string(),
413 temp_id,
414 temp_table_names: Vec::new(),
415 }),
416 })
417 }
418
419 pub fn txn_start_ts(&self) -> Option<i64> {
421 self.inner.borrow().txn_start_ts
422 }
423
424 pub fn session_timezone(&self) -> String {
426 self.inner.borrow().session_timezone.clone()
427 }
428
429 pub fn set_session_timezone(&self, tz: &str) -> Result<()> {
431 self.inner.borrow_mut().set_session_timezone_impl(tz)
432 }
433
434 pub fn execute(&self, sql: &str) -> Result<ExecutionResult> {
435 self.inner.borrow_mut().execute_impl(self.db, sql)
436 }
437
438 pub fn execute_params(&self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
439 self.inner
440 .borrow_mut()
441 .execute_params_impl(self.db, sql, params)
442 }
443
444 pub fn execute_script(&self, sql: &str) -> ScriptExecution {
446 let stmts = match parser::parse_sql_multi(sql) {
447 Ok(s) => s,
448 Err(e) => {
449 return ScriptExecution {
450 completed: vec![],
451 error: Some(e),
452 }
453 }
454 };
455 let mut completed = Vec::with_capacity(stmts.len());
456 for stmt in stmts {
457 match self.inner.borrow_mut().dispatch(self.db, &stmt, &[]) {
458 Ok(r) => completed.push(r),
459 Err(e) => {
460 return ScriptExecution {
461 completed,
462 error: Some(e),
463 }
464 }
465 }
466 }
467 ScriptExecution {
468 completed,
469 error: None,
470 }
471 }
472
473 pub fn query(&self, sql: &str) -> Result<QueryResult> {
474 self.query_params(sql, &[])
475 }
476
477 pub fn query_params(&self, sql: &str, params: &[Value]) -> Result<QueryResult> {
478 match self.execute_params(sql, params)? {
479 ExecutionResult::Query(qr) => Ok(qr),
480 ExecutionResult::RowsAffected(n) => Ok(QueryResult {
481 columns: vec!["rows_affected".into()],
482 rows: vec![vec![Value::Integer(n as i64)]],
483 }),
484 ExecutionResult::Ok => Ok(QueryResult {
485 columns: vec![],
486 rows: vec![],
487 }),
488 }
489 }
490
491 pub fn prepare(&self, sql: &str) -> Result<PreparedStatement<'_, 'a>> {
492 if let Some(rewritten) = rewrite_show_triggers(sql) {
493 return PreparedStatement::new(self, &rewritten);
494 }
495 if let Some(rewritten) = rewrite_show_matviews(sql) {
496 return PreparedStatement::new(self, &rewritten);
497 }
498 PreparedStatement::new(self, sql)
499 }
500
501 pub fn tables(&self) -> Vec<String> {
502 self.inner
503 .borrow()
504 .schema
505 .table_names()
506 .into_iter()
507 .map(String::from)
508 .collect()
509 }
510
511 pub fn in_transaction(&self) -> bool {
513 self.inner.borrow().active_txn.is_active()
514 }
515
516 pub fn table_schema(&self, name: &str) -> Option<TableSchema> {
517 self.inner.borrow().schema.get(name).cloned()
518 }
519
520 pub fn refresh_schema(&self) -> Result<()> {
521 let new_schema = SchemaManager::load(self.db)?;
522 self.inner.borrow_mut().schema = new_schema;
523 Ok(())
524 }
525}
526
527impl<'a> ConnectionInner<'a> {
528 pub(crate) fn active_txn_is_some(&self) -> bool {
529 self.active_txn.is_active()
530 }
531
532 fn set_session_timezone_impl(&mut self, tz: &str) -> Result<()> {
533 let upper = tz.to_ascii_uppercase();
534 if (upper.starts_with("UTC+") || upper.starts_with("UTC-")) && tz.len() > 3 {
535 return Err(SqlError::InvalidTimezone(format!(
536 "'{tz}' is ambiguous; use ISO-8601 offset (e.g. '+05:00') or named zone (e.g. 'Etc/GMT-5')"
537 )));
538 }
539 if jiff::tz::TimeZone::get(tz).is_err() && parse_fixed_offset(tz).is_none() {
540 return Err(SqlError::InvalidTimezone(format!(
541 "{tz}: not a known IANA zone or ISO-8601 offset (e.g. '+05:00', 'UTC', 'America/New_York')"
542 )));
543 }
544 self.session_timezone = tz.to_string();
545 Ok(())
546 }
547
548 fn execute_impl(&mut self, db: &'a Database, sql: &str) -> Result<ExecutionResult> {
549 if let Some(rewritten) = rewrite_show_triggers(sql) {
550 return self.execute_params_impl(db, &rewritten, &[]);
551 }
552 if let Some(rewritten) = rewrite_show_matviews(sql) {
553 return self.execute_params_impl(db, &rewritten, &[]);
554 }
555 if matches!(sql.as_bytes().first(), Some(b'I' | b'i')) {
556 if let Some((normalized_key, extracted)) = try_normalize_insert(sql) {
557 let gen = self.schema.generation();
558 let stmt = if let Some(entry) = self.stmt_cache.get(&normalized_key) {
559 if entry.schema_gen == gen {
560 Arc::clone(&entry.stmt)
561 } else {
562 self.parse_and_cache(normalized_key, gen)?
563 }
564 } else {
565 self.parse_and_cache(normalized_key, gen)?
566 };
567 return self.dispatch(db, &stmt, &extracted);
568 }
569 }
570 self.execute_params_impl(db, sql, &[])
571 }
572
573 fn execute_params_impl(
574 &mut self,
575 db: &'a Database,
576 sql: &str,
577 params: &[Value],
578 ) -> Result<ExecutionResult> {
579 let gen = self.schema.generation();
580 if self.active_txn.is_none() {
581 if let Some(entry) = self.stmt_cache.get(sql) {
582 if entry.schema_gen == gen && entry.param_count == params.len() {
583 if let Some(plan) = entry.compiled.as_ref().map(Arc::clone) {
584 let stmt = Arc::clone(&entry.stmt);
585 return self.run_compiled(db, &plan, &stmt, params);
586 }
587 }
588 }
589 }
590
591 let (stmt, param_count) = self.get_or_parse(sql)?;
592
593 if param_count != params.len() {
594 return Err(SqlError::ParameterCountMismatch {
595 expected: param_count,
596 got: params.len(),
597 });
598 }
599
600 if self.active_txn.is_none() {
601 if let Some(plan) = executor::compile(&self.schema, &stmt) {
602 if let Some(e) = self.stmt_cache.get_mut(sql) {
603 e.compiled = Some(Arc::clone(&plan));
604 }
605 let stmt_owned = Arc::clone(&stmt);
606 return self.run_compiled(db, &plan, &stmt_owned, params);
607 }
608 }
609
610 self.dispatch(db, &stmt, params)
611 }
612
613 fn run_compiled(
614 &mut self,
615 db: &'a Database,
616 plan: &Arc<dyn executor::CompiledPlan>,
617 stmt: &Statement,
618 params: &[Value],
619 ) -> Result<ExecutionResult> {
620 let schema = &self.schema;
621 let exec = || {
622 if params.is_empty() {
623 plan.execute(db, schema, stmt, params, None)
624 } else {
625 crate::eval::with_scoped_params(params, || {
626 plan.execute(db, schema, stmt, params, None)
627 })
628 }
629 };
630 if plan.needs_txn_clock() {
631 let cached_ts = self
632 .txn_start_ts
633 .or_else(|| Some(crate::datetime::now_micros()));
634 crate::datetime::with_txn_clock(cached_ts, exec)
635 } else {
636 exec()
637 }
638 }
639
640 pub(crate) fn parse_and_cache(
641 &mut self,
642 normalized_key: String,
643 gen: u64,
644 ) -> Result<Arc<Statement>> {
645 let stmt = Arc::new(parser::parse_sql(&normalized_key)?);
646 let param_count = parser::count_params(&stmt);
647 self.stmt_cache.put(
648 normalized_key,
649 CacheEntry {
650 stmt: Arc::clone(&stmt),
651 schema_gen: gen,
652 param_count,
653 compiled: None,
654 },
655 );
656 Ok(stmt)
657 }
658
659 pub(crate) fn get_or_parse(&mut self, sql: &str) -> Result<(Arc<Statement>, usize)> {
660 let gen = self.schema.generation();
661
662 if let Some(entry) = self.stmt_cache.get(sql) {
663 if entry.schema_gen == gen {
664 return Ok((Arc::clone(&entry.stmt), entry.param_count));
665 }
666 }
667
668 let stmt = Arc::new(parser::parse_sql(sql)?);
669 let param_count = parser::count_params(&stmt);
670
671 let cacheable = !matches!(
672 *stmt,
673 Statement::CreateTable(_)
674 | Statement::DropTable(_)
675 | Statement::CreateIndex(_)
676 | Statement::DropIndex(_)
677 | Statement::CreateView(_)
678 | Statement::DropView(_)
679 | Statement::AlterTable(_)
680 );
681
682 if cacheable {
683 self.stmt_cache.put(
684 sql.to_string(),
685 CacheEntry {
686 stmt: Arc::clone(&stmt),
687 schema_gen: gen,
688 param_count,
689 compiled: None,
690 },
691 );
692 }
693
694 Ok((stmt, param_count))
695 }
696
697 pub(crate) fn execute_prepared(
698 &mut self,
699 db: &'a Database,
700 stmt: &Statement,
701 compiled: Option<&Arc<dyn executor::CompiledPlan>>,
702 params: &[Value],
703 ) -> Result<ExecutionResult> {
704 if let Some(plan) = compiled {
705 if self.active_txn.is_none() {
706 return self.run_compiled(db, plan, stmt, params);
707 }
708 if !self.savepoint_stack.is_empty() && stmt_mutates(stmt) {
709 self.capture_pending_snapshots();
710 }
711 return self.run_compiled_in_txn(db, plan, stmt, params);
712 }
713 self.dispatch(db, stmt, params)
714 }
715
716 fn run_compiled_in_txn(
717 &mut self,
718 db: &'a Database,
719 plan: &Arc<dyn executor::CompiledPlan>,
720 stmt: &Statement,
721 params: &[Value],
722 ) -> Result<ExecutionResult> {
723 let schema = &self.schema;
724 let wtx = self.active_txn.as_write_mut();
725 if params.is_empty() || !plan.uses_scoped_params() {
726 plan.execute(db, schema, stmt, params, wtx)
727 } else {
728 crate::eval::with_scoped_params(params, || plan.execute(db, schema, stmt, params, wtx))
729 }
730 }
731
732 pub(crate) fn dispatch(
733 &mut self,
734 db: &'a Database,
735 stmt: &Statement,
736 params: &[Value],
737 ) -> Result<ExecutionResult> {
738 let cached_ts = self
739 .txn_start_ts
740 .or_else(|| Some(crate::datetime::now_micros()));
741 crate::datetime::with_txn_clock(cached_ts, || {
742 if params.is_empty() {
743 self.dispatch_inner(db, stmt, params)
744 } else {
745 crate::eval::with_scoped_params(params, || self.dispatch_inner(db, stmt, params))
746 }
747 })
748 }
749
750 fn dispatch_inner(
751 &mut self,
752 db: &'a Database,
753 stmt: &Statement,
754 params: &[Value],
755 ) -> Result<ExecutionResult> {
756 match stmt {
757 Statement::Begin { access_mode } => {
758 if self.active_txn.is_active() {
759 return Err(SqlError::TransactionAlreadyActive);
760 }
761 let ts = crate::datetime::txn_or_clock_micros();
762 match access_mode {
763 BeginAccessMode::ReadOnly => {
764 let rtx = db.begin_read();
765 self.active_txn = ActiveTxn::Read(rtx);
766 }
767 BeginAccessMode::ReadWrite | BeginAccessMode::Default => {
768 let wtx = db.begin_write().map_err(SqlError::Storage)?;
769 self.active_txn = ActiveTxn::Write(wtx);
770 }
771 }
772 self.txn_start_ts = Some(ts);
773 crate::datetime::set_txn_clock(Some(ts));
774 Ok(ExecutionResult::Ok)
775 }
776 Statement::Commit => {
777 match self.active_txn.take() {
778 ActiveTxn::None => return Err(SqlError::NoActiveTransaction),
779 ActiveTxn::Write(mut wtx) => {
780 crate::executor::helpers::drain_deferred_fk_checks(&mut wtx)?;
781 wtx.commit().map_err(SqlError::Storage)?;
782 }
783 ActiveTxn::Read(_rtx) => {}
784 }
785 self.clear_savepoint_state();
786 self.txn_start_ts = None;
787 crate::datetime::set_txn_clock(None);
788 Ok(ExecutionResult::Ok)
789 }
790 Statement::Rollback => {
791 match self.active_txn.take() {
792 ActiveTxn::None => return Err(SqlError::NoActiveTransaction),
793 ActiveTxn::Write(wtx) => {
794 wtx.abort();
795 self.schema = SchemaManager::load(db)?;
796 }
797 ActiveTxn::Read(_rtx) => {}
798 }
799 self.clear_savepoint_state();
800 self.txn_start_ts = None;
801 crate::datetime::set_txn_clock(None);
802 Ok(ExecutionResult::Ok)
803 }
804 Statement::Savepoint(name) => self.do_savepoint(name),
805 Statement::ReleaseSavepoint(name) => self.do_release(name),
806 Statement::RollbackTo(name) => self.do_rollback_to(name),
807 Statement::SetTimezone(zone) => {
808 self.set_session_timezone_impl(zone)?;
809 Ok(ExecutionResult::Ok)
810 }
811 Statement::CreateTable(ct) if ct.temporary => {
812 if self.active_txn.is_read_only() {
813 return Err(SqlError::Unsupported(
814 "cannot execute mutating statement inside a read-only transaction".into(),
815 ));
816 }
817 let user_name = ct.name.clone();
818 let prefixed = temp_storage_name(self.temp_id, &user_name);
819 if self.schema.contains(&user_name) {
820 if ct.if_not_exists {
821 return Ok(ExecutionResult::Ok);
822 }
823 return Err(SqlError::TableAlreadyExists(user_name));
824 }
825 let mut clone = ct.clone();
826 clone.name = prefixed.clone();
827 clone.temporary = false;
828 let stmt_concrete = Statement::CreateTable(clone);
829 let outcome = if let Some(wtx) = self.active_txn.as_write_mut() {
830 executor::execute_in_txn(wtx, &mut self.schema, &stmt_concrete, params)?
831 } else {
832 executor::execute(db, &mut self.schema, &stmt_concrete, params)?
833 };
834 self.schema
835 .register_temp_alias(&user_name, prefixed.clone());
836 self.temp_table_names.push(prefixed);
837 Ok(outcome)
838 }
839 Statement::Insert(ins) if self.active_txn.as_write_mut().is_some() => {
840 self.capture_pending_snapshots();
841 let wtx = self.active_txn.as_write_mut().unwrap();
842 executor::exec_insert_in_txn(wtx, &self.schema, ins, params)
843 }
844 _ => {
845 if self.active_txn.is_read_only() && stmt_mutates(stmt) {
846 return Err(SqlError::Unsupported(
847 "cannot execute mutating statement inside a read-only transaction".into(),
848 ));
849 }
850 if self.active_txn.as_write_mut().is_some() && stmt_mutates(stmt) {
851 self.capture_pending_snapshots();
852 }
853 let outcome = if let Some(wtx) = self.active_txn.as_write_mut() {
854 executor::execute_in_txn(wtx, &mut self.schema, stmt, params)?
855 } else {
856 executor::execute(db, &mut self.schema, stmt, params)?
857 };
858 if let Statement::DropTable(dt) = stmt {
859 self.schema.unregister_temp_alias(&dt.name);
860 }
861 Ok(outcome)
862 }
863 }
864 }
865
866 fn clear_savepoint_state(&mut self) {
867 self.savepoint_stack.clear();
868 self.in_place_saved = None;
869 }
870
871 fn do_savepoint(&mut self, name: &str) -> Result<ExecutionResult> {
872 let wtx = self
873 .active_txn
874 .as_write_mut()
875 .ok_or(SqlError::NoActiveTransaction)?;
876
877 if self.savepoint_stack.is_empty() {
878 self.in_place_saved = Some(wtx.in_place());
879 wtx.set_in_place(false);
880 }
881
882 self.savepoint_stack.push(SavepointEntry {
883 name: name.to_string(),
884 snapshot: None,
885 });
886
887 Ok(ExecutionResult::Ok)
888 }
889
890 fn capture_pending_snapshots(&mut self) {
891 let last_pending = match self
892 .savepoint_stack
893 .iter()
894 .rposition(|e| e.snapshot.is_none())
895 {
896 Some(i) => i,
897 None => return,
898 };
899 let wtx = match self.active_txn.as_write_mut() {
900 Some(w) => w,
901 None => return,
902 };
903 let wtx_snap = wtx.begin_savepoint();
904 let schema_snap = self.schema.save_snapshot();
905
906 for i in 0..last_pending {
907 if self.savepoint_stack[i].snapshot.is_none() {
908 self.savepoint_stack[i].snapshot = Some(SavepointSnapshot {
909 wtx_snap: wtx_snap.clone(),
910 schema_snap: schema_snap.clone(),
911 });
912 }
913 }
914 self.savepoint_stack[last_pending].snapshot = Some(SavepointSnapshot {
915 wtx_snap,
916 schema_snap,
917 });
918 }
919
920 fn do_release(&mut self, name: &str) -> Result<ExecutionResult> {
921 if !self.active_txn.is_active() {
922 return Err(SqlError::NoActiveTransaction);
923 }
924
925 let idx = self
926 .savepoint_stack
927 .iter()
928 .rposition(|e| e.name == name)
929 .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
930 self.savepoint_stack.truncate(idx);
931
932 if self.savepoint_stack.is_empty() {
933 if let (Some(wtx), Some(original)) =
934 (self.active_txn.as_write_mut(), self.in_place_saved.take())
935 {
936 wtx.set_in_place(original);
937 }
938 }
939
940 Ok(ExecutionResult::Ok)
941 }
942
943 fn do_rollback_to(&mut self, name: &str) -> Result<ExecutionResult> {
944 if !self.active_txn.is_active() {
945 return Err(SqlError::NoActiveTransaction);
946 }
947
948 let idx = self
949 .savepoint_stack
950 .iter()
951 .rposition(|e| e.name == name)
952 .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
953
954 self.savepoint_stack.truncate(idx + 1);
955 let entry = self.savepoint_stack.last_mut().unwrap();
956 let snapshot = match entry.snapshot.take() {
957 Some(s) => s,
958 None => return Ok(ExecutionResult::Ok),
959 };
960
961 let wtx = match self.active_txn.as_write_mut() {
962 Some(w) => w,
963 None => return Err(SqlError::NoActiveTransaction),
964 };
965 wtx.restore_snapshot(snapshot.wtx_snap);
966 self.schema.restore_snapshot(snapshot.schema_snap);
967
968 Ok(ExecutionResult::Ok)
969 }
970}
971
972impl<'a> Drop for Connection<'a> {
973 fn drop(&mut self) {
974 let temp_names = std::mem::take(&mut self.inner.borrow_mut().temp_table_names);
975 if temp_names.is_empty() {
976 return;
977 }
978 if let Ok(mut wtx) = self.db.begin_write() {
979 for prefixed in &temp_names {
980 let _ = wtx.drop_table(prefixed.as_bytes());
981 }
982 let _ = wtx.commit();
983 }
984 }
985}