citadel_sql/
connection.rs1use std::num::NonZeroUsize;
4use std::sync::Arc;
5
6use lru::LruCache;
7
8use citadel::Database;
9use citadel_txn::write_txn::WriteTxn;
10
11use crate::error::{Result, SqlError};
12use crate::executor;
13use crate::parser;
14use crate::parser::{InsertSource, Statement};
15use crate::schema::SchemaManager;
16use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
17
18const DEFAULT_CACHE_CAPACITY: usize = 64;
19
20fn try_normalize_insert(sql: &str) -> Option<(String, Vec<Value>)> {
21 let bytes = sql.as_bytes();
22 let len = bytes.len();
23 let mut i = 0;
24
25 while i < len && bytes[i].is_ascii_whitespace() {
26 i += 1;
27 }
28 if i + 6 > len || !bytes[i..i + 6].eq_ignore_ascii_case(b"INSERT") {
29 return None;
30 }
31 i += 6;
32 if i >= len || !bytes[i].is_ascii_whitespace() {
33 return None;
34 }
35 while i < len && bytes[i].is_ascii_whitespace() {
36 i += 1;
37 }
38
39 if i + 4 > len || !bytes[i..i + 4].eq_ignore_ascii_case(b"INTO") {
40 return None;
41 }
42 i += 4;
43 if i >= len || !bytes[i].is_ascii_whitespace() {
44 return None;
45 }
46
47 let prefix_start = 0;
48 let mut values_pos = None;
49 let mut j = i;
50 while j + 6 <= len {
51 if bytes[j..j + 6].eq_ignore_ascii_case(b"VALUES")
52 && (j == 0 || !bytes[j - 1].is_ascii_alphanumeric() && bytes[j - 1] != b'_')
53 && (j + 6 >= len || !bytes[j + 6].is_ascii_alphanumeric() && bytes[j + 6] != b'_')
54 {
55 values_pos = Some(j);
56 break;
57 }
58 j += 1;
59 }
60 let values_pos = values_pos?;
61
62 let prefix = &sql[prefix_start..values_pos + 6];
63 let mut pos = values_pos + 6;
64
65 while pos < len && bytes[pos].is_ascii_whitespace() {
66 pos += 1;
67 }
68 if pos >= len || bytes[pos] != b'(' {
69 return None;
70 }
71 pos += 1;
72
73 let mut values = Vec::new();
74 let mut normalized = String::with_capacity(sql.len());
75 normalized.push_str(prefix);
76 normalized.push_str(" (");
77
78 loop {
79 while pos < len && bytes[pos].is_ascii_whitespace() {
80 pos += 1;
81 }
82 if pos >= len {
83 return None;
84 }
85
86 let param_idx = values.len() + 1;
87 if param_idx > 1 {
88 normalized.push_str(", ");
89 }
90
91 if bytes[pos] == b'\'' {
92 pos += 1;
93 let mut seg_start = pos;
94 let mut s = String::new();
95 loop {
96 if pos >= len {
97 return None;
98 }
99 if bytes[pos] == b'\'' {
100 s.push_str(std::str::from_utf8(&bytes[seg_start..pos]).ok()?);
101 pos += 1;
102 if pos < len && bytes[pos] == b'\'' {
103 s.push('\'');
104 pos += 1;
105 seg_start = pos;
106 } else {
107 break;
108 }
109 } else {
110 pos += 1;
111 }
112 }
113 values.push(Value::Text(s.into()));
114 } else if bytes[pos] == b'-' || bytes[pos].is_ascii_digit() {
115 let start = pos;
116 if bytes[pos] == b'-' {
117 pos += 1;
118 }
119 while pos < len && bytes[pos].is_ascii_digit() {
120 pos += 1;
121 }
122 if pos < len && bytes[pos] == b'.' {
123 pos += 1;
124 while pos < len && bytes[pos].is_ascii_digit() {
125 pos += 1;
126 }
127 let num: f64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
128 values.push(Value::Real(num));
129 } else {
130 let num: i64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
131 values.push(Value::Integer(num));
132 }
133 } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"NULL") {
134 let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
135 if !after.is_ascii_alphanumeric() && after != b'_' {
136 pos += 4;
137 values.push(Value::Null);
138 } else {
139 return None;
140 }
141 } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"TRUE") {
142 let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
143 if !after.is_ascii_alphanumeric() && after != b'_' {
144 pos += 4;
145 values.push(Value::Boolean(true));
146 } else {
147 return None;
148 }
149 } else if pos + 5 <= len && bytes[pos..pos + 5].eq_ignore_ascii_case(b"FALSE") {
150 let after = if pos + 5 < len { bytes[pos + 5] } else { b')' };
151 if !after.is_ascii_alphanumeric() && after != b'_' {
152 pos += 5;
153 values.push(Value::Boolean(false));
154 } else {
155 return None;
156 }
157 } else {
158 return None;
159 }
160
161 normalized.push('$');
162 normalized.push_str(¶m_idx.to_string());
163
164 while pos < len && bytes[pos].is_ascii_whitespace() {
165 pos += 1;
166 }
167 if pos >= len {
168 return None;
169 }
170
171 if bytes[pos] == b',' {
172 pos += 1;
173 } else if bytes[pos] == b')' {
174 pos += 1;
175 break;
176 } else {
177 return None;
178 }
179 }
180
181 normalized.push(')');
182
183 while pos < len && (bytes[pos].is_ascii_whitespace() || bytes[pos] == b';') {
184 pos += 1;
185 }
186 if pos != len {
187 return None;
188 }
189
190 if values.is_empty() {
191 return None;
192 }
193
194 Some((normalized, values))
195}
196
197struct CacheEntry {
198 stmt: Arc<Statement>,
199 schema_gen: u64,
200 param_count: usize,
201}
202
203pub struct Connection<'a> {
211 db: &'a Database,
212 schema: SchemaManager,
213 active_txn: Option<WriteTxn<'a>>,
214 stmt_cache: LruCache<String, CacheEntry>,
215 insert_bufs: executor::InsertBufs,
216}
217
218impl<'a> Connection<'a> {
219 pub fn open(db: &'a Database) -> Result<Self> {
221 let schema = SchemaManager::load(db)?;
222 let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
223 Ok(Self {
224 db,
225 schema,
226 active_txn: None,
227 stmt_cache,
228 insert_bufs: executor::InsertBufs::new(),
229 })
230 }
231
232 pub fn execute(&mut self, sql: &str) -> Result<ExecutionResult> {
234 if let Some((normalized_key, extracted)) = try_normalize_insert(sql) {
235 let gen = self.schema.generation();
236 let stmt = if let Some(entry) = self.stmt_cache.get(&normalized_key) {
237 if entry.schema_gen == gen {
238 Arc::clone(&entry.stmt)
239 } else {
240 self.parse_and_cache(normalized_key, gen)?
241 }
242 } else {
243 self.parse_and_cache(normalized_key, gen)?
244 };
245 return self.dispatch(&stmt, &extracted);
246 }
247
248 self.execute_params(sql, &[])
249 }
250
251 pub fn execute_params(&mut self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
253 let (stmt, param_count) = self.get_or_parse(sql)?;
254
255 if param_count != params.len() {
256 return Err(SqlError::ParameterCountMismatch {
257 expected: param_count,
258 got: params.len(),
259 });
260 }
261
262 if param_count > 0
263 && matches!(*stmt, Statement::Insert(ref ins) if matches!(ins.source, InsertSource::Values(_)))
264 {
265 self.dispatch(&stmt, params)
266 } else if param_count > 0 {
267 let bound = parser::bind_params(&stmt, params)?;
268 self.dispatch(&bound, &[])
269 } else {
270 self.dispatch(&stmt, &[])
271 }
272 }
273
274 pub fn query(&mut self, sql: &str) -> Result<QueryResult> {
276 self.query_params(sql, &[])
277 }
278
279 pub fn query_params(&mut self, sql: &str, params: &[Value]) -> Result<QueryResult> {
281 match self.execute_params(sql, params)? {
282 ExecutionResult::Query(qr) => Ok(qr),
283 ExecutionResult::RowsAffected(n) => Ok(QueryResult {
284 columns: vec!["rows_affected".into()],
285 rows: vec![vec![Value::Integer(n as i64)]],
286 }),
287 ExecutionResult::Ok => Ok(QueryResult {
288 columns: vec![],
289 rows: vec![],
290 }),
291 }
292 }
293
294 pub fn tables(&self) -> Vec<&str> {
296 self.schema.table_names()
297 }
298
299 pub fn in_transaction(&self) -> bool {
301 self.active_txn.is_some()
302 }
303
304 pub fn table_schema(&self, name: &str) -> Option<&TableSchema> {
306 self.schema.get(name)
307 }
308
309 pub fn refresh_schema(&mut self) -> Result<()> {
311 self.schema = SchemaManager::load(self.db)?;
312 Ok(())
313 }
314
315 fn parse_and_cache(&mut self, normalized_key: String, gen: u64) -> Result<Arc<Statement>> {
316 let stmt = Arc::new(parser::parse_sql(&normalized_key)?);
317 let param_count = parser::count_params(&stmt);
318 self.stmt_cache.put(
319 normalized_key,
320 CacheEntry {
321 stmt: Arc::clone(&stmt),
322 schema_gen: gen,
323 param_count,
324 },
325 );
326 Ok(stmt)
327 }
328
329 fn get_or_parse(&mut self, sql: &str) -> Result<(Arc<Statement>, usize)> {
330 let gen = self.schema.generation();
331
332 if let Some(entry) = self.stmt_cache.get(sql) {
333 if entry.schema_gen == gen {
334 return Ok((Arc::clone(&entry.stmt), entry.param_count));
335 }
336 }
337
338 let stmt = Arc::new(parser::parse_sql(sql)?);
339 let param_count = parser::count_params(&stmt);
340
341 let cacheable = !matches!(
342 *stmt,
343 Statement::CreateTable(_)
344 | Statement::DropTable(_)
345 | Statement::CreateIndex(_)
346 | Statement::DropIndex(_)
347 | Statement::AlterTable(_)
348 | Statement::Begin
349 | Statement::Commit
350 | Statement::Rollback
351 );
352
353 if cacheable {
354 self.stmt_cache.put(
355 sql.to_string(),
356 CacheEntry {
357 stmt: Arc::clone(&stmt),
358 schema_gen: gen,
359 param_count,
360 },
361 );
362 }
363
364 Ok((stmt, param_count))
365 }
366
367 fn dispatch(&mut self, stmt: &Statement, params: &[Value]) -> Result<ExecutionResult> {
368 match stmt {
369 Statement::Begin => {
370 if self.active_txn.is_some() {
371 return Err(SqlError::TransactionAlreadyActive);
372 }
373 let wtx = self.db.begin_write().map_err(SqlError::Storage)?;
374 self.active_txn = Some(wtx);
375 Ok(ExecutionResult::Ok)
376 }
377 Statement::Commit => {
378 let wtx = self
379 .active_txn
380 .take()
381 .ok_or(SqlError::NoActiveTransaction)?;
382 wtx.commit().map_err(SqlError::Storage)?;
383 Ok(ExecutionResult::Ok)
384 }
385 Statement::Rollback => {
386 let wtx = self
387 .active_txn
388 .take()
389 .ok_or(SqlError::NoActiveTransaction)?;
390 wtx.abort();
391 self.schema = SchemaManager::load(self.db)?;
392 Ok(ExecutionResult::Ok)
393 }
394 Statement::Insert(ins) if self.active_txn.is_some() => {
395 let wtx = self.active_txn.as_mut().unwrap();
396 executor::exec_insert_in_txn(wtx, &self.schema, ins, params, &mut self.insert_bufs)
397 }
398 _ => {
399 if let Some(ref mut wtx) = self.active_txn {
400 executor::execute_in_txn(wtx, &mut self.schema, stmt, params)
401 } else {
402 executor::execute(self.db, &mut self.schema, stmt, params)
403 }
404 }
405 }
406 }
407}