#![warn(missing_docs)]
#![deny(warnings, clippy::pedantic, clippy::nursery, unused_crate_dependencies)]
#![allow(clippy::future_not_send)]
#[cfg(feature = "async")]
use {
deadpool_postgres::Pool as Deadpool,
std::{future::Future, pin::Pin},
tokio_postgres::error::SqlState,
};
#[cfg(feature = "sync")]
use {
r2d2_postgres::PostgresConnectionManager,
tokio::task::block_in_place,
tokio_postgres::{
tls::{MakeTlsConnect, TlsConnect},
Socket,
},
};
pub const COCKROACH_SAVEPOINT: &str = "cockroach_restart";
pub enum Error<E> {
#[cfg(feature = "async")]
Pool(deadpool::managed::PoolError<tokio_postgres::Error>),
#[cfg(feature = "async")]
Postgres(tokio_postgres::Error),
#[cfg(feature = "sync")]
R2d2(r2d2::Error),
Other(E),
}
#[cfg(feature = "async")]
impl<E> From<deadpool::managed::PoolError<tokio_postgres::Error>> for Error<E> {
#[inline]
fn from(e: deadpool::managed::PoolError<tokio_postgres::Error>) -> Self {
Self::Pool(e)
}
}
#[cfg(feature = "async")]
impl<E> From<tokio_postgres::Error> for Error<E> {
#[inline]
fn from(e: tokio_postgres::Error) -> Self {
Self::Postgres(e)
}
}
#[cfg(feature = "sync")]
impl<E> From<r2d2::Error> for Error<E> {
#[inline]
fn from(e: r2d2::Error) -> Self {
Self::R2d2(e)
}
}
#[cfg(feature = "sync")]
pub type Pool<T> = r2d2::Pool<PostgresConnectionManager<T>>;
#[cfg(feature = "async")]
pub type AsyncResult<'a, T, I> = Pin<Box<dyn Future<Output = Result<T, I>> + Send + 'a>>;
#[cfg(feature = "async")]
#[inline]
pub async fn tx<T, E, I, S, F>(pool: &Deadpool, savepoint: S, f: F) -> Result<T, Error<E>>
where
I: Into<Error<E>>,
S: AsRef<str>,
for<'a> F: Fn(&'a tokio_postgres::Transaction<'a>) -> AsyncResult<'a, T, I>,
{
let mut client = pool.get().await?;
let mut tx = client.transaction().await?;
let savepoint = savepoint.as_ref();
let v = loop {
match execute_fn(&mut tx, savepoint, &f).await {
Err(Error::Postgres(e)) if e.code() == Some(&SqlState::T_R_SERIALIZATION_FAILURE) => {}
r => break r,
}
}?;
tx.commit().await?;
Ok(v)
}
#[cfg(feature = "async")]
#[inline]
async fn execute_fn<T, E, I, F>(
tx: &mut tokio_postgres::Transaction<'_>,
savepoint: &str,
f: &F,
) -> Result<T, Error<E>>
where
I: Into<Error<E>>,
for<'a> F: Fn(&'a tokio_postgres::Transaction<'a>) -> AsyncResult<'a, T, I>,
{
let mut sp = tx.savepoint(savepoint).await?;
let v = f(&mut sp).await.map_err(Into::into)?;
sp.commit().await?;
Ok(v)
}
#[cfg(feature = "sync")]
#[inline]
pub fn tx_sync<T, M, E, I, S, F>(pool: &Pool<M>, savepoint: S, mut f: F) -> Result<T, Error<E>>
where
M: MakeTlsConnect<Socket> + Clone + 'static + Sync + Send,
M::TlsConnect: Send,
M::Stream: Send,
<M::TlsConnect as TlsConnect<Socket>>::Future: Send,
I: Into<Error<E>>,
S: AsRef<str>,
F: FnMut(&mut postgres::Transaction<'_>) -> Result<T, I>,
{
block_in_place(|| {
let mut con = pool.get()?;
let mut tx = con.transaction()?;
let savepoint = savepoint.as_ref();
loop {
let mut sp = tx.savepoint(savepoint)?;
match f(&mut sp)
.map_err(Into::into)
.and_then(|t| sp.commit().map(|_| t).map_err(Error::from))
{
Err(Error::Postgres(e))
if e.code() == Some(&SqlState::T_R_SERIALIZATION_FAILURE) => {}
r => break r,
}
}
.and_then(|t| tx.commit().map(|_| t).map_err(Error::from))
})
}