1use std::cell::RefCell;
4use std::num::NonZeroUsize;
5use std::sync::Arc;
6
7use lru::LruCache;
8
9use citadel::Database;
10use citadel_txn::write_txn::{WriteTxn, WriteTxnSnapshot};
11
12use crate::error::{Result, SqlError};
13use crate::executor;
14use crate::parser;
15use crate::parser::{QueryBody, SelectQuery, Statement};
16use crate::prepared::PreparedStatement;
17use crate::schema::{SchemaManager, SchemaSnapshot};
18use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
19
20const DEFAULT_CACHE_CAPACITY: usize = 64;
21
22#[derive(Debug)]
23pub struct ScriptExecution {
24 pub completed: Vec<ExecutionResult>,
25 pub error: Option<SqlError>,
26}
27
28fn parse_fixed_offset(s: &str) -> Option<()> {
29 let s = s.trim();
30 if s.eq_ignore_ascii_case("z") || s.eq_ignore_ascii_case("utc") {
31 return Some(());
32 }
33 let bytes = s.as_bytes();
34 if bytes.is_empty() {
35 return None;
36 }
37 let sign = match bytes[0] {
38 b'+' | b'-' => bytes[0] as char,
39 _ => return None,
40 };
41 let rest = &s[1..];
42 let (hh, mm) = if let Some((h, m)) = rest.split_once(':') {
43 (h, m)
44 } else if rest.len() == 4 {
45 (&rest[..2], &rest[2..])
46 } else if rest.len() == 2 {
47 (rest, "00")
48 } else {
49 return None;
50 };
51 let h: u32 = hh.parse().ok()?;
52 let m: u32 = mm.parse().ok()?;
53 if h > 23 || m > 59 {
54 return None;
55 }
56 let _ = sign;
57 Some(())
58}
59
60fn stmt_mutates(stmt: &Statement) -> bool {
61 if matches!(
62 stmt,
63 Statement::Insert(_)
64 | Statement::Update(_)
65 | Statement::Delete(_)
66 | Statement::Truncate(_)
67 | Statement::CreateTable(_)
68 | Statement::DropTable(_)
69 | Statement::AlterTable(_)
70 | Statement::CreateIndex(_)
71 | Statement::DropIndex(_)
72 | Statement::CreateView(_)
73 | Statement::DropView(_)
74 ) {
75 return true;
76 }
77 if let Statement::Select(sq) = stmt {
78 if select_query_has_dml(sq) {
79 return true;
80 }
81 }
82 false
83}
84
85fn select_query_has_dml(sq: &SelectQuery) -> bool {
86 sq.ctes.iter().any(|cte| query_body_has_dml(&cte.body)) || query_body_has_dml(&sq.body)
87}
88
89fn query_body_has_dml(body: &QueryBody) -> bool {
90 match body {
91 QueryBody::Insert(_) | QueryBody::Update(_) | QueryBody::Delete(_) => true,
92 QueryBody::Compound(c) => query_body_has_dml(&c.left) || query_body_has_dml(&c.right),
93 QueryBody::Select(_) => false,
94 }
95}
96
97fn try_normalize_insert(sql: &str) -> Option<(String, Vec<Value>)> {
98 let bytes = sql.as_bytes();
99 let len = bytes.len();
100 let mut i = 0;
101
102 while i < len && bytes[i].is_ascii_whitespace() {
103 i += 1;
104 }
105 if i + 6 > len || !bytes[i..i + 6].eq_ignore_ascii_case(b"INSERT") {
106 return None;
107 }
108 i += 6;
109 if i >= len || !bytes[i].is_ascii_whitespace() {
110 return None;
111 }
112 while i < len && bytes[i].is_ascii_whitespace() {
113 i += 1;
114 }
115
116 if i + 4 > len || !bytes[i..i + 4].eq_ignore_ascii_case(b"INTO") {
117 return None;
118 }
119 i += 4;
120 if i >= len || !bytes[i].is_ascii_whitespace() {
121 return None;
122 }
123
124 let prefix_start = 0;
125 let mut values_pos = None;
126 let mut j = i;
127 while j + 6 <= len {
128 if bytes[j..j + 6].eq_ignore_ascii_case(b"VALUES")
129 && (j == 0 || !bytes[j - 1].is_ascii_alphanumeric() && bytes[j - 1] != b'_')
130 && (j + 6 >= len || !bytes[j + 6].is_ascii_alphanumeric() && bytes[j + 6] != b'_')
131 {
132 values_pos = Some(j);
133 break;
134 }
135 j += 1;
136 }
137 let values_pos = values_pos?;
138
139 let prefix = &sql[prefix_start..values_pos + 6];
140 let mut pos = values_pos + 6;
141
142 while pos < len && bytes[pos].is_ascii_whitespace() {
143 pos += 1;
144 }
145 if pos >= len || bytes[pos] != b'(' {
146 return None;
147 }
148 pos += 1;
149
150 let mut values = Vec::new();
151 let mut normalized = String::with_capacity(sql.len());
152 normalized.push_str(prefix);
153 normalized.push_str(" (");
154
155 loop {
156 while pos < len && bytes[pos].is_ascii_whitespace() {
157 pos += 1;
158 }
159 if pos >= len {
160 return None;
161 }
162
163 let param_idx = values.len() + 1;
164 if param_idx > 1 {
165 normalized.push_str(", ");
166 }
167
168 if bytes[pos] == b'\'' {
169 pos += 1;
170 let mut seg_start = pos;
171 let mut s = String::new();
172 loop {
173 if pos >= len {
174 return None;
175 }
176 if bytes[pos] == b'\'' {
177 s.push_str(std::str::from_utf8(&bytes[seg_start..pos]).ok()?);
178 pos += 1;
179 if pos < len && bytes[pos] == b'\'' {
180 s.push('\'');
181 pos += 1;
182 seg_start = pos;
183 } else {
184 break;
185 }
186 } else {
187 pos += 1;
188 }
189 }
190 values.push(Value::Text(s.into()));
191 } else if bytes[pos] == b'-' || bytes[pos].is_ascii_digit() {
192 let start = pos;
193 if bytes[pos] == b'-' {
194 pos += 1;
195 }
196 while pos < len && bytes[pos].is_ascii_digit() {
197 pos += 1;
198 }
199 if pos < len && bytes[pos] == b'.' {
200 pos += 1;
201 while pos < len && bytes[pos].is_ascii_digit() {
202 pos += 1;
203 }
204 let num: f64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
205 values.push(Value::Real(num));
206 } else {
207 let num: i64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
208 values.push(Value::Integer(num));
209 }
210 } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"NULL") {
211 let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
212 if !after.is_ascii_alphanumeric() && after != b'_' {
213 pos += 4;
214 values.push(Value::Null);
215 } else {
216 return None;
217 }
218 } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"TRUE") {
219 let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
220 if !after.is_ascii_alphanumeric() && after != b'_' {
221 pos += 4;
222 values.push(Value::Boolean(true));
223 } else {
224 return None;
225 }
226 } else if pos + 5 <= len && bytes[pos..pos + 5].eq_ignore_ascii_case(b"FALSE") {
227 let after = if pos + 5 < len { bytes[pos + 5] } else { b')' };
228 if !after.is_ascii_alphanumeric() && after != b'_' {
229 pos += 5;
230 values.push(Value::Boolean(false));
231 } else {
232 return None;
233 }
234 } else {
235 return None;
236 }
237
238 normalized.push('$');
239 normalized.push_str(¶m_idx.to_string());
240
241 while pos < len && bytes[pos].is_ascii_whitespace() {
242 pos += 1;
243 }
244 if pos >= len {
245 return None;
246 }
247
248 if bytes[pos] == b',' {
249 pos += 1;
250 } else if bytes[pos] == b')' {
251 pos += 1;
252 break;
253 } else {
254 return None;
255 }
256 }
257
258 normalized.push(')');
259
260 while pos < len && (bytes[pos].is_ascii_whitespace() || bytes[pos] == b';') {
261 pos += 1;
262 }
263 if pos != len {
264 return None;
265 }
266
267 if values.is_empty() {
268 return None;
269 }
270
271 Some((normalized, values))
272}
273
274pub(crate) struct CacheEntry {
275 pub(crate) stmt: Arc<Statement>,
276 pub(crate) schema_gen: u64,
277 pub(crate) param_count: usize,
278 pub(crate) compiled: Option<Arc<dyn executor::CompiledPlan>>,
279}
280
281struct SavepointEntry {
282 name: String,
283 snapshot: Option<SavepointSnapshot>,
284}
285
286struct SavepointSnapshot {
287 wtx_snap: WriteTxnSnapshot,
288 schema_snap: SchemaSnapshot,
289}
290
291pub(crate) struct ConnectionInner<'a> {
292 pub(crate) schema: SchemaManager,
293 active_txn: Option<WriteTxn<'a>>,
294 savepoint_stack: Vec<SavepointEntry>,
295 in_place_saved: Option<bool>,
296 pub(crate) stmt_cache: LruCache<String, CacheEntry>,
297 txn_start_ts: Option<i64>,
298 session_timezone: String,
299}
300
301pub struct Connection<'a> {
303 pub(crate) db: &'a Database,
304 pub(crate) inner: RefCell<ConnectionInner<'a>>,
305}
306
307impl<'a> Connection<'a> {
308 pub fn open(db: &'a Database) -> Result<Self> {
310 let schema = SchemaManager::load(db)?;
311 let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
312 Ok(Self {
313 db,
314 inner: RefCell::new(ConnectionInner {
315 schema,
316 active_txn: None,
317 savepoint_stack: Vec::new(),
318 in_place_saved: None,
319 stmt_cache,
320 txn_start_ts: None,
321 session_timezone: "UTC".to_string(),
322 }),
323 })
324 }
325
326 pub fn txn_start_ts(&self) -> Option<i64> {
328 self.inner.borrow().txn_start_ts
329 }
330
331 pub fn session_timezone(&self) -> String {
333 self.inner.borrow().session_timezone.clone()
334 }
335
336 pub fn set_session_timezone(&self, tz: &str) -> Result<()> {
338 self.inner.borrow_mut().set_session_timezone_impl(tz)
339 }
340
341 pub fn execute(&self, sql: &str) -> Result<ExecutionResult> {
343 self.inner.borrow_mut().execute_impl(self.db, sql)
344 }
345
346 pub fn execute_params(&self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
348 self.inner
349 .borrow_mut()
350 .execute_params_impl(self.db, sql, params)
351 }
352
353 pub fn execute_script(&self, sql: &str) -> ScriptExecution {
355 let stmts = match parser::parse_sql_multi(sql) {
356 Ok(s) => s,
357 Err(e) => {
358 return ScriptExecution {
359 completed: vec![],
360 error: Some(e),
361 }
362 }
363 };
364 let mut completed = Vec::with_capacity(stmts.len());
365 for stmt in stmts {
366 match self.inner.borrow_mut().dispatch(self.db, &stmt, &[]) {
367 Ok(r) => completed.push(r),
368 Err(e) => {
369 return ScriptExecution {
370 completed,
371 error: Some(e),
372 }
373 }
374 }
375 }
376 ScriptExecution {
377 completed,
378 error: None,
379 }
380 }
381
382 pub fn query(&self, sql: &str) -> Result<QueryResult> {
384 self.query_params(sql, &[])
385 }
386
387 pub fn query_params(&self, sql: &str, params: &[Value]) -> Result<QueryResult> {
389 match self.execute_params(sql, params)? {
390 ExecutionResult::Query(qr) => Ok(qr),
391 ExecutionResult::RowsAffected(n) => Ok(QueryResult {
392 columns: vec!["rows_affected".into()],
393 rows: vec![vec![Value::Integer(n as i64)]],
394 }),
395 ExecutionResult::Ok => Ok(QueryResult {
396 columns: vec![],
397 rows: vec![],
398 }),
399 }
400 }
401
402 pub fn prepare(&self, sql: &str) -> Result<PreparedStatement<'_, 'a>> {
404 PreparedStatement::new(self, sql)
405 }
406
407 pub fn tables(&self) -> Vec<String> {
409 self.inner
410 .borrow()
411 .schema
412 .table_names()
413 .into_iter()
414 .map(String::from)
415 .collect()
416 }
417
418 pub fn in_transaction(&self) -> bool {
420 self.inner.borrow().active_txn.is_some()
421 }
422
423 pub fn table_schema(&self, name: &str) -> Option<TableSchema> {
425 self.inner.borrow().schema.get(name).cloned()
426 }
427
428 pub fn refresh_schema(&self) -> Result<()> {
430 let new_schema = SchemaManager::load(self.db)?;
431 self.inner.borrow_mut().schema = new_schema;
432 Ok(())
433 }
434}
435
436impl<'a> ConnectionInner<'a> {
437 pub(crate) fn active_txn_is_some(&self) -> bool {
438 self.active_txn.is_some()
439 }
440
441 fn set_session_timezone_impl(&mut self, tz: &str) -> Result<()> {
442 let upper = tz.to_ascii_uppercase();
443 if (upper.starts_with("UTC+") || upper.starts_with("UTC-")) && tz.len() > 3 {
444 return Err(SqlError::InvalidTimezone(format!(
445 "'{tz}' is ambiguous; use ISO-8601 offset (e.g. '+05:00') or named zone (e.g. 'Etc/GMT-5')"
446 )));
447 }
448 if jiff::tz::TimeZone::get(tz).is_err() && parse_fixed_offset(tz).is_none() {
449 return Err(SqlError::InvalidTimezone(format!(
450 "{tz}: not a known IANA zone or ISO-8601 offset (e.g. '+05:00', 'UTC', 'America/New_York')"
451 )));
452 }
453 self.session_timezone = tz.to_string();
454 Ok(())
455 }
456
457 fn execute_impl(&mut self, db: &'a Database, sql: &str) -> Result<ExecutionResult> {
458 if matches!(sql.as_bytes().first(), Some(b'I' | b'i')) {
459 if let Some((normalized_key, extracted)) = try_normalize_insert(sql) {
460 let gen = self.schema.generation();
461 let stmt = if let Some(entry) = self.stmt_cache.get(&normalized_key) {
462 if entry.schema_gen == gen {
463 Arc::clone(&entry.stmt)
464 } else {
465 self.parse_and_cache(normalized_key, gen)?
466 }
467 } else {
468 self.parse_and_cache(normalized_key, gen)?
469 };
470 return self.dispatch(db, &stmt, &extracted);
471 }
472 }
473 self.execute_params_impl(db, sql, &[])
474 }
475
476 fn execute_params_impl(
477 &mut self,
478 db: &'a Database,
479 sql: &str,
480 params: &[Value],
481 ) -> Result<ExecutionResult> {
482 let gen = self.schema.generation();
483 if self.active_txn.is_none() {
484 if let Some(entry) = self.stmt_cache.get(sql) {
485 if entry.schema_gen == gen && entry.param_count == params.len() {
486 if let Some(plan) = entry.compiled.as_ref().map(Arc::clone) {
487 let stmt = Arc::clone(&entry.stmt);
488 return self.run_compiled(db, &plan, &stmt, params);
489 }
490 }
491 }
492 }
493
494 let (stmt, param_count) = self.get_or_parse(sql)?;
495
496 if param_count != params.len() {
497 return Err(SqlError::ParameterCountMismatch {
498 expected: param_count,
499 got: params.len(),
500 });
501 }
502
503 if self.active_txn.is_none() {
504 if let Some(plan) = executor::compile(&self.schema, &stmt) {
505 if let Some(e) = self.stmt_cache.get_mut(sql) {
506 e.compiled = Some(Arc::clone(&plan));
507 }
508 let stmt_owned = Arc::clone(&stmt);
509 return self.run_compiled(db, &plan, &stmt_owned, params);
510 }
511 }
512
513 self.dispatch(db, &stmt, params)
514 }
515
516 fn run_compiled(
517 &mut self,
518 db: &'a Database,
519 plan: &Arc<dyn executor::CompiledPlan>,
520 stmt: &Statement,
521 params: &[Value],
522 ) -> Result<ExecutionResult> {
523 let schema = &self.schema;
524 let exec = || {
525 if params.is_empty() {
526 plan.execute(db, schema, stmt, params, None)
527 } else {
528 crate::eval::with_scoped_params(params, || {
529 plan.execute(db, schema, stmt, params, None)
530 })
531 }
532 };
533 if plan.needs_txn_clock() {
534 let cached_ts = self
535 .txn_start_ts
536 .or_else(|| Some(crate::datetime::now_micros()));
537 crate::datetime::with_txn_clock(cached_ts, exec)
538 } else {
539 exec()
540 }
541 }
542
543 pub(crate) fn parse_and_cache(
544 &mut self,
545 normalized_key: String,
546 gen: u64,
547 ) -> Result<Arc<Statement>> {
548 let stmt = Arc::new(parser::parse_sql(&normalized_key)?);
549 let param_count = parser::count_params(&stmt);
550 self.stmt_cache.put(
551 normalized_key,
552 CacheEntry {
553 stmt: Arc::clone(&stmt),
554 schema_gen: gen,
555 param_count,
556 compiled: None,
557 },
558 );
559 Ok(stmt)
560 }
561
562 pub(crate) fn get_or_parse(&mut self, sql: &str) -> Result<(Arc<Statement>, usize)> {
563 let gen = self.schema.generation();
564
565 if let Some(entry) = self.stmt_cache.get(sql) {
566 if entry.schema_gen == gen {
567 return Ok((Arc::clone(&entry.stmt), entry.param_count));
568 }
569 }
570
571 let stmt = Arc::new(parser::parse_sql(sql)?);
572 let param_count = parser::count_params(&stmt);
573
574 let cacheable = !matches!(
575 *stmt,
576 Statement::CreateTable(_)
577 | Statement::DropTable(_)
578 | Statement::CreateIndex(_)
579 | Statement::DropIndex(_)
580 | Statement::CreateView(_)
581 | Statement::DropView(_)
582 | Statement::AlterTable(_)
583 );
584
585 if cacheable {
586 self.stmt_cache.put(
587 sql.to_string(),
588 CacheEntry {
589 stmt: Arc::clone(&stmt),
590 schema_gen: gen,
591 param_count,
592 compiled: None,
593 },
594 );
595 }
596
597 Ok((stmt, param_count))
598 }
599
600 pub(crate) fn execute_prepared(
601 &mut self,
602 db: &'a Database,
603 stmt: &Statement,
604 compiled: Option<&Arc<dyn executor::CompiledPlan>>,
605 params: &[Value],
606 ) -> Result<ExecutionResult> {
607 if let Some(plan) = compiled {
608 if self.active_txn.is_none() {
609 return self.run_compiled(db, plan, stmt, params);
610 }
611 if !self.savepoint_stack.is_empty() && stmt_mutates(stmt) {
612 self.capture_pending_snapshots();
613 }
614 return self.run_compiled_in_txn(db, plan, stmt, params);
615 }
616 self.dispatch(db, stmt, params)
617 }
618
619 fn run_compiled_in_txn(
620 &mut self,
621 db: &'a Database,
622 plan: &Arc<dyn executor::CompiledPlan>,
623 stmt: &Statement,
624 params: &[Value],
625 ) -> Result<ExecutionResult> {
626 let schema = &self.schema;
627 let wtx = self.active_txn.as_mut();
628 if params.is_empty() || !plan.uses_scoped_params() {
629 plan.execute(db, schema, stmt, params, wtx)
630 } else {
631 crate::eval::with_scoped_params(params, || plan.execute(db, schema, stmt, params, wtx))
632 }
633 }
634
635 pub(crate) fn dispatch(
636 &mut self,
637 db: &'a Database,
638 stmt: &Statement,
639 params: &[Value],
640 ) -> Result<ExecutionResult> {
641 let cached_ts = self
642 .txn_start_ts
643 .or_else(|| Some(crate::datetime::now_micros()));
644 crate::datetime::with_txn_clock(cached_ts, || {
645 if params.is_empty() {
646 self.dispatch_inner(db, stmt, params)
647 } else {
648 crate::eval::with_scoped_params(params, || self.dispatch_inner(db, stmt, params))
649 }
650 })
651 }
652
653 fn dispatch_inner(
654 &mut self,
655 db: &'a Database,
656 stmt: &Statement,
657 params: &[Value],
658 ) -> Result<ExecutionResult> {
659 match stmt {
660 Statement::Begin => {
661 if self.active_txn.is_some() {
662 return Err(SqlError::TransactionAlreadyActive);
663 }
664 let wtx = db.begin_write().map_err(SqlError::Storage)?;
665 self.active_txn = Some(wtx);
666 let ts = crate::datetime::txn_or_clock_micros();
667 self.txn_start_ts = Some(ts);
668 crate::datetime::set_txn_clock(Some(ts));
669 Ok(ExecutionResult::Ok)
670 }
671 Statement::Commit => {
672 let mut wtx = self
673 .active_txn
674 .take()
675 .ok_or(SqlError::NoActiveTransaction)?;
676 crate::executor::helpers::drain_deferred_fk_checks(&mut wtx)?;
677 wtx.commit().map_err(SqlError::Storage)?;
678 self.clear_savepoint_state();
679 self.txn_start_ts = None;
680 crate::datetime::set_txn_clock(None);
681 Ok(ExecutionResult::Ok)
682 }
683 Statement::Rollback => {
684 let wtx = self
685 .active_txn
686 .take()
687 .ok_or(SqlError::NoActiveTransaction)?;
688 wtx.abort();
689 self.clear_savepoint_state();
690 self.schema = SchemaManager::load(db)?;
691 self.txn_start_ts = None;
692 crate::datetime::set_txn_clock(None);
693 Ok(ExecutionResult::Ok)
694 }
695 Statement::Savepoint(name) => self.do_savepoint(name),
696 Statement::ReleaseSavepoint(name) => self.do_release(name),
697 Statement::RollbackTo(name) => self.do_rollback_to(name),
698 Statement::SetTimezone(zone) => {
699 self.set_session_timezone_impl(zone)?;
700 Ok(ExecutionResult::Ok)
701 }
702 Statement::Insert(ins) if self.active_txn.is_some() => {
703 self.capture_pending_snapshots();
704 let wtx = self.active_txn.as_mut().unwrap();
705 executor::exec_insert_in_txn(wtx, &self.schema, ins, params)
706 }
707 _ => {
708 if self.active_txn.is_some() && stmt_mutates(stmt) {
709 self.capture_pending_snapshots();
710 }
711 if let Some(ref mut wtx) = self.active_txn {
712 executor::execute_in_txn(wtx, &mut self.schema, stmt, params)
713 } else {
714 executor::execute(db, &mut self.schema, stmt, params)
715 }
716 }
717 }
718 }
719
720 fn clear_savepoint_state(&mut self) {
721 self.savepoint_stack.clear();
722 self.in_place_saved = None;
723 }
724
725 fn do_savepoint(&mut self, name: &str) -> Result<ExecutionResult> {
726 let wtx = self
727 .active_txn
728 .as_mut()
729 .ok_or(SqlError::NoActiveTransaction)?;
730
731 if self.savepoint_stack.is_empty() {
732 self.in_place_saved = Some(wtx.in_place());
733 wtx.set_in_place(false);
734 }
735
736 self.savepoint_stack.push(SavepointEntry {
737 name: name.to_string(),
738 snapshot: None,
739 });
740
741 Ok(ExecutionResult::Ok)
742 }
743
744 fn capture_pending_snapshots(&mut self) {
745 let last_pending = match self
746 .savepoint_stack
747 .iter()
748 .rposition(|e| e.snapshot.is_none())
749 {
750 Some(i) => i,
751 None => return,
752 };
753 let wtx = match self.active_txn.as_mut() {
754 Some(w) => w,
755 None => return,
756 };
757 let wtx_snap = wtx.begin_savepoint();
758 let schema_snap = self.schema.save_snapshot();
759
760 for i in 0..last_pending {
761 if self.savepoint_stack[i].snapshot.is_none() {
762 self.savepoint_stack[i].snapshot = Some(SavepointSnapshot {
763 wtx_snap: wtx_snap.clone(),
764 schema_snap: schema_snap.clone(),
765 });
766 }
767 }
768 self.savepoint_stack[last_pending].snapshot = Some(SavepointSnapshot {
769 wtx_snap,
770 schema_snap,
771 });
772 }
773
774 fn do_release(&mut self, name: &str) -> Result<ExecutionResult> {
775 if self.active_txn.is_none() {
776 return Err(SqlError::NoActiveTransaction);
777 }
778
779 let idx = self
780 .savepoint_stack
781 .iter()
782 .rposition(|e| e.name == name)
783 .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
784 self.savepoint_stack.truncate(idx);
785
786 if self.savepoint_stack.is_empty() {
787 if let (Some(wtx), Some(original)) =
788 (self.active_txn.as_mut(), self.in_place_saved.take())
789 {
790 wtx.set_in_place(original);
791 }
792 }
793
794 Ok(ExecutionResult::Ok)
795 }
796
797 fn do_rollback_to(&mut self, name: &str) -> Result<ExecutionResult> {
798 if self.active_txn.is_none() {
799 return Err(SqlError::NoActiveTransaction);
800 }
801
802 let idx = self
803 .savepoint_stack
804 .iter()
805 .rposition(|e| e.name == name)
806 .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
807
808 self.savepoint_stack.truncate(idx + 1);
809 let entry = self.savepoint_stack.last_mut().unwrap();
810 let snapshot = match entry.snapshot.take() {
811 Some(s) => s,
812 None => return Ok(ExecutionResult::Ok),
813 };
814
815 let wtx = self.active_txn.as_mut().unwrap();
816 wtx.restore_snapshot(snapshot.wtx_snap);
817 self.schema.restore_snapshot(snapshot.schema_snap);
818
819 Ok(ExecutionResult::Ok)
820 }
821}