1use std::num::NonZeroUsize;
4use std::sync::Arc;
5
6use lru::LruCache;
7
8use citadel::Database;
9use citadel_txn::write_txn::WriteTxn;
10
11use crate::error::{Result, SqlError};
12use crate::executor;
13use crate::parser;
14use crate::parser::{InsertSource, Statement};
15use crate::schema::SchemaManager;
16use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
17
18const DEFAULT_CACHE_CAPACITY: usize = 64;
19
20fn try_normalize_insert(sql: &str) -> Option<(String, Vec<Value>)> {
21 let bytes = sql.as_bytes();
22 let len = bytes.len();
23 let mut i = 0;
24
25 while i < len && bytes[i].is_ascii_whitespace() {
26 i += 1;
27 }
28 if i + 6 > len || !bytes[i..i + 6].eq_ignore_ascii_case(b"INSERT") {
29 return None;
30 }
31 i += 6;
32 if i >= len || !bytes[i].is_ascii_whitespace() {
33 return None;
34 }
35 while i < len && bytes[i].is_ascii_whitespace() {
36 i += 1;
37 }
38
39 if i + 4 > len || !bytes[i..i + 4].eq_ignore_ascii_case(b"INTO") {
40 return None;
41 }
42 i += 4;
43 if i >= len || !bytes[i].is_ascii_whitespace() {
44 return None;
45 }
46
47 let prefix_start = 0;
48 let mut values_pos = None;
49 let mut j = i;
50 while j + 6 <= len {
51 if bytes[j..j + 6].eq_ignore_ascii_case(b"VALUES")
52 && (j == 0 || !bytes[j - 1].is_ascii_alphanumeric() && bytes[j - 1] != b'_')
53 && (j + 6 >= len || !bytes[j + 6].is_ascii_alphanumeric() && bytes[j + 6] != b'_')
54 {
55 values_pos = Some(j);
56 break;
57 }
58 j += 1;
59 }
60 let values_pos = values_pos?;
61
62 let prefix = &sql[prefix_start..values_pos + 6];
63 let mut pos = values_pos + 6;
64
65 while pos < len && bytes[pos].is_ascii_whitespace() {
66 pos += 1;
67 }
68 if pos >= len || bytes[pos] != b'(' {
69 return None;
70 }
71 pos += 1;
72
73 let mut values = Vec::new();
74 let mut normalized = String::with_capacity(sql.len());
75 normalized.push_str(prefix);
76 normalized.push_str(" (");
77
78 loop {
79 while pos < len && bytes[pos].is_ascii_whitespace() {
80 pos += 1;
81 }
82 if pos >= len {
83 return None;
84 }
85
86 let param_idx = values.len() + 1;
87 if param_idx > 1 {
88 normalized.push_str(", ");
89 }
90
91 if bytes[pos] == b'\'' {
92 pos += 1;
93 let mut seg_start = pos;
94 let mut s = String::new();
95 loop {
96 if pos >= len {
97 return None;
98 }
99 if bytes[pos] == b'\'' {
100 s.push_str(std::str::from_utf8(&bytes[seg_start..pos]).ok()?);
101 pos += 1;
102 if pos < len && bytes[pos] == b'\'' {
103 s.push('\'');
104 pos += 1;
105 seg_start = pos;
106 } else {
107 break;
108 }
109 } else {
110 pos += 1;
111 }
112 }
113 values.push(Value::Text(s.into()));
114 } else if bytes[pos] == b'-' || bytes[pos].is_ascii_digit() {
115 let start = pos;
116 if bytes[pos] == b'-' {
117 pos += 1;
118 }
119 while pos < len && bytes[pos].is_ascii_digit() {
120 pos += 1;
121 }
122 if pos < len && bytes[pos] == b'.' {
123 pos += 1;
124 while pos < len && bytes[pos].is_ascii_digit() {
125 pos += 1;
126 }
127 let num: f64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
128 values.push(Value::Real(num));
129 } else {
130 let num: i64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
131 values.push(Value::Integer(num));
132 }
133 } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"NULL") {
134 let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
135 if !after.is_ascii_alphanumeric() && after != b'_' {
136 pos += 4;
137 values.push(Value::Null);
138 } else {
139 return None;
140 }
141 } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"TRUE") {
142 let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
143 if !after.is_ascii_alphanumeric() && after != b'_' {
144 pos += 4;
145 values.push(Value::Boolean(true));
146 } else {
147 return None;
148 }
149 } else if pos + 5 <= len && bytes[pos..pos + 5].eq_ignore_ascii_case(b"FALSE") {
150 let after = if pos + 5 < len { bytes[pos + 5] } else { b')' };
151 if !after.is_ascii_alphanumeric() && after != b'_' {
152 pos += 5;
153 values.push(Value::Boolean(false));
154 } else {
155 return None;
156 }
157 } else {
158 return None;
159 }
160
161 normalized.push('$');
162 normalized.push_str(¶m_idx.to_string());
163
164 while pos < len && bytes[pos].is_ascii_whitespace() {
165 pos += 1;
166 }
167 if pos >= len {
168 return None;
169 }
170
171 if bytes[pos] == b',' {
172 pos += 1;
173 } else if bytes[pos] == b')' {
174 pos += 1;
175 break;
176 } else {
177 return None;
178 }
179 }
180
181 normalized.push(')');
182
183 while pos < len && (bytes[pos].is_ascii_whitespace() || bytes[pos] == b';') {
184 pos += 1;
185 }
186 if pos != len {
187 return None;
188 }
189
190 if values.is_empty() {
191 return None;
192 }
193
194 Some((normalized, values))
195}
196
197struct CacheEntry {
198 stmt: Arc<Statement>,
199 schema_gen: u64,
200 param_count: usize,
201 compiled_update: Option<executor::CompiledUpdate>,
202}
203
204pub struct Connection<'a> {
206 db: &'a Database,
207 schema: SchemaManager,
208 active_txn: Option<WriteTxn<'a>>,
209 stmt_cache: LruCache<String, CacheEntry>,
210 insert_bufs: executor::InsertBufs,
211 update_bufs: executor::UpdateBufs,
212}
213
214impl<'a> Connection<'a> {
215 pub fn open(db: &'a Database) -> Result<Self> {
217 let schema = SchemaManager::load(db)?;
218 let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
219 Ok(Self {
220 db,
221 schema,
222 active_txn: None,
223 stmt_cache,
224 insert_bufs: executor::InsertBufs::new(),
225 update_bufs: executor::UpdateBufs::new(),
226 })
227 }
228
229 pub fn execute(&mut self, sql: &str) -> Result<ExecutionResult> {
231 if matches!(sql.as_bytes().first(), Some(b'I' | b'i')) {
232 if let Some((normalized_key, extracted)) = try_normalize_insert(sql) {
233 let gen = self.schema.generation();
234 let stmt = if let Some(entry) = self.stmt_cache.get(&normalized_key) {
235 if entry.schema_gen == gen {
236 Arc::clone(&entry.stmt)
237 } else {
238 self.parse_and_cache(normalized_key, gen)?
239 }
240 } else {
241 self.parse_and_cache(normalized_key, gen)?
242 };
243 return self.dispatch(&stmt, &extracted);
244 }
245 }
246
247 self.execute_params(sql, &[])
248 }
249
250 pub fn execute_params(&mut self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
252 if params.is_empty() && self.active_txn.is_none() {
253 let gen = self.schema.generation();
254 if let Some(entry) = self.stmt_cache.get(sql) {
255 if entry.schema_gen == gen && entry.param_count == 0 {
256 if let Statement::Update(ref upd) = *entry.stmt {
257 if let Some(ref compiled) = entry.compiled_update {
258 return executor::exec_update_compiled(
259 self.db,
260 &self.schema,
261 upd,
262 compiled,
263 &mut self.update_bufs,
264 );
265 }
266 let compiled = executor::compile_update(&self.schema, upd)?;
267 let result = executor::exec_update_compiled(
268 self.db,
269 &self.schema,
270 upd,
271 &compiled,
272 &mut self.update_bufs,
273 )?;
274 if let Some(e) = self.stmt_cache.get_mut(sql) {
275 e.compiled_update = Some(compiled);
276 }
277 return Ok(result);
278 }
279 }
280 }
281 }
282
283 let (stmt, param_count) = self.get_or_parse(sql)?;
284
285 if param_count != params.len() {
286 return Err(SqlError::ParameterCountMismatch {
287 expected: param_count,
288 got: params.len(),
289 });
290 }
291
292 if param_count == 0 && self.active_txn.is_none() {
293 if let Statement::Update(ref upd) = *stmt {
294 let compiled = executor::compile_update(&self.schema, upd)?;
295 let result = executor::exec_update_compiled(
296 self.db,
297 &self.schema,
298 upd,
299 &compiled,
300 &mut self.update_bufs,
301 )?;
302 if let Some(e) = self.stmt_cache.get_mut(sql) {
303 e.compiled_update = Some(compiled);
304 }
305 return Ok(result);
306 }
307 }
308
309 if param_count > 0
310 && matches!(*stmt, Statement::Insert(ref ins) if matches!(ins.source, InsertSource::Values(_)))
311 {
312 self.dispatch(&stmt, params)
313 } else if param_count > 0 {
314 let bound = parser::bind_params(&stmt, params)?;
315 self.dispatch(&bound, &[])
316 } else {
317 self.dispatch(&stmt, &[])
318 }
319 }
320
321 pub fn query(&mut self, sql: &str) -> Result<QueryResult> {
323 self.query_params(sql, &[])
324 }
325
326 pub fn query_params(&mut self, sql: &str, params: &[Value]) -> Result<QueryResult> {
328 match self.execute_params(sql, params)? {
329 ExecutionResult::Query(qr) => Ok(qr),
330 ExecutionResult::RowsAffected(n) => Ok(QueryResult {
331 columns: vec!["rows_affected".into()],
332 rows: vec![vec![Value::Integer(n as i64)]],
333 }),
334 ExecutionResult::Ok => Ok(QueryResult {
335 columns: vec![],
336 rows: vec![],
337 }),
338 }
339 }
340
341 pub fn tables(&self) -> Vec<&str> {
343 self.schema.table_names()
344 }
345
346 pub fn in_transaction(&self) -> bool {
348 self.active_txn.is_some()
349 }
350
351 pub fn table_schema(&self, name: &str) -> Option<&TableSchema> {
353 self.schema.get(name)
354 }
355
356 pub fn refresh_schema(&mut self) -> Result<()> {
358 self.schema = SchemaManager::load(self.db)?;
359 Ok(())
360 }
361
362 fn parse_and_cache(&mut self, normalized_key: String, gen: u64) -> Result<Arc<Statement>> {
363 let stmt = Arc::new(parser::parse_sql(&normalized_key)?);
364 let param_count = parser::count_params(&stmt);
365 self.stmt_cache.put(
366 normalized_key,
367 CacheEntry {
368 stmt: Arc::clone(&stmt),
369 schema_gen: gen,
370 param_count,
371 compiled_update: None,
372 },
373 );
374 Ok(stmt)
375 }
376
377 fn get_or_parse(&mut self, sql: &str) -> Result<(Arc<Statement>, usize)> {
378 let gen = self.schema.generation();
379
380 if let Some(entry) = self.stmt_cache.get(sql) {
381 if entry.schema_gen == gen {
382 return Ok((Arc::clone(&entry.stmt), entry.param_count));
383 }
384 }
385
386 let stmt = Arc::new(parser::parse_sql(sql)?);
387 let param_count = parser::count_params(&stmt);
388
389 let cacheable = !matches!(
390 *stmt,
391 Statement::CreateTable(_)
392 | Statement::DropTable(_)
393 | Statement::CreateIndex(_)
394 | Statement::DropIndex(_)
395 | Statement::CreateView(_)
396 | Statement::DropView(_)
397 | Statement::AlterTable(_)
398 | Statement::Begin
399 | Statement::Commit
400 | Statement::Rollback
401 );
402
403 if cacheable {
404 self.stmt_cache.put(
405 sql.to_string(),
406 CacheEntry {
407 stmt: Arc::clone(&stmt),
408 schema_gen: gen,
409 param_count,
410 compiled_update: None,
411 },
412 );
413 }
414
415 Ok((stmt, param_count))
416 }
417
418 fn dispatch(&mut self, stmt: &Statement, params: &[Value]) -> Result<ExecutionResult> {
419 match stmt {
420 Statement::Begin => {
421 if self.active_txn.is_some() {
422 return Err(SqlError::TransactionAlreadyActive);
423 }
424 let wtx = self.db.begin_write().map_err(SqlError::Storage)?;
425 self.active_txn = Some(wtx);
426 Ok(ExecutionResult::Ok)
427 }
428 Statement::Commit => {
429 let wtx = self
430 .active_txn
431 .take()
432 .ok_or(SqlError::NoActiveTransaction)?;
433 wtx.commit().map_err(SqlError::Storage)?;
434 Ok(ExecutionResult::Ok)
435 }
436 Statement::Rollback => {
437 let wtx = self
438 .active_txn
439 .take()
440 .ok_or(SqlError::NoActiveTransaction)?;
441 wtx.abort();
442 self.schema = SchemaManager::load(self.db)?;
443 Ok(ExecutionResult::Ok)
444 }
445 Statement::Insert(ins) if self.active_txn.is_some() => {
446 let wtx = self.active_txn.as_mut().unwrap();
447 executor::exec_insert_in_txn(wtx, &self.schema, ins, params, &mut self.insert_bufs)
448 }
449 _ => {
450 if let Some(ref mut wtx) = self.active_txn {
451 executor::execute_in_txn(wtx, &mut self.schema, stmt, params)
452 } else {
453 executor::execute(self.db, &mut self.schema, stmt, params)
454 }
455 }
456 }
457 }
458}