#[allow(unused)]
use std::time::Duration;
use r2d2::ManageConnection;
use rocket::{Rocket, Build};
#[allow(unused_imports)]
use crate::{Config, Error};
pub trait Poolable: Send + Sized + 'static {
type Manager: ManageConnection<Connection=Self>;
type Error: std::fmt::Debug;
fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self>;
}
#[allow(type_alias_bounds)]
pub type PoolResult<P: Poolable> = Result<r2d2::Pool<P::Manager>, Error<P::Error>>;
#[cfg(feature = "diesel_sqlite_pool")]
impl Poolable for diesel::SqliteConnection {
type Manager = diesel::r2d2::ConnectionManager<diesel::SqliteConnection>;
type Error = std::convert::Infallible;
fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
use diesel::{SqliteConnection, connection::SimpleConnection};
use diesel::r2d2::{CustomizeConnection, ConnectionManager, Error, Pool};
#[derive(Debug)]
struct Customizer;
impl CustomizeConnection<SqliteConnection, Error> for Customizer {
fn on_acquire(&self, conn: &mut SqliteConnection) -> Result<(), Error> {
conn.batch_execute("\
PRAGMA journal_mode = WAL;\
PRAGMA busy_timeout = 1000;\
PRAGMA foreign_keys = ON;\
").map_err(Error::QueryError)?;
Ok(())
}
}
let config = Config::from(db_name, rocket)?;
let manager = ConnectionManager::new(&config.url);
let pool = Pool::builder()
.connection_customizer(Box::new(Customizer))
.max_size(config.pool_size)
.connection_timeout(Duration::from_secs(config.timeout as u64))
.build(manager)?;
Ok(pool)
}
}
#[cfg(feature = "diesel_postgres_pool")]
impl Poolable for diesel::PgConnection {
type Manager = diesel::r2d2::ConnectionManager<diesel::PgConnection>;
type Error = std::convert::Infallible;
fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
let config = Config::from(db_name, rocket)?;
let manager = diesel::r2d2::ConnectionManager::new(&config.url);
let pool = r2d2::Pool::builder()
.max_size(config.pool_size)
.connection_timeout(Duration::from_secs(config.timeout as u64))
.build(manager)?;
Ok(pool)
}
}
#[cfg(feature = "diesel_mysql_pool")]
impl Poolable for diesel::MysqlConnection {
type Manager = diesel::r2d2::ConnectionManager<diesel::MysqlConnection>;
type Error = std::convert::Infallible;
fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
let config = Config::from(db_name, rocket)?;
let manager = diesel::r2d2::ConnectionManager::new(&config.url);
let pool = r2d2::Pool::builder()
.max_size(config.pool_size)
.connection_timeout(Duration::from_secs(config.timeout as u64))
.build(manager)?;
Ok(pool)
}
}
#[cfg(feature = "postgres_pool")]
impl Poolable for postgres::Client {
type Manager = r2d2_postgres::PostgresConnectionManager<postgres::tls::NoTls>;
type Error = postgres::Error;
fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
let config = Config::from(db_name, rocket)?;
let url = config.url.parse().map_err(Error::Custom)?;
let manager = r2d2_postgres::PostgresConnectionManager::new(url, postgres::tls::NoTls);
let pool = r2d2::Pool::builder()
.max_size(config.pool_size)
.connection_timeout(Duration::from_secs(config.timeout as u64))
.build(manager)?;
Ok(pool)
}
}
#[cfg(feature = "sqlite_pool")]
impl Poolable for rusqlite::Connection {
type Manager = r2d2_sqlite::SqliteConnectionManager;
type Error = std::convert::Infallible;
fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
use rocket::figment::providers::Serialized;
#[derive(Debug, serde::Deserialize, serde::Serialize)]
#[serde(rename_all = "snake_case")]
enum OpenFlag {
ReadOnly,
ReadWrite,
Create,
Uri,
Memory,
NoMutex,
FullMutex,
SharedCache,
PrivateCache,
Nofollow,
}
let figment = Config::figment(db_name, rocket);
let config: Config = figment.extract()?;
let open_flags: Vec<OpenFlag> = figment
.join(Serialized::default("open_flags", <Vec<OpenFlag>>::new()))
.extract_inner("open_flags")?;
let mut flags = rusqlite::OpenFlags::default();
for flag in open_flags {
let sql_flag = match flag {
OpenFlag::ReadOnly => rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY,
OpenFlag::ReadWrite => rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE,
OpenFlag::Create => rusqlite::OpenFlags::SQLITE_OPEN_CREATE,
OpenFlag::Uri => rusqlite::OpenFlags::SQLITE_OPEN_URI,
OpenFlag::Memory => rusqlite::OpenFlags::SQLITE_OPEN_MEMORY,
OpenFlag::NoMutex => rusqlite::OpenFlags::SQLITE_OPEN_NO_MUTEX,
OpenFlag::FullMutex => rusqlite::OpenFlags::SQLITE_OPEN_FULL_MUTEX,
OpenFlag::SharedCache => rusqlite::OpenFlags::SQLITE_OPEN_SHARED_CACHE,
OpenFlag::PrivateCache => rusqlite::OpenFlags::SQLITE_OPEN_PRIVATE_CACHE,
OpenFlag::Nofollow => rusqlite::OpenFlags::SQLITE_OPEN_NOFOLLOW,
};
flags.insert(sql_flag)
};
let manager = r2d2_sqlite::SqliteConnectionManager::file(&*config.url)
.with_flags(flags);
let pool = r2d2::Pool::builder()
.max_size(config.pool_size)
.connection_timeout(Duration::from_secs(config.timeout as u64))
.build(manager)?;
Ok(pool)
}
}
#[cfg(feature = "memcache_pool")]
impl Poolable for memcache::Client {
type Manager = r2d2_memcache::MemcacheConnectionManager;
type Error = memcache::MemcacheError;
fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
let config = Config::from(db_name, rocket)?;
let manager = r2d2_memcache::MemcacheConnectionManager::new(&*config.url);
let pool = r2d2::Pool::builder()
.max_size(config.pool_size)
.connection_timeout(Duration::from_secs(config.timeout as u64))
.build(manager)?;
Ok(pool)
}
}