Skip to main content

citadel_sql/
connection.rs

1//! Public SQL connection API.
2
3use std::cell::RefCell;
4use std::num::NonZeroUsize;
5use std::sync::Arc;
6
7use lru::LruCache;
8
9use citadel::Database;
10use citadel_txn::write_txn::{WriteTxn, WriteTxnSnapshot};
11
12use crate::error::{Result, SqlError};
13use crate::executor;
14use crate::parser;
15use crate::parser::{QueryBody, SelectQuery, Statement};
16use crate::prepared::PreparedStatement;
17use crate::schema::{SchemaManager, SchemaSnapshot};
18use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
19
20const DEFAULT_CACHE_CAPACITY: usize = 64;
21
22#[derive(Debug)]
23pub struct ScriptExecution {
24    pub completed: Vec<ExecutionResult>,
25    pub error: Option<SqlError>,
26}
27
28fn parse_fixed_offset(s: &str) -> Option<()> {
29    let s = s.trim();
30    if s.eq_ignore_ascii_case("z") || s.eq_ignore_ascii_case("utc") {
31        return Some(());
32    }
33    let bytes = s.as_bytes();
34    if bytes.is_empty() {
35        return None;
36    }
37    let sign = match bytes[0] {
38        b'+' | b'-' => bytes[0] as char,
39        _ => return None,
40    };
41    let rest = &s[1..];
42    let (hh, mm) = if let Some((h, m)) = rest.split_once(':') {
43        (h, m)
44    } else if rest.len() == 4 {
45        (&rest[..2], &rest[2..])
46    } else if rest.len() == 2 {
47        (rest, "00")
48    } else {
49        return None;
50    };
51    let h: u32 = hh.parse().ok()?;
52    let m: u32 = mm.parse().ok()?;
53    if h > 23 || m > 59 {
54        return None;
55    }
56    let _ = sign;
57    Some(())
58}
59
60fn stmt_mutates(stmt: &Statement) -> bool {
61    if matches!(
62        stmt,
63        Statement::Insert(_)
64            | Statement::Update(_)
65            | Statement::Delete(_)
66            | Statement::Truncate(_)
67            | Statement::CreateTable(_)
68            | Statement::DropTable(_)
69            | Statement::AlterTable(_)
70            | Statement::CreateIndex(_)
71            | Statement::DropIndex(_)
72            | Statement::CreateView(_)
73            | Statement::DropView(_)
74    ) {
75        return true;
76    }
77    if let Statement::Select(sq) = stmt {
78        if select_query_has_dml(sq) {
79            return true;
80        }
81    }
82    false
83}
84
85fn select_query_has_dml(sq: &SelectQuery) -> bool {
86    sq.ctes.iter().any(|cte| query_body_has_dml(&cte.body)) || query_body_has_dml(&sq.body)
87}
88
89fn query_body_has_dml(body: &QueryBody) -> bool {
90    match body {
91        QueryBody::Insert(_) | QueryBody::Update(_) | QueryBody::Delete(_) => true,
92        QueryBody::Compound(c) => query_body_has_dml(&c.left) || query_body_has_dml(&c.right),
93        QueryBody::Select(_) => false,
94    }
95}
96
97fn try_normalize_insert(sql: &str) -> Option<(String, Vec<Value>)> {
98    let bytes = sql.as_bytes();
99    let len = bytes.len();
100    let mut i = 0;
101
102    while i < len && bytes[i].is_ascii_whitespace() {
103        i += 1;
104    }
105    if i + 6 > len || !bytes[i..i + 6].eq_ignore_ascii_case(b"INSERT") {
106        return None;
107    }
108    i += 6;
109    if i >= len || !bytes[i].is_ascii_whitespace() {
110        return None;
111    }
112    while i < len && bytes[i].is_ascii_whitespace() {
113        i += 1;
114    }
115
116    if i + 4 > len || !bytes[i..i + 4].eq_ignore_ascii_case(b"INTO") {
117        return None;
118    }
119    i += 4;
120    if i >= len || !bytes[i].is_ascii_whitespace() {
121        return None;
122    }
123
124    let prefix_start = 0;
125    let mut values_pos = None;
126    let mut j = i;
127    while j + 6 <= len {
128        if bytes[j..j + 6].eq_ignore_ascii_case(b"VALUES")
129            && (j == 0 || !bytes[j - 1].is_ascii_alphanumeric() && bytes[j - 1] != b'_')
130            && (j + 6 >= len || !bytes[j + 6].is_ascii_alphanumeric() && bytes[j + 6] != b'_')
131        {
132            values_pos = Some(j);
133            break;
134        }
135        j += 1;
136    }
137    let values_pos = values_pos?;
138
139    let prefix = &sql[prefix_start..values_pos + 6];
140    let mut pos = values_pos + 6;
141
142    while pos < len && bytes[pos].is_ascii_whitespace() {
143        pos += 1;
144    }
145    if pos >= len || bytes[pos] != b'(' {
146        return None;
147    }
148    pos += 1;
149
150    let mut values = Vec::new();
151    let mut normalized = String::with_capacity(sql.len());
152    normalized.push_str(prefix);
153    normalized.push_str(" (");
154
155    loop {
156        while pos < len && bytes[pos].is_ascii_whitespace() {
157            pos += 1;
158        }
159        if pos >= len {
160            return None;
161        }
162
163        let param_idx = values.len() + 1;
164        if param_idx > 1 {
165            normalized.push_str(", ");
166        }
167
168        if bytes[pos] == b'\'' {
169            pos += 1;
170            let mut seg_start = pos;
171            let mut s = String::new();
172            loop {
173                if pos >= len {
174                    return None;
175                }
176                if bytes[pos] == b'\'' {
177                    s.push_str(std::str::from_utf8(&bytes[seg_start..pos]).ok()?);
178                    pos += 1;
179                    if pos < len && bytes[pos] == b'\'' {
180                        s.push('\'');
181                        pos += 1;
182                        seg_start = pos;
183                    } else {
184                        break;
185                    }
186                } else {
187                    pos += 1;
188                }
189            }
190            values.push(Value::Text(s.into()));
191        } else if bytes[pos] == b'-' || bytes[pos].is_ascii_digit() {
192            let start = pos;
193            if bytes[pos] == b'-' {
194                pos += 1;
195            }
196            while pos < len && bytes[pos].is_ascii_digit() {
197                pos += 1;
198            }
199            if pos < len && bytes[pos] == b'.' {
200                pos += 1;
201                while pos < len && bytes[pos].is_ascii_digit() {
202                    pos += 1;
203                }
204                let num: f64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
205                values.push(Value::Real(num));
206            } else {
207                let num: i64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
208                values.push(Value::Integer(num));
209            }
210        } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"NULL") {
211            let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
212            if !after.is_ascii_alphanumeric() && after != b'_' {
213                pos += 4;
214                values.push(Value::Null);
215            } else {
216                return None;
217            }
218        } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"TRUE") {
219            let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
220            if !after.is_ascii_alphanumeric() && after != b'_' {
221                pos += 4;
222                values.push(Value::Boolean(true));
223            } else {
224                return None;
225            }
226        } else if pos + 5 <= len && bytes[pos..pos + 5].eq_ignore_ascii_case(b"FALSE") {
227            let after = if pos + 5 < len { bytes[pos + 5] } else { b')' };
228            if !after.is_ascii_alphanumeric() && after != b'_' {
229                pos += 5;
230                values.push(Value::Boolean(false));
231            } else {
232                return None;
233            }
234        } else {
235            return None;
236        }
237
238        normalized.push('$');
239        normalized.push_str(&param_idx.to_string());
240
241        while pos < len && bytes[pos].is_ascii_whitespace() {
242            pos += 1;
243        }
244        if pos >= len {
245            return None;
246        }
247
248        if bytes[pos] == b',' {
249            pos += 1;
250        } else if bytes[pos] == b')' {
251            pos += 1;
252            break;
253        } else {
254            return None;
255        }
256    }
257
258    normalized.push(')');
259
260    while pos < len && (bytes[pos].is_ascii_whitespace() || bytes[pos] == b';') {
261        pos += 1;
262    }
263    if pos != len {
264        return None;
265    }
266
267    if values.is_empty() {
268        return None;
269    }
270
271    Some((normalized, values))
272}
273
274pub(crate) struct CacheEntry {
275    pub(crate) stmt: Arc<Statement>,
276    pub(crate) schema_gen: u64,
277    pub(crate) param_count: usize,
278    pub(crate) compiled: Option<Arc<dyn executor::CompiledPlan>>,
279}
280
281struct SavepointEntry {
282    name: String,
283    snapshot: Option<SavepointSnapshot>,
284}
285
286struct SavepointSnapshot {
287    wtx_snap: WriteTxnSnapshot,
288    schema_snap: SchemaSnapshot,
289}
290
291pub(crate) struct ConnectionInner<'a> {
292    pub(crate) schema: SchemaManager,
293    active_txn: Option<WriteTxn<'a>>,
294    savepoint_stack: Vec<SavepointEntry>,
295    in_place_saved: Option<bool>,
296    pub(crate) stmt_cache: LruCache<String, CacheEntry>,
297    txn_start_ts: Option<i64>,
298    session_timezone: String,
299}
300
301/// SQL connection with LRU statement cache.
302pub struct Connection<'a> {
303    pub(crate) db: &'a Database,
304    pub(crate) inner: RefCell<ConnectionInner<'a>>,
305}
306
307impl<'a> Connection<'a> {
308    /// Open a SQL connection to a database.
309    pub fn open(db: &'a Database) -> Result<Self> {
310        let schema = SchemaManager::load(db)?;
311        let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
312        Ok(Self {
313            db,
314            inner: RefCell::new(ConnectionInner {
315                schema,
316                active_txn: None,
317                savepoint_stack: Vec::new(),
318                in_place_saved: None,
319                stmt_cache,
320                txn_start_ts: None,
321                session_timezone: "UTC".to_string(),
322            }),
323        })
324    }
325
326    /// Txn-start UTC µs inside BEGIN/COMMIT, else `None`.
327    pub fn txn_start_ts(&self) -> Option<i64> {
328        self.inner.borrow().txn_start_ts
329    }
330
331    /// Returns the session time-zone (IANA name or fixed offset). Default `"UTC"`.
332    pub fn session_timezone(&self) -> String {
333        self.inner.borrow().session_timezone.clone()
334    }
335
336    /// Set the session time-zone. Accepts IANA names, ISO-8601 offsets, `"UTC"`, `"Z"`.
337    pub fn set_session_timezone(&self, tz: &str) -> Result<()> {
338        self.inner.borrow_mut().set_session_timezone_impl(tz)
339    }
340
341    /// Execute a SQL statement. Returns the result.
342    pub fn execute(&self, sql: &str) -> Result<ExecutionResult> {
343        self.inner.borrow_mut().execute_impl(self.db, sql)
344    }
345
346    /// Execute a SQL statement with positional parameters ($1, $2, ...).
347    pub fn execute_params(&self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
348        self.inner
349            .borrow_mut()
350            .execute_params_impl(self.db, sql, params)
351    }
352
353    /// Execute `;`-separated SQL statements. Stops at the first failure.
354    pub fn execute_script(&self, sql: &str) -> ScriptExecution {
355        let stmts = match parser::parse_sql_multi(sql) {
356            Ok(s) => s,
357            Err(e) => {
358                return ScriptExecution {
359                    completed: vec![],
360                    error: Some(e),
361                }
362            }
363        };
364        let mut completed = Vec::with_capacity(stmts.len());
365        for stmt in stmts {
366            match self.inner.borrow_mut().dispatch(self.db, &stmt, &[]) {
367                Ok(r) => completed.push(r),
368                Err(e) => {
369                    return ScriptExecution {
370                        completed,
371                        error: Some(e),
372                    }
373                }
374            }
375        }
376        ScriptExecution {
377            completed,
378            error: None,
379        }
380    }
381
382    /// Execute a SQL query and return the result set.
383    pub fn query(&self, sql: &str) -> Result<QueryResult> {
384        self.query_params(sql, &[])
385    }
386
387    /// Execute a SQL query with positional parameters ($1, $2, ...).
388    pub fn query_params(&self, sql: &str, params: &[Value]) -> Result<QueryResult> {
389        match self.execute_params(sql, params)? {
390            ExecutionResult::Query(qr) => Ok(qr),
391            ExecutionResult::RowsAffected(n) => Ok(QueryResult {
392                columns: vec!["rows_affected".into()],
393                rows: vec![vec![Value::Integer(n as i64)]],
394            }),
395            ExecutionResult::Ok => Ok(QueryResult {
396                columns: vec![],
397                rows: vec![],
398            }),
399        }
400    }
401
402    /// Prepare a SQL statement for repeated execution with parameters.
403    pub fn prepare(&self, sql: &str) -> Result<PreparedStatement<'_, 'a>> {
404        PreparedStatement::new(self, sql)
405    }
406
407    /// List all table names.
408    pub fn tables(&self) -> Vec<String> {
409        self.inner
410            .borrow()
411            .schema
412            .table_names()
413            .into_iter()
414            .map(String::from)
415            .collect()
416    }
417
418    /// Returns true if an explicit transaction is active (BEGIN was issued).
419    pub fn in_transaction(&self) -> bool {
420        self.inner.borrow().active_txn.is_some()
421    }
422
423    /// Get the schema for a named table.
424    pub fn table_schema(&self, name: &str) -> Option<TableSchema> {
425        self.inner.borrow().schema.get(name).cloned()
426    }
427
428    /// Reload schemas from the database.
429    pub fn refresh_schema(&self) -> Result<()> {
430        let new_schema = SchemaManager::load(self.db)?;
431        self.inner.borrow_mut().schema = new_schema;
432        Ok(())
433    }
434}
435
436impl<'a> ConnectionInner<'a> {
437    pub(crate) fn active_txn_is_some(&self) -> bool {
438        self.active_txn.is_some()
439    }
440
441    fn set_session_timezone_impl(&mut self, tz: &str) -> Result<()> {
442        let upper = tz.to_ascii_uppercase();
443        if (upper.starts_with("UTC+") || upper.starts_with("UTC-")) && tz.len() > 3 {
444            return Err(SqlError::InvalidTimezone(format!(
445                "'{tz}' is ambiguous; use ISO-8601 offset (e.g. '+05:00') or named zone (e.g. 'Etc/GMT-5')"
446            )));
447        }
448        if jiff::tz::TimeZone::get(tz).is_err() && parse_fixed_offset(tz).is_none() {
449            return Err(SqlError::InvalidTimezone(format!(
450                "{tz}: not a known IANA zone or ISO-8601 offset (e.g. '+05:00', 'UTC', 'America/New_York')"
451            )));
452        }
453        self.session_timezone = tz.to_string();
454        Ok(())
455    }
456
457    fn execute_impl(&mut self, db: &'a Database, sql: &str) -> Result<ExecutionResult> {
458        if matches!(sql.as_bytes().first(), Some(b'I' | b'i')) {
459            if let Some((normalized_key, extracted)) = try_normalize_insert(sql) {
460                let gen = self.schema.generation();
461                let stmt = if let Some(entry) = self.stmt_cache.get(&normalized_key) {
462                    if entry.schema_gen == gen {
463                        Arc::clone(&entry.stmt)
464                    } else {
465                        self.parse_and_cache(normalized_key, gen)?
466                    }
467                } else {
468                    self.parse_and_cache(normalized_key, gen)?
469                };
470                return self.dispatch(db, &stmt, &extracted);
471            }
472        }
473        self.execute_params_impl(db, sql, &[])
474    }
475
476    fn execute_params_impl(
477        &mut self,
478        db: &'a Database,
479        sql: &str,
480        params: &[Value],
481    ) -> Result<ExecutionResult> {
482        let gen = self.schema.generation();
483        if self.active_txn.is_none() {
484            if let Some(entry) = self.stmt_cache.get(sql) {
485                if entry.schema_gen == gen && entry.param_count == params.len() {
486                    if let Some(plan) = entry.compiled.as_ref().map(Arc::clone) {
487                        let stmt = Arc::clone(&entry.stmt);
488                        return self.run_compiled(db, &plan, &stmt, params);
489                    }
490                }
491            }
492        }
493
494        let (stmt, param_count) = self.get_or_parse(sql)?;
495
496        if param_count != params.len() {
497            return Err(SqlError::ParameterCountMismatch {
498                expected: param_count,
499                got: params.len(),
500            });
501        }
502
503        if self.active_txn.is_none() {
504            if let Some(plan) = executor::compile(&self.schema, &stmt) {
505                if let Some(e) = self.stmt_cache.get_mut(sql) {
506                    e.compiled = Some(Arc::clone(&plan));
507                }
508                let stmt_owned = Arc::clone(&stmt);
509                return self.run_compiled(db, &plan, &stmt_owned, params);
510            }
511        }
512
513        self.dispatch(db, &stmt, params)
514    }
515
516    fn run_compiled(
517        &mut self,
518        db: &'a Database,
519        plan: &Arc<dyn executor::CompiledPlan>,
520        stmt: &Statement,
521        params: &[Value],
522    ) -> Result<ExecutionResult> {
523        let schema = &self.schema;
524        let exec = || {
525            if params.is_empty() {
526                plan.execute(db, schema, stmt, params, None)
527            } else {
528                crate::eval::with_scoped_params(params, || {
529                    plan.execute(db, schema, stmt, params, None)
530                })
531            }
532        };
533        if plan.needs_txn_clock() {
534            let cached_ts = self
535                .txn_start_ts
536                .or_else(|| Some(crate::datetime::now_micros()));
537            crate::datetime::with_txn_clock(cached_ts, exec)
538        } else {
539            exec()
540        }
541    }
542
543    pub(crate) fn parse_and_cache(
544        &mut self,
545        normalized_key: String,
546        gen: u64,
547    ) -> Result<Arc<Statement>> {
548        let stmt = Arc::new(parser::parse_sql(&normalized_key)?);
549        let param_count = parser::count_params(&stmt);
550        self.stmt_cache.put(
551            normalized_key,
552            CacheEntry {
553                stmt: Arc::clone(&stmt),
554                schema_gen: gen,
555                param_count,
556                compiled: None,
557            },
558        );
559        Ok(stmt)
560    }
561
562    pub(crate) fn get_or_parse(&mut self, sql: &str) -> Result<(Arc<Statement>, usize)> {
563        let gen = self.schema.generation();
564
565        if let Some(entry) = self.stmt_cache.get(sql) {
566            if entry.schema_gen == gen {
567                return Ok((Arc::clone(&entry.stmt), entry.param_count));
568            }
569        }
570
571        let stmt = Arc::new(parser::parse_sql(sql)?);
572        let param_count = parser::count_params(&stmt);
573
574        let cacheable = !matches!(
575            *stmt,
576            Statement::CreateTable(_)
577                | Statement::DropTable(_)
578                | Statement::CreateIndex(_)
579                | Statement::DropIndex(_)
580                | Statement::CreateView(_)
581                | Statement::DropView(_)
582                | Statement::AlterTable(_)
583        );
584
585        if cacheable {
586            self.stmt_cache.put(
587                sql.to_string(),
588                CacheEntry {
589                    stmt: Arc::clone(&stmt),
590                    schema_gen: gen,
591                    param_count,
592                    compiled: None,
593                },
594            );
595        }
596
597        Ok((stmt, param_count))
598    }
599
600    pub(crate) fn execute_prepared(
601        &mut self,
602        db: &'a Database,
603        stmt: &Statement,
604        compiled: Option<&Arc<dyn executor::CompiledPlan>>,
605        params: &[Value],
606    ) -> Result<ExecutionResult> {
607        if let Some(plan) = compiled {
608            if self.active_txn.is_none() {
609                return self.run_compiled(db, plan, stmt, params);
610            }
611            if !self.savepoint_stack.is_empty() && stmt_mutates(stmt) {
612                self.capture_pending_snapshots();
613            }
614            return self.run_compiled_in_txn(db, plan, stmt, params);
615        }
616        self.dispatch(db, stmt, params)
617    }
618
619    fn run_compiled_in_txn(
620        &mut self,
621        db: &'a Database,
622        plan: &Arc<dyn executor::CompiledPlan>,
623        stmt: &Statement,
624        params: &[Value],
625    ) -> Result<ExecutionResult> {
626        let schema = &self.schema;
627        let wtx = self.active_txn.as_mut();
628        if params.is_empty() || !plan.uses_scoped_params() {
629            plan.execute(db, schema, stmt, params, wtx)
630        } else {
631            crate::eval::with_scoped_params(params, || plan.execute(db, schema, stmt, params, wtx))
632        }
633    }
634
635    pub(crate) fn dispatch(
636        &mut self,
637        db: &'a Database,
638        stmt: &Statement,
639        params: &[Value],
640    ) -> Result<ExecutionResult> {
641        let cached_ts = self
642            .txn_start_ts
643            .or_else(|| Some(crate::datetime::now_micros()));
644        crate::datetime::with_txn_clock(cached_ts, || {
645            if params.is_empty() {
646                self.dispatch_inner(db, stmt, params)
647            } else {
648                crate::eval::with_scoped_params(params, || self.dispatch_inner(db, stmt, params))
649            }
650        })
651    }
652
653    fn dispatch_inner(
654        &mut self,
655        db: &'a Database,
656        stmt: &Statement,
657        params: &[Value],
658    ) -> Result<ExecutionResult> {
659        match stmt {
660            Statement::Begin => {
661                if self.active_txn.is_some() {
662                    return Err(SqlError::TransactionAlreadyActive);
663                }
664                let wtx = db.begin_write().map_err(SqlError::Storage)?;
665                self.active_txn = Some(wtx);
666                let ts = crate::datetime::txn_or_clock_micros();
667                self.txn_start_ts = Some(ts);
668                crate::datetime::set_txn_clock(Some(ts));
669                Ok(ExecutionResult::Ok)
670            }
671            Statement::Commit => {
672                let mut wtx = self
673                    .active_txn
674                    .take()
675                    .ok_or(SqlError::NoActiveTransaction)?;
676                crate::executor::helpers::drain_deferred_fk_checks(&mut wtx)?;
677                wtx.commit().map_err(SqlError::Storage)?;
678                self.clear_savepoint_state();
679                self.txn_start_ts = None;
680                crate::datetime::set_txn_clock(None);
681                Ok(ExecutionResult::Ok)
682            }
683            Statement::Rollback => {
684                let wtx = self
685                    .active_txn
686                    .take()
687                    .ok_or(SqlError::NoActiveTransaction)?;
688                wtx.abort();
689                self.clear_savepoint_state();
690                self.schema = SchemaManager::load(db)?;
691                self.txn_start_ts = None;
692                crate::datetime::set_txn_clock(None);
693                Ok(ExecutionResult::Ok)
694            }
695            Statement::Savepoint(name) => self.do_savepoint(name),
696            Statement::ReleaseSavepoint(name) => self.do_release(name),
697            Statement::RollbackTo(name) => self.do_rollback_to(name),
698            Statement::SetTimezone(zone) => {
699                self.set_session_timezone_impl(zone)?;
700                Ok(ExecutionResult::Ok)
701            }
702            Statement::Insert(ins) if self.active_txn.is_some() => {
703                self.capture_pending_snapshots();
704                let wtx = self.active_txn.as_mut().unwrap();
705                executor::exec_insert_in_txn(wtx, &self.schema, ins, params)
706            }
707            _ => {
708                if self.active_txn.is_some() && stmt_mutates(stmt) {
709                    self.capture_pending_snapshots();
710                }
711                if let Some(ref mut wtx) = self.active_txn {
712                    executor::execute_in_txn(wtx, &mut self.schema, stmt, params)
713                } else {
714                    executor::execute(db, &mut self.schema, stmt, params)
715                }
716            }
717        }
718    }
719
720    fn clear_savepoint_state(&mut self) {
721        self.savepoint_stack.clear();
722        self.in_place_saved = None;
723    }
724
725    fn do_savepoint(&mut self, name: &str) -> Result<ExecutionResult> {
726        let wtx = self
727            .active_txn
728            .as_mut()
729            .ok_or(SqlError::NoActiveTransaction)?;
730
731        if self.savepoint_stack.is_empty() {
732            self.in_place_saved = Some(wtx.in_place());
733            wtx.set_in_place(false);
734        }
735
736        self.savepoint_stack.push(SavepointEntry {
737            name: name.to_string(),
738            snapshot: None,
739        });
740
741        Ok(ExecutionResult::Ok)
742    }
743
744    fn capture_pending_snapshots(&mut self) {
745        let last_pending = match self
746            .savepoint_stack
747            .iter()
748            .rposition(|e| e.snapshot.is_none())
749        {
750            Some(i) => i,
751            None => return,
752        };
753        let wtx = match self.active_txn.as_mut() {
754            Some(w) => w,
755            None => return,
756        };
757        let wtx_snap = wtx.begin_savepoint();
758        let schema_snap = self.schema.save_snapshot();
759
760        for i in 0..last_pending {
761            if self.savepoint_stack[i].snapshot.is_none() {
762                self.savepoint_stack[i].snapshot = Some(SavepointSnapshot {
763                    wtx_snap: wtx_snap.clone(),
764                    schema_snap: schema_snap.clone(),
765                });
766            }
767        }
768        self.savepoint_stack[last_pending].snapshot = Some(SavepointSnapshot {
769            wtx_snap,
770            schema_snap,
771        });
772    }
773
774    fn do_release(&mut self, name: &str) -> Result<ExecutionResult> {
775        if self.active_txn.is_none() {
776            return Err(SqlError::NoActiveTransaction);
777        }
778
779        let idx = self
780            .savepoint_stack
781            .iter()
782            .rposition(|e| e.name == name)
783            .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
784        self.savepoint_stack.truncate(idx);
785
786        if self.savepoint_stack.is_empty() {
787            if let (Some(wtx), Some(original)) =
788                (self.active_txn.as_mut(), self.in_place_saved.take())
789            {
790                wtx.set_in_place(original);
791            }
792        }
793
794        Ok(ExecutionResult::Ok)
795    }
796
797    fn do_rollback_to(&mut self, name: &str) -> Result<ExecutionResult> {
798        if self.active_txn.is_none() {
799            return Err(SqlError::NoActiveTransaction);
800        }
801
802        let idx = self
803            .savepoint_stack
804            .iter()
805            .rposition(|e| e.name == name)
806            .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
807
808        self.savepoint_stack.truncate(idx + 1);
809        let entry = self.savepoint_stack.last_mut().unwrap();
810        let snapshot = match entry.snapshot.take() {
811            Some(s) => s,
812            None => return Ok(ExecutionResult::Ok),
813        };
814
815        let wtx = self.active_txn.as_mut().unwrap();
816        wtx.restore_snapshot(snapshot.wtx_snap);
817        self.schema.restore_snapshot(snapshot.schema_snap);
818
819        Ok(ExecutionResult::Ok)
820    }
821}