Skip to main content

citadel_sql/
connection.rs

1//! Public SQL connection API.
2
3use std::cell::RefCell;
4use std::num::NonZeroUsize;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7
8fn generate_temp_id() -> u64 {
9    static COUNTER: AtomicU64 = AtomicU64::new(0);
10    let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
11    let nanos = (crate::datetime::now_micros() as u64) & 0xFFFF_FFFF;
12    (nanos << 32) | (counter & 0xFFFF_FFFF)
13}
14
15fn temp_storage_name(temp_id: u64, user_name: &str) -> String {
16    format!("__temp_{temp_id}_{}", user_name.to_ascii_lowercase())
17}
18
19use lru::LruCache;
20
21use citadel::Database;
22use citadel_txn::write_txn::{WriteTxn, WriteTxnSnapshot};
23
24use crate::error::{Result, SqlError};
25use crate::executor;
26use crate::parser;
27use crate::parser::{BeginAccessMode, QueryBody, SelectQuery, Statement};
28use crate::prepared::PreparedStatement;
29use crate::schema::{SchemaManager, SchemaSnapshot};
30use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
31
32const DEFAULT_CACHE_CAPACITY: usize = 64;
33
34/// On commit, evict shared caches (e.g. ANN indexes) for DML-touched tables,
35/// and stamp each table's last-DML generation marker: an index whose snapshot
36/// predates the marker is refused at lookup AND at insert, closing the
37/// build-races-a-commit window that prefix eviction alone leaves open.
38fn invalidate_dml_caches(schema: &SchemaManager, db: &Database) {
39    let gen = db.manager().commit_generation();
40    for table in schema.drain_dml_dirty() {
41        db.sql_cache_invalidate_prefix(&format!("ann:{table}:"));
42        let marker: std::sync::Arc<dyn std::any::Any + Send + Sync> = std::sync::Arc::new(gen);
43        schema
44            .sql_caches
45            .lock()
46            .insert(crate::executor::ann_dml_gen_key(&table), marker);
47    }
48}
49
50#[derive(Debug)]
51pub struct ScriptExecution {
52    pub completed: Vec<ExecutionResult>,
53    pub error: Option<SqlError>,
54}
55
56fn parse_fixed_offset(s: &str) -> Option<()> {
57    let s = s.trim();
58    if s.eq_ignore_ascii_case("z") || s.eq_ignore_ascii_case("utc") {
59        return Some(());
60    }
61    let bytes = s.as_bytes();
62    if bytes.is_empty() {
63        return None;
64    }
65    if !matches!(bytes[0], b'+' | b'-') {
66        return None;
67    }
68    let rest = &s[1..];
69    let (hh, mm) = if let Some((h, m)) = rest.split_once(':') {
70        (h, m)
71    } else if rest.len() == 4 {
72        (&rest[..2], &rest[2..])
73    } else if rest.len() == 2 {
74        (rest, "00")
75    } else {
76        return None;
77    };
78    let h: u32 = hh.parse().ok()?;
79    let m: u32 = mm.parse().ok()?;
80    if h > 23 || m > 59 {
81        return None;
82    }
83    Some(())
84}
85
86fn rewrite_show_triggers(sql: &str) -> Option<String> {
87    let trimmed = sql.trim();
88    let trimmed = trimmed.trim_end_matches(';').trim();
89    let lower = trimmed.to_ascii_lowercase();
90    if !lower.starts_with("show triggers") {
91        return None;
92    }
93    let after = lower["show triggers".len()..].trim_start();
94    let base = "SELECT trigger_name, event_object_table AS table_name, action_timing, \
95                event_manipulation, action_orientation, action_statement \
96                FROM information_schema.triggers";
97    if after.is_empty() {
98        return Some(format!("{base} ORDER BY trigger_name"));
99    }
100    if let Some(rest) = after.strip_prefix("on ") {
101        let table = rest.trim().trim_end_matches(';').trim();
102        if table.is_empty() {
103            return None;
104        }
105        let escaped = table.replace('\'', "''");
106        return Some(format!(
107            "{base} WHERE LOWER(event_object_table) = LOWER('{escaped}') ORDER BY trigger_name"
108        ));
109    }
110    None
111}
112
113fn rewrite_show_matviews(sql: &str) -> Option<String> {
114    let trimmed = sql.trim();
115    let trimmed = trimmed.trim_end_matches(';').trim();
116    let lower = trimmed.to_ascii_lowercase();
117    if lower != "show materialized views" {
118        return None;
119    }
120    Some(
121        "SELECT matviewname, ispopulated, hasindexes, definition \
122         FROM pg_matviews ORDER BY matviewname"
123            .to_string(),
124    )
125}
126
127fn stmt_mutates(stmt: &Statement) -> bool {
128    if matches!(
129        stmt,
130        Statement::Insert(_)
131            | Statement::Update(_)
132            | Statement::Delete(_)
133            | Statement::Truncate(_)
134            | Statement::CreateTable(_)
135            | Statement::DropTable(_)
136            | Statement::AlterTable(_)
137            | Statement::CreateIndex(_)
138            | Statement::DropIndex(_)
139            | Statement::CreateView(_)
140            | Statement::DropView(_)
141            | Statement::CreateTrigger(_)
142            | Statement::DropTrigger(_)
143            | Statement::CreateMaterializedView(_)
144            | Statement::RefreshMaterializedView(_)
145            | Statement::DropMaterializedView(_)
146    ) {
147        return true;
148    }
149    if let Statement::Select(sq) = stmt {
150        if select_query_has_dml(sq) {
151            return true;
152        }
153    }
154    false
155}
156
157fn select_query_has_dml(sq: &SelectQuery) -> bool {
158    sq.ctes.iter().any(|cte| query_body_has_dml(&cte.body)) || query_body_has_dml(&sq.body)
159}
160
161fn query_body_has_dml(body: &QueryBody) -> bool {
162    match body {
163        QueryBody::Insert(_) | QueryBody::Update(_) | QueryBody::Delete(_) => true,
164        QueryBody::Compound(c) => query_body_has_dml(&c.left) || query_body_has_dml(&c.right),
165        QueryBody::Select(_) => false,
166    }
167}
168
169fn try_normalize_insert(sql: &str) -> Option<(String, Vec<Value>)> {
170    let bytes = sql.as_bytes();
171    let len = bytes.len();
172    let mut i = 0;
173
174    while i < len && bytes[i].is_ascii_whitespace() {
175        i += 1;
176    }
177    if i + 6 > len || !bytes[i..i + 6].eq_ignore_ascii_case(b"INSERT") {
178        return None;
179    }
180    i += 6;
181    if i >= len || !bytes[i].is_ascii_whitespace() {
182        return None;
183    }
184    while i < len && bytes[i].is_ascii_whitespace() {
185        i += 1;
186    }
187
188    if i + 4 > len || !bytes[i..i + 4].eq_ignore_ascii_case(b"INTO") {
189        return None;
190    }
191    i += 4;
192    if i >= len || !bytes[i].is_ascii_whitespace() {
193        return None;
194    }
195
196    let prefix_start = 0;
197    let mut values_pos = None;
198    let mut j = i;
199    while j + 6 <= len {
200        if bytes[j..j + 6].eq_ignore_ascii_case(b"VALUES")
201            && (j == 0 || !bytes[j - 1].is_ascii_alphanumeric() && bytes[j - 1] != b'_')
202            && (j + 6 >= len || !bytes[j + 6].is_ascii_alphanumeric() && bytes[j + 6] != b'_')
203        {
204            values_pos = Some(j);
205            break;
206        }
207        j += 1;
208    }
209    let values_pos = values_pos?;
210
211    let prefix = &sql[prefix_start..values_pos + 6];
212    let mut pos = values_pos + 6;
213
214    while pos < len && bytes[pos].is_ascii_whitespace() {
215        pos += 1;
216    }
217    if pos >= len || bytes[pos] != b'(' {
218        return None;
219    }
220    pos += 1;
221
222    let mut values = Vec::new();
223    let mut normalized = String::with_capacity(sql.len());
224    normalized.push_str(prefix);
225    normalized.push_str(" (");
226
227    loop {
228        while pos < len && bytes[pos].is_ascii_whitespace() {
229            pos += 1;
230        }
231        if pos >= len {
232            return None;
233        }
234
235        let param_idx = values.len() + 1;
236        if param_idx > 1 {
237            normalized.push_str(", ");
238        }
239
240        if bytes[pos] == b'\'' {
241            pos += 1;
242            let mut seg_start = pos;
243            let mut s = String::new();
244            loop {
245                if pos >= len {
246                    return None;
247                }
248                if bytes[pos] == b'\'' {
249                    s.push_str(std::str::from_utf8(&bytes[seg_start..pos]).ok()?);
250                    pos += 1;
251                    if pos < len && bytes[pos] == b'\'' {
252                        s.push('\'');
253                        pos += 1;
254                        seg_start = pos;
255                    } else {
256                        break;
257                    }
258                } else {
259                    pos += 1;
260                }
261            }
262            values.push(Value::Text(s.into()));
263        } else if bytes[pos] == b'-' || bytes[pos].is_ascii_digit() {
264            let start = pos;
265            if bytes[pos] == b'-' {
266                pos += 1;
267            }
268            while pos < len && bytes[pos].is_ascii_digit() {
269                pos += 1;
270            }
271            if pos < len && bytes[pos] == b'.' {
272                pos += 1;
273                while pos < len && bytes[pos].is_ascii_digit() {
274                    pos += 1;
275                }
276                let num: f64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
277                values.push(Value::Real(num));
278            } else {
279                let num: i64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
280                values.push(Value::Integer(num));
281            }
282        } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"NULL") {
283            let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
284            if !after.is_ascii_alphanumeric() && after != b'_' {
285                pos += 4;
286                values.push(Value::Null);
287            } else {
288                return None;
289            }
290        } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"TRUE") {
291            let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
292            if !after.is_ascii_alphanumeric() && after != b'_' {
293                pos += 4;
294                values.push(Value::Boolean(true));
295            } else {
296                return None;
297            }
298        } else if pos + 5 <= len && bytes[pos..pos + 5].eq_ignore_ascii_case(b"FALSE") {
299            let after = if pos + 5 < len { bytes[pos + 5] } else { b')' };
300            if !after.is_ascii_alphanumeric() && after != b'_' {
301                pos += 5;
302                values.push(Value::Boolean(false));
303            } else {
304                return None;
305            }
306        } else {
307            return None;
308        }
309
310        normalized.push('$');
311        normalized.push_str(&param_idx.to_string());
312
313        while pos < len && bytes[pos].is_ascii_whitespace() {
314            pos += 1;
315        }
316        if pos >= len {
317            return None;
318        }
319
320        if bytes[pos] == b',' {
321            pos += 1;
322        } else if bytes[pos] == b')' {
323            pos += 1;
324            break;
325        } else {
326            return None;
327        }
328    }
329
330    normalized.push(')');
331
332    while pos < len && (bytes[pos].is_ascii_whitespace() || bytes[pos] == b';') {
333        pos += 1;
334    }
335    if pos != len {
336        return None;
337    }
338
339    if values.is_empty() {
340        return None;
341    }
342
343    Some((normalized, values))
344}
345
346pub(crate) struct CacheEntry {
347    pub(crate) stmt: Arc<Statement>,
348    pub(crate) schema_gen: u64,
349    pub(crate) param_count: usize,
350    pub(crate) compiled: Option<Arc<dyn executor::CompiledPlan>>,
351}
352
353struct SavepointEntry {
354    name: String,
355    snapshot: Option<SavepointSnapshot>,
356}
357
358struct SavepointSnapshot {
359    wtx_snap: WriteTxnSnapshot,
360    schema_snap: SchemaSnapshot,
361}
362
363/// Active transaction held by a Connection. `None` outside BEGIN/COMMIT; `Write` for normal
364/// BEGIN (or BEGIN READ WRITE); `Read` for BEGIN READ ONLY.
365#[allow(clippy::large_enum_variant)]
366pub(crate) enum ActiveTxn<'a> {
367    None,
368    Write(WriteTxn<'a>),
369    Read(citadel_txn::read_txn::ReadTxn<'a>),
370}
371
372impl<'a> ActiveTxn<'a> {
373    fn is_none(&self) -> bool {
374        matches!(self, ActiveTxn::None)
375    }
376    fn is_active(&self) -> bool {
377        !self.is_none()
378    }
379    fn is_read_only(&self) -> bool {
380        matches!(self, ActiveTxn::Read(_))
381    }
382    fn as_write_mut(&mut self) -> Option<&mut WriteTxn<'a>> {
383        match self {
384            ActiveTxn::Write(w) => Some(w),
385            _ => None,
386        }
387    }
388    fn take(&mut self) -> ActiveTxn<'a> {
389        std::mem::replace(self, ActiveTxn::None)
390    }
391}
392
393pub(crate) struct ConnectionInner<'a> {
394    pub(crate) schema: SchemaManager,
395    active_txn: ActiveTxn<'a>,
396    savepoint_stack: Vec<SavepointEntry>,
397    in_place_saved: Option<bool>,
398    pub(crate) stmt_cache: LruCache<String, CacheEntry>,
399    txn_start_ts: Option<i64>,
400    session_timezone: String,
401    /// Namespaces TEMP tables as `__temp_<id>_<name>`. Cleaned up on Connection drop.
402    temp_id: u64,
403    temp_table_names: Vec<String>,
404}
405
406pub struct Connection<'a> {
407    pub(crate) db: &'a Database,
408    pub(crate) inner: RefCell<ConnectionInner<'a>>,
409}
410
411impl<'a> Connection<'a> {
412    pub fn open(db: &'a Database) -> Result<Self> {
413        let schema = SchemaManager::load(db)?;
414        let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
415        let temp_id = generate_temp_id();
416        Ok(Self {
417            db,
418            inner: RefCell::new(ConnectionInner {
419                schema,
420                active_txn: ActiveTxn::None,
421                savepoint_stack: Vec::new(),
422                in_place_saved: None,
423                stmt_cache,
424                txn_start_ts: None,
425                session_timezone: "UTC".to_string(),
426                temp_id,
427                temp_table_names: Vec::new(),
428            }),
429        })
430    }
431
432    /// Txn-start UTC µs inside BEGIN/COMMIT, else `None`.
433    pub fn txn_start_ts(&self) -> Option<i64> {
434        self.inner.borrow().txn_start_ts
435    }
436
437    /// Returns the session time-zone (IANA name or fixed offset). Default `"UTC"`.
438    pub fn session_timezone(&self) -> String {
439        self.inner.borrow().session_timezone.clone()
440    }
441
442    /// Set the session time-zone. Accepts IANA names, ISO-8601 offsets, `"UTC"`, `"Z"`.
443    pub fn set_session_timezone(&self, tz: &str) -> Result<()> {
444        self.inner.borrow_mut().set_session_timezone_impl(tz)
445    }
446
447    pub fn execute(&self, sql: &str) -> Result<ExecutionResult> {
448        self.inner.borrow_mut().execute_impl(self.db, sql)
449    }
450
451    pub fn execute_params(&self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
452        self.inner
453            .borrow_mut()
454            .execute_params_impl(self.db, sql, params)
455    }
456
457    /// Execute `;`-separated SQL statements. Stops at the first failure.
458    pub fn execute_script(&self, sql: &str) -> ScriptExecution {
459        let stmts = match parser::parse_sql_multi(sql) {
460            Ok(s) => s,
461            Err(e) => {
462                return ScriptExecution {
463                    completed: vec![],
464                    error: Some(e),
465                }
466            }
467        };
468        let mut completed = Vec::with_capacity(stmts.len());
469        for stmt in stmts {
470            match self.inner.borrow_mut().dispatch(self.db, &stmt, &[]) {
471                Ok(r) => completed.push(r),
472                Err(e) => {
473                    return ScriptExecution {
474                        completed,
475                        error: Some(e),
476                    }
477                }
478            }
479        }
480        ScriptExecution {
481            completed,
482            error: None,
483        }
484    }
485
486    pub fn query(&self, sql: &str) -> Result<QueryResult> {
487        self.query_params(sql, &[])
488    }
489
490    pub fn query_params(&self, sql: &str, params: &[Value]) -> Result<QueryResult> {
491        match self.execute_params(sql, params)? {
492            ExecutionResult::Query(qr) => Ok(qr),
493            ExecutionResult::RowsAffected(n) => Ok(QueryResult {
494                columns: vec!["rows_affected".into()],
495                rows: vec![vec![Value::Integer(n as i64)]],
496            }),
497            ExecutionResult::Ok => Ok(QueryResult {
498                columns: vec![],
499                rows: vec![],
500            }),
501        }
502    }
503
504    pub fn prepare(&self, sql: &str) -> Result<PreparedStatement<'_, 'a>> {
505        if let Some(rewritten) = rewrite_show_triggers(sql) {
506            return PreparedStatement::new(self, &rewritten);
507        }
508        if let Some(rewritten) = rewrite_show_matviews(sql) {
509            return PreparedStatement::new(self, &rewritten);
510        }
511        PreparedStatement::new(self, sql)
512    }
513
514    pub fn tables(&self) -> Vec<String> {
515        self.inner
516            .borrow()
517            .schema
518            .table_names()
519            .into_iter()
520            .map(String::from)
521            .collect()
522    }
523
524    /// Returns true if an explicit transaction is active (BEGIN was issued).
525    pub fn in_transaction(&self) -> bool {
526        self.inner.borrow().active_txn.is_active()
527    }
528
529    pub fn table_schema(&self, name: &str) -> Option<TableSchema> {
530        self.inner.borrow().schema.get(name).cloned()
531    }
532
533    pub fn refresh_schema(&self) -> Result<()> {
534        let new_schema = SchemaManager::load(self.db)?;
535        self.inner.borrow_mut().schema = new_schema;
536        Ok(())
537    }
538
539    /// Freeze the ANN index for `table.column` into a persisted segment: one
540    /// write txn scans, builds, serializes, and commits atomically; subsequent
541    /// cold attaches LOAD it (seconds) instead of rebuilding (minutes), with
542    /// the load-time scan re-proving freshness by content. The single writer
543    /// lock is held for the whole build - an offline/builder operation.
544    /// Refused inside an explicit transaction (it owns its own txn), and for
545    /// TEMP tables (their storage bypasses the DDL paths that purge segments).
546    pub fn persist_ann_index(
547        &self,
548        table: &str,
549        column: &str,
550    ) -> Result<crate::executor::AnnSegmentInfo> {
551        if self.in_transaction() {
552            return Err(SqlError::InvalidValue(
553                "persist_ann_index: not allowed inside an explicit transaction".into(),
554            ));
555        }
556        let inner = self.inner.borrow();
557        let lower = table.to_ascii_lowercase();
558        if inner.schema.resolve_temp(&lower) != lower {
559            return Err(SqlError::InvalidValue(
560                "persist_ann_index: TEMP tables are not persistable".into(),
561            ));
562        }
563        let table_schema = inner
564            .schema
565            .get(&lower)
566            .ok_or_else(|| SqlError::TableNotFound(table.to_string()))?;
567        crate::executor::persist_ann_index(self.db, &inner.schema, table_schema, column)
568    }
569
570    /// The identity of the index currently cached for `table.column`:
571    /// `(source, snapshot generation)` - `Loaded{segment_b3}` means queries are
572    /// served by the persisted segment; `Built{refusal}` carries why a segment
573    /// was rejected, if one was.
574    pub fn ann_cache_status(
575        &self,
576        table: &str,
577        column: &str,
578    ) -> Result<Option<(crate::executor::AnnIndexSource, u64)>> {
579        let inner = self.inner.borrow();
580        let table_schema = inner
581            .schema
582            .get(&table.to_ascii_lowercase())
583            .ok_or_else(|| SqlError::TableNotFound(table.to_string()))?;
584        crate::executor::ann_cache_status(&inner.schema, table_schema, column)
585    }
586}
587
588impl<'a> ConnectionInner<'a> {
589    pub(crate) fn active_txn_is_some(&self) -> bool {
590        self.active_txn.is_active()
591    }
592
593    fn set_session_timezone_impl(&mut self, tz: &str) -> Result<()> {
594        let upper = tz.to_ascii_uppercase();
595        if (upper.starts_with("UTC+") || upper.starts_with("UTC-")) && tz.len() > 3 {
596            return Err(SqlError::InvalidTimezone(format!(
597                "'{tz}' is ambiguous; use ISO-8601 offset (e.g. '+05:00') or named zone (e.g. 'Etc/GMT-5')"
598            )));
599        }
600        if jiff::tz::TimeZone::get(tz).is_err() && parse_fixed_offset(tz).is_none() {
601            return Err(SqlError::InvalidTimezone(format!(
602                "{tz}: not a known IANA zone or ISO-8601 offset (e.g. '+05:00', 'UTC', 'America/New_York')"
603            )));
604        }
605        self.session_timezone = tz.to_string();
606        Ok(())
607    }
608
609    fn execute_impl(&mut self, db: &'a Database, sql: &str) -> Result<ExecutionResult> {
610        if let Some(rewritten) = rewrite_show_triggers(sql) {
611            return self.execute_params_impl(db, &rewritten, &[]);
612        }
613        if let Some(rewritten) = rewrite_show_matviews(sql) {
614            return self.execute_params_impl(db, &rewritten, &[]);
615        }
616        if matches!(sql.as_bytes().first(), Some(b'I' | b'i')) {
617            if let Some((normalized_key, extracted)) = try_normalize_insert(sql) {
618                let gen = self.schema.generation();
619                let stmt = if let Some(entry) = self.stmt_cache.get(&normalized_key) {
620                    if entry.schema_gen == gen {
621                        Arc::clone(&entry.stmt)
622                    } else {
623                        self.parse_and_cache(normalized_key, gen)?
624                    }
625                } else {
626                    self.parse_and_cache(normalized_key, gen)?
627                };
628                return self.dispatch(db, &stmt, &extracted);
629            }
630        }
631        self.execute_params_impl(db, sql, &[])
632    }
633
634    fn execute_params_impl(
635        &mut self,
636        db: &'a Database,
637        sql: &str,
638        params: &[Value],
639    ) -> Result<ExecutionResult> {
640        let gen = self.schema.generation();
641        if self.active_txn.is_none() {
642            if let Some(entry) = self.stmt_cache.get(sql) {
643                if entry.schema_gen == gen && entry.param_count == params.len() {
644                    if let Some(plan) = entry.compiled.as_ref().map(Arc::clone) {
645                        let stmt = Arc::clone(&entry.stmt);
646                        return self.run_compiled(db, &plan, &stmt, params);
647                    }
648                }
649            }
650        }
651
652        let (stmt, param_count) = self.get_or_parse(sql)?;
653
654        if param_count != params.len() {
655            return Err(SqlError::ParameterCountMismatch {
656                expected: param_count,
657                got: params.len(),
658            });
659        }
660
661        if self.active_txn.is_none() {
662            if let Some(plan) = executor::compile(&self.schema, &stmt) {
663                if let Some(e) = self.stmt_cache.get_mut(sql) {
664                    e.compiled = Some(Arc::clone(&plan));
665                }
666                let stmt_owned = Arc::clone(&stmt);
667                return self.run_compiled(db, &plan, &stmt_owned, params);
668            }
669        }
670
671        self.dispatch(db, &stmt, params)
672    }
673
674    fn run_compiled(
675        &mut self,
676        db: &'a Database,
677        plan: &Arc<dyn executor::CompiledPlan>,
678        stmt: &Statement,
679        params: &[Value],
680    ) -> Result<ExecutionResult> {
681        use executor::compile::ActiveTxnRef;
682        let schema = &self.schema;
683        let exec = || {
684            if params.is_empty() {
685                plan.execute(db, schema, stmt, params, ActiveTxnRef::None)
686            } else {
687                crate::eval::with_scoped_params(params, || {
688                    plan.execute(db, schema, stmt, params, ActiveTxnRef::None)
689                })
690            }
691        };
692        let outcome = if plan.needs_txn_clock() {
693            let cached_ts = self
694                .txn_start_ts
695                .or_else(|| Some(crate::datetime::now_micros()));
696            crate::datetime::with_txn_clock(cached_ts, exec)
697        } else {
698            exec()
699        }?;
700        invalidate_dml_caches(&self.schema, db);
701        Ok(outcome)
702    }
703
704    pub(crate) fn parse_and_cache(
705        &mut self,
706        normalized_key: String,
707        gen: u64,
708    ) -> Result<Arc<Statement>> {
709        let stmt = Arc::new(parser::parse_sql(&normalized_key)?);
710        let param_count = parser::count_params(&stmt);
711        self.stmt_cache.put(
712            normalized_key,
713            CacheEntry {
714                stmt: Arc::clone(&stmt),
715                schema_gen: gen,
716                param_count,
717                compiled: None,
718            },
719        );
720        Ok(stmt)
721    }
722
723    pub(crate) fn get_or_parse(&mut self, sql: &str) -> Result<(Arc<Statement>, usize)> {
724        let gen = self.schema.generation();
725
726        if let Some(entry) = self.stmt_cache.get(sql) {
727            if entry.schema_gen == gen {
728                return Ok((Arc::clone(&entry.stmt), entry.param_count));
729            }
730        }
731
732        let stmt = Arc::new(parser::parse_sql(sql)?);
733        let param_count = parser::count_params(&stmt);
734
735        let cacheable = !matches!(
736            *stmt,
737            Statement::CreateTable(_)
738                | Statement::DropTable(_)
739                | Statement::CreateIndex(_)
740                | Statement::DropIndex(_)
741                | Statement::CreateView(_)
742                | Statement::DropView(_)
743                | Statement::AlterTable(_)
744        );
745
746        if cacheable {
747            self.stmt_cache.put(
748                sql.to_string(),
749                CacheEntry {
750                    stmt: Arc::clone(&stmt),
751                    schema_gen: gen,
752                    param_count,
753                    compiled: None,
754                },
755            );
756        }
757
758        Ok((stmt, param_count))
759    }
760
761    pub(crate) fn execute_prepared(
762        &mut self,
763        db: &'a Database,
764        stmt: &Statement,
765        compiled: Option<&Arc<dyn executor::CompiledPlan>>,
766        params: &[Value],
767    ) -> Result<ExecutionResult> {
768        if let Some(plan) = compiled {
769            if self.active_txn.is_none() {
770                return self.run_compiled(db, plan, stmt, params);
771            }
772            if !self.savepoint_stack.is_empty() && stmt_mutates(stmt) {
773                self.capture_pending_snapshots();
774            }
775            return self.run_compiled_in_txn(db, plan, stmt, params);
776        }
777        self.dispatch(db, stmt, params)
778    }
779
780    fn run_compiled_in_txn(
781        &mut self,
782        db: &'a Database,
783        plan: &Arc<dyn executor::CompiledPlan>,
784        stmt: &Statement,
785        params: &[Value],
786    ) -> Result<ExecutionResult> {
787        use executor::compile::ActiveTxnRef;
788        let schema = &self.schema;
789        let txn = match &mut self.active_txn {
790            ActiveTxn::Write(wtx) => ActiveTxnRef::Write(wtx),
791            ActiveTxn::Read(rtx) => ActiveTxnRef::Read(rtx),
792            ActiveTxn::None => ActiveTxnRef::None,
793        };
794        if params.is_empty() || !plan.uses_scoped_params() {
795            plan.execute(db, schema, stmt, params, txn)
796        } else {
797            crate::eval::with_scoped_params(params, || plan.execute(db, schema, stmt, params, txn))
798        }
799    }
800
801    pub(crate) fn dispatch(
802        &mut self,
803        db: &'a Database,
804        stmt: &Statement,
805        params: &[Value],
806    ) -> Result<ExecutionResult> {
807        let cached_ts = self
808            .txn_start_ts
809            .or_else(|| Some(crate::datetime::now_micros()));
810        crate::datetime::with_txn_clock(cached_ts, || {
811            if params.is_empty() {
812                self.dispatch_inner(db, stmt, params)
813            } else {
814                crate::eval::with_scoped_params(params, || self.dispatch_inner(db, stmt, params))
815            }
816        })
817    }
818
819    fn dispatch_inner(
820        &mut self,
821        db: &'a Database,
822        stmt: &Statement,
823        params: &[Value],
824    ) -> Result<ExecutionResult> {
825        match stmt {
826            Statement::Begin { access_mode } => {
827                if self.active_txn.is_active() {
828                    return Err(SqlError::TransactionAlreadyActive);
829                }
830                let ts = crate::datetime::txn_or_clock_micros();
831                match access_mode {
832                    BeginAccessMode::ReadOnly => {
833                        let rtx = db.begin_read();
834                        self.active_txn = ActiveTxn::Read(rtx);
835                    }
836                    BeginAccessMode::ReadWrite | BeginAccessMode::Default => {
837                        let wtx = db.begin_write().map_err(SqlError::Storage)?;
838                        self.active_txn = ActiveTxn::Write(wtx);
839                    }
840                }
841                self.txn_start_ts = Some(ts);
842                crate::datetime::set_txn_clock(Some(ts));
843                Ok(ExecutionResult::Ok)
844            }
845            Statement::Commit => {
846                match self.active_txn.take() {
847                    ActiveTxn::None => return Err(SqlError::NoActiveTransaction),
848                    ActiveTxn::Write(mut wtx) => {
849                        crate::executor::helpers::drain_deferred_fk_checks(&mut wtx)?;
850                        wtx.commit().map_err(SqlError::Storage)?;
851                        invalidate_dml_caches(&self.schema, db);
852                    }
853                    ActiveTxn::Read(_rtx) => {}
854                }
855                self.clear_savepoint_state();
856                self.txn_start_ts = None;
857                crate::datetime::set_txn_clock(None);
858                Ok(ExecutionResult::Ok)
859            }
860            Statement::Rollback => {
861                match self.active_txn.take() {
862                    ActiveTxn::None => return Err(SqlError::NoActiveTransaction),
863                    ActiveTxn::Write(wtx) => {
864                        wtx.abort();
865                        self.schema = SchemaManager::load(db)?;
866                    }
867                    ActiveTxn::Read(_rtx) => {}
868                }
869                self.clear_savepoint_state();
870                self.txn_start_ts = None;
871                crate::datetime::set_txn_clock(None);
872                Ok(ExecutionResult::Ok)
873            }
874            Statement::Savepoint(name) => self.do_savepoint(name),
875            Statement::ReleaseSavepoint(name) => self.do_release(name),
876            Statement::RollbackTo(name) => self.do_rollback_to(name),
877            Statement::SetTimezone(zone) => {
878                self.set_session_timezone_impl(zone)?;
879                Ok(ExecutionResult::Ok)
880            }
881            Statement::CreateTable(ct) if ct.temporary => {
882                if self.active_txn.is_read_only() {
883                    return Err(SqlError::Unsupported(
884                        "cannot execute mutating statement inside a read-only transaction".into(),
885                    ));
886                }
887                let user_name = ct.name.clone();
888                let prefixed = temp_storage_name(self.temp_id, &user_name);
889                if self.schema.contains(&user_name) {
890                    if ct.if_not_exists {
891                        return Ok(ExecutionResult::Ok);
892                    }
893                    return Err(SqlError::TableAlreadyExists(user_name));
894                }
895                let mut clone = ct.clone();
896                clone.name = prefixed.clone();
897                clone.temporary = false;
898                let stmt_concrete = Statement::CreateTable(clone);
899                let outcome = if let Some(wtx) = self.active_txn.as_write_mut() {
900                    executor::execute_in_txn(wtx, &mut self.schema, &stmt_concrete, params)?
901                } else {
902                    executor::execute(db, &mut self.schema, &stmt_concrete, params)?
903                };
904                self.schema
905                    .register_temp_alias(&user_name, prefixed.clone());
906                self.temp_table_names.push(prefixed);
907                Ok(outcome)
908            }
909            Statement::Insert(ins) if self.active_txn.as_write_mut().is_some() => {
910                self.capture_pending_snapshots();
911                let wtx = self.active_txn.as_write_mut().unwrap();
912                executor::exec_insert_in_txn(wtx, &self.schema, ins, params)
913            }
914            _ => {
915                if self.active_txn.is_read_only() && stmt_mutates(stmt) {
916                    return Err(SqlError::Unsupported(
917                        "cannot execute mutating statement inside a read-only transaction".into(),
918                    ));
919                }
920                if self.active_txn.as_write_mut().is_some() && stmt_mutates(stmt) {
921                    self.capture_pending_snapshots();
922                }
923                let was_auto_commit = matches!(self.active_txn, ActiveTxn::None);
924                let outcome = match &mut self.active_txn {
925                    ActiveTxn::Write(wtx) => {
926                        executor::execute_in_txn(wtx, &mut self.schema, stmt, params)?
927                    }
928                    ActiveTxn::Read(rtx) => {
929                        executor::execute_with_read(rtx, &self.schema, stmt, params)?
930                    }
931                    ActiveTxn::None => executor::execute(db, &mut self.schema, stmt, params)?,
932                };
933                if was_auto_commit {
934                    invalidate_dml_caches(&self.schema, db);
935                }
936                if let Statement::DropTable(dt) = stmt {
937                    self.schema.unregister_temp_alias(&dt.name);
938                }
939                Ok(outcome)
940            }
941        }
942    }
943
944    fn clear_savepoint_state(&mut self) {
945        self.savepoint_stack.clear();
946        self.in_place_saved = None;
947    }
948
949    fn do_savepoint(&mut self, name: &str) -> Result<ExecutionResult> {
950        let wtx = self
951            .active_txn
952            .as_write_mut()
953            .ok_or(SqlError::NoActiveTransaction)?;
954
955        if self.savepoint_stack.is_empty() {
956            self.in_place_saved = Some(wtx.in_place());
957            wtx.set_in_place(false);
958        }
959
960        self.savepoint_stack.push(SavepointEntry {
961            name: name.to_string(),
962            snapshot: None,
963        });
964
965        Ok(ExecutionResult::Ok)
966    }
967
968    fn capture_pending_snapshots(&mut self) {
969        let last_pending = match self
970            .savepoint_stack
971            .iter()
972            .rposition(|e| e.snapshot.is_none())
973        {
974            Some(i) => i,
975            None => return,
976        };
977        let wtx = match self.active_txn.as_write_mut() {
978            Some(w) => w,
979            None => return,
980        };
981        let wtx_snap = wtx.begin_savepoint();
982        let schema_snap = self.schema.save_snapshot();
983
984        for i in 0..last_pending {
985            if self.savepoint_stack[i].snapshot.is_none() {
986                self.savepoint_stack[i].snapshot = Some(SavepointSnapshot {
987                    wtx_snap: wtx_snap.clone(),
988                    schema_snap: schema_snap.clone(),
989                });
990            }
991        }
992        self.savepoint_stack[last_pending].snapshot = Some(SavepointSnapshot {
993            wtx_snap,
994            schema_snap,
995        });
996    }
997
998    fn do_release(&mut self, name: &str) -> Result<ExecutionResult> {
999        if !self.active_txn.is_active() {
1000            return Err(SqlError::NoActiveTransaction);
1001        }
1002
1003        let idx = self
1004            .savepoint_stack
1005            .iter()
1006            .rposition(|e| e.name == name)
1007            .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
1008        self.savepoint_stack.truncate(idx);
1009
1010        if self.savepoint_stack.is_empty() {
1011            if let (Some(wtx), Some(original)) =
1012                (self.active_txn.as_write_mut(), self.in_place_saved.take())
1013            {
1014                wtx.set_in_place(original);
1015            }
1016        }
1017
1018        Ok(ExecutionResult::Ok)
1019    }
1020
1021    fn do_rollback_to(&mut self, name: &str) -> Result<ExecutionResult> {
1022        if !self.active_txn.is_active() {
1023            return Err(SqlError::NoActiveTransaction);
1024        }
1025
1026        let idx = self
1027            .savepoint_stack
1028            .iter()
1029            .rposition(|e| e.name == name)
1030            .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
1031
1032        self.savepoint_stack.truncate(idx + 1);
1033        let entry = self.savepoint_stack.last_mut().unwrap();
1034        let snapshot = match entry.snapshot.take() {
1035            Some(s) => s,
1036            None => return Ok(ExecutionResult::Ok),
1037        };
1038
1039        let wtx = match self.active_txn.as_write_mut() {
1040            Some(w) => w,
1041            None => return Err(SqlError::NoActiveTransaction),
1042        };
1043        wtx.restore_snapshot(snapshot.wtx_snap);
1044        self.schema.restore_snapshot(snapshot.schema_snap);
1045
1046        Ok(ExecutionResult::Ok)
1047    }
1048}
1049
1050impl<'a> Drop for Connection<'a> {
1051    fn drop(&mut self) {
1052        let temp_names = std::mem::take(&mut self.inner.borrow_mut().temp_table_names);
1053        if temp_names.is_empty() {
1054            return;
1055        }
1056        if let Ok(mut wtx) = self.db.begin_write() {
1057            for prefixed in &temp_names {
1058                let _ = wtx.drop_table(prefixed.as_bytes());
1059            }
1060            let _ = wtx.commit();
1061        }
1062    }
1063}