use crate::backend::{self, Backend, Error};
use anyhow::anyhow;
use async_bb8_diesel::AsyncR2D2Connection;
use async_bb8_diesel::AsyncSimpleConnection;
use async_trait::async_trait;
use diesel::pg::PgConnection;
use diesel::Connection;
use diesel_dtrace::DTraceConnection;
type DbConnection = DTraceConnection<PgConnection>;
pub struct DieselPgConnector {
prefix: String,
suffix: String,
}
impl DieselPgConnector {
pub fn new(user: &str, db: &str, args: Option<&str>) -> Self {
Self {
prefix: format!("postgresql://{user}@"),
suffix: format!(
"/{db}{}",
args.map(|args| format!("?{args}"))
.unwrap_or("".to_string())
),
}
}
fn to_url(&self, address: std::net::SocketAddr) -> String {
format!(
"{prefix}{address}{suffix}",
prefix = self.prefix,
suffix = self.suffix,
)
}
}
pub const DISALLOW_FULL_TABLE_SCAN_SQL: &str =
"set disallow_full_table_scans = on; set large_full_scan_rows = 0;";
#[async_trait]
impl backend::Connector for DieselPgConnector {
type Connection = async_bb8_diesel::Connection<DbConnection>;
async fn connect(&self, backend: &Backend) -> Result<Self::Connection, Error> {
let url = self.to_url(backend.address);
let conn = tokio::task::spawn_blocking(move || {
let pg_conn = DbConnection::establish(&url).map_err(|e| Error::Other(anyhow!(e)))?;
Ok::<_, Error>(async_bb8_diesel::Connection::new(pg_conn))
})
.await
.expect("Task panicked establishing connection")?;
Ok(conn)
}
async fn on_acquire(&self, conn: &mut Self::Connection) -> Result<(), Error> {
conn.batch_execute_async(DISALLOW_FULL_TABLE_SCAN_SQL)
.await
.map_err(|e| Error::Other(anyhow!(e)))?;
Ok(())
}
async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Error> {
let is_broken = conn.is_broken_async().await;
if is_broken {
return Err(Error::Other(anyhow!("Connection broken")));
}
conn.ping_async()
.await
.map_err(|e| Error::Other(anyhow!(e)))
}
}