use crate::core::{
BulkInsertQuery, CountQuery, DeleteQuery, InsertQuery, Model, SelectQuery, SqlValue,
UpdateQuery,
};
use crate::query::{QuerySet, UpdateBuilder};
use sqlx::postgres::{PgArguments, PgPool, PgRow};
use sqlx::query::{Query, QueryAs};
use super::{Dialect, ExecError, Postgres};
pub trait Fetcher<T>
where
T: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
{
fn fetch(
self,
pool: &PgPool,
) -> impl std::future::Future<Output = Result<Vec<T>, ExecError>> + Send;
}
impl<T> Fetcher<T> for QuerySet<T>
where
T: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
{
async fn fetch(self, pool: &PgPool) -> Result<Vec<T>, ExecError> {
let select = self.compile()?;
let stmt = Postgres.compile_select(&select)?;
let mut q: QueryAs<'_, sqlx::Postgres, T, PgArguments> = sqlx::query_as::<_, T>(&stmt.sql);
for value in stmt.params {
q = bind_query_as(q, value);
}
let rows = q.fetch_all(pool).await?;
Ok(rows)
}
}
pub async fn insert(pool: &PgPool, query: &InsertQuery) -> Result<(), ExecError> {
query.validate()?;
let stmt = Postgres.compile_insert(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
q.execute(pool).await?;
Ok(())
}
pub async fn insert_returning(pool: &PgPool, query: &InsertQuery) -> Result<PgRow, ExecError> {
if query.returning.is_empty() {
return Err(ExecError::EmptyReturning);
}
query.validate()?;
let stmt = Postgres.compile_insert(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
let row = q.fetch_one(pool).await?;
Ok(row)
}
pub async fn bulk_insert(
pool: &PgPool,
query: &BulkInsertQuery,
) -> Result<Vec<PgRow>, ExecError> {
query.validate()?;
let stmt = Postgres.compile_bulk_insert(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
if query.returning.is_empty() {
q.execute(pool).await?;
Ok(Vec::new())
} else {
Ok(q.fetch_all(pool).await?)
}
}
pub async fn update(pool: &PgPool, query: &UpdateQuery) -> Result<u64, ExecError> {
query.validate()?;
let stmt = Postgres.compile_update(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
let result = q.execute(pool).await?;
Ok(result.rows_affected())
}
pub async fn delete(pool: &PgPool, query: &DeleteQuery) -> Result<u64, ExecError> {
let stmt = Postgres.compile_delete(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
let result = q.execute(pool).await?;
Ok(result.rows_affected())
}
pub async fn select_rows(pool: &PgPool, query: &SelectQuery) -> Result<Vec<PgRow>, ExecError> {
let stmt = Postgres.compile_select(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
Ok(q.fetch_all(pool).await?)
}
pub async fn select_one_row(
pool: &PgPool,
query: &SelectQuery,
) -> Result<Option<PgRow>, ExecError> {
let stmt = Postgres.compile_select(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
Ok(q.fetch_optional(pool).await?)
}
pub async fn count_rows(pool: &PgPool, query: &CountQuery) -> Result<i64, ExecError> {
let stmt = Postgres.compile_count(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
let row = q.fetch_one(pool).await?;
Ok(sqlx::Row::try_get::<i64, _>(&row, 0)?)
}
pub trait Counter<T: Model + Send> {
fn count(
self,
pool: &PgPool,
) -> impl std::future::Future<Output = Result<i64, ExecError>> + Send;
}
impl<T: Model + Send> Counter<T> for QuerySet<T> {
async fn count(self, pool: &PgPool) -> Result<i64, ExecError> {
let select = self.compile()?;
count_rows(
pool,
&CountQuery {
model: select.model,
where_clause: select.where_clause,
},
)
.await
}
}
pub trait Deleter<T: Model + Send> {
fn delete(
self,
pool: &PgPool,
) -> impl std::future::Future<Output = Result<u64, ExecError>> + Send;
}
impl<T: Model + Send> Deleter<T> for QuerySet<T> {
async fn delete(self, pool: &PgPool) -> Result<u64, ExecError> {
let query = self.compile_delete()?;
delete(pool, &query).await
}
}
pub trait Updater<T: Model + Send> {
fn execute(
self,
pool: &PgPool,
) -> impl std::future::Future<Output = Result<u64, ExecError>> + Send;
}
impl<T: Model + Send> Updater<T> for UpdateBuilder<T> {
async fn execute(self, pool: &PgPool) -> Result<u64, ExecError> {
let query = self.compile()?;
update(pool, &query).await
}
}
macro_rules! bind_match {
($q:expr, $value:expr) => {
match $value {
SqlValue::Null => $q.bind(None::<String>),
SqlValue::I32(v) => $q.bind(v),
SqlValue::I64(v) => $q.bind(v),
SqlValue::F32(v) => $q.bind(v),
SqlValue::F64(v) => $q.bind(v),
SqlValue::Bool(v) => $q.bind(v),
SqlValue::String(v) => $q.bind(v),
SqlValue::DateTime(v) => $q.bind(v),
SqlValue::Date(v) => $q.bind(v),
SqlValue::Uuid(v) => $q.bind(v),
SqlValue::Json(_) => unreachable!(
"`SqlValue::Json` requires the `sqlx/json` feature, not enabled in v0.1"
),
SqlValue::List(_) => {
unreachable!("`SqlValue::List` is expanded to scalars by the SQL writer")
}
}
};
}
fn bind_query_as<T>(
q: QueryAs<'_, sqlx::Postgres, T, PgArguments>,
value: SqlValue,
) -> QueryAs<'_, sqlx::Postgres, T, PgArguments> {
bind_match!(q, value)
}
fn bind_query(
q: Query<'_, sqlx::Postgres, PgArguments>,
value: SqlValue,
) -> Query<'_, sqlx::Postgres, PgArguments> {
bind_match!(q, value)
}