citadeldb_sql/
connection.rs1use std::num::NonZeroUsize;
4
5use lru::LruCache;
6
7use citadel::Database;
8use citadel_txn::write_txn::WriteTxn;
9
10use crate::error::{Result, SqlError};
11use crate::executor;
12use crate::parser;
13use crate::parser::Statement;
14use crate::schema::SchemaManager;
15use crate::types::{ExecutionResult, QueryResult, TableSchema, Value};
16
17const DEFAULT_CACHE_CAPACITY: usize = 64;
18
19struct CacheEntry {
20 stmt: Statement,
21 schema_gen: u64,
22 param_count: usize,
23}
24
25pub struct Connection<'a> {
33 db: &'a Database,
34 schema: SchemaManager,
35 active_txn: Option<WriteTxn<'a>>,
36 stmt_cache: LruCache<String, CacheEntry>,
37}
38
39impl<'a> Connection<'a> {
40 pub fn open(db: &'a Database) -> Result<Self> {
42 let schema = SchemaManager::load(db)?;
43 let stmt_cache = LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).unwrap());
44 Ok(Self { db, schema, active_txn: None, stmt_cache })
45 }
46
47 pub fn execute(&mut self, sql: &str) -> Result<ExecutionResult> {
49 self.execute_params(sql, &[])
50 }
51
52 pub fn execute_params(&mut self, sql: &str, params: &[Value]) -> Result<ExecutionResult> {
54 let (stmt, param_count) = self.get_or_parse(sql)?;
55
56 if param_count != params.len() {
57 return Err(SqlError::ParameterCountMismatch {
58 expected: param_count,
59 got: params.len(),
60 });
61 }
62
63 let bound = if param_count > 0 {
64 parser::bind_params(&stmt, params)?
65 } else {
66 stmt
67 };
68
69 self.dispatch(bound)
70 }
71
72 pub fn query(&mut self, sql: &str) -> Result<QueryResult> {
74 self.query_params(sql, &[])
75 }
76
77 pub fn query_params(&mut self, sql: &str, params: &[Value]) -> Result<QueryResult> {
79 match self.execute_params(sql, params)? {
80 ExecutionResult::Query(qr) => Ok(qr),
81 ExecutionResult::RowsAffected(n) => Ok(QueryResult {
82 columns: vec!["rows_affected".into()],
83 rows: vec![vec![Value::Integer(n as i64)]],
84 }),
85 ExecutionResult::Ok => Ok(QueryResult {
86 columns: vec![],
87 rows: vec![],
88 }),
89 }
90 }
91
92 pub fn tables(&self) -> Vec<&str> {
94 self.schema.table_names()
95 }
96
97 pub fn in_transaction(&self) -> bool {
99 self.active_txn.is_some()
100 }
101
102 pub fn table_schema(&self, name: &str) -> Option<&TableSchema> {
104 self.schema.get(name)
105 }
106
107 pub fn refresh_schema(&mut self) -> Result<()> {
109 self.schema = SchemaManager::load(self.db)?;
110 Ok(())
111 }
112
113 fn get_or_parse(&mut self, sql: &str) -> Result<(Statement, usize)> {
114 let gen = self.schema.generation();
115
116 if let Some(entry) = self.stmt_cache.get(sql) {
117 if entry.schema_gen == gen {
118 return Ok((entry.stmt.clone(), entry.param_count));
119 }
120 }
121
122 let stmt = parser::parse_sql(sql)?;
123 let param_count = parser::count_params(&stmt);
124
125 let cacheable = !matches!(
126 stmt,
127 Statement::CreateTable(_) | Statement::DropTable(_)
128 | Statement::CreateIndex(_) | Statement::DropIndex(_)
129 | Statement::Begin | Statement::Commit | Statement::Rollback
130 );
131
132 if cacheable {
133 self.stmt_cache.put(sql.to_string(), CacheEntry {
134 stmt: stmt.clone(),
135 schema_gen: gen,
136 param_count,
137 });
138 }
139
140 Ok((stmt, param_count))
141 }
142
143 fn dispatch(&mut self, stmt: Statement) -> Result<ExecutionResult> {
144 match stmt {
145 Statement::Begin => {
146 if self.active_txn.is_some() {
147 return Err(SqlError::TransactionAlreadyActive);
148 }
149 let wtx = self.db.begin_write().map_err(SqlError::Storage)?;
150 self.active_txn = Some(wtx);
151 Ok(ExecutionResult::Ok)
152 }
153 Statement::Commit => {
154 let wtx = self.active_txn.take()
155 .ok_or(SqlError::NoActiveTransaction)?;
156 wtx.commit().map_err(SqlError::Storage)?;
157 Ok(ExecutionResult::Ok)
158 }
159 Statement::Rollback => {
160 let wtx = self.active_txn.take()
161 .ok_or(SqlError::NoActiveTransaction)?;
162 wtx.abort();
163 self.schema = SchemaManager::load(self.db)?;
164 Ok(ExecutionResult::Ok)
165 }
166 _ => {
167 if let Some(ref mut wtx) = self.active_txn {
168 executor::execute_in_txn(wtx, &mut self.schema, &stmt)
169 } else {
170 executor::execute(self.db, &mut self.schema, &stmt)
171 }
172 }
173 }
174 }
175}