Skip to main content

citadel_sql/
connection.rs

1//! Public SQL connection API.
2
3use std::cell::RefCell;
4use std::num::NonZeroUsize;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7
8fn generate_temp_id() -> u64 {
9    static COUNTER: AtomicU64 = AtomicU64::new(0);
10    let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
11    let nanos = (crate::datetime::now_micros() as u64) & 0xFFFF_FFFF;
12    (nanos << 32) | (counter & 0xFFFF_FFFF)
13}
14
15fn temp_storage_name(temp_id: u64, user_name: &str) -> String {
16    format!("__temp_{temp_id}_{}", user_name.to_ascii_lowercase())
17}
18
19use lru::LruCache;
20
21use citadel::Database;
22use citadel_txn::write_txn::{WriteTxn, WriteTxnSnapshot};
23
24use crate::error::{Result, SqlError};
25use crate::executor;
26use crate::parser;
27use crate::parser::{BeginAccessMode, QueryBody, SelectQuery, Statement};
28use crate::prepared::PreparedStatement;
29use crate::schema::{SchemaManager, SchemaSnapshot};
30use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
31
32const DEFAULT_CACHE_CAPACITY: usize = 64;
33
34/// On commit, evict shared caches (e.g. ANN indexes) for DML-touched tables,
35/// and stamp each table's last-DML generation marker: an index whose snapshot
36/// predates the marker is refused at lookup AND at insert, closing the
37/// build-races-a-commit window that prefix eviction alone leaves open.
38fn invalidate_dml_caches(schema: &SchemaManager, db: &Database) {
39    let gen = db.manager().commit_generation();
40    let hard_invalidate = |table: &str| {
41        db.sql_cache_invalidate_prefix(&format!("ann:{table}:"));
42        let marker: std::sync::Arc<dyn std::any::Any + Send + Sync> = std::sync::Arc::new(gen);
43        schema
44            .sql_caches
45            .lock()
46            .insert(crate::executor::ann_dml_gen_key(table), marker);
47    };
48
49    let dirty = schema.drain_dml_dirty();
50    for table in &dirty.mutating {
51        hard_invalidate(table);
52    }
53    // A pure append keeps the cached index (recall tail-merges the new rows)
54    // unless it landed at/below an index snapshot.
55    for (table, min_pk) in &dirty.appends {
56        if crate::executor::ann_appends_safe(schema, table, *min_pk) {
57            continue;
58        }
59        hard_invalidate(table);
60    }
61}
62
63#[derive(Debug)]
64pub struct ScriptExecution {
65    pub completed: Vec<ExecutionResult>,
66    pub error: Option<SqlError>,
67}
68
69fn parse_fixed_offset(s: &str) -> Option<()> {
70    let s = s.trim();
71    if s.eq_ignore_ascii_case("z") || s.eq_ignore_ascii_case("utc") {
72        return Some(());
73    }
74    let bytes = s.as_bytes();
75    if bytes.is_empty() {
76        return None;
77    }
78    if !matches!(bytes[0], b'+' | b'-') {
79        return None;
80    }
81    let rest = &s[1..];
82    let (hh, mm) = if let Some((h, m)) = rest.split_once(':') {
83        (h, m)
84    } else if rest.len() == 4 {
85        (&rest[..2], &rest[2..])
86    } else if rest.len() == 2 {
87        (rest, "00")
88    } else {
89        return None;
90    };
91    let h: u32 = hh.parse().ok()?;
92    let m: u32 = mm.parse().ok()?;
93    if h > 23 || m > 59 {
94        return None;
95    }
96    Some(())
97}
98
99fn rewrite_show_triggers(sql: &str) -> Option<String> {
100    let trimmed = sql.trim();
101    let trimmed = trimmed.trim_end_matches(';').trim();
102    let lower = trimmed.to_ascii_lowercase();
103    if !lower.starts_with("show triggers") {
104        return None;
105    }
106    let after = lower["show triggers".len()..].trim_start();
107    let base = "SELECT trigger_name, event_object_table AS table_name, action_timing, \
108                event_manipulation, action_orientation, action_statement \
109                FROM information_schema.triggers";
110    if after.is_empty() {
111        return Some(format!("{base} ORDER BY trigger_name"));
112    }
113    if let Some(rest) = after.strip_prefix("on ") {
114        let table = rest.trim().trim_end_matches(';').trim();
115        if table.is_empty() {
116            return None;
117        }
118        let escaped = table.replace('\'', "''");
119        return Some(format!(
120            "{base} WHERE LOWER(event_object_table) = LOWER('{escaped}') ORDER BY trigger_name"
121        ));
122    }
123    None
124}
125
126fn rewrite_show_matviews(sql: &str) -> Option<String> {
127    let trimmed = sql.trim();
128    let trimmed = trimmed.trim_end_matches(';').trim();
129    let lower = trimmed.to_ascii_lowercase();
130    if lower != "show materialized views" {
131        return None;
132    }
133    Some(
134        "SELECT matviewname, ispopulated, hasindexes, definition \
135         FROM pg_matviews ORDER BY matviewname"
136            .to_string(),
137    )
138}
139
140fn stmt_mutates(stmt: &Statement) -> bool {
141    if matches!(
142        stmt,
143        Statement::Insert(_)
144            | Statement::Update(_)
145            | Statement::Delete(_)
146            | Statement::Truncate(_)
147            | Statement::CreateTable(_)
148            | Statement::DropTable(_)
149            | Statement::AlterTable(_)
150            | Statement::CreateIndex(_)
151            | Statement::DropIndex(_)
152            | Statement::CreateView(_)
153            | Statement::DropView(_)
154            | Statement::CreateTrigger(_)
155            | Statement::DropTrigger(_)
156            | Statement::CreateMaterializedView(_)
157            | Statement::RefreshMaterializedView(_)
158            | Statement::DropMaterializedView(_)
159    ) {
160        return true;
161    }
162    if let Statement::Select(sq) = stmt {
163        if select_query_has_dml(sq) {
164            return true;
165        }
166    }
167    false
168}
169
170fn select_query_has_dml(sq: &SelectQuery) -> bool {
171    sq.ctes.iter().any(|cte| query_body_has_dml(&cte.body)) || query_body_has_dml(&sq.body)
172}
173
174fn query_body_has_dml(body: &QueryBody) -> bool {
175    match body {
176        QueryBody::Insert(_) | QueryBody::Update(_) | QueryBody::Delete(_) => true,
177        QueryBody::Compound(c) => query_body_has_dml(&c.left) || query_body_has_dml(&c.right),
178        QueryBody::Select(_) => false,
179    }
180}
181
182fn try_normalize_insert(sql: &str) -> Option<(String, Vec<Value>)> {
183    let bytes = sql.as_bytes();
184    let len = bytes.len();
185    let mut i = 0;
186
187    while i < len && bytes[i].is_ascii_whitespace() {
188        i += 1;
189    }
190    if i + 6 > len || !bytes[i..i + 6].eq_ignore_ascii_case(b"INSERT") {
191        return None;
192    }
193    i += 6;
194    if i >= len || !bytes[i].is_ascii_whitespace() {
195        return None;
196    }
197    while i < len && bytes[i].is_ascii_whitespace() {
198        i += 1;
199    }
200
201    if i + 4 > len || !bytes[i..i + 4].eq_ignore_ascii_case(b"INTO") {
202        return None;
203    }
204    i += 4;
205    if i >= len || !bytes[i].is_ascii_whitespace() {
206        return None;
207    }
208
209    let prefix_start = 0;
210    let mut values_pos = None;
211    let mut j = i;
212    while j + 6 <= len {
213        if bytes[j..j + 6].eq_ignore_ascii_case(b"VALUES")
214            && (j == 0 || !bytes[j - 1].is_ascii_alphanumeric() && bytes[j - 1] != b'_')
215            && (j + 6 >= len || !bytes[j + 6].is_ascii_alphanumeric() && bytes[j + 6] != b'_')
216        {
217            values_pos = Some(j);
218            break;
219        }
220        j += 1;
221    }
222    let values_pos = values_pos?;
223
224    let prefix = &sql[prefix_start..values_pos + 6];
225    let mut pos = values_pos + 6;
226
227    while pos < len && bytes[pos].is_ascii_whitespace() {
228        pos += 1;
229    }
230    if pos >= len || bytes[pos] != b'(' {
231        return None;
232    }
233    pos += 1;
234
235    let mut values = Vec::new();
236    let mut normalized = String::with_capacity(sql.len());
237    normalized.push_str(prefix);
238    normalized.push_str(" (");
239
240    loop {
241        while pos < len && bytes[pos].is_ascii_whitespace() {
242            pos += 1;
243        }
244        if pos >= len {
245            return None;
246        }
247
248        let param_idx = values.len() + 1;
249        if param_idx > 1 {
250            normalized.push_str(", ");
251        }
252
253        if bytes[pos] == b'\'' {
254            pos += 1;
255            let mut seg_start = pos;
256            let mut s = String::new();
257            loop {
258                if pos >= len {
259                    return None;
260                }
261                if bytes[pos] == b'\'' {
262                    s.push_str(std::str::from_utf8(&bytes[seg_start..pos]).ok()?);
263                    pos += 1;
264                    if pos < len && bytes[pos] == b'\'' {
265                        s.push('\'');
266                        pos += 1;
267                        seg_start = pos;
268                    } else {
269                        break;
270                    }
271                } else {
272                    pos += 1;
273                }
274            }
275            values.push(Value::Text(s.into()));
276        } else if bytes[pos] == b'-' || bytes[pos].is_ascii_digit() {
277            let start = pos;
278            if bytes[pos] == b'-' {
279                pos += 1;
280            }
281            while pos < len && bytes[pos].is_ascii_digit() {
282                pos += 1;
283            }
284            if pos < len && bytes[pos] == b'.' {
285                pos += 1;
286                while pos < len && bytes[pos].is_ascii_digit() {
287                    pos += 1;
288                }
289                let num: f64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
290                values.push(Value::Real(num));
291            } else {
292                let num: i64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
293                values.push(Value::Integer(num));
294            }
295        } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"NULL") {
296            let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
297            if !after.is_ascii_alphanumeric() && after != b'_' {
298                pos += 4;
299                values.push(Value::Null);
300            } else {
301                return None;
302            }
303        } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"TRUE") {
304            let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
305            if !after.is_ascii_alphanumeric() && after != b'_' {
306                pos += 4;
307                values.push(Value::Boolean(true));
308            } else {
309                return None;
310            }
311        } else if pos + 5 <= len && bytes[pos..pos + 5].eq_ignore_ascii_case(b"FALSE") {
312            let after = if pos + 5 < len { bytes[pos + 5] } else { b')' };
313            if !after.is_ascii_alphanumeric() && after != b'_' {
314                pos += 5;
315                values.push(Value::Boolean(false));
316            } else {
317                return None;
318            }
319        } else {
320            return None;
321        }
322
323        normalized.push('$');
324        normalized.push_str(&param_idx.to_string());
325
326        while pos < len && bytes[pos].is_ascii_whitespace() {
327            pos += 1;
328        }
329        if pos >= len {
330            return None;
331        }
332
333        if bytes[pos] == b',' {
334            pos += 1;
335        } else if bytes[pos] == b')' {
336            pos += 1;
337            break;
338        } else {
339            return None;
340        }
341    }
342
343    normalized.push(')');
344
345    while pos < len && (bytes[pos].is_ascii_whitespace() || bytes[pos] == b';') {
346        pos += 1;
347    }
348    if pos != len {
349        return None;
350    }
351
352    if values.is_empty() {
353        return None;
354    }
355
356    Some((normalized, values))
357}
358
359pub(crate) struct CacheEntry {
360    pub(crate) stmt: Arc<Statement>,
361    pub(crate) schema_gen: u64,
362    pub(crate) param_count: usize,
363    pub(crate) compiled: Option<Arc<dyn executor::CompiledPlan>>,
364}
365
366struct SavepointEntry {
367    name: String,
368    snapshot: Option<SavepointSnapshot>,
369}
370
371struct SavepointSnapshot {
372    wtx_snap: WriteTxnSnapshot,
373    schema_snap: SchemaSnapshot,
374}
375
376/// Active transaction held by a Connection. `None` outside BEGIN/COMMIT; `Write` for normal
377/// BEGIN (or BEGIN READ WRITE); `Read` for BEGIN READ ONLY.
378#[allow(clippy::large_enum_variant)]
379pub(crate) enum ActiveTxn<'a> {
380    None,
381    Write(WriteTxn<'a>),
382    Read(citadel_txn::read_txn::ReadTxn<'a>),
383}
384
385impl<'a> ActiveTxn<'a> {
386    fn is_none(&self) -> bool {
387        matches!(self, ActiveTxn::None)
388    }
389    fn is_active(&self) -> bool {
390        !self.is_none()
391    }
392    fn is_read_only(&self) -> bool {
393        matches!(self, ActiveTxn::Read(_))
394    }
395    fn as_write_mut(&mut self) -> Option<&mut WriteTxn<'a>> {
396        match self {
397            ActiveTxn::Write(w) => Some(w),
398            _ => None,
399        }
400    }
401    fn take(&mut self) -> ActiveTxn<'a> {
402        std::mem::replace(self, ActiveTxn::None)
403    }
404}
405
406pub(crate) struct ConnectionInner<'a> {
407    pub(crate) schema: SchemaManager,
408    active_txn: ActiveTxn<'a>,
409    savepoint_stack: Vec<SavepointEntry>,
410    in_place_saved: Option<bool>,
411    pub(crate) stmt_cache: LruCache<String, CacheEntry>,
412    txn_start_ts: Option<i64>,
413    session_timezone: String,
414    /// Namespaces TEMP tables as `__temp_<id>_<name>`. Cleaned up on Connection drop.
415    temp_id: u64,
416    temp_table_names: Vec<String>,
417}
418
419pub struct Connection<'a> {
420    pub(crate) db: &'a Database,
421    pub(crate) inner: RefCell<ConnectionInner<'a>>,
422}
423
424impl<'a> Connection<'a> {
425    pub fn open(db: &'a Database) -> Result<Self> {
426        let schema = SchemaManager::load(db)?;
427        let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
428        let temp_id = generate_temp_id();
429        Ok(Self {
430            db,
431            inner: RefCell::new(ConnectionInner {
432                schema,
433                active_txn: ActiveTxn::None,
434                savepoint_stack: Vec::new(),
435                in_place_saved: None,
436                stmt_cache,
437                txn_start_ts: None,
438                session_timezone: "UTC".to_string(),
439                temp_id,
440                temp_table_names: Vec::new(),
441            }),
442        })
443    }
444
445    /// Txn-start UTC µs inside BEGIN/COMMIT, else `None`.
446    pub fn txn_start_ts(&self) -> Option<i64> {
447        self.inner.borrow().txn_start_ts
448    }
449
450    /// Returns the session time-zone (IANA name or fixed offset). Default `"UTC"`.
451    pub fn session_timezone(&self) -> String {
452        self.inner.borrow().session_timezone.clone()
453    }
454
455    /// Set the session time-zone. Accepts IANA names, ISO-8601 offsets, `"UTC"`, `"Z"`.
456    pub fn set_session_timezone(&self, tz: &str) -> Result<()> {
457        self.inner.borrow_mut().set_session_timezone_impl(tz)
458    }
459
460    pub fn execute(&self, sql: &str) -> Result<ExecutionResult> {
461        self.inner.borrow_mut().execute_impl(self.db, sql)
462    }
463
464    pub fn execute_params(&self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
465        self.inner
466            .borrow_mut()
467            .execute_params_impl(self.db, sql, params)
468    }
469
470    /// Execute `;`-separated SQL statements. Stops at the first failure.
471    pub fn execute_script(&self, sql: &str) -> ScriptExecution {
472        let stmts = match parser::parse_sql_multi(sql) {
473            Ok(s) => s,
474            Err(e) => {
475                return ScriptExecution {
476                    completed: vec![],
477                    error: Some(e),
478                }
479            }
480        };
481        let mut completed = Vec::with_capacity(stmts.len());
482        for stmt in stmts {
483            match self.inner.borrow_mut().dispatch(self.db, &stmt, &[]) {
484                Ok(r) => completed.push(r),
485                Err(e) => {
486                    return ScriptExecution {
487                        completed,
488                        error: Some(e),
489                    }
490                }
491            }
492        }
493        ScriptExecution {
494            completed,
495            error: None,
496        }
497    }
498
499    pub fn query(&self, sql: &str) -> Result<QueryResult> {
500        self.query_params(sql, &[])
501    }
502
503    pub fn query_params(&self, sql: &str, params: &[Value]) -> Result<QueryResult> {
504        match self.execute_params(sql, params)? {
505            ExecutionResult::Query(qr) => Ok(qr),
506            ExecutionResult::RowsAffected(n) => Ok(QueryResult {
507                columns: vec!["rows_affected".into()],
508                rows: vec![vec![Value::Integer(n as i64)]],
509            }),
510            ExecutionResult::Ok => Ok(QueryResult {
511                columns: vec![],
512                rows: vec![],
513            }),
514        }
515    }
516
517    pub fn prepare(&self, sql: &str) -> Result<PreparedStatement<'_, 'a>> {
518        if let Some(rewritten) = rewrite_show_triggers(sql) {
519            return PreparedStatement::new(self, &rewritten);
520        }
521        if let Some(rewritten) = rewrite_show_matviews(sql) {
522            return PreparedStatement::new(self, &rewritten);
523        }
524        PreparedStatement::new(self, sql)
525    }
526
527    pub fn tables(&self) -> Vec<String> {
528        self.inner
529            .borrow()
530            .schema
531            .table_names()
532            .into_iter()
533            .map(String::from)
534            .collect()
535    }
536
537    /// Returns true if an explicit transaction is active (BEGIN was issued).
538    pub fn in_transaction(&self) -> bool {
539        self.inner.borrow().active_txn.is_active()
540    }
541
542    pub fn table_schema(&self, name: &str) -> Option<TableSchema> {
543        self.inner.borrow().schema.get(name).cloned()
544    }
545
546    pub fn refresh_schema(&self) -> Result<()> {
547        let new_schema = SchemaManager::load(self.db)?;
548        self.inner.borrow_mut().schema = new_schema;
549        Ok(())
550    }
551
552    /// Freeze the ANN index for `table.column` into a persisted segment: one
553    /// write txn scans, builds, serializes, and commits atomically; subsequent
554    /// cold attaches LOAD it (seconds) instead of rebuilding (minutes), with
555    /// the load-time scan re-proving freshness by content. The single writer
556    /// lock is held for the whole build - an offline/builder operation.
557    /// Refused inside an explicit transaction (it owns its own txn), and for
558    /// TEMP tables (their storage bypasses the DDL paths that purge segments).
559    pub fn persist_ann_index(
560        &self,
561        table: &str,
562        column: &str,
563    ) -> Result<crate::executor::AnnSegmentInfo> {
564        if self.in_transaction() {
565            return Err(SqlError::InvalidValue(
566                "persist_ann_index: not allowed inside an explicit transaction".into(),
567            ));
568        }
569        let inner = self.inner.borrow();
570        let lower = table.to_ascii_lowercase();
571        if inner.schema.resolve_temp(&lower) != lower {
572            return Err(SqlError::InvalidValue(
573                "persist_ann_index: TEMP tables are not persistable".into(),
574            ));
575        }
576        let table_schema = inner
577            .schema
578            .get(&lower)
579            .ok_or_else(|| SqlError::TableNotFound(table.to_string()))?;
580        crate::executor::persist_ann_index(self.db, &inner.schema, table_schema, column)
581    }
582
583    /// The identity of the index currently cached for `table.column`:
584    /// `(source, snapshot generation)` - `Loaded{segment_b3}` means queries are
585    /// served by the persisted segment; `Built{refusal}` carries why a segment
586    /// was rejected, if one was.
587    pub fn ann_cache_status(
588        &self,
589        table: &str,
590        column: &str,
591    ) -> Result<Option<(crate::executor::AnnIndexSource, u64)>> {
592        let inner = self.inner.borrow();
593        let table_schema = inner
594            .schema
595            .get(&table.to_ascii_lowercase())
596            .ok_or_else(|| SqlError::TableNotFound(table.to_string()))?;
597        crate::executor::ann_cache_status(&inner.schema, table_schema, column)
598    }
599}
600
601impl<'a> ConnectionInner<'a> {
602    pub(crate) fn active_txn_is_some(&self) -> bool {
603        self.active_txn.is_active()
604    }
605
606    fn set_session_timezone_impl(&mut self, tz: &str) -> Result<()> {
607        let upper = tz.to_ascii_uppercase();
608        if (upper.starts_with("UTC+") || upper.starts_with("UTC-")) && tz.len() > 3 {
609            return Err(SqlError::InvalidTimezone(format!(
610                "'{tz}' is ambiguous; use ISO-8601 offset (e.g. '+05:00') or named zone (e.g. 'Etc/GMT-5')"
611            )));
612        }
613        if jiff::tz::TimeZone::get(tz).is_err() && parse_fixed_offset(tz).is_none() {
614            return Err(SqlError::InvalidTimezone(format!(
615                "{tz}: not a known IANA zone or ISO-8601 offset (e.g. '+05:00', 'UTC', 'America/New_York')"
616            )));
617        }
618        self.session_timezone = tz.to_string();
619        Ok(())
620    }
621
622    fn execute_impl(&mut self, db: &'a Database, sql: &str) -> Result<ExecutionResult> {
623        if let Some(rewritten) = rewrite_show_triggers(sql) {
624            return self.execute_params_impl(db, &rewritten, &[]);
625        }
626        if let Some(rewritten) = rewrite_show_matviews(sql) {
627            return self.execute_params_impl(db, &rewritten, &[]);
628        }
629        if matches!(sql.as_bytes().first(), Some(b'I' | b'i')) {
630            if let Some((normalized_key, extracted)) = try_normalize_insert(sql) {
631                let gen = self.schema.generation();
632                let stmt = if let Some(entry) = self.stmt_cache.get(&normalized_key) {
633                    if entry.schema_gen == gen {
634                        Arc::clone(&entry.stmt)
635                    } else {
636                        self.parse_and_cache(normalized_key, gen)?
637                    }
638                } else {
639                    self.parse_and_cache(normalized_key, gen)?
640                };
641                return self.dispatch(db, &stmt, &extracted);
642            }
643        }
644        self.execute_params_impl(db, sql, &[])
645    }
646
647    fn execute_params_impl(
648        &mut self,
649        db: &'a Database,
650        sql: &str,
651        params: &[Value],
652    ) -> Result<ExecutionResult> {
653        let gen = self.schema.generation();
654        if self.active_txn.is_none() {
655            if let Some(entry) = self.stmt_cache.get(sql) {
656                if entry.schema_gen == gen && entry.param_count == params.len() {
657                    if let Some(plan) = entry.compiled.as_ref().map(Arc::clone) {
658                        let stmt = Arc::clone(&entry.stmt);
659                        return self.run_compiled(db, &plan, &stmt, params);
660                    }
661                }
662            }
663        }
664
665        let (stmt, param_count) = self.get_or_parse(sql)?;
666
667        if param_count != params.len() {
668            return Err(SqlError::ParameterCountMismatch {
669                expected: param_count,
670                got: params.len(),
671            });
672        }
673
674        if self.active_txn.is_none() {
675            if let Some(plan) = executor::compile(&self.schema, &stmt) {
676                if let Some(e) = self.stmt_cache.get_mut(sql) {
677                    e.compiled = Some(Arc::clone(&plan));
678                }
679                let stmt_owned = Arc::clone(&stmt);
680                return self.run_compiled(db, &plan, &stmt_owned, params);
681            }
682        }
683
684        self.dispatch(db, &stmt, params)
685    }
686
687    fn run_compiled(
688        &mut self,
689        db: &'a Database,
690        plan: &Arc<dyn executor::CompiledPlan>,
691        stmt: &Statement,
692        params: &[Value],
693    ) -> Result<ExecutionResult> {
694        use executor::compile::ActiveTxnRef;
695        let schema = &self.schema;
696        let exec = || {
697            if params.is_empty() {
698                plan.execute(db, schema, stmt, params, ActiveTxnRef::None)
699            } else {
700                crate::eval::with_scoped_params(params, || {
701                    plan.execute(db, schema, stmt, params, ActiveTxnRef::None)
702                })
703            }
704        };
705        let outcome = if plan.needs_txn_clock() {
706            let cached_ts = self
707                .txn_start_ts
708                .or_else(|| Some(crate::datetime::now_micros()));
709            crate::datetime::with_txn_clock(cached_ts, exec)
710        } else {
711            exec()
712        }?;
713        invalidate_dml_caches(&self.schema, db);
714        Ok(outcome)
715    }
716
717    pub(crate) fn parse_and_cache(
718        &mut self,
719        normalized_key: String,
720        gen: u64,
721    ) -> Result<Arc<Statement>> {
722        let stmt = Arc::new(parser::parse_sql(&normalized_key)?);
723        let param_count = parser::count_params(&stmt);
724        self.stmt_cache.put(
725            normalized_key,
726            CacheEntry {
727                stmt: Arc::clone(&stmt),
728                schema_gen: gen,
729                param_count,
730                compiled: None,
731            },
732        );
733        Ok(stmt)
734    }
735
736    pub(crate) fn get_or_parse(&mut self, sql: &str) -> Result<(Arc<Statement>, usize)> {
737        let gen = self.schema.generation();
738
739        if let Some(entry) = self.stmt_cache.get(sql) {
740            if entry.schema_gen == gen {
741                return Ok((Arc::clone(&entry.stmt), entry.param_count));
742            }
743        }
744
745        let stmt = Arc::new(parser::parse_sql(sql)?);
746        let param_count = parser::count_params(&stmt);
747
748        let cacheable = !matches!(
749            *stmt,
750            Statement::CreateTable(_)
751                | Statement::DropTable(_)
752                | Statement::CreateIndex(_)
753                | Statement::DropIndex(_)
754                | Statement::CreateView(_)
755                | Statement::DropView(_)
756                | Statement::AlterTable(_)
757        );
758
759        if cacheable {
760            self.stmt_cache.put(
761                sql.to_string(),
762                CacheEntry {
763                    stmt: Arc::clone(&stmt),
764                    schema_gen: gen,
765                    param_count,
766                    compiled: None,
767                },
768            );
769        }
770
771        Ok((stmt, param_count))
772    }
773
774    pub(crate) fn execute_prepared(
775        &mut self,
776        db: &'a Database,
777        stmt: &Statement,
778        compiled: Option<&Arc<dyn executor::CompiledPlan>>,
779        params: &[Value],
780    ) -> Result<ExecutionResult> {
781        if let Some(plan) = compiled {
782            if self.active_txn.is_none() {
783                return self.run_compiled(db, plan, stmt, params);
784            }
785            if !self.savepoint_stack.is_empty() && stmt_mutates(stmt) {
786                self.capture_pending_snapshots();
787            }
788            return self.run_compiled_in_txn(db, plan, stmt, params);
789        }
790        self.dispatch(db, stmt, params)
791    }
792
793    fn run_compiled_in_txn(
794        &mut self,
795        db: &'a Database,
796        plan: &Arc<dyn executor::CompiledPlan>,
797        stmt: &Statement,
798        params: &[Value],
799    ) -> Result<ExecutionResult> {
800        use executor::compile::ActiveTxnRef;
801        let schema = &self.schema;
802        let txn = match &mut self.active_txn {
803            ActiveTxn::Write(wtx) => ActiveTxnRef::Write(wtx),
804            ActiveTxn::Read(rtx) => ActiveTxnRef::Read(rtx),
805            ActiveTxn::None => ActiveTxnRef::None,
806        };
807        if params.is_empty() || !plan.uses_scoped_params() {
808            plan.execute(db, schema, stmt, params, txn)
809        } else {
810            crate::eval::with_scoped_params(params, || plan.execute(db, schema, stmt, params, txn))
811        }
812    }
813
814    pub(crate) fn dispatch(
815        &mut self,
816        db: &'a Database,
817        stmt: &Statement,
818        params: &[Value],
819    ) -> Result<ExecutionResult> {
820        let cached_ts = self
821            .txn_start_ts
822            .or_else(|| Some(crate::datetime::now_micros()));
823        crate::datetime::with_txn_clock(cached_ts, || {
824            if params.is_empty() {
825                self.dispatch_inner(db, stmt, params)
826            } else {
827                crate::eval::with_scoped_params(params, || self.dispatch_inner(db, stmt, params))
828            }
829        })
830    }
831
832    fn dispatch_inner(
833        &mut self,
834        db: &'a Database,
835        stmt: &Statement,
836        params: &[Value],
837    ) -> Result<ExecutionResult> {
838        match stmt {
839            Statement::Begin { access_mode } => {
840                if self.active_txn.is_active() {
841                    return Err(SqlError::TransactionAlreadyActive);
842                }
843                let ts = crate::datetime::txn_or_clock_micros();
844                match access_mode {
845                    BeginAccessMode::ReadOnly => {
846                        let rtx = db.begin_read();
847                        self.active_txn = ActiveTxn::Read(rtx);
848                    }
849                    BeginAccessMode::ReadWrite | BeginAccessMode::Default => {
850                        let wtx = db.begin_write().map_err(SqlError::Storage)?;
851                        self.active_txn = ActiveTxn::Write(wtx);
852                    }
853                }
854                self.txn_start_ts = Some(ts);
855                crate::datetime::set_txn_clock(Some(ts));
856                Ok(ExecutionResult::Ok)
857            }
858            Statement::Commit => {
859                match self.active_txn.take() {
860                    ActiveTxn::None => return Err(SqlError::NoActiveTransaction),
861                    ActiveTxn::Write(mut wtx) => {
862                        crate::executor::helpers::drain_deferred_fk_checks(&mut wtx)?;
863                        wtx.commit().map_err(SqlError::Storage)?;
864                        invalidate_dml_caches(&self.schema, db);
865                    }
866                    ActiveTxn::Read(_rtx) => {}
867                }
868                self.clear_savepoint_state();
869                self.txn_start_ts = None;
870                crate::datetime::set_txn_clock(None);
871                Ok(ExecutionResult::Ok)
872            }
873            Statement::Rollback => {
874                match self.active_txn.take() {
875                    ActiveTxn::None => return Err(SqlError::NoActiveTransaction),
876                    ActiveTxn::Write(wtx) => {
877                        wtx.abort();
878                        self.schema = SchemaManager::load(db)?;
879                    }
880                    ActiveTxn::Read(_rtx) => {}
881                }
882                self.clear_savepoint_state();
883                self.txn_start_ts = None;
884                crate::datetime::set_txn_clock(None);
885                Ok(ExecutionResult::Ok)
886            }
887            Statement::Savepoint(name) => self.do_savepoint(name),
888            Statement::ReleaseSavepoint(name) => self.do_release(name),
889            Statement::RollbackTo(name) => self.do_rollback_to(name),
890            Statement::SetTimezone(zone) => {
891                self.set_session_timezone_impl(zone)?;
892                Ok(ExecutionResult::Ok)
893            }
894            Statement::CreateTable(ct) if ct.temporary => {
895                if self.active_txn.is_read_only() {
896                    return Err(SqlError::Unsupported(
897                        "cannot execute mutating statement inside a read-only transaction".into(),
898                    ));
899                }
900                let user_name = ct.name.clone();
901                let prefixed = temp_storage_name(self.temp_id, &user_name);
902                if self.schema.contains(&user_name) {
903                    if ct.if_not_exists {
904                        return Ok(ExecutionResult::Ok);
905                    }
906                    return Err(SqlError::TableAlreadyExists(user_name));
907                }
908                let mut clone = ct.clone();
909                clone.name = prefixed.clone();
910                clone.temporary = false;
911                let stmt_concrete = Statement::CreateTable(clone);
912                let outcome = if let Some(wtx) = self.active_txn.as_write_mut() {
913                    executor::execute_in_txn(wtx, &mut self.schema, &stmt_concrete, params)?
914                } else {
915                    executor::execute(db, &mut self.schema, &stmt_concrete, params)?
916                };
917                self.schema
918                    .register_temp_alias(&user_name, prefixed.clone());
919                self.temp_table_names.push(prefixed);
920                Ok(outcome)
921            }
922            Statement::Insert(ins) if self.active_txn.as_write_mut().is_some() => {
923                self.capture_pending_snapshots();
924                let wtx = self.active_txn.as_write_mut().unwrap();
925                executor::exec_insert_in_txn(wtx, &self.schema, ins, params)
926            }
927            _ => {
928                if self.active_txn.is_read_only() && stmt_mutates(stmt) {
929                    return Err(SqlError::Unsupported(
930                        "cannot execute mutating statement inside a read-only transaction".into(),
931                    ));
932                }
933                if self.active_txn.as_write_mut().is_some() && stmt_mutates(stmt) {
934                    self.capture_pending_snapshots();
935                }
936                let was_auto_commit = matches!(self.active_txn, ActiveTxn::None);
937                let outcome = match &mut self.active_txn {
938                    ActiveTxn::Write(wtx) => {
939                        executor::execute_in_txn(wtx, &mut self.schema, stmt, params)?
940                    }
941                    ActiveTxn::Read(rtx) => {
942                        executor::execute_with_read(rtx, &self.schema, stmt, params)?
943                    }
944                    ActiveTxn::None => executor::execute(db, &mut self.schema, stmt, params)?,
945                };
946                if was_auto_commit {
947                    invalidate_dml_caches(&self.schema, db);
948                }
949                if let Statement::DropTable(dt) = stmt {
950                    self.schema.unregister_temp_alias(&dt.name);
951                }
952                Ok(outcome)
953            }
954        }
955    }
956
957    fn clear_savepoint_state(&mut self) {
958        self.savepoint_stack.clear();
959        self.in_place_saved = None;
960    }
961
962    fn do_savepoint(&mut self, name: &str) -> Result<ExecutionResult> {
963        let wtx = self
964            .active_txn
965            .as_write_mut()
966            .ok_or(SqlError::NoActiveTransaction)?;
967
968        if self.savepoint_stack.is_empty() {
969            self.in_place_saved = Some(wtx.in_place());
970            wtx.set_in_place(false);
971        }
972
973        self.savepoint_stack.push(SavepointEntry {
974            name: name.to_string(),
975            snapshot: None,
976        });
977
978        Ok(ExecutionResult::Ok)
979    }
980
981    fn capture_pending_snapshots(&mut self) {
982        let last_pending = match self
983            .savepoint_stack
984            .iter()
985            .rposition(|e| e.snapshot.is_none())
986        {
987            Some(i) => i,
988            None => return,
989        };
990        let wtx = match self.active_txn.as_write_mut() {
991            Some(w) => w,
992            None => return,
993        };
994        let wtx_snap = wtx.begin_savepoint();
995        let schema_snap = self.schema.save_snapshot();
996
997        for i in 0..last_pending {
998            if self.savepoint_stack[i].snapshot.is_none() {
999                self.savepoint_stack[i].snapshot = Some(SavepointSnapshot {
1000                    wtx_snap: wtx_snap.clone(),
1001                    schema_snap: schema_snap.clone(),
1002                });
1003            }
1004        }
1005        self.savepoint_stack[last_pending].snapshot = Some(SavepointSnapshot {
1006            wtx_snap,
1007            schema_snap,
1008        });
1009    }
1010
1011    fn do_release(&mut self, name: &str) -> Result<ExecutionResult> {
1012        if !self.active_txn.is_active() {
1013            return Err(SqlError::NoActiveTransaction);
1014        }
1015
1016        let idx = self
1017            .savepoint_stack
1018            .iter()
1019            .rposition(|e| e.name == name)
1020            .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
1021        self.savepoint_stack.truncate(idx);
1022
1023        if self.savepoint_stack.is_empty() {
1024            if let (Some(wtx), Some(original)) =
1025                (self.active_txn.as_write_mut(), self.in_place_saved.take())
1026            {
1027                wtx.set_in_place(original);
1028            }
1029        }
1030
1031        Ok(ExecutionResult::Ok)
1032    }
1033
1034    fn do_rollback_to(&mut self, name: &str) -> Result<ExecutionResult> {
1035        if !self.active_txn.is_active() {
1036            return Err(SqlError::NoActiveTransaction);
1037        }
1038
1039        let idx = self
1040            .savepoint_stack
1041            .iter()
1042            .rposition(|e| e.name == name)
1043            .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
1044
1045        self.savepoint_stack.truncate(idx + 1);
1046        let entry = self.savepoint_stack.last_mut().unwrap();
1047        let snapshot = match entry.snapshot.take() {
1048            Some(s) => s,
1049            None => return Ok(ExecutionResult::Ok),
1050        };
1051
1052        let wtx = match self.active_txn.as_write_mut() {
1053            Some(w) => w,
1054            None => return Err(SqlError::NoActiveTransaction),
1055        };
1056        wtx.restore_snapshot(snapshot.wtx_snap);
1057        self.schema.restore_snapshot(snapshot.schema_snap);
1058
1059        Ok(ExecutionResult::Ok)
1060    }
1061}
1062
1063impl<'a> Drop for Connection<'a> {
1064    fn drop(&mut self) {
1065        let temp_names = std::mem::take(&mut self.inner.borrow_mut().temp_table_names);
1066        if temp_names.is_empty() {
1067            return;
1068        }
1069        if let Ok(mut wtx) = self.db.begin_write() {
1070            for prefixed in &temp_names {
1071                let _ = wtx.drop_table(prefixed.as_bytes());
1072            }
1073            let _ = wtx.commit();
1074        }
1075    }
1076}