1use crate::config::{SQLITE_URI_NUL, STATEMENT_CACHE_CAPACITY, VFS_NAME_NUL};
8use crate::db::row::{FromColumn, Row};
9use crate::db::statement::Statement;
10use crate::db::value::ToSql;
11use crate::db::{pragmas, DbError};
12use crate::sqlite_vfs::ffi;
13use std::cell::RefCell;
14use std::ffi::{c_char, c_int, c_void, CStr, CString};
15use std::ops::{Deref, DerefMut};
16use std::ptr::{self, NonNull};
17
18pub struct Connection {
19 raw: NonNull<ffi::sqlite3>,
20 cached: RefCell<StatementCache>,
21}
22
23pub struct CachedStatement<'connection> {
24 statement: Option<Statement<'connection>>,
25 sql: String,
26 cache: &'connection RefCell<StatementCache>,
27}
28
29struct StatementCache {
30 statements: Vec<CachedEntry>,
31}
32
33struct CachedEntry {
34 sql: String,
35 statement: NonNull<ffi::sqlite3_stmt>,
36 parameter_count: usize,
37}
38
39impl StatementCache {
40 fn new() -> Self {
41 Self {
42 statements: Vec::new(),
43 }
44 }
45
46 fn take(&mut self, sql: &str) -> Option<(String, NonNull<ffi::sqlite3_stmt>, usize)> {
47 if let Some(entry) = self.statements.last() {
48 if entry.sql == sql {
49 let entry = self.statements.pop().expect("last cached statement exists");
50 return Some((entry.sql, entry.statement, entry.parameter_count));
51 }
52 }
53 let index = self.statements.iter().position(|entry| entry.sql == sql)?;
54 let entry = self.statements.remove(index);
55 Some((entry.sql, entry.statement, entry.parameter_count))
56 }
57
58 unsafe fn insert(
59 &mut self,
60 sql: String,
61 raw: NonNull<ffi::sqlite3_stmt>,
62 parameter_count: usize,
63 ) {
64 if let Some(index) = self.statements.iter().position(|entry| entry.sql == sql) {
65 let previous = self.statements.remove(index);
66 ffi::sqlite3_finalize(previous.statement.as_ptr());
67 }
68 self.statements.push(CachedEntry {
69 sql,
70 statement: raw,
71 parameter_count,
72 });
73 self.evict_over_capacity();
74 }
75
76 unsafe fn evict_over_capacity(&mut self) {
77 while self.statements.len() > STATEMENT_CACHE_CAPACITY {
78 let entry = self.statements.remove(0);
79 ffi::sqlite3_finalize(entry.statement.as_ptr());
80 }
81 }
82
83 unsafe fn finalize_all(&mut self) {
84 for entry in std::mem::take(&mut self.statements) {
85 ffi::sqlite3_finalize(entry.statement.as_ptr());
86 }
87 }
88}
89
90pub fn open_read_write() -> Result<Connection, DbError> {
91 open_read_write_with_page_size(true)
92}
93
94pub(crate) fn open_read_write_existing() -> Result<Connection, DbError> {
95 open_read_write_with_page_size(false)
96}
97
98fn open_read_write_with_page_size(apply_page_size: bool) -> Result<Connection, DbError> {
99 let flags = ffi::SQLITE_OPEN_READWRITE
100 | ffi::SQLITE_OPEN_CREATE
101 | ffi::SQLITE_OPEN_URI
102 | ffi::SQLITE_OPEN_NOMUTEX;
103 let connection = Connection::open(flags)?;
104 pragmas::apply_read_write(&connection, apply_page_size)?;
105 Ok(connection)
106}
107
108pub fn open_read_only() -> Result<Connection, DbError> {
109 let flags = ffi::SQLITE_OPEN_READONLY | ffi::SQLITE_OPEN_URI | ffi::SQLITE_OPEN_NOMUTEX;
110 let connection = Connection::open(flags)?;
111 pragmas::apply_read_only(&connection)?;
112 Ok(connection)
113}
114
115impl Connection {
116 fn open(flags: c_int) -> Result<Self, DbError> {
117 debug_assert!(CStr::from_bytes_with_nul(SQLITE_URI_NUL).is_ok());
118 debug_assert!(CStr::from_bytes_with_nul(VFS_NAME_NUL).is_ok());
119 let filename = unsafe { CStr::from_bytes_with_nul_unchecked(SQLITE_URI_NUL) };
120 let vfs = unsafe { CStr::from_bytes_with_nul_unchecked(VFS_NAME_NUL) };
121 let mut db = ptr::null_mut();
122 let rc = unsafe { ffi::sqlite3_open_v2(filename.as_ptr(), &mut db, flags, vfs.as_ptr()) };
123 let Some(raw) = NonNull::new(db) else {
124 return Err(DbError::Sqlite(
125 rc,
126 "sqlite3_open_v2 returned null".to_string(),
127 ));
128 };
129 if rc != ffi::SQLITE_OK {
130 let error = sqlite_error(raw.as_ptr(), rc);
131 unsafe {
132 ffi::sqlite3_close(raw.as_ptr());
133 }
134 return Err(error);
135 }
136 Ok(Self {
137 raw,
138 cached: RefCell::new(StatementCache::new()),
139 })
140 }
141
142 pub fn raw(&self) -> *mut ffi::sqlite3 {
143 self.raw.as_ptr()
144 }
145
146 pub fn execute_batch(&self, sql: &str) -> Result<(), DbError> {
147 let sql = CString::new(sql).map_err(|_| DbError::InteriorNul)?;
148 self.execute_batch_cstr(&sql)
149 }
150
151 pub(crate) fn execute_batch_nul_terminated(&self, sql: &'static [u8]) -> Result<(), DbError> {
152 debug_assert!(CStr::from_bytes_with_nul(sql).is_ok());
153 let sql = unsafe { CStr::from_bytes_with_nul_unchecked(sql) };
154 self.execute_batch_cstr(sql)
155 }
156
157 fn execute_batch_cstr(&self, sql: &CStr) -> Result<(), DbError> {
158 let mut error = ptr::null_mut();
159 let rc = unsafe {
160 ffi::sqlite3_exec(
161 self.raw.as_ptr(),
162 sql.as_ptr(),
163 None,
164 ptr::null_mut(),
165 &mut error,
166 )
167 };
168 if rc == ffi::SQLITE_OK {
169 return Ok(());
170 }
171 Err(classify_sqlite_error(rc, take_error_message(error)))
172 }
173
174 pub fn execute(&self, sql: &str, values: &[&dyn ToSql]) -> Result<(), DbError> {
175 let mut statement = self.prepare(sql)?;
176 statement.execute(values)
177 }
178
179 pub fn execute_named(&self, sql: &str, values: &[(&str, &dyn ToSql)]) -> Result<(), DbError> {
180 let mut statement = self.prepare(sql)?;
181 statement.execute_named(values)
182 }
183
184 pub fn execute_text_text(&self, sql: &str, first: &str, second: &str) -> Result<(), DbError> {
185 let mut statement = self.prepare(sql)?;
186 statement.execute_text_text(first, second)
187 }
188
189 #[inline(always)]
190 pub fn changes(&self) -> u64 {
191 unsafe { ffi::sqlite3_changes64(self.raw.as_ptr()) as u64 }
192 }
193
194 pub fn prepare(&self, sql: &str) -> Result<Statement<'_>, DbError> {
195 let sql = CString::new(sql).map_err(|_| DbError::InteriorNul)?;
196 let mut statement = ptr::null_mut();
197 let mut tail = ptr::null();
198 let rc = unsafe {
199 ffi::sqlite3_prepare_v2(
200 self.raw.as_ptr(),
201 sql.as_ptr(),
202 -1,
203 &mut statement,
204 &mut tail,
205 )
206 };
207 if rc != ffi::SQLITE_OK {
208 return Err(sqlite_error(self.raw.as_ptr(), rc));
209 }
210 let Some(raw) = NonNull::new(statement) else {
211 return Err(DbError::EmptySql);
212 };
213 if !tail_is_empty(tail) {
214 unsafe {
215 ffi::sqlite3_finalize(raw.as_ptr());
216 }
217 return Err(DbError::TrailingSql);
218 }
219 Ok(Statement::new(self.raw.as_ptr(), raw))
220 }
221
222 pub fn prepare_cached(&self, sql: &str) -> Result<CachedStatement<'_>, DbError> {
223 if let Some((cached_sql, raw, parameter_count)) = self.cached.borrow_mut().take(sql) {
224 return Ok(CachedStatement::new(
225 Statement::from_cached_raw(self.raw.as_ptr(), raw, parameter_count),
226 cached_sql,
227 &self.cached,
228 ));
229 }
230 let statement = self.prepare(sql)?;
231 Ok(CachedStatement::new(
232 statement,
233 sql.to_string(),
234 &self.cached,
235 ))
236 }
237
238 pub fn query_one<T, F>(&self, sql: &str, values: &[&dyn ToSql], f: F) -> Result<T, DbError>
239 where
240 F: FnOnce(&Row<'_>) -> Result<T, DbError>,
241 {
242 let mut statement = self.prepare(sql)?;
243 statement.query_one(values, f)
244 }
245
246 pub fn query_one_named<T, F>(
247 &self,
248 sql: &str,
249 values: &[(&str, &dyn ToSql)],
250 f: F,
251 ) -> Result<T, DbError>
252 where
253 F: FnOnce(&Row<'_>) -> Result<T, DbError>,
254 {
255 let mut statement = self.prepare(sql)?;
256 statement.query_one_named(values, f)
257 }
258
259 pub fn query_row<T, F>(&self, sql: &str, values: &[&dyn ToSql], f: F) -> Result<T, DbError>
263 where
264 F: FnOnce(&Row<'_>) -> Result<T, DbError>,
265 {
266 self.query_one(sql, values, f)
267 }
268
269 pub fn query_row_named<T, F>(
273 &self,
274 sql: &str,
275 values: &[(&str, &dyn ToSql)],
276 f: F,
277 ) -> Result<T, DbError>
278 where
279 F: FnOnce(&Row<'_>) -> Result<T, DbError>,
280 {
281 self.query_one_named(sql, values, f)
282 }
283
284 pub fn query_optional<T, F>(
285 &self,
286 sql: &str,
287 values: &[&dyn ToSql],
288 f: F,
289 ) -> Result<Option<T>, DbError>
290 where
291 F: FnOnce(&Row<'_>) -> Result<T, DbError>,
292 {
293 let mut statement = self.prepare(sql)?;
294 statement.query_optional(values, f)
295 }
296
297 pub fn query_optional_named<T, F>(
298 &self,
299 sql: &str,
300 values: &[(&str, &dyn ToSql)],
301 f: F,
302 ) -> Result<Option<T>, DbError>
303 where
304 F: FnOnce(&Row<'_>) -> Result<T, DbError>,
305 {
306 let mut statement = self.prepare(sql)?;
307 statement.query_optional_named(values, f)
308 }
309
310 pub fn query_all<T, F>(&self, sql: &str, values: &[&dyn ToSql], f: F) -> Result<Vec<T>, DbError>
311 where
312 F: FnMut(&Row<'_>) -> Result<T, DbError>,
313 {
314 let mut statement = self.prepare(sql)?;
315 statement.query_all(values, f)
316 }
317
318 pub fn query_all_named<T, F>(
319 &self,
320 sql: &str,
321 values: &[(&str, &dyn ToSql)],
322 f: F,
323 ) -> Result<Vec<T>, DbError>
324 where
325 F: FnMut(&Row<'_>) -> Result<T, DbError>,
326 {
327 let mut statement = self.prepare(sql)?;
328 statement.query_all_named(values, f)
329 }
330
331 pub fn query_map<T, F>(&self, sql: &str, values: &[&dyn ToSql], f: F) -> Result<Vec<T>, DbError>
337 where
338 F: FnMut(&Row<'_>) -> Result<T, DbError>,
339 {
340 self.query_all(sql, values, f)
341 }
342
343 pub fn query_map_named<T, F>(
349 &self,
350 sql: &str,
351 values: &[(&str, &dyn ToSql)],
352 f: F,
353 ) -> Result<Vec<T>, DbError>
354 where
355 F: FnMut(&Row<'_>) -> Result<T, DbError>,
356 {
357 self.query_all_named(sql, values, f)
358 }
359
360 pub fn exists(&self, sql: &str, values: &[&dyn ToSql]) -> Result<bool, DbError> {
361 self.query_optional(sql, values, |row| row.get::<i64>(0))
362 .map(|value| value.unwrap_or(0) != 0)
363 }
364
365 pub fn query_scalar<T: FromColumn>(
366 &self,
367 sql: &str,
368 values: &[&dyn ToSql],
369 ) -> Result<T, DbError> {
370 self.query_one(sql, values, |row| row.get(0))
371 }
372
373 pub fn query_scalar_named<T: FromColumn>(
374 &self,
375 sql: &str,
376 values: &[(&str, &dyn ToSql)],
377 ) -> Result<T, DbError> {
378 self.query_one_named(sql, values, |row| row.get(0))
379 }
380
381 pub fn query_optional_scalar<T: FromColumn>(
382 &self,
383 sql: &str,
384 values: &[&dyn ToSql],
385 ) -> Result<Option<T>, DbError> {
386 self.query_optional(sql, values, |row| row.get(0))
387 }
388
389 pub fn query_optional_string_text(
390 &self,
391 sql: &str,
392 value: &str,
393 ) -> Result<Option<String>, DbError> {
394 let mut statement = self.prepare_cached(sql)?;
395 statement.query_optional_string_text_borrowed(value)
396 }
397
398 #[doc(hidden)]
399 pub fn query_text_iter_text_len_sum<'value, I>(
403 &self,
404 sql: &str,
405 values: I,
406 ) -> Result<u64, DbError>
407 where
408 I: ExactSizeIterator<Item = &'value str>,
409 {
410 let mut statement = self.prepare_cached(sql)?;
411 statement.query_text_iter_text_len_sum(values)
412 }
413
414 pub fn query_optional_scalar_named<T: FromColumn>(
415 &self,
416 sql: &str,
417 values: &[(&str, &dyn ToSql)],
418 ) -> Result<Option<T>, DbError> {
419 self.query_optional_named(sql, values, |row| row.get(0))
420 }
421
422 pub fn query_column<T: FromColumn>(
423 &self,
424 sql: &str,
425 values: &[&dyn ToSql],
426 ) -> Result<Vec<T>, DbError> {
427 self.query_all(sql, values, |row| row.get(0))
428 }
429
430 pub fn query_column_named<T: FromColumn>(
431 &self,
432 sql: &str,
433 values: &[(&str, &dyn ToSql)],
434 ) -> Result<Vec<T>, DbError> {
435 self.query_all_named(sql, values, |row| row.get(0))
436 }
437}
438
439impl Drop for Connection {
440 fn drop(&mut self) {
441 unsafe {
442 self.cached.get_mut().finalize_all();
443 let rc = ffi::sqlite3_close(self.raw.as_ptr());
444 debug_assert_eq!(rc, ffi::SQLITE_OK, "sqlite3_close left resources open");
445 }
446 }
447}
448
449impl<'connection> CachedStatement<'connection> {
450 fn new(
451 statement: Statement<'connection>,
452 sql: String,
453 cache: &'connection RefCell<StatementCache>,
454 ) -> Self {
455 Self {
456 statement: Some(statement),
457 sql,
458 cache,
459 }
460 }
461
462 pub fn discard(mut self) {
463 if let Some(statement) = self.statement.take() {
464 unsafe {
465 ffi::sqlite3_finalize(statement.into_raw().as_ptr());
466 }
467 }
468 }
469}
470
471impl<'connection> Deref for CachedStatement<'connection> {
472 type Target = Statement<'connection>;
473
474 fn deref(&self) -> &Self::Target {
475 self.statement
476 .as_ref()
477 .expect("cached statement is present")
478 }
479}
480
481impl DerefMut for CachedStatement<'_> {
482 fn deref_mut(&mut self) -> &mut Self::Target {
483 self.statement
484 .as_mut()
485 .expect("cached statement is present")
486 }
487}
488
489impl Drop for CachedStatement<'_> {
490 fn drop(&mut self) {
491 let Some(statement) = self.statement.take() else {
492 return;
493 };
494 let parameter_count = statement.parameter_count();
495 let raw = statement.into_raw();
496 unsafe {
497 ffi::sqlite3_reset(raw.as_ptr());
498 ffi::sqlite3_clear_bindings(raw.as_ptr());
499 self.cache
500 .borrow_mut()
501 .insert(std::mem::take(&mut self.sql), raw, parameter_count);
502 }
503 }
504}
505
506pub(crate) fn sqlite_error(db: *mut ffi::sqlite3, code: c_int) -> DbError {
507 let message = unsafe {
508 let ptr = ffi::sqlite3_errmsg(db);
509 if ptr.is_null() {
510 "unknown sqlite error".to_string()
511 } else {
512 CStr::from_ptr(ptr).to_string_lossy().into_owned()
513 }
514 };
515 classify_sqlite_error(code, message)
516}
517
518fn classify_sqlite_error(code: c_int, message: String) -> DbError {
519 if code == ffi::SQLITE_CONSTRAINT {
520 DbError::Constraint(message)
521 } else {
522 DbError::Sqlite(code, message)
523 }
524}
525
526fn take_error_message(error: *mut c_char) -> String {
527 if error.is_null() {
528 return "unknown sqlite error".to_string();
529 }
530 let message = unsafe { CStr::from_ptr(error).to_string_lossy().into_owned() };
531 unsafe {
532 ffi::sqlite3_free(error.cast::<c_void>());
533 }
534 message
535}
536
537fn tail_is_empty(tail: *const c_char) -> bool {
538 if tail.is_null() {
539 return true;
540 }
541 if unsafe { *tail } == 0 {
542 return true;
543 }
544 let bytes = unsafe { CStr::from_ptr(tail).to_bytes() };
545 bytes.iter().all(u8::is_ascii_whitespace)
546}
547
548#[cfg(test)]
549mod tests {
550 use super::open_read_write;
551 use crate::config::{
552 SQLITE_URI, SQLITE_URI_NUL, STATEMENT_CACHE_CAPACITY, VFS_NAME, VFS_NAME_NUL,
553 };
554 use crate::sqlite_vfs::{lock, stable_blob};
555 use crate::stable::memory;
556 use crate::Db;
557 use serial_test::serial;
558 use std::ffi::CStr;
559
560 fn reset() {
561 stable_blob::rollback_update();
562 stable_blob::invalidate_read_cache();
563 memory::reset_for_tests();
564 lock::reset_for_tests();
565 Db::init(memory::memory_for_tests()).unwrap();
566 }
567
568 #[test]
569 fn sqlite_open_strings_are_static_nul_terminated() {
570 let uri = CStr::from_bytes_with_nul(SQLITE_URI_NUL).unwrap();
571 let vfs = CStr::from_bytes_with_nul(VFS_NAME_NUL).unwrap();
572 assert_eq!(uri.to_str().unwrap(), SQLITE_URI);
573 assert_eq!(vfs.to_str().unwrap(), VFS_NAME);
574 }
575
576 #[test]
577 #[serial]
578 fn cached_statements_are_lru_bounded() {
579 reset();
580 let connection = open_read_write().unwrap();
581
582 for index in 0..(STATEMENT_CACHE_CAPACITY + 8) {
583 let sql = format!("SELECT {index}");
584 let mut statement = connection.prepare_cached(&sql).unwrap();
585 let value = statement.query_scalar::<i64>(crate::params![]).unwrap();
586 assert_eq!(value, i64::try_from(index).unwrap());
587 }
588
589 let cache = connection.cached.borrow();
590 assert_eq!(cache.statements.len(), STATEMENT_CACHE_CAPACITY);
591 assert!(!cache.statements.iter().any(|entry| entry.sql == "SELECT 0"));
592 assert!(cache
593 .statements
594 .iter()
595 .any(|entry| entry.sql == format!("SELECT {}", STATEMENT_CACHE_CAPACITY + 7)));
596 }
597
598 #[test]
599 #[serial]
600 fn discarded_cached_statement_is_finalized_not_cached() {
601 reset();
602 let connection = open_read_write().unwrap();
603
604 let statement = connection.prepare_cached("SELECT 1").unwrap();
605 statement.discard();
606
607 assert_eq!(connection.cached.borrow().statements.len(), 0);
608 }
609
610 #[test]
611 #[serial]
612 fn cached_statement_reuses_sql_after_constraint_error() {
613 reset();
614 let connection = open_read_write().unwrap();
615 connection
616 .execute_batch("CREATE TABLE cached_error(k TEXT PRIMARY KEY, v TEXT NOT NULL)")
617 .unwrap();
618
619 {
620 let mut statement = connection
621 .prepare_cached("INSERT INTO cached_error(k, v) VALUES (?1, ?2)")
622 .unwrap();
623 statement.execute(crate::params!["a", "one"]).unwrap();
624 }
625 {
626 let mut statement = connection
627 .prepare_cached("INSERT INTO cached_error(k, v) VALUES (?1, ?2)")
628 .unwrap();
629 let duplicate = statement.execute(crate::params!["a", "duplicate"]);
630 assert!(matches!(duplicate, Err(crate::db::DbError::Constraint(_))));
631 }
632 {
633 let mut statement = connection
634 .prepare_cached("INSERT INTO cached_error(k, v) VALUES (?1, ?2)")
635 .unwrap();
636 statement.execute(crate::params!["b", "two"]).unwrap();
637 }
638
639 let values = connection
640 .query_column::<String>("SELECT v FROM cached_error ORDER BY k", crate::params![])
641 .unwrap();
642 assert_eq!(values, vec!["one".to_string(), "two".to_string()]);
643 }
644
645 #[test]
646 #[serial]
647 fn regular_statements_are_finalized_before_connection_close() {
648 reset();
649 let connection = open_read_write().unwrap();
650
651 {
652 let _statement = connection.prepare("SELECT 1").unwrap();
653 assert_eq!(open_statement_count(&connection), 1);
654 }
655 assert_eq!(open_statement_count(&connection), 0);
656
657 for _ in 0..512 {
658 let value = connection
659 .query_one("SELECT 42", crate::params![], |row| row.get::<i64>(0))
660 .unwrap();
661 assert_eq!(value, 42);
662 }
663 assert_eq!(open_statement_count(&connection), 0);
664 }
665
666 #[test]
667 #[serial]
668 fn cached_and_regular_statement_lifetimes_do_not_double_finalize() {
669 reset();
670 let connection = open_read_write().unwrap();
671
672 {
673 let mut cached = connection.prepare_cached("SELECT ?1").unwrap();
674 let value = cached.query_scalar::<i64>(crate::params![7_i64]).unwrap();
675 assert_eq!(value, 7);
676 }
677 assert_eq!(open_statement_count(&connection), 1);
678
679 {
680 let _regular = connection.prepare("SELECT 8").unwrap();
681 assert_eq!(open_statement_count(&connection), 2);
682 }
683 assert_eq!(open_statement_count(&connection), 1);
684
685 unsafe {
686 connection.cached.borrow_mut().finalize_all();
687 }
688 assert_eq!(open_statement_count(&connection), 0);
689 }
690
691 #[test]
692 #[serial]
693 fn prepare_error_paths_do_not_leave_statements_open() {
694 reset();
695 let connection = open_read_write().unwrap();
696
697 assert!(connection.prepare("").is_err());
698 assert_eq!(open_statement_count(&connection), 0);
699
700 assert!(connection.prepare("SELECT 1; SELECT 2").is_err());
701 assert_eq!(open_statement_count(&connection), 0);
702
703 assert!(connection.prepare("SELECT * FROM missing_table").is_err());
704 assert_eq!(open_statement_count(&connection), 0);
705 }
706
707 fn open_statement_count(connection: &super::Connection) -> usize {
708 let mut count = 0;
709 let mut statement = std::ptr::null_mut();
710 loop {
711 statement =
712 unsafe { crate::sqlite_vfs::ffi::sqlite3_next_stmt(connection.raw(), statement) };
713 if statement.is_null() {
714 return count;
715 }
716 count += 1;
717 }
718 }
719}