ankurah_storage_sqlite/
connection.rs1use std::path::PathBuf;
4use std::sync::Arc;
5
6use rusqlite::Connection;
7use tokio::sync::Mutex;
8
9use crate::error::SqliteError;
10
11#[derive(Clone, Debug)]
13pub enum SqliteConfig {
14 File(PathBuf),
16 Memory,
18}
19
20pub struct SqliteConnectionManager {
25 config: SqliteConfig,
26}
27
28impl SqliteConnectionManager {
29 pub fn new(config: SqliteConfig) -> Self { Self { config } }
30
31 pub fn file(path: impl Into<PathBuf>) -> Self { Self::new(SqliteConfig::File(path.into())) }
32
33 pub fn memory() -> Self { Self::new(SqliteConfig::Memory) }
34
35 fn create_connection(&self) -> Result<Connection, SqliteError> {
36 let conn = match &self.config {
37 SqliteConfig::File(path) => Connection::open(path)?,
38 SqliteConfig::Memory => Connection::open_in_memory()?,
39 };
40
41 conn.execute_batch(
43 "PRAGMA journal_mode=WAL;
44 PRAGMA synchronous=NORMAL;
45 PRAGMA foreign_keys=ON;
46 PRAGMA cache_size=-64000;
47 PRAGMA mmap_size=268435456;
48 PRAGMA temp_store=MEMORY;",
49 )?;
50
51 Ok(conn)
52 }
53}
54
55pub struct PooledConnection {
60 inner: Arc<Mutex<Connection>>,
61}
62
63impl PooledConnection {
64 pub fn new(conn: Connection) -> Self { Self { inner: Arc::new(Mutex::new(conn)) } }
65
66 pub async fn with_connection<F, T>(&self, f: F) -> Result<T, SqliteError>
71 where
72 F: FnOnce(&Connection) -> Result<T, SqliteError> + Send + 'static,
73 T: Send + 'static,
74 {
75 let conn = self.inner.clone();
76 tokio::task::spawn_blocking(move || {
77 let guard = conn.blocking_lock();
78 f(&guard)
79 })
80 .await
81 .map_err(|e| SqliteError::TaskJoin(e.to_string()))?
82 }
83
84 pub async fn with_connection_mut<F, T>(&self, f: F) -> Result<T, SqliteError>
86 where
87 F: FnOnce(&mut Connection) -> Result<T, SqliteError> + Send + 'static,
88 T: Send + 'static,
89 {
90 let conn = self.inner.clone();
91 tokio::task::spawn_blocking(move || {
92 let mut guard = conn.blocking_lock();
93 f(&mut guard)
94 })
95 .await
96 .map_err(|e| SqliteError::TaskJoin(e.to_string()))?
97 }
98}
99
100impl Clone for PooledConnection {
101 fn clone(&self) -> Self { Self { inner: self.inner.clone() } }
102}
103
104impl bb8::ManageConnection for SqliteConnectionManager {
105 type Connection = PooledConnection;
106 type Error = SqliteError;
107
108 fn connect(&self) -> impl std::future::Future<Output = Result<Self::Connection, Self::Error>> + Send {
109 let config = self.config.clone();
110 async move {
111 let manager = SqliteConnectionManager::new(config);
112 tokio::task::spawn_blocking(move || manager.create_connection().map(PooledConnection::new))
113 .await
114 .map_err(|e| SqliteError::TaskJoin(e.to_string()))?
115 }
116 }
117
118 #[allow(refining_impl_trait)]
119 fn is_valid<'a, 'b>(&'a self, conn: &'b mut Self::Connection) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send {
120 let conn_inner = conn.inner.clone();
121 async move {
122 tokio::task::spawn_blocking(move || {
123 let guard = conn_inner.blocking_lock();
124 guard.execute_batch("SELECT 1").map_err(SqliteError::from)
125 })
126 .await
127 .map_err(|e| SqliteError::TaskJoin(e.to_string()))?
128 }
129 }
130
131 fn has_broken(&self, _conn: &mut Self::Connection) -> bool { false }
132}