Skip to main content

ankurah_storage_sqlite/
connection.rs

1//! Connection manager for bb8 pool with rusqlite
2
3use std::path::PathBuf;
4use std::sync::Arc;
5
6use rusqlite::Connection;
7use tokio::sync::Mutex;
8
9use crate::error::SqliteError;
10
11/// Configuration for SQLite connections
12#[derive(Clone, Debug)]
13pub enum SqliteConfig {
14    /// File-based database
15    File(PathBuf),
16    /// In-memory database (for testing)
17    Memory,
18}
19
20/// A wrapper around a SQLite connection that can be used with bb8
21///
22/// Since rusqlite::Connection is not Send, we wrap it in a Mutex
23/// and use spawn_blocking for all operations.
24pub struct SqliteConnectionManager {
25    config: SqliteConfig,
26}
27
28impl SqliteConnectionManager {
29    pub fn new(config: SqliteConfig) -> Self { Self { config } }
30
31    pub fn file(path: impl Into<PathBuf>) -> Self { Self::new(SqliteConfig::File(path.into())) }
32
33    pub fn memory() -> Self { Self::new(SqliteConfig::Memory) }
34
35    fn create_connection(&self) -> Result<Connection, SqliteError> {
36        let conn = match &self.config {
37            SqliteConfig::File(path) => Connection::open(path)?,
38            SqliteConfig::Memory => Connection::open_in_memory()?,
39        };
40
41        // Performance optimizations
42        conn.execute_batch(
43            "PRAGMA journal_mode=WAL;
44             PRAGMA synchronous=NORMAL;
45             PRAGMA foreign_keys=ON;
46             PRAGMA cache_size=-64000;
47             PRAGMA mmap_size=268435456;
48             PRAGMA temp_store=MEMORY;",
49        )?;
50
51        Ok(conn)
52    }
53}
54
55/// A pooled SQLite connection wrapper
56///
57/// Wraps the rusqlite Connection in an Arc<Mutex> for thread-safe access
58/// since rusqlite connections are not Send.
59pub struct PooledConnection {
60    inner: Arc<Mutex<Connection>>,
61}
62
63impl PooledConnection {
64    pub fn new(conn: Connection) -> Self { Self { inner: Arc::new(Mutex::new(conn)) } }
65
66    /// Execute a function with the connection
67    ///
68    /// This acquires the mutex lock and runs the provided closure with the connection.
69    /// The closure is executed within spawn_blocking since rusqlite operations are synchronous.
70    pub async fn with_connection<F, T>(&self, f: F) -> Result<T, SqliteError>
71    where
72        F: FnOnce(&Connection) -> Result<T, SqliteError> + Send + 'static,
73        T: Send + 'static,
74    {
75        let conn = self.inner.clone();
76        tokio::task::spawn_blocking(move || {
77            let guard = conn.blocking_lock();
78            f(&guard)
79        })
80        .await
81        .map_err(|e| SqliteError::TaskJoin(e.to_string()))?
82    }
83
84    /// Execute a function with mutable access to the connection
85    pub async fn with_connection_mut<F, T>(&self, f: F) -> Result<T, SqliteError>
86    where
87        F: FnOnce(&mut Connection) -> Result<T, SqliteError> + Send + 'static,
88        T: Send + 'static,
89    {
90        let conn = self.inner.clone();
91        tokio::task::spawn_blocking(move || {
92            let mut guard = conn.blocking_lock();
93            f(&mut guard)
94        })
95        .await
96        .map_err(|e| SqliteError::TaskJoin(e.to_string()))?
97    }
98}
99
100impl Clone for PooledConnection {
101    fn clone(&self) -> Self { Self { inner: self.inner.clone() } }
102}
103
104impl bb8::ManageConnection for SqliteConnectionManager {
105    type Connection = PooledConnection;
106    type Error = SqliteError;
107
108    fn connect(&self) -> impl std::future::Future<Output = Result<Self::Connection, Self::Error>> + Send {
109        let config = self.config.clone();
110        async move {
111            let manager = SqliteConnectionManager::new(config);
112            tokio::task::spawn_blocking(move || manager.create_connection().map(PooledConnection::new))
113                .await
114                .map_err(|e| SqliteError::TaskJoin(e.to_string()))?
115        }
116    }
117
118    #[allow(refining_impl_trait)]
119    fn is_valid<'a, 'b>(&'a self, conn: &'b mut Self::Connection) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send {
120        let conn_inner = conn.inner.clone();
121        async move {
122            tokio::task::spawn_blocking(move || {
123                let guard = conn_inner.blocking_lock();
124                guard.execute_batch("SELECT 1").map_err(SqliteError::from)
125            })
126            .await
127            .map_err(|e| SqliteError::TaskJoin(e.to_string()))?
128        }
129    }
130
131    fn has_broken(&self, _conn: &mut Self::Connection) -> bool { false }
132}