Skip to main content

citadel_sql/
connection.rs

1//! Public SQL connection API.
2
3use std::cell::RefCell;
4use std::num::NonZeroUsize;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7
8fn generate_temp_id() -> u64 {
9    static COUNTER: AtomicU64 = AtomicU64::new(0);
10    let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
11    let nanos = (crate::datetime::now_micros() as u64) & 0xFFFF_FFFF;
12    (nanos << 32) | (counter & 0xFFFF_FFFF)
13}
14
15fn temp_storage_name(temp_id: u64, user_name: &str) -> String {
16    format!("__temp_{temp_id}_{}", user_name.to_ascii_lowercase())
17}
18
19use lru::LruCache;
20
21use citadel::Database;
22use citadel_txn::write_txn::{WriteTxn, WriteTxnSnapshot};
23
24use crate::error::{Result, SqlError};
25use crate::executor;
26use crate::parser;
27use crate::parser::{BeginAccessMode, QueryBody, SelectQuery, Statement};
28use crate::prepared::PreparedStatement;
29use crate::schema::{SchemaManager, SchemaSnapshot};
30use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
31
32const DEFAULT_CACHE_CAPACITY: usize = 64;
33
34/// On commit, evict shared caches (e.g. ANN indexes) for DML-touched tables,
35/// and stamp each table's last-DML generation marker: an index whose snapshot
36/// predates the marker is refused at lookup AND at insert, closing the
37/// build-races-a-commit window that prefix eviction alone leaves open.
38fn invalidate_dml_caches(schema: &SchemaManager, db: &Database) {
39    if !schema.has_dml_dirty() {
40        return;
41    }
42    let gen = db.manager().commit_generation();
43    let hard_invalidate = |table: &str| {
44        db.sql_cache_invalidate_prefix(&format!("ann:{table}:"));
45        let marker: std::sync::Arc<dyn std::any::Any + Send + Sync> = std::sync::Arc::new(gen);
46        schema
47            .sql_caches
48            .lock()
49            .insert(crate::executor::ann_dml_gen_key(table), marker);
50    };
51
52    let dirty = schema.drain_dml_dirty();
53    for table in &dirty.mutating {
54        hard_invalidate(table);
55    }
56    // A pure append keeps the cached index (recall tail-merges the new rows)
57    // unless it landed at/below an index snapshot.
58    for (table, min_pk) in &dirty.appends {
59        if crate::executor::ann_appends_safe(schema, table, *min_pk) {
60            continue;
61        }
62        hard_invalidate(table);
63    }
64}
65
66#[derive(Debug)]
67pub struct ScriptExecution {
68    pub completed: Vec<ExecutionResult>,
69    pub error: Option<SqlError>,
70}
71
72fn parse_fixed_offset(s: &str) -> Option<()> {
73    let s = s.trim();
74    if s.eq_ignore_ascii_case("z") || s.eq_ignore_ascii_case("utc") {
75        return Some(());
76    }
77    let bytes = s.as_bytes();
78    if bytes.is_empty() {
79        return None;
80    }
81    if !matches!(bytes[0], b'+' | b'-') {
82        return None;
83    }
84    let rest = &s[1..];
85    let (hh, mm) = if let Some((h, m)) = rest.split_once(':') {
86        (h, m)
87    } else if rest.len() == 4 {
88        (&rest[..2], &rest[2..])
89    } else if rest.len() == 2 {
90        (rest, "00")
91    } else {
92        return None;
93    };
94    let h: u32 = hh.parse().ok()?;
95    let m: u32 = mm.parse().ok()?;
96    if h > 23 || m > 59 {
97        return None;
98    }
99    Some(())
100}
101
102fn rewrite_show_triggers(sql: &str) -> Option<String> {
103    let trimmed = sql.trim();
104    let trimmed = trimmed.trim_end_matches(';').trim();
105    let lower = trimmed.to_ascii_lowercase();
106    if !lower.starts_with("show triggers") {
107        return None;
108    }
109    let after = lower["show triggers".len()..].trim_start();
110    let base = "SELECT trigger_name, event_object_table AS table_name, action_timing, \
111                event_manipulation, action_orientation, action_statement \
112                FROM information_schema.triggers";
113    if after.is_empty() {
114        return Some(format!("{base} ORDER BY trigger_name"));
115    }
116    if let Some(rest) = after.strip_prefix("on ") {
117        let table = rest.trim().trim_end_matches(';').trim();
118        if table.is_empty() {
119            return None;
120        }
121        let escaped = table.replace('\'', "''");
122        return Some(format!(
123            "{base} WHERE LOWER(event_object_table) = LOWER('{escaped}') ORDER BY trigger_name"
124        ));
125    }
126    None
127}
128
129fn rewrite_show_matviews(sql: &str) -> Option<String> {
130    let trimmed = sql.trim();
131    let trimmed = trimmed.trim_end_matches(';').trim();
132    let lower = trimmed.to_ascii_lowercase();
133    if lower != "show materialized views" {
134        return None;
135    }
136    Some(
137        "SELECT matviewname, ispopulated, hasindexes, definition \
138         FROM pg_matviews ORDER BY matviewname"
139            .to_string(),
140    )
141}
142
143fn stmt_mutates(stmt: &Statement) -> bool {
144    if matches!(
145        stmt,
146        Statement::Insert(_)
147            | Statement::Update(_)
148            | Statement::Delete(_)
149            | Statement::Truncate(_)
150            | Statement::CreateTable(_)
151            | Statement::DropTable(_)
152            | Statement::AlterTable(_)
153            | Statement::CreateIndex(_)
154            | Statement::DropIndex(_)
155            | Statement::CreateView(_)
156            | Statement::DropView(_)
157            | Statement::CreateTrigger(_)
158            | Statement::DropTrigger(_)
159            | Statement::CreateMaterializedView(_)
160            | Statement::RefreshMaterializedView(_)
161            | Statement::DropMaterializedView(_)
162    ) {
163        return true;
164    }
165    if let Statement::Select(sq) = stmt {
166        if select_query_has_dml(sq) {
167            return true;
168        }
169    }
170    false
171}
172
173fn is_txn_control(stmt: &Statement) -> bool {
174    matches!(
175        stmt,
176        Statement::Begin { .. }
177            | Statement::Commit
178            | Statement::Rollback
179            | Statement::Savepoint(_)
180            | Statement::ReleaseSavepoint(_)
181            | Statement::RollbackTo(_)
182    )
183}
184
185fn select_query_has_dml(sq: &SelectQuery) -> bool {
186    sq.ctes.iter().any(|cte| query_body_has_dml(&cte.body)) || query_body_has_dml(&sq.body)
187}
188
189fn query_body_has_dml(body: &QueryBody) -> bool {
190    match body {
191        QueryBody::Insert(_) | QueryBody::Update(_) | QueryBody::Delete(_) => true,
192        QueryBody::Compound(c) => query_body_has_dml(&c.left) || query_body_has_dml(&c.right),
193        QueryBody::Select(_) => false,
194    }
195}
196
197fn try_normalize_insert(sql: &str) -> Option<(String, Vec<Value>)> {
198    let bytes = sql.as_bytes();
199    let len = bytes.len();
200    let mut i = 0;
201
202    while i < len && bytes[i].is_ascii_whitespace() {
203        i += 1;
204    }
205    if i + 6 > len || !bytes[i..i + 6].eq_ignore_ascii_case(b"INSERT") {
206        return None;
207    }
208    i += 6;
209    if i >= len || !bytes[i].is_ascii_whitespace() {
210        return None;
211    }
212    while i < len && bytes[i].is_ascii_whitespace() {
213        i += 1;
214    }
215
216    if i + 4 > len || !bytes[i..i + 4].eq_ignore_ascii_case(b"INTO") {
217        return None;
218    }
219    i += 4;
220    if i >= len || !bytes[i].is_ascii_whitespace() {
221        return None;
222    }
223
224    let prefix_start = 0;
225    let mut values_pos = None;
226    let mut j = i;
227    while j + 6 <= len {
228        if bytes[j..j + 6].eq_ignore_ascii_case(b"VALUES")
229            && (j == 0 || !bytes[j - 1].is_ascii_alphanumeric() && bytes[j - 1] != b'_')
230            && (j + 6 >= len || !bytes[j + 6].is_ascii_alphanumeric() && bytes[j + 6] != b'_')
231        {
232            values_pos = Some(j);
233            break;
234        }
235        j += 1;
236    }
237    let values_pos = values_pos?;
238
239    let prefix = &sql[prefix_start..values_pos + 6];
240    let mut pos = values_pos + 6;
241
242    while pos < len && bytes[pos].is_ascii_whitespace() {
243        pos += 1;
244    }
245    if pos >= len || bytes[pos] != b'(' {
246        return None;
247    }
248    pos += 1;
249
250    let mut values = Vec::new();
251    let mut normalized = String::with_capacity(sql.len());
252    normalized.push_str(prefix);
253    normalized.push_str(" (");
254
255    loop {
256        while pos < len && bytes[pos].is_ascii_whitespace() {
257            pos += 1;
258        }
259        if pos >= len {
260            return None;
261        }
262
263        let param_idx = values.len() + 1;
264        if param_idx > 1 {
265            normalized.push_str(", ");
266        }
267
268        if bytes[pos] == b'\'' {
269            pos += 1;
270            let mut seg_start = pos;
271            let mut s = String::new();
272            loop {
273                if pos >= len {
274                    return None;
275                }
276                if bytes[pos] == b'\'' {
277                    s.push_str(std::str::from_utf8(&bytes[seg_start..pos]).ok()?);
278                    pos += 1;
279                    if pos < len && bytes[pos] == b'\'' {
280                        s.push('\'');
281                        pos += 1;
282                        seg_start = pos;
283                    } else {
284                        break;
285                    }
286                } else {
287                    pos += 1;
288                }
289            }
290            values.push(Value::Text(s.into()));
291        } else if bytes[pos] == b'-' || bytes[pos].is_ascii_digit() {
292            let start = pos;
293            if bytes[pos] == b'-' {
294                pos += 1;
295            }
296            while pos < len && bytes[pos].is_ascii_digit() {
297                pos += 1;
298            }
299            if pos < len && bytes[pos] == b'.' {
300                pos += 1;
301                while pos < len && bytes[pos].is_ascii_digit() {
302                    pos += 1;
303                }
304                let num: f64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
305                values.push(Value::Real(num));
306            } else {
307                let num: i64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
308                values.push(Value::Integer(num));
309            }
310        } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"NULL") {
311            let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
312            if !after.is_ascii_alphanumeric() && after != b'_' {
313                pos += 4;
314                values.push(Value::Null);
315            } else {
316                return None;
317            }
318        } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"TRUE") {
319            let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
320            if !after.is_ascii_alphanumeric() && after != b'_' {
321                pos += 4;
322                values.push(Value::Boolean(true));
323            } else {
324                return None;
325            }
326        } else if pos + 5 <= len && bytes[pos..pos + 5].eq_ignore_ascii_case(b"FALSE") {
327            let after = if pos + 5 < len { bytes[pos + 5] } else { b')' };
328            if !after.is_ascii_alphanumeric() && after != b'_' {
329                pos += 5;
330                values.push(Value::Boolean(false));
331            } else {
332                return None;
333            }
334        } else {
335            return None;
336        }
337
338        normalized.push('$');
339        normalized.push_str(&param_idx.to_string());
340
341        while pos < len && bytes[pos].is_ascii_whitespace() {
342            pos += 1;
343        }
344        if pos >= len {
345            return None;
346        }
347
348        if bytes[pos] == b',' {
349            pos += 1;
350        } else if bytes[pos] == b')' {
351            pos += 1;
352            break;
353        } else {
354            return None;
355        }
356    }
357
358    normalized.push(')');
359
360    while pos < len && (bytes[pos].is_ascii_whitespace() || bytes[pos] == b';') {
361        pos += 1;
362    }
363    if pos != len {
364        return None;
365    }
366
367    if values.is_empty() {
368        return None;
369    }
370
371    Some((normalized, values))
372}
373
374pub(crate) struct CacheEntry {
375    pub(crate) stmt: Arc<Statement>,
376    pub(crate) schema_gen: u64,
377    pub(crate) param_count: usize,
378    pub(crate) compiled: Option<Arc<dyn executor::CompiledPlan>>,
379}
380
381struct SavepointEntry {
382    name: String,
383    snapshot: Option<SavepointSnapshot>,
384}
385
386struct SavepointSnapshot {
387    wtx_snap: WriteTxnSnapshot,
388    schema_snap: SchemaSnapshot,
389}
390
391/// Active transaction held by a Connection. `None` outside BEGIN/COMMIT; `Write` for normal
392/// BEGIN (or BEGIN READ WRITE); `Read` for BEGIN READ ONLY.
393#[allow(clippy::large_enum_variant)]
394pub(crate) enum ActiveTxn<'a> {
395    None,
396    Write(WriteTxn<'a>),
397    Read(citadel_txn::read_txn::ReadTxn<'a>),
398}
399
400impl<'a> ActiveTxn<'a> {
401    fn is_none(&self) -> bool {
402        matches!(self, ActiveTxn::None)
403    }
404    fn is_active(&self) -> bool {
405        !self.is_none()
406    }
407    fn is_read_only(&self) -> bool {
408        matches!(self, ActiveTxn::Read(_))
409    }
410    fn as_write_mut(&mut self) -> Option<&mut WriteTxn<'a>> {
411        match self {
412            ActiveTxn::Write(w) => Some(w),
413            _ => None,
414        }
415    }
416    fn take(&mut self) -> ActiveTxn<'a> {
417        std::mem::replace(self, ActiveTxn::None)
418    }
419}
420
421pub(crate) struct ConnectionInner<'a> {
422    pub(crate) schema: SchemaManager,
423    active_txn: ActiveTxn<'a>,
424    savepoint_stack: Vec<SavepointEntry>,
425    in_place_saved: Option<bool>,
426    pub(crate) stmt_cache: LruCache<String, CacheEntry>,
427    txn_start_ts: Option<i64>,
428    session_timezone: String,
429    /// Namespaces TEMP tables as `__temp_<id>_<name>`. Cleaned up on Connection drop.
430    temp_id: u64,
431    temp_table_names: Vec<String>,
432}
433
434pub struct Connection<'a> {
435    pub(crate) db: &'a Database,
436    pub(crate) inner: RefCell<ConnectionInner<'a>>,
437}
438
439impl<'a> Connection<'a> {
440    pub fn open(db: &'a Database) -> Result<Self> {
441        let schema = SchemaManager::load(db)?;
442        let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
443        let temp_id = generate_temp_id();
444        Ok(Self {
445            db,
446            inner: RefCell::new(ConnectionInner {
447                schema,
448                active_txn: ActiveTxn::None,
449                savepoint_stack: Vec::new(),
450                in_place_saved: None,
451                stmt_cache,
452                txn_start_ts: None,
453                session_timezone: "UTC".to_string(),
454                temp_id,
455                temp_table_names: Vec::new(),
456            }),
457        })
458    }
459
460    /// Txn-start UTC µs inside BEGIN/COMMIT, else `None`.
461    pub fn txn_start_ts(&self) -> Option<i64> {
462        self.inner.borrow().txn_start_ts
463    }
464
465    /// Returns the session time-zone (IANA name or fixed offset). Default `"UTC"`.
466    pub fn session_timezone(&self) -> String {
467        self.inner.borrow().session_timezone.clone()
468    }
469
470    /// Set the session time-zone. Accepts IANA names, ISO-8601 offsets, `"UTC"`, `"Z"`.
471    pub fn set_session_timezone(&self, tz: &str) -> Result<()> {
472        self.inner.borrow_mut().set_session_timezone_impl(tz)
473    }
474
475    pub fn execute(&self, sql: &str) -> Result<ExecutionResult> {
476        self.inner.borrow_mut().execute_impl(self.db, sql)
477    }
478
479    pub fn execute_params(&self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
480        self.inner
481            .borrow_mut()
482            .execute_params_impl(self.db, sql, params)
483    }
484
485    /// Execute `;`-separated SQL statements. Stops at the first failure.
486    pub fn execute_script(&self, sql: &str) -> ScriptExecution {
487        let stmts = match parser::parse_sql_multi(sql) {
488            Ok(s) => s,
489            Err(e) => {
490                return ScriptExecution {
491                    completed: vec![],
492                    error: Some(e),
493                }
494            }
495        };
496        let mut completed = Vec::with_capacity(stmts.len());
497        for stmt in stmts {
498            match self.inner.borrow_mut().dispatch(self.db, &stmt, &[]) {
499                Ok(r) => completed.push(r),
500                Err(e) => {
501                    return ScriptExecution {
502                        completed,
503                        error: Some(e),
504                    }
505                }
506            }
507        }
508        ScriptExecution {
509            completed,
510            error: None,
511        }
512    }
513
514    pub fn execute_batch(&self, sql: &str) -> Result<Vec<ExecutionResult>> {
515        self.inner.borrow_mut().execute_batch_impl(self.db, sql)
516    }
517
518    pub fn query(&self, sql: &str) -> Result<QueryResult> {
519        self.query_params(sql, &[])
520    }
521
522    pub fn query_params(&self, sql: &str, params: &[Value]) -> Result<QueryResult> {
523        match self.execute_params(sql, params)? {
524            ExecutionResult::Query(qr) => Ok(qr),
525            ExecutionResult::RowsAffected(n) => Ok(QueryResult {
526                columns: vec!["rows_affected".into()],
527                rows: vec![vec![Value::Integer(n as i64)]],
528            }),
529            ExecutionResult::Ok => Ok(QueryResult {
530                columns: vec![],
531                rows: vec![],
532            }),
533        }
534    }
535
536    pub fn prepare(&self, sql: &str) -> Result<PreparedStatement<'_, 'a>> {
537        if let Some(rewritten) = rewrite_show_triggers(sql) {
538            return PreparedStatement::new(self, &rewritten);
539        }
540        if let Some(rewritten) = rewrite_show_matviews(sql) {
541            return PreparedStatement::new(self, &rewritten);
542        }
543        PreparedStatement::new(self, sql)
544    }
545
546    pub fn tables(&self) -> Vec<String> {
547        self.inner
548            .borrow()
549            .schema
550            .table_names()
551            .into_iter()
552            .map(String::from)
553            .collect()
554    }
555
556    /// Returns true if an explicit transaction is active (BEGIN was issued).
557    pub fn in_transaction(&self) -> bool {
558        self.inner.borrow().active_txn.is_active()
559    }
560
561    pub fn table_schema(&self, name: &str) -> Option<TableSchema> {
562        self.inner.borrow().schema.get(name).cloned()
563    }
564
565    pub fn refresh_schema(&self) -> Result<()> {
566        let new_schema = SchemaManager::load(self.db)?;
567        self.inner.borrow_mut().schema = new_schema;
568        Ok(())
569    }
570
571    /// Freeze the ANN index for `table.column` into a persisted segment: one
572    /// write txn scans, builds, serializes, and commits atomically; subsequent
573    /// cold attaches LOAD it (seconds) instead of rebuilding (minutes), with
574    /// the load-time scan re-proving freshness by content. The single writer
575    /// lock is held for the whole build - an offline/builder operation.
576    /// Refused inside an explicit transaction (it owns its own txn), and for
577    /// TEMP tables (their storage bypasses the DDL paths that purge segments).
578    pub fn persist_ann_index(
579        &self,
580        table: &str,
581        column: &str,
582    ) -> Result<crate::executor::AnnSegmentInfo> {
583        if self.in_transaction() {
584            return Err(SqlError::InvalidValue(
585                "persist_ann_index: not allowed inside an explicit transaction".into(),
586            ));
587        }
588        let inner = self.inner.borrow();
589        let lower = table.to_ascii_lowercase();
590        if inner.schema.resolve_temp(&lower) != lower {
591            return Err(SqlError::InvalidValue(
592                "persist_ann_index: TEMP tables are not persistable".into(),
593            ));
594        }
595        let table_schema = inner
596            .schema
597            .get(&lower)
598            .ok_or_else(|| SqlError::TableNotFound(table.to_string()))?;
599        crate::executor::persist_ann_index(self.db, &inner.schema, table_schema, column)
600    }
601
602    /// The identity of the index currently cached for `table.column`:
603    /// `(source, snapshot generation)` - `Loaded{segment_b3}` means queries are
604    /// served by the persisted segment; `Built{refusal}` carries why a segment
605    /// was rejected, if one was.
606    pub fn ann_cache_status(
607        &self,
608        table: &str,
609        column: &str,
610    ) -> Result<Option<(crate::executor::AnnIndexSource, u64)>> {
611        let inner = self.inner.borrow();
612        let table_schema = inner
613            .schema
614            .get(&table.to_ascii_lowercase())
615            .ok_or_else(|| SqlError::TableNotFound(table.to_string()))?;
616        crate::executor::ann_cache_status(&inner.schema, table_schema, column)
617    }
618}
619
620impl<'a> ConnectionInner<'a> {
621    pub(crate) fn active_txn_is_some(&self) -> bool {
622        self.active_txn.is_active()
623    }
624
625    fn set_session_timezone_impl(&mut self, tz: &str) -> Result<()> {
626        let upper = tz.to_ascii_uppercase();
627        if (upper.starts_with("UTC+") || upper.starts_with("UTC-")) && tz.len() > 3 {
628            return Err(SqlError::InvalidTimezone(format!(
629                "'{tz}' is ambiguous; use ISO-8601 offset (e.g. '+05:00') or named zone (e.g. 'Etc/GMT-5')"
630            )));
631        }
632        if jiff::tz::TimeZone::get(tz).is_err() && parse_fixed_offset(tz).is_none() {
633            return Err(SqlError::InvalidTimezone(format!(
634                "{tz}: not a known IANA zone or ISO-8601 offset (e.g. '+05:00', 'UTC', 'America/New_York')"
635            )));
636        }
637        self.session_timezone = tz.to_string();
638        Ok(())
639    }
640
641    fn execute_impl(&mut self, db: &'a Database, sql: &str) -> Result<ExecutionResult> {
642        if let Some(rewritten) = rewrite_show_triggers(sql) {
643            return self.execute_params_impl(db, &rewritten, &[]);
644        }
645        if let Some(rewritten) = rewrite_show_matviews(sql) {
646            return self.execute_params_impl(db, &rewritten, &[]);
647        }
648        if matches!(sql.as_bytes().first(), Some(b'I' | b'i')) {
649            if let Some((normalized_key, extracted)) = try_normalize_insert(sql) {
650                let gen = self.schema.generation();
651                let stmt = if let Some(entry) = self.stmt_cache.get(&normalized_key) {
652                    if entry.schema_gen == gen {
653                        Arc::clone(&entry.stmt)
654                    } else {
655                        self.parse_and_cache(normalized_key, gen)?
656                    }
657                } else {
658                    self.parse_and_cache(normalized_key, gen)?
659                };
660                return self.dispatch(db, &stmt, &extracted);
661            }
662        }
663        self.execute_params_impl(db, sql, &[])
664    }
665
666    fn execute_batch_impl(&mut self, db: &'a Database, sql: &str) -> Result<Vec<ExecutionResult>> {
667        if self.active_txn.is_active() {
668            return Err(SqlError::TransactionAlreadyActive);
669        }
670        let stmts = parser::parse_sql_multi(sql)?;
671        if stmts.iter().any(is_txn_control) {
672            return Err(SqlError::Unsupported(
673                "transaction-control statements are not allowed in execute_batch".into(),
674            ));
675        }
676
677        let wtx = db.begin_write().map_err(SqlError::Storage)?;
678        let ts = crate::datetime::txn_or_clock_micros();
679        self.active_txn = ActiveTxn::Write(wtx);
680        self.txn_start_ts = Some(ts);
681        crate::datetime::set_txn_clock(Some(ts));
682
683        let mut results = Vec::with_capacity(stmts.len());
684        for stmt in &stmts {
685            match self.dispatch(db, stmt, &[]) {
686                Ok(r) => results.push(r),
687                Err(e) => {
688                    self.abort_active_txn(db);
689                    return Err(e);
690                }
691            }
692        }
693
694        let commit = match self.active_txn.take() {
695            ActiveTxn::Write(mut wtx) => {
696                match crate::executor::helpers::drain_deferred_fk_checks(&mut wtx) {
697                    Ok(()) => wtx.commit().map_err(SqlError::Storage),
698                    Err(e) => {
699                        wtx.abort();
700                        Err(e)
701                    }
702                }
703            }
704            _ => Err(SqlError::NoActiveTransaction),
705        };
706        self.reset_txn_state();
707        match commit {
708            Ok(()) => {
709                invalidate_dml_caches(&self.schema, db);
710                Ok(results)
711            }
712            Err(e) => {
713                self.schema = SchemaManager::load(db)?;
714                Err(e)
715            }
716        }
717    }
718
719    fn reset_txn_state(&mut self) {
720        self.clear_savepoint_state();
721        self.txn_start_ts = None;
722        crate::datetime::set_txn_clock(None);
723    }
724
725    fn abort_active_txn(&mut self, db: &'a Database) {
726        if let ActiveTxn::Write(wtx) = self.active_txn.take() {
727            wtx.abort();
728        }
729        if let Ok(mut fresh) = SchemaManager::load(db) {
730            fresh.bump_generation_past(self.schema.generation());
731            self.schema = fresh;
732        }
733        self.reset_txn_state();
734    }
735
736    fn execute_params_impl(
737        &mut self,
738        db: &'a Database,
739        sql: &str,
740        params: &[Value],
741    ) -> Result<ExecutionResult> {
742        let gen = self.schema.generation();
743        if self.active_txn.is_none() {
744            if let Some(entry) = self.stmt_cache.get(sql) {
745                if entry.schema_gen == gen && entry.param_count == params.len() {
746                    if let Some(plan) = entry.compiled.as_ref().map(Arc::clone) {
747                        let stmt = Arc::clone(&entry.stmt);
748                        return self.run_compiled(db, &plan, &stmt, params);
749                    }
750                }
751            }
752        }
753
754        let (stmt, param_count) = self.get_or_parse(sql)?;
755
756        if param_count != params.len() {
757            return Err(SqlError::ParameterCountMismatch {
758                expected: param_count,
759                got: params.len(),
760            });
761        }
762
763        if self.active_txn.is_none() {
764            if let Some(plan) = executor::compile(&self.schema, &stmt) {
765                if let Some(e) = self.stmt_cache.get_mut(sql) {
766                    e.compiled = Some(Arc::clone(&plan));
767                }
768                let stmt_owned = Arc::clone(&stmt);
769                return self.run_compiled(db, &plan, &stmt_owned, params);
770            }
771        }
772
773        self.dispatch(db, &stmt, params)
774    }
775
776    fn run_compiled(
777        &mut self,
778        db: &'a Database,
779        plan: &Arc<dyn executor::CompiledPlan>,
780        stmt: &Statement,
781        params: &[Value],
782    ) -> Result<ExecutionResult> {
783        use executor::compile::ActiveTxnRef;
784        let schema = &self.schema;
785        let exec = || {
786            if params.is_empty() {
787                plan.execute(db, schema, stmt, params, ActiveTxnRef::None)
788            } else {
789                crate::eval::with_scoped_params(params, || {
790                    plan.execute(db, schema, stmt, params, ActiveTxnRef::None)
791                })
792            }
793        };
794        let outcome = if plan.needs_txn_clock() {
795            let cached_ts = self
796                .txn_start_ts
797                .or_else(|| Some(crate::datetime::now_micros()));
798            crate::datetime::with_txn_clock(cached_ts, exec)
799        } else {
800            exec()
801        }?;
802        invalidate_dml_caches(&self.schema, db);
803        Ok(outcome)
804    }
805
806    pub(crate) fn parse_and_cache(
807        &mut self,
808        normalized_key: String,
809        gen: u64,
810    ) -> Result<Arc<Statement>> {
811        let stmt = Arc::new(parser::parse_sql(&normalized_key)?);
812        let param_count = parser::count_params(&stmt);
813        self.stmt_cache.put(
814            normalized_key,
815            CacheEntry {
816                stmt: Arc::clone(&stmt),
817                schema_gen: gen,
818                param_count,
819                compiled: None,
820            },
821        );
822        Ok(stmt)
823    }
824
825    pub(crate) fn get_or_parse(&mut self, sql: &str) -> Result<(Arc<Statement>, usize)> {
826        let gen = self.schema.generation();
827
828        if let Some(entry) = self.stmt_cache.get(sql) {
829            if entry.schema_gen == gen {
830                return Ok((Arc::clone(&entry.stmt), entry.param_count));
831            }
832        }
833
834        let stmt = Arc::new(parser::parse_sql(sql)?);
835        let param_count = parser::count_params(&stmt);
836
837        let cacheable = !matches!(
838            *stmt,
839            Statement::CreateTable(_)
840                | Statement::DropTable(_)
841                | Statement::CreateIndex(_)
842                | Statement::DropIndex(_)
843                | Statement::CreateView(_)
844                | Statement::DropView(_)
845                | Statement::AlterTable(_)
846        );
847
848        if cacheable {
849            self.stmt_cache.put(
850                sql.to_string(),
851                CacheEntry {
852                    stmt: Arc::clone(&stmt),
853                    schema_gen: gen,
854                    param_count,
855                    compiled: None,
856                },
857            );
858        }
859
860        Ok((stmt, param_count))
861    }
862
863    pub(crate) fn execute_prepared(
864        &mut self,
865        db: &'a Database,
866        stmt: &Statement,
867        compiled: Option<&Arc<dyn executor::CompiledPlan>>,
868        params: &[Value],
869    ) -> Result<ExecutionResult> {
870        if let Some(plan) = compiled {
871            if self.active_txn.is_none() {
872                return self.run_compiled(db, plan, stmt, params);
873            }
874            if !self.savepoint_stack.is_empty() && stmt_mutates(stmt) {
875                self.capture_pending_snapshots();
876            }
877            return self.run_compiled_in_txn(db, plan, stmt, params);
878        }
879        self.dispatch(db, stmt, params)
880    }
881
882    fn run_compiled_in_txn(
883        &mut self,
884        db: &'a Database,
885        plan: &Arc<dyn executor::CompiledPlan>,
886        stmt: &Statement,
887        params: &[Value],
888    ) -> Result<ExecutionResult> {
889        use executor::compile::ActiveTxnRef;
890        let schema = &self.schema;
891        let txn = match &mut self.active_txn {
892            ActiveTxn::Write(wtx) => ActiveTxnRef::Write(wtx),
893            ActiveTxn::Read(rtx) => ActiveTxnRef::Read(rtx),
894            ActiveTxn::None => ActiveTxnRef::None,
895        };
896        if params.is_empty() || !plan.uses_scoped_params() {
897            plan.execute(db, schema, stmt, params, txn)
898        } else {
899            crate::eval::with_scoped_params(params, || plan.execute(db, schema, stmt, params, txn))
900        }
901    }
902
903    pub(crate) fn dispatch(
904        &mut self,
905        db: &'a Database,
906        stmt: &Statement,
907        params: &[Value],
908    ) -> Result<ExecutionResult> {
909        let cached_ts = self
910            .txn_start_ts
911            .or_else(|| Some(crate::datetime::now_micros()));
912        crate::datetime::with_txn_clock(cached_ts, || {
913            if params.is_empty() {
914                self.dispatch_inner(db, stmt, params)
915            } else {
916                crate::eval::with_scoped_params(params, || self.dispatch_inner(db, stmt, params))
917            }
918        })
919    }
920
921    fn dispatch_inner(
922        &mut self,
923        db: &'a Database,
924        stmt: &Statement,
925        params: &[Value],
926    ) -> Result<ExecutionResult> {
927        match stmt {
928            Statement::Begin { access_mode } => {
929                if self.active_txn.is_active() {
930                    return Err(SqlError::TransactionAlreadyActive);
931                }
932                let ts = crate::datetime::txn_or_clock_micros();
933                match access_mode {
934                    BeginAccessMode::ReadOnly => {
935                        let rtx = db.begin_read();
936                        self.active_txn = ActiveTxn::Read(rtx);
937                    }
938                    BeginAccessMode::ReadWrite | BeginAccessMode::Default => {
939                        let wtx = db.begin_write().map_err(SqlError::Storage)?;
940                        self.active_txn = ActiveTxn::Write(wtx);
941                    }
942                }
943                self.txn_start_ts = Some(ts);
944                crate::datetime::set_txn_clock(Some(ts));
945                Ok(ExecutionResult::Ok)
946            }
947            Statement::Commit => {
948                match self.active_txn.take() {
949                    ActiveTxn::None => return Err(SqlError::NoActiveTransaction),
950                    ActiveTxn::Write(mut wtx) => {
951                        crate::executor::helpers::drain_deferred_fk_checks(&mut wtx)?;
952                        wtx.commit().map_err(SqlError::Storage)?;
953                        invalidate_dml_caches(&self.schema, db);
954                    }
955                    ActiveTxn::Read(_rtx) => {}
956                }
957                self.reset_txn_state();
958                Ok(ExecutionResult::Ok)
959            }
960            Statement::Rollback => {
961                match self.active_txn.take() {
962                    ActiveTxn::None => return Err(SqlError::NoActiveTransaction),
963                    ActiveTxn::Write(wtx) => {
964                        wtx.abort();
965                        let mut fresh = SchemaManager::load(db)?;
966                        fresh.bump_generation_past(self.schema.generation());
967                        self.schema = fresh;
968                    }
969                    ActiveTxn::Read(_rtx) => {}
970                }
971                self.reset_txn_state();
972                Ok(ExecutionResult::Ok)
973            }
974            Statement::Savepoint(name) => self.do_savepoint(name),
975            Statement::ReleaseSavepoint(name) => self.do_release(name),
976            Statement::RollbackTo(name) => self.do_rollback_to(name),
977            Statement::SetTimezone(zone) => {
978                self.set_session_timezone_impl(zone)?;
979                Ok(ExecutionResult::Ok)
980            }
981            Statement::CreateTable(ct) if ct.temporary => {
982                if self.active_txn.is_read_only() {
983                    return Err(SqlError::Unsupported(
984                        "cannot execute mutating statement inside a read-only transaction".into(),
985                    ));
986                }
987                let user_name = ct.name.clone();
988                let prefixed = temp_storage_name(self.temp_id, &user_name);
989                if self.schema.contains(&user_name) {
990                    if ct.if_not_exists {
991                        return Ok(ExecutionResult::Ok);
992                    }
993                    return Err(SqlError::TableAlreadyExists(user_name));
994                }
995                let mut clone = ct.clone();
996                clone.name = prefixed.clone();
997                clone.temporary = false;
998                let stmt_concrete = Statement::CreateTable(clone);
999                let outcome = if let Some(wtx) = self.active_txn.as_write_mut() {
1000                    executor::execute_in_txn(wtx, &mut self.schema, &stmt_concrete, params)?
1001                } else {
1002                    executor::execute(db, &mut self.schema, &stmt_concrete, params)?
1003                };
1004                self.schema
1005                    .register_temp_alias(&user_name, prefixed.clone());
1006                self.temp_table_names.push(prefixed);
1007                Ok(outcome)
1008            }
1009            Statement::Insert(ins) if self.active_txn.as_write_mut().is_some() => {
1010                self.capture_pending_snapshots();
1011                let wtx = self.active_txn.as_write_mut().unwrap();
1012                executor::exec_insert_in_txn(wtx, &self.schema, ins, params)
1013            }
1014            _ => {
1015                if self.active_txn.is_read_only() && stmt_mutates(stmt) {
1016                    return Err(SqlError::Unsupported(
1017                        "cannot execute mutating statement inside a read-only transaction".into(),
1018                    ));
1019                }
1020                if self.active_txn.as_write_mut().is_some() && stmt_mutates(stmt) {
1021                    self.capture_pending_snapshots();
1022                }
1023                let was_auto_commit = matches!(self.active_txn, ActiveTxn::None);
1024                let outcome = match &mut self.active_txn {
1025                    ActiveTxn::Write(wtx) => {
1026                        executor::execute_in_txn(wtx, &mut self.schema, stmt, params)?
1027                    }
1028                    ActiveTxn::Read(rtx) => {
1029                        executor::execute_with_read(rtx, &self.schema, stmt, params)?
1030                    }
1031                    ActiveTxn::None => executor::execute(db, &mut self.schema, stmt, params)?,
1032                };
1033                if was_auto_commit {
1034                    invalidate_dml_caches(&self.schema, db);
1035                }
1036                if let Statement::DropTable(dt) = stmt {
1037                    self.schema.unregister_temp_alias(&dt.name);
1038                }
1039                Ok(outcome)
1040            }
1041        }
1042    }
1043
1044    fn clear_savepoint_state(&mut self) {
1045        self.savepoint_stack.clear();
1046        self.in_place_saved = None;
1047    }
1048
1049    fn do_savepoint(&mut self, name: &str) -> Result<ExecutionResult> {
1050        let wtx = self
1051            .active_txn
1052            .as_write_mut()
1053            .ok_or(SqlError::NoActiveTransaction)?;
1054
1055        if self.savepoint_stack.is_empty() {
1056            self.in_place_saved = Some(wtx.in_place());
1057            wtx.set_in_place(false);
1058        }
1059
1060        self.savepoint_stack.push(SavepointEntry {
1061            name: name.to_string(),
1062            snapshot: None,
1063        });
1064
1065        Ok(ExecutionResult::Ok)
1066    }
1067
1068    fn capture_pending_snapshots(&mut self) {
1069        let last_pending = match self
1070            .savepoint_stack
1071            .iter()
1072            .rposition(|e| e.snapshot.is_none())
1073        {
1074            Some(i) => i,
1075            None => return,
1076        };
1077        let wtx = match self.active_txn.as_write_mut() {
1078            Some(w) => w,
1079            None => return,
1080        };
1081        let wtx_snap = wtx.begin_savepoint();
1082        let schema_snap = self.schema.save_snapshot();
1083
1084        for i in 0..last_pending {
1085            if self.savepoint_stack[i].snapshot.is_none() {
1086                self.savepoint_stack[i].snapshot = Some(SavepointSnapshot {
1087                    wtx_snap: wtx_snap.clone(),
1088                    schema_snap: schema_snap.clone(),
1089                });
1090            }
1091        }
1092        self.savepoint_stack[last_pending].snapshot = Some(SavepointSnapshot {
1093            wtx_snap,
1094            schema_snap,
1095        });
1096    }
1097
1098    fn do_release(&mut self, name: &str) -> Result<ExecutionResult> {
1099        if !self.active_txn.is_active() {
1100            return Err(SqlError::NoActiveTransaction);
1101        }
1102
1103        let idx = self
1104            .savepoint_stack
1105            .iter()
1106            .rposition(|e| e.name == name)
1107            .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
1108        self.savepoint_stack.truncate(idx);
1109
1110        if self.savepoint_stack.is_empty() {
1111            if let (Some(wtx), Some(original)) =
1112                (self.active_txn.as_write_mut(), self.in_place_saved.take())
1113            {
1114                wtx.set_in_place(original);
1115            }
1116        }
1117
1118        Ok(ExecutionResult::Ok)
1119    }
1120
1121    fn do_rollback_to(&mut self, name: &str) -> Result<ExecutionResult> {
1122        if !self.active_txn.is_active() {
1123            return Err(SqlError::NoActiveTransaction);
1124        }
1125
1126        let idx = self
1127            .savepoint_stack
1128            .iter()
1129            .rposition(|e| e.name == name)
1130            .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
1131
1132        self.savepoint_stack.truncate(idx + 1);
1133        let entry = self.savepoint_stack.last_mut().unwrap();
1134        let snapshot = match entry.snapshot.take() {
1135            Some(s) => s,
1136            None => return Ok(ExecutionResult::Ok),
1137        };
1138
1139        let wtx = match self.active_txn.as_write_mut() {
1140            Some(w) => w,
1141            None => return Err(SqlError::NoActiveTransaction),
1142        };
1143        wtx.restore_snapshot(snapshot.wtx_snap);
1144        self.schema.restore_snapshot(snapshot.schema_snap);
1145
1146        Ok(ExecutionResult::Ok)
1147    }
1148}
1149
1150impl<'a> Drop for Connection<'a> {
1151    fn drop(&mut self) {
1152        let temp_names = std::mem::take(&mut self.inner.borrow_mut().temp_table_names);
1153        if temp_names.is_empty() {
1154            return;
1155        }
1156        if let Ok(mut wtx) = self.db.begin_write() {
1157            for prefixed in &temp_names {
1158                let _ = wtx.drop_table(prefixed.as_bytes());
1159            }
1160            let _ = wtx.commit();
1161        }
1162    }
1163}
1164
1165#[cfg(test)]
1166mod tests {
1167    use super::*;
1168    use citadel::{Argon2Profile, DatabaseBuilder};
1169
1170    fn fresh_db(dir: &std::path::Path) -> citadel::Database {
1171        DatabaseBuilder::new(dir.join("t.db"))
1172            .passphrase(b"test-passphrase")
1173            .argon2_profile(Argon2Profile::Iot)
1174            .create()
1175            .unwrap()
1176    }
1177
1178    #[test]
1179    fn execute_batch_commits_all_statements() {
1180        let dir = tempfile::tempdir().unwrap();
1181        let db = fresh_db(dir.path());
1182        let conn = Connection::open(&db).unwrap();
1183        conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER)")
1184            .unwrap();
1185
1186        let results = conn
1187            .execute_batch(
1188                "INSERT INTO t VALUES (1, 10); \
1189                 INSERT INTO t VALUES (2, 20); \
1190                 UPDATE t SET n = n + 1 WHERE id = 1;",
1191            )
1192            .unwrap();
1193        assert_eq!(results.len(), 3);
1194
1195        let qr = conn.query("SELECT id, n FROM t ORDER BY id").unwrap();
1196        assert_eq!(qr.rows.len(), 2);
1197        assert_eq!(qr.rows[0], vec![Value::Integer(1), Value::Integer(11)]);
1198        assert_eq!(qr.rows[1], vec![Value::Integer(2), Value::Integer(20)]);
1199    }
1200
1201    #[test]
1202    fn execute_batch_rolls_back_whole_batch_on_error() {
1203        let dir = tempfile::tempdir().unwrap();
1204        let db = fresh_db(dir.path());
1205        let conn = Connection::open(&db).unwrap();
1206        conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER)")
1207            .unwrap();
1208        conn.execute("INSERT INTO t VALUES (1, 100)").unwrap();
1209
1210        let res =
1211            conn.execute_batch("INSERT INTO t VALUES (2, 20); INSERT INTO t VALUES (1, 999);");
1212        assert!(res.is_err());
1213
1214        let qr = conn.query("SELECT id, n FROM t ORDER BY id").unwrap();
1215        assert_eq!(qr.rows.len(), 1);
1216        assert_eq!(qr.rows[0], vec![Value::Integer(1), Value::Integer(100)]);
1217
1218        conn.execute("INSERT INTO t VALUES (3, 30)").unwrap();
1219        assert!(!conn.in_transaction());
1220    }
1221
1222    #[test]
1223    fn execute_batch_rejects_txn_control() {
1224        let dir = tempfile::tempdir().unwrap();
1225        let db = fresh_db(dir.path());
1226        let conn = Connection::open(&db).unwrap();
1227        conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY)")
1228            .unwrap();
1229        assert!(conn
1230            .execute_batch("INSERT INTO t VALUES (1); COMMIT;")
1231            .is_err());
1232        assert!(!conn.in_transaction());
1233    }
1234
1235    #[test]
1236    fn execute_batch_rejected_inside_transaction() {
1237        let dir = tempfile::tempdir().unwrap();
1238        let db = fresh_db(dir.path());
1239        let conn = Connection::open(&db).unwrap();
1240        conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY)")
1241            .unwrap();
1242        conn.execute("BEGIN").unwrap();
1243        assert!(conn.execute_batch("INSERT INTO t VALUES (1);").is_err());
1244        conn.execute("ROLLBACK").unwrap();
1245    }
1246
1247    #[test]
1248    fn streamed_expression_projection_correct() {
1249        let dir = tempfile::tempdir().unwrap();
1250        let db = fresh_db(dir.path());
1251        let conn = Connection::open(&db).unwrap();
1252        conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER)")
1253            .unwrap();
1254        conn.execute_batch(
1255            "INSERT INTO t VALUES (1, 10); INSERT INTO t VALUES (2, 20); INSERT INTO t VALUES (3, 30);",
1256        )
1257        .unwrap();
1258        let stmt = conn.prepare("SELECT id + 1, n * 2 FROM t").unwrap();
1259        let qr = stmt.query_collect(&[]).unwrap();
1260        assert_eq!(qr.rows.len(), 3);
1261        assert_eq!(qr.rows[0], vec![Value::Integer(2), Value::Integer(20)]);
1262        assert_eq!(qr.rows[1], vec![Value::Integer(3), Value::Integer(40)]);
1263        assert_eq!(qr.rows[2], vec![Value::Integer(4), Value::Integer(60)]);
1264    }
1265
1266    #[test]
1267    fn jsonb_contains_raw_predicate_correct() {
1268        let dir = tempfile::tempdir().unwrap();
1269        let db = fresh_db(dir.path());
1270        let conn = Connection::open(&db).unwrap();
1271        conn.execute("CREATE TABLE u (id INTEGER PRIMARY KEY, data JSONB)")
1272            .unwrap();
1273        conn.execute_batch(
1274            "INSERT INTO u VALUES (1, '{\"role\":\"admin\",\"x\":1}'); \
1275             INSERT INTO u VALUES (2, '{\"role\":\"user\"}'); \
1276             INSERT INTO u VALUES (3, NULL); \
1277             INSERT INTO u VALUES (4, '{\"role\":\"admin\"}');",
1278        )
1279        .unwrap();
1280        let stmt = conn
1281            .prepare("SELECT id FROM u WHERE data @> '{\"role\":\"admin\"}'::jsonb")
1282            .unwrap();
1283        let qr = stmt.query_collect(&[]).unwrap();
1284        let mut ids: Vec<i64> = qr
1285            .rows
1286            .iter()
1287            .map(|r| match r[0] {
1288                Value::Integer(i) => i,
1289                _ => -1,
1290            })
1291            .collect();
1292        ids.sort_unstable();
1293        assert_eq!(ids, vec![1, 4]);
1294    }
1295}