use crate::core::{
BulkInsertQuery, CountQuery, DeleteQuery, InsertQuery, Model, SelectQuery, SqlValue,
UpdateQuery,
};
use crate::query::{QuerySet, UpdateBuilder};
use sqlx::postgres::{PgArguments, PgPool, PgRow};
use sqlx::query::{Query, QueryAs};
use super::{Dialect, ExecError, Postgres};
#[doc(hidden)]
pub trait FkPkAccess {
fn __rustango_fk_pk(&self, field_name: &str) -> Option<i64>;
}
#[doc(hidden)]
pub trait LoadRelated {
fn __rustango_load_related(
&mut self,
row: &PgRow,
field_name: &str,
alias: &str,
) -> Result<bool, sqlx::Error>;
}
pub trait Fetcher<T>
where
T: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
{
fn fetch(
self,
pool: &PgPool,
) -> impl std::future::Future<Output = Result<Vec<T>, ExecError>> + Send;
}
impl<T> Fetcher<T> for QuerySet<T>
where
T: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
{
async fn fetch(self, pool: &PgPool) -> Result<Vec<T>, ExecError> {
let select = self.compile()?;
let stmt = Postgres.compile_select(&select)?;
let mut q: QueryAs<'_, sqlx::Postgres, T, PgArguments> = sqlx::query_as::<_, T>(&stmt.sql);
for value in stmt.params {
q = bind_query_as(q, value);
}
let rows = q.fetch_all(pool).await?;
Ok(rows)
}
}
impl<T> QuerySet<T>
where
T: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
{
pub async fn fetch_on<'c, E>(self, executor: E) -> Result<Vec<T>, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
T: LoadRelated,
{
let select = self.compile()?;
let select_related_aliases: Vec<&'static str> =
select.joins.iter().map(|j| j.alias).collect();
let stmt = Postgres.compile_select(&select)?;
if select_related_aliases.is_empty() {
let mut q: QueryAs<'_, sqlx::Postgres, T, PgArguments> =
sqlx::query_as::<_, T>(&stmt.sql);
for value in stmt.params {
q = bind_query_as(q, value);
}
let rows = q.fetch_all(executor).await?;
return Ok(rows);
}
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
let raw_rows = q.fetch_all(executor).await?;
let mut out = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
let mut t = T::from_row(row)?;
for alias in &select_related_aliases {
let _ = t.__rustango_load_related(row, alias, alias)?;
}
out.push(t);
}
Ok(out)
}
pub async fn fetch_paginated_on<'c, E>(self, executor: E) -> Result<Page<T>, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let select = self.compile()?;
let stmt = Postgres.compile_select(&select)?;
let sql = inject_total_count(&stmt.sql);
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&sql);
for value in stmt.params {
q = bind_query(q, value);
}
let raw_rows: Vec<PgRow> = q.fetch_all(executor).await?;
let total: i64 = raw_rows
.first()
.map(|row| sqlx::Row::try_get::<i64, _>(row, "__rustango_total"))
.transpose()?
.unwrap_or(0);
let mut rows = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
rows.push(T::from_row(row)?);
}
Ok(Page { rows, total })
}
pub async fn fetch_paginated(self, pool: &PgPool) -> Result<Page<T>, ExecError> {
self.fetch_paginated_on(pool).await
}
}
pub struct Page<T> {
pub rows: Vec<T>,
pub total: i64,
}
impl<T> Default for Page<T> {
fn default() -> Self {
Self { rows: Vec::new(), total: 0 }
}
}
fn inject_total_count(sql: &str) -> String {
if let Some(idx) = sql.find(" FROM ") {
let mut out = String::with_capacity(sql.len() + 48);
out.push_str(&sql[..idx]);
out.push_str(", COUNT(*) OVER () AS \"__rustango_total\"");
out.push_str(&sql[idx..]);
out
} else {
format!(
"/* rustango: fetch_paginated_on could not splice COUNT(*) OVER () \
into the compiled SELECT — anchor ` FROM ` not found. The query \
below will run unchanged but `total` will be 0. */ {sql}"
)
}
}
pub async fn insert(pool: &PgPool, query: &InsertQuery) -> Result<(), ExecError> {
insert_on(pool, query).await
}
pub async fn insert_on<'c, E>(executor: E, query: &InsertQuery) -> Result<(), ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
query.validate()?;
let stmt = Postgres.compile_insert(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
q.execute(executor).await?;
Ok(())
}
pub async fn insert_returning(pool: &PgPool, query: &InsertQuery) -> Result<PgRow, ExecError> {
insert_returning_on(pool, query).await
}
pub async fn insert_returning_on<'c, E>(
executor: E,
query: &InsertQuery,
) -> Result<PgRow, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
if query.returning.is_empty() {
return Err(ExecError::EmptyReturning);
}
query.validate()?;
let stmt = Postgres.compile_insert(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
let row = q.fetch_one(executor).await?;
Ok(row)
}
pub async fn bulk_insert(
pool: &PgPool,
query: &BulkInsertQuery,
) -> Result<Vec<PgRow>, ExecError> {
bulk_insert_on(pool, query).await
}
pub async fn bulk_insert_on<'c, E>(
executor: E,
query: &BulkInsertQuery,
) -> Result<Vec<PgRow>, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
query.validate()?;
let stmt = Postgres.compile_bulk_insert(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
if query.returning.is_empty() {
q.execute(executor).await?;
Ok(Vec::new())
} else {
Ok(q.fetch_all(executor).await?)
}
}
pub async fn update(pool: &PgPool, query: &UpdateQuery) -> Result<u64, ExecError> {
update_on(pool, query).await
}
pub async fn update_on<'c, E>(executor: E, query: &UpdateQuery) -> Result<u64, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
query.validate()?;
let stmt = Postgres.compile_update(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
let result = q.execute(executor).await?;
Ok(result.rows_affected())
}
pub async fn delete(pool: &PgPool, query: &DeleteQuery) -> Result<u64, ExecError> {
delete_on(pool, query).await
}
pub async fn delete_on<'c, E>(executor: E, query: &DeleteQuery) -> Result<u64, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let stmt = Postgres.compile_delete(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
let result = q.execute(executor).await?;
Ok(result.rows_affected())
}
pub async fn select_rows(pool: &PgPool, query: &SelectQuery) -> Result<Vec<PgRow>, ExecError> {
let stmt = Postgres.compile_select(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
Ok(q.fetch_all(pool).await?)
}
pub async fn select_one_row(
pool: &PgPool,
query: &SelectQuery,
) -> Result<Option<PgRow>, ExecError> {
let stmt = Postgres.compile_select(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
Ok(q.fetch_optional(pool).await?)
}
pub async fn count_rows(pool: &PgPool, query: &CountQuery) -> Result<i64, ExecError> {
count_rows_on(pool, query).await
}
pub async fn count_rows_on<'c, E>(executor: E, query: &CountQuery) -> Result<i64, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let stmt = Postgres.compile_count(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
let row = q.fetch_one(executor).await?;
Ok(sqlx::Row::try_get::<i64, _>(&row, 0)?)
}
pub async fn annotate_count_children<P>(
parent_qs: crate::query::QuerySet<P>,
child_table: &'static str,
child_fk_column: &'static str,
pool: &PgPool,
) -> Result<Vec<(P, i64)>, ExecError>
where
P: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
{
annotate_count_children_on(parent_qs, child_table, child_fk_column, pool).await
}
pub async fn annotate_count_children_on<'c, P, E>(
parent_qs: crate::query::QuerySet<P>,
child_table: &'static str,
child_fk_column: &'static str,
executor: E,
) -> Result<Vec<(P, i64)>, ExecError>
where
P: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
use std::fmt::Write as _;
let select = parent_qs.compile()?;
let parent = select.model;
let pk_field = parent.primary_key().ok_or(ExecError::MissingPrimaryKey {
table: parent.table,
})?;
let mut sql = String::from("SELECT ");
let cols: Vec<&'static str> = parent.scalar_fields().map(|f| f.column).collect();
for (i, col) in cols.iter().enumerate() {
if i > 0 {
sql.push_str(", ");
}
let _ = write!(sql, "\"{}\".\"{col}\"", parent.table);
}
let _ = write!(
sql,
", COUNT(\"{child_table}\".\"{child_pk}\") AS \"__annotated_count\" FROM \"{parent_table}\" LEFT JOIN \"{child_table}\" ON \"{child_table}\".\"{child_fk_column}\" = \"{parent_table}\".\"{parent_pk}\"",
parent_table = parent.table,
parent_pk = pk_field.column,
child_pk = "id", );
let _ = select;
sql.push_str(" GROUP BY ");
for (i, col) in cols.iter().enumerate() {
if i > 0 {
sql.push_str(", ");
}
let _ = write!(sql, "\"{}\".\"{col}\"", parent.table);
}
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&sql);
let _ = &mut q; let raw_rows = q.fetch_all(executor).await?;
let mut out = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
let parent_obj = P::from_row(row)?;
let count: i64 = sqlx::Row::try_get(row, "__annotated_count")?;
out.push((parent_obj, count));
}
Ok(out)
}
pub async fn fetch_with_prefetch<P, C>(
parent_qs: crate::query::QuerySet<P>,
child_fk_column: &'static str,
pool: &PgPool,
) -> Result<Vec<(P, Vec<C>)>, ExecError>
where
P: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin + LoadRelated + HasPkValue,
C: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin + LoadRelated + FkPkAccess,
{
let parents: Vec<P> = parent_qs.fetch(pool).await?;
if parents.is_empty() {
return Ok(Vec::new());
}
let pk_field = P::SCHEMA.primary_key().ok_or(ExecError::MissingPrimaryKey {
table: P::SCHEMA.table,
})?;
let mut parent_pks: Vec<i64> = Vec::with_capacity(parents.len());
for parent in &parents {
if let Some(pk) = sql_value_as_i64(&extract_pk_value(parent)) {
parent_pks.push(pk);
}
}
parent_pks.sort_unstable();
parent_pks.dedup();
if parent_pks.is_empty() {
return Ok(parents.into_iter().map(|p| (p, Vec::new())).collect());
}
let pk_values: Vec<crate::core::SqlValue> = parent_pks
.iter()
.copied()
.map(crate::core::SqlValue::I64)
.collect();
let children: Vec<C> = crate::query::QuerySet::<C>::new()
.filter(
child_fk_column,
crate::core::Op::In,
crate::core::SqlValue::List(pk_values),
)
.fetch(pool)
.await?;
let mut grouped: std::collections::HashMap<i64, Vec<C>> = std::collections::HashMap::new();
for child in children {
let Some(fk_pk) = child.__rustango_fk_pk(child_fk_column) else {
continue;
};
grouped.entry(fk_pk).or_default().push(child);
}
let mut out = Vec::with_capacity(parents.len());
for parent in parents {
let pk = sql_value_as_i64(&extract_pk_value(&parent)).unwrap_or(0);
let kids = grouped.remove(&pk).unwrap_or_default();
out.push((parent, kids));
}
let _ = pk_field; Ok(out)
}
fn extract_pk_value<P: HasPkValue>(parent: &P) -> crate::core::SqlValue {
parent.__rustango_pk_value_impl()
}
fn sql_value_as_i64(v: &crate::core::SqlValue) -> Option<i64> {
match v {
crate::core::SqlValue::I64(n) => Some(*n),
crate::core::SqlValue::I32(n) => Some(i64::from(*n)),
_ => None,
}
}
#[doc(hidden)]
pub trait HasPkValue {
fn __rustango_pk_value_impl(&self) -> crate::core::SqlValue;
}
pub trait Counter<T: Model + Send> {
fn count(
self,
pool: &PgPool,
) -> impl std::future::Future<Output = Result<i64, ExecError>> + Send;
}
impl<T: Model + Send> Counter<T> for QuerySet<T> {
async fn count(self, pool: &PgPool) -> Result<i64, ExecError> {
self.count_on(pool).await
}
}
impl<T: Model + Send> QuerySet<T> {
pub async fn count_on<'c, E>(self, executor: E) -> Result<i64, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let select = self.compile()?;
count_rows_on(
executor,
&CountQuery {
model: select.model,
where_clause: select.where_clause,
},
)
.await
}
}
pub trait Deleter<T: Model + Send> {
fn delete(
self,
pool: &PgPool,
) -> impl std::future::Future<Output = Result<u64, ExecError>> + Send;
}
impl<T: Model + Send> Deleter<T> for QuerySet<T> {
async fn delete(self, pool: &PgPool) -> Result<u64, ExecError> {
let query = self.compile_delete()?;
delete(pool, &query).await
}
}
pub trait Updater<T: Model + Send> {
fn execute(
self,
pool: &PgPool,
) -> impl std::future::Future<Output = Result<u64, ExecError>> + Send;
}
impl<T: Model + Send> Updater<T> for UpdateBuilder<T> {
async fn execute(self, pool: &PgPool) -> Result<u64, ExecError> {
let query = self.compile()?;
update(pool, &query).await
}
}
macro_rules! bind_match {
($q:expr, $value:expr) => {
match $value {
SqlValue::Null => $q.bind(None::<String>),
SqlValue::I32(v) => $q.bind(v),
SqlValue::I64(v) => $q.bind(v),
SqlValue::F32(v) => $q.bind(v),
SqlValue::F64(v) => $q.bind(v),
SqlValue::Bool(v) => $q.bind(v),
SqlValue::String(v) => $q.bind(v),
SqlValue::DateTime(v) => $q.bind(v),
SqlValue::Date(v) => $q.bind(v),
SqlValue::Uuid(v) => $q.bind(v),
SqlValue::Json(_) => unreachable!(
"`SqlValue::Json` requires the `sqlx/json` feature, not enabled in v0.1"
),
SqlValue::List(_) => {
unreachable!("`SqlValue::List` is expanded to scalars by the SQL writer")
}
}
};
}
fn bind_query_as<T>(
q: QueryAs<'_, sqlx::Postgres, T, PgArguments>,
value: SqlValue,
) -> QueryAs<'_, sqlx::Postgres, T, PgArguments> {
bind_match!(q, value)
}
fn bind_query(
q: Query<'_, sqlx::Postgres, PgArguments>,
value: SqlValue,
) -> Query<'_, sqlx::Postgres, PgArguments> {
bind_match!(q, value)
}