1use std::cell::RefCell;
4use std::num::NonZeroUsize;
5use std::sync::Arc;
6
7use lru::LruCache;
8
9use citadel::Database;
10use citadel_txn::write_txn::{WriteTxn, WriteTxnSnapshot};
11
12use crate::error::{Result, SqlError};
13use crate::executor;
14use crate::parser;
15use crate::parser::{QueryBody, SelectQuery, Statement};
16use crate::prepared::PreparedStatement;
17use crate::schema::{SchemaManager, SchemaSnapshot};
18use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
19
20const DEFAULT_CACHE_CAPACITY: usize = 64;
21
22#[derive(Debug)]
23pub struct ScriptExecution {
24 pub completed: Vec<ExecutionResult>,
25 pub error: Option<SqlError>,
26}
27
28fn parse_fixed_offset(s: &str) -> Option<()> {
29 let s = s.trim();
30 if s.eq_ignore_ascii_case("z") || s.eq_ignore_ascii_case("utc") {
31 return Some(());
32 }
33 let bytes = s.as_bytes();
34 if bytes.is_empty() {
35 return None;
36 }
37 let sign = match bytes[0] {
38 b'+' | b'-' => bytes[0] as char,
39 _ => return None,
40 };
41 let rest = &s[1..];
42 let (hh, mm) = if let Some((h, m)) = rest.split_once(':') {
43 (h, m)
44 } else if rest.len() == 4 {
45 (&rest[..2], &rest[2..])
46 } else if rest.len() == 2 {
47 (rest, "00")
48 } else {
49 return None;
50 };
51 let h: u32 = hh.parse().ok()?;
52 let m: u32 = mm.parse().ok()?;
53 if h > 23 || m > 59 {
54 return None;
55 }
56 let _ = sign;
57 Some(())
58}
59
60fn stmt_mutates(stmt: &Statement) -> bool {
61 if matches!(
62 stmt,
63 Statement::Insert(_)
64 | Statement::Update(_)
65 | Statement::Delete(_)
66 | Statement::Truncate(_)
67 | Statement::CreateTable(_)
68 | Statement::DropTable(_)
69 | Statement::AlterTable(_)
70 | Statement::CreateIndex(_)
71 | Statement::DropIndex(_)
72 | Statement::CreateView(_)
73 | Statement::DropView(_)
74 ) {
75 return true;
76 }
77 if let Statement::Select(sq) = stmt {
78 if select_query_has_dml(sq) {
79 return true;
80 }
81 }
82 false
83}
84
85fn select_query_has_dml(sq: &SelectQuery) -> bool {
86 sq.ctes.iter().any(|cte| query_body_has_dml(&cte.body)) || query_body_has_dml(&sq.body)
87}
88
89fn query_body_has_dml(body: &QueryBody) -> bool {
90 match body {
91 QueryBody::Insert(_) | QueryBody::Update(_) | QueryBody::Delete(_) => true,
92 QueryBody::Compound(c) => query_body_has_dml(&c.left) || query_body_has_dml(&c.right),
93 QueryBody::Select(_) => false,
94 }
95}
96
97fn try_normalize_insert(sql: &str) -> Option<(String, Vec<Value>)> {
98 let bytes = sql.as_bytes();
99 let len = bytes.len();
100 let mut i = 0;
101
102 while i < len && bytes[i].is_ascii_whitespace() {
103 i += 1;
104 }
105 if i + 6 > len || !bytes[i..i + 6].eq_ignore_ascii_case(b"INSERT") {
106 return None;
107 }
108 i += 6;
109 if i >= len || !bytes[i].is_ascii_whitespace() {
110 return None;
111 }
112 while i < len && bytes[i].is_ascii_whitespace() {
113 i += 1;
114 }
115
116 if i + 4 > len || !bytes[i..i + 4].eq_ignore_ascii_case(b"INTO") {
117 return None;
118 }
119 i += 4;
120 if i >= len || !bytes[i].is_ascii_whitespace() {
121 return None;
122 }
123
124 let prefix_start = 0;
125 let mut values_pos = None;
126 let mut j = i;
127 while j + 6 <= len {
128 if bytes[j..j + 6].eq_ignore_ascii_case(b"VALUES")
129 && (j == 0 || !bytes[j - 1].is_ascii_alphanumeric() && bytes[j - 1] != b'_')
130 && (j + 6 >= len || !bytes[j + 6].is_ascii_alphanumeric() && bytes[j + 6] != b'_')
131 {
132 values_pos = Some(j);
133 break;
134 }
135 j += 1;
136 }
137 let values_pos = values_pos?;
138
139 let prefix = &sql[prefix_start..values_pos + 6];
140 let mut pos = values_pos + 6;
141
142 while pos < len && bytes[pos].is_ascii_whitespace() {
143 pos += 1;
144 }
145 if pos >= len || bytes[pos] != b'(' {
146 return None;
147 }
148 pos += 1;
149
150 let mut values = Vec::new();
151 let mut normalized = String::with_capacity(sql.len());
152 normalized.push_str(prefix);
153 normalized.push_str(" (");
154
155 loop {
156 while pos < len && bytes[pos].is_ascii_whitespace() {
157 pos += 1;
158 }
159 if pos >= len {
160 return None;
161 }
162
163 let param_idx = values.len() + 1;
164 if param_idx > 1 {
165 normalized.push_str(", ");
166 }
167
168 if bytes[pos] == b'\'' {
169 pos += 1;
170 let mut seg_start = pos;
171 let mut s = String::new();
172 loop {
173 if pos >= len {
174 return None;
175 }
176 if bytes[pos] == b'\'' {
177 s.push_str(std::str::from_utf8(&bytes[seg_start..pos]).ok()?);
178 pos += 1;
179 if pos < len && bytes[pos] == b'\'' {
180 s.push('\'');
181 pos += 1;
182 seg_start = pos;
183 } else {
184 break;
185 }
186 } else {
187 pos += 1;
188 }
189 }
190 values.push(Value::Text(s.into()));
191 } else if bytes[pos] == b'-' || bytes[pos].is_ascii_digit() {
192 let start = pos;
193 if bytes[pos] == b'-' {
194 pos += 1;
195 }
196 while pos < len && bytes[pos].is_ascii_digit() {
197 pos += 1;
198 }
199 if pos < len && bytes[pos] == b'.' {
200 pos += 1;
201 while pos < len && bytes[pos].is_ascii_digit() {
202 pos += 1;
203 }
204 let num: f64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
205 values.push(Value::Real(num));
206 } else {
207 let num: i64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
208 values.push(Value::Integer(num));
209 }
210 } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"NULL") {
211 let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
212 if !after.is_ascii_alphanumeric() && after != b'_' {
213 pos += 4;
214 values.push(Value::Null);
215 } else {
216 return None;
217 }
218 } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"TRUE") {
219 let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
220 if !after.is_ascii_alphanumeric() && after != b'_' {
221 pos += 4;
222 values.push(Value::Boolean(true));
223 } else {
224 return None;
225 }
226 } else if pos + 5 <= len && bytes[pos..pos + 5].eq_ignore_ascii_case(b"FALSE") {
227 let after = if pos + 5 < len { bytes[pos + 5] } else { b')' };
228 if !after.is_ascii_alphanumeric() && after != b'_' {
229 pos += 5;
230 values.push(Value::Boolean(false));
231 } else {
232 return None;
233 }
234 } else {
235 return None;
236 }
237
238 normalized.push('$');
239 normalized.push_str(¶m_idx.to_string());
240
241 while pos < len && bytes[pos].is_ascii_whitespace() {
242 pos += 1;
243 }
244 if pos >= len {
245 return None;
246 }
247
248 if bytes[pos] == b',' {
249 pos += 1;
250 } else if bytes[pos] == b')' {
251 pos += 1;
252 break;
253 } else {
254 return None;
255 }
256 }
257
258 normalized.push(')');
259
260 while pos < len && (bytes[pos].is_ascii_whitespace() || bytes[pos] == b';') {
261 pos += 1;
262 }
263 if pos != len {
264 return None;
265 }
266
267 if values.is_empty() {
268 return None;
269 }
270
271 Some((normalized, values))
272}
273
274pub(crate) struct CacheEntry {
275 pub(crate) stmt: Arc<Statement>,
276 pub(crate) schema_gen: u64,
277 pub(crate) param_count: usize,
278 pub(crate) compiled: Option<Arc<dyn executor::CompiledPlan>>,
279}
280
281struct SavepointEntry {
282 name: String,
283 snapshot: Option<SavepointSnapshot>,
284}
285
286struct SavepointSnapshot {
287 wtx_snap: WriteTxnSnapshot,
288 schema_snap: SchemaSnapshot,
289}
290
291pub(crate) struct ConnectionInner<'a> {
292 pub(crate) schema: SchemaManager,
293 active_txn: Option<WriteTxn<'a>>,
294 savepoint_stack: Vec<SavepointEntry>,
295 in_place_saved: Option<bool>,
296 pub(crate) stmt_cache: LruCache<String, CacheEntry>,
297 txn_start_ts: Option<i64>,
298 session_timezone: String,
299}
300
301pub struct Connection<'a> {
303 pub(crate) db: &'a Database,
304 pub(crate) inner: RefCell<ConnectionInner<'a>>,
305}
306
307impl<'a> Connection<'a> {
308 pub fn open(db: &'a Database) -> Result<Self> {
310 let schema = SchemaManager::load(db)?;
311 let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
312 Ok(Self {
313 db,
314 inner: RefCell::new(ConnectionInner {
315 schema,
316 active_txn: None,
317 savepoint_stack: Vec::new(),
318 in_place_saved: None,
319 stmt_cache,
320 txn_start_ts: None,
321 session_timezone: "UTC".to_string(),
322 }),
323 })
324 }
325
326 pub fn txn_start_ts(&self) -> Option<i64> {
328 self.inner.borrow().txn_start_ts
329 }
330
331 pub fn session_timezone(&self) -> String {
333 self.inner.borrow().session_timezone.clone()
334 }
335
336 pub fn set_session_timezone(&self, tz: &str) -> Result<()> {
338 self.inner.borrow_mut().set_session_timezone_impl(tz)
339 }
340
341 pub fn execute(&self, sql: &str) -> Result<ExecutionResult> {
343 self.inner.borrow_mut().execute_impl(self.db, sql)
344 }
345
346 pub fn execute_params(&self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
348 self.inner
349 .borrow_mut()
350 .execute_params_impl(self.db, sql, params)
351 }
352
353 pub fn execute_script(&self, sql: &str) -> ScriptExecution {
355 let stmts = match parser::parse_sql_multi(sql) {
356 Ok(s) => s,
357 Err(e) => {
358 return ScriptExecution {
359 completed: vec![],
360 error: Some(e),
361 }
362 }
363 };
364 let mut completed = Vec::with_capacity(stmts.len());
365 for stmt in stmts {
366 match self.inner.borrow_mut().dispatch(self.db, &stmt, &[]) {
367 Ok(r) => completed.push(r),
368 Err(e) => {
369 return ScriptExecution {
370 completed,
371 error: Some(e),
372 }
373 }
374 }
375 }
376 ScriptExecution {
377 completed,
378 error: None,
379 }
380 }
381
382 pub fn query(&self, sql: &str) -> Result<QueryResult> {
384 self.query_params(sql, &[])
385 }
386
387 pub fn query_params(&self, sql: &str, params: &[Value]) -> Result<QueryResult> {
389 match self.execute_params(sql, params)? {
390 ExecutionResult::Query(qr) => Ok(qr),
391 ExecutionResult::RowsAffected(n) => Ok(QueryResult {
392 columns: vec!["rows_affected".into()],
393 rows: vec![vec![Value::Integer(n as i64)]],
394 }),
395 ExecutionResult::Ok => Ok(QueryResult {
396 columns: vec![],
397 rows: vec![],
398 }),
399 }
400 }
401
402 pub fn prepare(&self, sql: &str) -> Result<PreparedStatement<'_, 'a>> {
404 PreparedStatement::new(self, sql)
405 }
406
407 pub fn tables(&self) -> Vec<String> {
409 self.inner
410 .borrow()
411 .schema
412 .table_names()
413 .into_iter()
414 .map(String::from)
415 .collect()
416 }
417
418 pub fn in_transaction(&self) -> bool {
420 self.inner.borrow().active_txn.is_some()
421 }
422
423 pub fn table_schema(&self, name: &str) -> Option<TableSchema> {
425 self.inner.borrow().schema.get(name).cloned()
426 }
427
428 pub fn refresh_schema(&self) -> Result<()> {
430 let new_schema = SchemaManager::load(self.db)?;
431 self.inner.borrow_mut().schema = new_schema;
432 Ok(())
433 }
434}
435
436impl<'a> ConnectionInner<'a> {
437 pub(crate) fn active_txn_is_some(&self) -> bool {
438 self.active_txn.is_some()
439 }
440
441 fn set_session_timezone_impl(&mut self, tz: &str) -> Result<()> {
442 let upper = tz.to_ascii_uppercase();
443 if (upper.starts_with("UTC+") || upper.starts_with("UTC-")) && tz.len() > 3 {
444 return Err(SqlError::InvalidTimezone(format!(
445 "'{tz}' is ambiguous; use ISO-8601 offset (e.g. '+05:00') or named zone (e.g. 'Etc/GMT-5')"
446 )));
447 }
448 if jiff::tz::TimeZone::get(tz).is_err() && parse_fixed_offset(tz).is_none() {
449 return Err(SqlError::InvalidTimezone(format!(
450 "{tz}: not a known IANA zone or ISO-8601 offset (e.g. '+05:00', 'UTC', 'America/New_York')"
451 )));
452 }
453 self.session_timezone = tz.to_string();
454 Ok(())
455 }
456
457 fn execute_impl(&mut self, db: &'a Database, sql: &str) -> Result<ExecutionResult> {
458 if matches!(sql.as_bytes().first(), Some(b'I' | b'i')) {
459 if let Some((normalized_key, extracted)) = try_normalize_insert(sql) {
460 let gen = self.schema.generation();
461 let stmt = if let Some(entry) = self.stmt_cache.get(&normalized_key) {
462 if entry.schema_gen == gen {
463 Arc::clone(&entry.stmt)
464 } else {
465 self.parse_and_cache(normalized_key, gen)?
466 }
467 } else {
468 self.parse_and_cache(normalized_key, gen)?
469 };
470 return self.dispatch(db, &stmt, &extracted);
471 }
472 }
473 self.execute_params_impl(db, sql, &[])
474 }
475
476 fn execute_params_impl(
477 &mut self,
478 db: &'a Database,
479 sql: &str,
480 params: &[Value],
481 ) -> Result<ExecutionResult> {
482 let gen = self.schema.generation();
483 if self.active_txn.is_none() {
484 if let Some(entry) = self.stmt_cache.get(sql) {
485 if entry.schema_gen == gen && entry.param_count == params.len() {
486 if let Some(plan) = entry.compiled.as_ref().map(Arc::clone) {
487 let stmt = Arc::clone(&entry.stmt);
488 return self.run_compiled(db, &plan, &stmt, params);
489 }
490 }
491 }
492 }
493
494 let (stmt, param_count) = self.get_or_parse(sql)?;
495
496 if param_count != params.len() {
497 return Err(SqlError::ParameterCountMismatch {
498 expected: param_count,
499 got: params.len(),
500 });
501 }
502
503 if self.active_txn.is_none() {
504 if let Some(plan) = executor::compile(&self.schema, &stmt) {
505 if let Some(e) = self.stmt_cache.get_mut(sql) {
506 e.compiled = Some(Arc::clone(&plan));
507 }
508 let stmt_owned = Arc::clone(&stmt);
509 return self.run_compiled(db, &plan, &stmt_owned, params);
510 }
511 }
512
513 self.dispatch(db, &stmt, params)
514 }
515
516 fn run_compiled(
517 &mut self,
518 db: &'a Database,
519 plan: &Arc<dyn executor::CompiledPlan>,
520 stmt: &Statement,
521 params: &[Value],
522 ) -> Result<ExecutionResult> {
523 let cached_ts = self
524 .txn_start_ts
525 .or_else(|| Some(crate::datetime::now_micros()));
526 let schema = &self.schema;
527 crate::datetime::with_txn_clock(cached_ts, || {
528 if params.is_empty() {
529 plan.execute(db, schema, stmt, params, None)
530 } else {
531 crate::eval::with_scoped_params(params, || {
532 plan.execute(db, schema, stmt, params, None)
533 })
534 }
535 })
536 }
537
538 pub(crate) fn parse_and_cache(
539 &mut self,
540 normalized_key: String,
541 gen: u64,
542 ) -> Result<Arc<Statement>> {
543 let stmt = Arc::new(parser::parse_sql(&normalized_key)?);
544 let param_count = parser::count_params(&stmt);
545 self.stmt_cache.put(
546 normalized_key,
547 CacheEntry {
548 stmt: Arc::clone(&stmt),
549 schema_gen: gen,
550 param_count,
551 compiled: None,
552 },
553 );
554 Ok(stmt)
555 }
556
557 pub(crate) fn get_or_parse(&mut self, sql: &str) -> Result<(Arc<Statement>, usize)> {
558 let gen = self.schema.generation();
559
560 if let Some(entry) = self.stmt_cache.get(sql) {
561 if entry.schema_gen == gen {
562 return Ok((Arc::clone(&entry.stmt), entry.param_count));
563 }
564 }
565
566 let stmt = Arc::new(parser::parse_sql(sql)?);
567 let param_count = parser::count_params(&stmt);
568
569 let cacheable = !matches!(
570 *stmt,
571 Statement::CreateTable(_)
572 | Statement::DropTable(_)
573 | Statement::CreateIndex(_)
574 | Statement::DropIndex(_)
575 | Statement::CreateView(_)
576 | Statement::DropView(_)
577 | Statement::AlterTable(_)
578 );
579
580 if cacheable {
581 self.stmt_cache.put(
582 sql.to_string(),
583 CacheEntry {
584 stmt: Arc::clone(&stmt),
585 schema_gen: gen,
586 param_count,
587 compiled: None,
588 },
589 );
590 }
591
592 Ok((stmt, param_count))
593 }
594
595 pub(crate) fn execute_prepared(
596 &mut self,
597 db: &'a Database,
598 stmt: &Statement,
599 compiled: Option<&Arc<dyn executor::CompiledPlan>>,
600 params: &[Value],
601 ) -> Result<ExecutionResult> {
602 if let Some(plan) = compiled {
603 if self.active_txn.is_none() {
604 return self.run_compiled(db, plan, stmt, params);
605 }
606 if !self.savepoint_stack.is_empty() && stmt_mutates(stmt) {
607 self.capture_pending_snapshots();
608 }
609 return self.run_compiled_in_txn(db, plan, stmt, params);
610 }
611 self.dispatch(db, stmt, params)
612 }
613
614 fn run_compiled_in_txn(
615 &mut self,
616 db: &'a Database,
617 plan: &Arc<dyn executor::CompiledPlan>,
618 stmt: &Statement,
619 params: &[Value],
620 ) -> Result<ExecutionResult> {
621 let schema = &self.schema;
622 let wtx = self.active_txn.as_mut();
623 if params.is_empty() || !plan.uses_scoped_params() {
624 plan.execute(db, schema, stmt, params, wtx)
625 } else {
626 crate::eval::with_scoped_params(params, || plan.execute(db, schema, stmt, params, wtx))
627 }
628 }
629
630 pub(crate) fn dispatch(
631 &mut self,
632 db: &'a Database,
633 stmt: &Statement,
634 params: &[Value],
635 ) -> Result<ExecutionResult> {
636 let cached_ts = self
637 .txn_start_ts
638 .or_else(|| Some(crate::datetime::now_micros()));
639 crate::datetime::with_txn_clock(cached_ts, || {
640 if params.is_empty() {
641 self.dispatch_inner(db, stmt, params)
642 } else {
643 crate::eval::with_scoped_params(params, || self.dispatch_inner(db, stmt, params))
644 }
645 })
646 }
647
648 fn dispatch_inner(
649 &mut self,
650 db: &'a Database,
651 stmt: &Statement,
652 params: &[Value],
653 ) -> Result<ExecutionResult> {
654 match stmt {
655 Statement::Begin => {
656 if self.active_txn.is_some() {
657 return Err(SqlError::TransactionAlreadyActive);
658 }
659 let wtx = db.begin_write().map_err(SqlError::Storage)?;
660 self.active_txn = Some(wtx);
661 let ts = crate::datetime::txn_or_clock_micros();
662 self.txn_start_ts = Some(ts);
663 crate::datetime::set_txn_clock(Some(ts));
664 Ok(ExecutionResult::Ok)
665 }
666 Statement::Commit => {
667 let wtx = self
668 .active_txn
669 .take()
670 .ok_or(SqlError::NoActiveTransaction)?;
671 wtx.commit().map_err(SqlError::Storage)?;
672 self.clear_savepoint_state();
673 self.txn_start_ts = None;
674 crate::datetime::set_txn_clock(None);
675 Ok(ExecutionResult::Ok)
676 }
677 Statement::Rollback => {
678 let wtx = self
679 .active_txn
680 .take()
681 .ok_or(SqlError::NoActiveTransaction)?;
682 wtx.abort();
683 self.clear_savepoint_state();
684 self.schema = SchemaManager::load(db)?;
685 self.txn_start_ts = None;
686 crate::datetime::set_txn_clock(None);
687 Ok(ExecutionResult::Ok)
688 }
689 Statement::Savepoint(name) => self.do_savepoint(name),
690 Statement::ReleaseSavepoint(name) => self.do_release(name),
691 Statement::RollbackTo(name) => self.do_rollback_to(name),
692 Statement::SetTimezone(zone) => {
693 self.set_session_timezone_impl(zone)?;
694 Ok(ExecutionResult::Ok)
695 }
696 Statement::Insert(ins) if self.active_txn.is_some() => {
697 self.capture_pending_snapshots();
698 let wtx = self.active_txn.as_mut().unwrap();
699 executor::exec_insert_in_txn(wtx, &self.schema, ins, params)
700 }
701 _ => {
702 if self.active_txn.is_some() && stmt_mutates(stmt) {
703 self.capture_pending_snapshots();
704 }
705 if let Some(ref mut wtx) = self.active_txn {
706 executor::execute_in_txn(wtx, &mut self.schema, stmt, params)
707 } else {
708 executor::execute(db, &mut self.schema, stmt, params)
709 }
710 }
711 }
712 }
713
714 fn clear_savepoint_state(&mut self) {
715 self.savepoint_stack.clear();
716 self.in_place_saved = None;
717 }
718
719 fn do_savepoint(&mut self, name: &str) -> Result<ExecutionResult> {
720 let wtx = self
721 .active_txn
722 .as_mut()
723 .ok_or(SqlError::NoActiveTransaction)?;
724
725 if self.savepoint_stack.is_empty() {
726 self.in_place_saved = Some(wtx.in_place());
727 wtx.set_in_place(false);
728 }
729
730 self.savepoint_stack.push(SavepointEntry {
731 name: name.to_string(),
732 snapshot: None,
733 });
734
735 Ok(ExecutionResult::Ok)
736 }
737
738 fn capture_pending_snapshots(&mut self) {
739 let last_pending = match self
740 .savepoint_stack
741 .iter()
742 .rposition(|e| e.snapshot.is_none())
743 {
744 Some(i) => i,
745 None => return,
746 };
747 let wtx = match self.active_txn.as_mut() {
748 Some(w) => w,
749 None => return,
750 };
751 let wtx_snap = wtx.begin_savepoint();
752 let schema_snap = self.schema.save_snapshot();
753
754 for i in 0..last_pending {
755 if self.savepoint_stack[i].snapshot.is_none() {
756 self.savepoint_stack[i].snapshot = Some(SavepointSnapshot {
757 wtx_snap: wtx_snap.clone(),
758 schema_snap: schema_snap.clone(),
759 });
760 }
761 }
762 self.savepoint_stack[last_pending].snapshot = Some(SavepointSnapshot {
763 wtx_snap,
764 schema_snap,
765 });
766 }
767
768 fn do_release(&mut self, name: &str) -> Result<ExecutionResult> {
769 if self.active_txn.is_none() {
770 return Err(SqlError::NoActiveTransaction);
771 }
772
773 let idx = self
774 .savepoint_stack
775 .iter()
776 .rposition(|e| e.name == name)
777 .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
778 self.savepoint_stack.truncate(idx);
779
780 if self.savepoint_stack.is_empty() {
781 if let (Some(wtx), Some(original)) =
782 (self.active_txn.as_mut(), self.in_place_saved.take())
783 {
784 wtx.set_in_place(original);
785 }
786 }
787
788 Ok(ExecutionResult::Ok)
789 }
790
791 fn do_rollback_to(&mut self, name: &str) -> Result<ExecutionResult> {
792 if self.active_txn.is_none() {
793 return Err(SqlError::NoActiveTransaction);
794 }
795
796 let idx = self
797 .savepoint_stack
798 .iter()
799 .rposition(|e| e.name == name)
800 .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
801
802 self.savepoint_stack.truncate(idx + 1);
803 let entry = self.savepoint_stack.last_mut().unwrap();
804 let snapshot = match entry.snapshot.take() {
805 Some(s) => s,
806 None => return Ok(ExecutionResult::Ok),
807 };
808
809 let wtx = self.active_txn.as_mut().unwrap();
810 wtx.restore_snapshot(snapshot.wtx_snap);
811 self.schema.restore_snapshot(snapshot.schema_snap);
812
813 Ok(ExecutionResult::Ok)
814 }
815}