citadel_sql/
connection.rs1use std::num::NonZeroUsize;
4
5use lru::LruCache;
6
7use citadel::Database;
8use citadel_txn::write_txn::WriteTxn;
9
10use crate::error::{Result, SqlError};
11use crate::executor;
12use crate::parser;
13use crate::parser::Statement;
14use crate::schema::SchemaManager;
15use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
16
17const DEFAULT_CACHE_CAPACITY: usize = 64;
18
19struct CacheEntry {
20 stmt: Statement,
21 schema_gen: u64,
22 param_count: usize,
23}
24
25pub struct Connection<'a> {
33 db: &'a Database,
34 schema: SchemaManager,
35 active_txn: Option<WriteTxn<'a>>,
36 stmt_cache: LruCache<String, CacheEntry>,
37}
38
39impl<'a> Connection<'a> {
40 pub fn open(db: &'a Database) -> Result<Self> {
42 let schema = SchemaManager::load(db)?;
43 let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
44 Ok(Self {
45 db,
46 schema,
47 active_txn: None,
48 stmt_cache,
49 })
50 }
51
52 pub fn execute(&mut self, sql: &str) -> Result<ExecutionResult> {
54 self.execute_params(sql, &[])
55 }
56
57 pub fn execute_params(&mut self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
59 let (stmt, param_count) = self.get_or_parse(sql)?;
60
61 if param_count != params.len() {
62 return Err(SqlError::ParameterCountMismatch {
63 expected: param_count,
64 got: params.len(),
65 });
66 }
67
68 let bound = if param_count > 0 {
69 parser::bind_params(&stmt, params)?
70 } else {
71 stmt
72 };
73
74 self.dispatch(bound)
75 }
76
77 pub fn query(&mut self, sql: &str) -> Result<QueryResult> {
79 self.query_params(sql, &[])
80 }
81
82 pub fn query_params(&mut self, sql: &str, params: &[Value]) -> Result<QueryResult> {
84 match self.execute_params(sql, params)? {
85 ExecutionResult::Query(qr) => Ok(qr),
86 ExecutionResult::RowsAffected(n) => Ok(QueryResult {
87 columns: vec!["rows_affected".into()],
88 rows: vec![vec![Value::Integer(n as i64)]],
89 }),
90 ExecutionResult::Ok => Ok(QueryResult {
91 columns: vec![],
92 rows: vec![],
93 }),
94 }
95 }
96
97 pub fn tables(&self) -> Vec<&str> {
99 self.schema.table_names()
100 }
101
102 pub fn in_transaction(&self) -> bool {
104 self.active_txn.is_some()
105 }
106
107 pub fn table_schema(&self, name: &str) -> Option<&TableSchema> {
109 self.schema.get(name)
110 }
111
112 pub fn refresh_schema(&mut self) -> Result<()> {
114 self.schema = SchemaManager::load(self.db)?;
115 Ok(())
116 }
117
118 fn get_or_parse(&mut self, sql: &str) -> Result<(Statement, usize)> {
119 let gen = self.schema.generation();
120
121 if let Some(entry) = self.stmt_cache.get(sql) {
122 if entry.schema_gen == gen {
123 return Ok((entry.stmt.clone(), entry.param_count));
124 }
125 }
126
127 let stmt = parser::parse_sql(sql)?;
128 let param_count = parser::count_params(&stmt);
129
130 let cacheable = !matches!(
131 stmt,
132 Statement::CreateTable(_)
133 | Statement::DropTable(_)
134 | Statement::CreateIndex(_)
135 | Statement::DropIndex(_)
136 | Statement::Begin
137 | Statement::Commit
138 | Statement::Rollback
139 );
140
141 if cacheable {
142 self.stmt_cache.put(
143 sql.to_string(),
144 CacheEntry {
145 stmt: stmt.clone(),
146 schema_gen: gen,
147 param_count,
148 },
149 );
150 }
151
152 Ok((stmt, param_count))
153 }
154
155 fn dispatch(&mut self, stmt: Statement) -> Result<ExecutionResult> {
156 match stmt {
157 Statement::Begin => {
158 if self.active_txn.is_some() {
159 return Err(SqlError::TransactionAlreadyActive);
160 }
161 let wtx = self.db.begin_write().map_err(SqlError::Storage)?;
162 self.active_txn = Some(wtx);
163 Ok(ExecutionResult::Ok)
164 }
165 Statement::Commit => {
166 let wtx = self
167 .active_txn
168 .take()
169 .ok_or(SqlError::NoActiveTransaction)?;
170 wtx.commit().map_err(SqlError::Storage)?;
171 Ok(ExecutionResult::Ok)
172 }
173 Statement::Rollback => {
174 let wtx = self
175 .active_txn
176 .take()
177 .ok_or(SqlError::NoActiveTransaction)?;
178 wtx.abort();
179 self.schema = SchemaManager::load(self.db)?;
180 Ok(ExecutionResult::Ok)
181 }
182 _ => {
183 if let Some(ref mut wtx) = self.active_txn {
184 executor::execute_in_txn(wtx, &mut self.schema, &stmt)
185 } else {
186 executor::execute(self.db, &mut self.schema, &stmt)
187 }
188 }
189 }
190 }
191}