Skip to main content

citadel_sql/
connection.rs

1//! Public SQL connection API.
2
3use std::num::NonZeroUsize;
4
5use lru::LruCache;
6
7use citadel::Database;
8use citadel_txn::write_txn::WriteTxn;
9
10use crate::error::{Result, SqlError};
11use crate::executor;
12use crate::parser;
13use crate::parser::Statement;
14use crate::schema::SchemaManager;
15use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
16
17const DEFAULT_CACHE_CAPACITY: usize = 64;
18
19struct CacheEntry {
20    stmt: Statement,
21    schema_gen: u64,
22    param_count: usize,
23}
24
25/// A SQL connection wrapping a Citadel database.
26///
27/// Supports explicit transactions via BEGIN / COMMIT / ROLLBACK.
28/// Without BEGIN, each statement runs in auto-commit mode.
29///
30/// Caches parsed SQL statements in an LRU cache keyed by SQL string.
31/// Cache entries are invalidated when the schema changes (DDL operations).
32pub struct Connection<'a> {
33    db: &'a Database,
34    schema: SchemaManager,
35    active_txn: Option<WriteTxn<'a>>,
36    stmt_cache: LruCache<String, CacheEntry>,
37}
38
39impl<'a> Connection<'a> {
40    /// Open a SQL connection to a database.
41    pub fn open(db: &'a Database) -> Result<Self> {
42        let schema = SchemaManager::load(db)?;
43        let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
44        Ok(Self {
45            db,
46            schema,
47            active_txn: None,
48            stmt_cache,
49        })
50    }
51
52    /// Execute a SQL statement. Returns the result.
53    pub fn execute(&mut self, sql: &str) -> Result<ExecutionResult> {
54        self.execute_params(sql, &[])
55    }
56
57    /// Execute a SQL statement with positional parameters ($1, $2, ...).
58    pub fn execute_params(&mut self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
59        let (stmt, param_count) = self.get_or_parse(sql)?;
60
61        if param_count != params.len() {
62            return Err(SqlError::ParameterCountMismatch {
63                expected: param_count,
64                got: params.len(),
65            });
66        }
67
68        let bound = if param_count > 0 {
69            parser::bind_params(&stmt, params)?
70        } else {
71            stmt
72        };
73
74        self.dispatch(bound)
75    }
76
77    /// Execute a SQL query and return the result set.
78    pub fn query(&mut self, sql: &str) -> Result<QueryResult> {
79        self.query_params(sql, &[])
80    }
81
82    /// Execute a SQL query with positional parameters ($1, $2, ...).
83    pub fn query_params(&mut self, sql: &str, params: &[Value]) -> Result<QueryResult> {
84        match self.execute_params(sql, params)? {
85            ExecutionResult::Query(qr) => Ok(qr),
86            ExecutionResult::RowsAffected(n) => Ok(QueryResult {
87                columns: vec!["rows_affected".into()],
88                rows: vec![vec![Value::Integer(n as i64)]],
89            }),
90            ExecutionResult::Ok => Ok(QueryResult {
91                columns: vec![],
92                rows: vec![],
93            }),
94        }
95    }
96
97    /// List all table names.
98    pub fn tables(&self) -> Vec<&str> {
99        self.schema.table_names()
100    }
101
102    /// Returns true if an explicit transaction is active (BEGIN was issued).
103    pub fn in_transaction(&self) -> bool {
104        self.active_txn.is_some()
105    }
106
107    /// Get the schema for a named table.
108    pub fn table_schema(&self, name: &str) -> Option<&TableSchema> {
109        self.schema.get(name)
110    }
111
112    /// Reload schemas from the database.
113    pub fn refresh_schema(&mut self) -> Result<()> {
114        self.schema = SchemaManager::load(self.db)?;
115        Ok(())
116    }
117
118    fn get_or_parse(&mut self, sql: &str) -> Result<(Statement, usize)> {
119        let gen = self.schema.generation();
120
121        if let Some(entry) = self.stmt_cache.get(sql) {
122            if entry.schema_gen == gen {
123                return Ok((entry.stmt.clone(), entry.param_count));
124            }
125        }
126
127        let stmt = parser::parse_sql(sql)?;
128        let param_count = parser::count_params(&stmt);
129
130        let cacheable = !matches!(
131            stmt,
132            Statement::CreateTable(_)
133                | Statement::DropTable(_)
134                | Statement::CreateIndex(_)
135                | Statement::DropIndex(_)
136                | Statement::Begin
137                | Statement::Commit
138                | Statement::Rollback
139        );
140
141        if cacheable {
142            self.stmt_cache.put(
143                sql.to_string(),
144                CacheEntry {
145                    stmt: stmt.clone(),
146                    schema_gen: gen,
147                    param_count,
148                },
149            );
150        }
151
152        Ok((stmt, param_count))
153    }
154
155    fn dispatch(&mut self, stmt: Statement) -> Result<ExecutionResult> {
156        match stmt {
157            Statement::Begin => {
158                if self.active_txn.is_some() {
159                    return Err(SqlError::TransactionAlreadyActive);
160                }
161                let wtx = self.db.begin_write().map_err(SqlError::Storage)?;
162                self.active_txn = Some(wtx);
163                Ok(ExecutionResult::Ok)
164            }
165            Statement::Commit => {
166                let wtx = self
167                    .active_txn
168                    .take()
169                    .ok_or(SqlError::NoActiveTransaction)?;
170                wtx.commit().map_err(SqlError::Storage)?;
171                Ok(ExecutionResult::Ok)
172            }
173            Statement::Rollback => {
174                let wtx = self
175                    .active_txn
176                    .take()
177                    .ok_or(SqlError::NoActiveTransaction)?;
178                wtx.abort();
179                self.schema = SchemaManager::load(self.db)?;
180                Ok(ExecutionResult::Ok)
181            }
182            _ => {
183                if let Some(ref mut wtx) = self.active_txn {
184                    executor::execute_in_txn(wtx, &mut self.schema, &stmt)
185                } else {
186                    executor::execute(self.db, &mut self.schema, &stmt)
187                }
188            }
189        }
190    }
191}