use std::time::Duration;
use diesel::{
r2d2,
r2d2::{ConnectionManager, PooledConnection},
result::Error as DieselError,
PgConnection,
};
use diesel_migrations::{EmbeddedMigrations, MigrationHarness};
pub type PgConnectionManager = ConnectionManager<PgConnection>;
pub type PgPool = r2d2::Pool<PgConnectionManager>;
pub type PgPooledConnection = PooledConnection<PgConnectionManager>;
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("create pool: {0}")]
CreatePool(String),
#[error("get connection from pool: {0}")]
GetConnection(String),
#[error("database migration: {0}")]
GetMigration(String),
#[error("there are {0} pending migrations")]
PendingMigration(usize),
#[error("diesel failed: {0}")]
Diesel(DieselError),
}
#[derive(Clone)]
pub struct Pool {
pool: PgPool,
}
#[derive(Clone, Debug)]
pub struct PoolOpts {
pub connection_timeout: Option<Duration>,
pub idle_timeout: Option<Duration>,
pub max_size: Option<u32>,
pub min_idle: Option<u32>,
pub test_on_check_out: Option<bool>,
}
impl Default for PoolOpts {
fn default() -> Self {
Self {
connection_timeout: Some(Duration::from_secs(10)),
idle_timeout: None,
max_size: None,
min_idle: None,
test_on_check_out: Some(true),
}
}
}
impl Pool {
pub fn new(connection_url: String, migrations: EmbeddedMigrations) -> Result<Self, Error> {
Self::new_with_opts(connection_url, migrations, PoolOpts::default())
}
pub fn new_with_opts(
connection_url: String,
migrations: EmbeddedMigrations,
opts: PoolOpts,
) -> Result<Self, Error> {
let mut builder = r2d2::Pool::builder();
if let Some(connection_timeout) = opts.connection_timeout {
builder = builder.connection_timeout(connection_timeout);
}
if let Some(idle_timeout) = opts.idle_timeout {
builder = builder.idle_timeout(Some(idle_timeout));
}
if let Some(min_idle) = opts.min_idle {
builder = builder.min_idle(Some(min_idle));
}
if let Some(max_size) = opts.max_size {
builder = builder.max_size(max_size);
}
if let Some(test_on_check_out) = opts.test_on_check_out {
builder = builder.test_on_check_out(test_on_check_out);
}
let pool = builder
.build(ConnectionManager::<PgConnection>::new(connection_url))
.map_err(|e| Error::CreatePool(e.to_string()))?;
check_pending_migrations(&pool, migrations)?;
Ok(Self { pool })
}
pub fn execute<T, Q, E>(&self, query: Q) -> Result<T, E>
where
T: Send + 'static,
Q: FnOnce(PgPooledConnection) -> Result<T, E> + Send + 'static,
E: From<Error>,
{
let conn = self.connection()?;
tokio::task::block_in_place(|| query(conn))
}
fn connection(&self) -> Result<PgPooledConnection, Error> {
self.pool
.get()
.map_err(|e| Error::GetConnection(e.to_string()))
}
}
fn check_pending_migrations(pool: &PgPool, migrations: EmbeddedMigrations) -> Result<(), Error> {
match count_pending_migrations(pool, migrations)? {
0 => Ok(()),
n => Err(Error::PendingMigration(n)),
}
}
fn count_pending_migrations(pool: &PgPool, migrations: EmbeddedMigrations) -> Result<usize, Error> {
let count_pending_migrations = MigrationHarness::pending_migrations(
&mut pool
.get()
.map_err(|e| Error::GetConnection(e.to_string()))?,
migrations,
)
.map_err(|e| Error::GetMigration(e.to_string()))?
.len();
Ok(count_pending_migrations)
}