use async_trait::async_trait;
use futures::{channel::oneshot, prelude::*};
use std::{
convert::{AsMut, AsRef},
ops::{Deref, DerefMut},
};
use tokio::spawn;
use tokio_postgres::error::Error;
use tokio_postgres::{
tls::{MakeTlsConnect, TlsConnect},
Client, Socket,
};
use tracing::{debug, debug_span, info, warn, Instrument};
use std::fmt;
pub struct AsyncConnection {
pub client: Client,
broken: bool,
done_rx: oneshot::Receiver<()>,
drop_tx: Option<oneshot::Sender<()>>,
}
impl Drop for AsyncConnection {
fn drop(&mut self) {
if let Some(drop_tx) = self.drop_tx.take() {
let _ = drop_tx.send(());
}
}
}
impl Deref for AsyncConnection {
type Target = Client;
fn deref(&self) -> &Self::Target {
&self.client
}
}
impl DerefMut for AsyncConnection {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.client
}
}
impl AsMut<Client> for AsyncConnection {
fn as_mut(&mut self) -> &mut Client {
&mut self.client
}
}
impl AsRef<Client> for AsyncConnection {
fn as_ref(&self) -> &Client {
&self.client
}
}
pub struct PostgresConnectionManager<T>
where
T: 'static + MakeTlsConnect<Socket> + Clone + Send + Sync,
{
config: tokio_postgres::Config,
make_tls_connect: T,
}
impl<T> PostgresConnectionManager<T>
where
T: 'static + MakeTlsConnect<Socket> + Clone + Send + Sync,
{
pub fn new(config: tokio_postgres::Config, make_tls_connect: T) -> Self {
Self {
config,
make_tls_connect,
}
}
}
#[async_trait]
impl<T> l337::ManageConnection for PostgresConnectionManager<T>
where
T: 'static + MakeTlsConnect<Socket> + Clone + Send + Sync,
T::Stream: Send + Sync,
T::TlsConnect: Send,
<T::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
type Connection = AsyncConnection;
type Error = Error;
async fn connect(&self) -> Result<Self::Connection, l337::Error<Self::Error>> {
let (client, connection) = self
.config
.connect(self.make_tls_connect.clone())
.instrument(debug_span!("connect: open new postgres connection"))
.await
.map_err(|e| l337::Error::External(e))?;
let (done_tx, done_rx) = oneshot::channel();
let (drop_tx, drop_rx) = oneshot::channel();
spawn(async move {
debug!("connect: start connection future");
let connection = connection.fuse();
let drop_rx = drop_rx.fuse();
futures::pin_mut!(connection, drop_rx);
futures::select! {
result = connection => {
if let Err(e) = result {
warn!("future backing postgres future ended with an error: {}", e);
}
}
_ = drop_rx => { }
}
let _ = done_tx.send(());
info!("connect: connection future ended");
});
debug!("connect: postgres connection established");
Ok(AsyncConnection {
broken: false,
client,
done_rx,
drop_tx: Some(drop_tx),
})
}
async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), l337::Error<Self::Error>> {
conn.simple_query("")
.await
.map_err(|e| l337::Error::External(e))?;
Ok(())
}
fn has_broken(&self, conn: &mut Self::Connection) -> bool {
if conn.broken {
return true;
}
if conn.client.is_closed() {
return true;
}
match conn.done_rx.try_recv() {
Ok(Some(_)) => {
conn.broken = true;
true
}
Ok(None) => false,
Err(error) => {
warn!(%error, "cannot receive from connection future");
conn.broken = true;
true
}
}
}
fn timed_out(&self) -> l337::Error<Self::Error> {
unimplemented!()
}
}
impl<T> fmt::Debug for PostgresConnectionManager<T>
where
T: 'static + MakeTlsConnect<Socket> + Clone + Send + Sync,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("PostgresConnectionManager")
.field("config", &self.config)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use l337::{Config, Pool};
use std::time::Duration;
use tokio::time::sleep;
#[tokio::test]
async fn it_works() {
let mngr = PostgresConnectionManager::new(
"postgres://pass_user:password@localhost:5433/postgres"
.parse()
.unwrap(),
tokio_postgres::NoTls,
);
let config: Config = Default::default();
let pool = Pool::new(mngr, config).await.unwrap();
let conn = pool.connection().await.unwrap();
let select = conn.prepare("SELECT 1::INT4").await.unwrap();
let rows = conn.query(&select, &[]).await.unwrap();
for row in rows {
assert_eq!(1, row.get(0));
}
}
#[tokio::test]
async fn it_allows_multiple_queries_at_the_same_time() {
let mngr = PostgresConnectionManager::new(
"postgres://pass_user:password@localhost:5433/postgres"
.parse()
.unwrap(),
tokio_postgres::NoTls,
);
let config: Config = Default::default();
let pool = Pool::new(mngr, config).await.unwrap();
let q1 = async {
let conn = pool.connection().await.unwrap();
let select = conn.prepare("SELECT 1::INT4").await.unwrap();
let rows = conn.query(&select, &[]).await.unwrap();
for row in rows {
assert_eq!(1, row.get(0));
}
sleep(Duration::from_secs(5)).await;
conn
};
let q2 = async {
let conn = pool.connection().await.unwrap();
let select = conn.prepare("SELECT 2::INT4").await.unwrap();
let rows = conn.query(&select, &[]).await.unwrap();
for row in rows {
assert_eq!(2, row.get(0));
}
sleep(Duration::from_secs(5)).await;
conn
};
futures::join!(q1, q2);
}
#[tokio::test]
async fn it_reuses_connections() {
let mngr = PostgresConnectionManager::new(
"postgres://pass_user:password@localhost:5433/postgres"
.parse()
.unwrap(),
tokio_postgres::NoTls,
);
let config: Config = Default::default();
let pool = Pool::new(mngr, config).await.unwrap();
let q1 = async {
let conn = pool.connection().await.unwrap();
let select = conn.prepare("SELECT 1::INT4").await.unwrap();
let rows = conn.query(&select, &[]).await.unwrap();
for row in rows {
assert_eq!(1, row.get(0));
}
};
q1.await;
sleep(Duration::from_millis(500)).await;
let q2 = async {
let conn = pool.connection().await.unwrap();
let select = conn.prepare("SELECT 2::INT4").await.unwrap();
let rows = conn.query(&select, &[]).await.unwrap();
for row in rows {
assert_eq!(2, row.get(0));
}
};
let q3 = async {
let conn = pool.connection().await.unwrap();
let select = conn.prepare("SELECT 3::INT4").await.unwrap();
let rows = conn.query(&select, &[]).await.unwrap();
for row in rows {
assert_eq!(3, row.get(0));
}
};
futures::join!(q2, q3);
assert_eq!(pool.total_conns(), 2);
}
}