azoth_sqlite/
read_pool.rs1use azoth_core::{
7 error::{AzothError, Result},
8 ReadPoolConfig,
9};
10use rusqlite::{Connection, OpenFlags};
11use std::path::{Path, PathBuf};
12use std::sync::Mutex;
13use std::time::Duration;
14use tokio::sync::{Semaphore, SemaphorePermit};
15
16pub struct PooledSqliteConnection<'a> {
21 conn: std::sync::MutexGuard<'a, Connection>,
22 _permit: SemaphorePermit<'a>,
23}
24
25impl<'a> PooledSqliteConnection<'a> {
26 pub fn query_row<T, P, F>(&self, sql: &str, params: P, f: F) -> Result<T>
34 where
35 P: rusqlite::Params,
36 F: FnOnce(&rusqlite::Row<'_>) -> rusqlite::Result<T>,
37 {
38 self.conn
39 .query_row(sql, params, f)
40 .map_err(|e| AzothError::Projection(e.to_string()))
41 }
42
43 pub fn prepare(&self, sql: &str) -> Result<rusqlite::Statement<'_>> {
45 self.conn
46 .prepare(sql)
47 .map_err(|e| AzothError::Projection(e.to_string()))
48 }
49
50 pub fn connection(&self) -> &Connection {
54 &self.conn
55 }
56}
57
58pub struct SqliteReadPool {
74 connections: Vec<Mutex<Connection>>,
75 semaphore: Semaphore,
76 acquire_timeout: Duration,
77 enabled: bool,
78 db_path: PathBuf,
79}
80
81impl SqliteReadPool {
82 pub fn new(db_path: &Path, config: ReadPoolConfig) -> Result<Self> {
86 let pool_size = if config.enabled { config.pool_size } else { 1 };
87 let mut connections = Vec::with_capacity(pool_size);
88
89 for _ in 0..pool_size {
90 let conn = Connection::open_with_flags(
91 db_path,
92 OpenFlags::SQLITE_OPEN_READ_ONLY | OpenFlags::SQLITE_OPEN_NO_MUTEX,
93 )
94 .map_err(|e| AzothError::Projection(e.to_string()))?;
95
96 connections.push(Mutex::new(conn));
97 }
98
99 Ok(Self {
100 connections,
101 semaphore: Semaphore::new(pool_size),
102 acquire_timeout: Duration::from_millis(config.acquire_timeout_ms),
103 enabled: config.enabled,
104 db_path: db_path.to_path_buf(),
105 })
106 }
107
108 pub async fn acquire(&self) -> Result<PooledSqliteConnection<'_>> {
113 let permit = tokio::time::timeout(self.acquire_timeout, self.semaphore.acquire())
114 .await
115 .map_err(|_| {
116 AzothError::Timeout(format!(
117 "Read pool acquire timeout after {:?}",
118 self.acquire_timeout
119 ))
120 })?
121 .map_err(|e| AzothError::Internal(format!("Semaphore closed: {}", e)))?;
122
123 for conn in &self.connections {
125 if let Ok(guard) = conn.try_lock() {
126 return Ok(PooledSqliteConnection {
127 conn: guard,
128 _permit: permit,
129 });
130 }
131 }
132
133 Err(AzothError::Internal(
135 "No available connection despite having permit".into(),
136 ))
137 }
138
139 pub fn try_acquire(&self) -> Result<Option<PooledSqliteConnection<'_>>> {
143 match self.semaphore.try_acquire() {
144 Ok(permit) => {
145 for conn in &self.connections {
146 if let Ok(guard) = conn.try_lock() {
147 return Ok(Some(PooledSqliteConnection {
148 conn: guard,
149 _permit: permit,
150 }));
151 }
152 }
153 Ok(None)
155 }
156 Err(_) => Ok(None),
157 }
158 }
159
160 pub fn available_permits(&self) -> usize {
162 self.semaphore.available_permits()
163 }
164
165 pub fn is_enabled(&self) -> bool {
167 self.enabled
168 }
169
170 pub fn db_path(&self) -> &Path {
172 &self.db_path
173 }
174
175 pub fn pool_size(&self) -> usize {
177 self.connections.len()
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184 use tempfile::TempDir;
185
186 fn create_test_db() -> (TempDir, PathBuf) {
187 let temp_dir = TempDir::new().unwrap();
188 let db_path = temp_dir.path().join("test.db");
189
190 let conn = Connection::open(&db_path).unwrap();
192 conn.execute("CREATE TABLE test (id INTEGER PRIMARY KEY, value TEXT)", [])
193 .unwrap();
194 conn.execute("INSERT INTO test (id, value) VALUES (1, 'hello')", [])
195 .unwrap();
196 conn.execute("INSERT INTO test (id, value) VALUES (2, 'world')", [])
197 .unwrap();
198 drop(conn);
199
200 (temp_dir, db_path)
201 }
202
203 #[tokio::test]
204 async fn test_pool_acquire_release() {
205 let (_temp_dir, db_path) = create_test_db();
206 let config = ReadPoolConfig::enabled(2);
207 let pool = SqliteReadPool::new(&db_path, config).unwrap();
208
209 assert_eq!(pool.available_permits(), 2);
210
211 let conn1 = pool.acquire().await.unwrap();
213 assert_eq!(pool.available_permits(), 1);
214
215 let conn2 = pool.acquire().await.unwrap();
217 assert_eq!(pool.available_permits(), 0);
218
219 assert!(pool.try_acquire().unwrap().is_none());
221
222 drop(conn1);
224 assert_eq!(pool.available_permits(), 1);
225
226 drop(conn2);
228 assert_eq!(pool.available_permits(), 2);
229 }
230
231 #[tokio::test]
232 async fn test_pool_query() {
233 let (_temp_dir, db_path) = create_test_db();
234 let config = ReadPoolConfig::enabled(2);
235 let pool = SqliteReadPool::new(&db_path, config).unwrap();
236
237 let conn = pool.acquire().await.unwrap();
238 let value: String = conn
239 .query_row("SELECT value FROM test WHERE id = ?1", [1], |row| {
240 row.get(0)
241 })
242 .unwrap();
243 assert_eq!(value, "hello");
244
245 let count: i64 = conn
246 .query_row("SELECT COUNT(*) FROM test", [], |row| row.get(0))
247 .unwrap();
248 assert_eq!(count, 2);
249 }
250
251 #[test]
252 fn test_try_acquire() {
253 let (_temp_dir, db_path) = create_test_db();
254 let config = ReadPoolConfig::enabled(1);
255 let pool = SqliteReadPool::new(&db_path, config).unwrap();
256
257 let conn = pool.try_acquire().unwrap();
259 assert!(conn.is_some());
260
261 assert!(pool.try_acquire().unwrap().is_none());
263
264 drop(conn);
266 assert!(pool.try_acquire().unwrap().is_some());
267 }
268}