1use crate::config::{SQLITE_URI, STATEMENT_CACHE_CAPACITY, VFS_NAME};
8use crate::db::row::{FromColumn, Row};
9use crate::db::statement::Statement;
10use crate::db::value::ToSql;
11use crate::db::{pragmas, DbError};
12use crate::sqlite_vfs::ffi;
13use std::cell::RefCell;
14use std::collections::{BTreeMap, VecDeque};
15use std::ffi::{c_char, c_int, c_void, CStr, CString};
16use std::ops::{Deref, DerefMut};
17use std::ptr::{self, NonNull};
18
19pub struct Connection {
20 raw: NonNull<ffi::sqlite3>,
21 cached: RefCell<StatementCache>,
22}
23
24pub struct CachedStatement<'connection> {
25 statement: Option<Statement<'connection>>,
26 sql: String,
27 cache: &'connection RefCell<StatementCache>,
28}
29
30struct StatementCache {
31 statements: BTreeMap<String, NonNull<ffi::sqlite3_stmt>>,
32 returned_lru: VecDeque<String>,
33}
34
35impl StatementCache {
36 fn new() -> Self {
37 Self {
38 statements: BTreeMap::new(),
39 returned_lru: VecDeque::new(),
40 }
41 }
42
43 fn take(&mut self, sql: &str) -> Option<NonNull<ffi::sqlite3_stmt>> {
44 let raw = self.statements.remove(sql)?;
45 self.returned_lru.retain(|cached_sql| cached_sql != sql);
46 Some(raw)
47 }
48
49 unsafe fn insert(&mut self, sql: String, raw: NonNull<ffi::sqlite3_stmt>) {
50 if let Some(previous) = self.statements.insert(sql.clone(), raw) {
51 ffi::sqlite3_finalize(previous.as_ptr());
52 }
53 self.returned_lru.retain(|cached_sql| cached_sql != &sql);
54 self.returned_lru.push_back(sql);
55 self.evict_over_capacity();
56 }
57
58 unsafe fn evict_over_capacity(&mut self) {
59 while self.statements.len() > STATEMENT_CACHE_CAPACITY {
60 let Some(sql) = self.returned_lru.pop_front() else {
61 return;
62 };
63 if let Some(statement) = self.statements.remove(&sql) {
64 ffi::sqlite3_finalize(statement.as_ptr());
65 }
66 }
67 }
68
69 unsafe fn finalize_all(&mut self) {
70 for (_, statement) in std::mem::take(&mut self.statements) {
71 ffi::sqlite3_finalize(statement.as_ptr());
72 }
73 self.returned_lru.clear();
74 }
75}
76
77pub fn open_read_write() -> Result<Connection, DbError> {
78 let flags = ffi::SQLITE_OPEN_READWRITE
79 | ffi::SQLITE_OPEN_CREATE
80 | ffi::SQLITE_OPEN_URI
81 | ffi::SQLITE_OPEN_NOMUTEX;
82 let connection = Connection::open(flags)?;
83 pragmas::apply_read_write(&connection)?;
84 Ok(connection)
85}
86
87pub fn open_read_only() -> Result<Connection, DbError> {
88 let flags = ffi::SQLITE_OPEN_READONLY | ffi::SQLITE_OPEN_URI | ffi::SQLITE_OPEN_NOMUTEX;
89 let connection = Connection::open(flags)?;
90 pragmas::apply_read_only(&connection)?;
91 Ok(connection)
92}
93
94impl Connection {
95 fn open(flags: c_int) -> Result<Self, DbError> {
96 let filename = CString::new(SQLITE_URI).map_err(|_| DbError::InteriorNul)?;
97 let vfs = CString::new(VFS_NAME).map_err(|_| DbError::InteriorNul)?;
98 let mut db = ptr::null_mut();
99 let rc = unsafe { ffi::sqlite3_open_v2(filename.as_ptr(), &mut db, flags, vfs.as_ptr()) };
100 let Some(raw) = NonNull::new(db) else {
101 return Err(DbError::Sqlite(
102 rc,
103 "sqlite3_open_v2 returned null".to_string(),
104 ));
105 };
106 if rc != ffi::SQLITE_OK {
107 let error = sqlite_error(raw.as_ptr(), rc);
108 unsafe {
109 ffi::sqlite3_close(raw.as_ptr());
110 }
111 return Err(error);
112 }
113 Ok(Self {
114 raw,
115 cached: RefCell::new(StatementCache::new()),
116 })
117 }
118
119 pub fn raw(&self) -> *mut ffi::sqlite3 {
120 self.raw.as_ptr()
121 }
122
123 pub fn execute_batch(&self, sql: &str) -> Result<(), DbError> {
124 let sql = CString::new(sql).map_err(|_| DbError::InteriorNul)?;
125 let mut error = ptr::null_mut();
126 let rc = unsafe {
127 ffi::sqlite3_exec(
128 self.raw.as_ptr(),
129 sql.as_ptr(),
130 None,
131 ptr::null_mut(),
132 &mut error,
133 )
134 };
135 if rc == ffi::SQLITE_OK {
136 return Ok(());
137 }
138 Err(classify_sqlite_error(rc, take_error_message(error)))
139 }
140
141 pub fn execute(&self, sql: &str, values: &[&dyn ToSql]) -> Result<(), DbError> {
142 let mut statement = self.prepare(sql)?;
143 statement.execute(values)
144 }
145
146 pub fn execute_named(&self, sql: &str, values: &[(&str, &dyn ToSql)]) -> Result<(), DbError> {
147 let mut statement = self.prepare(sql)?;
148 statement.execute_named(values)
149 }
150
151 pub fn prepare(&self, sql: &str) -> Result<Statement<'_>, DbError> {
152 let sql = CString::new(sql).map_err(|_| DbError::InteriorNul)?;
153 let mut statement = ptr::null_mut();
154 let mut tail = ptr::null();
155 let rc = unsafe {
156 ffi::sqlite3_prepare_v2(
157 self.raw.as_ptr(),
158 sql.as_ptr(),
159 -1,
160 &mut statement,
161 &mut tail,
162 )
163 };
164 if rc != ffi::SQLITE_OK {
165 return Err(sqlite_error(self.raw.as_ptr(), rc));
166 }
167 let Some(raw) = NonNull::new(statement) else {
168 return Err(DbError::EmptySql);
169 };
170 if !tail_is_empty(tail) {
171 unsafe {
172 ffi::sqlite3_finalize(raw.as_ptr());
173 }
174 return Err(DbError::TrailingSql);
175 }
176 Ok(Statement::new(self.raw.as_ptr(), raw))
177 }
178
179 pub fn prepare_cached(&self, sql: &str) -> Result<CachedStatement<'_>, DbError> {
180 if let Some(raw) = self.cached.borrow_mut().take(sql) {
181 return Ok(CachedStatement::new(
182 Statement::new(self.raw.as_ptr(), raw),
183 sql.to_string(),
184 &self.cached,
185 ));
186 }
187 let statement = self.prepare(sql)?;
188 Ok(CachedStatement::new(
189 statement,
190 sql.to_string(),
191 &self.cached,
192 ))
193 }
194
195 pub fn query_one<T, F>(&self, sql: &str, values: &[&dyn ToSql], f: F) -> Result<T, DbError>
196 where
197 F: FnOnce(&Row<'_>) -> Result<T, DbError>,
198 {
199 let mut statement = self.prepare(sql)?;
200 statement.query_one(values, f)
201 }
202
203 pub fn query_one_named<T, F>(
204 &self,
205 sql: &str,
206 values: &[(&str, &dyn ToSql)],
207 f: F,
208 ) -> Result<T, DbError>
209 where
210 F: FnOnce(&Row<'_>) -> Result<T, DbError>,
211 {
212 let mut statement = self.prepare(sql)?;
213 statement.query_one_named(values, f)
214 }
215
216 pub fn query_row<T, F>(&self, sql: &str, values: &[&dyn ToSql], f: F) -> Result<T, DbError>
220 where
221 F: FnOnce(&Row<'_>) -> Result<T, DbError>,
222 {
223 self.query_one(sql, values, f)
224 }
225
226 pub fn query_row_named<T, F>(
230 &self,
231 sql: &str,
232 values: &[(&str, &dyn ToSql)],
233 f: F,
234 ) -> Result<T, DbError>
235 where
236 F: FnOnce(&Row<'_>) -> Result<T, DbError>,
237 {
238 self.query_one_named(sql, values, f)
239 }
240
241 pub fn query_optional<T, F>(
242 &self,
243 sql: &str,
244 values: &[&dyn ToSql],
245 f: F,
246 ) -> Result<Option<T>, DbError>
247 where
248 F: FnOnce(&Row<'_>) -> Result<T, DbError>,
249 {
250 let mut statement = self.prepare(sql)?;
251 statement.query_optional(values, f)
252 }
253
254 pub fn query_optional_named<T, F>(
255 &self,
256 sql: &str,
257 values: &[(&str, &dyn ToSql)],
258 f: F,
259 ) -> Result<Option<T>, DbError>
260 where
261 F: FnOnce(&Row<'_>) -> Result<T, DbError>,
262 {
263 let mut statement = self.prepare(sql)?;
264 statement.query_optional_named(values, f)
265 }
266
267 pub fn query_all<T, F>(&self, sql: &str, values: &[&dyn ToSql], f: F) -> Result<Vec<T>, DbError>
268 where
269 F: FnMut(&Row<'_>) -> Result<T, DbError>,
270 {
271 let mut statement = self.prepare(sql)?;
272 statement.query_all(values, f)
273 }
274
275 pub fn query_all_named<T, F>(
276 &self,
277 sql: &str,
278 values: &[(&str, &dyn ToSql)],
279 f: F,
280 ) -> Result<Vec<T>, DbError>
281 where
282 F: FnMut(&Row<'_>) -> Result<T, DbError>,
283 {
284 let mut statement = self.prepare(sql)?;
285 statement.query_all_named(values, f)
286 }
287
288 pub fn query_map<T, F>(&self, sql: &str, values: &[&dyn ToSql], f: F) -> Result<Vec<T>, DbError>
294 where
295 F: FnMut(&Row<'_>) -> Result<T, DbError>,
296 {
297 self.query_all(sql, values, f)
298 }
299
300 pub fn query_map_named<T, F>(
306 &self,
307 sql: &str,
308 values: &[(&str, &dyn ToSql)],
309 f: F,
310 ) -> Result<Vec<T>, DbError>
311 where
312 F: FnMut(&Row<'_>) -> Result<T, DbError>,
313 {
314 self.query_all_named(sql, values, f)
315 }
316
317 pub fn exists(&self, sql: &str, values: &[&dyn ToSql]) -> Result<bool, DbError> {
318 self.query_optional(sql, values, |row| row.get::<i64>(0))
319 .map(|value| value.unwrap_or(0) != 0)
320 }
321
322 pub fn query_scalar<T: FromColumn>(
323 &self,
324 sql: &str,
325 values: &[&dyn ToSql],
326 ) -> Result<T, DbError> {
327 self.query_one(sql, values, |row| row.get(0))
328 }
329
330 pub fn query_scalar_named<T: FromColumn>(
331 &self,
332 sql: &str,
333 values: &[(&str, &dyn ToSql)],
334 ) -> Result<T, DbError> {
335 self.query_one_named(sql, values, |row| row.get(0))
336 }
337
338 pub fn query_optional_scalar<T: FromColumn>(
339 &self,
340 sql: &str,
341 values: &[&dyn ToSql],
342 ) -> Result<Option<T>, DbError> {
343 self.query_optional(sql, values, |row| row.get(0))
344 }
345
346 pub fn query_optional_string_text(
347 &self,
348 sql: &str,
349 value: &str,
350 ) -> Result<Option<String>, DbError> {
351 let mut statement = self.prepare(sql)?;
352 statement.query_optional_string_text(value)
353 }
354
355 pub fn query_optional_scalar_named<T: FromColumn>(
356 &self,
357 sql: &str,
358 values: &[(&str, &dyn ToSql)],
359 ) -> Result<Option<T>, DbError> {
360 self.query_optional_named(sql, values, |row| row.get(0))
361 }
362
363 pub fn query_column<T: FromColumn>(
364 &self,
365 sql: &str,
366 values: &[&dyn ToSql],
367 ) -> Result<Vec<T>, DbError> {
368 self.query_all(sql, values, |row| row.get(0))
369 }
370
371 pub fn query_column_named<T: FromColumn>(
372 &self,
373 sql: &str,
374 values: &[(&str, &dyn ToSql)],
375 ) -> Result<Vec<T>, DbError> {
376 self.query_all_named(sql, values, |row| row.get(0))
377 }
378}
379
380impl Drop for Connection {
381 fn drop(&mut self) {
382 unsafe {
383 self.cached.get_mut().finalize_all();
384 ffi::sqlite3_close(self.raw.as_ptr());
385 }
386 }
387}
388
389impl<'connection> CachedStatement<'connection> {
390 fn new(
391 statement: Statement<'connection>,
392 sql: String,
393 cache: &'connection RefCell<StatementCache>,
394 ) -> Self {
395 Self {
396 statement: Some(statement),
397 sql,
398 cache,
399 }
400 }
401
402 pub fn discard(mut self) {
403 if let Some(statement) = self.statement.take() {
404 unsafe {
405 ffi::sqlite3_finalize(statement.into_raw().as_ptr());
406 }
407 }
408 }
409}
410
411impl<'connection> Deref for CachedStatement<'connection> {
412 type Target = Statement<'connection>;
413
414 fn deref(&self) -> &Self::Target {
415 self.statement
416 .as_ref()
417 .expect("cached statement is present")
418 }
419}
420
421impl DerefMut for CachedStatement<'_> {
422 fn deref_mut(&mut self) -> &mut Self::Target {
423 self.statement
424 .as_mut()
425 .expect("cached statement is present")
426 }
427}
428
429impl Drop for CachedStatement<'_> {
430 fn drop(&mut self) {
431 let Some(statement) = self.statement.take() else {
432 return;
433 };
434 let raw = statement.into_raw();
435 unsafe {
436 ffi::sqlite3_reset(raw.as_ptr());
437 ffi::sqlite3_clear_bindings(raw.as_ptr());
438 self.cache.borrow_mut().insert(self.sql.clone(), raw);
439 }
440 }
441}
442
443pub(crate) fn sqlite_error(db: *mut ffi::sqlite3, code: c_int) -> DbError {
444 let message = unsafe {
445 let ptr = ffi::sqlite3_errmsg(db);
446 if ptr.is_null() {
447 "unknown sqlite error".to_string()
448 } else {
449 CStr::from_ptr(ptr).to_string_lossy().into_owned()
450 }
451 };
452 classify_sqlite_error(code, message)
453}
454
455fn classify_sqlite_error(code: c_int, message: String) -> DbError {
456 if code == ffi::SQLITE_CONSTRAINT {
457 DbError::Constraint(message)
458 } else {
459 DbError::Sqlite(code, message)
460 }
461}
462
463fn take_error_message(error: *mut c_char) -> String {
464 if error.is_null() {
465 return "unknown sqlite error".to_string();
466 }
467 let message = unsafe { CStr::from_ptr(error).to_string_lossy().into_owned() };
468 unsafe {
469 ffi::sqlite3_free(error.cast::<c_void>());
470 }
471 message
472}
473
474fn tail_is_empty(tail: *const c_char) -> bool {
475 if tail.is_null() {
476 return true;
477 }
478 let bytes = unsafe { CStr::from_ptr(tail).to_bytes() };
479 bytes.iter().all(u8::is_ascii_whitespace)
480}
481
482#[cfg(test)]
483mod tests {
484 use super::open_read_write;
485 use crate::config::STATEMENT_CACHE_CAPACITY;
486 use crate::sqlite_vfs::{lock, stable_blob};
487 use crate::stable::memory;
488 use crate::Db;
489 use serial_test::serial;
490
491 fn reset() {
492 stable_blob::rollback_update();
493 stable_blob::invalidate_read_cache();
494 memory::reset_for_tests();
495 lock::reset_for_tests();
496 Db::init(memory::memory_for_tests()).unwrap();
497 }
498
499 #[test]
500 #[serial]
501 fn cached_statements_are_lru_bounded() {
502 reset();
503 let connection = open_read_write().unwrap();
504
505 for index in 0..(STATEMENT_CACHE_CAPACITY + 8) {
506 let sql = format!("SELECT {index}");
507 let mut statement = connection.prepare_cached(&sql).unwrap();
508 let value = statement.query_scalar::<i64>(crate::params![]).unwrap();
509 assert_eq!(value, i64::try_from(index).unwrap());
510 }
511
512 let cache = connection.cached.borrow();
513 assert_eq!(cache.statements.len(), STATEMENT_CACHE_CAPACITY);
514 assert!(!cache.statements.contains_key("SELECT 0"));
515 assert!(cache
516 .statements
517 .contains_key(&format!("SELECT {}", STATEMENT_CACHE_CAPACITY + 7)));
518 }
519
520 #[test]
521 #[serial]
522 fn discarded_cached_statement_is_finalized_not_cached() {
523 reset();
524 let connection = open_read_write().unwrap();
525
526 let statement = connection.prepare_cached("SELECT 1").unwrap();
527 statement.discard();
528
529 assert_eq!(connection.cached.borrow().statements.len(), 0);
530 }
531}