use {
crate::{table::Schema, Error, Transaction},
std::path::Path,
};
#[derive(Clone, Debug)]
pub enum Pool {
#[cfg(feature = "postgresql")]
PostgreSQL(r2d2::Pool<self::postgres::PostgresConnectionManager<::postgres::NoTls>>),
#[cfg(feature = "sqlite")]
SQLite(r2d2::Pool<self::sqlite::SqliteConnectionManager>),
}
impl Pool {
#[cfg(feature = "postgresql")]
pub fn postgres(config: ::postgres::Config) -> Result<Self, Error> {
let conn = Pool::PostgreSQL(r2d2::Pool::new(
crate::pool::postgres::PostgresConnectionManager::new(config, ::postgres::NoTls),
)?);
Ok(conn)
}
#[cfg(feature = "sqlite")]
pub fn sqlite(path: impl AsRef<Path>) -> Result<Self, Error> {
let conn = Pool::SQLite(r2d2::Pool::new(
crate::pool::sqlite::SqliteConnectionManager::file(path),
)?);
Ok(conn)
}
pub fn as_kind(&self) -> PoolKind {
match self {
#[cfg(feature = "postgresql")]
Pool::PostgreSQL(_) => PoolKind::PostgreSQL,
#[cfg(feature = "sqlite")]
Pool::SQLite(_) => PoolKind::SQLite,
}
}
pub fn batch_execute(&self, exec: impl AsRef<str>) -> Result<(), Error> {
match self {
#[cfg(feature = "postgresql")]
Pool::PostgreSQL(pool) => {
let mut conn = pool.get()?;
conn.batch_execute(exec.as_ref())?;
}
#[cfg(feature = "sqlite")]
Pool::SQLite(pool) => {
let conn = pool.get()?;
conn.execute_batch(exec.as_ref())?;
}
}
Ok(())
}
pub fn transaction(
&self,
run: impl FnOnce(Transaction<'_>) -> Result<(), Error>,
) -> Result<(), Error> {
match self {
#[cfg(feature = "postgresql")]
Pool::PostgreSQL(pool) => {
let mut conn = pool.get()?;
let trans = conn.transaction()?;
let inner = Transaction::PostgreSQL(trans);
run(inner)?;
}
#[cfg(feature = "sqlite")]
Pool::SQLite(pool) => {
let mut conn = pool.get()?;
let trans = conn.transaction()?;
let inner = Transaction::SQLite(trans);
run(inner)?;
}
}
Ok(())
}
pub fn schema<T: Schema>(&self) -> Result<(), Error> {
match self {
#[cfg(feature = "postgresql")]
Pool::PostgreSQL(pool) => {
let mut conn = pool.get()?;
conn.batch_execute(T::schema_postgres())?;
}
#[cfg(feature = "sqlite")]
Pool::SQLite(pool) => {
let conn = pool.get()?;
conn.execute_batch(T::schema_sqlite())?;
}
}
Ok(())
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum PoolKind {
#[cfg(feature = "postgresql")]
PostgreSQL,
#[cfg(feature = "sqlite")]
SQLite,
}
impl From<Pool> for PoolKind {
fn from(pool: Pool) -> PoolKind {
pool.as_kind()
}
}
impl<'a> From<&'a Pool> for PoolKind {
fn from(pool: &'a Pool) -> PoolKind {
pool.as_kind()
}
}
#[cfg(feature = "postgresql")]
pub mod postgres {
use {
postgres::{
tls::{MakeTlsConnect, TlsConnect},
Client, Config, Error, Socket,
},
r2d2::ManageConnection,
};
#[derive(Debug)]
pub struct PostgresConnectionManager<T> {
config: Config,
tls_connector: T,
}
impl<T> PostgresConnectionManager<T>
where
T: MakeTlsConnect<Socket> + Clone + 'static + Sync + Send,
T::TlsConnect: Send,
T::Stream: Send,
<T::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
pub fn new(config: Config, tls_connector: T) -> PostgresConnectionManager<T> {
PostgresConnectionManager {
config,
tls_connector,
}
}
}
impl<T> ManageConnection for PostgresConnectionManager<T>
where
T: MakeTlsConnect<Socket> + Clone + 'static + Sync + Send,
T::TlsConnect: Send,
T::Stream: Send,
<T::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
type Connection = Client;
type Error = Error;
fn connect(&self) -> Result<Client, Error> {
self.config.connect(self.tls_connector.clone())
}
fn is_valid(&self, client: &mut Client) -> Result<(), Error> {
client.simple_query("").map(|_| ())
}
fn has_broken(&self, client: &mut Client) -> bool {
client.is_closed()
}
}
}
#[cfg(feature = "sqlite")]
pub mod sqlite {
use {
rusqlite::{Connection, Error, OpenFlags},
std::{
fmt,
path::{Path, PathBuf},
},
};
pub struct SqliteConnectionManager {
path: PathBuf,
}
impl fmt::Debug for SqliteConnectionManager {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut builder = f.debug_struct("SqliteConnectionManager");
let _ = builder.field("path", &self.path);
builder.finish()
}
}
impl SqliteConnectionManager {
pub fn file<P: AsRef<Path>>(path: P) -> Self {
Self {
path: path.as_ref().to_path_buf(),
}
}
}
impl r2d2::ManageConnection for SqliteConnectionManager {
type Connection = Connection;
type Error = rusqlite::Error;
fn connect(&self) -> Result<Connection, Error> {
Connection::open_with_flags(&self.path, OpenFlags::default()).map_err(Into::into)
}
fn is_valid(&self, conn: &mut Connection) -> Result<(), Error> {
conn.execute_batch("").map_err(Into::into)
}
fn has_broken(&self, _: &mut Connection) -> bool {
false
}
}
}