azoth_sqlite/
read_pool.rs1use azoth_core::{
7 error::{AzothError, Result},
8 ReadPoolConfig,
9};
10use rusqlite::{Connection, OpenFlags};
11use std::path::{Path, PathBuf};
12use std::sync::Mutex;
13use std::time::{Duration, Instant};
14use tokio::sync::{Semaphore, SemaphorePermit};
15
16pub struct PooledSqliteConnection<'a> {
21 conn: std::sync::MutexGuard<'a, Connection>,
22 _permit: SemaphorePermit<'a>,
23}
24
25impl<'a> PooledSqliteConnection<'a> {
26 pub fn query_row<T, P, F>(&self, sql: &str, params: P, f: F) -> Result<T>
34 where
35 P: rusqlite::Params,
36 F: FnOnce(&rusqlite::Row<'_>) -> rusqlite::Result<T>,
37 {
38 self.conn
39 .query_row(sql, params, f)
40 .map_err(|e| AzothError::Projection(e.to_string()))
41 }
42
43 pub fn prepare(&self, sql: &str) -> Result<rusqlite::Statement<'_>> {
45 self.conn
46 .prepare(sql)
47 .map_err(|e| AzothError::Projection(e.to_string()))
48 }
49
50 pub fn connection(&self) -> &Connection {
54 &self.conn
55 }
56}
57
58pub struct SqliteReadPool {
74 connections: Vec<Mutex<Connection>>,
75 semaphore: Semaphore,
76 acquire_timeout: Duration,
77 enabled: bool,
78 db_path: PathBuf,
79}
80
81impl SqliteReadPool {
82 pub fn new(db_path: &Path, config: ReadPoolConfig) -> Result<Self> {
86 let pool_size = if config.enabled { config.pool_size } else { 1 };
87 let mut connections = Vec::with_capacity(pool_size);
88
89 for _ in 0..pool_size {
90 let conn = Connection::open_with_flags(
91 db_path,
92 OpenFlags::SQLITE_OPEN_READ_ONLY | OpenFlags::SQLITE_OPEN_NO_MUTEX,
93 )
94 .map_err(|e| AzothError::Projection(e.to_string()))?;
95
96 connections.push(Mutex::new(conn));
97 }
98
99 Ok(Self {
100 connections,
101 semaphore: Semaphore::new(pool_size),
102 acquire_timeout: Duration::from_millis(config.acquire_timeout_ms),
103 enabled: config.enabled,
104 db_path: db_path.to_path_buf(),
105 })
106 }
107
108 pub async fn acquire(&self) -> Result<PooledSqliteConnection<'_>> {
113 let permit = tokio::time::timeout(self.acquire_timeout, self.semaphore.acquire())
114 .await
115 .map_err(|_| {
116 AzothError::Timeout(format!(
117 "Read pool acquire timeout after {:?}",
118 self.acquire_timeout
119 ))
120 })?
121 .map_err(|e| AzothError::Internal(format!("Semaphore closed: {}", e)))?;
122
123 for conn in &self.connections {
125 if let Ok(guard) = conn.try_lock() {
126 return Ok(PooledSqliteConnection {
127 conn: guard,
128 _permit: permit,
129 });
130 }
131 }
132
133 Err(AzothError::Internal(
135 "No available connection despite having permit".into(),
136 ))
137 }
138
139 pub fn try_acquire(&self) -> Result<Option<PooledSqliteConnection<'_>>> {
143 match self.semaphore.try_acquire() {
144 Ok(permit) => {
145 for conn in &self.connections {
146 if let Ok(guard) = conn.try_lock() {
147 return Ok(Some(PooledSqliteConnection {
148 conn: guard,
149 _permit: permit,
150 }));
151 }
152 }
153 Ok(None)
155 }
156 Err(_) => Ok(None),
157 }
158 }
159
160 pub fn available_permits(&self) -> usize {
162 self.semaphore.available_permits()
163 }
164
165 pub fn is_enabled(&self) -> bool {
167 self.enabled
168 }
169
170 pub fn db_path(&self) -> &Path {
172 &self.db_path
173 }
174
175 pub fn pool_size(&self) -> usize {
177 self.connections.len()
178 }
179
180 pub fn acquire_blocking(&self) -> Result<PooledSqliteConnection<'_>> {
184 let deadline = Instant::now() + self.acquire_timeout;
185
186 loop {
187 if let Ok(Some(conn)) = self.try_acquire() {
188 return Ok(conn);
189 }
190
191 if Instant::now() >= deadline {
192 return Err(AzothError::Timeout(format!(
193 "Read pool acquire timeout after {:?}",
194 self.acquire_timeout
195 )));
196 }
197
198 std::thread::sleep(Duration::from_millis(1));
199 }
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206 use tempfile::TempDir;
207
208 fn create_test_db() -> (TempDir, PathBuf) {
209 let temp_dir = TempDir::new().unwrap();
210 let db_path = temp_dir.path().join("test.db");
211
212 let conn = Connection::open(&db_path).unwrap();
214 conn.execute("CREATE TABLE test (id INTEGER PRIMARY KEY, value TEXT)", [])
215 .unwrap();
216 conn.execute("INSERT INTO test (id, value) VALUES (1, 'hello')", [])
217 .unwrap();
218 conn.execute("INSERT INTO test (id, value) VALUES (2, 'world')", [])
219 .unwrap();
220 drop(conn);
221
222 (temp_dir, db_path)
223 }
224
225 #[tokio::test]
226 async fn test_pool_acquire_release() {
227 let (_temp_dir, db_path) = create_test_db();
228 let config = ReadPoolConfig::enabled(2);
229 let pool = SqliteReadPool::new(&db_path, config).unwrap();
230
231 assert_eq!(pool.available_permits(), 2);
232
233 let conn1 = pool.acquire().await.unwrap();
235 assert_eq!(pool.available_permits(), 1);
236
237 let conn2 = pool.acquire().await.unwrap();
239 assert_eq!(pool.available_permits(), 0);
240
241 assert!(pool.try_acquire().unwrap().is_none());
243
244 drop(conn1);
246 assert_eq!(pool.available_permits(), 1);
247
248 drop(conn2);
250 assert_eq!(pool.available_permits(), 2);
251 }
252
253 #[tokio::test]
254 async fn test_pool_query() {
255 let (_temp_dir, db_path) = create_test_db();
256 let config = ReadPoolConfig::enabled(2);
257 let pool = SqliteReadPool::new(&db_path, config).unwrap();
258
259 let conn = pool.acquire().await.unwrap();
260 let value: String = conn
261 .query_row("SELECT value FROM test WHERE id = ?1", [1], |row| {
262 row.get(0)
263 })
264 .unwrap();
265 assert_eq!(value, "hello");
266
267 let count: i64 = conn
268 .query_row("SELECT COUNT(*) FROM test", [], |row| row.get(0))
269 .unwrap();
270 assert_eq!(count, 2);
271 }
272
273 #[test]
274 fn test_try_acquire() {
275 let (_temp_dir, db_path) = create_test_db();
276 let config = ReadPoolConfig::enabled(1);
277 let pool = SqliteReadPool::new(&db_path, config).unwrap();
278
279 let conn = pool.try_acquire().unwrap();
281 assert!(conn.is_some());
282
283 assert!(pool.try_acquire().unwrap().is_none());
285
286 drop(conn);
288 assert!(pool.try_acquire().unwrap().is_some());
289 }
290}