1use std::cell::RefCell;
4use std::num::NonZeroUsize;
5use std::sync::Arc;
6
7use lru::LruCache;
8
9use citadel::Database;
10use citadel_txn::write_txn::{WriteTxn, WriteTxnSnapshot};
11
12use crate::error::{Result, SqlError};
13use crate::executor;
14use crate::parser;
15use crate::parser::Statement;
16use crate::prepared::PreparedStatement;
17use crate::schema::{SchemaManager, SchemaSnapshot};
18use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
19
20const DEFAULT_CACHE_CAPACITY: usize = 64;
21
22fn parse_fixed_offset(s: &str) -> Option<()> {
23 let s = s.trim();
24 if s.eq_ignore_ascii_case("z") || s.eq_ignore_ascii_case("utc") {
25 return Some(());
26 }
27 let bytes = s.as_bytes();
28 if bytes.is_empty() {
29 return None;
30 }
31 let sign = match bytes[0] {
32 b'+' | b'-' => bytes[0] as char,
33 _ => return None,
34 };
35 let rest = &s[1..];
36 let (hh, mm) = if let Some((h, m)) = rest.split_once(':') {
37 (h, m)
38 } else if rest.len() == 4 {
39 (&rest[..2], &rest[2..])
40 } else if rest.len() == 2 {
41 (rest, "00")
42 } else {
43 return None;
44 };
45 let h: u32 = hh.parse().ok()?;
46 let m: u32 = mm.parse().ok()?;
47 if h > 23 || m > 59 {
48 return None;
49 }
50 let _ = sign;
51 Some(())
52}
53
54fn stmt_mutates(stmt: &Statement) -> bool {
55 matches!(
56 stmt,
57 Statement::Insert(_)
58 | Statement::Update(_)
59 | Statement::Delete(_)
60 | Statement::CreateTable(_)
61 | Statement::DropTable(_)
62 | Statement::AlterTable(_)
63 | Statement::CreateIndex(_)
64 | Statement::DropIndex(_)
65 | Statement::CreateView(_)
66 | Statement::DropView(_)
67 )
68}
69
70fn try_normalize_insert(sql: &str) -> Option<(String, Vec<Value>)> {
71 let bytes = sql.as_bytes();
72 let len = bytes.len();
73 let mut i = 0;
74
75 while i < len && bytes[i].is_ascii_whitespace() {
76 i += 1;
77 }
78 if i + 6 > len || !bytes[i..i + 6].eq_ignore_ascii_case(b"INSERT") {
79 return None;
80 }
81 i += 6;
82 if i >= len || !bytes[i].is_ascii_whitespace() {
83 return None;
84 }
85 while i < len && bytes[i].is_ascii_whitespace() {
86 i += 1;
87 }
88
89 if i + 4 > len || !bytes[i..i + 4].eq_ignore_ascii_case(b"INTO") {
90 return None;
91 }
92 i += 4;
93 if i >= len || !bytes[i].is_ascii_whitespace() {
94 return None;
95 }
96
97 let prefix_start = 0;
98 let mut values_pos = None;
99 let mut j = i;
100 while j + 6 <= len {
101 if bytes[j..j + 6].eq_ignore_ascii_case(b"VALUES")
102 && (j == 0 || !bytes[j - 1].is_ascii_alphanumeric() && bytes[j - 1] != b'_')
103 && (j + 6 >= len || !bytes[j + 6].is_ascii_alphanumeric() && bytes[j + 6] != b'_')
104 {
105 values_pos = Some(j);
106 break;
107 }
108 j += 1;
109 }
110 let values_pos = values_pos?;
111
112 let prefix = &sql[prefix_start..values_pos + 6];
113 let mut pos = values_pos + 6;
114
115 while pos < len && bytes[pos].is_ascii_whitespace() {
116 pos += 1;
117 }
118 if pos >= len || bytes[pos] != b'(' {
119 return None;
120 }
121 pos += 1;
122
123 let mut values = Vec::new();
124 let mut normalized = String::with_capacity(sql.len());
125 normalized.push_str(prefix);
126 normalized.push_str(" (");
127
128 loop {
129 while pos < len && bytes[pos].is_ascii_whitespace() {
130 pos += 1;
131 }
132 if pos >= len {
133 return None;
134 }
135
136 let param_idx = values.len() + 1;
137 if param_idx > 1 {
138 normalized.push_str(", ");
139 }
140
141 if bytes[pos] == b'\'' {
142 pos += 1;
143 let mut seg_start = pos;
144 let mut s = String::new();
145 loop {
146 if pos >= len {
147 return None;
148 }
149 if bytes[pos] == b'\'' {
150 s.push_str(std::str::from_utf8(&bytes[seg_start..pos]).ok()?);
151 pos += 1;
152 if pos < len && bytes[pos] == b'\'' {
153 s.push('\'');
154 pos += 1;
155 seg_start = pos;
156 } else {
157 break;
158 }
159 } else {
160 pos += 1;
161 }
162 }
163 values.push(Value::Text(s.into()));
164 } else if bytes[pos] == b'-' || bytes[pos].is_ascii_digit() {
165 let start = pos;
166 if bytes[pos] == b'-' {
167 pos += 1;
168 }
169 while pos < len && bytes[pos].is_ascii_digit() {
170 pos += 1;
171 }
172 if pos < len && bytes[pos] == b'.' {
173 pos += 1;
174 while pos < len && bytes[pos].is_ascii_digit() {
175 pos += 1;
176 }
177 let num: f64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
178 values.push(Value::Real(num));
179 } else {
180 let num: i64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
181 values.push(Value::Integer(num));
182 }
183 } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"NULL") {
184 let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
185 if !after.is_ascii_alphanumeric() && after != b'_' {
186 pos += 4;
187 values.push(Value::Null);
188 } else {
189 return None;
190 }
191 } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"TRUE") {
192 let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
193 if !after.is_ascii_alphanumeric() && after != b'_' {
194 pos += 4;
195 values.push(Value::Boolean(true));
196 } else {
197 return None;
198 }
199 } else if pos + 5 <= len && bytes[pos..pos + 5].eq_ignore_ascii_case(b"FALSE") {
200 let after = if pos + 5 < len { bytes[pos + 5] } else { b')' };
201 if !after.is_ascii_alphanumeric() && after != b'_' {
202 pos += 5;
203 values.push(Value::Boolean(false));
204 } else {
205 return None;
206 }
207 } else {
208 return None;
209 }
210
211 normalized.push('$');
212 normalized.push_str(¶m_idx.to_string());
213
214 while pos < len && bytes[pos].is_ascii_whitespace() {
215 pos += 1;
216 }
217 if pos >= len {
218 return None;
219 }
220
221 if bytes[pos] == b',' {
222 pos += 1;
223 } else if bytes[pos] == b')' {
224 pos += 1;
225 break;
226 } else {
227 return None;
228 }
229 }
230
231 normalized.push(')');
232
233 while pos < len && (bytes[pos].is_ascii_whitespace() || bytes[pos] == b';') {
234 pos += 1;
235 }
236 if pos != len {
237 return None;
238 }
239
240 if values.is_empty() {
241 return None;
242 }
243
244 Some((normalized, values))
245}
246
247pub(crate) struct CacheEntry {
248 pub(crate) stmt: Arc<Statement>,
249 pub(crate) schema_gen: u64,
250 pub(crate) param_count: usize,
251 pub(crate) compiled: Option<Arc<dyn executor::CompiledPlan>>,
252}
253
254struct SavepointEntry {
255 name: String,
256 snapshot: Option<SavepointSnapshot>,
257}
258
259struct SavepointSnapshot {
260 wtx_snap: WriteTxnSnapshot,
261 schema_snap: SchemaSnapshot,
262}
263
264pub(crate) struct ConnectionInner<'a> {
265 pub(crate) schema: SchemaManager,
266 active_txn: Option<WriteTxn<'a>>,
267 savepoint_stack: Vec<SavepointEntry>,
268 in_place_saved: Option<bool>,
269 pub(crate) stmt_cache: LruCache<String, CacheEntry>,
270 txn_start_ts: Option<i64>,
271 session_timezone: String,
272}
273
274pub struct Connection<'a> {
276 pub(crate) db: &'a Database,
277 pub(crate) inner: RefCell<ConnectionInner<'a>>,
278}
279
280impl<'a> Connection<'a> {
281 pub fn open(db: &'a Database) -> Result<Self> {
283 let schema = SchemaManager::load(db)?;
284 let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
285 Ok(Self {
286 db,
287 inner: RefCell::new(ConnectionInner {
288 schema,
289 active_txn: None,
290 savepoint_stack: Vec::new(),
291 in_place_saved: None,
292 stmt_cache,
293 txn_start_ts: None,
294 session_timezone: "UTC".to_string(),
295 }),
296 })
297 }
298
299 pub fn txn_start_ts(&self) -> Option<i64> {
301 self.inner.borrow().txn_start_ts
302 }
303
304 pub fn session_timezone(&self) -> String {
306 self.inner.borrow().session_timezone.clone()
307 }
308
309 pub fn set_session_timezone(&self, tz: &str) -> Result<()> {
311 self.inner.borrow_mut().set_session_timezone_impl(tz)
312 }
313
314 pub fn execute(&self, sql: &str) -> Result<ExecutionResult> {
316 self.inner.borrow_mut().execute_impl(self.db, sql)
317 }
318
319 pub fn execute_params(&self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
321 self.inner
322 .borrow_mut()
323 .execute_params_impl(self.db, sql, params)
324 }
325
326 pub fn query(&self, sql: &str) -> Result<QueryResult> {
328 self.query_params(sql, &[])
329 }
330
331 pub fn query_params(&self, sql: &str, params: &[Value]) -> Result<QueryResult> {
333 match self.execute_params(sql, params)? {
334 ExecutionResult::Query(qr) => Ok(qr),
335 ExecutionResult::RowsAffected(n) => Ok(QueryResult {
336 columns: vec!["rows_affected".into()],
337 rows: vec![vec![Value::Integer(n as i64)]],
338 }),
339 ExecutionResult::Ok => Ok(QueryResult {
340 columns: vec![],
341 rows: vec![],
342 }),
343 }
344 }
345
346 pub fn prepare(&self, sql: &str) -> Result<PreparedStatement<'_, 'a>> {
348 PreparedStatement::new(self, sql)
349 }
350
351 pub fn tables(&self) -> Vec<String> {
353 self.inner
354 .borrow()
355 .schema
356 .table_names()
357 .into_iter()
358 .map(String::from)
359 .collect()
360 }
361
362 pub fn in_transaction(&self) -> bool {
364 self.inner.borrow().active_txn.is_some()
365 }
366
367 pub fn table_schema(&self, name: &str) -> Option<TableSchema> {
369 self.inner.borrow().schema.get(name).cloned()
370 }
371
372 pub fn refresh_schema(&self) -> Result<()> {
374 let new_schema = SchemaManager::load(self.db)?;
375 self.inner.borrow_mut().schema = new_schema;
376 Ok(())
377 }
378}
379
380impl<'a> ConnectionInner<'a> {
381 pub(crate) fn active_txn_is_some(&self) -> bool {
382 self.active_txn.is_some()
383 }
384
385 fn set_session_timezone_impl(&mut self, tz: &str) -> Result<()> {
386 let upper = tz.to_ascii_uppercase();
387 if (upper.starts_with("UTC+") || upper.starts_with("UTC-")) && tz.len() > 3 {
388 return Err(SqlError::InvalidTimezone(format!(
389 "'{tz}' is ambiguous; use ISO-8601 offset (e.g. '+05:00') or named zone (e.g. 'Etc/GMT-5')"
390 )));
391 }
392 if jiff::tz::TimeZone::get(tz).is_err() && parse_fixed_offset(tz).is_none() {
393 return Err(SqlError::InvalidTimezone(format!(
394 "{tz}: not a known IANA zone or ISO-8601 offset (e.g. '+05:00', 'UTC', 'America/New_York')"
395 )));
396 }
397 self.session_timezone = tz.to_string();
398 Ok(())
399 }
400
401 fn execute_impl(&mut self, db: &'a Database, sql: &str) -> Result<ExecutionResult> {
402 if matches!(sql.as_bytes().first(), Some(b'I' | b'i')) {
403 if let Some((normalized_key, extracted)) = try_normalize_insert(sql) {
404 let gen = self.schema.generation();
405 let stmt = if let Some(entry) = self.stmt_cache.get(&normalized_key) {
406 if entry.schema_gen == gen {
407 Arc::clone(&entry.stmt)
408 } else {
409 self.parse_and_cache(normalized_key, gen)?
410 }
411 } else {
412 self.parse_and_cache(normalized_key, gen)?
413 };
414 return self.dispatch(db, &stmt, &extracted);
415 }
416 }
417 self.execute_params_impl(db, sql, &[])
418 }
419
420 fn execute_params_impl(
421 &mut self,
422 db: &'a Database,
423 sql: &str,
424 params: &[Value],
425 ) -> Result<ExecutionResult> {
426 let gen = self.schema.generation();
427 if self.active_txn.is_none() {
428 if let Some(entry) = self.stmt_cache.get(sql) {
429 if entry.schema_gen == gen && entry.param_count == params.len() {
430 if let Some(plan) = entry.compiled.as_ref().map(Arc::clone) {
431 let stmt = Arc::clone(&entry.stmt);
432 return self.run_compiled(db, &plan, &stmt, params);
433 }
434 }
435 }
436 }
437
438 let (stmt, param_count) = self.get_or_parse(sql)?;
439
440 if param_count != params.len() {
441 return Err(SqlError::ParameterCountMismatch {
442 expected: param_count,
443 got: params.len(),
444 });
445 }
446
447 if self.active_txn.is_none() {
448 if let Some(plan) = executor::compile(&self.schema, &stmt) {
449 if let Some(e) = self.stmt_cache.get_mut(sql) {
450 e.compiled = Some(Arc::clone(&plan));
451 }
452 let stmt_owned = Arc::clone(&stmt);
453 return self.run_compiled(db, &plan, &stmt_owned, params);
454 }
455 }
456
457 self.dispatch(db, &stmt, params)
458 }
459
460 fn run_compiled(
461 &mut self,
462 db: &'a Database,
463 plan: &Arc<dyn executor::CompiledPlan>,
464 stmt: &Statement,
465 params: &[Value],
466 ) -> Result<ExecutionResult> {
467 let cached_ts = self
468 .txn_start_ts
469 .or_else(|| Some(crate::datetime::now_micros()));
470 let schema = &self.schema;
471 crate::datetime::with_txn_clock(cached_ts, || {
472 if params.is_empty() {
473 plan.execute(db, schema, stmt, params, None)
474 } else {
475 crate::eval::with_scoped_params(params, || {
476 plan.execute(db, schema, stmt, params, None)
477 })
478 }
479 })
480 }
481
482 pub(crate) fn parse_and_cache(
483 &mut self,
484 normalized_key: String,
485 gen: u64,
486 ) -> Result<Arc<Statement>> {
487 let stmt = Arc::new(parser::parse_sql(&normalized_key)?);
488 let param_count = parser::count_params(&stmt);
489 self.stmt_cache.put(
490 normalized_key,
491 CacheEntry {
492 stmt: Arc::clone(&stmt),
493 schema_gen: gen,
494 param_count,
495 compiled: None,
496 },
497 );
498 Ok(stmt)
499 }
500
501 pub(crate) fn get_or_parse(&mut self, sql: &str) -> Result<(Arc<Statement>, usize)> {
502 let gen = self.schema.generation();
503
504 if let Some(entry) = self.stmt_cache.get(sql) {
505 if entry.schema_gen == gen {
506 return Ok((Arc::clone(&entry.stmt), entry.param_count));
507 }
508 }
509
510 let stmt = Arc::new(parser::parse_sql(sql)?);
511 let param_count = parser::count_params(&stmt);
512
513 let cacheable = !matches!(
514 *stmt,
515 Statement::CreateTable(_)
516 | Statement::DropTable(_)
517 | Statement::CreateIndex(_)
518 | Statement::DropIndex(_)
519 | Statement::CreateView(_)
520 | Statement::DropView(_)
521 | Statement::AlterTable(_)
522 );
523
524 if cacheable {
525 self.stmt_cache.put(
526 sql.to_string(),
527 CacheEntry {
528 stmt: Arc::clone(&stmt),
529 schema_gen: gen,
530 param_count,
531 compiled: None,
532 },
533 );
534 }
535
536 Ok((stmt, param_count))
537 }
538
539 pub(crate) fn execute_prepared(
540 &mut self,
541 db: &'a Database,
542 stmt: &Statement,
543 compiled: Option<&Arc<dyn executor::CompiledPlan>>,
544 params: &[Value],
545 ) -> Result<ExecutionResult> {
546 if let Some(plan) = compiled {
547 if self.active_txn.is_none() {
548 return self.run_compiled(db, plan, stmt, params);
549 }
550 if stmt_mutates(stmt) {
551 self.capture_pending_snapshots();
552 }
553 return self.run_compiled_in_txn(db, plan, stmt, params);
554 }
555 self.dispatch(db, stmt, params)
556 }
557
558 fn run_compiled_in_txn(
559 &mut self,
560 db: &'a Database,
561 plan: &Arc<dyn executor::CompiledPlan>,
562 stmt: &Statement,
563 params: &[Value],
564 ) -> Result<ExecutionResult> {
565 let cached_ts = self
566 .txn_start_ts
567 .or_else(|| Some(crate::datetime::now_micros()));
568 let schema = &self.schema;
569 let wtx = self.active_txn.as_mut();
570 crate::datetime::with_txn_clock(cached_ts, || {
571 if params.is_empty() {
572 plan.execute(db, schema, stmt, params, wtx)
573 } else {
574 crate::eval::with_scoped_params(params, || {
575 plan.execute(db, schema, stmt, params, wtx)
576 })
577 }
578 })
579 }
580
581 pub(crate) fn dispatch(
582 &mut self,
583 db: &'a Database,
584 stmt: &Statement,
585 params: &[Value],
586 ) -> Result<ExecutionResult> {
587 let cached_ts = self
588 .txn_start_ts
589 .or_else(|| Some(crate::datetime::now_micros()));
590 crate::datetime::with_txn_clock(cached_ts, || {
591 if params.is_empty() {
592 self.dispatch_inner(db, stmt, params)
593 } else {
594 crate::eval::with_scoped_params(params, || self.dispatch_inner(db, stmt, params))
595 }
596 })
597 }
598
599 fn dispatch_inner(
600 &mut self,
601 db: &'a Database,
602 stmt: &Statement,
603 params: &[Value],
604 ) -> Result<ExecutionResult> {
605 match stmt {
606 Statement::Begin => {
607 if self.active_txn.is_some() {
608 return Err(SqlError::TransactionAlreadyActive);
609 }
610 let wtx = db.begin_write().map_err(SqlError::Storage)?;
611 self.active_txn = Some(wtx);
612 self.txn_start_ts = Some(crate::datetime::txn_or_clock_micros());
613 Ok(ExecutionResult::Ok)
614 }
615 Statement::Commit => {
616 let wtx = self
617 .active_txn
618 .take()
619 .ok_or(SqlError::NoActiveTransaction)?;
620 wtx.commit().map_err(SqlError::Storage)?;
621 self.clear_savepoint_state();
622 self.txn_start_ts = None;
623 Ok(ExecutionResult::Ok)
624 }
625 Statement::Rollback => {
626 let wtx = self
627 .active_txn
628 .take()
629 .ok_or(SqlError::NoActiveTransaction)?;
630 wtx.abort();
631 self.clear_savepoint_state();
632 self.schema = SchemaManager::load(db)?;
633 self.txn_start_ts = None;
634 Ok(ExecutionResult::Ok)
635 }
636 Statement::Savepoint(name) => self.do_savepoint(name),
637 Statement::ReleaseSavepoint(name) => self.do_release(name),
638 Statement::RollbackTo(name) => self.do_rollback_to(name),
639 Statement::SetTimezone(zone) => {
640 self.set_session_timezone_impl(zone)?;
641 Ok(ExecutionResult::Ok)
642 }
643 Statement::Insert(ins) if self.active_txn.is_some() => {
644 self.capture_pending_snapshots();
645 let wtx = self.active_txn.as_mut().unwrap();
646 executor::exec_insert_in_txn(wtx, &self.schema, ins, params)
647 }
648 _ => {
649 if self.active_txn.is_some() && stmt_mutates(stmt) {
650 self.capture_pending_snapshots();
651 }
652 if let Some(ref mut wtx) = self.active_txn {
653 executor::execute_in_txn(wtx, &mut self.schema, stmt, params)
654 } else {
655 executor::execute(db, &mut self.schema, stmt, params)
656 }
657 }
658 }
659 }
660
661 fn clear_savepoint_state(&mut self) {
662 self.savepoint_stack.clear();
663 self.in_place_saved = None;
664 }
665
666 fn do_savepoint(&mut self, name: &str) -> Result<ExecutionResult> {
667 let wtx = self
668 .active_txn
669 .as_mut()
670 .ok_or(SqlError::NoActiveTransaction)?;
671
672 if self.savepoint_stack.is_empty() {
673 self.in_place_saved = Some(wtx.in_place());
674 wtx.set_in_place(false);
675 }
676
677 self.savepoint_stack.push(SavepointEntry {
678 name: name.to_string(),
679 snapshot: None,
680 });
681
682 Ok(ExecutionResult::Ok)
683 }
684
685 fn capture_pending_snapshots(&mut self) {
686 if !self.savepoint_stack.iter().any(|e| e.snapshot.is_none()) {
687 return;
688 }
689 let wtx = match self.active_txn.as_mut() {
690 Some(w) => w,
691 None => return,
692 };
693 let wtx_snap = wtx.begin_savepoint();
694 let schema_snap = self.schema.save_snapshot();
695 let mut pending = self
696 .savepoint_stack
697 .iter_mut()
698 .filter(|e| e.snapshot.is_none());
699 if let Some(first) = pending.next() {
700 first.snapshot = Some(SavepointSnapshot {
701 wtx_snap: wtx_snap.clone(),
702 schema_snap: schema_snap.clone(),
703 });
704 }
705 for entry in pending {
706 entry.snapshot = Some(SavepointSnapshot {
707 wtx_snap: wtx_snap.clone(),
708 schema_snap: schema_snap.clone(),
709 });
710 }
711 }
712
713 fn do_release(&mut self, name: &str) -> Result<ExecutionResult> {
714 if self.active_txn.is_none() {
715 return Err(SqlError::NoActiveTransaction);
716 }
717
718 let idx = self
719 .savepoint_stack
720 .iter()
721 .rposition(|e| e.name == name)
722 .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
723 self.savepoint_stack.truncate(idx);
724
725 if self.savepoint_stack.is_empty() {
726 if let (Some(wtx), Some(original)) =
727 (self.active_txn.as_mut(), self.in_place_saved.take())
728 {
729 wtx.set_in_place(original);
730 }
731 }
732
733 Ok(ExecutionResult::Ok)
734 }
735
736 fn do_rollback_to(&mut self, name: &str) -> Result<ExecutionResult> {
737 if self.active_txn.is_none() {
738 return Err(SqlError::NoActiveTransaction);
739 }
740
741 let idx = self
742 .savepoint_stack
743 .iter()
744 .rposition(|e| e.name == name)
745 .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
746
747 self.savepoint_stack.truncate(idx + 1);
748 let entry = self.savepoint_stack.last_mut().unwrap();
749 let snapshot = match entry.snapshot.take() {
750 Some(s) => s,
751 None => return Ok(ExecutionResult::Ok),
752 };
753
754 let wtx = self.active_txn.as_mut().unwrap();
755 wtx.restore_snapshot(snapshot.wtx_snap);
756 self.schema.restore_snapshot(snapshot.schema_snap);
757
758 self.stmt_cache.clear();
759
760 Ok(ExecutionResult::Ok)
761 }
762}