Skip to main content

citadel_sql/
connection.rs

1//! Public SQL connection API.
2
3use std::cell::RefCell;
4use std::num::NonZeroUsize;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7
8fn generate_temp_id() -> u64 {
9    static COUNTER: AtomicU64 = AtomicU64::new(0);
10    let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
11    let nanos = std::time::SystemTime::now()
12        .duration_since(std::time::UNIX_EPOCH)
13        .map(|d| d.subsec_nanos() as u64)
14        .unwrap_or(0);
15    (nanos << 32) | (counter & 0xFFFF_FFFF)
16}
17
18fn temp_storage_name(temp_id: u64, user_name: &str) -> String {
19    format!("__temp_{temp_id}_{}", user_name.to_ascii_lowercase())
20}
21
22use lru::LruCache;
23
24use citadel::Database;
25use citadel_txn::write_txn::{WriteTxn, WriteTxnSnapshot};
26
27use crate::error::{Result, SqlError};
28use crate::executor;
29use crate::parser;
30use crate::parser::{BeginAccessMode, QueryBody, SelectQuery, Statement};
31use crate::prepared::PreparedStatement;
32use crate::schema::{SchemaManager, SchemaSnapshot};
33use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
34
35const DEFAULT_CACHE_CAPACITY: usize = 64;
36
37#[derive(Debug)]
38pub struct ScriptExecution {
39    pub completed: Vec<ExecutionResult>,
40    pub error: Option<SqlError>,
41}
42
43fn parse_fixed_offset(s: &str) -> Option<()> {
44    let s = s.trim();
45    if s.eq_ignore_ascii_case("z") || s.eq_ignore_ascii_case("utc") {
46        return Some(());
47    }
48    let bytes = s.as_bytes();
49    if bytes.is_empty() {
50        return None;
51    }
52    if !matches!(bytes[0], b'+' | b'-') {
53        return None;
54    }
55    let rest = &s[1..];
56    let (hh, mm) = if let Some((h, m)) = rest.split_once(':') {
57        (h, m)
58    } else if rest.len() == 4 {
59        (&rest[..2], &rest[2..])
60    } else if rest.len() == 2 {
61        (rest, "00")
62    } else {
63        return None;
64    };
65    let h: u32 = hh.parse().ok()?;
66    let m: u32 = mm.parse().ok()?;
67    if h > 23 || m > 59 {
68        return None;
69    }
70    Some(())
71}
72
73fn rewrite_show_triggers(sql: &str) -> Option<String> {
74    let trimmed = sql.trim();
75    let trimmed = trimmed.trim_end_matches(';').trim();
76    let lower = trimmed.to_ascii_lowercase();
77    if !lower.starts_with("show triggers") {
78        return None;
79    }
80    let after = lower["show triggers".len()..].trim_start();
81    let base = "SELECT trigger_name, event_object_table AS table_name, action_timing, \
82                event_manipulation, action_orientation, action_statement \
83                FROM information_schema.triggers";
84    if after.is_empty() {
85        return Some(format!("{base} ORDER BY trigger_name"));
86    }
87    if let Some(rest) = after.strip_prefix("on ") {
88        let table = rest.trim().trim_end_matches(';').trim();
89        if table.is_empty() {
90            return None;
91        }
92        let escaped = table.replace('\'', "''");
93        return Some(format!(
94            "{base} WHERE LOWER(event_object_table) = LOWER('{escaped}') ORDER BY trigger_name"
95        ));
96    }
97    None
98}
99
100fn rewrite_show_matviews(sql: &str) -> Option<String> {
101    let trimmed = sql.trim();
102    let trimmed = trimmed.trim_end_matches(';').trim();
103    let lower = trimmed.to_ascii_lowercase();
104    if lower != "show materialized views" {
105        return None;
106    }
107    Some(
108        "SELECT matviewname, ispopulated, hasindexes, definition \
109         FROM pg_matviews ORDER BY matviewname"
110            .to_string(),
111    )
112}
113
114fn stmt_mutates(stmt: &Statement) -> bool {
115    if matches!(
116        stmt,
117        Statement::Insert(_)
118            | Statement::Update(_)
119            | Statement::Delete(_)
120            | Statement::Truncate(_)
121            | Statement::CreateTable(_)
122            | Statement::DropTable(_)
123            | Statement::AlterTable(_)
124            | Statement::CreateIndex(_)
125            | Statement::DropIndex(_)
126            | Statement::CreateView(_)
127            | Statement::DropView(_)
128            | Statement::CreateTrigger(_)
129            | Statement::DropTrigger(_)
130            | Statement::CreateMaterializedView(_)
131            | Statement::RefreshMaterializedView(_)
132            | Statement::DropMaterializedView(_)
133    ) {
134        return true;
135    }
136    if let Statement::Select(sq) = stmt {
137        if select_query_has_dml(sq) {
138            return true;
139        }
140    }
141    false
142}
143
144fn select_query_has_dml(sq: &SelectQuery) -> bool {
145    sq.ctes.iter().any(|cte| query_body_has_dml(&cte.body)) || query_body_has_dml(&sq.body)
146}
147
148fn query_body_has_dml(body: &QueryBody) -> bool {
149    match body {
150        QueryBody::Insert(_) | QueryBody::Update(_) | QueryBody::Delete(_) => true,
151        QueryBody::Compound(c) => query_body_has_dml(&c.left) || query_body_has_dml(&c.right),
152        QueryBody::Select(_) => false,
153    }
154}
155
156fn try_normalize_insert(sql: &str) -> Option<(String, Vec<Value>)> {
157    let bytes = sql.as_bytes();
158    let len = bytes.len();
159    let mut i = 0;
160
161    while i < len && bytes[i].is_ascii_whitespace() {
162        i += 1;
163    }
164    if i + 6 > len || !bytes[i..i + 6].eq_ignore_ascii_case(b"INSERT") {
165        return None;
166    }
167    i += 6;
168    if i >= len || !bytes[i].is_ascii_whitespace() {
169        return None;
170    }
171    while i < len && bytes[i].is_ascii_whitespace() {
172        i += 1;
173    }
174
175    if i + 4 > len || !bytes[i..i + 4].eq_ignore_ascii_case(b"INTO") {
176        return None;
177    }
178    i += 4;
179    if i >= len || !bytes[i].is_ascii_whitespace() {
180        return None;
181    }
182
183    let prefix_start = 0;
184    let mut values_pos = None;
185    let mut j = i;
186    while j + 6 <= len {
187        if bytes[j..j + 6].eq_ignore_ascii_case(b"VALUES")
188            && (j == 0 || !bytes[j - 1].is_ascii_alphanumeric() && bytes[j - 1] != b'_')
189            && (j + 6 >= len || !bytes[j + 6].is_ascii_alphanumeric() && bytes[j + 6] != b'_')
190        {
191            values_pos = Some(j);
192            break;
193        }
194        j += 1;
195    }
196    let values_pos = values_pos?;
197
198    let prefix = &sql[prefix_start..values_pos + 6];
199    let mut pos = values_pos + 6;
200
201    while pos < len && bytes[pos].is_ascii_whitespace() {
202        pos += 1;
203    }
204    if pos >= len || bytes[pos] != b'(' {
205        return None;
206    }
207    pos += 1;
208
209    let mut values = Vec::new();
210    let mut normalized = String::with_capacity(sql.len());
211    normalized.push_str(prefix);
212    normalized.push_str(" (");
213
214    loop {
215        while pos < len && bytes[pos].is_ascii_whitespace() {
216            pos += 1;
217        }
218        if pos >= len {
219            return None;
220        }
221
222        let param_idx = values.len() + 1;
223        if param_idx > 1 {
224            normalized.push_str(", ");
225        }
226
227        if bytes[pos] == b'\'' {
228            pos += 1;
229            let mut seg_start = pos;
230            let mut s = String::new();
231            loop {
232                if pos >= len {
233                    return None;
234                }
235                if bytes[pos] == b'\'' {
236                    s.push_str(std::str::from_utf8(&bytes[seg_start..pos]).ok()?);
237                    pos += 1;
238                    if pos < len && bytes[pos] == b'\'' {
239                        s.push('\'');
240                        pos += 1;
241                        seg_start = pos;
242                    } else {
243                        break;
244                    }
245                } else {
246                    pos += 1;
247                }
248            }
249            values.push(Value::Text(s.into()));
250        } else if bytes[pos] == b'-' || bytes[pos].is_ascii_digit() {
251            let start = pos;
252            if bytes[pos] == b'-' {
253                pos += 1;
254            }
255            while pos < len && bytes[pos].is_ascii_digit() {
256                pos += 1;
257            }
258            if pos < len && bytes[pos] == b'.' {
259                pos += 1;
260                while pos < len && bytes[pos].is_ascii_digit() {
261                    pos += 1;
262                }
263                let num: f64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
264                values.push(Value::Real(num));
265            } else {
266                let num: i64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
267                values.push(Value::Integer(num));
268            }
269        } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"NULL") {
270            let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
271            if !after.is_ascii_alphanumeric() && after != b'_' {
272                pos += 4;
273                values.push(Value::Null);
274            } else {
275                return None;
276            }
277        } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"TRUE") {
278            let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
279            if !after.is_ascii_alphanumeric() && after != b'_' {
280                pos += 4;
281                values.push(Value::Boolean(true));
282            } else {
283                return None;
284            }
285        } else if pos + 5 <= len && bytes[pos..pos + 5].eq_ignore_ascii_case(b"FALSE") {
286            let after = if pos + 5 < len { bytes[pos + 5] } else { b')' };
287            if !after.is_ascii_alphanumeric() && after != b'_' {
288                pos += 5;
289                values.push(Value::Boolean(false));
290            } else {
291                return None;
292            }
293        } else {
294            return None;
295        }
296
297        normalized.push('$');
298        normalized.push_str(&param_idx.to_string());
299
300        while pos < len && bytes[pos].is_ascii_whitespace() {
301            pos += 1;
302        }
303        if pos >= len {
304            return None;
305        }
306
307        if bytes[pos] == b',' {
308            pos += 1;
309        } else if bytes[pos] == b')' {
310            pos += 1;
311            break;
312        } else {
313            return None;
314        }
315    }
316
317    normalized.push(')');
318
319    while pos < len && (bytes[pos].is_ascii_whitespace() || bytes[pos] == b';') {
320        pos += 1;
321    }
322    if pos != len {
323        return None;
324    }
325
326    if values.is_empty() {
327        return None;
328    }
329
330    Some((normalized, values))
331}
332
333pub(crate) struct CacheEntry {
334    pub(crate) stmt: Arc<Statement>,
335    pub(crate) schema_gen: u64,
336    pub(crate) param_count: usize,
337    pub(crate) compiled: Option<Arc<dyn executor::CompiledPlan>>,
338}
339
340struct SavepointEntry {
341    name: String,
342    snapshot: Option<SavepointSnapshot>,
343}
344
345struct SavepointSnapshot {
346    wtx_snap: WriteTxnSnapshot,
347    schema_snap: SchemaSnapshot,
348}
349
350/// Active transaction held by a Connection. `None` outside BEGIN/COMMIT; `Write` for normal
351/// BEGIN (or BEGIN READ WRITE); `Read` for BEGIN READ ONLY.
352#[allow(clippy::large_enum_variant)]
353pub(crate) enum ActiveTxn<'a> {
354    None,
355    Write(WriteTxn<'a>),
356    Read(citadel_txn::read_txn::ReadTxn<'a>),
357}
358
359impl<'a> ActiveTxn<'a> {
360    fn is_none(&self) -> bool {
361        matches!(self, ActiveTxn::None)
362    }
363    fn is_active(&self) -> bool {
364        !self.is_none()
365    }
366    fn is_read_only(&self) -> bool {
367        matches!(self, ActiveTxn::Read(_))
368    }
369    fn as_write_mut(&mut self) -> Option<&mut WriteTxn<'a>> {
370        match self {
371            ActiveTxn::Write(w) => Some(w),
372            _ => None,
373        }
374    }
375    fn take(&mut self) -> ActiveTxn<'a> {
376        std::mem::replace(self, ActiveTxn::None)
377    }
378}
379
380pub(crate) struct ConnectionInner<'a> {
381    pub(crate) schema: SchemaManager,
382    active_txn: ActiveTxn<'a>,
383    savepoint_stack: Vec<SavepointEntry>,
384    in_place_saved: Option<bool>,
385    pub(crate) stmt_cache: LruCache<String, CacheEntry>,
386    txn_start_ts: Option<i64>,
387    session_timezone: String,
388    /// Namespaces TEMP tables as `__temp_<id>_<name>`. Cleaned up on Connection drop.
389    temp_id: u64,
390    temp_table_names: Vec<String>,
391}
392
393pub struct Connection<'a> {
394    pub(crate) db: &'a Database,
395    pub(crate) inner: RefCell<ConnectionInner<'a>>,
396}
397
398impl<'a> Connection<'a> {
399    pub fn open(db: &'a Database) -> Result<Self> {
400        let schema = SchemaManager::load(db)?;
401        let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
402        let temp_id = generate_temp_id();
403        Ok(Self {
404            db,
405            inner: RefCell::new(ConnectionInner {
406                schema,
407                active_txn: ActiveTxn::None,
408                savepoint_stack: Vec::new(),
409                in_place_saved: None,
410                stmt_cache,
411                txn_start_ts: None,
412                session_timezone: "UTC".to_string(),
413                temp_id,
414                temp_table_names: Vec::new(),
415            }),
416        })
417    }
418
419    /// Txn-start UTC µs inside BEGIN/COMMIT, else `None`.
420    pub fn txn_start_ts(&self) -> Option<i64> {
421        self.inner.borrow().txn_start_ts
422    }
423
424    /// Returns the session time-zone (IANA name or fixed offset). Default `"UTC"`.
425    pub fn session_timezone(&self) -> String {
426        self.inner.borrow().session_timezone.clone()
427    }
428
429    /// Set the session time-zone. Accepts IANA names, ISO-8601 offsets, `"UTC"`, `"Z"`.
430    pub fn set_session_timezone(&self, tz: &str) -> Result<()> {
431        self.inner.borrow_mut().set_session_timezone_impl(tz)
432    }
433
434    pub fn execute(&self, sql: &str) -> Result<ExecutionResult> {
435        self.inner.borrow_mut().execute_impl(self.db, sql)
436    }
437
438    pub fn execute_params(&self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
439        self.inner
440            .borrow_mut()
441            .execute_params_impl(self.db, sql, params)
442    }
443
444    /// Execute `;`-separated SQL statements. Stops at the first failure.
445    pub fn execute_script(&self, sql: &str) -> ScriptExecution {
446        let stmts = match parser::parse_sql_multi(sql) {
447            Ok(s) => s,
448            Err(e) => {
449                return ScriptExecution {
450                    completed: vec![],
451                    error: Some(e),
452                }
453            }
454        };
455        let mut completed = Vec::with_capacity(stmts.len());
456        for stmt in stmts {
457            match self.inner.borrow_mut().dispatch(self.db, &stmt, &[]) {
458                Ok(r) => completed.push(r),
459                Err(e) => {
460                    return ScriptExecution {
461                        completed,
462                        error: Some(e),
463                    }
464                }
465            }
466        }
467        ScriptExecution {
468            completed,
469            error: None,
470        }
471    }
472
473    pub fn query(&self, sql: &str) -> Result<QueryResult> {
474        self.query_params(sql, &[])
475    }
476
477    pub fn query_params(&self, sql: &str, params: &[Value]) -> Result<QueryResult> {
478        match self.execute_params(sql, params)? {
479            ExecutionResult::Query(qr) => Ok(qr),
480            ExecutionResult::RowsAffected(n) => Ok(QueryResult {
481                columns: vec!["rows_affected".into()],
482                rows: vec![vec![Value::Integer(n as i64)]],
483            }),
484            ExecutionResult::Ok => Ok(QueryResult {
485                columns: vec![],
486                rows: vec![],
487            }),
488        }
489    }
490
491    pub fn prepare(&self, sql: &str) -> Result<PreparedStatement<'_, 'a>> {
492        if let Some(rewritten) = rewrite_show_triggers(sql) {
493            return PreparedStatement::new(self, &rewritten);
494        }
495        if let Some(rewritten) = rewrite_show_matviews(sql) {
496            return PreparedStatement::new(self, &rewritten);
497        }
498        PreparedStatement::new(self, sql)
499    }
500
501    pub fn tables(&self) -> Vec<String> {
502        self.inner
503            .borrow()
504            .schema
505            .table_names()
506            .into_iter()
507            .map(String::from)
508            .collect()
509    }
510
511    /// Returns true if an explicit transaction is active (BEGIN was issued).
512    pub fn in_transaction(&self) -> bool {
513        self.inner.borrow().active_txn.is_active()
514    }
515
516    pub fn table_schema(&self, name: &str) -> Option<TableSchema> {
517        self.inner.borrow().schema.get(name).cloned()
518    }
519
520    pub fn refresh_schema(&self) -> Result<()> {
521        let new_schema = SchemaManager::load(self.db)?;
522        self.inner.borrow_mut().schema = new_schema;
523        Ok(())
524    }
525}
526
527impl<'a> ConnectionInner<'a> {
528    pub(crate) fn active_txn_is_some(&self) -> bool {
529        self.active_txn.is_active()
530    }
531
532    fn set_session_timezone_impl(&mut self, tz: &str) -> Result<()> {
533        let upper = tz.to_ascii_uppercase();
534        if (upper.starts_with("UTC+") || upper.starts_with("UTC-")) && tz.len() > 3 {
535            return Err(SqlError::InvalidTimezone(format!(
536                "'{tz}' is ambiguous; use ISO-8601 offset (e.g. '+05:00') or named zone (e.g. 'Etc/GMT-5')"
537            )));
538        }
539        if jiff::tz::TimeZone::get(tz).is_err() && parse_fixed_offset(tz).is_none() {
540            return Err(SqlError::InvalidTimezone(format!(
541                "{tz}: not a known IANA zone or ISO-8601 offset (e.g. '+05:00', 'UTC', 'America/New_York')"
542            )));
543        }
544        self.session_timezone = tz.to_string();
545        Ok(())
546    }
547
548    fn execute_impl(&mut self, db: &'a Database, sql: &str) -> Result<ExecutionResult> {
549        if let Some(rewritten) = rewrite_show_triggers(sql) {
550            return self.execute_params_impl(db, &rewritten, &[]);
551        }
552        if let Some(rewritten) = rewrite_show_matviews(sql) {
553            return self.execute_params_impl(db, &rewritten, &[]);
554        }
555        if matches!(sql.as_bytes().first(), Some(b'I' | b'i')) {
556            if let Some((normalized_key, extracted)) = try_normalize_insert(sql) {
557                let gen = self.schema.generation();
558                let stmt = if let Some(entry) = self.stmt_cache.get(&normalized_key) {
559                    if entry.schema_gen == gen {
560                        Arc::clone(&entry.stmt)
561                    } else {
562                        self.parse_and_cache(normalized_key, gen)?
563                    }
564                } else {
565                    self.parse_and_cache(normalized_key, gen)?
566                };
567                return self.dispatch(db, &stmt, &extracted);
568            }
569        }
570        self.execute_params_impl(db, sql, &[])
571    }
572
573    fn execute_params_impl(
574        &mut self,
575        db: &'a Database,
576        sql: &str,
577        params: &[Value],
578    ) -> Result<ExecutionResult> {
579        let gen = self.schema.generation();
580        if self.active_txn.is_none() {
581            if let Some(entry) = self.stmt_cache.get(sql) {
582                if entry.schema_gen == gen && entry.param_count == params.len() {
583                    if let Some(plan) = entry.compiled.as_ref().map(Arc::clone) {
584                        let stmt = Arc::clone(&entry.stmt);
585                        return self.run_compiled(db, &plan, &stmt, params);
586                    }
587                }
588            }
589        }
590
591        let (stmt, param_count) = self.get_or_parse(sql)?;
592
593        if param_count != params.len() {
594            return Err(SqlError::ParameterCountMismatch {
595                expected: param_count,
596                got: params.len(),
597            });
598        }
599
600        if self.active_txn.is_none() {
601            if let Some(plan) = executor::compile(&self.schema, &stmt) {
602                if let Some(e) = self.stmt_cache.get_mut(sql) {
603                    e.compiled = Some(Arc::clone(&plan));
604                }
605                let stmt_owned = Arc::clone(&stmt);
606                return self.run_compiled(db, &plan, &stmt_owned, params);
607            }
608        }
609
610        self.dispatch(db, &stmt, params)
611    }
612
613    fn run_compiled(
614        &mut self,
615        db: &'a Database,
616        plan: &Arc<dyn executor::CompiledPlan>,
617        stmt: &Statement,
618        params: &[Value],
619    ) -> Result<ExecutionResult> {
620        let schema = &self.schema;
621        let exec = || {
622            if params.is_empty() {
623                plan.execute(db, schema, stmt, params, None)
624            } else {
625                crate::eval::with_scoped_params(params, || {
626                    plan.execute(db, schema, stmt, params, None)
627                })
628            }
629        };
630        if plan.needs_txn_clock() {
631            let cached_ts = self
632                .txn_start_ts
633                .or_else(|| Some(crate::datetime::now_micros()));
634            crate::datetime::with_txn_clock(cached_ts, exec)
635        } else {
636            exec()
637        }
638    }
639
640    pub(crate) fn parse_and_cache(
641        &mut self,
642        normalized_key: String,
643        gen: u64,
644    ) -> Result<Arc<Statement>> {
645        let stmt = Arc::new(parser::parse_sql(&normalized_key)?);
646        let param_count = parser::count_params(&stmt);
647        self.stmt_cache.put(
648            normalized_key,
649            CacheEntry {
650                stmt: Arc::clone(&stmt),
651                schema_gen: gen,
652                param_count,
653                compiled: None,
654            },
655        );
656        Ok(stmt)
657    }
658
659    pub(crate) fn get_or_parse(&mut self, sql: &str) -> Result<(Arc<Statement>, usize)> {
660        let gen = self.schema.generation();
661
662        if let Some(entry) = self.stmt_cache.get(sql) {
663            if entry.schema_gen == gen {
664                return Ok((Arc::clone(&entry.stmt), entry.param_count));
665            }
666        }
667
668        let stmt = Arc::new(parser::parse_sql(sql)?);
669        let param_count = parser::count_params(&stmt);
670
671        let cacheable = !matches!(
672            *stmt,
673            Statement::CreateTable(_)
674                | Statement::DropTable(_)
675                | Statement::CreateIndex(_)
676                | Statement::DropIndex(_)
677                | Statement::CreateView(_)
678                | Statement::DropView(_)
679                | Statement::AlterTable(_)
680        );
681
682        if cacheable {
683            self.stmt_cache.put(
684                sql.to_string(),
685                CacheEntry {
686                    stmt: Arc::clone(&stmt),
687                    schema_gen: gen,
688                    param_count,
689                    compiled: None,
690                },
691            );
692        }
693
694        Ok((stmt, param_count))
695    }
696
697    pub(crate) fn execute_prepared(
698        &mut self,
699        db: &'a Database,
700        stmt: &Statement,
701        compiled: Option<&Arc<dyn executor::CompiledPlan>>,
702        params: &[Value],
703    ) -> Result<ExecutionResult> {
704        if let Some(plan) = compiled {
705            if self.active_txn.is_none() {
706                return self.run_compiled(db, plan, stmt, params);
707            }
708            if !self.savepoint_stack.is_empty() && stmt_mutates(stmt) {
709                self.capture_pending_snapshots();
710            }
711            return self.run_compiled_in_txn(db, plan, stmt, params);
712        }
713        self.dispatch(db, stmt, params)
714    }
715
716    fn run_compiled_in_txn(
717        &mut self,
718        db: &'a Database,
719        plan: &Arc<dyn executor::CompiledPlan>,
720        stmt: &Statement,
721        params: &[Value],
722    ) -> Result<ExecutionResult> {
723        let schema = &self.schema;
724        let wtx = self.active_txn.as_write_mut();
725        if params.is_empty() || !plan.uses_scoped_params() {
726            plan.execute(db, schema, stmt, params, wtx)
727        } else {
728            crate::eval::with_scoped_params(params, || plan.execute(db, schema, stmt, params, wtx))
729        }
730    }
731
732    pub(crate) fn dispatch(
733        &mut self,
734        db: &'a Database,
735        stmt: &Statement,
736        params: &[Value],
737    ) -> Result<ExecutionResult> {
738        let cached_ts = self
739            .txn_start_ts
740            .or_else(|| Some(crate::datetime::now_micros()));
741        crate::datetime::with_txn_clock(cached_ts, || {
742            if params.is_empty() {
743                self.dispatch_inner(db, stmt, params)
744            } else {
745                crate::eval::with_scoped_params(params, || self.dispatch_inner(db, stmt, params))
746            }
747        })
748    }
749
750    fn dispatch_inner(
751        &mut self,
752        db: &'a Database,
753        stmt: &Statement,
754        params: &[Value],
755    ) -> Result<ExecutionResult> {
756        match stmt {
757            Statement::Begin { access_mode } => {
758                if self.active_txn.is_active() {
759                    return Err(SqlError::TransactionAlreadyActive);
760                }
761                let ts = crate::datetime::txn_or_clock_micros();
762                match access_mode {
763                    BeginAccessMode::ReadOnly => {
764                        let rtx = db.begin_read();
765                        self.active_txn = ActiveTxn::Read(rtx);
766                    }
767                    BeginAccessMode::ReadWrite | BeginAccessMode::Default => {
768                        let wtx = db.begin_write().map_err(SqlError::Storage)?;
769                        self.active_txn = ActiveTxn::Write(wtx);
770                    }
771                }
772                self.txn_start_ts = Some(ts);
773                crate::datetime::set_txn_clock(Some(ts));
774                Ok(ExecutionResult::Ok)
775            }
776            Statement::Commit => {
777                match self.active_txn.take() {
778                    ActiveTxn::None => return Err(SqlError::NoActiveTransaction),
779                    ActiveTxn::Write(mut wtx) => {
780                        crate::executor::helpers::drain_deferred_fk_checks(&mut wtx)?;
781                        wtx.commit().map_err(SqlError::Storage)?;
782                    }
783                    ActiveTxn::Read(_rtx) => {}
784                }
785                self.clear_savepoint_state();
786                self.txn_start_ts = None;
787                crate::datetime::set_txn_clock(None);
788                Ok(ExecutionResult::Ok)
789            }
790            Statement::Rollback => {
791                match self.active_txn.take() {
792                    ActiveTxn::None => return Err(SqlError::NoActiveTransaction),
793                    ActiveTxn::Write(wtx) => {
794                        wtx.abort();
795                        self.schema = SchemaManager::load(db)?;
796                    }
797                    ActiveTxn::Read(_rtx) => {}
798                }
799                self.clear_savepoint_state();
800                self.txn_start_ts = None;
801                crate::datetime::set_txn_clock(None);
802                Ok(ExecutionResult::Ok)
803            }
804            Statement::Savepoint(name) => self.do_savepoint(name),
805            Statement::ReleaseSavepoint(name) => self.do_release(name),
806            Statement::RollbackTo(name) => self.do_rollback_to(name),
807            Statement::SetTimezone(zone) => {
808                self.set_session_timezone_impl(zone)?;
809                Ok(ExecutionResult::Ok)
810            }
811            Statement::CreateTable(ct) if ct.temporary => {
812                if self.active_txn.is_read_only() {
813                    return Err(SqlError::Unsupported(
814                        "cannot execute mutating statement inside a read-only transaction".into(),
815                    ));
816                }
817                let user_name = ct.name.clone();
818                let prefixed = temp_storage_name(self.temp_id, &user_name);
819                if self.schema.contains(&user_name) {
820                    if ct.if_not_exists {
821                        return Ok(ExecutionResult::Ok);
822                    }
823                    return Err(SqlError::TableAlreadyExists(user_name));
824                }
825                let mut clone = ct.clone();
826                clone.name = prefixed.clone();
827                clone.temporary = false;
828                let stmt_concrete = Statement::CreateTable(clone);
829                let outcome = if let Some(wtx) = self.active_txn.as_write_mut() {
830                    executor::execute_in_txn(wtx, &mut self.schema, &stmt_concrete, params)?
831                } else {
832                    executor::execute(db, &mut self.schema, &stmt_concrete, params)?
833                };
834                self.schema
835                    .register_temp_alias(&user_name, prefixed.clone());
836                self.temp_table_names.push(prefixed);
837                Ok(outcome)
838            }
839            Statement::Insert(ins) if self.active_txn.as_write_mut().is_some() => {
840                self.capture_pending_snapshots();
841                let wtx = self.active_txn.as_write_mut().unwrap();
842                executor::exec_insert_in_txn(wtx, &self.schema, ins, params)
843            }
844            _ => {
845                if self.active_txn.is_read_only() && stmt_mutates(stmt) {
846                    return Err(SqlError::Unsupported(
847                        "cannot execute mutating statement inside a read-only transaction".into(),
848                    ));
849                }
850                if self.active_txn.as_write_mut().is_some() && stmt_mutates(stmt) {
851                    self.capture_pending_snapshots();
852                }
853                let outcome = if let Some(wtx) = self.active_txn.as_write_mut() {
854                    executor::execute_in_txn(wtx, &mut self.schema, stmt, params)?
855                } else {
856                    executor::execute(db, &mut self.schema, stmt, params)?
857                };
858                if let Statement::DropTable(dt) = stmt {
859                    self.schema.unregister_temp_alias(&dt.name);
860                }
861                Ok(outcome)
862            }
863        }
864    }
865
866    fn clear_savepoint_state(&mut self) {
867        self.savepoint_stack.clear();
868        self.in_place_saved = None;
869    }
870
871    fn do_savepoint(&mut self, name: &str) -> Result<ExecutionResult> {
872        let wtx = self
873            .active_txn
874            .as_write_mut()
875            .ok_or(SqlError::NoActiveTransaction)?;
876
877        if self.savepoint_stack.is_empty() {
878            self.in_place_saved = Some(wtx.in_place());
879            wtx.set_in_place(false);
880        }
881
882        self.savepoint_stack.push(SavepointEntry {
883            name: name.to_string(),
884            snapshot: None,
885        });
886
887        Ok(ExecutionResult::Ok)
888    }
889
890    fn capture_pending_snapshots(&mut self) {
891        let last_pending = match self
892            .savepoint_stack
893            .iter()
894            .rposition(|e| e.snapshot.is_none())
895        {
896            Some(i) => i,
897            None => return,
898        };
899        let wtx = match self.active_txn.as_write_mut() {
900            Some(w) => w,
901            None => return,
902        };
903        let wtx_snap = wtx.begin_savepoint();
904        let schema_snap = self.schema.save_snapshot();
905
906        for i in 0..last_pending {
907            if self.savepoint_stack[i].snapshot.is_none() {
908                self.savepoint_stack[i].snapshot = Some(SavepointSnapshot {
909                    wtx_snap: wtx_snap.clone(),
910                    schema_snap: schema_snap.clone(),
911                });
912            }
913        }
914        self.savepoint_stack[last_pending].snapshot = Some(SavepointSnapshot {
915            wtx_snap,
916            schema_snap,
917        });
918    }
919
920    fn do_release(&mut self, name: &str) -> Result<ExecutionResult> {
921        if !self.active_txn.is_active() {
922            return Err(SqlError::NoActiveTransaction);
923        }
924
925        let idx = self
926            .savepoint_stack
927            .iter()
928            .rposition(|e| e.name == name)
929            .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
930        self.savepoint_stack.truncate(idx);
931
932        if self.savepoint_stack.is_empty() {
933            if let (Some(wtx), Some(original)) =
934                (self.active_txn.as_write_mut(), self.in_place_saved.take())
935            {
936                wtx.set_in_place(original);
937            }
938        }
939
940        Ok(ExecutionResult::Ok)
941    }
942
943    fn do_rollback_to(&mut self, name: &str) -> Result<ExecutionResult> {
944        if !self.active_txn.is_active() {
945            return Err(SqlError::NoActiveTransaction);
946        }
947
948        let idx = self
949            .savepoint_stack
950            .iter()
951            .rposition(|e| e.name == name)
952            .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
953
954        self.savepoint_stack.truncate(idx + 1);
955        let entry = self.savepoint_stack.last_mut().unwrap();
956        let snapshot = match entry.snapshot.take() {
957            Some(s) => s,
958            None => return Ok(ExecutionResult::Ok),
959        };
960
961        let wtx = match self.active_txn.as_write_mut() {
962            Some(w) => w,
963            None => return Err(SqlError::NoActiveTransaction),
964        };
965        wtx.restore_snapshot(snapshot.wtx_snap);
966        self.schema.restore_snapshot(snapshot.schema_snap);
967
968        Ok(ExecutionResult::Ok)
969    }
970}
971
972impl<'a> Drop for Connection<'a> {
973    fn drop(&mut self) {
974        let temp_names = std::mem::take(&mut self.inner.borrow_mut().temp_table_names);
975        if temp_names.is_empty() {
976            return;
977        }
978        if let Ok(mut wtx) = self.db.begin_write() {
979            for prefixed in &temp_names {
980                let _ = wtx.drop_table(prefixed.as_bytes());
981            }
982            let _ = wtx.commit();
983        }
984    }
985}