citadel_sql/
connection.rs1use std::num::NonZeroUsize;
4use std::sync::Arc;
5
6use lru::LruCache;
7
8use citadel::Database;
9use citadel_txn::write_txn::WriteTxn;
10
11use crate::error::{Result, SqlError};
12use crate::executor;
13use crate::parser;
14use crate::parser::Statement;
15use crate::schema::SchemaManager;
16use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
17
18const DEFAULT_CACHE_CAPACITY: usize = 64;
19
20fn try_normalize_insert(sql: &str) -> Option<(String, Vec<Value>)> {
21 let bytes = sql.as_bytes();
22 let len = bytes.len();
23 let mut i = 0;
24
25 while i < len && bytes[i].is_ascii_whitespace() {
26 i += 1;
27 }
28 if i + 6 > len || !bytes[i..i + 6].eq_ignore_ascii_case(b"INSERT") {
29 return None;
30 }
31 i += 6;
32 if i >= len || !bytes[i].is_ascii_whitespace() {
33 return None;
34 }
35 while i < len && bytes[i].is_ascii_whitespace() {
36 i += 1;
37 }
38
39 if i + 4 > len || !bytes[i..i + 4].eq_ignore_ascii_case(b"INTO") {
40 return None;
41 }
42 i += 4;
43 if i >= len || !bytes[i].is_ascii_whitespace() {
44 return None;
45 }
46
47 let prefix_start = 0;
48 let mut values_pos = None;
49 let mut j = i;
50 while j + 6 <= len {
51 if bytes[j..j + 6].eq_ignore_ascii_case(b"VALUES")
52 && (j == 0 || !bytes[j - 1].is_ascii_alphanumeric() && bytes[j - 1] != b'_')
53 && (j + 6 >= len || !bytes[j + 6].is_ascii_alphanumeric() && bytes[j + 6] != b'_')
54 {
55 values_pos = Some(j);
56 break;
57 }
58 j += 1;
59 }
60 let values_pos = values_pos?;
61
62 let prefix = &sql[prefix_start..values_pos + 6];
63 let mut pos = values_pos + 6;
64
65 while pos < len && bytes[pos].is_ascii_whitespace() {
66 pos += 1;
67 }
68 if pos >= len || bytes[pos] != b'(' {
69 return None;
70 }
71 pos += 1;
72
73 let mut values = Vec::new();
74 let mut normalized = String::with_capacity(sql.len());
75 normalized.push_str(prefix);
76 normalized.push_str(" (");
77
78 loop {
79 while pos < len && bytes[pos].is_ascii_whitespace() {
80 pos += 1;
81 }
82 if pos >= len {
83 return None;
84 }
85
86 let param_idx = values.len() + 1;
87 if param_idx > 1 {
88 normalized.push_str(", ");
89 }
90
91 if bytes[pos] == b'\'' {
92 pos += 1;
93 let mut seg_start = pos;
94 let mut s = String::new();
95 loop {
96 if pos >= len {
97 return None;
98 }
99 if bytes[pos] == b'\'' {
100 s.push_str(std::str::from_utf8(&bytes[seg_start..pos]).ok()?);
101 pos += 1;
102 if pos < len && bytes[pos] == b'\'' {
103 s.push('\'');
104 pos += 1;
105 seg_start = pos;
106 } else {
107 break;
108 }
109 } else {
110 pos += 1;
111 }
112 }
113 values.push(Value::Text(s.into()));
114 } else if bytes[pos] == b'-' || bytes[pos].is_ascii_digit() {
115 let start = pos;
116 if bytes[pos] == b'-' {
117 pos += 1;
118 }
119 while pos < len && bytes[pos].is_ascii_digit() {
120 pos += 1;
121 }
122 if pos < len && bytes[pos] == b'.' {
123 pos += 1;
124 while pos < len && bytes[pos].is_ascii_digit() {
125 pos += 1;
126 }
127 let num: f64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
128 values.push(Value::Real(num));
129 } else {
130 let num: i64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
131 values.push(Value::Integer(num));
132 }
133 } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"NULL") {
134 let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
135 if !after.is_ascii_alphanumeric() && after != b'_' {
136 pos += 4;
137 values.push(Value::Null);
138 } else {
139 return None;
140 }
141 } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"TRUE") {
142 let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
143 if !after.is_ascii_alphanumeric() && after != b'_' {
144 pos += 4;
145 values.push(Value::Boolean(true));
146 } else {
147 return None;
148 }
149 } else if pos + 5 <= len && bytes[pos..pos + 5].eq_ignore_ascii_case(b"FALSE") {
150 let after = if pos + 5 < len { bytes[pos + 5] } else { b')' };
151 if !after.is_ascii_alphanumeric() && after != b'_' {
152 pos += 5;
153 values.push(Value::Boolean(false));
154 } else {
155 return None;
156 }
157 } else {
158 return None;
159 }
160
161 normalized.push('$');
162 normalized.push_str(¶m_idx.to_string());
163
164 while pos < len && bytes[pos].is_ascii_whitespace() {
165 pos += 1;
166 }
167 if pos >= len {
168 return None;
169 }
170
171 if bytes[pos] == b',' {
172 pos += 1;
173 } else if bytes[pos] == b')' {
174 pos += 1;
175 break;
176 } else {
177 return None;
178 }
179 }
180
181 normalized.push(')');
182
183 while pos < len && (bytes[pos].is_ascii_whitespace() || bytes[pos] == b';') {
184 pos += 1;
185 }
186 if pos != len {
187 return None;
188 }
189
190 if values.is_empty() {
191 return None;
192 }
193
194 Some((normalized, values))
195}
196
197struct CacheEntry {
198 stmt: Arc<Statement>,
199 schema_gen: u64,
200 param_count: usize,
201}
202
203pub struct Connection<'a> {
211 db: &'a Database,
212 schema: SchemaManager,
213 active_txn: Option<WriteTxn<'a>>,
214 stmt_cache: LruCache<String, CacheEntry>,
215 insert_bufs: executor::InsertBufs,
216}
217
218impl<'a> Connection<'a> {
219 pub fn open(db: &'a Database) -> Result<Self> {
221 let schema = SchemaManager::load(db)?;
222 let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
223 Ok(Self {
224 db,
225 schema,
226 active_txn: None,
227 stmt_cache,
228 insert_bufs: executor::InsertBufs::new(),
229 })
230 }
231
232 pub fn execute(&mut self, sql: &str) -> Result<ExecutionResult> {
234 if let Some((normalized_key, extracted)) = try_normalize_insert(sql) {
235 let gen = self.schema.generation();
236 let stmt = if let Some(entry) = self.stmt_cache.get(&normalized_key) {
237 if entry.schema_gen == gen {
238 Arc::clone(&entry.stmt)
239 } else {
240 self.parse_and_cache(normalized_key, gen)?
241 }
242 } else {
243 self.parse_and_cache(normalized_key, gen)?
244 };
245 return self.dispatch(&stmt, &extracted);
246 }
247
248 self.execute_params(sql, &[])
249 }
250
251 pub fn execute_params(&mut self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
253 let (stmt, param_count) = self.get_or_parse(sql)?;
254
255 if param_count != params.len() {
256 return Err(SqlError::ParameterCountMismatch {
257 expected: param_count,
258 got: params.len(),
259 });
260 }
261
262 if param_count > 0 && matches!(*stmt, Statement::Insert(_)) {
263 self.dispatch(&stmt, params)
264 } else if param_count > 0 {
265 let bound = parser::bind_params(&stmt, params)?;
266 self.dispatch(&bound, &[])
267 } else {
268 self.dispatch(&stmt, &[])
269 }
270 }
271
272 pub fn query(&mut self, sql: &str) -> Result<QueryResult> {
274 self.query_params(sql, &[])
275 }
276
277 pub fn query_params(&mut self, sql: &str, params: &[Value]) -> Result<QueryResult> {
279 match self.execute_params(sql, params)? {
280 ExecutionResult::Query(qr) => Ok(qr),
281 ExecutionResult::RowsAffected(n) => Ok(QueryResult {
282 columns: vec!["rows_affected".into()],
283 rows: vec![vec![Value::Integer(n as i64)]],
284 }),
285 ExecutionResult::Ok => Ok(QueryResult {
286 columns: vec![],
287 rows: vec![],
288 }),
289 }
290 }
291
292 pub fn tables(&self) -> Vec<&str> {
294 self.schema.table_names()
295 }
296
297 pub fn in_transaction(&self) -> bool {
299 self.active_txn.is_some()
300 }
301
302 pub fn table_schema(&self, name: &str) -> Option<&TableSchema> {
304 self.schema.get(name)
305 }
306
307 pub fn refresh_schema(&mut self) -> Result<()> {
309 self.schema = SchemaManager::load(self.db)?;
310 Ok(())
311 }
312
313 fn parse_and_cache(&mut self, normalized_key: String, gen: u64) -> Result<Arc<Statement>> {
314 let stmt = Arc::new(parser::parse_sql(&normalized_key)?);
315 let param_count = parser::count_params(&stmt);
316 self.stmt_cache.put(
317 normalized_key,
318 CacheEntry {
319 stmt: Arc::clone(&stmt),
320 schema_gen: gen,
321 param_count,
322 },
323 );
324 Ok(stmt)
325 }
326
327 fn get_or_parse(&mut self, sql: &str) -> Result<(Arc<Statement>, usize)> {
328 let gen = self.schema.generation();
329
330 if let Some(entry) = self.stmt_cache.get(sql) {
331 if entry.schema_gen == gen {
332 return Ok((Arc::clone(&entry.stmt), entry.param_count));
333 }
334 }
335
336 let stmt = Arc::new(parser::parse_sql(sql)?);
337 let param_count = parser::count_params(&stmt);
338
339 let cacheable = !matches!(
340 *stmt,
341 Statement::CreateTable(_)
342 | Statement::DropTable(_)
343 | Statement::CreateIndex(_)
344 | Statement::DropIndex(_)
345 | Statement::Begin
346 | Statement::Commit
347 | Statement::Rollback
348 );
349
350 if cacheable {
351 self.stmt_cache.put(
352 sql.to_string(),
353 CacheEntry {
354 stmt: Arc::clone(&stmt),
355 schema_gen: gen,
356 param_count,
357 },
358 );
359 }
360
361 Ok((stmt, param_count))
362 }
363
364 fn dispatch(&mut self, stmt: &Statement, params: &[Value]) -> Result<ExecutionResult> {
365 match stmt {
366 Statement::Begin => {
367 if self.active_txn.is_some() {
368 return Err(SqlError::TransactionAlreadyActive);
369 }
370 let wtx = self.db.begin_write().map_err(SqlError::Storage)?;
371 self.active_txn = Some(wtx);
372 Ok(ExecutionResult::Ok)
373 }
374 Statement::Commit => {
375 let wtx = self
376 .active_txn
377 .take()
378 .ok_or(SqlError::NoActiveTransaction)?;
379 wtx.commit().map_err(SqlError::Storage)?;
380 Ok(ExecutionResult::Ok)
381 }
382 Statement::Rollback => {
383 let wtx = self
384 .active_txn
385 .take()
386 .ok_or(SqlError::NoActiveTransaction)?;
387 wtx.abort();
388 self.schema = SchemaManager::load(self.db)?;
389 Ok(ExecutionResult::Ok)
390 }
391 Statement::Insert(ins) if self.active_txn.is_some() => {
392 let wtx = self.active_txn.as_mut().unwrap();
393 executor::exec_insert_in_txn(wtx, &self.schema, ins, params, &mut self.insert_bufs)
394 }
395 _ => {
396 if let Some(ref mut wtx) = self.active_txn {
397 executor::execute_in_txn(wtx, &mut self.schema, stmt, params)
398 } else {
399 executor::execute(self.db, &mut self.schema, stmt, params)
400 }
401 }
402 }
403 }
404}