use crate::core::{
AggregateQuery, BulkInsertQuery, BulkUpdateQuery, CountQuery, DeleteQuery, InsertQuery, Model,
SelectQuery, SqlValue, UpdateQuery,
};
use crate::query::{QuerySet, UpdateBuilder};
use sqlx::postgres::{PgArguments, PgPool, PgRow};
use sqlx::query::{Query, QueryAs};
use super::{Dialect, ExecError, Postgres};
#[doc(hidden)]
pub trait FkPkAccess {
fn __rustango_fk_pk(&self, field_name: &str) -> Option<i64>;
}
#[doc(hidden)]
pub trait LoadRelated {
fn __rustango_load_related(
&mut self,
row: &PgRow,
field_name: &str,
alias: &str,
) -> Result<bool, sqlx::Error>;
}
pub trait Fetcher<T>
where
T: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
{
fn fetch(
self,
pool: &PgPool,
) -> impl std::future::Future<Output = Result<Vec<T>, ExecError>> + Send;
}
impl<T> Fetcher<T> for QuerySet<T>
where
T: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
{
async fn fetch(self, pool: &PgPool) -> Result<Vec<T>, ExecError> {
let select = self.compile()?;
let stmt = Postgres.compile_select(&select)?;
let mut q: QueryAs<'_, sqlx::Postgres, T, PgArguments> = sqlx::query_as::<_, T>(&stmt.sql);
for value in stmt.params {
q = bind_query_as(q, value);
}
let rows = q.fetch_all(pool).await?;
Ok(rows)
}
}
impl<T> QuerySet<T>
where
T: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
{
pub async fn fetch_on<'c, E>(self, executor: E) -> Result<Vec<T>, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
T: LoadRelated,
{
let select = self.compile()?;
let select_related_aliases: Vec<&'static str> =
select.joins.iter().map(|j| j.alias).collect();
let stmt = Postgres.compile_select(&select)?;
if select_related_aliases.is_empty() {
let mut q: QueryAs<'_, sqlx::Postgres, T, PgArguments> =
sqlx::query_as::<_, T>(&stmt.sql);
for value in stmt.params {
q = bind_query_as(q, value);
}
let rows = q.fetch_all(executor).await?;
return Ok(rows);
}
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
let raw_rows = q.fetch_all(executor).await?;
let mut out = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
let mut t = T::from_row(row)?;
for alias in &select_related_aliases {
let _ = t.__rustango_load_related(row, alias, alias)?;
}
out.push(t);
}
Ok(out)
}
pub async fn fetch_paginated_on<'c, E>(self, executor: E) -> Result<Page<T>, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let select = self.compile()?;
let stmt = Postgres.compile_select(&select)?;
let sql = inject_total_count(&stmt.sql);
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&sql);
for value in stmt.params {
q = bind_query(q, value);
}
let raw_rows: Vec<PgRow> = q.fetch_all(executor).await?;
let total: i64 = raw_rows
.first()
.map(|row| sqlx::Row::try_get::<i64, _>(row, "__rustango_total"))
.transpose()?
.unwrap_or(0);
let mut rows = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
rows.push(T::from_row(row)?);
}
Ok(Page { rows, total })
}
pub async fn fetch_paginated(self, pool: &PgPool) -> Result<Page<T>, ExecError> {
self.fetch_paginated_on(pool).await
}
}
pub struct Page<T> {
pub rows: Vec<T>,
pub total: i64,
}
impl<T> Default for Page<T> {
fn default() -> Self {
Self { rows: Vec::new(), total: 0 }
}
}
fn inject_total_count(sql: &str) -> String {
if let Some(idx) = sql.find(" FROM ") {
let mut out = String::with_capacity(sql.len() + 48);
out.push_str(&sql[..idx]);
out.push_str(", COUNT(*) OVER () AS \"__rustango_total\"");
out.push_str(&sql[idx..]);
out
} else {
format!(
"/* rustango: fetch_paginated_on could not splice COUNT(*) OVER () \
into the compiled SELECT — anchor ` FROM ` not found. The query \
below will run unchanged but `total` will be 0. */ {sql}"
)
}
}
pub async fn insert(pool: &PgPool, query: &InsertQuery) -> Result<(), ExecError> {
insert_on(pool, query).await
}
pub async fn insert_on<'c, E>(executor: E, query: &InsertQuery) -> Result<(), ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
query.validate()?;
let stmt = Postgres.compile_insert(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
q.execute(executor).await?;
Ok(())
}
pub async fn insert_returning(pool: &PgPool, query: &InsertQuery) -> Result<PgRow, ExecError> {
insert_returning_on(pool, query).await
}
pub async fn insert_returning_on<'c, E>(
executor: E,
query: &InsertQuery,
) -> Result<PgRow, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
if query.returning.is_empty() {
return Err(ExecError::EmptyReturning);
}
query.validate()?;
let stmt = Postgres.compile_insert(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
let row = q.fetch_one(executor).await?;
Ok(row)
}
pub async fn bulk_insert(
pool: &PgPool,
query: &BulkInsertQuery,
) -> Result<Vec<PgRow>, ExecError> {
bulk_insert_on(pool, query).await
}
pub async fn bulk_insert_on<'c, E>(
executor: E,
query: &BulkInsertQuery,
) -> Result<Vec<PgRow>, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
query.validate()?;
let stmt = Postgres.compile_bulk_insert(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
if query.returning.is_empty() {
q.execute(executor).await?;
Ok(Vec::new())
} else {
Ok(q.fetch_all(executor).await?)
}
}
pub async fn update(pool: &PgPool, query: &UpdateQuery) -> Result<u64, ExecError> {
update_on(pool, query).await
}
pub async fn update_on<'c, E>(executor: E, query: &UpdateQuery) -> Result<u64, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
query.validate()?;
let stmt = Postgres.compile_update(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
let result = q.execute(executor).await?;
Ok(result.rows_affected())
}
pub async fn delete(pool: &PgPool, query: &DeleteQuery) -> Result<u64, ExecError> {
delete_on(pool, query).await
}
pub async fn delete_on<'c, E>(executor: E, query: &DeleteQuery) -> Result<u64, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let stmt = Postgres.compile_delete(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
let result = q.execute(executor).await?;
Ok(result.rows_affected())
}
pub async fn select_rows(pool: &PgPool, query: &SelectQuery) -> Result<Vec<PgRow>, ExecError> {
let stmt = Postgres.compile_select(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
Ok(q.fetch_all(pool).await?)
}
pub async fn select_one_row(
pool: &PgPool,
query: &SelectQuery,
) -> Result<Option<PgRow>, ExecError> {
let stmt = Postgres.compile_select(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
Ok(q.fetch_optional(pool).await?)
}
pub async fn count_rows(pool: &PgPool, query: &CountQuery) -> Result<i64, ExecError> {
count_rows_on(pool, query).await
}
pub async fn count_rows_on<'c, E>(executor: E, query: &CountQuery) -> Result<i64, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let stmt = Postgres.compile_count(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
let row = q.fetch_one(executor).await?;
Ok(sqlx::Row::try_get::<i64, _>(&row, 0)?)
}
pub async fn annotate_count_children<P>(
parent_qs: crate::query::QuerySet<P>,
child_table: &'static str,
child_fk_column: &'static str,
pool: &PgPool,
) -> Result<Vec<(P, i64)>, ExecError>
where
P: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
{
annotate_count_children_on(parent_qs, child_table, child_fk_column, pool).await
}
pub async fn annotate_count_children_on<'c, P, E>(
parent_qs: crate::query::QuerySet<P>,
child_table: &'static str,
child_fk_column: &'static str,
executor: E,
) -> Result<Vec<(P, i64)>, ExecError>
where
P: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
use std::fmt::Write as _;
let select = parent_qs.compile()?;
let parent = select.model;
let pk_field = parent.primary_key().ok_or(ExecError::MissingPrimaryKey {
table: parent.table,
})?;
let cols: Vec<&'static str> = parent.scalar_fields().map(|f| f.column).collect();
let mut sql = String::from("SELECT ");
for (i, col) in cols.iter().enumerate() {
if i > 0 {
sql.push_str(", ");
}
let _ = write!(sql, "\"{}\".\"{col}\"", parent.table);
}
let _ = write!(
sql,
", COUNT(\"{child_table}\".\"{child_pk}\") AS \"__annotated_count\" FROM \"{parent_table}\" LEFT JOIN \"{child_table}\" ON \"{child_table}\".\"{child_fk_column}\" = \"{parent_table}\".\"{parent_pk}\"",
parent_table = parent.table,
parent_pk = pk_field.column,
child_pk = "id",
);
let tail = crate::sql::postgres::compile_where_order_tail(
&select.where_clause,
select.search.as_ref(),
&select.order_by,
select.limit,
select.offset,
Some(parent.table),
Some(parent),
)?;
sql.push_str(" GROUP BY ");
for (i, col) in cols.iter().enumerate() {
if i > 0 {
sql.push_str(", ");
}
let _ = write!(sql, "\"{}\".\"{col}\"", parent.table);
}
sql.push_str(&tail.sql);
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&sql);
for param in tail.params {
q = bind_query(q, param);
}
let raw_rows = q.fetch_all(executor).await?;
let mut out = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
let parent_obj = P::from_row(row)?;
let count: i64 = sqlx::Row::try_get(row, "__annotated_count")?;
out.push((parent_obj, count));
}
Ok(out)
}
pub async fn fetch_with_prefetch<P, C>(
parent_qs: crate::query::QuerySet<P>,
child_fk_column: &'static str,
pool: &PgPool,
) -> Result<Vec<(P, Vec<C>)>, ExecError>
where
P: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin + LoadRelated + HasPkValue,
C: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin + LoadRelated + FkPkAccess,
{
let parents: Vec<P> = parent_qs.fetch(pool).await?;
if parents.is_empty() {
return Ok(Vec::new());
}
let pk_field = P::SCHEMA.primary_key().ok_or(ExecError::MissingPrimaryKey {
table: P::SCHEMA.table,
})?;
let mut parent_pks: Vec<i64> = Vec::with_capacity(parents.len());
for parent in &parents {
if let Some(pk) = sql_value_as_i64(&extract_pk_value(parent)) {
parent_pks.push(pk);
}
}
parent_pks.sort_unstable();
parent_pks.dedup();
if parent_pks.is_empty() {
return Ok(parents.into_iter().map(|p| (p, Vec::new())).collect());
}
let pk_values: Vec<crate::core::SqlValue> = parent_pks
.iter()
.copied()
.map(crate::core::SqlValue::I64)
.collect();
let children: Vec<C> = crate::query::QuerySet::<C>::new()
.filter(
child_fk_column,
crate::core::Op::In,
crate::core::SqlValue::List(pk_values),
)
.fetch(pool)
.await?;
let mut grouped: std::collections::HashMap<i64, Vec<C>> = std::collections::HashMap::new();
for child in children {
let Some(fk_pk) = child.__rustango_fk_pk(child_fk_column) else {
continue;
};
grouped.entry(fk_pk).or_default().push(child);
}
let mut out = Vec::with_capacity(parents.len());
for parent in parents {
let pk = sql_value_as_i64(&extract_pk_value(&parent)).unwrap_or(0);
let kids = grouped.remove(&pk).unwrap_or_default();
out.push((parent, kids));
}
let _ = pk_field; Ok(out)
}
fn extract_pk_value<P: HasPkValue>(parent: &P) -> crate::core::SqlValue {
parent.__rustango_pk_value_impl()
}
fn sql_value_as_i64(v: &crate::core::SqlValue) -> Option<i64> {
match v {
crate::core::SqlValue::I64(n) => Some(*n),
crate::core::SqlValue::I32(n) => Some(i64::from(*n)),
_ => None,
}
}
#[doc(hidden)]
pub trait HasPkValue {
fn __rustango_pk_value_impl(&self) -> crate::core::SqlValue;
}
pub trait Counter<T: Model + Send> {
fn count(
self,
pool: &PgPool,
) -> impl std::future::Future<Output = Result<i64, ExecError>> + Send;
}
impl<T: Model + Send> Counter<T> for QuerySet<T> {
async fn count(self, pool: &PgPool) -> Result<i64, ExecError> {
self.count_on(pool).await
}
}
impl<T: Model + Send> QuerySet<T> {
pub async fn count_on<'c, E>(self, executor: E) -> Result<i64, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let select = self.compile()?;
count_rows_on(
executor,
&CountQuery {
model: select.model,
where_clause: select.where_clause,
},
)
.await
}
}
pub trait Deleter<T: Model + Send> {
fn delete(
self,
pool: &PgPool,
) -> impl std::future::Future<Output = Result<u64, ExecError>> + Send;
}
impl<T: Model + Send> Deleter<T> for QuerySet<T> {
async fn delete(self, pool: &PgPool) -> Result<u64, ExecError> {
let query = self.compile_delete()?;
delete(pool, &query).await
}
}
pub trait Updater<T: Model + Send> {
fn execute(
self,
pool: &PgPool,
) -> impl std::future::Future<Output = Result<u64, ExecError>> + Send;
}
impl<T: Model + Send> Updater<T> for UpdateBuilder<T> {
async fn execute(self, pool: &PgPool) -> Result<u64, ExecError> {
let query = self.compile()?;
update(pool, &query).await
}
}
macro_rules! bind_match {
($q:expr, $value:expr) => {
match $value {
SqlValue::Null => $q.bind(None::<String>),
SqlValue::I32(v) => $q.bind(v),
SqlValue::I64(v) => $q.bind(v),
SqlValue::F32(v) => $q.bind(v),
SqlValue::F64(v) => $q.bind(v),
SqlValue::Bool(v) => $q.bind(v),
SqlValue::String(v) => $q.bind(v),
SqlValue::DateTime(v) => $q.bind(v),
SqlValue::Date(v) => $q.bind(v),
SqlValue::Uuid(v) => $q.bind(v),
SqlValue::Json(v) => $q.bind(sqlx::types::Json(v)),
SqlValue::List(_) => {
unreachable!("`SqlValue::List` is expanded to scalars by the SQL writer")
}
}
};
}
fn bind_query_as<T>(
q: QueryAs<'_, sqlx::Postgres, T, PgArguments>,
value: SqlValue,
) -> QueryAs<'_, sqlx::Postgres, T, PgArguments> {
bind_match!(q, value)
}
fn bind_query(
q: Query<'_, sqlx::Postgres, PgArguments>,
value: SqlValue,
) -> Query<'_, sqlx::Postgres, PgArguments> {
bind_match!(q, value)
}
pub async fn bulk_update(
pool: &PgPool,
query: &BulkUpdateQuery,
) -> Result<u64, ExecError> {
bulk_update_on(pool, query).await
}
pub async fn bulk_update_on<'c, E>(
executor: E,
query: &BulkUpdateQuery,
) -> Result<u64, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
if query.rows.is_empty() {
return Ok(0);
}
let stmt = Postgres.compile_bulk_update(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for p in stmt.params {
q = bind_query(q, p);
}
Ok(q.execute(executor).await?.rows_affected())
}
pub async fn raw_query<T>(
sql: &str,
binds: Vec<SqlValue>,
pool: &PgPool,
) -> Result<Vec<T>, ExecError>
where
T: for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
{
raw_query_on(sql, binds, pool).await
}
pub async fn raw_query_on<'c, T, E>(
sql: &str,
binds: Vec<SqlValue>,
executor: E,
) -> Result<Vec<T>, ExecError>
where
T: for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let mut q: QueryAs<'_, sqlx::Postgres, T, PgArguments> = sqlx::query_as(sql);
for b in binds {
q = bind_query_as(q, b);
}
Ok(q.fetch_all(executor).await?)
}
pub async fn raw_execute(
sql: &str,
binds: Vec<SqlValue>,
pool: &PgPool,
) -> Result<u64, ExecError> {
raw_execute_on(sql, binds, pool).await
}
pub async fn raw_execute_on<'c, E>(
sql: &str,
binds: Vec<SqlValue>,
executor: E,
) -> Result<u64, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(sql);
for b in binds {
q = bind_query(q, b);
}
Ok(q.execute(executor).await?.rows_affected())
}
pub async fn fetch_aggregate(
query: &AggregateQuery,
pool: &PgPool,
) -> Result<Vec<std::collections::HashMap<String, SqlValue>>, ExecError> {
fetch_aggregate_on(query, pool).await
}
pub async fn fetch_aggregate_on<'c, E>(
query: &AggregateQuery,
executor: E,
) -> Result<Vec<std::collections::HashMap<String, SqlValue>>, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let stmt = Postgres.compile_aggregate(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for p in stmt.params {
q = bind_query(q, p);
}
let raw_rows = q.fetch_all(executor).await?;
let mut out = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
use sqlx::{Column as _, Row as _};
let mut map = std::collections::HashMap::new();
for (i, col) in row.columns().iter().enumerate() {
let name = col.name().to_owned();
let val: SqlValue = if let Ok(v) = row.try_get::<i64, _>(i) {
SqlValue::I64(v)
} else if let Ok(v) = row.try_get::<i32, _>(i) {
SqlValue::I32(v)
} else if let Ok(v) = row.try_get::<f64, _>(i) {
SqlValue::F64(v)
} else if let Ok(v) = row.try_get::<bool, _>(i) {
SqlValue::Bool(v)
} else if let Ok(v) = row.try_get::<String, _>(i) {
SqlValue::String(v)
} else {
SqlValue::Null
};
map.insert(name, val);
}
out.push(map);
}
Ok(out)
}
pub async fn transaction<F, Fut, T>(pool: &PgPool, f: F) -> Result<T, ExecError>
where
F: FnOnce(&mut sqlx::PgConnection) -> Fut,
Fut: std::future::Future<Output = Result<T, ExecError>>,
{
let mut tx = pool.begin().await?;
match f(&mut *tx).await {
Ok(val) => {
tx.commit().await?;
Ok(val)
}
Err(e) => {
let _ = tx.rollback().await;
Err(e)
}
}
}