use crate::SqlError;
use serde_json::Value;
use sqlx::{AssertSqlSafe, Decode, Execute, Executor, FromRow, Row, SqlSafeStr, SqlStr, Type};
use sqlx_json::{QueryResult as _, RowExt};
use crate::timeout::execute_with_timeout;
pub trait IntoSafeQuery<DB: sqlx::Database> {
fn into_sql_and_args(self) -> Result<(SqlStr, Option<DB::Arguments>), SqlError>;
}
impl<DB: sqlx::Database> IntoSafeQuery<DB> for &str {
fn into_sql_and_args(self) -> Result<(SqlStr, Option<DB::Arguments>), SqlError> {
Ok((AssertSqlSafe(self).into_sql_str(), None))
}
}
impl<DB: sqlx::Database, A> IntoSafeQuery<DB> for sqlx::query::Query<'_, DB, A>
where
A: Send + sqlx::IntoArguments<DB>,
{
fn into_sql_and_args(mut self) -> Result<(SqlStr, Option<DB::Arguments>), SqlError> {
let arguments = self.take_arguments().map_err(|e| SqlError::Query(e.to_string()))?;
Ok((self.sql(), arguments))
}
}
#[allow(async_fn_in_trait)]
pub trait Connection: Send + Sync
where
for<'c> &'c mut <Self::DB as sqlx::Database>::Connection: Executor<'c, Database = Self::DB>,
usize: sqlx::ColumnIndex<<Self::DB as sqlx::Database>::Row>,
<Self::DB as sqlx::Database>::Row: RowExt,
<Self::DB as sqlx::Database>::QueryResult: sqlx_json::QueryResult,
{
type DB: sqlx::Database;
async fn pool(&self, target: Option<&str>) -> Result<sqlx::Pool<Self::DB>, SqlError>;
fn query_timeout(&self) -> Option<u64>;
async fn execute<Q>(&self, query: Q, database: Option<&str>) -> Result<u64, SqlError>
where
Q: IntoSafeQuery<Self::DB>,
{
let (sql, arguments) = query.into_sql_and_args()?;
let pool = self.pool(database).await?;
execute_with_timeout(self.query_timeout(), sql, |sql| async move {
Ok(pool.execute((sql, arguments)).await?.rows_affected())
})
.await
}
async fn fetch_json<Q>(&self, query: Q, database: Option<&str>) -> Result<Vec<Value>, SqlError>
where
Q: IntoSafeQuery<Self::DB>,
{
let (sql, arguments) = query.into_sql_and_args()?;
let pool = self.pool(database).await?;
execute_with_timeout(self.query_timeout(), sql, |sql| async move {
let rows = pool.fetch_all((sql, arguments)).await?;
Ok(rows.iter().map(RowExt::to_json).collect())
})
.await
}
async fn fetch_optional<Q, T>(&self, query: Q, database: Option<&str>) -> Result<Option<T>, SqlError>
where
Q: IntoSafeQuery<Self::DB>,
T: for<'r> Decode<'r, Self::DB> + Type<Self::DB> + Send + Unpin,
{
let (sql, arguments) = query.into_sql_and_args()?;
let pool = self.pool(database).await?;
execute_with_timeout(self.query_timeout(), sql, |sql| async move {
let row = pool.fetch_optional((sql, arguments)).await?;
Ok(row.and_then(|r| r.try_get(0usize).ok()))
})
.await
}
async fn fetch_scalar<Q, T>(&self, query: Q, database: Option<&str>) -> Result<Vec<T>, SqlError>
where
Q: IntoSafeQuery<Self::DB>,
T: for<'r> Decode<'r, Self::DB> + Type<Self::DB> + Send + Unpin,
{
let (sql, arguments) = query.into_sql_and_args()?;
let pool = self.pool(database).await?;
execute_with_timeout(self.query_timeout(), sql, |sql| async move {
let rows = pool.fetch_all((sql, arguments)).await?;
rows.iter().map(|r| r.try_get(0usize)).collect()
})
.await
}
async fn fetch<Q, T>(&self, query: Q, database: Option<&str>) -> Result<Vec<T>, SqlError>
where
Q: IntoSafeQuery<Self::DB>,
T: for<'r> FromRow<'r, <Self::DB as sqlx::Database>::Row> + Send + Unpin,
{
let (sql, arguments) = query.into_sql_and_args()?;
let pool = self.pool(database).await?;
execute_with_timeout(self.query_timeout(), sql, |sql| async move {
let rows = pool.fetch_all((sql, arguments)).await?;
rows.iter().map(T::from_row).collect()
})
.await
}
}