use std::marker::PhantomData;
use std::sync::Arc;
use prax_query::QueryResult;
use prax_query::filter::FilterValue;
use prax_query::traits::{BoxFuture, Model, QueryEngine};
use tracing::trace;
use crate::pool::PgPool;
use crate::types::filter_value_to_sql;
#[derive(Clone)]
pub struct PgEngine {
pool: PgPool,
tx_conn: Option<Arc<deadpool_postgres::Object>>,
}
impl PgEngine {
pub fn new(pool: PgPool) -> Self {
Self {
pool,
tx_conn: None,
}
}
pub fn pool(&self) -> &PgPool {
&self.pool
}
#[allow(clippy::result_large_err)]
fn to_params(
values: &[FilterValue],
) -> Result<Vec<Box<dyn tokio_postgres::types::ToSql + Sync + Send>>, prax_query::QueryError>
{
values
.iter()
.map(|v| {
filter_value_to_sql(v).map_err(|e| {
let msg = e.to_string();
prax_query::QueryError::database(msg).with_source(e)
})
})
.collect()
}
}
impl QueryEngine for PgEngine {
fn dialect(&self) -> &dyn prax_query::dialect::SqlDialect {
&prax_query::dialect::Postgres
}
fn query_many<T: Model + prax_query::row::FromRow + Send + 'static>(
&self,
sql: &str,
params: Vec<FilterValue>,
) -> BoxFuture<'_, QueryResult<Vec<T>>> {
let sql = sql.to_string();
Box::pin(async move {
trace!(sql = %sql, "Executing query_many");
let pg_params = Self::to_params(¶ms)?;
let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
pg_params.iter().map(|p| p.as_ref() as _).collect();
let rows = if let Some(tx) = &self.tx_conn {
tx.query(&sql, ¶m_refs)
.await
.map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
} else {
let conn = self.pool.get().await.map_err(|e| {
prax_query::QueryError::connection(e.to_string()).with_source(e)
})?;
conn.query(&sql, ¶m_refs)
.await
.map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
};
crate::deserialize::rows_into::<T>(rows)
})
}
fn query_one<T: Model + prax_query::row::FromRow + Send + 'static>(
&self,
sql: &str,
params: Vec<FilterValue>,
) -> BoxFuture<'_, QueryResult<T>> {
let sql = sql.to_string();
Box::pin(async move {
trace!(sql = %sql, "Executing query_one");
let pg_params = Self::to_params(¶ms)?;
let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
pg_params.iter().map(|p| p.as_ref() as _).collect();
let map_err = |e: String| -> prax_query::QueryError {
if e.contains("no rows") {
prax_query::QueryError::not_found(T::MODEL_NAME)
} else {
prax_query::QueryError::database(e)
}
};
let row = if let Some(tx) = &self.tx_conn {
tx.query_one(&sql, ¶m_refs)
.await
.map_err(|e| map_err(e.to_string()).with_source(e))?
} else {
let conn = self.pool.get().await.map_err(|e| {
prax_query::QueryError::connection(e.to_string()).with_source(e)
})?;
conn.query_one(&sql, ¶m_refs)
.await
.map_err(|e| map_err(e.to_string()).with_source(e))?
};
crate::deserialize::row_into::<T>(row)
})
}
fn query_optional<T: Model + prax_query::row::FromRow + Send + 'static>(
&self,
sql: &str,
params: Vec<FilterValue>,
) -> BoxFuture<'_, QueryResult<Option<T>>> {
let sql = sql.to_string();
Box::pin(async move {
trace!(sql = %sql, "Executing query_optional");
let pg_params = Self::to_params(¶ms)?;
let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
pg_params.iter().map(|p| p.as_ref() as _).collect();
let row = if let Some(tx) = &self.tx_conn {
tx.query_opt(&sql, ¶m_refs)
.await
.map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
} else {
let conn = self.pool.get().await.map_err(|e| {
prax_query::QueryError::connection(e.to_string()).with_source(e)
})?;
conn.query_opt(&sql, ¶m_refs)
.await
.map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
};
row.map(crate::deserialize::row_into::<T>).transpose()
})
}
fn execute_insert<T: Model + prax_query::row::FromRow + Send + 'static>(
&self,
sql: &str,
params: Vec<FilterValue>,
) -> BoxFuture<'_, QueryResult<T>> {
let sql = sql.to_string();
Box::pin(async move {
trace!(sql = %sql, "Executing insert");
let pg_params = Self::to_params(¶ms)?;
let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
pg_params.iter().map(|p| p.as_ref() as _).collect();
let row = if let Some(tx) = &self.tx_conn {
tx.query_one(&sql, ¶m_refs)
.await
.map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
} else {
let conn = self.pool.get().await.map_err(|e| {
prax_query::QueryError::connection(e.to_string()).with_source(e)
})?;
conn.query_one(&sql, ¶m_refs)
.await
.map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
};
crate::deserialize::row_into::<T>(row)
})
}
fn execute_update<T: Model + prax_query::row::FromRow + Send + 'static>(
&self,
sql: &str,
params: Vec<FilterValue>,
) -> BoxFuture<'_, QueryResult<Vec<T>>> {
let sql = sql.to_string();
Box::pin(async move {
trace!(sql = %sql, "Executing update");
let pg_params = Self::to_params(¶ms)?;
let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
pg_params.iter().map(|p| p.as_ref() as _).collect();
let rows = if let Some(tx) = &self.tx_conn {
tx.query(&sql, ¶m_refs)
.await
.map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
} else {
let conn = self.pool.get().await.map_err(|e| {
prax_query::QueryError::connection(e.to_string()).with_source(e)
})?;
conn.query(&sql, ¶m_refs)
.await
.map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
};
crate::deserialize::rows_into::<T>(rows)
})
}
fn execute_delete(
&self,
sql: &str,
params: Vec<FilterValue>,
) -> BoxFuture<'_, QueryResult<u64>> {
let sql = sql.to_string();
Box::pin(async move {
trace!(sql = %sql, "Executing delete");
let pg_params = Self::to_params(¶ms)?;
let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
pg_params.iter().map(|p| p.as_ref() as _).collect();
if let Some(tx) = &self.tx_conn {
tx.execute(&sql, ¶m_refs)
.await
.map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))
} else {
let conn = self.pool.get().await.map_err(|e| {
prax_query::QueryError::connection(e.to_string()).with_source(e)
})?;
conn.execute(&sql, ¶m_refs)
.await
.map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))
}
})
}
fn execute_raw(&self, sql: &str, params: Vec<FilterValue>) -> BoxFuture<'_, QueryResult<u64>> {
let sql = sql.to_string();
Box::pin(async move {
trace!(sql = %sql, "Executing raw SQL");
let pg_params = Self::to_params(¶ms)?;
let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
pg_params.iter().map(|p| p.as_ref() as _).collect();
if let Some(tx) = &self.tx_conn {
tx.execute(&sql, ¶m_refs)
.await
.map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))
} else {
let conn = self.pool.get().await.map_err(|e| {
prax_query::QueryError::connection(e.to_string()).with_source(e)
})?;
conn.execute(&sql, ¶m_refs)
.await
.map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))
}
})
}
fn count(&self, sql: &str, params: Vec<FilterValue>) -> BoxFuture<'_, QueryResult<u64>> {
let sql = sql.to_string();
Box::pin(async move {
trace!(sql = %sql, "Executing count");
let pg_params = Self::to_params(¶ms)?;
let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
pg_params.iter().map(|p| p.as_ref() as _).collect();
let row = if let Some(tx) = &self.tx_conn {
tx.query_one(&sql, ¶m_refs)
.await
.map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
} else {
let conn = self.pool.get().await.map_err(|e| {
prax_query::QueryError::connection(e.to_string()).with_source(e)
})?;
conn.query_one(&sql, ¶m_refs)
.await
.map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
};
let count: i64 = row.get(0);
Ok(count as u64)
})
}
fn aggregate_query(
&self,
sql: &str,
params: Vec<FilterValue>,
) -> BoxFuture<'_, QueryResult<Vec<std::collections::HashMap<String, FilterValue>>>> {
let sql = sql.to_string();
Box::pin(async move {
trace!(sql = %sql, "Executing aggregate_query");
let pg_params = Self::to_params(¶ms)?;
let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
pg_params.iter().map(|p| p.as_ref() as _).collect();
let rows = if let Some(tx) = &self.tx_conn {
tx.query(&sql, ¶m_refs)
.await
.map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
} else {
let conn = self.pool.get().await.map_err(|e| {
prax_query::QueryError::connection(e.to_string()).with_source(e)
})?;
conn.query(&sql, ¶m_refs)
.await
.map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
};
Ok(rows
.into_iter()
.map(|row| {
let mut map = std::collections::HashMap::new();
for (i, col) in row.columns().iter().enumerate() {
let name = col.name().to_string();
let value = decode_aggregate_cell(&row, i, col.type_());
map.insert(name, value);
}
map
})
.collect())
})
}
fn transaction<'a, R, Fut, F>(&'a self, f: F) -> BoxFuture<'a, QueryResult<R>>
where
F: FnOnce(Self) -> Fut + Send + 'a,
Fut: std::future::Future<Output = QueryResult<R>> + Send + 'a,
R: Send + 'a,
Self: Clone,
{
Box::pin(async move {
if self.tx_conn.is_some() {
return Err(prax_query::QueryError::internal(
"nested transactions not yet implemented \
(call .transaction() on the outer engine only, or \
issue SAVEPOINT via execute_raw)",
));
}
let conn =
self.pool.inner().get().await.map_err(|e| {
prax_query::QueryError::connection(e.to_string()).with_source(e)
})?;
conn.batch_execute("BEGIN")
.await
.map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?;
let tx_conn = Arc::new(conn);
let tx_engine = PgEngine {
pool: self.pool.clone(),
tx_conn: Some(tx_conn.clone()),
};
let result = f(tx_engine).await;
match result {
Ok(v) => {
tx_conn.batch_execute("COMMIT").await.map_err(|e| {
prax_query::QueryError::database(e.to_string()).with_source(e)
})?;
Ok(v)
}
Err(e) => {
let _ = tx_conn.batch_execute("ROLLBACK").await;
Err(e)
}
}
})
}
}
pub struct PgQueryBuilder<T: Model> {
engine: PgEngine,
_marker: PhantomData<T>,
}
impl<T: Model> PgQueryBuilder<T> {
pub fn new(engine: PgEngine) -> Self {
Self {
engine,
_marker: PhantomData,
}
}
pub fn engine(&self) -> &PgEngine {
&self.engine
}
}
fn decode_aggregate_cell(
row: &tokio_postgres::Row,
idx: usize,
ty: &tokio_postgres::types::Type,
) -> FilterValue {
use tokio_postgres::types::Type;
match *ty {
Type::BOOL => row
.try_get::<_, Option<bool>>(idx)
.ok()
.flatten()
.map(FilterValue::Bool)
.unwrap_or(FilterValue::Null),
Type::INT2 => row
.try_get::<_, Option<i16>>(idx)
.ok()
.flatten()
.map(|n| FilterValue::Int(n as i64))
.unwrap_or(FilterValue::Null),
Type::INT4 => row
.try_get::<_, Option<i32>>(idx)
.ok()
.flatten()
.map(|n| FilterValue::Int(n as i64))
.unwrap_or(FilterValue::Null),
Type::INT8 => row
.try_get::<_, Option<i64>>(idx)
.ok()
.flatten()
.map(FilterValue::Int)
.unwrap_or(FilterValue::Null),
Type::FLOAT4 => row
.try_get::<_, Option<f32>>(idx)
.ok()
.flatten()
.map(|f| FilterValue::Float(f as f64))
.unwrap_or(FilterValue::Null),
Type::FLOAT8 => row
.try_get::<_, Option<f64>>(idx)
.ok()
.flatten()
.map(FilterValue::Float)
.unwrap_or(FilterValue::Null),
Type::TEXT | Type::VARCHAR | Type::CHAR | Type::NAME | Type::BPCHAR | Type::NUMERIC => row
.try_get::<_, Option<String>>(idx)
.ok()
.flatten()
.map(FilterValue::String)
.unwrap_or(FilterValue::Null),
Type::JSON | Type::JSONB => row
.try_get::<_, Option<serde_json::Value>>(idx)
.ok()
.flatten()
.map(FilterValue::Json)
.unwrap_or(FilterValue::Null),
_ => row
.try_get::<_, Option<String>>(idx)
.ok()
.flatten()
.map(FilterValue::String)
.unwrap_or(FilterValue::Null),
}
}
#[cfg(test)]
mod tests {
}