use cache::PgCache;
use diesel::{
AsChangeset, ConnectionError, ConnectionResult, QueryResult,
query_builder::{AsQuery, IntoUpdateTarget, QueryFragment, QueryId},
};
use futures_util::stream::{BoxStream, StreamExt, TryStreamExt};
use prepared_client::PreparedClient;
use tokio_postgres::{
Client, Error, Row, RowStream, Statement, ToStatement,
types::{BorrowToSql, ToSql, Type},
};
pub use transaction::AsyncPgTransaction;
use self::{error_helper::ErrorHelper, row::PgRow};
use crate::{
AsyncConnection, AsyncExecute, AsyncTransactional, UpdateAndFetchResults, run_query_dsl::*,
};
mod cache;
mod error_helper;
mod metadata;
mod prepared_client;
mod row;
mod serialize;
mod transaction;
pub struct AsyncPgConnection {
conn: tokio_postgres::Client,
cache: PgCache,
}
impl AsyncExecute for AsyncPgConnection {
type Stream<'conn> = BoxStream<'conn, QueryResult<PgRow>>;
type Row<'conn> = PgRow;
type Backend = diesel::pg::Pg;
async fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
Ok(self.conn.batch_execute(query).await.map_err(ErrorHelper)?)
}
async fn load<T>(&mut self, source: T) -> QueryResult<Self::Stream<'_>>
where
T: AsQuery,
T::Query: QueryFragment<Self::Backend> + QueryId,
{
let res = self.cache.load_cached(&mut self.conn, source).await?;
let res = res
.map_err(|e| diesel::result::Error::from(ErrorHelper(e)))
.map_ok(PgRow::new);
Ok(res.boxed())
}
async fn execute_returning_count<T>(&mut self, source: T) -> QueryResult<usize>
where
T: QueryFragment<Self::Backend> + QueryId + Send,
{
self.cache
.execute_returning_count_cached(&mut self.conn, source)
.await
}
}
impl AsyncTransactional for AsyncPgConnection {
type Transaction<'a>
= AsyncPgTransaction<'a>
where
Self: 'a;
async fn begin_transaction(&mut self) -> QueryResult<Self::Transaction<'_>> {
let transaction = self.conn.transaction().await.map_err(ErrorHelper)?;
let transaction = AsyncPgTransaction::new(transaction, &mut self.cache);
Ok(transaction)
}
}
impl AsyncConnection for AsyncPgConnection {
async fn establish(database_url: &str) -> ConnectionResult<Self> {
let (client, connection) = tokio_postgres::connect(database_url, tokio_postgres::NoTls)
.await
.map_err(ErrorHelper)?;
tokio::spawn(async move {
let _ = connection.await;
});
Self::setup(client).await
}
fn is_broken(&self) -> bool {
self.conn.is_closed()
}
}
impl AsyncPgConnection {
pub async fn try_from(conn: tokio_postgres::Client) -> ConnectionResult<Self> {
Self::setup(conn).await
}
async fn setup(conn: tokio_postgres::Client) -> ConnectionResult<Self> {
let cache = PgCache::new();
let mut result = Self { conn, cache };
result
.set_config_options()
.await
.map_err(ConnectionError::CouldntSetupConfiguration)?;
Ok(result)
}
async fn set_config_options(&mut self) -> QueryResult<()> {
use crate::run_query_dsl::RunQueryDsl;
diesel::sql_query("SET TIME ZONE 'UTC'")
.execute(self)
.await?;
diesel::sql_query("SET CLIENT_ENCODING TO 'UTF8'")
.execute(self)
.await?;
Ok(())
}
pub fn cancel_token(&self) -> tokio_postgres::CancelToken {
self.conn.cancel_token()
}
}
impl PreparedClient for Client {
async fn query_one<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Row, Error>
where
T: ?Sized + Send + Sync + ToStatement,
{
(self as &Client).query_one(statement, params).await
}
async fn prepare_typed(
&self,
query: &str,
parameter_types: &[Type],
) -> Result<Statement, Error> {
(self as &Client)
.prepare_typed(query, parameter_types)
.await
}
async fn query_raw<T, P, I>(&self, statement: &T, params: I) -> Result<RowStream, Error>
where
T: ?Sized + Send + Sync + ToStatement,
P: BorrowToSql,
I: IntoIterator<Item = P> + Send + Sync,
I::IntoIter: ExactSizeIterator,
{
(self as &Client).query_raw(statement, params).await
}
async fn execute<T>(&self, statement: &T, params: &[&(dyn ToSql + Sync)]) -> Result<u64, Error>
where
T: ToStatement + ?Sized,
{
(self as &Client).execute(statement, params).await
}
}
impl<Changes, Output, Tab, V> UpdateAndFetchResults<Changes, Output> for crate::AsyncPgConnection
where
Output: Send,
Changes:
Copy + AsChangeset<Target = Tab> + Send + diesel::associations::Identifiable<Table = Tab>,
Tab: diesel::Table + diesel::query_dsl::methods::FindDsl<Changes::Id>,
diesel::dsl::Find<Tab, Changes::Id>: IntoUpdateTarget<Table = Tab, WhereClause = V>,
diesel::query_builder::UpdateStatement<Tab, V, Changes::Changeset>:
diesel::query_builder::AsQuery,
diesel::dsl::Update<Changes, Changes>: LoadQuery<Self, Output>,
V: Send,
Changes::Changeset: Send,
Tab::FromClause: Send,
{
async fn update_and_fetch(&mut self, changeset: Changes) -> QueryResult<Output> {
diesel::update(changeset)
.set(changeset)
.get_result(self)
.await
}
}