Skip to main content

citadel_sql/
connection.rs

1//! Public SQL connection API.
2
3use std::cell::RefCell;
4use std::num::NonZeroUsize;
5use std::sync::Arc;
6
7use lru::LruCache;
8
9use citadel::Database;
10use citadel_txn::write_txn::{WriteTxn, WriteTxnSnapshot};
11
12use crate::error::{Result, SqlError};
13use crate::executor;
14use crate::parser;
15use crate::parser::Statement;
16use crate::prepared::PreparedStatement;
17use crate::schema::{SchemaManager, SchemaSnapshot};
18use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
19
20const DEFAULT_CACHE_CAPACITY: usize = 64;
21
22fn parse_fixed_offset(s: &str) -> Option<()> {
23    let s = s.trim();
24    if s.eq_ignore_ascii_case("z") || s.eq_ignore_ascii_case("utc") {
25        return Some(());
26    }
27    let bytes = s.as_bytes();
28    if bytes.is_empty() {
29        return None;
30    }
31    let sign = match bytes[0] {
32        b'+' | b'-' => bytes[0] as char,
33        _ => return None,
34    };
35    let rest = &s[1..];
36    let (hh, mm) = if let Some((h, m)) = rest.split_once(':') {
37        (h, m)
38    } else if rest.len() == 4 {
39        (&rest[..2], &rest[2..])
40    } else if rest.len() == 2 {
41        (rest, "00")
42    } else {
43        return None;
44    };
45    let h: u32 = hh.parse().ok()?;
46    let m: u32 = mm.parse().ok()?;
47    if h > 23 || m > 59 {
48        return None;
49    }
50    let _ = sign;
51    Some(())
52}
53
54fn stmt_mutates(stmt: &Statement) -> bool {
55    matches!(
56        stmt,
57        Statement::Insert(_)
58            | Statement::Update(_)
59            | Statement::Delete(_)
60            | Statement::CreateTable(_)
61            | Statement::DropTable(_)
62            | Statement::AlterTable(_)
63            | Statement::CreateIndex(_)
64            | Statement::DropIndex(_)
65            | Statement::CreateView(_)
66            | Statement::DropView(_)
67    )
68}
69
70fn try_normalize_insert(sql: &str) -> Option<(String, Vec<Value>)> {
71    let bytes = sql.as_bytes();
72    let len = bytes.len();
73    let mut i = 0;
74
75    while i < len && bytes[i].is_ascii_whitespace() {
76        i += 1;
77    }
78    if i + 6 > len || !bytes[i..i + 6].eq_ignore_ascii_case(b"INSERT") {
79        return None;
80    }
81    i += 6;
82    if i >= len || !bytes[i].is_ascii_whitespace() {
83        return None;
84    }
85    while i < len && bytes[i].is_ascii_whitespace() {
86        i += 1;
87    }
88
89    if i + 4 > len || !bytes[i..i + 4].eq_ignore_ascii_case(b"INTO") {
90        return None;
91    }
92    i += 4;
93    if i >= len || !bytes[i].is_ascii_whitespace() {
94        return None;
95    }
96
97    let prefix_start = 0;
98    let mut values_pos = None;
99    let mut j = i;
100    while j + 6 <= len {
101        if bytes[j..j + 6].eq_ignore_ascii_case(b"VALUES")
102            && (j == 0 || !bytes[j - 1].is_ascii_alphanumeric() && bytes[j - 1] != b'_')
103            && (j + 6 >= len || !bytes[j + 6].is_ascii_alphanumeric() && bytes[j + 6] != b'_')
104        {
105            values_pos = Some(j);
106            break;
107        }
108        j += 1;
109    }
110    let values_pos = values_pos?;
111
112    let prefix = &sql[prefix_start..values_pos + 6];
113    let mut pos = values_pos + 6;
114
115    while pos < len && bytes[pos].is_ascii_whitespace() {
116        pos += 1;
117    }
118    if pos >= len || bytes[pos] != b'(' {
119        return None;
120    }
121    pos += 1;
122
123    let mut values = Vec::new();
124    let mut normalized = String::with_capacity(sql.len());
125    normalized.push_str(prefix);
126    normalized.push_str(" (");
127
128    loop {
129        while pos < len && bytes[pos].is_ascii_whitespace() {
130            pos += 1;
131        }
132        if pos >= len {
133            return None;
134        }
135
136        let param_idx = values.len() + 1;
137        if param_idx > 1 {
138            normalized.push_str(", ");
139        }
140
141        if bytes[pos] == b'\'' {
142            pos += 1;
143            let mut seg_start = pos;
144            let mut s = String::new();
145            loop {
146                if pos >= len {
147                    return None;
148                }
149                if bytes[pos] == b'\'' {
150                    s.push_str(std::str::from_utf8(&bytes[seg_start..pos]).ok()?);
151                    pos += 1;
152                    if pos < len && bytes[pos] == b'\'' {
153                        s.push('\'');
154                        pos += 1;
155                        seg_start = pos;
156                    } else {
157                        break;
158                    }
159                } else {
160                    pos += 1;
161                }
162            }
163            values.push(Value::Text(s.into()));
164        } else if bytes[pos] == b'-' || bytes[pos].is_ascii_digit() {
165            let start = pos;
166            if bytes[pos] == b'-' {
167                pos += 1;
168            }
169            while pos < len && bytes[pos].is_ascii_digit() {
170                pos += 1;
171            }
172            if pos < len && bytes[pos] == b'.' {
173                pos += 1;
174                while pos < len && bytes[pos].is_ascii_digit() {
175                    pos += 1;
176                }
177                let num: f64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
178                values.push(Value::Real(num));
179            } else {
180                let num: i64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
181                values.push(Value::Integer(num));
182            }
183        } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"NULL") {
184            let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
185            if !after.is_ascii_alphanumeric() && after != b'_' {
186                pos += 4;
187                values.push(Value::Null);
188            } else {
189                return None;
190            }
191        } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"TRUE") {
192            let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
193            if !after.is_ascii_alphanumeric() && after != b'_' {
194                pos += 4;
195                values.push(Value::Boolean(true));
196            } else {
197                return None;
198            }
199        } else if pos + 5 <= len && bytes[pos..pos + 5].eq_ignore_ascii_case(b"FALSE") {
200            let after = if pos + 5 < len { bytes[pos + 5] } else { b')' };
201            if !after.is_ascii_alphanumeric() && after != b'_' {
202                pos += 5;
203                values.push(Value::Boolean(false));
204            } else {
205                return None;
206            }
207        } else {
208            return None;
209        }
210
211        normalized.push('$');
212        normalized.push_str(&param_idx.to_string());
213
214        while pos < len && bytes[pos].is_ascii_whitespace() {
215            pos += 1;
216        }
217        if pos >= len {
218            return None;
219        }
220
221        if bytes[pos] == b',' {
222            pos += 1;
223        } else if bytes[pos] == b')' {
224            pos += 1;
225            break;
226        } else {
227            return None;
228        }
229    }
230
231    normalized.push(')');
232
233    while pos < len && (bytes[pos].is_ascii_whitespace() || bytes[pos] == b';') {
234        pos += 1;
235    }
236    if pos != len {
237        return None;
238    }
239
240    if values.is_empty() {
241        return None;
242    }
243
244    Some((normalized, values))
245}
246
247pub(crate) struct CacheEntry {
248    pub(crate) stmt: Arc<Statement>,
249    pub(crate) schema_gen: u64,
250    pub(crate) param_count: usize,
251    pub(crate) compiled: Option<Arc<dyn executor::CompiledPlan>>,
252}
253
254struct SavepointEntry {
255    name: String,
256    snapshot: Option<SavepointSnapshot>,
257}
258
259struct SavepointSnapshot {
260    wtx_snap: WriteTxnSnapshot,
261    schema_snap: SchemaSnapshot,
262}
263
264pub(crate) struct ConnectionInner<'a> {
265    pub(crate) schema: SchemaManager,
266    active_txn: Option<WriteTxn<'a>>,
267    savepoint_stack: Vec<SavepointEntry>,
268    in_place_saved: Option<bool>,
269    pub(crate) stmt_cache: LruCache<String, CacheEntry>,
270    txn_start_ts: Option<i64>,
271    session_timezone: String,
272}
273
274/// SQL connection with LRU statement cache.
275pub struct Connection<'a> {
276    pub(crate) db: &'a Database,
277    pub(crate) inner: RefCell<ConnectionInner<'a>>,
278}
279
280impl<'a> Connection<'a> {
281    /// Open a SQL connection to a database.
282    pub fn open(db: &'a Database) -> Result<Self> {
283        let schema = SchemaManager::load(db)?;
284        let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
285        Ok(Self {
286            db,
287            inner: RefCell::new(ConnectionInner {
288                schema,
289                active_txn: None,
290                savepoint_stack: Vec::new(),
291                in_place_saved: None,
292                stmt_cache,
293                txn_start_ts: None,
294                session_timezone: "UTC".to_string(),
295            }),
296        })
297    }
298
299    /// Txn-start UTC µs inside BEGIN/COMMIT, else `None`.
300    pub fn txn_start_ts(&self) -> Option<i64> {
301        self.inner.borrow().txn_start_ts
302    }
303
304    /// Returns the session time-zone (IANA name or fixed offset). Default `"UTC"`.
305    pub fn session_timezone(&self) -> String {
306        self.inner.borrow().session_timezone.clone()
307    }
308
309    /// Set the session time-zone. Accepts IANA names, ISO-8601 offsets, `"UTC"`, `"Z"`.
310    pub fn set_session_timezone(&self, tz: &str) -> Result<()> {
311        self.inner.borrow_mut().set_session_timezone_impl(tz)
312    }
313
314    /// Execute a SQL statement. Returns the result.
315    pub fn execute(&self, sql: &str) -> Result<ExecutionResult> {
316        self.inner.borrow_mut().execute_impl(self.db, sql)
317    }
318
319    /// Execute a SQL statement with positional parameters ($1, $2, ...).
320    pub fn execute_params(&self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
321        self.inner
322            .borrow_mut()
323            .execute_params_impl(self.db, sql, params)
324    }
325
326    /// Execute a SQL query and return the result set.
327    pub fn query(&self, sql: &str) -> Result<QueryResult> {
328        self.query_params(sql, &[])
329    }
330
331    /// Execute a SQL query with positional parameters ($1, $2, ...).
332    pub fn query_params(&self, sql: &str, params: &[Value]) -> Result<QueryResult> {
333        match self.execute_params(sql, params)? {
334            ExecutionResult::Query(qr) => Ok(qr),
335            ExecutionResult::RowsAffected(n) => Ok(QueryResult {
336                columns: vec!["rows_affected".into()],
337                rows: vec![vec![Value::Integer(n as i64)]],
338            }),
339            ExecutionResult::Ok => Ok(QueryResult {
340                columns: vec![],
341                rows: vec![],
342            }),
343        }
344    }
345
346    /// Prepare a SQL statement for repeated execution with parameters.
347    pub fn prepare(&self, sql: &str) -> Result<PreparedStatement<'_, 'a>> {
348        PreparedStatement::new(self, sql)
349    }
350
351    /// List all table names.
352    pub fn tables(&self) -> Vec<String> {
353        self.inner
354            .borrow()
355            .schema
356            .table_names()
357            .into_iter()
358            .map(String::from)
359            .collect()
360    }
361
362    /// Returns true if an explicit transaction is active (BEGIN was issued).
363    pub fn in_transaction(&self) -> bool {
364        self.inner.borrow().active_txn.is_some()
365    }
366
367    /// Get the schema for a named table.
368    pub fn table_schema(&self, name: &str) -> Option<TableSchema> {
369        self.inner.borrow().schema.get(name).cloned()
370    }
371
372    /// Reload schemas from the database.
373    pub fn refresh_schema(&self) -> Result<()> {
374        let new_schema = SchemaManager::load(self.db)?;
375        self.inner.borrow_mut().schema = new_schema;
376        Ok(())
377    }
378}
379
380impl<'a> ConnectionInner<'a> {
381    pub(crate) fn active_txn_is_some(&self) -> bool {
382        self.active_txn.is_some()
383    }
384
385    fn set_session_timezone_impl(&mut self, tz: &str) -> Result<()> {
386        let upper = tz.to_ascii_uppercase();
387        if (upper.starts_with("UTC+") || upper.starts_with("UTC-")) && tz.len() > 3 {
388            return Err(SqlError::InvalidTimezone(format!(
389                "'{tz}' is ambiguous; use ISO-8601 offset (e.g. '+05:00') or named zone (e.g. 'Etc/GMT-5')"
390            )));
391        }
392        if jiff::tz::TimeZone::get(tz).is_err() && parse_fixed_offset(tz).is_none() {
393            return Err(SqlError::InvalidTimezone(format!(
394                "{tz}: not a known IANA zone or ISO-8601 offset (e.g. '+05:00', 'UTC', 'America/New_York')"
395            )));
396        }
397        self.session_timezone = tz.to_string();
398        Ok(())
399    }
400
401    fn execute_impl(&mut self, db: &'a Database, sql: &str) -> Result<ExecutionResult> {
402        if matches!(sql.as_bytes().first(), Some(b'I' | b'i')) {
403            if let Some((normalized_key, extracted)) = try_normalize_insert(sql) {
404                let gen = self.schema.generation();
405                let stmt = if let Some(entry) = self.stmt_cache.get(&normalized_key) {
406                    if entry.schema_gen == gen {
407                        Arc::clone(&entry.stmt)
408                    } else {
409                        self.parse_and_cache(normalized_key, gen)?
410                    }
411                } else {
412                    self.parse_and_cache(normalized_key, gen)?
413                };
414                return self.dispatch(db, &stmt, &extracted);
415            }
416        }
417        self.execute_params_impl(db, sql, &[])
418    }
419
420    fn execute_params_impl(
421        &mut self,
422        db: &'a Database,
423        sql: &str,
424        params: &[Value],
425    ) -> Result<ExecutionResult> {
426        let gen = self.schema.generation();
427        if self.active_txn.is_none() {
428            if let Some(entry) = self.stmt_cache.get(sql) {
429                if entry.schema_gen == gen && entry.param_count == params.len() {
430                    if let Some(plan) = entry.compiled.as_ref().map(Arc::clone) {
431                        let stmt = Arc::clone(&entry.stmt);
432                        return self.run_compiled(db, &plan, &stmt, params);
433                    }
434                }
435            }
436        }
437
438        let (stmt, param_count) = self.get_or_parse(sql)?;
439
440        if param_count != params.len() {
441            return Err(SqlError::ParameterCountMismatch {
442                expected: param_count,
443                got: params.len(),
444            });
445        }
446
447        if self.active_txn.is_none() {
448            if let Some(plan) = executor::compile(&self.schema, &stmt) {
449                if let Some(e) = self.stmt_cache.get_mut(sql) {
450                    e.compiled = Some(Arc::clone(&plan));
451                }
452                let stmt_owned = Arc::clone(&stmt);
453                return self.run_compiled(db, &plan, &stmt_owned, params);
454            }
455        }
456
457        self.dispatch(db, &stmt, params)
458    }
459
460    fn run_compiled(
461        &mut self,
462        db: &'a Database,
463        plan: &Arc<dyn executor::CompiledPlan>,
464        stmt: &Statement,
465        params: &[Value],
466    ) -> Result<ExecutionResult> {
467        let cached_ts = self
468            .txn_start_ts
469            .or_else(|| Some(crate::datetime::now_micros()));
470        let schema = &self.schema;
471        crate::datetime::with_txn_clock(cached_ts, || {
472            if params.is_empty() {
473                plan.execute(db, schema, stmt, params, None)
474            } else {
475                crate::eval::with_scoped_params(params, || {
476                    plan.execute(db, schema, stmt, params, None)
477                })
478            }
479        })
480    }
481
482    pub(crate) fn parse_and_cache(
483        &mut self,
484        normalized_key: String,
485        gen: u64,
486    ) -> Result<Arc<Statement>> {
487        let stmt = Arc::new(parser::parse_sql(&normalized_key)?);
488        let param_count = parser::count_params(&stmt);
489        self.stmt_cache.put(
490            normalized_key,
491            CacheEntry {
492                stmt: Arc::clone(&stmt),
493                schema_gen: gen,
494                param_count,
495                compiled: None,
496            },
497        );
498        Ok(stmt)
499    }
500
501    pub(crate) fn get_or_parse(&mut self, sql: &str) -> Result<(Arc<Statement>, usize)> {
502        let gen = self.schema.generation();
503
504        if let Some(entry) = self.stmt_cache.get(sql) {
505            if entry.schema_gen == gen {
506                return Ok((Arc::clone(&entry.stmt), entry.param_count));
507            }
508        }
509
510        let stmt = Arc::new(parser::parse_sql(sql)?);
511        let param_count = parser::count_params(&stmt);
512
513        let cacheable = !matches!(
514            *stmt,
515            Statement::CreateTable(_)
516                | Statement::DropTable(_)
517                | Statement::CreateIndex(_)
518                | Statement::DropIndex(_)
519                | Statement::CreateView(_)
520                | Statement::DropView(_)
521                | Statement::AlterTable(_)
522        );
523
524        if cacheable {
525            self.stmt_cache.put(
526                sql.to_string(),
527                CacheEntry {
528                    stmt: Arc::clone(&stmt),
529                    schema_gen: gen,
530                    param_count,
531                    compiled: None,
532                },
533            );
534        }
535
536        Ok((stmt, param_count))
537    }
538
539    pub(crate) fn execute_prepared(
540        &mut self,
541        db: &'a Database,
542        stmt: &Statement,
543        compiled: Option<&Arc<dyn executor::CompiledPlan>>,
544        params: &[Value],
545    ) -> Result<ExecutionResult> {
546        if let Some(plan) = compiled {
547            if self.active_txn.is_none() {
548                return self.run_compiled(db, plan, stmt, params);
549            }
550            if stmt_mutates(stmt) {
551                self.capture_pending_snapshots();
552            }
553            return self.run_compiled_in_txn(db, plan, stmt, params);
554        }
555        self.dispatch(db, stmt, params)
556    }
557
558    fn run_compiled_in_txn(
559        &mut self,
560        db: &'a Database,
561        plan: &Arc<dyn executor::CompiledPlan>,
562        stmt: &Statement,
563        params: &[Value],
564    ) -> Result<ExecutionResult> {
565        let cached_ts = self
566            .txn_start_ts
567            .or_else(|| Some(crate::datetime::now_micros()));
568        let schema = &self.schema;
569        let wtx = self.active_txn.as_mut();
570        crate::datetime::with_txn_clock(cached_ts, || {
571            if params.is_empty() {
572                plan.execute(db, schema, stmt, params, wtx)
573            } else {
574                crate::eval::with_scoped_params(params, || {
575                    plan.execute(db, schema, stmt, params, wtx)
576                })
577            }
578        })
579    }
580
581    pub(crate) fn dispatch(
582        &mut self,
583        db: &'a Database,
584        stmt: &Statement,
585        params: &[Value],
586    ) -> Result<ExecutionResult> {
587        let cached_ts = self
588            .txn_start_ts
589            .or_else(|| Some(crate::datetime::now_micros()));
590        crate::datetime::with_txn_clock(cached_ts, || {
591            if params.is_empty() {
592                self.dispatch_inner(db, stmt, params)
593            } else {
594                crate::eval::with_scoped_params(params, || self.dispatch_inner(db, stmt, params))
595            }
596        })
597    }
598
599    fn dispatch_inner(
600        &mut self,
601        db: &'a Database,
602        stmt: &Statement,
603        params: &[Value],
604    ) -> Result<ExecutionResult> {
605        match stmt {
606            Statement::Begin => {
607                if self.active_txn.is_some() {
608                    return Err(SqlError::TransactionAlreadyActive);
609                }
610                let wtx = db.begin_write().map_err(SqlError::Storage)?;
611                self.active_txn = Some(wtx);
612                self.txn_start_ts = Some(crate::datetime::txn_or_clock_micros());
613                Ok(ExecutionResult::Ok)
614            }
615            Statement::Commit => {
616                let wtx = self
617                    .active_txn
618                    .take()
619                    .ok_or(SqlError::NoActiveTransaction)?;
620                wtx.commit().map_err(SqlError::Storage)?;
621                self.clear_savepoint_state();
622                self.txn_start_ts = None;
623                Ok(ExecutionResult::Ok)
624            }
625            Statement::Rollback => {
626                let wtx = self
627                    .active_txn
628                    .take()
629                    .ok_or(SqlError::NoActiveTransaction)?;
630                wtx.abort();
631                self.clear_savepoint_state();
632                self.schema = SchemaManager::load(db)?;
633                self.txn_start_ts = None;
634                Ok(ExecutionResult::Ok)
635            }
636            Statement::Savepoint(name) => self.do_savepoint(name),
637            Statement::ReleaseSavepoint(name) => self.do_release(name),
638            Statement::RollbackTo(name) => self.do_rollback_to(name),
639            Statement::SetTimezone(zone) => {
640                self.set_session_timezone_impl(zone)?;
641                Ok(ExecutionResult::Ok)
642            }
643            Statement::Insert(ins) if self.active_txn.is_some() => {
644                self.capture_pending_snapshots();
645                let wtx = self.active_txn.as_mut().unwrap();
646                executor::exec_insert_in_txn(wtx, &self.schema, ins, params)
647            }
648            _ => {
649                if self.active_txn.is_some() && stmt_mutates(stmt) {
650                    self.capture_pending_snapshots();
651                }
652                if let Some(ref mut wtx) = self.active_txn {
653                    executor::execute_in_txn(wtx, &mut self.schema, stmt, params)
654                } else {
655                    executor::execute(db, &mut self.schema, stmt, params)
656                }
657            }
658        }
659    }
660
661    fn clear_savepoint_state(&mut self) {
662        self.savepoint_stack.clear();
663        self.in_place_saved = None;
664    }
665
666    fn do_savepoint(&mut self, name: &str) -> Result<ExecutionResult> {
667        let wtx = self
668            .active_txn
669            .as_mut()
670            .ok_or(SqlError::NoActiveTransaction)?;
671
672        if self.savepoint_stack.is_empty() {
673            self.in_place_saved = Some(wtx.in_place());
674            wtx.set_in_place(false);
675        }
676
677        self.savepoint_stack.push(SavepointEntry {
678            name: name.to_string(),
679            snapshot: None,
680        });
681
682        Ok(ExecutionResult::Ok)
683    }
684
685    fn capture_pending_snapshots(&mut self) {
686        if !self.savepoint_stack.iter().any(|e| e.snapshot.is_none()) {
687            return;
688        }
689        let wtx = match self.active_txn.as_mut() {
690            Some(w) => w,
691            None => return,
692        };
693        let wtx_snap = wtx.begin_savepoint();
694        let schema_snap = self.schema.save_snapshot();
695        let mut pending = self
696            .savepoint_stack
697            .iter_mut()
698            .filter(|e| e.snapshot.is_none());
699        if let Some(first) = pending.next() {
700            first.snapshot = Some(SavepointSnapshot {
701                wtx_snap: wtx_snap.clone(),
702                schema_snap: schema_snap.clone(),
703            });
704        }
705        for entry in pending {
706            entry.snapshot = Some(SavepointSnapshot {
707                wtx_snap: wtx_snap.clone(),
708                schema_snap: schema_snap.clone(),
709            });
710        }
711    }
712
713    fn do_release(&mut self, name: &str) -> Result<ExecutionResult> {
714        if self.active_txn.is_none() {
715            return Err(SqlError::NoActiveTransaction);
716        }
717
718        let idx = self
719            .savepoint_stack
720            .iter()
721            .rposition(|e| e.name == name)
722            .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
723        self.savepoint_stack.truncate(idx);
724
725        if self.savepoint_stack.is_empty() {
726            if let (Some(wtx), Some(original)) =
727                (self.active_txn.as_mut(), self.in_place_saved.take())
728            {
729                wtx.set_in_place(original);
730            }
731        }
732
733        Ok(ExecutionResult::Ok)
734    }
735
736    fn do_rollback_to(&mut self, name: &str) -> Result<ExecutionResult> {
737        if self.active_txn.is_none() {
738            return Err(SqlError::NoActiveTransaction);
739        }
740
741        let idx = self
742            .savepoint_stack
743            .iter()
744            .rposition(|e| e.name == name)
745            .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
746
747        self.savepoint_stack.truncate(idx + 1);
748        let entry = self.savepoint_stack.last_mut().unwrap();
749        let snapshot = match entry.snapshot.take() {
750            Some(s) => s,
751            None => return Ok(ExecutionResult::Ok),
752        };
753
754        let wtx = self.active_txn.as_mut().unwrap();
755        wtx.restore_snapshot(snapshot.wtx_snap);
756        self.schema.restore_snapshot(snapshot.schema_snap);
757
758        self.stmt_cache.clear();
759
760        Ok(ExecutionResult::Ok)
761    }
762}