use std::future::Future;
use std::sync::Arc;
use crate::DBError;
use bb8_postgres::PostgresConnectionManager;
use sqlw::{FromRow, Query, RowCell, RowError, RowLike, Value};
use tokio_postgres::NoTls;
use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres::types::ToSql;
pub struct PostgresRowRef<'a> {
row: &'a tokio_postgres::Row,
}
impl<'a> PostgresRowRef<'a> {
pub fn new(row: &'a tokio_postgres::Row) -> Self {
Self { row }
}
fn column_index(&self, name: &str) -> Result<usize, RowError> {
self.row
.columns()
.iter()
.position(|col| col.name() == name)
.ok_or_else(|| RowError::ColumnNotFound {
name: name.to_string(),
})
}
fn get_value_by_index(&self, index: usize) -> Result<Value, RowError> {
let col = &self.row.columns()[index];
let ty = col.type_();
let null_check: Option<i64> = self
.row
.try_get(index)
.map_err(|e| RowError::Any(e.to_string()))?;
if null_check.is_none() {
return Ok(Value::Null);
}
match ty {
&tokio_postgres::types::Type::TEXT
| &tokio_postgres::types::Type::VARCHAR
| &tokio_postgres::types::Type::BPCHAR
| &tokio_postgres::types::Type::NAME => {
let val: String = self
.row
.try_get(index)
.map_err(|e| RowError::Any(e.to_string()))?;
Ok(Value::Text(val))
}
&tokio_postgres::types::Type::INT2 => {
let val: i16 = self
.row
.try_get(index)
.map_err(|e| RowError::Any(e.to_string()))?;
Ok(Value::Int(val as i64))
}
&tokio_postgres::types::Type::INT4 => {
let val: i32 = self
.row
.try_get(index)
.map_err(|e| RowError::Any(e.to_string()))?;
Ok(Value::Int(val as i64))
}
&tokio_postgres::types::Type::INT8 => {
let val: i64 = self
.row
.try_get(index)
.map_err(|e| RowError::Any(e.to_string()))?;
Ok(Value::Int(val))
}
&tokio_postgres::types::Type::FLOAT4 => {
let val: f32 = self
.row
.try_get(index)
.map_err(|e| RowError::Any(e.to_string()))?;
Ok(Value::Float(val as f64))
}
&tokio_postgres::types::Type::FLOAT8 => {
let val: f64 = self
.row
.try_get(index)
.map_err(|e| RowError::Any(e.to_string()))?;
Ok(Value::Float(val))
}
&tokio_postgres::types::Type::BOOL => {
let val: bool = self
.row
.try_get(index)
.map_err(|e| RowError::Any(e.to_string()))?;
Ok(Value::Bool(val))
}
&tokio_postgres::types::Type::BYTEA => {
let val: Vec<u8> = self
.row
.try_get(index)
.map_err(|e| RowError::Any(e.to_string()))?;
Ok(Value::Blob(val))
}
_ => {
let val: String = self
.row
.try_get(index)
.map_err(|e| RowError::Any(e.to_string()))?;
Ok(Value::Text(val))
}
}
}
}
impl<'a> RowLike for PostgresRowRef<'a> {
fn cell<'b>(&'b self, name: &str) -> Result<RowCell<'b>, RowError> {
let index = self.column_index(name)?;
Ok(RowCell::Owned(self.get_value_by_index(index)?))
}
}
pub struct PostgresExecutor<Tls = NoTls>
where
Tls: MakeTlsConnect<tokio_postgres::Socket> + Send + Sync + Clone + 'static,
Tls::Stream: Send + Sync,
Tls::TlsConnect: Send,
<Tls::TlsConnect as tokio_postgres::tls::TlsConnect<tokio_postgres::Socket>>::Future: Send,
{
pool: Arc<bb8::Pool<PostgresConnectionManager<Tls>>>,
}
impl PostgresExecutor<NoTls> {
pub async fn new(connection_string: &str) -> Result<Self, DBError> {
Self::with_config(connection_string, NoTls, |builder| builder).await
}
pub async fn from_url(connection_string: &str) -> Result<Self, DBError> {
Self::new(connection_string).await
}
}
impl<Tls> PostgresExecutor<Tls>
where
Tls: MakeTlsConnect<tokio_postgres::Socket> + Send + Sync + Clone + 'static,
Tls::Stream: Send + Sync,
Tls::TlsConnect: Send,
<Tls::TlsConnect as tokio_postgres::tls::TlsConnect<tokio_postgres::Socket>>::Future: Send,
{
pub async fn with_tls(connection_string: &str, tls: Tls) -> Result<Self, DBError> {
Self::with_config(connection_string, tls, |builder| builder).await
}
pub async fn with_config<F>(
connection_string: &str,
tls: Tls,
config_fn: F,
) -> Result<Self, DBError>
where
F: FnOnce(
bb8::Builder<PostgresConnectionManager<Tls>>,
) -> bb8::Builder<PostgresConnectionManager<Tls>>,
{
let manager = PostgresConnectionManager::new_from_stringlike(connection_string, tls)
.map_err(|e| DBError::Connection(e.into()))?;
let builder = config_fn(bb8::Pool::builder());
let pool = builder
.build(manager)
.await
.map_err(|e| DBError::Connection(e.into()))?;
Ok(PostgresExecutor {
pool: Arc::new(pool),
})
}
pub fn pool(&self) -> &bb8::Pool<PostgresConnectionManager<Tls>> {
&self.pool
}
}
impl<Tls> sqlw::QueryExecutor for PostgresExecutor<Tls>
where
Tls: MakeTlsConnect<tokio_postgres::Socket> + Send + Sync + Clone + 'static,
Tls::Stream: Send + Sync,
Tls::TlsConnect: Send,
<Tls::TlsConnect as tokio_postgres::tls::TlsConnect<tokio_postgres::Socket>>::Future: Send,
{
type Error = DBError;
fn query_void(&self, query: Query) -> impl Future<Output = Result<(), DBError>> {
let pool = Arc::clone(&self.pool);
async move {
let (sql, args) = query.split();
let conn = pool.get().await.map_err(|e| DBError::Execution(e.into()))?;
let params_owned = to_postgres_params(args);
let params: Vec<&(dyn ToSql + Sync)> =
params_owned.iter().map(|v| v.as_ref()).collect();
conn.execute(&sql, ¶ms)
.await
.map(|_| ())
.map_err(|e| DBError::Execution(e.into()))
}
}
fn query_one<T: FromRow + Send + 'static>(
&self,
query: Query,
) -> impl Future<Output = Result<Option<T>, DBError>> {
let pool = Arc::clone(&self.pool);
async move {
let (sql, args) = query.split();
let conn = pool.get().await.map_err(|e| DBError::Execution(e.into()))?;
let params_owned = to_postgres_params(args);
let params: Vec<&(dyn ToSql + Sync)> =
params_owned.iter().map(|v| v.as_ref()).collect();
let row = conn
.query_opt(&sql, ¶ms)
.await
.map_err(|e| DBError::Execution(e.into()))?;
match row {
Some(row) => {
let row_ref = PostgresRowRef::new(&row);
T::from_row(&row_ref).map(Some)
}
None => Ok(None),
}
}
}
fn query_list<T: FromRow + Send + 'static>(
&self,
query: Query,
) -> impl Future<Output = Result<Vec<T>, DBError>> {
let pool = Arc::clone(&self.pool);
async move {
let (sql, args) = query.split();
let conn = pool.get().await.map_err(|e| DBError::Execution(e.into()))?;
let params_owned = to_postgres_params(args);
let params: Vec<&(dyn ToSql + Sync)> =
params_owned.iter().map(|v| v.as_ref()).collect();
let rows = conn
.query(&sql, ¶ms)
.await
.map_err(|e| DBError::Execution(e.into()))?;
let mut results = Vec::new();
for row in rows {
let row_ref = PostgresRowRef::new(&row);
results.push(T::from_row(&row_ref)?);
}
Ok(results)
}
}
}
fn to_postgres_params(args: Vec<Value>) -> Vec<Box<dyn ToSql + Sync>> {
args.into_iter()
.map(|value| match value {
Value::Text(s) => Box::new(s) as Box<dyn ToSql + Sync>,
Value::Int(i) => Box::new(i) as Box<dyn ToSql + Sync>,
Value::Float(f) => Box::new(f) as Box<dyn ToSql + Sync>,
Value::Bool(b) => Box::new(b) as Box<dyn ToSql + Sync>,
Value::Blob(b) => Box::new(b) as Box<dyn ToSql + Sync>,
Value::Null => Box::new(Option::<i64>::None) as Box<dyn ToSql + Sync>,
})
.collect()
}