Skip to main content

citadel_sql/
connection.rs

1//! Public SQL connection API.
2
3use std::cell::RefCell;
4use std::num::NonZeroUsize;
5use std::sync::Arc;
6
7use lru::LruCache;
8
9use citadel::Database;
10use citadel_txn::write_txn::{WriteTxn, WriteTxnSnapshot};
11
12use crate::error::{Result, SqlError};
13use crate::executor;
14use crate::parser;
15use crate::parser::Statement;
16use crate::prepared::PreparedStatement;
17use crate::schema::{SchemaManager, SchemaSnapshot};
18use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
19
20const DEFAULT_CACHE_CAPACITY: usize = 64;
21
22#[derive(Debug)]
23pub struct ScriptExecution {
24    pub completed: Vec<ExecutionResult>,
25    pub error: Option<SqlError>,
26}
27
28fn parse_fixed_offset(s: &str) -> Option<()> {
29    let s = s.trim();
30    if s.eq_ignore_ascii_case("z") || s.eq_ignore_ascii_case("utc") {
31        return Some(());
32    }
33    let bytes = s.as_bytes();
34    if bytes.is_empty() {
35        return None;
36    }
37    let sign = match bytes[0] {
38        b'+' | b'-' => bytes[0] as char,
39        _ => return None,
40    };
41    let rest = &s[1..];
42    let (hh, mm) = if let Some((h, m)) = rest.split_once(':') {
43        (h, m)
44    } else if rest.len() == 4 {
45        (&rest[..2], &rest[2..])
46    } else if rest.len() == 2 {
47        (rest, "00")
48    } else {
49        return None;
50    };
51    let h: u32 = hh.parse().ok()?;
52    let m: u32 = mm.parse().ok()?;
53    if h > 23 || m > 59 {
54        return None;
55    }
56    let _ = sign;
57    Some(())
58}
59
60fn stmt_mutates(stmt: &Statement) -> bool {
61    matches!(
62        stmt,
63        Statement::Insert(_)
64            | Statement::Update(_)
65            | Statement::Delete(_)
66            | Statement::CreateTable(_)
67            | Statement::DropTable(_)
68            | Statement::AlterTable(_)
69            | Statement::CreateIndex(_)
70            | Statement::DropIndex(_)
71            | Statement::CreateView(_)
72            | Statement::DropView(_)
73    )
74}
75
76fn try_normalize_insert(sql: &str) -> Option<(String, Vec<Value>)> {
77    let bytes = sql.as_bytes();
78    let len = bytes.len();
79    let mut i = 0;
80
81    while i < len && bytes[i].is_ascii_whitespace() {
82        i += 1;
83    }
84    if i + 6 > len || !bytes[i..i + 6].eq_ignore_ascii_case(b"INSERT") {
85        return None;
86    }
87    i += 6;
88    if i >= len || !bytes[i].is_ascii_whitespace() {
89        return None;
90    }
91    while i < len && bytes[i].is_ascii_whitespace() {
92        i += 1;
93    }
94
95    if i + 4 > len || !bytes[i..i + 4].eq_ignore_ascii_case(b"INTO") {
96        return None;
97    }
98    i += 4;
99    if i >= len || !bytes[i].is_ascii_whitespace() {
100        return None;
101    }
102
103    let prefix_start = 0;
104    let mut values_pos = None;
105    let mut j = i;
106    while j + 6 <= len {
107        if bytes[j..j + 6].eq_ignore_ascii_case(b"VALUES")
108            && (j == 0 || !bytes[j - 1].is_ascii_alphanumeric() && bytes[j - 1] != b'_')
109            && (j + 6 >= len || !bytes[j + 6].is_ascii_alphanumeric() && bytes[j + 6] != b'_')
110        {
111            values_pos = Some(j);
112            break;
113        }
114        j += 1;
115    }
116    let values_pos = values_pos?;
117
118    let prefix = &sql[prefix_start..values_pos + 6];
119    let mut pos = values_pos + 6;
120
121    while pos < len && bytes[pos].is_ascii_whitespace() {
122        pos += 1;
123    }
124    if pos >= len || bytes[pos] != b'(' {
125        return None;
126    }
127    pos += 1;
128
129    let mut values = Vec::new();
130    let mut normalized = String::with_capacity(sql.len());
131    normalized.push_str(prefix);
132    normalized.push_str(" (");
133
134    loop {
135        while pos < len && bytes[pos].is_ascii_whitespace() {
136            pos += 1;
137        }
138        if pos >= len {
139            return None;
140        }
141
142        let param_idx = values.len() + 1;
143        if param_idx > 1 {
144            normalized.push_str(", ");
145        }
146
147        if bytes[pos] == b'\'' {
148            pos += 1;
149            let mut seg_start = pos;
150            let mut s = String::new();
151            loop {
152                if pos >= len {
153                    return None;
154                }
155                if bytes[pos] == b'\'' {
156                    s.push_str(std::str::from_utf8(&bytes[seg_start..pos]).ok()?);
157                    pos += 1;
158                    if pos < len && bytes[pos] == b'\'' {
159                        s.push('\'');
160                        pos += 1;
161                        seg_start = pos;
162                    } else {
163                        break;
164                    }
165                } else {
166                    pos += 1;
167                }
168            }
169            values.push(Value::Text(s.into()));
170        } else if bytes[pos] == b'-' || bytes[pos].is_ascii_digit() {
171            let start = pos;
172            if bytes[pos] == b'-' {
173                pos += 1;
174            }
175            while pos < len && bytes[pos].is_ascii_digit() {
176                pos += 1;
177            }
178            if pos < len && bytes[pos] == b'.' {
179                pos += 1;
180                while pos < len && bytes[pos].is_ascii_digit() {
181                    pos += 1;
182                }
183                let num: f64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
184                values.push(Value::Real(num));
185            } else {
186                let num: i64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
187                values.push(Value::Integer(num));
188            }
189        } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"NULL") {
190            let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
191            if !after.is_ascii_alphanumeric() && after != b'_' {
192                pos += 4;
193                values.push(Value::Null);
194            } else {
195                return None;
196            }
197        } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"TRUE") {
198            let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
199            if !after.is_ascii_alphanumeric() && after != b'_' {
200                pos += 4;
201                values.push(Value::Boolean(true));
202            } else {
203                return None;
204            }
205        } else if pos + 5 <= len && bytes[pos..pos + 5].eq_ignore_ascii_case(b"FALSE") {
206            let after = if pos + 5 < len { bytes[pos + 5] } else { b')' };
207            if !after.is_ascii_alphanumeric() && after != b'_' {
208                pos += 5;
209                values.push(Value::Boolean(false));
210            } else {
211                return None;
212            }
213        } else {
214            return None;
215        }
216
217        normalized.push('$');
218        normalized.push_str(&param_idx.to_string());
219
220        while pos < len && bytes[pos].is_ascii_whitespace() {
221            pos += 1;
222        }
223        if pos >= len {
224            return None;
225        }
226
227        if bytes[pos] == b',' {
228            pos += 1;
229        } else if bytes[pos] == b')' {
230            pos += 1;
231            break;
232        } else {
233            return None;
234        }
235    }
236
237    normalized.push(')');
238
239    while pos < len && (bytes[pos].is_ascii_whitespace() || bytes[pos] == b';') {
240        pos += 1;
241    }
242    if pos != len {
243        return None;
244    }
245
246    if values.is_empty() {
247        return None;
248    }
249
250    Some((normalized, values))
251}
252
253pub(crate) struct CacheEntry {
254    pub(crate) stmt: Arc<Statement>,
255    pub(crate) schema_gen: u64,
256    pub(crate) param_count: usize,
257    pub(crate) compiled: Option<Arc<dyn executor::CompiledPlan>>,
258}
259
260struct SavepointEntry {
261    name: String,
262    snapshot: Option<SavepointSnapshot>,
263}
264
265struct SavepointSnapshot {
266    wtx_snap: WriteTxnSnapshot,
267    schema_snap: SchemaSnapshot,
268}
269
270pub(crate) struct ConnectionInner<'a> {
271    pub(crate) schema: SchemaManager,
272    active_txn: Option<WriteTxn<'a>>,
273    savepoint_stack: Vec<SavepointEntry>,
274    in_place_saved: Option<bool>,
275    pub(crate) stmt_cache: LruCache<String, CacheEntry>,
276    txn_start_ts: Option<i64>,
277    session_timezone: String,
278}
279
280/// SQL connection with LRU statement cache.
281pub struct Connection<'a> {
282    pub(crate) db: &'a Database,
283    pub(crate) inner: RefCell<ConnectionInner<'a>>,
284}
285
286impl<'a> Connection<'a> {
287    /// Open a SQL connection to a database.
288    pub fn open(db: &'a Database) -> Result<Self> {
289        let schema = SchemaManager::load(db)?;
290        let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
291        Ok(Self {
292            db,
293            inner: RefCell::new(ConnectionInner {
294                schema,
295                active_txn: None,
296                savepoint_stack: Vec::new(),
297                in_place_saved: None,
298                stmt_cache,
299                txn_start_ts: None,
300                session_timezone: "UTC".to_string(),
301            }),
302        })
303    }
304
305    /// Txn-start UTC µs inside BEGIN/COMMIT, else `None`.
306    pub fn txn_start_ts(&self) -> Option<i64> {
307        self.inner.borrow().txn_start_ts
308    }
309
310    /// Returns the session time-zone (IANA name or fixed offset). Default `"UTC"`.
311    pub fn session_timezone(&self) -> String {
312        self.inner.borrow().session_timezone.clone()
313    }
314
315    /// Set the session time-zone. Accepts IANA names, ISO-8601 offsets, `"UTC"`, `"Z"`.
316    pub fn set_session_timezone(&self, tz: &str) -> Result<()> {
317        self.inner.borrow_mut().set_session_timezone_impl(tz)
318    }
319
320    /// Execute a SQL statement. Returns the result.
321    pub fn execute(&self, sql: &str) -> Result<ExecutionResult> {
322        self.inner.borrow_mut().execute_impl(self.db, sql)
323    }
324
325    /// Execute a SQL statement with positional parameters ($1, $2, ...).
326    pub fn execute_params(&self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
327        self.inner
328            .borrow_mut()
329            .execute_params_impl(self.db, sql, params)
330    }
331
332    /// Execute `;`-separated SQL statements. Stops at the first failure.
333    pub fn execute_script(&self, sql: &str) -> ScriptExecution {
334        let stmts = match parser::parse_sql_multi(sql) {
335            Ok(s) => s,
336            Err(e) => {
337                return ScriptExecution {
338                    completed: vec![],
339                    error: Some(e),
340                }
341            }
342        };
343        let mut completed = Vec::with_capacity(stmts.len());
344        for stmt in stmts {
345            match self.inner.borrow_mut().dispatch(self.db, &stmt, &[]) {
346                Ok(r) => completed.push(r),
347                Err(e) => {
348                    return ScriptExecution {
349                        completed,
350                        error: Some(e),
351                    }
352                }
353            }
354        }
355        ScriptExecution {
356            completed,
357            error: None,
358        }
359    }
360
361    /// Execute a SQL query and return the result set.
362    pub fn query(&self, sql: &str) -> Result<QueryResult> {
363        self.query_params(sql, &[])
364    }
365
366    /// Execute a SQL query with positional parameters ($1, $2, ...).
367    pub fn query_params(&self, sql: &str, params: &[Value]) -> Result<QueryResult> {
368        match self.execute_params(sql, params)? {
369            ExecutionResult::Query(qr) => Ok(qr),
370            ExecutionResult::RowsAffected(n) => Ok(QueryResult {
371                columns: vec!["rows_affected".into()],
372                rows: vec![vec![Value::Integer(n as i64)]],
373            }),
374            ExecutionResult::Ok => Ok(QueryResult {
375                columns: vec![],
376                rows: vec![],
377            }),
378        }
379    }
380
381    /// Prepare a SQL statement for repeated execution with parameters.
382    pub fn prepare(&self, sql: &str) -> Result<PreparedStatement<'_, 'a>> {
383        PreparedStatement::new(self, sql)
384    }
385
386    /// List all table names.
387    pub fn tables(&self) -> Vec<String> {
388        self.inner
389            .borrow()
390            .schema
391            .table_names()
392            .into_iter()
393            .map(String::from)
394            .collect()
395    }
396
397    /// Returns true if an explicit transaction is active (BEGIN was issued).
398    pub fn in_transaction(&self) -> bool {
399        self.inner.borrow().active_txn.is_some()
400    }
401
402    /// Get the schema for a named table.
403    pub fn table_schema(&self, name: &str) -> Option<TableSchema> {
404        self.inner.borrow().schema.get(name).cloned()
405    }
406
407    /// Reload schemas from the database.
408    pub fn refresh_schema(&self) -> Result<()> {
409        let new_schema = SchemaManager::load(self.db)?;
410        self.inner.borrow_mut().schema = new_schema;
411        Ok(())
412    }
413}
414
415impl<'a> ConnectionInner<'a> {
416    pub(crate) fn active_txn_is_some(&self) -> bool {
417        self.active_txn.is_some()
418    }
419
420    fn set_session_timezone_impl(&mut self, tz: &str) -> Result<()> {
421        let upper = tz.to_ascii_uppercase();
422        if (upper.starts_with("UTC+") || upper.starts_with("UTC-")) && tz.len() > 3 {
423            return Err(SqlError::InvalidTimezone(format!(
424                "'{tz}' is ambiguous; use ISO-8601 offset (e.g. '+05:00') or named zone (e.g. 'Etc/GMT-5')"
425            )));
426        }
427        if jiff::tz::TimeZone::get(tz).is_err() && parse_fixed_offset(tz).is_none() {
428            return Err(SqlError::InvalidTimezone(format!(
429                "{tz}: not a known IANA zone or ISO-8601 offset (e.g. '+05:00', 'UTC', 'America/New_York')"
430            )));
431        }
432        self.session_timezone = tz.to_string();
433        Ok(())
434    }
435
436    fn execute_impl(&mut self, db: &'a Database, sql: &str) -> Result<ExecutionResult> {
437        if matches!(sql.as_bytes().first(), Some(b'I' | b'i')) {
438            if let Some((normalized_key, extracted)) = try_normalize_insert(sql) {
439                let gen = self.schema.generation();
440                let stmt = if let Some(entry) = self.stmt_cache.get(&normalized_key) {
441                    if entry.schema_gen == gen {
442                        Arc::clone(&entry.stmt)
443                    } else {
444                        self.parse_and_cache(normalized_key, gen)?
445                    }
446                } else {
447                    self.parse_and_cache(normalized_key, gen)?
448                };
449                return self.dispatch(db, &stmt, &extracted);
450            }
451        }
452        self.execute_params_impl(db, sql, &[])
453    }
454
455    fn execute_params_impl(
456        &mut self,
457        db: &'a Database,
458        sql: &str,
459        params: &[Value],
460    ) -> Result<ExecutionResult> {
461        let gen = self.schema.generation();
462        if self.active_txn.is_none() {
463            if let Some(entry) = self.stmt_cache.get(sql) {
464                if entry.schema_gen == gen && entry.param_count == params.len() {
465                    if let Some(plan) = entry.compiled.as_ref().map(Arc::clone) {
466                        let stmt = Arc::clone(&entry.stmt);
467                        return self.run_compiled(db, &plan, &stmt, params);
468                    }
469                }
470            }
471        }
472
473        let (stmt, param_count) = self.get_or_parse(sql)?;
474
475        if param_count != params.len() {
476            return Err(SqlError::ParameterCountMismatch {
477                expected: param_count,
478                got: params.len(),
479            });
480        }
481
482        if self.active_txn.is_none() {
483            if let Some(plan) = executor::compile(&self.schema, &stmt) {
484                if let Some(e) = self.stmt_cache.get_mut(sql) {
485                    e.compiled = Some(Arc::clone(&plan));
486                }
487                let stmt_owned = Arc::clone(&stmt);
488                return self.run_compiled(db, &plan, &stmt_owned, params);
489            }
490        }
491
492        self.dispatch(db, &stmt, params)
493    }
494
495    fn run_compiled(
496        &mut self,
497        db: &'a Database,
498        plan: &Arc<dyn executor::CompiledPlan>,
499        stmt: &Statement,
500        params: &[Value],
501    ) -> Result<ExecutionResult> {
502        let cached_ts = self
503            .txn_start_ts
504            .or_else(|| Some(crate::datetime::now_micros()));
505        let schema = &self.schema;
506        crate::datetime::with_txn_clock(cached_ts, || {
507            if params.is_empty() {
508                plan.execute(db, schema, stmt, params, None)
509            } else {
510                crate::eval::with_scoped_params(params, || {
511                    plan.execute(db, schema, stmt, params, None)
512                })
513            }
514        })
515    }
516
517    pub(crate) fn parse_and_cache(
518        &mut self,
519        normalized_key: String,
520        gen: u64,
521    ) -> Result<Arc<Statement>> {
522        let stmt = Arc::new(parser::parse_sql(&normalized_key)?);
523        let param_count = parser::count_params(&stmt);
524        self.stmt_cache.put(
525            normalized_key,
526            CacheEntry {
527                stmt: Arc::clone(&stmt),
528                schema_gen: gen,
529                param_count,
530                compiled: None,
531            },
532        );
533        Ok(stmt)
534    }
535
536    pub(crate) fn get_or_parse(&mut self, sql: &str) -> Result<(Arc<Statement>, usize)> {
537        let gen = self.schema.generation();
538
539        if let Some(entry) = self.stmt_cache.get(sql) {
540            if entry.schema_gen == gen {
541                return Ok((Arc::clone(&entry.stmt), entry.param_count));
542            }
543        }
544
545        let stmt = Arc::new(parser::parse_sql(sql)?);
546        let param_count = parser::count_params(&stmt);
547
548        let cacheable = !matches!(
549            *stmt,
550            Statement::CreateTable(_)
551                | Statement::DropTable(_)
552                | Statement::CreateIndex(_)
553                | Statement::DropIndex(_)
554                | Statement::CreateView(_)
555                | Statement::DropView(_)
556                | Statement::AlterTable(_)
557        );
558
559        if cacheable {
560            self.stmt_cache.put(
561                sql.to_string(),
562                CacheEntry {
563                    stmt: Arc::clone(&stmt),
564                    schema_gen: gen,
565                    param_count,
566                    compiled: None,
567                },
568            );
569        }
570
571        Ok((stmt, param_count))
572    }
573
574    pub(crate) fn execute_prepared(
575        &mut self,
576        db: &'a Database,
577        stmt: &Statement,
578        compiled: Option<&Arc<dyn executor::CompiledPlan>>,
579        params: &[Value],
580    ) -> Result<ExecutionResult> {
581        if let Some(plan) = compiled {
582            if self.active_txn.is_none() {
583                return self.run_compiled(db, plan, stmt, params);
584            }
585            if stmt_mutates(stmt) {
586                self.capture_pending_snapshots();
587            }
588            return self.run_compiled_in_txn(db, plan, stmt, params);
589        }
590        self.dispatch(db, stmt, params)
591    }
592
593    fn run_compiled_in_txn(
594        &mut self,
595        db: &'a Database,
596        plan: &Arc<dyn executor::CompiledPlan>,
597        stmt: &Statement,
598        params: &[Value],
599    ) -> Result<ExecutionResult> {
600        let cached_ts = self
601            .txn_start_ts
602            .or_else(|| Some(crate::datetime::now_micros()));
603        let schema = &self.schema;
604        let wtx = self.active_txn.as_mut();
605        crate::datetime::with_txn_clock(cached_ts, || {
606            if params.is_empty() {
607                plan.execute(db, schema, stmt, params, wtx)
608            } else {
609                crate::eval::with_scoped_params(params, || {
610                    plan.execute(db, schema, stmt, params, wtx)
611                })
612            }
613        })
614    }
615
616    pub(crate) fn dispatch(
617        &mut self,
618        db: &'a Database,
619        stmt: &Statement,
620        params: &[Value],
621    ) -> Result<ExecutionResult> {
622        let cached_ts = self
623            .txn_start_ts
624            .or_else(|| Some(crate::datetime::now_micros()));
625        crate::datetime::with_txn_clock(cached_ts, || {
626            if params.is_empty() {
627                self.dispatch_inner(db, stmt, params)
628            } else {
629                crate::eval::with_scoped_params(params, || self.dispatch_inner(db, stmt, params))
630            }
631        })
632    }
633
634    fn dispatch_inner(
635        &mut self,
636        db: &'a Database,
637        stmt: &Statement,
638        params: &[Value],
639    ) -> Result<ExecutionResult> {
640        match stmt {
641            Statement::Begin => {
642                if self.active_txn.is_some() {
643                    return Err(SqlError::TransactionAlreadyActive);
644                }
645                let wtx = db.begin_write().map_err(SqlError::Storage)?;
646                self.active_txn = Some(wtx);
647                self.txn_start_ts = Some(crate::datetime::txn_or_clock_micros());
648                Ok(ExecutionResult::Ok)
649            }
650            Statement::Commit => {
651                let wtx = self
652                    .active_txn
653                    .take()
654                    .ok_or(SqlError::NoActiveTransaction)?;
655                wtx.commit().map_err(SqlError::Storage)?;
656                self.clear_savepoint_state();
657                self.txn_start_ts = None;
658                Ok(ExecutionResult::Ok)
659            }
660            Statement::Rollback => {
661                let wtx = self
662                    .active_txn
663                    .take()
664                    .ok_or(SqlError::NoActiveTransaction)?;
665                wtx.abort();
666                self.clear_savepoint_state();
667                self.schema = SchemaManager::load(db)?;
668                self.txn_start_ts = None;
669                Ok(ExecutionResult::Ok)
670            }
671            Statement::Savepoint(name) => self.do_savepoint(name),
672            Statement::ReleaseSavepoint(name) => self.do_release(name),
673            Statement::RollbackTo(name) => self.do_rollback_to(name),
674            Statement::SetTimezone(zone) => {
675                self.set_session_timezone_impl(zone)?;
676                Ok(ExecutionResult::Ok)
677            }
678            Statement::Insert(ins) if self.active_txn.is_some() => {
679                self.capture_pending_snapshots();
680                let wtx = self.active_txn.as_mut().unwrap();
681                executor::exec_insert_in_txn(wtx, &self.schema, ins, params)
682            }
683            _ => {
684                if self.active_txn.is_some() && stmt_mutates(stmt) {
685                    self.capture_pending_snapshots();
686                }
687                if let Some(ref mut wtx) = self.active_txn {
688                    executor::execute_in_txn(wtx, &mut self.schema, stmt, params)
689                } else {
690                    executor::execute(db, &mut self.schema, stmt, params)
691                }
692            }
693        }
694    }
695
696    fn clear_savepoint_state(&mut self) {
697        self.savepoint_stack.clear();
698        self.in_place_saved = None;
699    }
700
701    fn do_savepoint(&mut self, name: &str) -> Result<ExecutionResult> {
702        let wtx = self
703            .active_txn
704            .as_mut()
705            .ok_or(SqlError::NoActiveTransaction)?;
706
707        if self.savepoint_stack.is_empty() {
708            self.in_place_saved = Some(wtx.in_place());
709            wtx.set_in_place(false);
710        }
711
712        self.savepoint_stack.push(SavepointEntry {
713            name: name.to_string(),
714            snapshot: None,
715        });
716
717        Ok(ExecutionResult::Ok)
718    }
719
720    fn capture_pending_snapshots(&mut self) {
721        if !self.savepoint_stack.iter().any(|e| e.snapshot.is_none()) {
722            return;
723        }
724        let wtx = match self.active_txn.as_mut() {
725            Some(w) => w,
726            None => return,
727        };
728        let wtx_snap = wtx.begin_savepoint();
729        let schema_snap = self.schema.save_snapshot();
730        let mut pending = self
731            .savepoint_stack
732            .iter_mut()
733            .filter(|e| e.snapshot.is_none());
734        if let Some(first) = pending.next() {
735            first.snapshot = Some(SavepointSnapshot {
736                wtx_snap: wtx_snap.clone(),
737                schema_snap: schema_snap.clone(),
738            });
739        }
740        for entry in pending {
741            entry.snapshot = Some(SavepointSnapshot {
742                wtx_snap: wtx_snap.clone(),
743                schema_snap: schema_snap.clone(),
744            });
745        }
746    }
747
748    fn do_release(&mut self, name: &str) -> Result<ExecutionResult> {
749        if self.active_txn.is_none() {
750            return Err(SqlError::NoActiveTransaction);
751        }
752
753        let idx = self
754            .savepoint_stack
755            .iter()
756            .rposition(|e| e.name == name)
757            .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
758        self.savepoint_stack.truncate(idx);
759
760        if self.savepoint_stack.is_empty() {
761            if let (Some(wtx), Some(original)) =
762                (self.active_txn.as_mut(), self.in_place_saved.take())
763            {
764                wtx.set_in_place(original);
765            }
766        }
767
768        Ok(ExecutionResult::Ok)
769    }
770
771    fn do_rollback_to(&mut self, name: &str) -> Result<ExecutionResult> {
772        if self.active_txn.is_none() {
773            return Err(SqlError::NoActiveTransaction);
774        }
775
776        let idx = self
777            .savepoint_stack
778            .iter()
779            .rposition(|e| e.name == name)
780            .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
781
782        self.savepoint_stack.truncate(idx + 1);
783        let entry = self.savepoint_stack.last_mut().unwrap();
784        let snapshot = match entry.snapshot.take() {
785            Some(s) => s,
786            None => return Ok(ExecutionResult::Ok),
787        };
788
789        let wtx = self.active_txn.as_mut().unwrap();
790        wtx.restore_snapshot(snapshot.wtx_snap);
791        self.schema.restore_snapshot(snapshot.schema_snap);
792
793        self.stmt_cache.clear();
794
795        Ok(ExecutionResult::Ok)
796    }
797}