Skip to main content

citadeldb_sql/
connection.rs

1//! Public SQL connection API.
2
3use std::num::NonZeroUsize;
4
5use lru::LruCache;
6
7use citadel::Database;
8use citadel_txn::write_txn::WriteTxn;
9
10use crate::error::{Result, SqlError};
11use crate::executor;
12use crate::parser;
13use crate::parser::Statement;
14use crate::schema::SchemaManager;
15use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
16
17const DEFAULT_CACHE_CAPACITY: usize = 64;
18
19struct CacheEntry {
20    stmt: Statement,
21    schema_gen: u64,
22    param_count: usize,
23}
24
25/// A SQL connection wrapping a Citadel database.
26///
27/// Supports explicit transactions via BEGIN / COMMIT / ROLLBACK.
28/// Without BEGIN, each statement runs in auto-commit mode.
29///
30/// Caches parsed SQL statements in an LRU cache keyed by SQL string.
31/// Cache entries are invalidated when the schema changes (DDL operations).
32pub struct Connection<'a> {
33    db: &'a Database,
34    schema: SchemaManager,
35    active_txn: Option<WriteTxn<'a>>,
36    stmt_cache: LruCache<String, CacheEntry>,
37}
38
39impl<'a> Connection<'a> {
40    /// Open a SQL connection to a database.
41    pub fn open(db: &'a Database) -> Result<Self> {
42        let schema = SchemaManager::load(db)?;
43        let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
44        Ok(Self { db, schema, active_txn: None, stmt_cache })
45    }
46
47    /// Execute a SQL statement. Returns the result.
48    pub fn execute(&mut self, sql: &str) -> Result<ExecutionResult> {
49        self.execute_params(sql, &[])
50    }
51
52    /// Execute a SQL statement with positional parameters ($1, $2, ...).
53    pub fn execute_params(&mut self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
54        let (stmt, param_count) = self.get_or_parse(sql)?;
55
56        if param_count != params.len() {
57            return Err(SqlError::ParameterCountMismatch {
58                expected: param_count,
59                got: params.len(),
60            });
61        }
62
63        let bound = if param_count > 0 {
64            parser::bind_params(&stmt, params)?
65        } else {
66            stmt
67        };
68
69        self.dispatch(bound)
70    }
71
72    /// Execute a SQL query and return the result set.
73    pub fn query(&mut self, sql: &str) -> Result<QueryResult> {
74        self.query_params(sql, &[])
75    }
76
77    /// Execute a SQL query with positional parameters ($1, $2, ...).
78    pub fn query_params(&mut self, sql: &str, params: &[Value]) -> Result<QueryResult> {
79        match self.execute_params(sql, params)? {
80            ExecutionResult::Query(qr) => Ok(qr),
81            ExecutionResult::RowsAffected(n) => Ok(QueryResult {
82                columns: vec!["rows_affected".into()],
83                rows: vec![vec![Value::Integer(n as i64)]],
84            }),
85            ExecutionResult::Ok => Ok(QueryResult {
86                columns: vec![],
87                rows: vec![],
88            }),
89        }
90    }
91
92    /// List all table names.
93    pub fn tables(&self) -> Vec<&str> {
94        self.schema.table_names()
95    }
96
97    /// Returns true if an explicit transaction is active (BEGIN was issued).
98    pub fn in_transaction(&self) -> bool {
99        self.active_txn.is_some()
100    }
101
102    /// Get the schema for a named table.
103    pub fn table_schema(&self, name: &str) -> Option<&TableSchema> {
104        self.schema.get(name)
105    }
106
107    /// Reload schemas from the database.
108    pub fn refresh_schema(&mut self) -> Result<()> {
109        self.schema = SchemaManager::load(self.db)?;
110        Ok(())
111    }
112
113    fn get_or_parse(&mut self, sql: &str) -> Result<(Statement, usize)> {
114        let gen = self.schema.generation();
115
116        if let Some(entry) = self.stmt_cache.get(sql) {
117            if entry.schema_gen == gen {
118                return Ok((entry.stmt.clone(), entry.param_count));
119            }
120        }
121
122        let stmt = parser::parse_sql(sql)?;
123        let param_count = parser::count_params(&stmt);
124
125        let cacheable = !matches!(
126            stmt,
127            Statement::CreateTable(_) | Statement::DropTable(_)
128            | Statement::CreateIndex(_) | Statement::DropIndex(_)
129            | Statement::Begin | Statement::Commit | Statement::Rollback
130        );
131
132        if cacheable {
133            self.stmt_cache.put(sql.to_string(), CacheEntry {
134                stmt: stmt.clone(),
135                schema_gen: gen,
136                param_count,
137            });
138        }
139
140        Ok((stmt, param_count))
141    }
142
143    fn dispatch(&mut self, stmt: Statement) -> Result<ExecutionResult> {
144        match stmt {
145            Statement::Begin => {
146                if self.active_txn.is_some() {
147                    return Err(SqlError::TransactionAlreadyActive);
148                }
149                let wtx = self.db.begin_write().map_err(SqlError::Storage)?;
150                self.active_txn = Some(wtx);
151                Ok(ExecutionResult::Ok)
152            }
153            Statement::Commit => {
154                let wtx = self.active_txn.take()
155                    .ok_or(SqlError::NoActiveTransaction)?;
156                wtx.commit().map_err(SqlError::Storage)?;
157                Ok(ExecutionResult::Ok)
158            }
159            Statement::Rollback => {
160                let wtx = self.active_txn.take()
161                    .ok_or(SqlError::NoActiveTransaction)?;
162                wtx.abort();
163                self.schema = SchemaManager::load(self.db)?;
164                Ok(ExecutionResult::Ok)
165            }
166            _ => {
167                if let Some(ref mut wtx) = self.active_txn {
168                    executor::execute_in_txn(wtx, &mut self.schema, &stmt)
169                } else {
170                    executor::execute(self.db, &mut self.schema, &stmt)
171                }
172            }
173        }
174    }
175}