Skip to main content

citadel_sql/
connection.rs

1//! Public SQL connection API.
2
3use std::cell::RefCell;
4use std::num::NonZeroUsize;
5use std::sync::Arc;
6
7use lru::LruCache;
8
9use citadel::Database;
10use citadel_txn::write_txn::{WriteTxn, WriteTxnSnapshot};
11
12use crate::error::{Result, SqlError};
13use crate::executor;
14use crate::parser;
15use crate::parser::{QueryBody, SelectQuery, Statement};
16use crate::prepared::PreparedStatement;
17use crate::schema::{SchemaManager, SchemaSnapshot};
18use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
19
20const DEFAULT_CACHE_CAPACITY: usize = 64;
21
22#[derive(Debug)]
23pub struct ScriptExecution {
24    pub completed: Vec<ExecutionResult>,
25    pub error: Option<SqlError>,
26}
27
28fn parse_fixed_offset(s: &str) -> Option<()> {
29    let s = s.trim();
30    if s.eq_ignore_ascii_case("z") || s.eq_ignore_ascii_case("utc") {
31        return Some(());
32    }
33    let bytes = s.as_bytes();
34    if bytes.is_empty() {
35        return None;
36    }
37    let sign = match bytes[0] {
38        b'+' | b'-' => bytes[0] as char,
39        _ => return None,
40    };
41    let rest = &s[1..];
42    let (hh, mm) = if let Some((h, m)) = rest.split_once(':') {
43        (h, m)
44    } else if rest.len() == 4 {
45        (&rest[..2], &rest[2..])
46    } else if rest.len() == 2 {
47        (rest, "00")
48    } else {
49        return None;
50    };
51    let h: u32 = hh.parse().ok()?;
52    let m: u32 = mm.parse().ok()?;
53    if h > 23 || m > 59 {
54        return None;
55    }
56    let _ = sign;
57    Some(())
58}
59
60fn stmt_mutates(stmt: &Statement) -> bool {
61    if matches!(
62        stmt,
63        Statement::Insert(_)
64            | Statement::Update(_)
65            | Statement::Delete(_)
66            | Statement::Truncate(_)
67            | Statement::CreateTable(_)
68            | Statement::DropTable(_)
69            | Statement::AlterTable(_)
70            | Statement::CreateIndex(_)
71            | Statement::DropIndex(_)
72            | Statement::CreateView(_)
73            | Statement::DropView(_)
74    ) {
75        return true;
76    }
77    if let Statement::Select(sq) = stmt {
78        if select_query_has_dml(sq) {
79            return true;
80        }
81    }
82    false
83}
84
85fn select_query_has_dml(sq: &SelectQuery) -> bool {
86    sq.ctes.iter().any(|cte| query_body_has_dml(&cte.body)) || query_body_has_dml(&sq.body)
87}
88
89fn query_body_has_dml(body: &QueryBody) -> bool {
90    match body {
91        QueryBody::Insert(_) | QueryBody::Update(_) | QueryBody::Delete(_) => true,
92        QueryBody::Compound(c) => query_body_has_dml(&c.left) || query_body_has_dml(&c.right),
93        QueryBody::Select(_) => false,
94    }
95}
96
97fn try_normalize_insert(sql: &str) -> Option<(String, Vec<Value>)> {
98    let bytes = sql.as_bytes();
99    let len = bytes.len();
100    let mut i = 0;
101
102    while i < len && bytes[i].is_ascii_whitespace() {
103        i += 1;
104    }
105    if i + 6 > len || !bytes[i..i + 6].eq_ignore_ascii_case(b"INSERT") {
106        return None;
107    }
108    i += 6;
109    if i >= len || !bytes[i].is_ascii_whitespace() {
110        return None;
111    }
112    while i < len && bytes[i].is_ascii_whitespace() {
113        i += 1;
114    }
115
116    if i + 4 > len || !bytes[i..i + 4].eq_ignore_ascii_case(b"INTO") {
117        return None;
118    }
119    i += 4;
120    if i >= len || !bytes[i].is_ascii_whitespace() {
121        return None;
122    }
123
124    let prefix_start = 0;
125    let mut values_pos = None;
126    let mut j = i;
127    while j + 6 <= len {
128        if bytes[j..j + 6].eq_ignore_ascii_case(b"VALUES")
129            && (j == 0 || !bytes[j - 1].is_ascii_alphanumeric() && bytes[j - 1] != b'_')
130            && (j + 6 >= len || !bytes[j + 6].is_ascii_alphanumeric() && bytes[j + 6] != b'_')
131        {
132            values_pos = Some(j);
133            break;
134        }
135        j += 1;
136    }
137    let values_pos = values_pos?;
138
139    let prefix = &sql[prefix_start..values_pos + 6];
140    let mut pos = values_pos + 6;
141
142    while pos < len && bytes[pos].is_ascii_whitespace() {
143        pos += 1;
144    }
145    if pos >= len || bytes[pos] != b'(' {
146        return None;
147    }
148    pos += 1;
149
150    let mut values = Vec::new();
151    let mut normalized = String::with_capacity(sql.len());
152    normalized.push_str(prefix);
153    normalized.push_str(" (");
154
155    loop {
156        while pos < len && bytes[pos].is_ascii_whitespace() {
157            pos += 1;
158        }
159        if pos >= len {
160            return None;
161        }
162
163        let param_idx = values.len() + 1;
164        if param_idx > 1 {
165            normalized.push_str(", ");
166        }
167
168        if bytes[pos] == b'\'' {
169            pos += 1;
170            let mut seg_start = pos;
171            let mut s = String::new();
172            loop {
173                if pos >= len {
174                    return None;
175                }
176                if bytes[pos] == b'\'' {
177                    s.push_str(std::str::from_utf8(&bytes[seg_start..pos]).ok()?);
178                    pos += 1;
179                    if pos < len && bytes[pos] == b'\'' {
180                        s.push('\'');
181                        pos += 1;
182                        seg_start = pos;
183                    } else {
184                        break;
185                    }
186                } else {
187                    pos += 1;
188                }
189            }
190            values.push(Value::Text(s.into()));
191        } else if bytes[pos] == b'-' || bytes[pos].is_ascii_digit() {
192            let start = pos;
193            if bytes[pos] == b'-' {
194                pos += 1;
195            }
196            while pos < len && bytes[pos].is_ascii_digit() {
197                pos += 1;
198            }
199            if pos < len && bytes[pos] == b'.' {
200                pos += 1;
201                while pos < len && bytes[pos].is_ascii_digit() {
202                    pos += 1;
203                }
204                let num: f64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
205                values.push(Value::Real(num));
206            } else {
207                let num: i64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
208                values.push(Value::Integer(num));
209            }
210        } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"NULL") {
211            let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
212            if !after.is_ascii_alphanumeric() && after != b'_' {
213                pos += 4;
214                values.push(Value::Null);
215            } else {
216                return None;
217            }
218        } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"TRUE") {
219            let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
220            if !after.is_ascii_alphanumeric() && after != b'_' {
221                pos += 4;
222                values.push(Value::Boolean(true));
223            } else {
224                return None;
225            }
226        } else if pos + 5 <= len && bytes[pos..pos + 5].eq_ignore_ascii_case(b"FALSE") {
227            let after = if pos + 5 < len { bytes[pos + 5] } else { b')' };
228            if !after.is_ascii_alphanumeric() && after != b'_' {
229                pos += 5;
230                values.push(Value::Boolean(false));
231            } else {
232                return None;
233            }
234        } else {
235            return None;
236        }
237
238        normalized.push('$');
239        normalized.push_str(&param_idx.to_string());
240
241        while pos < len && bytes[pos].is_ascii_whitespace() {
242            pos += 1;
243        }
244        if pos >= len {
245            return None;
246        }
247
248        if bytes[pos] == b',' {
249            pos += 1;
250        } else if bytes[pos] == b')' {
251            pos += 1;
252            break;
253        } else {
254            return None;
255        }
256    }
257
258    normalized.push(')');
259
260    while pos < len && (bytes[pos].is_ascii_whitespace() || bytes[pos] == b';') {
261        pos += 1;
262    }
263    if pos != len {
264        return None;
265    }
266
267    if values.is_empty() {
268        return None;
269    }
270
271    Some((normalized, values))
272}
273
274pub(crate) struct CacheEntry {
275    pub(crate) stmt: Arc<Statement>,
276    pub(crate) schema_gen: u64,
277    pub(crate) param_count: usize,
278    pub(crate) compiled: Option<Arc<dyn executor::CompiledPlan>>,
279}
280
281struct SavepointEntry {
282    name: String,
283    snapshot: Option<SavepointSnapshot>,
284}
285
286struct SavepointSnapshot {
287    wtx_snap: WriteTxnSnapshot,
288    schema_snap: SchemaSnapshot,
289}
290
291pub(crate) struct ConnectionInner<'a> {
292    pub(crate) schema: SchemaManager,
293    active_txn: Option<WriteTxn<'a>>,
294    savepoint_stack: Vec<SavepointEntry>,
295    in_place_saved: Option<bool>,
296    pub(crate) stmt_cache: LruCache<String, CacheEntry>,
297    txn_start_ts: Option<i64>,
298    session_timezone: String,
299}
300
301/// SQL connection with LRU statement cache.
302pub struct Connection<'a> {
303    pub(crate) db: &'a Database,
304    pub(crate) inner: RefCell<ConnectionInner<'a>>,
305}
306
307impl<'a> Connection<'a> {
308    /// Open a SQL connection to a database.
309    pub fn open(db: &'a Database) -> Result<Self> {
310        let schema = SchemaManager::load(db)?;
311        let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
312        Ok(Self {
313            db,
314            inner: RefCell::new(ConnectionInner {
315                schema,
316                active_txn: None,
317                savepoint_stack: Vec::new(),
318                in_place_saved: None,
319                stmt_cache,
320                txn_start_ts: None,
321                session_timezone: "UTC".to_string(),
322            }),
323        })
324    }
325
326    /// Txn-start UTC µs inside BEGIN/COMMIT, else `None`.
327    pub fn txn_start_ts(&self) -> Option<i64> {
328        self.inner.borrow().txn_start_ts
329    }
330
331    /// Returns the session time-zone (IANA name or fixed offset). Default `"UTC"`.
332    pub fn session_timezone(&self) -> String {
333        self.inner.borrow().session_timezone.clone()
334    }
335
336    /// Set the session time-zone. Accepts IANA names, ISO-8601 offsets, `"UTC"`, `"Z"`.
337    pub fn set_session_timezone(&self, tz: &str) -> Result<()> {
338        self.inner.borrow_mut().set_session_timezone_impl(tz)
339    }
340
341    /// Execute a SQL statement. Returns the result.
342    pub fn execute(&self, sql: &str) -> Result<ExecutionResult> {
343        self.inner.borrow_mut().execute_impl(self.db, sql)
344    }
345
346    /// Execute a SQL statement with positional parameters ($1, $2, ...).
347    pub fn execute_params(&self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
348        self.inner
349            .borrow_mut()
350            .execute_params_impl(self.db, sql, params)
351    }
352
353    /// Execute `;`-separated SQL statements. Stops at the first failure.
354    pub fn execute_script(&self, sql: &str) -> ScriptExecution {
355        let stmts = match parser::parse_sql_multi(sql) {
356            Ok(s) => s,
357            Err(e) => {
358                return ScriptExecution {
359                    completed: vec![],
360                    error: Some(e),
361                }
362            }
363        };
364        let mut completed = Vec::with_capacity(stmts.len());
365        for stmt in stmts {
366            match self.inner.borrow_mut().dispatch(self.db, &stmt, &[]) {
367                Ok(r) => completed.push(r),
368                Err(e) => {
369                    return ScriptExecution {
370                        completed,
371                        error: Some(e),
372                    }
373                }
374            }
375        }
376        ScriptExecution {
377            completed,
378            error: None,
379        }
380    }
381
382    /// Execute a SQL query and return the result set.
383    pub fn query(&self, sql: &str) -> Result<QueryResult> {
384        self.query_params(sql, &[])
385    }
386
387    /// Execute a SQL query with positional parameters ($1, $2, ...).
388    pub fn query_params(&self, sql: &str, params: &[Value]) -> Result<QueryResult> {
389        match self.execute_params(sql, params)? {
390            ExecutionResult::Query(qr) => Ok(qr),
391            ExecutionResult::RowsAffected(n) => Ok(QueryResult {
392                columns: vec!["rows_affected".into()],
393                rows: vec![vec![Value::Integer(n as i64)]],
394            }),
395            ExecutionResult::Ok => Ok(QueryResult {
396                columns: vec![],
397                rows: vec![],
398            }),
399        }
400    }
401
402    /// Prepare a SQL statement for repeated execution with parameters.
403    pub fn prepare(&self, sql: &str) -> Result<PreparedStatement<'_, 'a>> {
404        PreparedStatement::new(self, sql)
405    }
406
407    /// List all table names.
408    pub fn tables(&self) -> Vec<String> {
409        self.inner
410            .borrow()
411            .schema
412            .table_names()
413            .into_iter()
414            .map(String::from)
415            .collect()
416    }
417
418    /// Returns true if an explicit transaction is active (BEGIN was issued).
419    pub fn in_transaction(&self) -> bool {
420        self.inner.borrow().active_txn.is_some()
421    }
422
423    /// Get the schema for a named table.
424    pub fn table_schema(&self, name: &str) -> Option<TableSchema> {
425        self.inner.borrow().schema.get(name).cloned()
426    }
427
428    /// Reload schemas from the database.
429    pub fn refresh_schema(&self) -> Result<()> {
430        let new_schema = SchemaManager::load(self.db)?;
431        self.inner.borrow_mut().schema = new_schema;
432        Ok(())
433    }
434}
435
436impl<'a> ConnectionInner<'a> {
437    pub(crate) fn active_txn_is_some(&self) -> bool {
438        self.active_txn.is_some()
439    }
440
441    fn set_session_timezone_impl(&mut self, tz: &str) -> Result<()> {
442        let upper = tz.to_ascii_uppercase();
443        if (upper.starts_with("UTC+") || upper.starts_with("UTC-")) && tz.len() > 3 {
444            return Err(SqlError::InvalidTimezone(format!(
445                "'{tz}' is ambiguous; use ISO-8601 offset (e.g. '+05:00') or named zone (e.g. 'Etc/GMT-5')"
446            )));
447        }
448        if jiff::tz::TimeZone::get(tz).is_err() && parse_fixed_offset(tz).is_none() {
449            return Err(SqlError::InvalidTimezone(format!(
450                "{tz}: not a known IANA zone or ISO-8601 offset (e.g. '+05:00', 'UTC', 'America/New_York')"
451            )));
452        }
453        self.session_timezone = tz.to_string();
454        Ok(())
455    }
456
457    fn execute_impl(&mut self, db: &'a Database, sql: &str) -> Result<ExecutionResult> {
458        if matches!(sql.as_bytes().first(), Some(b'I' | b'i')) {
459            if let Some((normalized_key, extracted)) = try_normalize_insert(sql) {
460                let gen = self.schema.generation();
461                let stmt = if let Some(entry) = self.stmt_cache.get(&normalized_key) {
462                    if entry.schema_gen == gen {
463                        Arc::clone(&entry.stmt)
464                    } else {
465                        self.parse_and_cache(normalized_key, gen)?
466                    }
467                } else {
468                    self.parse_and_cache(normalized_key, gen)?
469                };
470                return self.dispatch(db, &stmt, &extracted);
471            }
472        }
473        self.execute_params_impl(db, sql, &[])
474    }
475
476    fn execute_params_impl(
477        &mut self,
478        db: &'a Database,
479        sql: &str,
480        params: &[Value],
481    ) -> Result<ExecutionResult> {
482        let gen = self.schema.generation();
483        if self.active_txn.is_none() {
484            if let Some(entry) = self.stmt_cache.get(sql) {
485                if entry.schema_gen == gen && entry.param_count == params.len() {
486                    if let Some(plan) = entry.compiled.as_ref().map(Arc::clone) {
487                        let stmt = Arc::clone(&entry.stmt);
488                        return self.run_compiled(db, &plan, &stmt, params);
489                    }
490                }
491            }
492        }
493
494        let (stmt, param_count) = self.get_or_parse(sql)?;
495
496        if param_count != params.len() {
497            return Err(SqlError::ParameterCountMismatch {
498                expected: param_count,
499                got: params.len(),
500            });
501        }
502
503        if self.active_txn.is_none() {
504            if let Some(plan) = executor::compile(&self.schema, &stmt) {
505                if let Some(e) = self.stmt_cache.get_mut(sql) {
506                    e.compiled = Some(Arc::clone(&plan));
507                }
508                let stmt_owned = Arc::clone(&stmt);
509                return self.run_compiled(db, &plan, &stmt_owned, params);
510            }
511        }
512
513        self.dispatch(db, &stmt, params)
514    }
515
516    fn run_compiled(
517        &mut self,
518        db: &'a Database,
519        plan: &Arc<dyn executor::CompiledPlan>,
520        stmt: &Statement,
521        params: &[Value],
522    ) -> Result<ExecutionResult> {
523        let cached_ts = self
524            .txn_start_ts
525            .or_else(|| Some(crate::datetime::now_micros()));
526        let schema = &self.schema;
527        crate::datetime::with_txn_clock(cached_ts, || {
528            if params.is_empty() {
529                plan.execute(db, schema, stmt, params, None)
530            } else {
531                crate::eval::with_scoped_params(params, || {
532                    plan.execute(db, schema, stmt, params, None)
533                })
534            }
535        })
536    }
537
538    pub(crate) fn parse_and_cache(
539        &mut self,
540        normalized_key: String,
541        gen: u64,
542    ) -> Result<Arc<Statement>> {
543        let stmt = Arc::new(parser::parse_sql(&normalized_key)?);
544        let param_count = parser::count_params(&stmt);
545        self.stmt_cache.put(
546            normalized_key,
547            CacheEntry {
548                stmt: Arc::clone(&stmt),
549                schema_gen: gen,
550                param_count,
551                compiled: None,
552            },
553        );
554        Ok(stmt)
555    }
556
557    pub(crate) fn get_or_parse(&mut self, sql: &str) -> Result<(Arc<Statement>, usize)> {
558        let gen = self.schema.generation();
559
560        if let Some(entry) = self.stmt_cache.get(sql) {
561            if entry.schema_gen == gen {
562                return Ok((Arc::clone(&entry.stmt), entry.param_count));
563            }
564        }
565
566        let stmt = Arc::new(parser::parse_sql(sql)?);
567        let param_count = parser::count_params(&stmt);
568
569        let cacheable = !matches!(
570            *stmt,
571            Statement::CreateTable(_)
572                | Statement::DropTable(_)
573                | Statement::CreateIndex(_)
574                | Statement::DropIndex(_)
575                | Statement::CreateView(_)
576                | Statement::DropView(_)
577                | Statement::AlterTable(_)
578        );
579
580        if cacheable {
581            self.stmt_cache.put(
582                sql.to_string(),
583                CacheEntry {
584                    stmt: Arc::clone(&stmt),
585                    schema_gen: gen,
586                    param_count,
587                    compiled: None,
588                },
589            );
590        }
591
592        Ok((stmt, param_count))
593    }
594
595    pub(crate) fn execute_prepared(
596        &mut self,
597        db: &'a Database,
598        stmt: &Statement,
599        compiled: Option<&Arc<dyn executor::CompiledPlan>>,
600        params: &[Value],
601    ) -> Result<ExecutionResult> {
602        if let Some(plan) = compiled {
603            if self.active_txn.is_none() {
604                return self.run_compiled(db, plan, stmt, params);
605            }
606            if !self.savepoint_stack.is_empty() && stmt_mutates(stmt) {
607                self.capture_pending_snapshots();
608            }
609            return self.run_compiled_in_txn(db, plan, stmt, params);
610        }
611        self.dispatch(db, stmt, params)
612    }
613
614    fn run_compiled_in_txn(
615        &mut self,
616        db: &'a Database,
617        plan: &Arc<dyn executor::CompiledPlan>,
618        stmt: &Statement,
619        params: &[Value],
620    ) -> Result<ExecutionResult> {
621        let schema = &self.schema;
622        let wtx = self.active_txn.as_mut();
623        if params.is_empty() || !plan.uses_scoped_params() {
624            plan.execute(db, schema, stmt, params, wtx)
625        } else {
626            crate::eval::with_scoped_params(params, || plan.execute(db, schema, stmt, params, wtx))
627        }
628    }
629
630    pub(crate) fn dispatch(
631        &mut self,
632        db: &'a Database,
633        stmt: &Statement,
634        params: &[Value],
635    ) -> Result<ExecutionResult> {
636        let cached_ts = self
637            .txn_start_ts
638            .or_else(|| Some(crate::datetime::now_micros()));
639        crate::datetime::with_txn_clock(cached_ts, || {
640            if params.is_empty() {
641                self.dispatch_inner(db, stmt, params)
642            } else {
643                crate::eval::with_scoped_params(params, || self.dispatch_inner(db, stmt, params))
644            }
645        })
646    }
647
648    fn dispatch_inner(
649        &mut self,
650        db: &'a Database,
651        stmt: &Statement,
652        params: &[Value],
653    ) -> Result<ExecutionResult> {
654        match stmt {
655            Statement::Begin => {
656                if self.active_txn.is_some() {
657                    return Err(SqlError::TransactionAlreadyActive);
658                }
659                let wtx = db.begin_write().map_err(SqlError::Storage)?;
660                self.active_txn = Some(wtx);
661                let ts = crate::datetime::txn_or_clock_micros();
662                self.txn_start_ts = Some(ts);
663                crate::datetime::set_txn_clock(Some(ts));
664                Ok(ExecutionResult::Ok)
665            }
666            Statement::Commit => {
667                let wtx = self
668                    .active_txn
669                    .take()
670                    .ok_or(SqlError::NoActiveTransaction)?;
671                wtx.commit().map_err(SqlError::Storage)?;
672                self.clear_savepoint_state();
673                self.txn_start_ts = None;
674                crate::datetime::set_txn_clock(None);
675                Ok(ExecutionResult::Ok)
676            }
677            Statement::Rollback => {
678                let wtx = self
679                    .active_txn
680                    .take()
681                    .ok_or(SqlError::NoActiveTransaction)?;
682                wtx.abort();
683                self.clear_savepoint_state();
684                self.schema = SchemaManager::load(db)?;
685                self.txn_start_ts = None;
686                crate::datetime::set_txn_clock(None);
687                Ok(ExecutionResult::Ok)
688            }
689            Statement::Savepoint(name) => self.do_savepoint(name),
690            Statement::ReleaseSavepoint(name) => self.do_release(name),
691            Statement::RollbackTo(name) => self.do_rollback_to(name),
692            Statement::SetTimezone(zone) => {
693                self.set_session_timezone_impl(zone)?;
694                Ok(ExecutionResult::Ok)
695            }
696            Statement::Insert(ins) if self.active_txn.is_some() => {
697                self.capture_pending_snapshots();
698                let wtx = self.active_txn.as_mut().unwrap();
699                executor::exec_insert_in_txn(wtx, &self.schema, ins, params)
700            }
701            _ => {
702                if self.active_txn.is_some() && stmt_mutates(stmt) {
703                    self.capture_pending_snapshots();
704                }
705                if let Some(ref mut wtx) = self.active_txn {
706                    executor::execute_in_txn(wtx, &mut self.schema, stmt, params)
707                } else {
708                    executor::execute(db, &mut self.schema, stmt, params)
709                }
710            }
711        }
712    }
713
714    fn clear_savepoint_state(&mut self) {
715        self.savepoint_stack.clear();
716        self.in_place_saved = None;
717    }
718
719    fn do_savepoint(&mut self, name: &str) -> Result<ExecutionResult> {
720        let wtx = self
721            .active_txn
722            .as_mut()
723            .ok_or(SqlError::NoActiveTransaction)?;
724
725        if self.savepoint_stack.is_empty() {
726            self.in_place_saved = Some(wtx.in_place());
727            wtx.set_in_place(false);
728        }
729
730        self.savepoint_stack.push(SavepointEntry {
731            name: name.to_string(),
732            snapshot: None,
733        });
734
735        Ok(ExecutionResult::Ok)
736    }
737
738    fn capture_pending_snapshots(&mut self) {
739        let last_pending = match self
740            .savepoint_stack
741            .iter()
742            .rposition(|e| e.snapshot.is_none())
743        {
744            Some(i) => i,
745            None => return,
746        };
747        let wtx = match self.active_txn.as_mut() {
748            Some(w) => w,
749            None => return,
750        };
751        let wtx_snap = wtx.begin_savepoint();
752        let schema_snap = self.schema.save_snapshot();
753
754        for i in 0..last_pending {
755            if self.savepoint_stack[i].snapshot.is_none() {
756                self.savepoint_stack[i].snapshot = Some(SavepointSnapshot {
757                    wtx_snap: wtx_snap.clone(),
758                    schema_snap: schema_snap.clone(),
759                });
760            }
761        }
762        self.savepoint_stack[last_pending].snapshot = Some(SavepointSnapshot {
763            wtx_snap,
764            schema_snap,
765        });
766    }
767
768    fn do_release(&mut self, name: &str) -> Result<ExecutionResult> {
769        if self.active_txn.is_none() {
770            return Err(SqlError::NoActiveTransaction);
771        }
772
773        let idx = self
774            .savepoint_stack
775            .iter()
776            .rposition(|e| e.name == name)
777            .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
778        self.savepoint_stack.truncate(idx);
779
780        if self.savepoint_stack.is_empty() {
781            if let (Some(wtx), Some(original)) =
782                (self.active_txn.as_mut(), self.in_place_saved.take())
783            {
784                wtx.set_in_place(original);
785            }
786        }
787
788        Ok(ExecutionResult::Ok)
789    }
790
791    fn do_rollback_to(&mut self, name: &str) -> Result<ExecutionResult> {
792        if self.active_txn.is_none() {
793            return Err(SqlError::NoActiveTransaction);
794        }
795
796        let idx = self
797            .savepoint_stack
798            .iter()
799            .rposition(|e| e.name == name)
800            .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
801
802        self.savepoint_stack.truncate(idx + 1);
803        let entry = self.savepoint_stack.last_mut().unwrap();
804        let snapshot = match entry.snapshot.take() {
805            Some(s) => s,
806            None => return Ok(ExecutionResult::Ok),
807        };
808
809        let wtx = self.active_txn.as_mut().unwrap();
810        wtx.restore_snapshot(snapshot.wtx_snap);
811        self.schema.restore_snapshot(snapshot.schema_snap);
812
813        Ok(ExecutionResult::Ok)
814    }
815}