use std::path::Path;
use r2d2::Pool;
use r2d2_sqlite::SqliteConnectionManager;
use rusqlite::Connection;
use crate::errors::SqliteGraphError;
pub type ConnectionManager = SqliteConnectionManager;
pub type PooledConnection = r2d2::PooledConnection<ConnectionManager>;
pub struct PoolManager {
pool: Option<Pool<ConnectionManager>>,
direct_conn: Option<Connection>,
}
impl PoolManager {
pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, SqliteGraphError> {
Self::with_max_size(path, 5)
}
pub fn with_max_size<P: AsRef<Path>>(path: P, max_size: u32) -> Result<Self, SqliteGraphError> {
let manager = SqliteConnectionManager::file(path);
let pool = Pool::builder()
.max_size(max_size)
.build(manager)
.map_err(|e| SqliteGraphError::connection(e.to_string()))?;
Ok(Self {
pool: Some(pool),
direct_conn: None,
})
}
pub fn in_memory() -> Result<Self, SqliteGraphError> {
let conn = Connection::open_in_memory()
.map_err(|e| SqliteGraphError::connection(e.to_string()))?;
Ok(Self {
pool: None,
direct_conn: Some(conn),
})
}
pub fn from_connection(conn: Connection) -> Self {
Self {
pool: None,
direct_conn: Some(conn),
}
}
pub fn get(&self) -> Result<PooledConnection, SqliteGraphError> {
self.pool
.as_ref()
.ok_or_else(|| {
SqliteGraphError::connection(
"Cannot checkout from in-memory database (use direct_connection() instead)"
.to_string(),
)
})?
.get()
.map_err(|e| SqliteGraphError::connection(e.to_string()))
}
pub fn direct_connection(&self) -> Option<&Connection> {
self.direct_conn.as_ref()
}
pub fn is_in_memory(&self) -> bool {
self.direct_conn.is_some()
}
pub fn max_size(&self) -> Option<u32> {
self.pool.as_ref().map(|p| p.max_size())
}
pub fn configure_pool<F>(&self, f: F) -> Result<(), SqliteGraphError>
where
F: FnOnce(&Connection) -> Result<(), rusqlite::Error>,
{
if let Some(pool) = &self.pool {
let conn = pool
.get()
.map_err(|e| SqliteGraphError::connection(e.to_string()))?;
f(&conn).map_err(|e| SqliteGraphError::connection(e.to_string()))?;
}
Ok(())
}
pub fn configure_direct<F>(&mut self, f: F) -> Result<(), SqliteGraphError>
where
F: FnOnce(&Connection) -> Result<(), rusqlite::Error>,
{
if let Some(conn) = &self.direct_conn {
f(conn).map_err(|e| SqliteGraphError::connection(e.to_string()))?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pool_manager_size() {
let manager = PoolManager::with_max_size(":memory:", 10).unwrap();
assert!(!manager.is_in_memory());
}
#[test]
fn test_in_memory_pool_manager() {
let manager = PoolManager::in_memory().unwrap();
assert!(manager.is_in_memory());
assert!(manager.direct_connection().is_some());
assert!(manager.pool.is_none());
}
#[test]
fn test_from_connection() {
let conn = Connection::open_in_memory().unwrap();
let manager = PoolManager::from_connection(conn);
assert!(manager.is_in_memory());
assert!(manager.direct_connection().is_some());
}
}