Skip to main content

citadel_sql/
connection.rs

1//! Public SQL connection API.
2
3use std::cell::RefCell;
4use std::num::NonZeroUsize;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7
8fn generate_temp_id() -> u64 {
9    static COUNTER: AtomicU64 = AtomicU64::new(0);
10    let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
11    let nanos = (crate::datetime::now_micros() as u64) & 0xFFFF_FFFF;
12    (nanos << 32) | (counter & 0xFFFF_FFFF)
13}
14
15fn temp_storage_name(temp_id: u64, user_name: &str) -> String {
16    format!("__temp_{temp_id}_{}", user_name.to_ascii_lowercase())
17}
18
19use lru::LruCache;
20
21use citadel::Database;
22use citadel_txn::write_txn::{WriteTxn, WriteTxnSnapshot};
23
24use crate::error::{Result, SqlError};
25use crate::executor;
26use crate::parser;
27use crate::parser::{BeginAccessMode, QueryBody, SelectQuery, Statement};
28use crate::prepared::PreparedStatement;
29use crate::schema::{SchemaManager, SchemaSnapshot};
30use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
31
32const DEFAULT_CACHE_CAPACITY: usize = 64;
33
34/// On commit, evict shared caches (e.g. ANN indexes) for DML-touched tables.
35fn invalidate_dml_caches(schema: &SchemaManager, db: &Database) {
36    for table in schema.drain_dml_dirty() {
37        db.sql_cache_invalidate_prefix(&format!("ann:{table}:"));
38    }
39}
40
41#[derive(Debug)]
42pub struct ScriptExecution {
43    pub completed: Vec<ExecutionResult>,
44    pub error: Option<SqlError>,
45}
46
47fn parse_fixed_offset(s: &str) -> Option<()> {
48    let s = s.trim();
49    if s.eq_ignore_ascii_case("z") || s.eq_ignore_ascii_case("utc") {
50        return Some(());
51    }
52    let bytes = s.as_bytes();
53    if bytes.is_empty() {
54        return None;
55    }
56    if !matches!(bytes[0], b'+' | b'-') {
57        return None;
58    }
59    let rest = &s[1..];
60    let (hh, mm) = if let Some((h, m)) = rest.split_once(':') {
61        (h, m)
62    } else if rest.len() == 4 {
63        (&rest[..2], &rest[2..])
64    } else if rest.len() == 2 {
65        (rest, "00")
66    } else {
67        return None;
68    };
69    let h: u32 = hh.parse().ok()?;
70    let m: u32 = mm.parse().ok()?;
71    if h > 23 || m > 59 {
72        return None;
73    }
74    Some(())
75}
76
77fn rewrite_show_triggers(sql: &str) -> Option<String> {
78    let trimmed = sql.trim();
79    let trimmed = trimmed.trim_end_matches(';').trim();
80    let lower = trimmed.to_ascii_lowercase();
81    if !lower.starts_with("show triggers") {
82        return None;
83    }
84    let after = lower["show triggers".len()..].trim_start();
85    let base = "SELECT trigger_name, event_object_table AS table_name, action_timing, \
86                event_manipulation, action_orientation, action_statement \
87                FROM information_schema.triggers";
88    if after.is_empty() {
89        return Some(format!("{base} ORDER BY trigger_name"));
90    }
91    if let Some(rest) = after.strip_prefix("on ") {
92        let table = rest.trim().trim_end_matches(';').trim();
93        if table.is_empty() {
94            return None;
95        }
96        let escaped = table.replace('\'', "''");
97        return Some(format!(
98            "{base} WHERE LOWER(event_object_table) = LOWER('{escaped}') ORDER BY trigger_name"
99        ));
100    }
101    None
102}
103
104fn rewrite_show_matviews(sql: &str) -> Option<String> {
105    let trimmed = sql.trim();
106    let trimmed = trimmed.trim_end_matches(';').trim();
107    let lower = trimmed.to_ascii_lowercase();
108    if lower != "show materialized views" {
109        return None;
110    }
111    Some(
112        "SELECT matviewname, ispopulated, hasindexes, definition \
113         FROM pg_matviews ORDER BY matviewname"
114            .to_string(),
115    )
116}
117
118fn stmt_mutates(stmt: &Statement) -> bool {
119    if matches!(
120        stmt,
121        Statement::Insert(_)
122            | Statement::Update(_)
123            | Statement::Delete(_)
124            | Statement::Truncate(_)
125            | Statement::CreateTable(_)
126            | Statement::DropTable(_)
127            | Statement::AlterTable(_)
128            | Statement::CreateIndex(_)
129            | Statement::DropIndex(_)
130            | Statement::CreateView(_)
131            | Statement::DropView(_)
132            | Statement::CreateTrigger(_)
133            | Statement::DropTrigger(_)
134            | Statement::CreateMaterializedView(_)
135            | Statement::RefreshMaterializedView(_)
136            | Statement::DropMaterializedView(_)
137    ) {
138        return true;
139    }
140    if let Statement::Select(sq) = stmt {
141        if select_query_has_dml(sq) {
142            return true;
143        }
144    }
145    false
146}
147
148fn select_query_has_dml(sq: &SelectQuery) -> bool {
149    sq.ctes.iter().any(|cte| query_body_has_dml(&cte.body)) || query_body_has_dml(&sq.body)
150}
151
152fn query_body_has_dml(body: &QueryBody) -> bool {
153    match body {
154        QueryBody::Insert(_) | QueryBody::Update(_) | QueryBody::Delete(_) => true,
155        QueryBody::Compound(c) => query_body_has_dml(&c.left) || query_body_has_dml(&c.right),
156        QueryBody::Select(_) => false,
157    }
158}
159
160fn try_normalize_insert(sql: &str) -> Option<(String, Vec<Value>)> {
161    let bytes = sql.as_bytes();
162    let len = bytes.len();
163    let mut i = 0;
164
165    while i < len && bytes[i].is_ascii_whitespace() {
166        i += 1;
167    }
168    if i + 6 > len || !bytes[i..i + 6].eq_ignore_ascii_case(b"INSERT") {
169        return None;
170    }
171    i += 6;
172    if i >= len || !bytes[i].is_ascii_whitespace() {
173        return None;
174    }
175    while i < len && bytes[i].is_ascii_whitespace() {
176        i += 1;
177    }
178
179    if i + 4 > len || !bytes[i..i + 4].eq_ignore_ascii_case(b"INTO") {
180        return None;
181    }
182    i += 4;
183    if i >= len || !bytes[i].is_ascii_whitespace() {
184        return None;
185    }
186
187    let prefix_start = 0;
188    let mut values_pos = None;
189    let mut j = i;
190    while j + 6 <= len {
191        if bytes[j..j + 6].eq_ignore_ascii_case(b"VALUES")
192            && (j == 0 || !bytes[j - 1].is_ascii_alphanumeric() && bytes[j - 1] != b'_')
193            && (j + 6 >= len || !bytes[j + 6].is_ascii_alphanumeric() && bytes[j + 6] != b'_')
194        {
195            values_pos = Some(j);
196            break;
197        }
198        j += 1;
199    }
200    let values_pos = values_pos?;
201
202    let prefix = &sql[prefix_start..values_pos + 6];
203    let mut pos = values_pos + 6;
204
205    while pos < len && bytes[pos].is_ascii_whitespace() {
206        pos += 1;
207    }
208    if pos >= len || bytes[pos] != b'(' {
209        return None;
210    }
211    pos += 1;
212
213    let mut values = Vec::new();
214    let mut normalized = String::with_capacity(sql.len());
215    normalized.push_str(prefix);
216    normalized.push_str(" (");
217
218    loop {
219        while pos < len && bytes[pos].is_ascii_whitespace() {
220            pos += 1;
221        }
222        if pos >= len {
223            return None;
224        }
225
226        let param_idx = values.len() + 1;
227        if param_idx > 1 {
228            normalized.push_str(", ");
229        }
230
231        if bytes[pos] == b'\'' {
232            pos += 1;
233            let mut seg_start = pos;
234            let mut s = String::new();
235            loop {
236                if pos >= len {
237                    return None;
238                }
239                if bytes[pos] == b'\'' {
240                    s.push_str(std::str::from_utf8(&bytes[seg_start..pos]).ok()?);
241                    pos += 1;
242                    if pos < len && bytes[pos] == b'\'' {
243                        s.push('\'');
244                        pos += 1;
245                        seg_start = pos;
246                    } else {
247                        break;
248                    }
249                } else {
250                    pos += 1;
251                }
252            }
253            values.push(Value::Text(s.into()));
254        } else if bytes[pos] == b'-' || bytes[pos].is_ascii_digit() {
255            let start = pos;
256            if bytes[pos] == b'-' {
257                pos += 1;
258            }
259            while pos < len && bytes[pos].is_ascii_digit() {
260                pos += 1;
261            }
262            if pos < len && bytes[pos] == b'.' {
263                pos += 1;
264                while pos < len && bytes[pos].is_ascii_digit() {
265                    pos += 1;
266                }
267                let num: f64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
268                values.push(Value::Real(num));
269            } else {
270                let num: i64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
271                values.push(Value::Integer(num));
272            }
273        } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"NULL") {
274            let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
275            if !after.is_ascii_alphanumeric() && after != b'_' {
276                pos += 4;
277                values.push(Value::Null);
278            } else {
279                return None;
280            }
281        } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"TRUE") {
282            let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
283            if !after.is_ascii_alphanumeric() && after != b'_' {
284                pos += 4;
285                values.push(Value::Boolean(true));
286            } else {
287                return None;
288            }
289        } else if pos + 5 <= len && bytes[pos..pos + 5].eq_ignore_ascii_case(b"FALSE") {
290            let after = if pos + 5 < len { bytes[pos + 5] } else { b')' };
291            if !after.is_ascii_alphanumeric() && after != b'_' {
292                pos += 5;
293                values.push(Value::Boolean(false));
294            } else {
295                return None;
296            }
297        } else {
298            return None;
299        }
300
301        normalized.push('$');
302        normalized.push_str(&param_idx.to_string());
303
304        while pos < len && bytes[pos].is_ascii_whitespace() {
305            pos += 1;
306        }
307        if pos >= len {
308            return None;
309        }
310
311        if bytes[pos] == b',' {
312            pos += 1;
313        } else if bytes[pos] == b')' {
314            pos += 1;
315            break;
316        } else {
317            return None;
318        }
319    }
320
321    normalized.push(')');
322
323    while pos < len && (bytes[pos].is_ascii_whitespace() || bytes[pos] == b';') {
324        pos += 1;
325    }
326    if pos != len {
327        return None;
328    }
329
330    if values.is_empty() {
331        return None;
332    }
333
334    Some((normalized, values))
335}
336
337pub(crate) struct CacheEntry {
338    pub(crate) stmt: Arc<Statement>,
339    pub(crate) schema_gen: u64,
340    pub(crate) param_count: usize,
341    pub(crate) compiled: Option<Arc<dyn executor::CompiledPlan>>,
342}
343
344struct SavepointEntry {
345    name: String,
346    snapshot: Option<SavepointSnapshot>,
347}
348
349struct SavepointSnapshot {
350    wtx_snap: WriteTxnSnapshot,
351    schema_snap: SchemaSnapshot,
352}
353
354/// Active transaction held by a Connection. `None` outside BEGIN/COMMIT; `Write` for normal
355/// BEGIN (or BEGIN READ WRITE); `Read` for BEGIN READ ONLY.
356#[allow(clippy::large_enum_variant)]
357pub(crate) enum ActiveTxn<'a> {
358    None,
359    Write(WriteTxn<'a>),
360    Read(citadel_txn::read_txn::ReadTxn<'a>),
361}
362
363impl<'a> ActiveTxn<'a> {
364    fn is_none(&self) -> bool {
365        matches!(self, ActiveTxn::None)
366    }
367    fn is_active(&self) -> bool {
368        !self.is_none()
369    }
370    fn is_read_only(&self) -> bool {
371        matches!(self, ActiveTxn::Read(_))
372    }
373    fn as_write_mut(&mut self) -> Option<&mut WriteTxn<'a>> {
374        match self {
375            ActiveTxn::Write(w) => Some(w),
376            _ => None,
377        }
378    }
379    fn take(&mut self) -> ActiveTxn<'a> {
380        std::mem::replace(self, ActiveTxn::None)
381    }
382}
383
384pub(crate) struct ConnectionInner<'a> {
385    pub(crate) schema: SchemaManager,
386    active_txn: ActiveTxn<'a>,
387    savepoint_stack: Vec<SavepointEntry>,
388    in_place_saved: Option<bool>,
389    pub(crate) stmt_cache: LruCache<String, CacheEntry>,
390    txn_start_ts: Option<i64>,
391    session_timezone: String,
392    /// Namespaces TEMP tables as `__temp_<id>_<name>`. Cleaned up on Connection drop.
393    temp_id: u64,
394    temp_table_names: Vec<String>,
395}
396
397pub struct Connection<'a> {
398    pub(crate) db: &'a Database,
399    pub(crate) inner: RefCell<ConnectionInner<'a>>,
400}
401
402impl<'a> Connection<'a> {
403    pub fn open(db: &'a Database) -> Result<Self> {
404        let schema = SchemaManager::load(db)?;
405        let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
406        let temp_id = generate_temp_id();
407        Ok(Self {
408            db,
409            inner: RefCell::new(ConnectionInner {
410                schema,
411                active_txn: ActiveTxn::None,
412                savepoint_stack: Vec::new(),
413                in_place_saved: None,
414                stmt_cache,
415                txn_start_ts: None,
416                session_timezone: "UTC".to_string(),
417                temp_id,
418                temp_table_names: Vec::new(),
419            }),
420        })
421    }
422
423    /// Txn-start UTC µs inside BEGIN/COMMIT, else `None`.
424    pub fn txn_start_ts(&self) -> Option<i64> {
425        self.inner.borrow().txn_start_ts
426    }
427
428    /// Returns the session time-zone (IANA name or fixed offset). Default `"UTC"`.
429    pub fn session_timezone(&self) -> String {
430        self.inner.borrow().session_timezone.clone()
431    }
432
433    /// Set the session time-zone. Accepts IANA names, ISO-8601 offsets, `"UTC"`, `"Z"`.
434    pub fn set_session_timezone(&self, tz: &str) -> Result<()> {
435        self.inner.borrow_mut().set_session_timezone_impl(tz)
436    }
437
438    pub fn execute(&self, sql: &str) -> Result<ExecutionResult> {
439        self.inner.borrow_mut().execute_impl(self.db, sql)
440    }
441
442    pub fn execute_params(&self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
443        self.inner
444            .borrow_mut()
445            .execute_params_impl(self.db, sql, params)
446    }
447
448    /// Execute `;`-separated SQL statements. Stops at the first failure.
449    pub fn execute_script(&self, sql: &str) -> ScriptExecution {
450        let stmts = match parser::parse_sql_multi(sql) {
451            Ok(s) => s,
452            Err(e) => {
453                return ScriptExecution {
454                    completed: vec![],
455                    error: Some(e),
456                }
457            }
458        };
459        let mut completed = Vec::with_capacity(stmts.len());
460        for stmt in stmts {
461            match self.inner.borrow_mut().dispatch(self.db, &stmt, &[]) {
462                Ok(r) => completed.push(r),
463                Err(e) => {
464                    return ScriptExecution {
465                        completed,
466                        error: Some(e),
467                    }
468                }
469            }
470        }
471        ScriptExecution {
472            completed,
473            error: None,
474        }
475    }
476
477    pub fn query(&self, sql: &str) -> Result<QueryResult> {
478        self.query_params(sql, &[])
479    }
480
481    pub fn query_params(&self, sql: &str, params: &[Value]) -> Result<QueryResult> {
482        match self.execute_params(sql, params)? {
483            ExecutionResult::Query(qr) => Ok(qr),
484            ExecutionResult::RowsAffected(n) => Ok(QueryResult {
485                columns: vec!["rows_affected".into()],
486                rows: vec![vec![Value::Integer(n as i64)]],
487            }),
488            ExecutionResult::Ok => Ok(QueryResult {
489                columns: vec![],
490                rows: vec![],
491            }),
492        }
493    }
494
495    pub fn prepare(&self, sql: &str) -> Result<PreparedStatement<'_, 'a>> {
496        if let Some(rewritten) = rewrite_show_triggers(sql) {
497            return PreparedStatement::new(self, &rewritten);
498        }
499        if let Some(rewritten) = rewrite_show_matviews(sql) {
500            return PreparedStatement::new(self, &rewritten);
501        }
502        PreparedStatement::new(self, sql)
503    }
504
505    pub fn tables(&self) -> Vec<String> {
506        self.inner
507            .borrow()
508            .schema
509            .table_names()
510            .into_iter()
511            .map(String::from)
512            .collect()
513    }
514
515    /// Returns true if an explicit transaction is active (BEGIN was issued).
516    pub fn in_transaction(&self) -> bool {
517        self.inner.borrow().active_txn.is_active()
518    }
519
520    pub fn table_schema(&self, name: &str) -> Option<TableSchema> {
521        self.inner.borrow().schema.get(name).cloned()
522    }
523
524    pub fn refresh_schema(&self) -> Result<()> {
525        let new_schema = SchemaManager::load(self.db)?;
526        self.inner.borrow_mut().schema = new_schema;
527        Ok(())
528    }
529}
530
531impl<'a> ConnectionInner<'a> {
532    pub(crate) fn active_txn_is_some(&self) -> bool {
533        self.active_txn.is_active()
534    }
535
536    fn set_session_timezone_impl(&mut self, tz: &str) -> Result<()> {
537        let upper = tz.to_ascii_uppercase();
538        if (upper.starts_with("UTC+") || upper.starts_with("UTC-")) && tz.len() > 3 {
539            return Err(SqlError::InvalidTimezone(format!(
540                "'{tz}' is ambiguous; use ISO-8601 offset (e.g. '+05:00') or named zone (e.g. 'Etc/GMT-5')"
541            )));
542        }
543        if jiff::tz::TimeZone::get(tz).is_err() && parse_fixed_offset(tz).is_none() {
544            return Err(SqlError::InvalidTimezone(format!(
545                "{tz}: not a known IANA zone or ISO-8601 offset (e.g. '+05:00', 'UTC', 'America/New_York')"
546            )));
547        }
548        self.session_timezone = tz.to_string();
549        Ok(())
550    }
551
552    fn execute_impl(&mut self, db: &'a Database, sql: &str) -> Result<ExecutionResult> {
553        if let Some(rewritten) = rewrite_show_triggers(sql) {
554            return self.execute_params_impl(db, &rewritten, &[]);
555        }
556        if let Some(rewritten) = rewrite_show_matviews(sql) {
557            return self.execute_params_impl(db, &rewritten, &[]);
558        }
559        if matches!(sql.as_bytes().first(), Some(b'I' | b'i')) {
560            if let Some((normalized_key, extracted)) = try_normalize_insert(sql) {
561                let gen = self.schema.generation();
562                let stmt = if let Some(entry) = self.stmt_cache.get(&normalized_key) {
563                    if entry.schema_gen == gen {
564                        Arc::clone(&entry.stmt)
565                    } else {
566                        self.parse_and_cache(normalized_key, gen)?
567                    }
568                } else {
569                    self.parse_and_cache(normalized_key, gen)?
570                };
571                return self.dispatch(db, &stmt, &extracted);
572            }
573        }
574        self.execute_params_impl(db, sql, &[])
575    }
576
577    fn execute_params_impl(
578        &mut self,
579        db: &'a Database,
580        sql: &str,
581        params: &[Value],
582    ) -> Result<ExecutionResult> {
583        let gen = self.schema.generation();
584        if self.active_txn.is_none() {
585            if let Some(entry) = self.stmt_cache.get(sql) {
586                if entry.schema_gen == gen && entry.param_count == params.len() {
587                    if let Some(plan) = entry.compiled.as_ref().map(Arc::clone) {
588                        let stmt = Arc::clone(&entry.stmt);
589                        return self.run_compiled(db, &plan, &stmt, params);
590                    }
591                }
592            }
593        }
594
595        let (stmt, param_count) = self.get_or_parse(sql)?;
596
597        if param_count != params.len() {
598            return Err(SqlError::ParameterCountMismatch {
599                expected: param_count,
600                got: params.len(),
601            });
602        }
603
604        if self.active_txn.is_none() {
605            if let Some(plan) = executor::compile(&self.schema, &stmt) {
606                if let Some(e) = self.stmt_cache.get_mut(sql) {
607                    e.compiled = Some(Arc::clone(&plan));
608                }
609                let stmt_owned = Arc::clone(&stmt);
610                return self.run_compiled(db, &plan, &stmt_owned, params);
611            }
612        }
613
614        self.dispatch(db, &stmt, params)
615    }
616
617    fn run_compiled(
618        &mut self,
619        db: &'a Database,
620        plan: &Arc<dyn executor::CompiledPlan>,
621        stmt: &Statement,
622        params: &[Value],
623    ) -> Result<ExecutionResult> {
624        use executor::compile::ActiveTxnRef;
625        let schema = &self.schema;
626        let exec = || {
627            if params.is_empty() {
628                plan.execute(db, schema, stmt, params, ActiveTxnRef::None)
629            } else {
630                crate::eval::with_scoped_params(params, || {
631                    plan.execute(db, schema, stmt, params, ActiveTxnRef::None)
632                })
633            }
634        };
635        let outcome = if plan.needs_txn_clock() {
636            let cached_ts = self
637                .txn_start_ts
638                .or_else(|| Some(crate::datetime::now_micros()));
639            crate::datetime::with_txn_clock(cached_ts, exec)
640        } else {
641            exec()
642        }?;
643        invalidate_dml_caches(&self.schema, db);
644        Ok(outcome)
645    }
646
647    pub(crate) fn parse_and_cache(
648        &mut self,
649        normalized_key: String,
650        gen: u64,
651    ) -> Result<Arc<Statement>> {
652        let stmt = Arc::new(parser::parse_sql(&normalized_key)?);
653        let param_count = parser::count_params(&stmt);
654        self.stmt_cache.put(
655            normalized_key,
656            CacheEntry {
657                stmt: Arc::clone(&stmt),
658                schema_gen: gen,
659                param_count,
660                compiled: None,
661            },
662        );
663        Ok(stmt)
664    }
665
666    pub(crate) fn get_or_parse(&mut self, sql: &str) -> Result<(Arc<Statement>, usize)> {
667        let gen = self.schema.generation();
668
669        if let Some(entry) = self.stmt_cache.get(sql) {
670            if entry.schema_gen == gen {
671                return Ok((Arc::clone(&entry.stmt), entry.param_count));
672            }
673        }
674
675        let stmt = Arc::new(parser::parse_sql(sql)?);
676        let param_count = parser::count_params(&stmt);
677
678        let cacheable = !matches!(
679            *stmt,
680            Statement::CreateTable(_)
681                | Statement::DropTable(_)
682                | Statement::CreateIndex(_)
683                | Statement::DropIndex(_)
684                | Statement::CreateView(_)
685                | Statement::DropView(_)
686                | Statement::AlterTable(_)
687        );
688
689        if cacheable {
690            self.stmt_cache.put(
691                sql.to_string(),
692                CacheEntry {
693                    stmt: Arc::clone(&stmt),
694                    schema_gen: gen,
695                    param_count,
696                    compiled: None,
697                },
698            );
699        }
700
701        Ok((stmt, param_count))
702    }
703
704    pub(crate) fn execute_prepared(
705        &mut self,
706        db: &'a Database,
707        stmt: &Statement,
708        compiled: Option<&Arc<dyn executor::CompiledPlan>>,
709        params: &[Value],
710    ) -> Result<ExecutionResult> {
711        if let Some(plan) = compiled {
712            if self.active_txn.is_none() {
713                return self.run_compiled(db, plan, stmt, params);
714            }
715            if !self.savepoint_stack.is_empty() && stmt_mutates(stmt) {
716                self.capture_pending_snapshots();
717            }
718            return self.run_compiled_in_txn(db, plan, stmt, params);
719        }
720        self.dispatch(db, stmt, params)
721    }
722
723    fn run_compiled_in_txn(
724        &mut self,
725        db: &'a Database,
726        plan: &Arc<dyn executor::CompiledPlan>,
727        stmt: &Statement,
728        params: &[Value],
729    ) -> Result<ExecutionResult> {
730        use executor::compile::ActiveTxnRef;
731        let schema = &self.schema;
732        let txn = match &mut self.active_txn {
733            ActiveTxn::Write(wtx) => ActiveTxnRef::Write(wtx),
734            ActiveTxn::Read(rtx) => ActiveTxnRef::Read(rtx),
735            ActiveTxn::None => ActiveTxnRef::None,
736        };
737        if params.is_empty() || !plan.uses_scoped_params() {
738            plan.execute(db, schema, stmt, params, txn)
739        } else {
740            crate::eval::with_scoped_params(params, || plan.execute(db, schema, stmt, params, txn))
741        }
742    }
743
744    pub(crate) fn dispatch(
745        &mut self,
746        db: &'a Database,
747        stmt: &Statement,
748        params: &[Value],
749    ) -> Result<ExecutionResult> {
750        let cached_ts = self
751            .txn_start_ts
752            .or_else(|| Some(crate::datetime::now_micros()));
753        crate::datetime::with_txn_clock(cached_ts, || {
754            if params.is_empty() {
755                self.dispatch_inner(db, stmt, params)
756            } else {
757                crate::eval::with_scoped_params(params, || self.dispatch_inner(db, stmt, params))
758            }
759        })
760    }
761
762    fn dispatch_inner(
763        &mut self,
764        db: &'a Database,
765        stmt: &Statement,
766        params: &[Value],
767    ) -> Result<ExecutionResult> {
768        match stmt {
769            Statement::Begin { access_mode } => {
770                if self.active_txn.is_active() {
771                    return Err(SqlError::TransactionAlreadyActive);
772                }
773                let ts = crate::datetime::txn_or_clock_micros();
774                match access_mode {
775                    BeginAccessMode::ReadOnly => {
776                        let rtx = db.begin_read();
777                        self.active_txn = ActiveTxn::Read(rtx);
778                    }
779                    BeginAccessMode::ReadWrite | BeginAccessMode::Default => {
780                        let wtx = db.begin_write().map_err(SqlError::Storage)?;
781                        self.active_txn = ActiveTxn::Write(wtx);
782                    }
783                }
784                self.txn_start_ts = Some(ts);
785                crate::datetime::set_txn_clock(Some(ts));
786                Ok(ExecutionResult::Ok)
787            }
788            Statement::Commit => {
789                match self.active_txn.take() {
790                    ActiveTxn::None => return Err(SqlError::NoActiveTransaction),
791                    ActiveTxn::Write(mut wtx) => {
792                        crate::executor::helpers::drain_deferred_fk_checks(&mut wtx)?;
793                        wtx.commit().map_err(SqlError::Storage)?;
794                        invalidate_dml_caches(&self.schema, db);
795                    }
796                    ActiveTxn::Read(_rtx) => {}
797                }
798                self.clear_savepoint_state();
799                self.txn_start_ts = None;
800                crate::datetime::set_txn_clock(None);
801                Ok(ExecutionResult::Ok)
802            }
803            Statement::Rollback => {
804                match self.active_txn.take() {
805                    ActiveTxn::None => return Err(SqlError::NoActiveTransaction),
806                    ActiveTxn::Write(wtx) => {
807                        wtx.abort();
808                        self.schema = SchemaManager::load(db)?;
809                    }
810                    ActiveTxn::Read(_rtx) => {}
811                }
812                self.clear_savepoint_state();
813                self.txn_start_ts = None;
814                crate::datetime::set_txn_clock(None);
815                Ok(ExecutionResult::Ok)
816            }
817            Statement::Savepoint(name) => self.do_savepoint(name),
818            Statement::ReleaseSavepoint(name) => self.do_release(name),
819            Statement::RollbackTo(name) => self.do_rollback_to(name),
820            Statement::SetTimezone(zone) => {
821                self.set_session_timezone_impl(zone)?;
822                Ok(ExecutionResult::Ok)
823            }
824            Statement::CreateTable(ct) if ct.temporary => {
825                if self.active_txn.is_read_only() {
826                    return Err(SqlError::Unsupported(
827                        "cannot execute mutating statement inside a read-only transaction".into(),
828                    ));
829                }
830                let user_name = ct.name.clone();
831                let prefixed = temp_storage_name(self.temp_id, &user_name);
832                if self.schema.contains(&user_name) {
833                    if ct.if_not_exists {
834                        return Ok(ExecutionResult::Ok);
835                    }
836                    return Err(SqlError::TableAlreadyExists(user_name));
837                }
838                let mut clone = ct.clone();
839                clone.name = prefixed.clone();
840                clone.temporary = false;
841                let stmt_concrete = Statement::CreateTable(clone);
842                let outcome = if let Some(wtx) = self.active_txn.as_write_mut() {
843                    executor::execute_in_txn(wtx, &mut self.schema, &stmt_concrete, params)?
844                } else {
845                    executor::execute(db, &mut self.schema, &stmt_concrete, params)?
846                };
847                self.schema
848                    .register_temp_alias(&user_name, prefixed.clone());
849                self.temp_table_names.push(prefixed);
850                Ok(outcome)
851            }
852            Statement::Insert(ins) if self.active_txn.as_write_mut().is_some() => {
853                self.capture_pending_snapshots();
854                let wtx = self.active_txn.as_write_mut().unwrap();
855                executor::exec_insert_in_txn(wtx, &self.schema, ins, params)
856            }
857            _ => {
858                if self.active_txn.is_read_only() && stmt_mutates(stmt) {
859                    return Err(SqlError::Unsupported(
860                        "cannot execute mutating statement inside a read-only transaction".into(),
861                    ));
862                }
863                if self.active_txn.as_write_mut().is_some() && stmt_mutates(stmt) {
864                    self.capture_pending_snapshots();
865                }
866                let was_auto_commit = matches!(self.active_txn, ActiveTxn::None);
867                let outcome = match &mut self.active_txn {
868                    ActiveTxn::Write(wtx) => {
869                        executor::execute_in_txn(wtx, &mut self.schema, stmt, params)?
870                    }
871                    ActiveTxn::Read(rtx) => {
872                        executor::execute_with_read(rtx, &self.schema, stmt, params)?
873                    }
874                    ActiveTxn::None => executor::execute(db, &mut self.schema, stmt, params)?,
875                };
876                if was_auto_commit {
877                    invalidate_dml_caches(&self.schema, db);
878                }
879                if let Statement::DropTable(dt) = stmt {
880                    self.schema.unregister_temp_alias(&dt.name);
881                }
882                Ok(outcome)
883            }
884        }
885    }
886
887    fn clear_savepoint_state(&mut self) {
888        self.savepoint_stack.clear();
889        self.in_place_saved = None;
890    }
891
892    fn do_savepoint(&mut self, name: &str) -> Result<ExecutionResult> {
893        let wtx = self
894            .active_txn
895            .as_write_mut()
896            .ok_or(SqlError::NoActiveTransaction)?;
897
898        if self.savepoint_stack.is_empty() {
899            self.in_place_saved = Some(wtx.in_place());
900            wtx.set_in_place(false);
901        }
902
903        self.savepoint_stack.push(SavepointEntry {
904            name: name.to_string(),
905            snapshot: None,
906        });
907
908        Ok(ExecutionResult::Ok)
909    }
910
911    fn capture_pending_snapshots(&mut self) {
912        let last_pending = match self
913            .savepoint_stack
914            .iter()
915            .rposition(|e| e.snapshot.is_none())
916        {
917            Some(i) => i,
918            None => return,
919        };
920        let wtx = match self.active_txn.as_write_mut() {
921            Some(w) => w,
922            None => return,
923        };
924        let wtx_snap = wtx.begin_savepoint();
925        let schema_snap = self.schema.save_snapshot();
926
927        for i in 0..last_pending {
928            if self.savepoint_stack[i].snapshot.is_none() {
929                self.savepoint_stack[i].snapshot = Some(SavepointSnapshot {
930                    wtx_snap: wtx_snap.clone(),
931                    schema_snap: schema_snap.clone(),
932                });
933            }
934        }
935        self.savepoint_stack[last_pending].snapshot = Some(SavepointSnapshot {
936            wtx_snap,
937            schema_snap,
938        });
939    }
940
941    fn do_release(&mut self, name: &str) -> Result<ExecutionResult> {
942        if !self.active_txn.is_active() {
943            return Err(SqlError::NoActiveTransaction);
944        }
945
946        let idx = self
947            .savepoint_stack
948            .iter()
949            .rposition(|e| e.name == name)
950            .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
951        self.savepoint_stack.truncate(idx);
952
953        if self.savepoint_stack.is_empty() {
954            if let (Some(wtx), Some(original)) =
955                (self.active_txn.as_write_mut(), self.in_place_saved.take())
956            {
957                wtx.set_in_place(original);
958            }
959        }
960
961        Ok(ExecutionResult::Ok)
962    }
963
964    fn do_rollback_to(&mut self, name: &str) -> Result<ExecutionResult> {
965        if !self.active_txn.is_active() {
966            return Err(SqlError::NoActiveTransaction);
967        }
968
969        let idx = self
970            .savepoint_stack
971            .iter()
972            .rposition(|e| e.name == name)
973            .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
974
975        self.savepoint_stack.truncate(idx + 1);
976        let entry = self.savepoint_stack.last_mut().unwrap();
977        let snapshot = match entry.snapshot.take() {
978            Some(s) => s,
979            None => return Ok(ExecutionResult::Ok),
980        };
981
982        let wtx = match self.active_txn.as_write_mut() {
983            Some(w) => w,
984            None => return Err(SqlError::NoActiveTransaction),
985        };
986        wtx.restore_snapshot(snapshot.wtx_snap);
987        self.schema.restore_snapshot(snapshot.schema_snap);
988
989        Ok(ExecutionResult::Ok)
990    }
991}
992
993impl<'a> Drop for Connection<'a> {
994    fn drop(&mut self) {
995        let temp_names = std::mem::take(&mut self.inner.borrow_mut().temp_table_names);
996        if temp_names.is_empty() {
997            return;
998        }
999        if let Ok(mut wtx) = self.db.begin_write() {
1000            for prefixed in &temp_names {
1001                let _ = wtx.drop_table(prefixed.as_bytes());
1002            }
1003            let _ = wtx.commit();
1004        }
1005    }
1006}