Skip to main content

citadel_sql/
connection.rs

1//! Public SQL connection API.
2
3use std::cell::RefCell;
4use std::num::NonZeroUsize;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7
8fn generate_temp_id() -> u64 {
9    static COUNTER: AtomicU64 = AtomicU64::new(0);
10    let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
11    let nanos = (crate::datetime::now_micros() as u64) & 0xFFFF_FFFF;
12    (nanos << 32) | (counter & 0xFFFF_FFFF)
13}
14
15fn temp_storage_name(temp_id: u64, user_name: &str) -> String {
16    format!("__temp_{temp_id}_{}", user_name.to_ascii_lowercase())
17}
18
19use lru::LruCache;
20
21use citadel::Database;
22use citadel_txn::write_txn::{WriteTxn, WriteTxnSnapshot};
23
24use crate::error::{Result, SqlError};
25use crate::executor;
26use crate::parser;
27use crate::parser::{BeginAccessMode, QueryBody, SelectQuery, Statement};
28use crate::prepared::PreparedStatement;
29use crate::schema::{SchemaManager, SchemaSnapshot};
30use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
31
32const DEFAULT_CACHE_CAPACITY: usize = 64;
33
34#[derive(Debug)]
35pub struct ScriptExecution {
36    pub completed: Vec<ExecutionResult>,
37    pub error: Option<SqlError>,
38}
39
40fn parse_fixed_offset(s: &str) -> Option<()> {
41    let s = s.trim();
42    if s.eq_ignore_ascii_case("z") || s.eq_ignore_ascii_case("utc") {
43        return Some(());
44    }
45    let bytes = s.as_bytes();
46    if bytes.is_empty() {
47        return None;
48    }
49    if !matches!(bytes[0], b'+' | b'-') {
50        return None;
51    }
52    let rest = &s[1..];
53    let (hh, mm) = if let Some((h, m)) = rest.split_once(':') {
54        (h, m)
55    } else if rest.len() == 4 {
56        (&rest[..2], &rest[2..])
57    } else if rest.len() == 2 {
58        (rest, "00")
59    } else {
60        return None;
61    };
62    let h: u32 = hh.parse().ok()?;
63    let m: u32 = mm.parse().ok()?;
64    if h > 23 || m > 59 {
65        return None;
66    }
67    Some(())
68}
69
70fn rewrite_show_triggers(sql: &str) -> Option<String> {
71    let trimmed = sql.trim();
72    let trimmed = trimmed.trim_end_matches(';').trim();
73    let lower = trimmed.to_ascii_lowercase();
74    if !lower.starts_with("show triggers") {
75        return None;
76    }
77    let after = lower["show triggers".len()..].trim_start();
78    let base = "SELECT trigger_name, event_object_table AS table_name, action_timing, \
79                event_manipulation, action_orientation, action_statement \
80                FROM information_schema.triggers";
81    if after.is_empty() {
82        return Some(format!("{base} ORDER BY trigger_name"));
83    }
84    if let Some(rest) = after.strip_prefix("on ") {
85        let table = rest.trim().trim_end_matches(';').trim();
86        if table.is_empty() {
87            return None;
88        }
89        let escaped = table.replace('\'', "''");
90        return Some(format!(
91            "{base} WHERE LOWER(event_object_table) = LOWER('{escaped}') ORDER BY trigger_name"
92        ));
93    }
94    None
95}
96
97fn rewrite_show_matviews(sql: &str) -> Option<String> {
98    let trimmed = sql.trim();
99    let trimmed = trimmed.trim_end_matches(';').trim();
100    let lower = trimmed.to_ascii_lowercase();
101    if lower != "show materialized views" {
102        return None;
103    }
104    Some(
105        "SELECT matviewname, ispopulated, hasindexes, definition \
106         FROM pg_matviews ORDER BY matviewname"
107            .to_string(),
108    )
109}
110
111fn stmt_mutates(stmt: &Statement) -> bool {
112    if matches!(
113        stmt,
114        Statement::Insert(_)
115            | Statement::Update(_)
116            | Statement::Delete(_)
117            | Statement::Truncate(_)
118            | Statement::CreateTable(_)
119            | Statement::DropTable(_)
120            | Statement::AlterTable(_)
121            | Statement::CreateIndex(_)
122            | Statement::DropIndex(_)
123            | Statement::CreateView(_)
124            | Statement::DropView(_)
125            | Statement::CreateTrigger(_)
126            | Statement::DropTrigger(_)
127            | Statement::CreateMaterializedView(_)
128            | Statement::RefreshMaterializedView(_)
129            | Statement::DropMaterializedView(_)
130    ) {
131        return true;
132    }
133    if let Statement::Select(sq) = stmt {
134        if select_query_has_dml(sq) {
135            return true;
136        }
137    }
138    false
139}
140
141fn select_query_has_dml(sq: &SelectQuery) -> bool {
142    sq.ctes.iter().any(|cte| query_body_has_dml(&cte.body)) || query_body_has_dml(&sq.body)
143}
144
145fn query_body_has_dml(body: &QueryBody) -> bool {
146    match body {
147        QueryBody::Insert(_) | QueryBody::Update(_) | QueryBody::Delete(_) => true,
148        QueryBody::Compound(c) => query_body_has_dml(&c.left) || query_body_has_dml(&c.right),
149        QueryBody::Select(_) => false,
150    }
151}
152
153fn try_normalize_insert(sql: &str) -> Option<(String, Vec<Value>)> {
154    let bytes = sql.as_bytes();
155    let len = bytes.len();
156    let mut i = 0;
157
158    while i < len && bytes[i].is_ascii_whitespace() {
159        i += 1;
160    }
161    if i + 6 > len || !bytes[i..i + 6].eq_ignore_ascii_case(b"INSERT") {
162        return None;
163    }
164    i += 6;
165    if i >= len || !bytes[i].is_ascii_whitespace() {
166        return None;
167    }
168    while i < len && bytes[i].is_ascii_whitespace() {
169        i += 1;
170    }
171
172    if i + 4 > len || !bytes[i..i + 4].eq_ignore_ascii_case(b"INTO") {
173        return None;
174    }
175    i += 4;
176    if i >= len || !bytes[i].is_ascii_whitespace() {
177        return None;
178    }
179
180    let prefix_start = 0;
181    let mut values_pos = None;
182    let mut j = i;
183    while j + 6 <= len {
184        if bytes[j..j + 6].eq_ignore_ascii_case(b"VALUES")
185            && (j == 0 || !bytes[j - 1].is_ascii_alphanumeric() && bytes[j - 1] != b'_')
186            && (j + 6 >= len || !bytes[j + 6].is_ascii_alphanumeric() && bytes[j + 6] != b'_')
187        {
188            values_pos = Some(j);
189            break;
190        }
191        j += 1;
192    }
193    let values_pos = values_pos?;
194
195    let prefix = &sql[prefix_start..values_pos + 6];
196    let mut pos = values_pos + 6;
197
198    while pos < len && bytes[pos].is_ascii_whitespace() {
199        pos += 1;
200    }
201    if pos >= len || bytes[pos] != b'(' {
202        return None;
203    }
204    pos += 1;
205
206    let mut values = Vec::new();
207    let mut normalized = String::with_capacity(sql.len());
208    normalized.push_str(prefix);
209    normalized.push_str(" (");
210
211    loop {
212        while pos < len && bytes[pos].is_ascii_whitespace() {
213            pos += 1;
214        }
215        if pos >= len {
216            return None;
217        }
218
219        let param_idx = values.len() + 1;
220        if param_idx > 1 {
221            normalized.push_str(", ");
222        }
223
224        if bytes[pos] == b'\'' {
225            pos += 1;
226            let mut seg_start = pos;
227            let mut s = String::new();
228            loop {
229                if pos >= len {
230                    return None;
231                }
232                if bytes[pos] == b'\'' {
233                    s.push_str(std::str::from_utf8(&bytes[seg_start..pos]).ok()?);
234                    pos += 1;
235                    if pos < len && bytes[pos] == b'\'' {
236                        s.push('\'');
237                        pos += 1;
238                        seg_start = pos;
239                    } else {
240                        break;
241                    }
242                } else {
243                    pos += 1;
244                }
245            }
246            values.push(Value::Text(s.into()));
247        } else if bytes[pos] == b'-' || bytes[pos].is_ascii_digit() {
248            let start = pos;
249            if bytes[pos] == b'-' {
250                pos += 1;
251            }
252            while pos < len && bytes[pos].is_ascii_digit() {
253                pos += 1;
254            }
255            if pos < len && bytes[pos] == b'.' {
256                pos += 1;
257                while pos < len && bytes[pos].is_ascii_digit() {
258                    pos += 1;
259                }
260                let num: f64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
261                values.push(Value::Real(num));
262            } else {
263                let num: i64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
264                values.push(Value::Integer(num));
265            }
266        } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"NULL") {
267            let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
268            if !after.is_ascii_alphanumeric() && after != b'_' {
269                pos += 4;
270                values.push(Value::Null);
271            } else {
272                return None;
273            }
274        } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"TRUE") {
275            let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
276            if !after.is_ascii_alphanumeric() && after != b'_' {
277                pos += 4;
278                values.push(Value::Boolean(true));
279            } else {
280                return None;
281            }
282        } else if pos + 5 <= len && bytes[pos..pos + 5].eq_ignore_ascii_case(b"FALSE") {
283            let after = if pos + 5 < len { bytes[pos + 5] } else { b')' };
284            if !after.is_ascii_alphanumeric() && after != b'_' {
285                pos += 5;
286                values.push(Value::Boolean(false));
287            } else {
288                return None;
289            }
290        } else {
291            return None;
292        }
293
294        normalized.push('$');
295        normalized.push_str(&param_idx.to_string());
296
297        while pos < len && bytes[pos].is_ascii_whitespace() {
298            pos += 1;
299        }
300        if pos >= len {
301            return None;
302        }
303
304        if bytes[pos] == b',' {
305            pos += 1;
306        } else if bytes[pos] == b')' {
307            pos += 1;
308            break;
309        } else {
310            return None;
311        }
312    }
313
314    normalized.push(')');
315
316    while pos < len && (bytes[pos].is_ascii_whitespace() || bytes[pos] == b';') {
317        pos += 1;
318    }
319    if pos != len {
320        return None;
321    }
322
323    if values.is_empty() {
324        return None;
325    }
326
327    Some((normalized, values))
328}
329
330pub(crate) struct CacheEntry {
331    pub(crate) stmt: Arc<Statement>,
332    pub(crate) schema_gen: u64,
333    pub(crate) param_count: usize,
334    pub(crate) compiled: Option<Arc<dyn executor::CompiledPlan>>,
335}
336
337struct SavepointEntry {
338    name: String,
339    snapshot: Option<SavepointSnapshot>,
340}
341
342struct SavepointSnapshot {
343    wtx_snap: WriteTxnSnapshot,
344    schema_snap: SchemaSnapshot,
345}
346
347/// Active transaction held by a Connection. `None` outside BEGIN/COMMIT; `Write` for normal
348/// BEGIN (or BEGIN READ WRITE); `Read` for BEGIN READ ONLY.
349#[allow(clippy::large_enum_variant)]
350pub(crate) enum ActiveTxn<'a> {
351    None,
352    Write(WriteTxn<'a>),
353    Read(citadel_txn::read_txn::ReadTxn<'a>),
354}
355
356impl<'a> ActiveTxn<'a> {
357    fn is_none(&self) -> bool {
358        matches!(self, ActiveTxn::None)
359    }
360    fn is_active(&self) -> bool {
361        !self.is_none()
362    }
363    fn is_read_only(&self) -> bool {
364        matches!(self, ActiveTxn::Read(_))
365    }
366    fn as_write_mut(&mut self) -> Option<&mut WriteTxn<'a>> {
367        match self {
368            ActiveTxn::Write(w) => Some(w),
369            _ => None,
370        }
371    }
372    fn take(&mut self) -> ActiveTxn<'a> {
373        std::mem::replace(self, ActiveTxn::None)
374    }
375}
376
377pub(crate) struct ConnectionInner<'a> {
378    pub(crate) schema: SchemaManager,
379    active_txn: ActiveTxn<'a>,
380    savepoint_stack: Vec<SavepointEntry>,
381    in_place_saved: Option<bool>,
382    pub(crate) stmt_cache: LruCache<String, CacheEntry>,
383    txn_start_ts: Option<i64>,
384    session_timezone: String,
385    /// Namespaces TEMP tables as `__temp_<id>_<name>`. Cleaned up on Connection drop.
386    temp_id: u64,
387    temp_table_names: Vec<String>,
388}
389
390pub struct Connection<'a> {
391    pub(crate) db: &'a Database,
392    pub(crate) inner: RefCell<ConnectionInner<'a>>,
393}
394
395impl<'a> Connection<'a> {
396    pub fn open(db: &'a Database) -> Result<Self> {
397        let schema = SchemaManager::load(db)?;
398        let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
399        let temp_id = generate_temp_id();
400        Ok(Self {
401            db,
402            inner: RefCell::new(ConnectionInner {
403                schema,
404                active_txn: ActiveTxn::None,
405                savepoint_stack: Vec::new(),
406                in_place_saved: None,
407                stmt_cache,
408                txn_start_ts: None,
409                session_timezone: "UTC".to_string(),
410                temp_id,
411                temp_table_names: Vec::new(),
412            }),
413        })
414    }
415
416    /// Txn-start UTC µs inside BEGIN/COMMIT, else `None`.
417    pub fn txn_start_ts(&self) -> Option<i64> {
418        self.inner.borrow().txn_start_ts
419    }
420
421    /// Returns the session time-zone (IANA name or fixed offset). Default `"UTC"`.
422    pub fn session_timezone(&self) -> String {
423        self.inner.borrow().session_timezone.clone()
424    }
425
426    /// Set the session time-zone. Accepts IANA names, ISO-8601 offsets, `"UTC"`, `"Z"`.
427    pub fn set_session_timezone(&self, tz: &str) -> Result<()> {
428        self.inner.borrow_mut().set_session_timezone_impl(tz)
429    }
430
431    pub fn execute(&self, sql: &str) -> Result<ExecutionResult> {
432        self.inner.borrow_mut().execute_impl(self.db, sql)
433    }
434
435    pub fn execute_params(&self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
436        self.inner
437            .borrow_mut()
438            .execute_params_impl(self.db, sql, params)
439    }
440
441    /// Execute `;`-separated SQL statements. Stops at the first failure.
442    pub fn execute_script(&self, sql: &str) -> ScriptExecution {
443        let stmts = match parser::parse_sql_multi(sql) {
444            Ok(s) => s,
445            Err(e) => {
446                return ScriptExecution {
447                    completed: vec![],
448                    error: Some(e),
449                }
450            }
451        };
452        let mut completed = Vec::with_capacity(stmts.len());
453        for stmt in stmts {
454            match self.inner.borrow_mut().dispatch(self.db, &stmt, &[]) {
455                Ok(r) => completed.push(r),
456                Err(e) => {
457                    return ScriptExecution {
458                        completed,
459                        error: Some(e),
460                    }
461                }
462            }
463        }
464        ScriptExecution {
465            completed,
466            error: None,
467        }
468    }
469
470    pub fn query(&self, sql: &str) -> Result<QueryResult> {
471        self.query_params(sql, &[])
472    }
473
474    pub fn query_params(&self, sql: &str, params: &[Value]) -> Result<QueryResult> {
475        match self.execute_params(sql, params)? {
476            ExecutionResult::Query(qr) => Ok(qr),
477            ExecutionResult::RowsAffected(n) => Ok(QueryResult {
478                columns: vec!["rows_affected".into()],
479                rows: vec![vec![Value::Integer(n as i64)]],
480            }),
481            ExecutionResult::Ok => Ok(QueryResult {
482                columns: vec![],
483                rows: vec![],
484            }),
485        }
486    }
487
488    pub fn prepare(&self, sql: &str) -> Result<PreparedStatement<'_, 'a>> {
489        if let Some(rewritten) = rewrite_show_triggers(sql) {
490            return PreparedStatement::new(self, &rewritten);
491        }
492        if let Some(rewritten) = rewrite_show_matviews(sql) {
493            return PreparedStatement::new(self, &rewritten);
494        }
495        PreparedStatement::new(self, sql)
496    }
497
498    pub fn tables(&self) -> Vec<String> {
499        self.inner
500            .borrow()
501            .schema
502            .table_names()
503            .into_iter()
504            .map(String::from)
505            .collect()
506    }
507
508    /// Returns true if an explicit transaction is active (BEGIN was issued).
509    pub fn in_transaction(&self) -> bool {
510        self.inner.borrow().active_txn.is_active()
511    }
512
513    pub fn table_schema(&self, name: &str) -> Option<TableSchema> {
514        self.inner.borrow().schema.get(name).cloned()
515    }
516
517    pub fn refresh_schema(&self) -> Result<()> {
518        let new_schema = SchemaManager::load(self.db)?;
519        self.inner.borrow_mut().schema = new_schema;
520        Ok(())
521    }
522}
523
524impl<'a> ConnectionInner<'a> {
525    pub(crate) fn active_txn_is_some(&self) -> bool {
526        self.active_txn.is_active()
527    }
528
529    fn set_session_timezone_impl(&mut self, tz: &str) -> Result<()> {
530        let upper = tz.to_ascii_uppercase();
531        if (upper.starts_with("UTC+") || upper.starts_with("UTC-")) && tz.len() > 3 {
532            return Err(SqlError::InvalidTimezone(format!(
533                "'{tz}' is ambiguous; use ISO-8601 offset (e.g. '+05:00') or named zone (e.g. 'Etc/GMT-5')"
534            )));
535        }
536        if jiff::tz::TimeZone::get(tz).is_err() && parse_fixed_offset(tz).is_none() {
537            return Err(SqlError::InvalidTimezone(format!(
538                "{tz}: not a known IANA zone or ISO-8601 offset (e.g. '+05:00', 'UTC', 'America/New_York')"
539            )));
540        }
541        self.session_timezone = tz.to_string();
542        Ok(())
543    }
544
545    fn execute_impl(&mut self, db: &'a Database, sql: &str) -> Result<ExecutionResult> {
546        if let Some(rewritten) = rewrite_show_triggers(sql) {
547            return self.execute_params_impl(db, &rewritten, &[]);
548        }
549        if let Some(rewritten) = rewrite_show_matviews(sql) {
550            return self.execute_params_impl(db, &rewritten, &[]);
551        }
552        if matches!(sql.as_bytes().first(), Some(b'I' | b'i')) {
553            if let Some((normalized_key, extracted)) = try_normalize_insert(sql) {
554                let gen = self.schema.generation();
555                let stmt = if let Some(entry) = self.stmt_cache.get(&normalized_key) {
556                    if entry.schema_gen == gen {
557                        Arc::clone(&entry.stmt)
558                    } else {
559                        self.parse_and_cache(normalized_key, gen)?
560                    }
561                } else {
562                    self.parse_and_cache(normalized_key, gen)?
563                };
564                return self.dispatch(db, &stmt, &extracted);
565            }
566        }
567        self.execute_params_impl(db, sql, &[])
568    }
569
570    fn execute_params_impl(
571        &mut self,
572        db: &'a Database,
573        sql: &str,
574        params: &[Value],
575    ) -> Result<ExecutionResult> {
576        let gen = self.schema.generation();
577        if self.active_txn.is_none() {
578            if let Some(entry) = self.stmt_cache.get(sql) {
579                if entry.schema_gen == gen && entry.param_count == params.len() {
580                    if let Some(plan) = entry.compiled.as_ref().map(Arc::clone) {
581                        let stmt = Arc::clone(&entry.stmt);
582                        return self.run_compiled(db, &plan, &stmt, params);
583                    }
584                }
585            }
586        }
587
588        let (stmt, param_count) = self.get_or_parse(sql)?;
589
590        if param_count != params.len() {
591            return Err(SqlError::ParameterCountMismatch {
592                expected: param_count,
593                got: params.len(),
594            });
595        }
596
597        if self.active_txn.is_none() {
598            if let Some(plan) = executor::compile(&self.schema, &stmt) {
599                if let Some(e) = self.stmt_cache.get_mut(sql) {
600                    e.compiled = Some(Arc::clone(&plan));
601                }
602                let stmt_owned = Arc::clone(&stmt);
603                return self.run_compiled(db, &plan, &stmt_owned, params);
604            }
605        }
606
607        self.dispatch(db, &stmt, params)
608    }
609
610    fn run_compiled(
611        &mut self,
612        db: &'a Database,
613        plan: &Arc<dyn executor::CompiledPlan>,
614        stmt: &Statement,
615        params: &[Value],
616    ) -> Result<ExecutionResult> {
617        use executor::compile::ActiveTxnRef;
618        let schema = &self.schema;
619        let exec = || {
620            if params.is_empty() {
621                plan.execute(db, schema, stmt, params, ActiveTxnRef::None)
622            } else {
623                crate::eval::with_scoped_params(params, || {
624                    plan.execute(db, schema, stmt, params, ActiveTxnRef::None)
625                })
626            }
627        };
628        if plan.needs_txn_clock() {
629            let cached_ts = self
630                .txn_start_ts
631                .or_else(|| Some(crate::datetime::now_micros()));
632            crate::datetime::with_txn_clock(cached_ts, exec)
633        } else {
634            exec()
635        }
636    }
637
638    pub(crate) fn parse_and_cache(
639        &mut self,
640        normalized_key: String,
641        gen: u64,
642    ) -> Result<Arc<Statement>> {
643        let stmt = Arc::new(parser::parse_sql(&normalized_key)?);
644        let param_count = parser::count_params(&stmt);
645        self.stmt_cache.put(
646            normalized_key,
647            CacheEntry {
648                stmt: Arc::clone(&stmt),
649                schema_gen: gen,
650                param_count,
651                compiled: None,
652            },
653        );
654        Ok(stmt)
655    }
656
657    pub(crate) fn get_or_parse(&mut self, sql: &str) -> Result<(Arc<Statement>, usize)> {
658        let gen = self.schema.generation();
659
660        if let Some(entry) = self.stmt_cache.get(sql) {
661            if entry.schema_gen == gen {
662                return Ok((Arc::clone(&entry.stmt), entry.param_count));
663            }
664        }
665
666        let stmt = Arc::new(parser::parse_sql(sql)?);
667        let param_count = parser::count_params(&stmt);
668
669        let cacheable = !matches!(
670            *stmt,
671            Statement::CreateTable(_)
672                | Statement::DropTable(_)
673                | Statement::CreateIndex(_)
674                | Statement::DropIndex(_)
675                | Statement::CreateView(_)
676                | Statement::DropView(_)
677                | Statement::AlterTable(_)
678        );
679
680        if cacheable {
681            self.stmt_cache.put(
682                sql.to_string(),
683                CacheEntry {
684                    stmt: Arc::clone(&stmt),
685                    schema_gen: gen,
686                    param_count,
687                    compiled: None,
688                },
689            );
690        }
691
692        Ok((stmt, param_count))
693    }
694
695    pub(crate) fn execute_prepared(
696        &mut self,
697        db: &'a Database,
698        stmt: &Statement,
699        compiled: Option<&Arc<dyn executor::CompiledPlan>>,
700        params: &[Value],
701    ) -> Result<ExecutionResult> {
702        if let Some(plan) = compiled {
703            if self.active_txn.is_none() {
704                return self.run_compiled(db, plan, stmt, params);
705            }
706            if !self.savepoint_stack.is_empty() && stmt_mutates(stmt) {
707                self.capture_pending_snapshots();
708            }
709            return self.run_compiled_in_txn(db, plan, stmt, params);
710        }
711        self.dispatch(db, stmt, params)
712    }
713
714    fn run_compiled_in_txn(
715        &mut self,
716        db: &'a Database,
717        plan: &Arc<dyn executor::CompiledPlan>,
718        stmt: &Statement,
719        params: &[Value],
720    ) -> Result<ExecutionResult> {
721        use executor::compile::ActiveTxnRef;
722        let schema = &self.schema;
723        let txn = match &mut self.active_txn {
724            ActiveTxn::Write(wtx) => ActiveTxnRef::Write(wtx),
725            ActiveTxn::Read(rtx) => ActiveTxnRef::Read(rtx),
726            ActiveTxn::None => ActiveTxnRef::None,
727        };
728        if params.is_empty() || !plan.uses_scoped_params() {
729            plan.execute(db, schema, stmt, params, txn)
730        } else {
731            crate::eval::with_scoped_params(params, || plan.execute(db, schema, stmt, params, txn))
732        }
733    }
734
735    pub(crate) fn dispatch(
736        &mut self,
737        db: &'a Database,
738        stmt: &Statement,
739        params: &[Value],
740    ) -> Result<ExecutionResult> {
741        let cached_ts = self
742            .txn_start_ts
743            .or_else(|| Some(crate::datetime::now_micros()));
744        crate::datetime::with_txn_clock(cached_ts, || {
745            if params.is_empty() {
746                self.dispatch_inner(db, stmt, params)
747            } else {
748                crate::eval::with_scoped_params(params, || self.dispatch_inner(db, stmt, params))
749            }
750        })
751    }
752
753    fn dispatch_inner(
754        &mut self,
755        db: &'a Database,
756        stmt: &Statement,
757        params: &[Value],
758    ) -> Result<ExecutionResult> {
759        match stmt {
760            Statement::Begin { access_mode } => {
761                if self.active_txn.is_active() {
762                    return Err(SqlError::TransactionAlreadyActive);
763                }
764                let ts = crate::datetime::txn_or_clock_micros();
765                match access_mode {
766                    BeginAccessMode::ReadOnly => {
767                        let rtx = db.begin_read();
768                        self.active_txn = ActiveTxn::Read(rtx);
769                    }
770                    BeginAccessMode::ReadWrite | BeginAccessMode::Default => {
771                        let wtx = db.begin_write().map_err(SqlError::Storage)?;
772                        self.active_txn = ActiveTxn::Write(wtx);
773                    }
774                }
775                self.txn_start_ts = Some(ts);
776                crate::datetime::set_txn_clock(Some(ts));
777                Ok(ExecutionResult::Ok)
778            }
779            Statement::Commit => {
780                match self.active_txn.take() {
781                    ActiveTxn::None => return Err(SqlError::NoActiveTransaction),
782                    ActiveTxn::Write(mut wtx) => {
783                        crate::executor::helpers::drain_deferred_fk_checks(&mut wtx)?;
784                        wtx.commit().map_err(SqlError::Storage)?;
785                    }
786                    ActiveTxn::Read(_rtx) => {}
787                }
788                self.clear_savepoint_state();
789                self.txn_start_ts = None;
790                crate::datetime::set_txn_clock(None);
791                Ok(ExecutionResult::Ok)
792            }
793            Statement::Rollback => {
794                match self.active_txn.take() {
795                    ActiveTxn::None => return Err(SqlError::NoActiveTransaction),
796                    ActiveTxn::Write(wtx) => {
797                        wtx.abort();
798                        self.schema = SchemaManager::load(db)?;
799                    }
800                    ActiveTxn::Read(_rtx) => {}
801                }
802                self.clear_savepoint_state();
803                self.txn_start_ts = None;
804                crate::datetime::set_txn_clock(None);
805                Ok(ExecutionResult::Ok)
806            }
807            Statement::Savepoint(name) => self.do_savepoint(name),
808            Statement::ReleaseSavepoint(name) => self.do_release(name),
809            Statement::RollbackTo(name) => self.do_rollback_to(name),
810            Statement::SetTimezone(zone) => {
811                self.set_session_timezone_impl(zone)?;
812                Ok(ExecutionResult::Ok)
813            }
814            Statement::CreateTable(ct) if ct.temporary => {
815                if self.active_txn.is_read_only() {
816                    return Err(SqlError::Unsupported(
817                        "cannot execute mutating statement inside a read-only transaction".into(),
818                    ));
819                }
820                let user_name = ct.name.clone();
821                let prefixed = temp_storage_name(self.temp_id, &user_name);
822                if self.schema.contains(&user_name) {
823                    if ct.if_not_exists {
824                        return Ok(ExecutionResult::Ok);
825                    }
826                    return Err(SqlError::TableAlreadyExists(user_name));
827                }
828                let mut clone = ct.clone();
829                clone.name = prefixed.clone();
830                clone.temporary = false;
831                let stmt_concrete = Statement::CreateTable(clone);
832                let outcome = if let Some(wtx) = self.active_txn.as_write_mut() {
833                    executor::execute_in_txn(wtx, &mut self.schema, &stmt_concrete, params)?
834                } else {
835                    executor::execute(db, &mut self.schema, &stmt_concrete, params)?
836                };
837                self.schema
838                    .register_temp_alias(&user_name, prefixed.clone());
839                self.temp_table_names.push(prefixed);
840                Ok(outcome)
841            }
842            Statement::Insert(ins) if self.active_txn.as_write_mut().is_some() => {
843                self.capture_pending_snapshots();
844                let wtx = self.active_txn.as_write_mut().unwrap();
845                executor::exec_insert_in_txn(wtx, &self.schema, ins, params)
846            }
847            _ => {
848                if self.active_txn.is_read_only() && stmt_mutates(stmt) {
849                    return Err(SqlError::Unsupported(
850                        "cannot execute mutating statement inside a read-only transaction".into(),
851                    ));
852                }
853                if self.active_txn.as_write_mut().is_some() && stmt_mutates(stmt) {
854                    self.capture_pending_snapshots();
855                }
856                let outcome = match &mut self.active_txn {
857                    ActiveTxn::Write(wtx) => {
858                        executor::execute_in_txn(wtx, &mut self.schema, stmt, params)?
859                    }
860                    ActiveTxn::Read(rtx) => {
861                        executor::execute_with_read(rtx, &self.schema, stmt, params)?
862                    }
863                    ActiveTxn::None => executor::execute(db, &mut self.schema, stmt, params)?,
864                };
865                if let Statement::DropTable(dt) = stmt {
866                    self.schema.unregister_temp_alias(&dt.name);
867                }
868                Ok(outcome)
869            }
870        }
871    }
872
873    fn clear_savepoint_state(&mut self) {
874        self.savepoint_stack.clear();
875        self.in_place_saved = None;
876    }
877
878    fn do_savepoint(&mut self, name: &str) -> Result<ExecutionResult> {
879        let wtx = self
880            .active_txn
881            .as_write_mut()
882            .ok_or(SqlError::NoActiveTransaction)?;
883
884        if self.savepoint_stack.is_empty() {
885            self.in_place_saved = Some(wtx.in_place());
886            wtx.set_in_place(false);
887        }
888
889        self.savepoint_stack.push(SavepointEntry {
890            name: name.to_string(),
891            snapshot: None,
892        });
893
894        Ok(ExecutionResult::Ok)
895    }
896
897    fn capture_pending_snapshots(&mut self) {
898        let last_pending = match self
899            .savepoint_stack
900            .iter()
901            .rposition(|e| e.snapshot.is_none())
902        {
903            Some(i) => i,
904            None => return,
905        };
906        let wtx = match self.active_txn.as_write_mut() {
907            Some(w) => w,
908            None => return,
909        };
910        let wtx_snap = wtx.begin_savepoint();
911        let schema_snap = self.schema.save_snapshot();
912
913        for i in 0..last_pending {
914            if self.savepoint_stack[i].snapshot.is_none() {
915                self.savepoint_stack[i].snapshot = Some(SavepointSnapshot {
916                    wtx_snap: wtx_snap.clone(),
917                    schema_snap: schema_snap.clone(),
918                });
919            }
920        }
921        self.savepoint_stack[last_pending].snapshot = Some(SavepointSnapshot {
922            wtx_snap,
923            schema_snap,
924        });
925    }
926
927    fn do_release(&mut self, name: &str) -> Result<ExecutionResult> {
928        if !self.active_txn.is_active() {
929            return Err(SqlError::NoActiveTransaction);
930        }
931
932        let idx = self
933            .savepoint_stack
934            .iter()
935            .rposition(|e| e.name == name)
936            .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
937        self.savepoint_stack.truncate(idx);
938
939        if self.savepoint_stack.is_empty() {
940            if let (Some(wtx), Some(original)) =
941                (self.active_txn.as_write_mut(), self.in_place_saved.take())
942            {
943                wtx.set_in_place(original);
944            }
945        }
946
947        Ok(ExecutionResult::Ok)
948    }
949
950    fn do_rollback_to(&mut self, name: &str) -> Result<ExecutionResult> {
951        if !self.active_txn.is_active() {
952            return Err(SqlError::NoActiveTransaction);
953        }
954
955        let idx = self
956            .savepoint_stack
957            .iter()
958            .rposition(|e| e.name == name)
959            .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
960
961        self.savepoint_stack.truncate(idx + 1);
962        let entry = self.savepoint_stack.last_mut().unwrap();
963        let snapshot = match entry.snapshot.take() {
964            Some(s) => s,
965            None => return Ok(ExecutionResult::Ok),
966        };
967
968        let wtx = match self.active_txn.as_write_mut() {
969            Some(w) => w,
970            None => return Err(SqlError::NoActiveTransaction),
971        };
972        wtx.restore_snapshot(snapshot.wtx_snap);
973        self.schema.restore_snapshot(snapshot.schema_snap);
974
975        Ok(ExecutionResult::Ok)
976    }
977}
978
979impl<'a> Drop for Connection<'a> {
980    fn drop(&mut self) {
981        let temp_names = std::mem::take(&mut self.inner.borrow_mut().temp_table_names);
982        if temp_names.is_empty() {
983            return;
984        }
985        if let Ok(mut wtx) = self.db.begin_write() {
986            for prefixed in &temp_names {
987                let _ = wtx.drop_table(prefixed.as_bytes());
988            }
989            let _ = wtx.commit();
990        }
991    }
992}