Skip to main content

citadel_sql/
connection.rs

1//! Public SQL connection API.
2
3use std::cell::RefCell;
4use std::num::NonZeroUsize;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7
8fn generate_temp_id() -> u64 {
9    static COUNTER: AtomicU64 = AtomicU64::new(0);
10    let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
11    let nanos = (crate::datetime::now_micros() as u64) & 0xFFFF_FFFF;
12    (nanos << 32) | (counter & 0xFFFF_FFFF)
13}
14
15fn temp_storage_name(temp_id: u64, user_name: &str) -> String {
16    format!("__temp_{temp_id}_{}", user_name.to_ascii_lowercase())
17}
18
19use lru::LruCache;
20
21use citadel::Database;
22use citadel_txn::write_txn::{WriteTxn, WriteTxnSnapshot};
23
24use crate::error::{Result, SqlError};
25use crate::executor;
26use crate::parser;
27use crate::parser::{BeginAccessMode, QueryBody, SelectQuery, Statement};
28use crate::prepared::PreparedStatement;
29use crate::schema::{SchemaManager, SchemaSnapshot};
30use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
31
32const DEFAULT_CACHE_CAPACITY: usize = 64;
33
34/// On commit, evict shared caches (e.g. ANN indexes) for DML-touched tables,
35/// and stamp each table's last-DML generation marker: an index whose snapshot
36/// predates the marker is refused at lookup AND at insert, closing the
37/// build-races-a-commit window that prefix eviction alone leaves open.
38fn invalidate_dml_caches(schema: &SchemaManager, db: &Database) {
39    if !schema.has_dml_dirty() {
40        return;
41    }
42    let gen = db.manager().commit_generation();
43    let hard_invalidate = |table: &str| {
44        db.sql_cache_invalidate_prefix(&format!("ann:{table}:"));
45        let marker: std::sync::Arc<dyn std::any::Any + Send + Sync> = std::sync::Arc::new(gen);
46        schema
47            .sql_caches
48            .lock()
49            .insert(crate::executor::ann_dml_gen_key(table), marker);
50    };
51
52    let dirty = schema.drain_dml_dirty();
53    for table in &dirty.mutating {
54        hard_invalidate(table);
55    }
56    // A pure append keeps the cached index (recall tail-merges the new rows)
57    // unless it landed at/below an index snapshot.
58    for (table, min_pk) in &dirty.appends {
59        if crate::executor::ann_appends_safe(schema, table, *min_pk) {
60            continue;
61        }
62        hard_invalidate(table);
63    }
64}
65
66#[derive(Debug)]
67pub struct ScriptExecution {
68    pub completed: Vec<ExecutionResult>,
69    pub error: Option<SqlError>,
70}
71
72fn parse_fixed_offset(s: &str) -> Option<()> {
73    let s = s.trim();
74    if s.eq_ignore_ascii_case("z") || s.eq_ignore_ascii_case("utc") {
75        return Some(());
76    }
77    let bytes = s.as_bytes();
78    if bytes.is_empty() {
79        return None;
80    }
81    if !matches!(bytes[0], b'+' | b'-') {
82        return None;
83    }
84    let rest = &s[1..];
85    let (hh, mm) = if let Some((h, m)) = rest.split_once(':') {
86        (h, m)
87    } else if rest.len() == 4 {
88        (&rest[..2], &rest[2..])
89    } else if rest.len() == 2 {
90        (rest, "00")
91    } else {
92        return None;
93    };
94    let h: u32 = hh.parse().ok()?;
95    let m: u32 = mm.parse().ok()?;
96    if h > 23 || m > 59 {
97        return None;
98    }
99    Some(())
100}
101
102fn rewrite_show_triggers(sql: &str) -> Option<String> {
103    let trimmed = sql.trim();
104    let trimmed = trimmed.trim_end_matches(';').trim();
105    let lower = trimmed.to_ascii_lowercase();
106    if !lower.starts_with("show triggers") {
107        return None;
108    }
109    let after = lower["show triggers".len()..].trim_start();
110    let base = "SELECT trigger_name, event_object_table AS table_name, action_timing, \
111                event_manipulation, action_orientation, action_statement \
112                FROM information_schema.triggers";
113    if after.is_empty() {
114        return Some(format!("{base} ORDER BY trigger_name"));
115    }
116    if let Some(rest) = after.strip_prefix("on ") {
117        let table = rest.trim().trim_end_matches(';').trim();
118        if table.is_empty() {
119            return None;
120        }
121        let escaped = table.replace('\'', "''");
122        return Some(format!(
123            "{base} WHERE LOWER(event_object_table) = LOWER('{escaped}') ORDER BY trigger_name"
124        ));
125    }
126    None
127}
128
129fn rewrite_show_matviews(sql: &str) -> Option<String> {
130    let trimmed = sql.trim();
131    let trimmed = trimmed.trim_end_matches(';').trim();
132    let lower = trimmed.to_ascii_lowercase();
133    if lower != "show materialized views" {
134        return None;
135    }
136    Some(
137        "SELECT matviewname, ispopulated, hasindexes, definition \
138         FROM pg_matviews ORDER BY matviewname"
139            .to_string(),
140    )
141}
142
143fn stmt_mutates(stmt: &Statement) -> bool {
144    if matches!(
145        stmt,
146        Statement::Insert(_)
147            | Statement::Update(_)
148            | Statement::Delete(_)
149            | Statement::Truncate(_)
150            | Statement::CreateTable(_)
151            | Statement::DropTable(_)
152            | Statement::AlterTable(_)
153            | Statement::CreateIndex(_)
154            | Statement::DropIndex(_)
155            | Statement::CreateView(_)
156            | Statement::DropView(_)
157            | Statement::CreateTrigger(_)
158            | Statement::DropTrigger(_)
159            | Statement::CreateMaterializedView(_)
160            | Statement::RefreshMaterializedView(_)
161            | Statement::DropMaterializedView(_)
162    ) {
163        return true;
164    }
165    if let Statement::Select(sq) = stmt {
166        if select_query_has_dml(sq) {
167            return true;
168        }
169    }
170    false
171}
172
173fn is_txn_control(stmt: &Statement) -> bool {
174    matches!(
175        stmt,
176        Statement::Begin { .. }
177            | Statement::Commit
178            | Statement::Rollback
179            | Statement::Savepoint(_)
180            | Statement::ReleaseSavepoint(_)
181            | Statement::RollbackTo(_)
182    )
183}
184
185fn select_query_has_dml(sq: &SelectQuery) -> bool {
186    sq.ctes.iter().any(|cte| query_body_has_dml(&cte.body)) || query_body_has_dml(&sq.body)
187}
188
189fn query_body_has_dml(body: &QueryBody) -> bool {
190    match body {
191        QueryBody::Insert(_) | QueryBody::Update(_) | QueryBody::Delete(_) => true,
192        QueryBody::Compound(c) => query_body_has_dml(&c.left) || query_body_has_dml(&c.right),
193        QueryBody::Select(_) => false,
194    }
195}
196
197fn try_normalize_insert(sql: &str) -> Option<(String, Vec<Value>)> {
198    let bytes = sql.as_bytes();
199    let len = bytes.len();
200    let mut i = 0;
201
202    while i < len && bytes[i].is_ascii_whitespace() {
203        i += 1;
204    }
205    if i + 6 > len || !bytes[i..i + 6].eq_ignore_ascii_case(b"INSERT") {
206        return None;
207    }
208    i += 6;
209    if i >= len || !bytes[i].is_ascii_whitespace() {
210        return None;
211    }
212    while i < len && bytes[i].is_ascii_whitespace() {
213        i += 1;
214    }
215
216    if i + 4 > len || !bytes[i..i + 4].eq_ignore_ascii_case(b"INTO") {
217        return None;
218    }
219    i += 4;
220    if i >= len || !bytes[i].is_ascii_whitespace() {
221        return None;
222    }
223
224    let prefix_start = 0;
225    let mut values_pos = None;
226    let mut j = i;
227    while j + 6 <= len {
228        if bytes[j..j + 6].eq_ignore_ascii_case(b"VALUES")
229            && (j == 0 || !bytes[j - 1].is_ascii_alphanumeric() && bytes[j - 1] != b'_')
230            && (j + 6 >= len || !bytes[j + 6].is_ascii_alphanumeric() && bytes[j + 6] != b'_')
231        {
232            values_pos = Some(j);
233            break;
234        }
235        j += 1;
236    }
237    let values_pos = values_pos?;
238
239    let prefix = &sql[prefix_start..values_pos + 6];
240    let mut pos = values_pos + 6;
241
242    while pos < len && bytes[pos].is_ascii_whitespace() {
243        pos += 1;
244    }
245    if pos >= len || bytes[pos] != b'(' {
246        return None;
247    }
248    pos += 1;
249
250    let mut values = Vec::new();
251    let mut normalized = String::with_capacity(sql.len());
252    normalized.push_str(prefix);
253    normalized.push_str(" (");
254
255    loop {
256        while pos < len && bytes[pos].is_ascii_whitespace() {
257            pos += 1;
258        }
259        if pos >= len {
260            return None;
261        }
262
263        let param_idx = values.len() + 1;
264        if param_idx > 1 {
265            normalized.push_str(", ");
266        }
267
268        if bytes[pos] == b'\'' {
269            pos += 1;
270            let mut seg_start = pos;
271            let mut s = String::new();
272            loop {
273                if pos >= len {
274                    return None;
275                }
276                if bytes[pos] == b'\'' {
277                    s.push_str(std::str::from_utf8(&bytes[seg_start..pos]).ok()?);
278                    pos += 1;
279                    if pos < len && bytes[pos] == b'\'' {
280                        s.push('\'');
281                        pos += 1;
282                        seg_start = pos;
283                    } else {
284                        break;
285                    }
286                } else {
287                    pos += 1;
288                }
289            }
290            values.push(Value::Text(s.into()));
291        } else if bytes[pos] == b'-' || bytes[pos].is_ascii_digit() {
292            let start = pos;
293            if bytes[pos] == b'-' {
294                pos += 1;
295            }
296            while pos < len && bytes[pos].is_ascii_digit() {
297                pos += 1;
298            }
299            if pos < len && bytes[pos] == b'.' {
300                pos += 1;
301                while pos < len && bytes[pos].is_ascii_digit() {
302                    pos += 1;
303                }
304                let num: f64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
305                values.push(Value::Real(num));
306            } else {
307                let num: i64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
308                values.push(Value::Integer(num));
309            }
310        } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"NULL") {
311            let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
312            if !after.is_ascii_alphanumeric() && after != b'_' {
313                pos += 4;
314                values.push(Value::Null);
315            } else {
316                return None;
317            }
318        } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"TRUE") {
319            let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
320            if !after.is_ascii_alphanumeric() && after != b'_' {
321                pos += 4;
322                values.push(Value::Boolean(true));
323            } else {
324                return None;
325            }
326        } else if pos + 5 <= len && bytes[pos..pos + 5].eq_ignore_ascii_case(b"FALSE") {
327            let after = if pos + 5 < len { bytes[pos + 5] } else { b')' };
328            if !after.is_ascii_alphanumeric() && after != b'_' {
329                pos += 5;
330                values.push(Value::Boolean(false));
331            } else {
332                return None;
333            }
334        } else {
335            return None;
336        }
337
338        normalized.push('$');
339        normalized.push_str(&param_idx.to_string());
340
341        while pos < len && bytes[pos].is_ascii_whitespace() {
342            pos += 1;
343        }
344        if pos >= len {
345            return None;
346        }
347
348        if bytes[pos] == b',' {
349            pos += 1;
350        } else if bytes[pos] == b')' {
351            pos += 1;
352            break;
353        } else {
354            return None;
355        }
356    }
357
358    normalized.push(')');
359
360    while pos < len && (bytes[pos].is_ascii_whitespace() || bytes[pos] == b';') {
361        pos += 1;
362    }
363    if pos != len {
364        return None;
365    }
366
367    if values.is_empty() {
368        return None;
369    }
370
371    Some((normalized, values))
372}
373
374pub(crate) struct CacheEntry {
375    pub(crate) stmt: Arc<Statement>,
376    pub(crate) schema_gen: u64,
377    pub(crate) param_count: usize,
378    pub(crate) compiled: Option<Arc<dyn executor::CompiledPlan>>,
379}
380
381struct SavepointEntry {
382    name: String,
383    snapshot: Option<SavepointSnapshot>,
384}
385
386struct SavepointSnapshot {
387    wtx_snap: WriteTxnSnapshot,
388    schema_snap: SchemaSnapshot,
389}
390
391/// Active transaction held by a Connection. `None` outside BEGIN/COMMIT; `Write` for normal
392/// BEGIN (or BEGIN READ WRITE); `Read` for BEGIN READ ONLY.
393#[allow(clippy::large_enum_variant)]
394pub(crate) enum ActiveTxn<'a> {
395    None,
396    Write(WriteTxn<'a>),
397    Read(citadel_txn::read_txn::ReadTxn<'a>),
398}
399
400impl<'a> ActiveTxn<'a> {
401    fn is_none(&self) -> bool {
402        matches!(self, ActiveTxn::None)
403    }
404    fn is_active(&self) -> bool {
405        !self.is_none()
406    }
407    fn is_read_only(&self) -> bool {
408        matches!(self, ActiveTxn::Read(_))
409    }
410    fn as_write_mut(&mut self) -> Option<&mut WriteTxn<'a>> {
411        match self {
412            ActiveTxn::Write(w) => Some(w),
413            _ => None,
414        }
415    }
416    fn take(&mut self) -> ActiveTxn<'a> {
417        std::mem::replace(self, ActiveTxn::None)
418    }
419}
420
421pub(crate) struct ConnectionInner<'a> {
422    pub(crate) schema: SchemaManager,
423    active_txn: ActiveTxn<'a>,
424    savepoint_stack: Vec<SavepointEntry>,
425    in_place_saved: Option<bool>,
426    pub(crate) stmt_cache: LruCache<String, CacheEntry>,
427    txn_start_ts: Option<i64>,
428    session_timezone: String,
429    /// Namespaces TEMP tables as `__temp_<id>_<name>`. Cleaned up on Connection drop.
430    temp_id: u64,
431    temp_table_names: Vec<String>,
432}
433
434pub struct Connection<'a> {
435    pub(crate) db: &'a Database,
436    pub(crate) inner: RefCell<ConnectionInner<'a>>,
437}
438
439impl<'a> Connection<'a> {
440    pub fn open(db: &'a Database) -> Result<Self> {
441        let schema = SchemaManager::load(db)?;
442        let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
443        let temp_id = generate_temp_id();
444        Ok(Self {
445            db,
446            inner: RefCell::new(ConnectionInner {
447                schema,
448                active_txn: ActiveTxn::None,
449                savepoint_stack: Vec::new(),
450                in_place_saved: None,
451                stmt_cache,
452                txn_start_ts: None,
453                session_timezone: "UTC".to_string(),
454                temp_id,
455                temp_table_names: Vec::new(),
456            }),
457        })
458    }
459
460    /// Txn-start UTC µs inside BEGIN/COMMIT, else `None`.
461    pub fn txn_start_ts(&self) -> Option<i64> {
462        self.inner.borrow().txn_start_ts
463    }
464
465    /// Returns the session time-zone (IANA name or fixed offset). Default `"UTC"`.
466    pub fn session_timezone(&self) -> String {
467        self.inner.borrow().session_timezone.clone()
468    }
469
470    /// Set the session time-zone. Accepts IANA names, ISO-8601 offsets, `"UTC"`, `"Z"`.
471    pub fn set_session_timezone(&self, tz: &str) -> Result<()> {
472        self.inner.borrow_mut().set_session_timezone_impl(tz)
473    }
474
475    pub fn execute(&self, sql: &str) -> Result<ExecutionResult> {
476        self.inner.borrow_mut().execute_impl(self.db, sql)
477    }
478
479    pub fn execute_params(&self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
480        self.inner
481            .borrow_mut()
482            .execute_params_impl(self.db, sql, params)
483    }
484
485    /// Execute `;`-separated SQL statements. Stops at the first failure.
486    pub fn execute_script(&self, sql: &str) -> ScriptExecution {
487        let stmts = match parser::parse_sql_multi(sql) {
488            Ok(s) => s,
489            Err(e) => {
490                return ScriptExecution {
491                    completed: vec![],
492                    error: Some(e),
493                }
494            }
495        };
496        let mut completed = Vec::with_capacity(stmts.len());
497        for stmt in stmts {
498            match self.inner.borrow_mut().dispatch(self.db, &stmt, &[]) {
499                Ok(r) => completed.push(r),
500                Err(e) => {
501                    return ScriptExecution {
502                        completed,
503                        error: Some(e),
504                    }
505                }
506            }
507        }
508        ScriptExecution {
509            completed,
510            error: None,
511        }
512    }
513
514    pub fn execute_batch(&self, sql: &str) -> Result<Vec<ExecutionResult>> {
515        self.inner.borrow_mut().execute_batch_impl(self.db, sql)
516    }
517
518    pub fn query(&self, sql: &str) -> Result<QueryResult> {
519        self.query_params(sql, &[])
520    }
521
522    pub fn query_params(&self, sql: &str, params: &[Value]) -> Result<QueryResult> {
523        match self.execute_params(sql, params)? {
524            ExecutionResult::Query(qr) => Ok(qr),
525            ExecutionResult::RowsAffected(n) => Ok(QueryResult {
526                columns: vec!["rows_affected".into()],
527                rows: vec![vec![Value::Integer(n as i64)]],
528            }),
529            ExecutionResult::Ok => Ok(QueryResult {
530                columns: vec![],
531                rows: vec![],
532            }),
533        }
534    }
535
536    pub fn prepare(&self, sql: &str) -> Result<PreparedStatement<'_, 'a>> {
537        if let Some(rewritten) = rewrite_show_triggers(sql) {
538            return PreparedStatement::new(self, &rewritten);
539        }
540        if let Some(rewritten) = rewrite_show_matviews(sql) {
541            return PreparedStatement::new(self, &rewritten);
542        }
543        PreparedStatement::new(self, sql)
544    }
545
546    pub fn tables(&self) -> Vec<String> {
547        self.inner
548            .borrow()
549            .schema
550            .table_names()
551            .into_iter()
552            .map(String::from)
553            .collect()
554    }
555
556    /// Returns true if an explicit transaction is active (BEGIN was issued).
557    pub fn in_transaction(&self) -> bool {
558        self.inner.borrow().active_txn.is_active()
559    }
560
561    pub fn table_schema(&self, name: &str) -> Option<TableSchema> {
562        self.inner.borrow().schema.get(name).cloned()
563    }
564
565    pub fn refresh_schema(&self) -> Result<()> {
566        let new_schema = SchemaManager::load(self.db)?;
567        self.inner.borrow_mut().schema = new_schema;
568        Ok(())
569    }
570
571    /// Freeze the ANN index for `table.column` into a persisted segment: one
572    /// write txn scans, builds, serializes, and commits atomically; subsequent
573    /// cold attaches LOAD it (seconds) instead of rebuilding (minutes), with
574    /// the load-time scan re-proving freshness by content. The single writer
575    /// lock is held for the whole build - an offline/builder operation.
576    /// Refused inside an explicit transaction (it owns its own txn), and for
577    /// TEMP tables (their storage bypasses the DDL paths that purge segments).
578    pub fn persist_ann_index(
579        &self,
580        table: &str,
581        column: &str,
582    ) -> Result<crate::executor::AnnSegmentInfo> {
583        if self.in_transaction() {
584            return Err(SqlError::InvalidValue(
585                "persist_ann_index: not allowed inside an explicit transaction".into(),
586            ));
587        }
588        let inner = self.inner.borrow();
589        let lower = table.to_ascii_lowercase();
590        if inner.schema.resolve_temp(&lower) != lower {
591            return Err(SqlError::InvalidValue(
592                "persist_ann_index: TEMP tables are not persistable".into(),
593            ));
594        }
595        let table_schema = inner
596            .schema
597            .get(&lower)
598            .ok_or_else(|| SqlError::TableNotFound(table.to_string()))?;
599        crate::executor::persist_ann_index(self.db, &inner.schema, table_schema, column)
600    }
601
602    /// The identity of the index currently cached for `table.column`:
603    /// `(source, snapshot generation)` - `Loaded{segment_b3}` means queries are
604    /// served by the persisted segment; `Built{refusal}` carries why a segment
605    /// was rejected, if one was.
606    pub fn ann_cache_status(
607        &self,
608        table: &str,
609        column: &str,
610    ) -> Result<Option<(crate::executor::AnnIndexSource, u64)>> {
611        let inner = self.inner.borrow();
612        let table_schema = inner
613            .schema
614            .get(&table.to_ascii_lowercase())
615            .ok_or_else(|| SqlError::TableNotFound(table.to_string()))?;
616        crate::executor::ann_cache_status(&inner.schema, table_schema, column)
617    }
618}
619
620impl<'a> ConnectionInner<'a> {
621    pub(crate) fn active_txn_is_some(&self) -> bool {
622        self.active_txn.is_active()
623    }
624
625    fn set_session_timezone_impl(&mut self, tz: &str) -> Result<()> {
626        let upper = tz.to_ascii_uppercase();
627        if (upper.starts_with("UTC+") || upper.starts_with("UTC-")) && tz.len() > 3 {
628            return Err(SqlError::InvalidTimezone(format!(
629                "'{tz}' is ambiguous; use ISO-8601 offset (e.g. '+05:00') or named zone (e.g. 'Etc/GMT-5')"
630            )));
631        }
632        if jiff::tz::TimeZone::get(tz).is_err() && parse_fixed_offset(tz).is_none() {
633            return Err(SqlError::InvalidTimezone(format!(
634                "{tz}: not a known IANA zone or ISO-8601 offset (e.g. '+05:00', 'UTC', 'America/New_York')"
635            )));
636        }
637        self.session_timezone = tz.to_string();
638        Ok(())
639    }
640
641    fn execute_impl(&mut self, db: &'a Database, sql: &str) -> Result<ExecutionResult> {
642        if let Some(rewritten) = rewrite_show_triggers(sql) {
643            return self.execute_params_impl(db, &rewritten, &[]);
644        }
645        if let Some(rewritten) = rewrite_show_matviews(sql) {
646            return self.execute_params_impl(db, &rewritten, &[]);
647        }
648        if matches!(sql.as_bytes().first(), Some(b'I' | b'i')) {
649            if let Some((normalized_key, extracted)) = try_normalize_insert(sql) {
650                let gen = self.schema.generation();
651                let stmt = if let Some(entry) = self.stmt_cache.get(&normalized_key) {
652                    if entry.schema_gen == gen {
653                        Arc::clone(&entry.stmt)
654                    } else {
655                        self.parse_and_cache(normalized_key, gen)?
656                    }
657                } else {
658                    self.parse_and_cache(normalized_key, gen)?
659                };
660                return self.dispatch(db, &stmt, &extracted);
661            }
662        }
663        self.execute_params_impl(db, sql, &[])
664    }
665
666    fn execute_batch_impl(&mut self, db: &'a Database, sql: &str) -> Result<Vec<ExecutionResult>> {
667        if self.active_txn.is_active() {
668            return Err(SqlError::TransactionAlreadyActive);
669        }
670        let stmts = parser::parse_sql_multi(sql)?;
671        if stmts.iter().any(is_txn_control) {
672            return Err(SqlError::Unsupported(
673                "transaction-control statements are not allowed in execute_batch".into(),
674            ));
675        }
676
677        let wtx = db.begin_write().map_err(SqlError::Storage)?;
678        let ts = crate::datetime::txn_or_clock_micros();
679        self.active_txn = ActiveTxn::Write(wtx);
680        self.txn_start_ts = Some(ts);
681        crate::datetime::set_txn_clock(Some(ts));
682
683        let mut results = Vec::with_capacity(stmts.len());
684        for stmt in &stmts {
685            match self.dispatch(db, stmt, &[]) {
686                Ok(r) => results.push(r),
687                Err(e) => {
688                    self.abort_active_txn(db);
689                    return Err(e);
690                }
691            }
692        }
693
694        let commit = match self.active_txn.take() {
695            ActiveTxn::Write(mut wtx) => {
696                match crate::executor::helpers::drain_deferred_fk_checks(&mut wtx) {
697                    Ok(()) => wtx.commit().map_err(SqlError::Storage),
698                    Err(e) => {
699                        wtx.abort();
700                        Err(e)
701                    }
702                }
703            }
704            _ => Err(SqlError::NoActiveTransaction),
705        };
706        self.reset_txn_state();
707        match commit {
708            Ok(()) => {
709                invalidate_dml_caches(&self.schema, db);
710                Ok(results)
711            }
712            Err(e) => {
713                self.schema = SchemaManager::load(db)?;
714                Err(e)
715            }
716        }
717    }
718
719    fn reset_txn_state(&mut self) {
720        self.clear_savepoint_state();
721        self.txn_start_ts = None;
722        crate::datetime::set_txn_clock(None);
723    }
724
725    fn abort_active_txn(&mut self, db: &'a Database) {
726        if let ActiveTxn::Write(wtx) = self.active_txn.take() {
727            wtx.abort();
728        }
729        if let Ok(fresh) = SchemaManager::load(db) {
730            self.schema = fresh;
731        }
732        self.reset_txn_state();
733    }
734
735    fn execute_params_impl(
736        &mut self,
737        db: &'a Database,
738        sql: &str,
739        params: &[Value],
740    ) -> Result<ExecutionResult> {
741        let gen = self.schema.generation();
742        if self.active_txn.is_none() {
743            if let Some(entry) = self.stmt_cache.get(sql) {
744                if entry.schema_gen == gen && entry.param_count == params.len() {
745                    if let Some(plan) = entry.compiled.as_ref().map(Arc::clone) {
746                        let stmt = Arc::clone(&entry.stmt);
747                        return self.run_compiled(db, &plan, &stmt, params);
748                    }
749                }
750            }
751        }
752
753        let (stmt, param_count) = self.get_or_parse(sql)?;
754
755        if param_count != params.len() {
756            return Err(SqlError::ParameterCountMismatch {
757                expected: param_count,
758                got: params.len(),
759            });
760        }
761
762        if self.active_txn.is_none() {
763            if let Some(plan) = executor::compile(&self.schema, &stmt) {
764                if let Some(e) = self.stmt_cache.get_mut(sql) {
765                    e.compiled = Some(Arc::clone(&plan));
766                }
767                let stmt_owned = Arc::clone(&stmt);
768                return self.run_compiled(db, &plan, &stmt_owned, params);
769            }
770        }
771
772        self.dispatch(db, &stmt, params)
773    }
774
775    fn run_compiled(
776        &mut self,
777        db: &'a Database,
778        plan: &Arc<dyn executor::CompiledPlan>,
779        stmt: &Statement,
780        params: &[Value],
781    ) -> Result<ExecutionResult> {
782        use executor::compile::ActiveTxnRef;
783        let schema = &self.schema;
784        let exec = || {
785            if params.is_empty() {
786                plan.execute(db, schema, stmt, params, ActiveTxnRef::None)
787            } else {
788                crate::eval::with_scoped_params(params, || {
789                    plan.execute(db, schema, stmt, params, ActiveTxnRef::None)
790                })
791            }
792        };
793        let outcome = if plan.needs_txn_clock() {
794            let cached_ts = self
795                .txn_start_ts
796                .or_else(|| Some(crate::datetime::now_micros()));
797            crate::datetime::with_txn_clock(cached_ts, exec)
798        } else {
799            exec()
800        }?;
801        invalidate_dml_caches(&self.schema, db);
802        Ok(outcome)
803    }
804
805    pub(crate) fn parse_and_cache(
806        &mut self,
807        normalized_key: String,
808        gen: u64,
809    ) -> Result<Arc<Statement>> {
810        let stmt = Arc::new(parser::parse_sql(&normalized_key)?);
811        let param_count = parser::count_params(&stmt);
812        self.stmt_cache.put(
813            normalized_key,
814            CacheEntry {
815                stmt: Arc::clone(&stmt),
816                schema_gen: gen,
817                param_count,
818                compiled: None,
819            },
820        );
821        Ok(stmt)
822    }
823
824    pub(crate) fn get_or_parse(&mut self, sql: &str) -> Result<(Arc<Statement>, usize)> {
825        let gen = self.schema.generation();
826
827        if let Some(entry) = self.stmt_cache.get(sql) {
828            if entry.schema_gen == gen {
829                return Ok((Arc::clone(&entry.stmt), entry.param_count));
830            }
831        }
832
833        let stmt = Arc::new(parser::parse_sql(sql)?);
834        let param_count = parser::count_params(&stmt);
835
836        let cacheable = !matches!(
837            *stmt,
838            Statement::CreateTable(_)
839                | Statement::DropTable(_)
840                | Statement::CreateIndex(_)
841                | Statement::DropIndex(_)
842                | Statement::CreateView(_)
843                | Statement::DropView(_)
844                | Statement::AlterTable(_)
845        );
846
847        if cacheable {
848            self.stmt_cache.put(
849                sql.to_string(),
850                CacheEntry {
851                    stmt: Arc::clone(&stmt),
852                    schema_gen: gen,
853                    param_count,
854                    compiled: None,
855                },
856            );
857        }
858
859        Ok((stmt, param_count))
860    }
861
862    pub(crate) fn execute_prepared(
863        &mut self,
864        db: &'a Database,
865        stmt: &Statement,
866        compiled: Option<&Arc<dyn executor::CompiledPlan>>,
867        params: &[Value],
868    ) -> Result<ExecutionResult> {
869        if let Some(plan) = compiled {
870            if self.active_txn.is_none() {
871                return self.run_compiled(db, plan, stmt, params);
872            }
873            if !self.savepoint_stack.is_empty() && stmt_mutates(stmt) {
874                self.capture_pending_snapshots();
875            }
876            return self.run_compiled_in_txn(db, plan, stmt, params);
877        }
878        self.dispatch(db, stmt, params)
879    }
880
881    fn run_compiled_in_txn(
882        &mut self,
883        db: &'a Database,
884        plan: &Arc<dyn executor::CompiledPlan>,
885        stmt: &Statement,
886        params: &[Value],
887    ) -> Result<ExecutionResult> {
888        use executor::compile::ActiveTxnRef;
889        let schema = &self.schema;
890        let txn = match &mut self.active_txn {
891            ActiveTxn::Write(wtx) => ActiveTxnRef::Write(wtx),
892            ActiveTxn::Read(rtx) => ActiveTxnRef::Read(rtx),
893            ActiveTxn::None => ActiveTxnRef::None,
894        };
895        if params.is_empty() || !plan.uses_scoped_params() {
896            plan.execute(db, schema, stmt, params, txn)
897        } else {
898            crate::eval::with_scoped_params(params, || plan.execute(db, schema, stmt, params, txn))
899        }
900    }
901
902    pub(crate) fn dispatch(
903        &mut self,
904        db: &'a Database,
905        stmt: &Statement,
906        params: &[Value],
907    ) -> Result<ExecutionResult> {
908        let cached_ts = self
909            .txn_start_ts
910            .or_else(|| Some(crate::datetime::now_micros()));
911        crate::datetime::with_txn_clock(cached_ts, || {
912            if params.is_empty() {
913                self.dispatch_inner(db, stmt, params)
914            } else {
915                crate::eval::with_scoped_params(params, || self.dispatch_inner(db, stmt, params))
916            }
917        })
918    }
919
920    fn dispatch_inner(
921        &mut self,
922        db: &'a Database,
923        stmt: &Statement,
924        params: &[Value],
925    ) -> Result<ExecutionResult> {
926        match stmt {
927            Statement::Begin { access_mode } => {
928                if self.active_txn.is_active() {
929                    return Err(SqlError::TransactionAlreadyActive);
930                }
931                let ts = crate::datetime::txn_or_clock_micros();
932                match access_mode {
933                    BeginAccessMode::ReadOnly => {
934                        let rtx = db.begin_read();
935                        self.active_txn = ActiveTxn::Read(rtx);
936                    }
937                    BeginAccessMode::ReadWrite | BeginAccessMode::Default => {
938                        let wtx = db.begin_write().map_err(SqlError::Storage)?;
939                        self.active_txn = ActiveTxn::Write(wtx);
940                    }
941                }
942                self.txn_start_ts = Some(ts);
943                crate::datetime::set_txn_clock(Some(ts));
944                Ok(ExecutionResult::Ok)
945            }
946            Statement::Commit => {
947                match self.active_txn.take() {
948                    ActiveTxn::None => return Err(SqlError::NoActiveTransaction),
949                    ActiveTxn::Write(mut wtx) => {
950                        crate::executor::helpers::drain_deferred_fk_checks(&mut wtx)?;
951                        wtx.commit().map_err(SqlError::Storage)?;
952                        invalidate_dml_caches(&self.schema, db);
953                    }
954                    ActiveTxn::Read(_rtx) => {}
955                }
956                self.reset_txn_state();
957                Ok(ExecutionResult::Ok)
958            }
959            Statement::Rollback => {
960                match self.active_txn.take() {
961                    ActiveTxn::None => return Err(SqlError::NoActiveTransaction),
962                    ActiveTxn::Write(wtx) => {
963                        wtx.abort();
964                        self.schema = SchemaManager::load(db)?;
965                    }
966                    ActiveTxn::Read(_rtx) => {}
967                }
968                self.reset_txn_state();
969                Ok(ExecutionResult::Ok)
970            }
971            Statement::Savepoint(name) => self.do_savepoint(name),
972            Statement::ReleaseSavepoint(name) => self.do_release(name),
973            Statement::RollbackTo(name) => self.do_rollback_to(name),
974            Statement::SetTimezone(zone) => {
975                self.set_session_timezone_impl(zone)?;
976                Ok(ExecutionResult::Ok)
977            }
978            Statement::CreateTable(ct) if ct.temporary => {
979                if self.active_txn.is_read_only() {
980                    return Err(SqlError::Unsupported(
981                        "cannot execute mutating statement inside a read-only transaction".into(),
982                    ));
983                }
984                let user_name = ct.name.clone();
985                let prefixed = temp_storage_name(self.temp_id, &user_name);
986                if self.schema.contains(&user_name) {
987                    if ct.if_not_exists {
988                        return Ok(ExecutionResult::Ok);
989                    }
990                    return Err(SqlError::TableAlreadyExists(user_name));
991                }
992                let mut clone = ct.clone();
993                clone.name = prefixed.clone();
994                clone.temporary = false;
995                let stmt_concrete = Statement::CreateTable(clone);
996                let outcome = if let Some(wtx) = self.active_txn.as_write_mut() {
997                    executor::execute_in_txn(wtx, &mut self.schema, &stmt_concrete, params)?
998                } else {
999                    executor::execute(db, &mut self.schema, &stmt_concrete, params)?
1000                };
1001                self.schema
1002                    .register_temp_alias(&user_name, prefixed.clone());
1003                self.temp_table_names.push(prefixed);
1004                Ok(outcome)
1005            }
1006            Statement::Insert(ins) if self.active_txn.as_write_mut().is_some() => {
1007                self.capture_pending_snapshots();
1008                let wtx = self.active_txn.as_write_mut().unwrap();
1009                executor::exec_insert_in_txn(wtx, &self.schema, ins, params)
1010            }
1011            _ => {
1012                if self.active_txn.is_read_only() && stmt_mutates(stmt) {
1013                    return Err(SqlError::Unsupported(
1014                        "cannot execute mutating statement inside a read-only transaction".into(),
1015                    ));
1016                }
1017                if self.active_txn.as_write_mut().is_some() && stmt_mutates(stmt) {
1018                    self.capture_pending_snapshots();
1019                }
1020                let was_auto_commit = matches!(self.active_txn, ActiveTxn::None);
1021                let outcome = match &mut self.active_txn {
1022                    ActiveTxn::Write(wtx) => {
1023                        executor::execute_in_txn(wtx, &mut self.schema, stmt, params)?
1024                    }
1025                    ActiveTxn::Read(rtx) => {
1026                        executor::execute_with_read(rtx, &self.schema, stmt, params)?
1027                    }
1028                    ActiveTxn::None => executor::execute(db, &mut self.schema, stmt, params)?,
1029                };
1030                if was_auto_commit {
1031                    invalidate_dml_caches(&self.schema, db);
1032                }
1033                if let Statement::DropTable(dt) = stmt {
1034                    self.schema.unregister_temp_alias(&dt.name);
1035                }
1036                Ok(outcome)
1037            }
1038        }
1039    }
1040
1041    fn clear_savepoint_state(&mut self) {
1042        self.savepoint_stack.clear();
1043        self.in_place_saved = None;
1044    }
1045
1046    fn do_savepoint(&mut self, name: &str) -> Result<ExecutionResult> {
1047        let wtx = self
1048            .active_txn
1049            .as_write_mut()
1050            .ok_or(SqlError::NoActiveTransaction)?;
1051
1052        if self.savepoint_stack.is_empty() {
1053            self.in_place_saved = Some(wtx.in_place());
1054            wtx.set_in_place(false);
1055        }
1056
1057        self.savepoint_stack.push(SavepointEntry {
1058            name: name.to_string(),
1059            snapshot: None,
1060        });
1061
1062        Ok(ExecutionResult::Ok)
1063    }
1064
1065    fn capture_pending_snapshots(&mut self) {
1066        let last_pending = match self
1067            .savepoint_stack
1068            .iter()
1069            .rposition(|e| e.snapshot.is_none())
1070        {
1071            Some(i) => i,
1072            None => return,
1073        };
1074        let wtx = match self.active_txn.as_write_mut() {
1075            Some(w) => w,
1076            None => return,
1077        };
1078        let wtx_snap = wtx.begin_savepoint();
1079        let schema_snap = self.schema.save_snapshot();
1080
1081        for i in 0..last_pending {
1082            if self.savepoint_stack[i].snapshot.is_none() {
1083                self.savepoint_stack[i].snapshot = Some(SavepointSnapshot {
1084                    wtx_snap: wtx_snap.clone(),
1085                    schema_snap: schema_snap.clone(),
1086                });
1087            }
1088        }
1089        self.savepoint_stack[last_pending].snapshot = Some(SavepointSnapshot {
1090            wtx_snap,
1091            schema_snap,
1092        });
1093    }
1094
1095    fn do_release(&mut self, name: &str) -> Result<ExecutionResult> {
1096        if !self.active_txn.is_active() {
1097            return Err(SqlError::NoActiveTransaction);
1098        }
1099
1100        let idx = self
1101            .savepoint_stack
1102            .iter()
1103            .rposition(|e| e.name == name)
1104            .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
1105        self.savepoint_stack.truncate(idx);
1106
1107        if self.savepoint_stack.is_empty() {
1108            if let (Some(wtx), Some(original)) =
1109                (self.active_txn.as_write_mut(), self.in_place_saved.take())
1110            {
1111                wtx.set_in_place(original);
1112            }
1113        }
1114
1115        Ok(ExecutionResult::Ok)
1116    }
1117
1118    fn do_rollback_to(&mut self, name: &str) -> Result<ExecutionResult> {
1119        if !self.active_txn.is_active() {
1120            return Err(SqlError::NoActiveTransaction);
1121        }
1122
1123        let idx = self
1124            .savepoint_stack
1125            .iter()
1126            .rposition(|e| e.name == name)
1127            .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
1128
1129        self.savepoint_stack.truncate(idx + 1);
1130        let entry = self.savepoint_stack.last_mut().unwrap();
1131        let snapshot = match entry.snapshot.take() {
1132            Some(s) => s,
1133            None => return Ok(ExecutionResult::Ok),
1134        };
1135
1136        let wtx = match self.active_txn.as_write_mut() {
1137            Some(w) => w,
1138            None => return Err(SqlError::NoActiveTransaction),
1139        };
1140        wtx.restore_snapshot(snapshot.wtx_snap);
1141        self.schema.restore_snapshot(snapshot.schema_snap);
1142
1143        Ok(ExecutionResult::Ok)
1144    }
1145}
1146
1147impl<'a> Drop for Connection<'a> {
1148    fn drop(&mut self) {
1149        let temp_names = std::mem::take(&mut self.inner.borrow_mut().temp_table_names);
1150        if temp_names.is_empty() {
1151            return;
1152        }
1153        if let Ok(mut wtx) = self.db.begin_write() {
1154            for prefixed in &temp_names {
1155                let _ = wtx.drop_table(prefixed.as_bytes());
1156            }
1157            let _ = wtx.commit();
1158        }
1159    }
1160}
1161
1162#[cfg(test)]
1163mod tests {
1164    use super::*;
1165    use citadel::{Argon2Profile, DatabaseBuilder};
1166
1167    fn fresh_db(dir: &std::path::Path) -> citadel::Database {
1168        DatabaseBuilder::new(dir.join("t.db"))
1169            .passphrase(b"test-passphrase")
1170            .argon2_profile(Argon2Profile::Iot)
1171            .create()
1172            .unwrap()
1173    }
1174
1175    #[test]
1176    fn execute_batch_commits_all_statements() {
1177        let dir = tempfile::tempdir().unwrap();
1178        let db = fresh_db(dir.path());
1179        let conn = Connection::open(&db).unwrap();
1180        conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER)")
1181            .unwrap();
1182
1183        let results = conn
1184            .execute_batch(
1185                "INSERT INTO t VALUES (1, 10); \
1186                 INSERT INTO t VALUES (2, 20); \
1187                 UPDATE t SET n = n + 1 WHERE id = 1;",
1188            )
1189            .unwrap();
1190        assert_eq!(results.len(), 3);
1191
1192        let qr = conn.query("SELECT id, n FROM t ORDER BY id").unwrap();
1193        assert_eq!(qr.rows.len(), 2);
1194        assert_eq!(qr.rows[0], vec![Value::Integer(1), Value::Integer(11)]);
1195        assert_eq!(qr.rows[1], vec![Value::Integer(2), Value::Integer(20)]);
1196    }
1197
1198    #[test]
1199    fn execute_batch_rolls_back_whole_batch_on_error() {
1200        let dir = tempfile::tempdir().unwrap();
1201        let db = fresh_db(dir.path());
1202        let conn = Connection::open(&db).unwrap();
1203        conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER)")
1204            .unwrap();
1205        conn.execute("INSERT INTO t VALUES (1, 100)").unwrap();
1206
1207        let res =
1208            conn.execute_batch("INSERT INTO t VALUES (2, 20); INSERT INTO t VALUES (1, 999);");
1209        assert!(res.is_err());
1210
1211        let qr = conn.query("SELECT id, n FROM t ORDER BY id").unwrap();
1212        assert_eq!(qr.rows.len(), 1);
1213        assert_eq!(qr.rows[0], vec![Value::Integer(1), Value::Integer(100)]);
1214
1215        conn.execute("INSERT INTO t VALUES (3, 30)").unwrap();
1216        assert!(!conn.in_transaction());
1217    }
1218
1219    #[test]
1220    fn execute_batch_rejects_txn_control() {
1221        let dir = tempfile::tempdir().unwrap();
1222        let db = fresh_db(dir.path());
1223        let conn = Connection::open(&db).unwrap();
1224        conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY)")
1225            .unwrap();
1226        assert!(conn
1227            .execute_batch("INSERT INTO t VALUES (1); COMMIT;")
1228            .is_err());
1229        assert!(!conn.in_transaction());
1230    }
1231
1232    #[test]
1233    fn execute_batch_rejected_inside_transaction() {
1234        let dir = tempfile::tempdir().unwrap();
1235        let db = fresh_db(dir.path());
1236        let conn = Connection::open(&db).unwrap();
1237        conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY)")
1238            .unwrap();
1239        conn.execute("BEGIN").unwrap();
1240        assert!(conn.execute_batch("INSERT INTO t VALUES (1);").is_err());
1241        conn.execute("ROLLBACK").unwrap();
1242    }
1243
1244    #[test]
1245    fn streamed_expression_projection_correct() {
1246        let dir = tempfile::tempdir().unwrap();
1247        let db = fresh_db(dir.path());
1248        let conn = Connection::open(&db).unwrap();
1249        conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER)")
1250            .unwrap();
1251        conn.execute_batch(
1252            "INSERT INTO t VALUES (1, 10); INSERT INTO t VALUES (2, 20); INSERT INTO t VALUES (3, 30);",
1253        )
1254        .unwrap();
1255        let stmt = conn.prepare("SELECT id + 1, n * 2 FROM t").unwrap();
1256        let qr = stmt.query_collect(&[]).unwrap();
1257        assert_eq!(qr.rows.len(), 3);
1258        assert_eq!(qr.rows[0], vec![Value::Integer(2), Value::Integer(20)]);
1259        assert_eq!(qr.rows[1], vec![Value::Integer(3), Value::Integer(40)]);
1260        assert_eq!(qr.rows[2], vec![Value::Integer(4), Value::Integer(60)]);
1261    }
1262
1263    #[test]
1264    fn jsonb_contains_raw_predicate_correct() {
1265        let dir = tempfile::tempdir().unwrap();
1266        let db = fresh_db(dir.path());
1267        let conn = Connection::open(&db).unwrap();
1268        conn.execute("CREATE TABLE u (id INTEGER PRIMARY KEY, data JSONB)")
1269            .unwrap();
1270        conn.execute_batch(
1271            "INSERT INTO u VALUES (1, '{\"role\":\"admin\",\"x\":1}'); \
1272             INSERT INTO u VALUES (2, '{\"role\":\"user\"}'); \
1273             INSERT INTO u VALUES (3, NULL); \
1274             INSERT INTO u VALUES (4, '{\"role\":\"admin\"}');",
1275        )
1276        .unwrap();
1277        let stmt = conn
1278            .prepare("SELECT id FROM u WHERE data @> '{\"role\":\"admin\"}'::jsonb")
1279            .unwrap();
1280        let qr = stmt.query_collect(&[]).unwrap();
1281        let mut ids: Vec<i64> = qr
1282            .rows
1283            .iter()
1284            .map(|r| match r[0] {
1285                Value::Integer(i) => i,
1286                _ => -1,
1287            })
1288            .collect();
1289        ids.sort_unstable();
1290        assert_eq!(ids, vec![1, 4]);
1291    }
1292}