1use std::num::NonZeroUsize;
4use std::sync::Arc;
5
6use lru::LruCache;
7
8use citadel::Database;
9use citadel_txn::write_txn::{WriteTxn, WriteTxnSnapshot};
10
11use crate::error::{Result, SqlError};
12use crate::executor;
13use crate::parser;
14use crate::parser::{InsertSource, Statement};
15use crate::schema::{SchemaManager, SchemaSnapshot};
16use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
17
18const DEFAULT_CACHE_CAPACITY: usize = 64;
19
20fn try_normalize_insert(sql: &str) -> Option<(String, Vec<Value>)> {
21 let bytes = sql.as_bytes();
22 let len = bytes.len();
23 let mut i = 0;
24
25 while i < len && bytes[i].is_ascii_whitespace() {
26 i += 1;
27 }
28 if i + 6 > len || !bytes[i..i + 6].eq_ignore_ascii_case(b"INSERT") {
29 return None;
30 }
31 i += 6;
32 if i >= len || !bytes[i].is_ascii_whitespace() {
33 return None;
34 }
35 while i < len && bytes[i].is_ascii_whitespace() {
36 i += 1;
37 }
38
39 if i + 4 > len || !bytes[i..i + 4].eq_ignore_ascii_case(b"INTO") {
40 return None;
41 }
42 i += 4;
43 if i >= len || !bytes[i].is_ascii_whitespace() {
44 return None;
45 }
46
47 let prefix_start = 0;
48 let mut values_pos = None;
49 let mut j = i;
50 while j + 6 <= len {
51 if bytes[j..j + 6].eq_ignore_ascii_case(b"VALUES")
52 && (j == 0 || !bytes[j - 1].is_ascii_alphanumeric() && bytes[j - 1] != b'_')
53 && (j + 6 >= len || !bytes[j + 6].is_ascii_alphanumeric() && bytes[j + 6] != b'_')
54 {
55 values_pos = Some(j);
56 break;
57 }
58 j += 1;
59 }
60 let values_pos = values_pos?;
61
62 let prefix = &sql[prefix_start..values_pos + 6];
63 let mut pos = values_pos + 6;
64
65 while pos < len && bytes[pos].is_ascii_whitespace() {
66 pos += 1;
67 }
68 if pos >= len || bytes[pos] != b'(' {
69 return None;
70 }
71 pos += 1;
72
73 let mut values = Vec::new();
74 let mut normalized = String::with_capacity(sql.len());
75 normalized.push_str(prefix);
76 normalized.push_str(" (");
77
78 loop {
79 while pos < len && bytes[pos].is_ascii_whitespace() {
80 pos += 1;
81 }
82 if pos >= len {
83 return None;
84 }
85
86 let param_idx = values.len() + 1;
87 if param_idx > 1 {
88 normalized.push_str(", ");
89 }
90
91 if bytes[pos] == b'\'' {
92 pos += 1;
93 let mut seg_start = pos;
94 let mut s = String::new();
95 loop {
96 if pos >= len {
97 return None;
98 }
99 if bytes[pos] == b'\'' {
100 s.push_str(std::str::from_utf8(&bytes[seg_start..pos]).ok()?);
101 pos += 1;
102 if pos < len && bytes[pos] == b'\'' {
103 s.push('\'');
104 pos += 1;
105 seg_start = pos;
106 } else {
107 break;
108 }
109 } else {
110 pos += 1;
111 }
112 }
113 values.push(Value::Text(s.into()));
114 } else if bytes[pos] == b'-' || bytes[pos].is_ascii_digit() {
115 let start = pos;
116 if bytes[pos] == b'-' {
117 pos += 1;
118 }
119 while pos < len && bytes[pos].is_ascii_digit() {
120 pos += 1;
121 }
122 if pos < len && bytes[pos] == b'.' {
123 pos += 1;
124 while pos < len && bytes[pos].is_ascii_digit() {
125 pos += 1;
126 }
127 let num: f64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
128 values.push(Value::Real(num));
129 } else {
130 let num: i64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
131 values.push(Value::Integer(num));
132 }
133 } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"NULL") {
134 let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
135 if !after.is_ascii_alphanumeric() && after != b'_' {
136 pos += 4;
137 values.push(Value::Null);
138 } else {
139 return None;
140 }
141 } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"TRUE") {
142 let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
143 if !after.is_ascii_alphanumeric() && after != b'_' {
144 pos += 4;
145 values.push(Value::Boolean(true));
146 } else {
147 return None;
148 }
149 } else if pos + 5 <= len && bytes[pos..pos + 5].eq_ignore_ascii_case(b"FALSE") {
150 let after = if pos + 5 < len { bytes[pos + 5] } else { b')' };
151 if !after.is_ascii_alphanumeric() && after != b'_' {
152 pos += 5;
153 values.push(Value::Boolean(false));
154 } else {
155 return None;
156 }
157 } else {
158 return None;
159 }
160
161 normalized.push('$');
162 normalized.push_str(¶m_idx.to_string());
163
164 while pos < len && bytes[pos].is_ascii_whitespace() {
165 pos += 1;
166 }
167 if pos >= len {
168 return None;
169 }
170
171 if bytes[pos] == b',' {
172 pos += 1;
173 } else if bytes[pos] == b')' {
174 pos += 1;
175 break;
176 } else {
177 return None;
178 }
179 }
180
181 normalized.push(')');
182
183 while pos < len && (bytes[pos].is_ascii_whitespace() || bytes[pos] == b';') {
184 pos += 1;
185 }
186 if pos != len {
187 return None;
188 }
189
190 if values.is_empty() {
191 return None;
192 }
193
194 Some((normalized, values))
195}
196
197struct CacheEntry {
198 stmt: Arc<Statement>,
199 schema_gen: u64,
200 param_count: usize,
201 compiled_update: Option<executor::CompiledUpdate>,
202}
203
204struct SavepointEntry {
205 name: String,
206 wtx_snap: WriteTxnSnapshot,
207 schema_snap: SchemaSnapshot,
208}
209
210pub struct Connection<'a> {
213 db: &'a Database,
214 schema: SchemaManager,
215 active_txn: Option<WriteTxn<'a>>,
216 savepoint_stack: Vec<SavepointEntry>,
217 in_place_saved: Option<bool>,
218 stmt_cache: LruCache<String, CacheEntry>,
219 insert_bufs: executor::InsertBufs,
220 update_bufs: executor::UpdateBufs,
221}
222
223impl<'a> Connection<'a> {
224 pub fn open(db: &'a Database) -> Result<Self> {
226 let schema = SchemaManager::load(db)?;
227 let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
228 Ok(Self {
229 db,
230 schema,
231 active_txn: None,
232 savepoint_stack: Vec::new(),
233 in_place_saved: None,
234 stmt_cache,
235 insert_bufs: executor::InsertBufs::new(),
236 update_bufs: executor::UpdateBufs::new(),
237 })
238 }
239
240 pub fn execute(&mut self, sql: &str) -> Result<ExecutionResult> {
242 if matches!(sql.as_bytes().first(), Some(b'I' | b'i')) {
243 if let Some((normalized_key, extracted)) = try_normalize_insert(sql) {
244 let gen = self.schema.generation();
245 let stmt = if let Some(entry) = self.stmt_cache.get(&normalized_key) {
246 if entry.schema_gen == gen {
247 Arc::clone(&entry.stmt)
248 } else {
249 self.parse_and_cache(normalized_key, gen)?
250 }
251 } else {
252 self.parse_and_cache(normalized_key, gen)?
253 };
254 return self.dispatch(&stmt, &extracted);
255 }
256 }
257
258 self.execute_params(sql, &[])
259 }
260
261 pub fn execute_params(&mut self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
263 if params.is_empty() && self.active_txn.is_none() {
264 let gen = self.schema.generation();
265 if let Some(entry) = self.stmt_cache.get(sql) {
266 if entry.schema_gen == gen && entry.param_count == 0 {
267 if let Statement::Update(ref upd) = *entry.stmt {
268 if let Some(ref compiled) = entry.compiled_update {
269 return executor::exec_update_compiled(
270 self.db,
271 &self.schema,
272 upd,
273 compiled,
274 &mut self.update_bufs,
275 );
276 }
277 let compiled = executor::compile_update(&self.schema, upd)?;
278 let result = executor::exec_update_compiled(
279 self.db,
280 &self.schema,
281 upd,
282 &compiled,
283 &mut self.update_bufs,
284 )?;
285 if let Some(e) = self.stmt_cache.get_mut(sql) {
286 e.compiled_update = Some(compiled);
287 }
288 return Ok(result);
289 }
290 }
291 }
292 }
293
294 let (stmt, param_count) = self.get_or_parse(sql)?;
295
296 if param_count != params.len() {
297 return Err(SqlError::ParameterCountMismatch {
298 expected: param_count,
299 got: params.len(),
300 });
301 }
302
303 if param_count == 0 && self.active_txn.is_none() {
304 if let Statement::Update(ref upd) = *stmt {
305 let compiled = executor::compile_update(&self.schema, upd)?;
306 let result = executor::exec_update_compiled(
307 self.db,
308 &self.schema,
309 upd,
310 &compiled,
311 &mut self.update_bufs,
312 )?;
313 if let Some(e) = self.stmt_cache.get_mut(sql) {
314 e.compiled_update = Some(compiled);
315 }
316 return Ok(result);
317 }
318 }
319
320 if param_count > 0
321 && matches!(*stmt, Statement::Insert(ref ins) if matches!(ins.source, InsertSource::Values(_)))
322 {
323 self.dispatch(&stmt, params)
324 } else if param_count > 0 {
325 let bound = parser::bind_params(&stmt, params)?;
326 self.dispatch(&bound, &[])
327 } else {
328 self.dispatch(&stmt, &[])
329 }
330 }
331
332 pub fn query(&mut self, sql: &str) -> Result<QueryResult> {
334 self.query_params(sql, &[])
335 }
336
337 pub fn query_params(&mut self, sql: &str, params: &[Value]) -> Result<QueryResult> {
339 match self.execute_params(sql, params)? {
340 ExecutionResult::Query(qr) => Ok(qr),
341 ExecutionResult::RowsAffected(n) => Ok(QueryResult {
342 columns: vec!["rows_affected".into()],
343 rows: vec![vec![Value::Integer(n as i64)]],
344 }),
345 ExecutionResult::Ok => Ok(QueryResult {
346 columns: vec![],
347 rows: vec![],
348 }),
349 }
350 }
351
352 pub fn tables(&self) -> Vec<&str> {
354 self.schema.table_names()
355 }
356
357 pub fn in_transaction(&self) -> bool {
359 self.active_txn.is_some()
360 }
361
362 pub fn table_schema(&self, name: &str) -> Option<&TableSchema> {
364 self.schema.get(name)
365 }
366
367 pub fn refresh_schema(&mut self) -> Result<()> {
369 self.schema = SchemaManager::load(self.db)?;
370 Ok(())
371 }
372
373 fn parse_and_cache(&mut self, normalized_key: String, gen: u64) -> Result<Arc<Statement>> {
374 let stmt = Arc::new(parser::parse_sql(&normalized_key)?);
375 let param_count = parser::count_params(&stmt);
376 self.stmt_cache.put(
377 normalized_key,
378 CacheEntry {
379 stmt: Arc::clone(&stmt),
380 schema_gen: gen,
381 param_count,
382 compiled_update: None,
383 },
384 );
385 Ok(stmt)
386 }
387
388 fn get_or_parse(&mut self, sql: &str) -> Result<(Arc<Statement>, usize)> {
389 let gen = self.schema.generation();
390
391 if let Some(entry) = self.stmt_cache.get(sql) {
392 if entry.schema_gen == gen {
393 return Ok((Arc::clone(&entry.stmt), entry.param_count));
394 }
395 }
396
397 let stmt = Arc::new(parser::parse_sql(sql)?);
398 let param_count = parser::count_params(&stmt);
399
400 let cacheable = !matches!(
401 *stmt,
402 Statement::CreateTable(_)
403 | Statement::DropTable(_)
404 | Statement::CreateIndex(_)
405 | Statement::DropIndex(_)
406 | Statement::CreateView(_)
407 | Statement::DropView(_)
408 | Statement::AlterTable(_)
409 | Statement::Begin
410 | Statement::Commit
411 | Statement::Rollback
412 | Statement::Savepoint(_)
413 | Statement::ReleaseSavepoint(_)
414 | Statement::RollbackTo(_)
415 );
416
417 if cacheable {
418 self.stmt_cache.put(
419 sql.to_string(),
420 CacheEntry {
421 stmt: Arc::clone(&stmt),
422 schema_gen: gen,
423 param_count,
424 compiled_update: None,
425 },
426 );
427 }
428
429 Ok((stmt, param_count))
430 }
431
432 fn dispatch(&mut self, stmt: &Statement, params: &[Value]) -> Result<ExecutionResult> {
433 match stmt {
434 Statement::Begin => {
435 if self.active_txn.is_some() {
436 return Err(SqlError::TransactionAlreadyActive);
437 }
438 let wtx = self.db.begin_write().map_err(SqlError::Storage)?;
439 self.active_txn = Some(wtx);
440 Ok(ExecutionResult::Ok)
441 }
442 Statement::Commit => {
443 let wtx = self
444 .active_txn
445 .take()
446 .ok_or(SqlError::NoActiveTransaction)?;
447 wtx.commit().map_err(SqlError::Storage)?;
448 self.clear_savepoint_state();
449 Ok(ExecutionResult::Ok)
450 }
451 Statement::Rollback => {
452 let wtx = self
453 .active_txn
454 .take()
455 .ok_or(SqlError::NoActiveTransaction)?;
456 wtx.abort();
457 self.clear_savepoint_state();
458 self.schema = SchemaManager::load(self.db)?;
459 Ok(ExecutionResult::Ok)
460 }
461 Statement::Savepoint(name) => self.do_savepoint(name),
462 Statement::ReleaseSavepoint(name) => self.do_release(name),
463 Statement::RollbackTo(name) => self.do_rollback_to(name),
464 Statement::Insert(ins) if self.active_txn.is_some() => {
465 let wtx = self.active_txn.as_mut().unwrap();
466 executor::exec_insert_in_txn(wtx, &self.schema, ins, params, &mut self.insert_bufs)
467 }
468 _ => {
469 if let Some(ref mut wtx) = self.active_txn {
470 executor::execute_in_txn(wtx, &mut self.schema, stmt, params)
471 } else {
472 executor::execute(self.db, &mut self.schema, stmt, params)
473 }
474 }
475 }
476 }
477
478 fn clear_savepoint_state(&mut self) {
479 self.savepoint_stack.clear();
480 self.in_place_saved = None;
481 }
482
483 fn do_savepoint(&mut self, name: &str) -> Result<ExecutionResult> {
484 let wtx = self
485 .active_txn
486 .as_mut()
487 .ok_or(SqlError::NoActiveTransaction)?;
488
489 if self.savepoint_stack.is_empty() {
490 self.in_place_saved = Some(wtx.in_place());
491 wtx.set_in_place(false);
492 }
493
494 let wtx_snap = wtx.begin_savepoint();
495 let schema_snap = self.schema.save_snapshot();
496
497 self.savepoint_stack.push(SavepointEntry {
498 name: name.to_string(),
499 wtx_snap,
500 schema_snap,
501 });
502
503 Ok(ExecutionResult::Ok)
504 }
505
506 fn do_release(&mut self, name: &str) -> Result<ExecutionResult> {
507 if self.active_txn.is_none() {
508 return Err(SqlError::NoActiveTransaction);
509 }
510
511 let idx = self
512 .savepoint_stack
513 .iter()
514 .rposition(|e| e.name == name)
515 .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
516 self.savepoint_stack.truncate(idx);
517
518 if self.savepoint_stack.is_empty() {
519 if let (Some(wtx), Some(original)) =
520 (self.active_txn.as_mut(), self.in_place_saved.take())
521 {
522 wtx.set_in_place(original);
523 }
524 }
525
526 Ok(ExecutionResult::Ok)
527 }
528
529 fn do_rollback_to(&mut self, name: &str) -> Result<ExecutionResult> {
530 if self.active_txn.is_none() {
531 return Err(SqlError::NoActiveTransaction);
532 }
533
534 let idx = self
535 .savepoint_stack
536 .iter()
537 .rposition(|e| e.name == name)
538 .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
539
540 self.savepoint_stack.truncate(idx + 1);
541 let entry = self.savepoint_stack.last().unwrap();
542 let wtx_snap = entry.wtx_snap.clone();
543 let schema_snap = entry.schema_snap.clone();
544
545 let wtx = self.active_txn.as_mut().unwrap();
546 wtx.restore_snapshot(wtx_snap);
547 self.schema.restore_snapshot(schema_snap);
548
549 self.stmt_cache.clear();
551
552 Ok(ExecutionResult::Ok)
553 }
554}