Skip to main content

citadel_sql/
connection.rs

1//! Public SQL connection API.
2
3use std::num::NonZeroUsize;
4use std::sync::Arc;
5
6use lru::LruCache;
7
8use citadel::Database;
9use citadel_txn::write_txn::{WriteTxn, WriteTxnSnapshot};
10
11use crate::error::{Result, SqlError};
12use crate::executor;
13use crate::parser;
14use crate::parser::{InsertSource, Statement};
15use crate::schema::{SchemaManager, SchemaSnapshot};
16use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
17
18const DEFAULT_CACHE_CAPACITY: usize = 64;
19
20fn try_normalize_insert(sql: &str) -> Option<(String, Vec<Value>)> {
21    let bytes = sql.as_bytes();
22    let len = bytes.len();
23    let mut i = 0;
24
25    while i < len && bytes[i].is_ascii_whitespace() {
26        i += 1;
27    }
28    if i + 6 > len || !bytes[i..i + 6].eq_ignore_ascii_case(b"INSERT") {
29        return None;
30    }
31    i += 6;
32    if i >= len || !bytes[i].is_ascii_whitespace() {
33        return None;
34    }
35    while i < len && bytes[i].is_ascii_whitespace() {
36        i += 1;
37    }
38
39    if i + 4 > len || !bytes[i..i + 4].eq_ignore_ascii_case(b"INTO") {
40        return None;
41    }
42    i += 4;
43    if i >= len || !bytes[i].is_ascii_whitespace() {
44        return None;
45    }
46
47    let prefix_start = 0;
48    let mut values_pos = None;
49    let mut j = i;
50    while j + 6 <= len {
51        if bytes[j..j + 6].eq_ignore_ascii_case(b"VALUES")
52            && (j == 0 || !bytes[j - 1].is_ascii_alphanumeric() && bytes[j - 1] != b'_')
53            && (j + 6 >= len || !bytes[j + 6].is_ascii_alphanumeric() && bytes[j + 6] != b'_')
54        {
55            values_pos = Some(j);
56            break;
57        }
58        j += 1;
59    }
60    let values_pos = values_pos?;
61
62    let prefix = &sql[prefix_start..values_pos + 6];
63    let mut pos = values_pos + 6;
64
65    while pos < len && bytes[pos].is_ascii_whitespace() {
66        pos += 1;
67    }
68    if pos >= len || bytes[pos] != b'(' {
69        return None;
70    }
71    pos += 1;
72
73    let mut values = Vec::new();
74    let mut normalized = String::with_capacity(sql.len());
75    normalized.push_str(prefix);
76    normalized.push_str(" (");
77
78    loop {
79        while pos < len && bytes[pos].is_ascii_whitespace() {
80            pos += 1;
81        }
82        if pos >= len {
83            return None;
84        }
85
86        let param_idx = values.len() + 1;
87        if param_idx > 1 {
88            normalized.push_str(", ");
89        }
90
91        if bytes[pos] == b'\'' {
92            pos += 1;
93            let mut seg_start = pos;
94            let mut s = String::new();
95            loop {
96                if pos >= len {
97                    return None;
98                }
99                if bytes[pos] == b'\'' {
100                    s.push_str(std::str::from_utf8(&bytes[seg_start..pos]).ok()?);
101                    pos += 1;
102                    if pos < len && bytes[pos] == b'\'' {
103                        s.push('\'');
104                        pos += 1;
105                        seg_start = pos;
106                    } else {
107                        break;
108                    }
109                } else {
110                    pos += 1;
111                }
112            }
113            values.push(Value::Text(s.into()));
114        } else if bytes[pos] == b'-' || bytes[pos].is_ascii_digit() {
115            let start = pos;
116            if bytes[pos] == b'-' {
117                pos += 1;
118            }
119            while pos < len && bytes[pos].is_ascii_digit() {
120                pos += 1;
121            }
122            if pos < len && bytes[pos] == b'.' {
123                pos += 1;
124                while pos < len && bytes[pos].is_ascii_digit() {
125                    pos += 1;
126                }
127                let num: f64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
128                values.push(Value::Real(num));
129            } else {
130                let num: i64 = std::str::from_utf8(&bytes[start..pos]).ok()?.parse().ok()?;
131                values.push(Value::Integer(num));
132            }
133        } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"NULL") {
134            let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
135            if !after.is_ascii_alphanumeric() && after != b'_' {
136                pos += 4;
137                values.push(Value::Null);
138            } else {
139                return None;
140            }
141        } else if pos + 4 <= len && bytes[pos..pos + 4].eq_ignore_ascii_case(b"TRUE") {
142            let after = if pos + 4 < len { bytes[pos + 4] } else { b')' };
143            if !after.is_ascii_alphanumeric() && after != b'_' {
144                pos += 4;
145                values.push(Value::Boolean(true));
146            } else {
147                return None;
148            }
149        } else if pos + 5 <= len && bytes[pos..pos + 5].eq_ignore_ascii_case(b"FALSE") {
150            let after = if pos + 5 < len { bytes[pos + 5] } else { b')' };
151            if !after.is_ascii_alphanumeric() && after != b'_' {
152                pos += 5;
153                values.push(Value::Boolean(false));
154            } else {
155                return None;
156            }
157        } else {
158            return None;
159        }
160
161        normalized.push('$');
162        normalized.push_str(&param_idx.to_string());
163
164        while pos < len && bytes[pos].is_ascii_whitespace() {
165            pos += 1;
166        }
167        if pos >= len {
168            return None;
169        }
170
171        if bytes[pos] == b',' {
172            pos += 1;
173        } else if bytes[pos] == b')' {
174            pos += 1;
175            break;
176        } else {
177            return None;
178        }
179    }
180
181    normalized.push(')');
182
183    while pos < len && (bytes[pos].is_ascii_whitespace() || bytes[pos] == b';') {
184        pos += 1;
185    }
186    if pos != len {
187        return None;
188    }
189
190    if values.is_empty() {
191        return None;
192    }
193
194    Some((normalized, values))
195}
196
197struct CacheEntry {
198    stmt: Arc<Statement>,
199    schema_gen: u64,
200    param_count: usize,
201    compiled_update: Option<executor::CompiledUpdate>,
202}
203
204struct SavepointEntry {
205    name: String,
206    wtx_snap: WriteTxnSnapshot,
207    schema_snap: SchemaSnapshot,
208}
209
210/// SQL connection with LRU statement cache. Auto-commit; explicit txns via
211/// BEGIN/COMMIT/ROLLBACK with nested SAVEPOINT/RELEASE/ROLLBACK TO support.
212pub struct Connection<'a> {
213    db: &'a Database,
214    schema: SchemaManager,
215    active_txn: Option<WriteTxn<'a>>,
216    savepoint_stack: Vec<SavepointEntry>,
217    in_place_saved: Option<bool>,
218    stmt_cache: LruCache<String, CacheEntry>,
219    insert_bufs: executor::InsertBufs,
220    update_bufs: executor::UpdateBufs,
221}
222
223impl<'a> Connection<'a> {
224    /// Open a SQL connection to a database.
225    pub fn open(db: &'a Database) -> Result<Self> {
226        let schema = SchemaManager::load(db)?;
227        let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
228        Ok(Self {
229            db,
230            schema,
231            active_txn: None,
232            savepoint_stack: Vec::new(),
233            in_place_saved: None,
234            stmt_cache,
235            insert_bufs: executor::InsertBufs::new(),
236            update_bufs: executor::UpdateBufs::new(),
237        })
238    }
239
240    /// Execute a SQL statement. Returns the result.
241    pub fn execute(&mut self, sql: &str) -> Result<ExecutionResult> {
242        if matches!(sql.as_bytes().first(), Some(b'I' | b'i')) {
243            if let Some((normalized_key, extracted)) = try_normalize_insert(sql) {
244                let gen = self.schema.generation();
245                let stmt = if let Some(entry) = self.stmt_cache.get(&normalized_key) {
246                    if entry.schema_gen == gen {
247                        Arc::clone(&entry.stmt)
248                    } else {
249                        self.parse_and_cache(normalized_key, gen)?
250                    }
251                } else {
252                    self.parse_and_cache(normalized_key, gen)?
253                };
254                return self.dispatch(&stmt, &extracted);
255            }
256        }
257
258        self.execute_params(sql, &[])
259    }
260
261    /// Execute a SQL statement with positional parameters ($1, $2, ...).
262    pub fn execute_params(&mut self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
263        if params.is_empty() && self.active_txn.is_none() {
264            let gen = self.schema.generation();
265            if let Some(entry) = self.stmt_cache.get(sql) {
266                if entry.schema_gen == gen && entry.param_count == 0 {
267                    if let Statement::Update(ref upd) = *entry.stmt {
268                        if let Some(ref compiled) = entry.compiled_update {
269                            return executor::exec_update_compiled(
270                                self.db,
271                                &self.schema,
272                                upd,
273                                compiled,
274                                &mut self.update_bufs,
275                            );
276                        }
277                        let compiled = executor::compile_update(&self.schema, upd)?;
278                        let result = executor::exec_update_compiled(
279                            self.db,
280                            &self.schema,
281                            upd,
282                            &compiled,
283                            &mut self.update_bufs,
284                        )?;
285                        if let Some(e) = self.stmt_cache.get_mut(sql) {
286                            e.compiled_update = Some(compiled);
287                        }
288                        return Ok(result);
289                    }
290                }
291            }
292        }
293
294        let (stmt, param_count) = self.get_or_parse(sql)?;
295
296        if param_count != params.len() {
297            return Err(SqlError::ParameterCountMismatch {
298                expected: param_count,
299                got: params.len(),
300            });
301        }
302
303        if param_count == 0 && self.active_txn.is_none() {
304            if let Statement::Update(ref upd) = *stmt {
305                let compiled = executor::compile_update(&self.schema, upd)?;
306                let result = executor::exec_update_compiled(
307                    self.db,
308                    &self.schema,
309                    upd,
310                    &compiled,
311                    &mut self.update_bufs,
312                )?;
313                if let Some(e) = self.stmt_cache.get_mut(sql) {
314                    e.compiled_update = Some(compiled);
315                }
316                return Ok(result);
317            }
318        }
319
320        if param_count > 0
321            && matches!(*stmt, Statement::Insert(ref ins) if matches!(ins.source, InsertSource::Values(_)))
322        {
323            self.dispatch(&stmt, params)
324        } else if param_count > 0 {
325            let bound = parser::bind_params(&stmt, params)?;
326            self.dispatch(&bound, &[])
327        } else {
328            self.dispatch(&stmt, &[])
329        }
330    }
331
332    /// Execute a SQL query and return the result set.
333    pub fn query(&mut self, sql: &str) -> Result<QueryResult> {
334        self.query_params(sql, &[])
335    }
336
337    /// Execute a SQL query with positional parameters ($1, $2, ...).
338    pub fn query_params(&mut self, sql: &str, params: &[Value]) -> Result<QueryResult> {
339        match self.execute_params(sql, params)? {
340            ExecutionResult::Query(qr) => Ok(qr),
341            ExecutionResult::RowsAffected(n) => Ok(QueryResult {
342                columns: vec!["rows_affected".into()],
343                rows: vec![vec![Value::Integer(n as i64)]],
344            }),
345            ExecutionResult::Ok => Ok(QueryResult {
346                columns: vec![],
347                rows: vec![],
348            }),
349        }
350    }
351
352    /// List all table names.
353    pub fn tables(&self) -> Vec<&str> {
354        self.schema.table_names()
355    }
356
357    /// Returns true if an explicit transaction is active (BEGIN was issued).
358    pub fn in_transaction(&self) -> bool {
359        self.active_txn.is_some()
360    }
361
362    /// Get the schema for a named table.
363    pub fn table_schema(&self, name: &str) -> Option<&TableSchema> {
364        self.schema.get(name)
365    }
366
367    /// Reload schemas from the database.
368    pub fn refresh_schema(&mut self) -> Result<()> {
369        self.schema = SchemaManager::load(self.db)?;
370        Ok(())
371    }
372
373    fn parse_and_cache(&mut self, normalized_key: String, gen: u64) -> Result<Arc<Statement>> {
374        let stmt = Arc::new(parser::parse_sql(&normalized_key)?);
375        let param_count = parser::count_params(&stmt);
376        self.stmt_cache.put(
377            normalized_key,
378            CacheEntry {
379                stmt: Arc::clone(&stmt),
380                schema_gen: gen,
381                param_count,
382                compiled_update: None,
383            },
384        );
385        Ok(stmt)
386    }
387
388    fn get_or_parse(&mut self, sql: &str) -> Result<(Arc<Statement>, usize)> {
389        let gen = self.schema.generation();
390
391        if let Some(entry) = self.stmt_cache.get(sql) {
392            if entry.schema_gen == gen {
393                return Ok((Arc::clone(&entry.stmt), entry.param_count));
394            }
395        }
396
397        let stmt = Arc::new(parser::parse_sql(sql)?);
398        let param_count = parser::count_params(&stmt);
399
400        let cacheable = !matches!(
401            *stmt,
402            Statement::CreateTable(_)
403                | Statement::DropTable(_)
404                | Statement::CreateIndex(_)
405                | Statement::DropIndex(_)
406                | Statement::CreateView(_)
407                | Statement::DropView(_)
408                | Statement::AlterTable(_)
409                | Statement::Begin
410                | Statement::Commit
411                | Statement::Rollback
412                | Statement::Savepoint(_)
413                | Statement::ReleaseSavepoint(_)
414                | Statement::RollbackTo(_)
415        );
416
417        if cacheable {
418            self.stmt_cache.put(
419                sql.to_string(),
420                CacheEntry {
421                    stmt: Arc::clone(&stmt),
422                    schema_gen: gen,
423                    param_count,
424                    compiled_update: None,
425                },
426            );
427        }
428
429        Ok((stmt, param_count))
430    }
431
432    fn dispatch(&mut self, stmt: &Statement, params: &[Value]) -> Result<ExecutionResult> {
433        match stmt {
434            Statement::Begin => {
435                if self.active_txn.is_some() {
436                    return Err(SqlError::TransactionAlreadyActive);
437                }
438                let wtx = self.db.begin_write().map_err(SqlError::Storage)?;
439                self.active_txn = Some(wtx);
440                Ok(ExecutionResult::Ok)
441            }
442            Statement::Commit => {
443                let wtx = self
444                    .active_txn
445                    .take()
446                    .ok_or(SqlError::NoActiveTransaction)?;
447                wtx.commit().map_err(SqlError::Storage)?;
448                self.clear_savepoint_state();
449                Ok(ExecutionResult::Ok)
450            }
451            Statement::Rollback => {
452                let wtx = self
453                    .active_txn
454                    .take()
455                    .ok_or(SqlError::NoActiveTransaction)?;
456                wtx.abort();
457                self.clear_savepoint_state();
458                self.schema = SchemaManager::load(self.db)?;
459                Ok(ExecutionResult::Ok)
460            }
461            Statement::Savepoint(name) => self.do_savepoint(name),
462            Statement::ReleaseSavepoint(name) => self.do_release(name),
463            Statement::RollbackTo(name) => self.do_rollback_to(name),
464            Statement::Insert(ins) if self.active_txn.is_some() => {
465                let wtx = self.active_txn.as_mut().unwrap();
466                executor::exec_insert_in_txn(wtx, &self.schema, ins, params, &mut self.insert_bufs)
467            }
468            _ => {
469                if let Some(ref mut wtx) = self.active_txn {
470                    executor::execute_in_txn(wtx, &mut self.schema, stmt, params)
471                } else {
472                    executor::execute(self.db, &mut self.schema, stmt, params)
473                }
474            }
475        }
476    }
477
478    fn clear_savepoint_state(&mut self) {
479        self.savepoint_stack.clear();
480        self.in_place_saved = None;
481    }
482
483    fn do_savepoint(&mut self, name: &str) -> Result<ExecutionResult> {
484        let wtx = self
485            .active_txn
486            .as_mut()
487            .ok_or(SqlError::NoActiveTransaction)?;
488
489        if self.savepoint_stack.is_empty() {
490            self.in_place_saved = Some(wtx.in_place());
491            wtx.set_in_place(false);
492        }
493
494        let wtx_snap = wtx.begin_savepoint();
495        let schema_snap = self.schema.save_snapshot();
496
497        self.savepoint_stack.push(SavepointEntry {
498            name: name.to_string(),
499            wtx_snap,
500            schema_snap,
501        });
502
503        Ok(ExecutionResult::Ok)
504    }
505
506    fn do_release(&mut self, name: &str) -> Result<ExecutionResult> {
507        if self.active_txn.is_none() {
508            return Err(SqlError::NoActiveTransaction);
509        }
510
511        let idx = self
512            .savepoint_stack
513            .iter()
514            .rposition(|e| e.name == name)
515            .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
516        self.savepoint_stack.truncate(idx);
517
518        if self.savepoint_stack.is_empty() {
519            if let (Some(wtx), Some(original)) =
520                (self.active_txn.as_mut(), self.in_place_saved.take())
521            {
522                wtx.set_in_place(original);
523            }
524        }
525
526        Ok(ExecutionResult::Ok)
527    }
528
529    fn do_rollback_to(&mut self, name: &str) -> Result<ExecutionResult> {
530        if self.active_txn.is_none() {
531            return Err(SqlError::NoActiveTransaction);
532        }
533
534        let idx = self
535            .savepoint_stack
536            .iter()
537            .rposition(|e| e.name == name)
538            .ok_or_else(|| SqlError::SavepointNotFound(name.to_string()))?;
539
540        self.savepoint_stack.truncate(idx + 1);
541        let entry = self.savepoint_stack.last().unwrap();
542        let wtx_snap = entry.wtx_snap.clone();
543        let schema_snap = entry.schema_snap.clone();
544
545        let wtx = self.active_txn.as_mut().unwrap();
546        wtx.restore_snapshot(wtx_snap);
547        self.schema.restore_snapshot(schema_snap);
548
549        // schema_gen went backward; evict cache entries keyed on it.
550        self.stmt_cache.clear();
551
552        Ok(ExecutionResult::Ok)
553    }
554}