use azoth_core::{
error::{AzothError, Result},
ReadPoolConfig,
};
use parking_lot::Mutex;
use rusqlite::{Connection, OpenFlags};
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::{Semaphore, SemaphorePermit};
pub struct PooledSqliteConnection<'a> {
conn: parking_lot::MutexGuard<'a, Connection>,
_permit: SemaphorePermit<'a>,
}
impl<'a> PooledSqliteConnection<'a> {
pub fn query_row<T, P, F>(&self, sql: &str, params: P, f: F) -> Result<T>
where
P: rusqlite::Params,
F: FnOnce(&rusqlite::Row<'_>) -> rusqlite::Result<T>,
{
self.conn
.query_row(sql, params, f)
.map_err(|e| AzothError::Projection(e.to_string()))
}
pub fn prepare(&self, sql: &str) -> Result<rusqlite::Statement<'_>> {
self.conn
.prepare(sql)
.map_err(|e| AzothError::Projection(e.to_string()))
}
pub fn connection(&self) -> &Connection {
&self.conn
}
}
pub struct SqliteReadPool {
connections: Vec<Mutex<Connection>>,
semaphore: Semaphore,
acquire_timeout: Duration,
enabled: bool,
db_path: PathBuf,
next_idx: AtomicUsize,
}
impl SqliteReadPool {
pub fn new(db_path: &Path, config: ReadPoolConfig) -> Result<Self> {
let pool_size = if config.enabled { config.pool_size } else { 1 };
let mut connections = Vec::with_capacity(pool_size);
for _ in 0..pool_size {
let conn = Connection::open_with_flags(
db_path,
OpenFlags::SQLITE_OPEN_READ_ONLY | OpenFlags::SQLITE_OPEN_NO_MUTEX,
)
.map_err(|e| AzothError::Projection(e.to_string()))?;
connections.push(Mutex::new(conn));
}
Ok(Self {
connections,
semaphore: Semaphore::new(pool_size),
acquire_timeout: Duration::from_millis(config.acquire_timeout_ms),
enabled: config.enabled,
db_path: db_path.to_path_buf(),
next_idx: AtomicUsize::new(0),
})
}
pub async fn acquire(&self) -> Result<PooledSqliteConnection<'_>> {
let permit = tokio::time::timeout(self.acquire_timeout, self.semaphore.acquire())
.await
.map_err(|_| {
AzothError::Timeout(format!(
"Read pool acquire timeout after {:?}",
self.acquire_timeout
))
})?
.map_err(|e| AzothError::Internal(format!("Semaphore closed: {}", e)))?;
let start = self.next_idx.fetch_add(1, Ordering::Relaxed) % self.connections.len();
for i in 0..self.connections.len() {
let idx = (start + i) % self.connections.len();
if let Some(guard) = self.connections[idx].try_lock() {
return Ok(PooledSqliteConnection {
conn: guard,
_permit: permit,
});
}
}
Err(AzothError::Internal(
"No available connection despite having permit".into(),
))
}
pub fn try_acquire(&self) -> Result<Option<PooledSqliteConnection<'_>>> {
match self.semaphore.try_acquire() {
Ok(permit) => {
let start = self.next_idx.fetch_add(1, Ordering::Relaxed) % self.connections.len();
for i in 0..self.connections.len() {
let idx = (start + i) % self.connections.len();
if let Some(guard) = self.connections[idx].try_lock() {
return Ok(Some(PooledSqliteConnection {
conn: guard,
_permit: permit,
}));
}
}
Ok(None)
}
Err(_) => Ok(None),
}
}
pub fn available_permits(&self) -> usize {
self.semaphore.available_permits()
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
pub fn db_path(&self) -> &Path {
&self.db_path
}
pub fn pool_size(&self) -> usize {
self.connections.len()
}
pub fn acquire_blocking(&self) -> Result<PooledSqliteConnection<'_>> {
let deadline = Instant::now() + self.acquire_timeout;
loop {
if let Ok(Some(conn)) = self.try_acquire() {
return Ok(conn);
}
if Instant::now() >= deadline {
return Err(AzothError::Timeout(format!(
"Read pool acquire timeout after {:?}",
self.acquire_timeout
)));
}
std::thread::sleep(Duration::from_millis(1));
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn create_test_db() -> (TempDir, PathBuf) {
let temp_dir = TempDir::new().unwrap();
let db_path = temp_dir.path().join("test.db");
let conn = Connection::open(&db_path).unwrap();
conn.execute("CREATE TABLE test (id INTEGER PRIMARY KEY, value TEXT)", [])
.unwrap();
conn.execute("INSERT INTO test (id, value) VALUES (1, 'hello')", [])
.unwrap();
conn.execute("INSERT INTO test (id, value) VALUES (2, 'world')", [])
.unwrap();
drop(conn);
(temp_dir, db_path)
}
#[tokio::test]
async fn test_pool_acquire_release() {
let (_temp_dir, db_path) = create_test_db();
let config = ReadPoolConfig::enabled(2);
let pool = SqliteReadPool::new(&db_path, config).unwrap();
assert_eq!(pool.available_permits(), 2);
let conn1 = pool.acquire().await.unwrap();
assert_eq!(pool.available_permits(), 1);
let conn2 = pool.acquire().await.unwrap();
assert_eq!(pool.available_permits(), 0);
assert!(pool.try_acquire().unwrap().is_none());
drop(conn1);
assert_eq!(pool.available_permits(), 1);
drop(conn2);
assert_eq!(pool.available_permits(), 2);
}
#[tokio::test]
async fn test_pool_query() {
let (_temp_dir, db_path) = create_test_db();
let config = ReadPoolConfig::enabled(2);
let pool = SqliteReadPool::new(&db_path, config).unwrap();
let conn = pool.acquire().await.unwrap();
let value: String = conn
.query_row("SELECT value FROM test WHERE id = ?1", [1], |row| {
row.get(0)
})
.unwrap();
assert_eq!(value, "hello");
let count: i64 = conn
.query_row("SELECT COUNT(*) FROM test", [], |row| row.get(0))
.unwrap();
assert_eq!(count, 2);
}
#[test]
fn test_try_acquire() {
let (_temp_dir, db_path) = create_test_db();
let config = ReadPoolConfig::enabled(1);
let pool = SqliteReadPool::new(&db_path, config).unwrap();
let conn = pool.try_acquire().unwrap();
assert!(conn.is_some());
assert!(pool.try_acquire().unwrap().is_none());
drop(conn);
assert!(pool.try_acquire().unwrap().is_some());
}
}