use crate::core::{
AggregateQuery, BulkInsertQuery, BulkUpdateQuery, CountQuery, DeleteQuery, InsertQuery, Model,
SelectQuery, SqlValue, UpdateQuery,
};
use crate::query::{QuerySet, UpdateBuilder};
use sqlx::postgres::{PgArguments, PgPool, PgRow};
use sqlx::query::{Query, QueryAs};
use super::{Dialect, ExecError, Postgres};
#[doc(hidden)]
pub trait FkPkAccess {
fn __rustango_fk_pk(&self, field_name: &str) -> Option<i64>;
fn __rustango_fk_pk_value(&self, field_name: &str) -> Option<crate::core::SqlValue>;
}
#[doc(hidden)]
pub trait LoadRelated {
fn __rustango_load_related(
&mut self,
row: &PgRow,
field_name: &str,
alias: &str,
) -> Result<bool, sqlx::Error>;
}
pub trait Fetcher<T>
where
T: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
{
fn fetch(
self,
pool: &PgPool,
) -> impl std::future::Future<Output = Result<Vec<T>, ExecError>> + Send;
}
impl<T> Fetcher<T> for QuerySet<T>
where
T: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
{
async fn fetch(self, pool: &PgPool) -> Result<Vec<T>, ExecError> {
let select = self.compile()?;
let stmt = Postgres.compile_select(&select)?;
let mut q: QueryAs<'_, sqlx::Postgres, T, PgArguments> = sqlx::query_as::<_, T>(&stmt.sql);
for value in stmt.params {
q = bind_query_as(q, value);
}
let rows = q.fetch_all(pool).await?;
Ok(rows)
}
}
impl<T> QuerySet<T>
where
T: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
{
pub async fn fetch_on<'c, E>(self, executor: E) -> Result<Vec<T>, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
T: LoadRelated,
{
let select = self.compile()?;
let select_related_aliases: Vec<&'static str> =
select.joins.iter().map(|j| j.alias).collect();
let stmt = Postgres.compile_select(&select)?;
if select_related_aliases.is_empty() {
let mut q: QueryAs<'_, sqlx::Postgres, T, PgArguments> =
sqlx::query_as::<_, T>(&stmt.sql);
for value in stmt.params {
q = bind_query_as(q, value);
}
let rows = q.fetch_all(executor).await?;
return Ok(rows);
}
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
let raw_rows = q.fetch_all(executor).await?;
let mut out = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
let mut t = T::from_row(row)?;
for alias in &select_related_aliases {
let _ = t.__rustango_load_related(row, alias, alias)?;
}
out.push(t);
}
Ok(out)
}
pub async fn fetch_paginated_on<'c, E>(self, executor: E) -> Result<Page<T>, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let select = self.compile()?;
let stmt = Postgres.compile_select(&select)?;
let sql = inject_total_count(&stmt.sql);
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&sql);
for value in stmt.params {
q = bind_query(q, value);
}
let raw_rows: Vec<PgRow> = q.fetch_all(executor).await?;
let total: i64 = raw_rows
.first()
.map(|row| sqlx::Row::try_get::<i64, _>(row, "__rustango_total"))
.transpose()?
.unwrap_or(0);
let mut rows = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
rows.push(T::from_row(row)?);
}
Ok(Page { rows, total })
}
pub async fn fetch_paginated(self, pool: &PgPool) -> Result<Page<T>, ExecError> {
self.fetch_paginated_on(pool).await
}
}
pub struct Page<T> {
pub rows: Vec<T>,
pub total: i64,
}
impl<T> Default for Page<T> {
fn default() -> Self {
Self {
rows: Vec::new(),
total: 0,
}
}
}
fn inject_total_count(sql: &str) -> String {
if let Some(idx) = sql.find(" FROM ") {
let mut out = String::with_capacity(sql.len() + 48);
out.push_str(&sql[..idx]);
out.push_str(", COUNT(*) OVER () AS \"__rustango_total\"");
out.push_str(&sql[idx..]);
out
} else {
format!(
"/* rustango: fetch_paginated_on could not splice COUNT(*) OVER () \
into the compiled SELECT — anchor ` FROM ` not found. The query \
below will run unchanged but `total` will be 0. */ {sql}"
)
}
}
pub async fn insert(pool: &PgPool, query: &InsertQuery) -> Result<(), ExecError> {
insert_on(pool, query).await
}
pub async fn insert_on<'c, E>(executor: E, query: &InsertQuery) -> Result<(), ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
query.validate()?;
let stmt = Postgres.compile_insert(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
q.execute(executor).await?;
Ok(())
}
pub async fn insert_returning(pool: &PgPool, query: &InsertQuery) -> Result<PgRow, ExecError> {
insert_returning_on(pool, query).await
}
pub async fn insert_returning_on<'c, E>(
executor: E,
query: &InsertQuery,
) -> Result<PgRow, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
if query.returning.is_empty() {
return Err(ExecError::EmptyReturning);
}
query.validate()?;
let stmt = Postgres.compile_insert(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
let row = q.fetch_one(executor).await?;
Ok(row)
}
pub async fn bulk_insert(pool: &PgPool, query: &BulkInsertQuery) -> Result<Vec<PgRow>, ExecError> {
bulk_insert_on(pool, query).await
}
pub async fn bulk_insert_on<'c, E>(
executor: E,
query: &BulkInsertQuery,
) -> Result<Vec<PgRow>, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
query.validate()?;
let stmt = Postgres.compile_bulk_insert(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
if query.returning.is_empty() {
q.execute(executor).await?;
Ok(Vec::new())
} else {
Ok(q.fetch_all(executor).await?)
}
}
pub async fn update(pool: &PgPool, query: &UpdateQuery) -> Result<u64, ExecError> {
update_on(pool, query).await
}
pub async fn update_on<'c, E>(executor: E, query: &UpdateQuery) -> Result<u64, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
query.validate()?;
let stmt = Postgres.compile_update(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
let result = q.execute(executor).await?;
Ok(result.rows_affected())
}
pub async fn delete(pool: &PgPool, query: &DeleteQuery) -> Result<u64, ExecError> {
delete_on(pool, query).await
}
pub async fn delete_on<'c, E>(executor: E, query: &DeleteQuery) -> Result<u64, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let stmt = Postgres.compile_delete(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
let result = q.execute(executor).await?;
Ok(result.rows_affected())
}
pub async fn select_rows(pool: &PgPool, query: &SelectQuery) -> Result<Vec<PgRow>, ExecError> {
select_rows_on(pool, query).await
}
#[must_use]
pub fn row_to_json(
row: &sqlx::postgres::PgRow,
fields: &[&'static crate::core::FieldSchema],
) -> serde_json::Value {
use crate::core::FieldType;
use serde_json::{json, Value};
use sqlx::Row as _;
let mut map = serde_json::Map::new();
for field in fields {
let value = match field.ty {
FieldType::I16 => row
.try_get::<i16, _>(field.column)
.map(|n| json!(n))
.unwrap_or(Value::Null),
FieldType::I32 => row
.try_get::<i32, _>(field.column)
.map(|n| json!(n))
.unwrap_or(Value::Null),
FieldType::I64 => row
.try_get::<i64, _>(field.column)
.map(|n| json!(n))
.unwrap_or(Value::Null),
FieldType::F32 => row
.try_get::<f32, _>(field.column)
.map(|n| json!(n))
.unwrap_or(Value::Null),
FieldType::F64 => row
.try_get::<f64, _>(field.column)
.map(|n| json!(n))
.unwrap_or(Value::Null),
FieldType::Bool => row
.try_get::<bool, _>(field.column)
.map(|b| json!(b))
.unwrap_or(Value::Null),
FieldType::String => row
.try_get::<String, _>(field.column)
.map(|s| json!(s))
.unwrap_or(Value::Null),
FieldType::Date => row
.try_get::<chrono::NaiveDate, _>(field.column)
.map(|d| json!(d.to_string()))
.unwrap_or(Value::Null),
FieldType::DateTime => row
.try_get::<chrono::DateTime<chrono::Utc>, _>(field.column)
.map(|dt| json!(dt.to_rfc3339()))
.unwrap_or(Value::Null),
FieldType::Uuid => row
.try_get::<uuid::Uuid, _>(field.column)
.map(|u| json!(u.to_string()))
.unwrap_or(Value::Null),
FieldType::Json => row
.try_get::<serde_json::Value, _>(field.column)
.unwrap_or(Value::Null),
};
map.insert(field.name.to_owned(), value);
}
Value::Object(map)
}
pub async fn select_rows_on<'c, E>(
executor: E,
query: &SelectQuery,
) -> Result<Vec<PgRow>, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let stmt = Postgres.compile_select(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
Ok(q.fetch_all(executor).await?)
}
pub async fn select_one_row(
pool: &PgPool,
query: &SelectQuery,
) -> Result<Option<PgRow>, ExecError> {
select_one_row_on(pool, query).await
}
pub async fn select_one_row_on<'c, E>(
executor: E,
query: &SelectQuery,
) -> Result<Option<PgRow>, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let stmt = Postgres.compile_select(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
Ok(q.fetch_optional(executor).await?)
}
pub async fn count_rows(pool: &PgPool, query: &CountQuery) -> Result<i64, ExecError> {
count_rows_on(pool, query).await
}
pub async fn count_rows_on<'c, E>(executor: E, query: &CountQuery) -> Result<i64, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let stmt = Postgres.compile_count(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
let row = q.fetch_one(executor).await?;
Ok(sqlx::Row::try_get::<i64, _>(&row, 0)?)
}
pub async fn annotate_count_children<P>(
parent_qs: crate::query::QuerySet<P>,
child_table: &'static str,
child_fk_column: &'static str,
pool: &PgPool,
) -> Result<Vec<(P, i64)>, ExecError>
where
P: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
{
annotate_count_children_on(parent_qs, child_table, child_fk_column, pool).await
}
pub async fn annotate_count_children_on<'c, P, E>(
parent_qs: crate::query::QuerySet<P>,
child_table: &'static str,
child_fk_column: &'static str,
executor: E,
) -> Result<Vec<(P, i64)>, ExecError>
where
P: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
use std::fmt::Write as _;
let select = parent_qs.compile()?;
let parent = select.model;
let pk_field = parent.primary_key().ok_or(ExecError::MissingPrimaryKey {
table: parent.table,
})?;
let cols: Vec<&'static str> = parent.scalar_fields().map(|f| f.column).collect();
let mut sql = String::from("SELECT ");
for (i, col) in cols.iter().enumerate() {
if i > 0 {
sql.push_str(", ");
}
let _ = write!(sql, "\"{}\".\"{col}\"", parent.table);
}
let _ = write!(
sql,
", COUNT(\"{child_table}\".\"{child_pk}\") AS \"__annotated_count\" FROM \"{parent_table}\" LEFT JOIN \"{child_table}\" ON \"{child_table}\".\"{child_fk_column}\" = \"{parent_table}\".\"{parent_pk}\"",
parent_table = parent.table,
parent_pk = pk_field.column,
child_pk = "id",
);
let tail = crate::sql::postgres::compile_where_order_tail(
&select.where_clause,
select.search.as_ref(),
&select.order_by,
select.limit,
select.offset,
Some(parent.table),
Some(parent),
)?;
sql.push_str(" GROUP BY ");
for (i, col) in cols.iter().enumerate() {
if i > 0 {
sql.push_str(", ");
}
let _ = write!(sql, "\"{}\".\"{col}\"", parent.table);
}
sql.push_str(&tail.sql);
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&sql);
for param in tail.params {
q = bind_query(q, param);
}
let raw_rows = q.fetch_all(executor).await?;
let mut out = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
let parent_obj = P::from_row(row)?;
let count: i64 = sqlx::Row::try_get(row, "__annotated_count")?;
out.push((parent_obj, count));
}
Ok(out)
}
pub async fn fetch_with_prefetch<P, C>(
parent_qs: crate::query::QuerySet<P>,
child_fk_column: &'static str,
pool: &PgPool,
) -> Result<Vec<(P, Vec<C>)>, ExecError>
where
P: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin + LoadRelated + HasPkValue,
C: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin + LoadRelated + FkPkAccess,
{
let parents: Vec<P> = parent_qs.fetch(pool).await?;
if parents.is_empty() {
return Ok(Vec::new());
}
let pk_field = P::SCHEMA
.primary_key()
.ok_or(ExecError::MissingPrimaryKey {
table: P::SCHEMA.table,
})?;
let mut parent_pks: Vec<crate::core::SqlValue> = Vec::with_capacity(parents.len());
for parent in &parents {
let pk = extract_pk_value(parent);
if !matches!(pk, crate::core::SqlValue::Null) {
parent_pks.push(pk);
}
}
{
let mut seen = std::collections::HashSet::new();
parent_pks.retain(|v| seen.insert(v.to_display_string()));
}
if parent_pks.is_empty() {
return Ok(parents.into_iter().map(|p| (p, Vec::new())).collect());
}
let children: Vec<C> = crate::query::QuerySet::<C>::new()
.filter(
child_fk_column,
crate::core::Op::In,
crate::core::SqlValue::List(parent_pks),
)
.fetch(pool)
.await?;
let mut grouped: std::collections::HashMap<String, Vec<C>> = std::collections::HashMap::new();
for child in children {
let Some(fk_pk) = child.__rustango_fk_pk_value(child_fk_column) else {
continue;
};
grouped
.entry(fk_pk.to_display_string())
.or_default()
.push(child);
}
let mut out = Vec::with_capacity(parents.len());
for parent in parents {
let pk = extract_pk_value(&parent).to_display_string();
let kids = grouped.remove(&pk).unwrap_or_default();
out.push((parent, kids));
}
let _ = pk_field; Ok(out)
}
fn extract_pk_value<P: HasPkValue>(parent: &P) -> crate::core::SqlValue {
parent.__rustango_pk_value_impl()
}
#[doc(hidden)]
pub trait HasPkValue {
fn __rustango_pk_value_impl(&self) -> crate::core::SqlValue;
}
pub trait Counter<T: Model + Send> {
fn count(
self,
pool: &PgPool,
) -> impl std::future::Future<Output = Result<i64, ExecError>> + Send;
}
impl<T: Model + Send> Counter<T> for QuerySet<T> {
async fn count(self, pool: &PgPool) -> Result<i64, ExecError> {
self.count_on(pool).await
}
}
impl<T: Model + Send> QuerySet<T> {
pub async fn count_on<'c, E>(self, executor: E) -> Result<i64, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let select = self.compile()?;
count_rows_on(
executor,
&CountQuery {
model: select.model,
where_clause: select.where_clause,
search: select.search,
},
)
.await
}
pub async fn explain(self, pool: &PgPool) -> Result<Vec<String>, ExecError> {
self.explain_on(pool, ExplainOptions::default()).await
}
pub async fn explain_on<'c, E>(
self,
executor: E,
options: ExplainOptions,
) -> Result<Vec<String>, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let select = self.compile()?;
let stmt = Postgres.compile_select(&select)?;
let mut sql = String::with_capacity(stmt.sql.len() + 32);
sql.push_str("EXPLAIN ");
let prefix = options.to_clause();
if !prefix.is_empty() {
sql.push_str(&prefix);
sql.push(' ');
}
sql.push_str(&stmt.sql);
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&sql);
for value in stmt.params {
q = bind_query(q, value);
}
let rows = q.fetch_all(executor).await?;
let mut out = Vec::with_capacity(rows.len());
for row in &rows {
let line: String = match options.format {
ExplainFormat::Json => {
let v: serde_json::Value = sqlx::Row::try_get(row, 0)?;
v.to_string()
}
ExplainFormat::Text | ExplainFormat::Yaml | ExplainFormat::Xml => {
sqlx::Row::try_get(row, 0)?
}
};
out.push(line);
}
Ok(out)
}
}
#[derive(Debug, Clone, Default)]
pub struct ExplainOptions {
pub analyze: bool,
pub buffers: bool,
pub verbose: bool,
pub format: ExplainFormat,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum ExplainFormat {
#[default]
Text,
Json,
Yaml,
Xml,
}
impl ExplainOptions {
fn to_clause(&self) -> String {
let mut bits: Vec<&'static str> = Vec::new();
if self.analyze {
bits.push("ANALYZE");
}
if self.buffers {
bits.push("BUFFERS");
}
if self.verbose {
bits.push("VERBOSE");
}
let format_bit = match self.format {
ExplainFormat::Text => None,
ExplainFormat::Json => Some("FORMAT JSON"),
ExplainFormat::Yaml => Some("FORMAT YAML"),
ExplainFormat::Xml => Some("FORMAT XML"),
};
if let Some(f) = format_bit {
bits.push(f);
}
if bits.is_empty() {
String::new()
} else {
format!("({})", bits.join(", "))
}
}
}
pub trait Deleter<T: Model + Send> {
fn delete(
self,
pool: &PgPool,
) -> impl std::future::Future<Output = Result<u64, ExecError>> + Send;
}
impl<T: Model + Send> Deleter<T> for QuerySet<T> {
async fn delete(self, pool: &PgPool) -> Result<u64, ExecError> {
let query = self.compile_delete()?;
delete(pool, &query).await
}
}
pub trait Updater<T: Model + Send> {
fn execute(
self,
pool: &PgPool,
) -> impl std::future::Future<Output = Result<u64, ExecError>> + Send;
}
impl<T: Model + Send> Updater<T> for UpdateBuilder<T> {
async fn execute(self, pool: &PgPool) -> Result<u64, ExecError> {
let query = self.compile()?;
update(pool, &query).await
}
}
macro_rules! bind_match {
($q:expr, $value:expr) => {
match $value {
SqlValue::Null => $q.bind(None::<String>),
SqlValue::I16(v) => $q.bind(v),
SqlValue::I32(v) => $q.bind(v),
SqlValue::I64(v) => $q.bind(v),
SqlValue::F32(v) => $q.bind(v),
SqlValue::F64(v) => $q.bind(v),
SqlValue::Bool(v) => $q.bind(v),
SqlValue::String(v) => $q.bind(v),
SqlValue::DateTime(v) => $q.bind(v),
SqlValue::Date(v) => $q.bind(v),
SqlValue::Uuid(v) => $q.bind(v),
SqlValue::Json(v) => $q.bind(sqlx::types::Json(v)),
SqlValue::List(_) => {
unreachable!("`SqlValue::List` is expanded to scalars by the SQL writer")
}
}
};
}
fn bind_query_as<T>(
q: QueryAs<'_, sqlx::Postgres, T, PgArguments>,
value: SqlValue,
) -> QueryAs<'_, sqlx::Postgres, T, PgArguments> {
bind_match!(q, value)
}
fn bind_query(
q: Query<'_, sqlx::Postgres, PgArguments>,
value: SqlValue,
) -> Query<'_, sqlx::Postgres, PgArguments> {
bind_match!(q, value)
}
pub async fn bulk_update(pool: &PgPool, query: &BulkUpdateQuery) -> Result<u64, ExecError> {
bulk_update_on(pool, query).await
}
pub async fn bulk_update_on<'c, E>(executor: E, query: &BulkUpdateQuery) -> Result<u64, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
if query.rows.is_empty() {
return Ok(0);
}
let stmt = Postgres.compile_bulk_update(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for p in stmt.params {
q = bind_query(q, p);
}
Ok(q.execute(executor).await?.rows_affected())
}
pub async fn raw_query<T>(
sql: &str,
binds: Vec<SqlValue>,
pool: &PgPool,
) -> Result<Vec<T>, ExecError>
where
T: for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
{
raw_query_on(sql, binds, pool).await
}
pub async fn raw_query_on<'c, T, E>(
sql: &str,
binds: Vec<SqlValue>,
executor: E,
) -> Result<Vec<T>, ExecError>
where
T: for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let mut q: QueryAs<'_, sqlx::Postgres, T, PgArguments> = sqlx::query_as(sql);
for b in binds {
q = bind_query_as(q, b);
}
Ok(q.fetch_all(executor).await?)
}
pub async fn raw_execute(sql: &str, binds: Vec<SqlValue>, pool: &PgPool) -> Result<u64, ExecError> {
raw_execute_on(sql, binds, pool).await
}
pub async fn raw_execute_on<'c, E>(
sql: &str,
binds: Vec<SqlValue>,
executor: E,
) -> Result<u64, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(sql);
for b in binds {
q = bind_query(q, b);
}
Ok(q.execute(executor).await?.rows_affected())
}
pub async fn fetch_aggregate(
query: &AggregateQuery,
pool: &PgPool,
) -> Result<Vec<std::collections::HashMap<String, SqlValue>>, ExecError> {
fetch_aggregate_on(query, pool).await
}
pub async fn fetch_aggregate_on<'c, E>(
query: &AggregateQuery,
executor: E,
) -> Result<Vec<std::collections::HashMap<String, SqlValue>>, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let stmt = Postgres.compile_aggregate(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for p in stmt.params {
q = bind_query(q, p);
}
let raw_rows = q.fetch_all(executor).await?;
let mut out = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
use sqlx::{Column as _, Row as _};
let mut map = std::collections::HashMap::new();
for (i, col) in row.columns().iter().enumerate() {
let name = col.name().to_owned();
let val: SqlValue = if let Ok(v) = row.try_get::<i64, _>(i) {
SqlValue::I64(v)
} else if let Ok(v) = row.try_get::<i32, _>(i) {
SqlValue::I32(v)
} else if let Ok(v) = row.try_get::<f64, _>(i) {
SqlValue::F64(v)
} else if let Ok(v) = row.try_get::<bool, _>(i) {
SqlValue::Bool(v)
} else if let Ok(v) = row.try_get::<String, _>(i) {
SqlValue::String(v)
} else {
SqlValue::Null
};
map.insert(name, val);
}
out.push(map);
}
Ok(out)
}
pub async fn transaction<F, Fut, T>(pool: &PgPool, f: F) -> Result<T, ExecError>
where
F: FnOnce(&mut sqlx::PgConnection) -> Fut,
Fut: std::future::Future<Output = Result<T, ExecError>>,
{
let mut tx = pool.begin().await?;
match f(&mut *tx).await {
Ok(val) => {
tx.commit().await?;
Ok(val)
}
Err(e) => {
let _ = tx.rollback().await;
Err(e)
}
}
}
pub enum PoolTx<'a> {
#[cfg(feature = "postgres")]
Postgres(sqlx::Transaction<'a, sqlx::Postgres>),
#[cfg(feature = "mysql")]
Mysql(sqlx::Transaction<'a, sqlx::MySql>),
#[cfg(feature = "sqlite")]
Sqlite(sqlx::Transaction<'a, sqlx::Sqlite>),
}
impl<'a> PoolTx<'a> {
pub async fn commit(self) -> Result<(), sqlx::Error> {
match self {
#[cfg(feature = "postgres")]
PoolTx::Postgres(tx) => tx.commit().await,
#[cfg(feature = "mysql")]
PoolTx::Mysql(tx) => tx.commit().await,
#[cfg(feature = "sqlite")]
PoolTx::Sqlite(tx) => tx.commit().await,
}
}
pub async fn rollback(self) -> Result<(), sqlx::Error> {
match self {
#[cfg(feature = "postgres")]
PoolTx::Postgres(tx) => tx.rollback().await,
#[cfg(feature = "mysql")]
PoolTx::Mysql(tx) => tx.rollback().await,
#[cfg(feature = "sqlite")]
PoolTx::Sqlite(tx) => tx.rollback().await,
}
}
}
pub async fn transaction_pool(pool: &Pool) -> Result<PoolTx<'_>, ExecError> {
match pool {
#[cfg(feature = "postgres")]
Pool::Postgres(pg) => Ok(PoolTx::Postgres(pg.begin().await?)),
#[cfg(feature = "mysql")]
Pool::Mysql(my) => Ok(PoolTx::Mysql(my.begin().await?)),
#[cfg(feature = "sqlite")]
Pool::Sqlite(sq) => Ok(PoolTx::Sqlite(sq.begin().await?)),
}
}
use super::Pool;
#[cfg(feature = "mysql")]
fn bind_query_my(
q: sqlx::query::Query<'_, sqlx::MySql, sqlx::mysql::MySqlArguments>,
value: SqlValue,
) -> sqlx::query::Query<'_, sqlx::MySql, sqlx::mysql::MySqlArguments> {
bind_match!(q, value)
}
#[cfg(feature = "sqlite")]
fn bind_query_sqlite<'a>(
q: sqlx::query::Query<'a, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'a>>,
value: SqlValue,
) -> sqlx::query::Query<'a, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'a>> {
bind_match!(q, value)
}
pub async fn insert_pool(pool: &Pool, query: &InsertQuery) -> Result<(), ExecError> {
query.validate()?;
let stmt = pool.dialect().compile_insert(query)?;
execute_pool(pool, &stmt.sql, stmt.params).await?;
Ok(())
}
pub async fn insert_returning_pool(
pool: &Pool,
query: &InsertQuery,
) -> Result<InsertReturningPool, ExecError> {
query.validate()?;
if query.returning.is_empty() {
return Err(ExecError::EmptyReturning);
}
match pool {
#[cfg(feature = "postgres")]
Pool::Postgres(pg) => {
let row = insert_returning_on(pg, query).await?;
Ok(InsertReturningPool::PgRow(row))
}
#[cfg(feature = "mysql")]
Pool::Mysql(my) => {
let plain = InsertQuery {
model: query.model,
columns: query.columns.clone(),
values: query.values.clone(),
returning: ::std::vec::Vec::new(),
on_conflict: query.on_conflict.clone(),
};
let stmt = pool.dialect().compile_insert(&plain)?;
let mut conn = my.acquire().await?;
let mut q: sqlx::query::Query<'_, sqlx::MySql, sqlx::mysql::MySqlArguments> =
sqlx::query(&stmt.sql);
for v in stmt.params {
q = bind_query_my(q, v);
}
q.execute(&mut *conn).await?;
use sqlx::Row as _;
let row = sqlx::query("SELECT LAST_INSERT_ID()")
.fetch_one(&mut *conn)
.await?;
let id_u64: u64 = row.try_get::<u64, _>(0)?;
let id = i64::try_from(id_u64).unwrap_or(i64::MAX);
Ok(InsertReturningPool::MySqlAutoId(id))
}
#[cfg(feature = "sqlite")]
Pool::Sqlite(sq) => {
let stmt = pool.dialect().compile_insert(query)?;
let mut q: sqlx::query::Query<'_, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'_>> =
sqlx::query(&stmt.sql);
for v in stmt.params {
q = bind_query_sqlite(q, v);
}
let row = q.fetch_one(sq).await?;
Ok(InsertReturningPool::SqliteRow(row))
}
}
}
pub enum InsertReturningPool {
#[cfg(feature = "postgres")]
PgRow(PgRow),
#[cfg(feature = "mysql")]
MySqlAutoId(i64),
#[cfg(feature = "sqlite")]
SqliteRow(sqlx::sqlite::SqliteRow),
}
impl ::core::fmt::Debug for InsertReturningPool {
fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
match self {
#[cfg(feature = "postgres")]
Self::PgRow(_) => f.debug_tuple("PgRow").field(&"<PgRow>").finish(),
#[cfg(feature = "mysql")]
Self::MySqlAutoId(id) => f.debug_tuple("MySqlAutoId").field(id).finish(),
#[cfg(feature = "sqlite")]
Self::SqliteRow(_) => f.debug_tuple("SqliteRow").field(&"<SqliteRow>").finish(),
}
}
}
pub async fn update_pool(pool: &Pool, query: &UpdateQuery) -> Result<u64, ExecError> {
let stmt = pool.dialect().compile_update(query)?;
execute_pool(pool, &stmt.sql, stmt.params).await
}
pub async fn delete_pool(pool: &Pool, query: &DeleteQuery) -> Result<u64, ExecError> {
let stmt = pool.dialect().compile_delete(query)?;
execute_pool(pool, &stmt.sql, stmt.params).await
}
pub async fn count_rows_pool(pool: &Pool, query: &CountQuery) -> Result<i64, ExecError> {
let stmt = pool.dialect().compile_count(query)?;
fetch_scalar_pool(pool, &stmt.sql, stmt.params).await
}
pub async fn bulk_insert_pool(pool: &Pool, query: &BulkInsertQuery) -> Result<(), ExecError> {
if query.rows.is_empty() {
return Ok(());
}
let stmt = pool.dialect().compile_bulk_insert(query)?;
execute_pool(pool, &stmt.sql, stmt.params).await?;
Ok(())
}
pub async fn bulk_update_pool(pool: &Pool, query: &BulkUpdateQuery) -> Result<u64, ExecError> {
if query.rows.is_empty() {
return Ok(0);
}
let stmt = pool.dialect().compile_bulk_update(query)?;
execute_pool(pool, &stmt.sql, stmt.params).await
}
pub async fn raw_execute_pool(
pool: &Pool,
sql: &str,
binds: Vec<SqlValue>,
) -> Result<u64, ExecError> {
execute_pool(pool, sql, binds).await
}
async fn execute_pool(pool: &Pool, sql: &str, binds: Vec<SqlValue>) -> Result<u64, ExecError> {
match pool {
#[cfg(feature = "postgres")]
Pool::Postgres(pg) => {
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(sql);
for v in binds {
q = bind_query(q, v);
}
Ok(q.execute(pg).await?.rows_affected())
}
#[cfg(feature = "mysql")]
Pool::Mysql(my) => {
let mut q: sqlx::query::Query<'_, sqlx::MySql, sqlx::mysql::MySqlArguments> =
sqlx::query(sql);
for v in binds {
q = bind_query_my(q, v);
}
Ok(q.execute(my).await?.rows_affected())
}
#[cfg(feature = "sqlite")]
Pool::Sqlite(sq) => {
let mut q: sqlx::query::Query<'_, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'_>> =
sqlx::query(sql);
for v in binds {
q = bind_query_sqlite(q, v);
}
Ok(q.execute(sq).await?.rows_affected())
}
}
}
async fn fetch_scalar_pool(pool: &Pool, sql: &str, binds: Vec<SqlValue>) -> Result<i64, ExecError> {
match pool {
#[cfg(feature = "postgres")]
Pool::Postgres(pg) => {
use sqlx::Row as _;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(sql);
for v in binds {
q = bind_query(q, v);
}
let row = q.fetch_one(pg).await?;
Ok(row.try_get::<i64, _>(0)?)
}
#[cfg(feature = "mysql")]
Pool::Mysql(my) => {
use sqlx::Row as _;
let mut q: sqlx::query::Query<'_, sqlx::MySql, sqlx::mysql::MySqlArguments> =
sqlx::query(sql);
for v in binds {
q = bind_query_my(q, v);
}
let row = q.fetch_one(my).await?;
Ok(row.try_get::<i64, _>(0)?)
}
#[cfg(feature = "sqlite")]
Pool::Sqlite(sq) => {
use sqlx::Row as _;
let mut q: sqlx::query::Query<'_, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'_>> =
sqlx::query(sql);
for v in binds {
q = bind_query_sqlite(q, v);
}
let row = q.fetch_one(sq).await?;
Ok(row.try_get::<i64, _>(0)?)
}
}
}
#[cfg(feature = "mysql")]
pub trait MaybeMyFromRow: for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow> {}
#[cfg(feature = "mysql")]
impl<T> MaybeMyFromRow for T where T: for<'r> sqlx::FromRow<'r, sqlx::mysql::MySqlRow> {}
#[cfg(not(feature = "mysql"))]
pub trait MaybeMyFromRow {}
#[cfg(not(feature = "mysql"))]
impl<T> MaybeMyFromRow for T {}
#[cfg(feature = "mysql")]
pub trait LoadRelatedMy {
fn __rustango_load_related_my(
&mut self,
row: &sqlx::mysql::MySqlRow,
field_name: &str,
alias: &str,
) -> Result<bool, sqlx::Error>;
}
#[cfg(feature = "mysql")]
pub trait MaybeMyLoadRelated: LoadRelatedMy {}
#[cfg(feature = "mysql")]
impl<T> MaybeMyLoadRelated for T where T: LoadRelatedMy {}
#[cfg(not(feature = "mysql"))]
pub trait MaybeMyLoadRelated {}
#[cfg(not(feature = "mysql"))]
impl<T> MaybeMyLoadRelated for T {}
#[cfg(feature = "sqlite")]
pub trait MaybeSqliteFromRow: for<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> {}
#[cfg(feature = "sqlite")]
impl<T> MaybeSqliteFromRow for T where T: for<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow> {}
#[cfg(not(feature = "sqlite"))]
#[allow(dead_code)]
pub trait MaybeSqliteFromRow {}
#[cfg(not(feature = "sqlite"))]
impl<T> MaybeSqliteFromRow for T {}
#[cfg(feature = "sqlite")]
pub trait LoadRelatedSqlite {
fn __rustango_load_related_sqlite(
&mut self,
row: &sqlx::sqlite::SqliteRow,
field_name: &str,
alias: &str,
) -> Result<bool, sqlx::Error>;
}
#[cfg(feature = "sqlite")]
pub trait MaybeSqliteLoadRelated: LoadRelatedSqlite {}
#[cfg(feature = "sqlite")]
impl<T> MaybeSqliteLoadRelated for T where T: LoadRelatedSqlite {}
#[cfg(not(feature = "sqlite"))]
#[allow(dead_code)]
pub trait MaybeSqliteLoadRelated {}
#[cfg(not(feature = "sqlite"))]
impl<T> MaybeSqliteLoadRelated for T {}
pub async fn select_rows_pool<T>(pool: &Pool, query: &SelectQuery) -> Result<Vec<T>, ExecError>
where
T: for<'r> sqlx::FromRow<'r, PgRow> + MaybeMyFromRow + MaybeSqliteFromRow + Send + Unpin,
{
let stmt = pool.dialect().compile_select(query)?;
match pool {
#[cfg(feature = "postgres")]
Pool::Postgres(pg) => {
let mut q: QueryAs<'_, sqlx::Postgres, T, PgArguments> =
sqlx::query_as::<_, T>(&stmt.sql);
for v in stmt.params {
q = bind_query_as(q, v);
}
Ok(q.fetch_all(pg).await?)
}
#[cfg(feature = "mysql")]
Pool::Mysql(my) => {
let mut q: sqlx::query::QueryAs<'_, sqlx::MySql, T, sqlx::mysql::MySqlArguments> =
sqlx::query_as::<_, T>(&stmt.sql);
for v in stmt.params {
q = bind_query_as_my(q, v);
}
Ok(q.fetch_all(my).await?)
}
#[cfg(feature = "sqlite")]
Pool::Sqlite(sq) => {
let mut q: sqlx::query::QueryAs<
'_,
sqlx::Sqlite,
T,
sqlx::sqlite::SqliteArguments<'_>,
> = sqlx::query_as::<_, T>(&stmt.sql);
for v in stmt.params {
q = bind_query_as_sqlite(q, v);
}
Ok(q.fetch_all(sq).await?)
}
}
}
pub async fn select_one_row_pool<T>(
pool: &Pool,
query: &SelectQuery,
) -> Result<Option<T>, ExecError>
where
T: for<'r> sqlx::FromRow<'r, PgRow> + MaybeMyFromRow + MaybeSqliteFromRow + Send + Unpin,
{
let stmt = pool.dialect().compile_select(query)?;
match pool {
#[cfg(feature = "postgres")]
Pool::Postgres(pg) => {
let mut q: QueryAs<'_, sqlx::Postgres, T, PgArguments> =
sqlx::query_as::<_, T>(&stmt.sql);
for v in stmt.params {
q = bind_query_as(q, v);
}
Ok(q.fetch_optional(pg).await?)
}
#[cfg(feature = "mysql")]
Pool::Mysql(my) => {
let mut q: sqlx::query::QueryAs<'_, sqlx::MySql, T, sqlx::mysql::MySqlArguments> =
sqlx::query_as::<_, T>(&stmt.sql);
for v in stmt.params {
q = bind_query_as_my(q, v);
}
Ok(q.fetch_optional(my).await?)
}
#[cfg(feature = "sqlite")]
Pool::Sqlite(sq) => {
let mut q: sqlx::query::QueryAs<
'_,
sqlx::Sqlite,
T,
sqlx::sqlite::SqliteArguments<'_>,
> = sqlx::query_as::<_, T>(&stmt.sql);
for v in stmt.params {
q = bind_query_as_sqlite(q, v);
}
Ok(q.fetch_optional(sq).await?)
}
}
}
#[cfg(feature = "mysql")]
fn bind_query_as_my<T>(
q: sqlx::query::QueryAs<'_, sqlx::MySql, T, sqlx::mysql::MySqlArguments>,
value: SqlValue,
) -> sqlx::query::QueryAs<'_, sqlx::MySql, T, sqlx::mysql::MySqlArguments> {
bind_match!(q, value)
}
#[cfg(feature = "sqlite")]
fn bind_query_as_sqlite<'a, T>(
q: sqlx::query::QueryAs<'a, sqlx::Sqlite, T, sqlx::sqlite::SqliteArguments<'a>>,
value: SqlValue,
) -> sqlx::query::QueryAs<'a, sqlx::Sqlite, T, sqlx::sqlite::SqliteArguments<'a>> {
bind_match!(q, value)
}
pub async fn fetch_aggregate_pool<T>(
pool: &Pool,
query: &AggregateQuery,
) -> Result<Vec<T>, ExecError>
where
T: for<'r> sqlx::FromRow<'r, PgRow> + MaybeMyFromRow + MaybeSqliteFromRow + Send + Unpin,
{
let stmt = pool.dialect().compile_aggregate(query)?;
match pool {
#[cfg(feature = "postgres")]
Pool::Postgres(pg) => {
let mut q: QueryAs<'_, sqlx::Postgres, T, PgArguments> =
sqlx::query_as::<_, T>(&stmt.sql);
for v in stmt.params {
q = bind_query_as(q, v);
}
Ok(q.fetch_all(pg).await?)
}
#[cfg(feature = "mysql")]
Pool::Mysql(my) => {
let mut q: sqlx::query::QueryAs<'_, sqlx::MySql, T, sqlx::mysql::MySqlArguments> =
sqlx::query_as::<_, T>(&stmt.sql);
for v in stmt.params {
q = bind_query_as_my(q, v);
}
Ok(q.fetch_all(my).await?)
}
#[cfg(feature = "sqlite")]
Pool::Sqlite(sq) => {
let mut q: sqlx::query::QueryAs<
'_,
sqlx::Sqlite,
T,
sqlx::sqlite::SqliteArguments<'_>,
> = sqlx::query_as::<_, T>(&stmt.sql);
for v in stmt.params {
q = bind_query_as_sqlite(q, v);
}
Ok(q.fetch_all(sq).await?)
}
}
}
pub async fn raw_query_pool<T>(
sql: &str,
binds: Vec<SqlValue>,
pool: &Pool,
) -> Result<Vec<T>, ExecError>
where
T: for<'r> sqlx::FromRow<'r, PgRow> + MaybeMyFromRow + MaybeSqliteFromRow + Send + Unpin,
{
match pool {
#[cfg(feature = "postgres")]
Pool::Postgres(pg) => {
let mut q: QueryAs<'_, sqlx::Postgres, T, PgArguments> = sqlx::query_as::<_, T>(sql);
for v in binds {
q = bind_query_as(q, v);
}
Ok(q.fetch_all(pg).await?)
}
#[cfg(feature = "mysql")]
Pool::Mysql(my) => {
let mut q: sqlx::query::QueryAs<'_, sqlx::MySql, T, sqlx::mysql::MySqlArguments> =
sqlx::query_as::<_, T>(sql);
for v in binds {
q = bind_query_as_my(q, v);
}
Ok(q.fetch_all(my).await?)
}
#[cfg(feature = "sqlite")]
Pool::Sqlite(sq) => {
let mut q: sqlx::query::QueryAs<
'_,
sqlx::Sqlite,
T,
sqlx::sqlite::SqliteArguments<'_>,
> = sqlx::query_as::<_, T>(sql);
for v in binds {
q = bind_query_as_sqlite(q, v);
}
Ok(q.fetch_all(sq).await?)
}
}
}
pub trait CounterPool<T: Model + Send> {
fn count_pool(
self,
pool: &Pool,
) -> impl std::future::Future<Output = Result<i64, ExecError>> + Send;
}
impl<T: Model + Send> CounterPool<T> for QuerySet<T> {
async fn count_pool(self, pool: &Pool) -> Result<i64, ExecError> {
let select = self.compile()?;
count_rows_pool(
pool,
&CountQuery {
model: select.model,
where_clause: select.where_clause,
search: select.search,
},
)
.await
}
}
pub async fn fetch_paginated_pool<T>(
qs: crate::query::QuerySet<T>,
pool: &Pool,
) -> Result<Page<T>, ExecError>
where
T: Model
+ for<'r> sqlx::FromRow<'r, PgRow>
+ MaybeMyFromRow
+ MaybeSqliteFromRow
+ Send
+ Unpin,
{
let select = qs.compile()?;
let stmt = pool.dialect().compile_select(&select)?;
let sql = inject_total_count(&stmt.sql);
match pool {
#[cfg(feature = "postgres")]
Pool::Postgres(pg) => {
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&sql);
for v in stmt.params {
q = bind_query(q, v);
}
use sqlx::Row as _;
let raw_rows: Vec<PgRow> = q.fetch_all(pg).await?;
let total: i64 = raw_rows
.first()
.map(|row| row.try_get::<i64, _>("__rustango_total"))
.transpose()?
.unwrap_or(0);
let mut rows = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
rows.push(T::from_row(row)?);
}
Ok(Page { rows, total })
}
#[cfg(feature = "mysql")]
Pool::Mysql(my) => {
let mut q: sqlx::query::Query<'_, sqlx::MySql, sqlx::mysql::MySqlArguments> =
sqlx::query(&sql);
for v in stmt.params {
q = bind_query_my(q, v);
}
use sqlx::Row as _;
let raw_rows: Vec<sqlx::mysql::MySqlRow> = q.fetch_all(my).await?;
let total: i64 = raw_rows
.first()
.map(|row| row.try_get::<i64, _>("__rustango_total"))
.transpose()?
.unwrap_or(0);
let mut rows = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
rows.push(<T as sqlx::FromRow<sqlx::mysql::MySqlRow>>::from_row(row)?);
}
Ok(Page { rows, total })
}
#[cfg(feature = "sqlite")]
Pool::Sqlite(sq) => {
let mut q: sqlx::query::Query<'_, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'_>> =
sqlx::query(&sql);
for v in stmt.params {
q = bind_query_sqlite(q, v);
}
use sqlx::Row as _;
let raw_rows: Vec<sqlx::sqlite::SqliteRow> = q.fetch_all(sq).await?;
let total: i64 = raw_rows
.first()
.map(|row| row.try_get::<i64, _>("__rustango_total"))
.transpose()?
.unwrap_or(0);
let mut rows = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
rows.push(<T as sqlx::FromRow<sqlx::sqlite::SqliteRow>>::from_row(
row,
)?);
}
Ok(Page { rows, total })
}
}
}
pub async fn fetch_with_prefetch_pool<P, C>(
parent_qs: crate::query::QuerySet<P>,
child_fk_column: &'static str,
pool: &Pool,
) -> Result<Vec<(P, Vec<C>)>, ExecError>
where
P: Model
+ for<'r> sqlx::FromRow<'r, PgRow>
+ MaybeMyFromRow
+ MaybeSqliteFromRow
+ LoadRelated
+ MaybeMyLoadRelated
+ MaybeSqliteLoadRelated
+ HasPkValue
+ Send
+ Unpin,
C: Model
+ for<'r> sqlx::FromRow<'r, PgRow>
+ MaybeMyFromRow
+ MaybeSqliteFromRow
+ LoadRelated
+ MaybeMyLoadRelated
+ MaybeSqliteLoadRelated
+ FkPkAccess
+ Send
+ Unpin,
{
let parents: Vec<P> = parent_qs.fetch_pool(pool).await?;
if parents.is_empty() {
return Ok(Vec::new());
}
let pk_field = P::SCHEMA
.primary_key()
.ok_or(ExecError::MissingPrimaryKey {
table: P::SCHEMA.table,
})?;
let mut parent_pks: Vec<crate::core::SqlValue> = Vec::with_capacity(parents.len());
for parent in &parents {
let pk = extract_pk_value(parent);
if !matches!(pk, crate::core::SqlValue::Null) {
parent_pks.push(pk);
}
}
{
let mut seen = std::collections::HashSet::new();
parent_pks.retain(|v| seen.insert(v.to_display_string()));
}
if parent_pks.is_empty() {
return Ok(parents.into_iter().map(|p| (p, Vec::new())).collect());
}
let children: Vec<C> = crate::query::QuerySet::<C>::new()
.filter(
child_fk_column,
crate::core::Op::In,
crate::core::SqlValue::List(parent_pks),
)
.fetch_pool(pool)
.await?;
let mut grouped: std::collections::HashMap<String, Vec<C>> = std::collections::HashMap::new();
for child in children {
let Some(fk_pk) = child.__rustango_fk_pk_value(child_fk_column) else {
continue;
};
grouped
.entry(fk_pk.to_display_string())
.or_default()
.push(child);
}
let mut out = Vec::with_capacity(parents.len());
for parent in parents {
let pk = extract_pk_value(&parent).to_display_string();
let kids = grouped.remove(&pk).unwrap_or_default();
out.push((parent, kids));
}
let _ = pk_field;
Ok(out)
}
pub async fn select_rows_pool_with_related<T>(
pool: &Pool,
query: &SelectQuery,
) -> Result<Vec<T>, ExecError>
where
T: for<'r> sqlx::FromRow<'r, PgRow>
+ MaybeMyFromRow
+ MaybeSqliteFromRow
+ LoadRelated
+ MaybeMyLoadRelated
+ MaybeSqliteLoadRelated
+ Send
+ Unpin,
{
let stmt = pool.dialect().compile_select(query)?;
let aliases: Vec<&'static str> = query.joins.iter().map(|j| j.alias).collect();
match pool {
#[cfg(feature = "postgres")]
Pool::Postgres(pg) => {
if aliases.is_empty() {
let mut q: QueryAs<'_, sqlx::Postgres, T, PgArguments> =
sqlx::query_as::<_, T>(&stmt.sql);
for v in stmt.params {
q = bind_query_as(q, v);
}
return Ok(q.fetch_all(pg).await?);
}
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for v in stmt.params {
q = bind_query(q, v);
}
let raw_rows = q.fetch_all(pg).await?;
let mut out = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
let mut t = T::from_row(row)?;
for alias in &aliases {
let _ = t.__rustango_load_related(row, alias, alias)?;
}
out.push(t);
}
Ok(out)
}
#[cfg(feature = "mysql")]
Pool::Mysql(my) => {
if aliases.is_empty() {
let mut q: sqlx::query::QueryAs<'_, sqlx::MySql, T, sqlx::mysql::MySqlArguments> =
sqlx::query_as::<_, T>(&stmt.sql);
for v in stmt.params {
q = bind_query_as_my(q, v);
}
return Ok(q.fetch_all(my).await?);
}
let mut q: sqlx::query::Query<'_, sqlx::MySql, sqlx::mysql::MySqlArguments> =
sqlx::query(&stmt.sql);
for v in stmt.params {
q = bind_query_my(q, v);
}
let raw_rows = q.fetch_all(my).await?;
let mut out = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
let mut t = <T as sqlx::FromRow<sqlx::mysql::MySqlRow>>::from_row(row)?;
for alias in &aliases {
let _ = t.__rustango_load_related_my(row, alias, alias)?;
}
out.push(t);
}
Ok(out)
}
#[cfg(feature = "sqlite")]
Pool::Sqlite(sq) => {
if aliases.is_empty() {
let mut q: sqlx::query::QueryAs<
'_,
sqlx::Sqlite,
T,
sqlx::sqlite::SqliteArguments<'_>,
> = sqlx::query_as::<_, T>(&stmt.sql);
for v in stmt.params {
q = bind_query_as_sqlite(q, v);
}
return Ok(q.fetch_all(sq).await?);
}
let mut q: sqlx::query::Query<'_, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'_>> =
sqlx::query(&stmt.sql);
for v in stmt.params {
q = bind_query_sqlite(q, v);
}
let raw_rows = q.fetch_all(sq).await?;
let mut out = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
let mut t = <T as sqlx::FromRow<sqlx::sqlite::SqliteRow>>::from_row(row)?;
for alias in &aliases {
let _ = t.__rustango_load_related_sqlite(row, alias, alias)?;
}
out.push(t);
}
Ok(out)
}
}
}
pub trait FetcherPool<T>
where
T: Model
+ for<'r> sqlx::FromRow<'r, PgRow>
+ MaybeMyFromRow
+ MaybeSqliteFromRow
+ LoadRelated
+ MaybeMyLoadRelated
+ MaybeSqliteLoadRelated
+ Send
+ Unpin,
{
fn fetch_pool(
self,
pool: &Pool,
) -> impl std::future::Future<Output = Result<Vec<T>, ExecError>> + Send;
}
impl<T> FetcherPool<T> for QuerySet<T>
where
T: Model
+ for<'r> sqlx::FromRow<'r, PgRow>
+ MaybeMyFromRow
+ MaybeSqliteFromRow
+ LoadRelated
+ MaybeMyLoadRelated
+ MaybeSqliteLoadRelated
+ Send
+ Unpin,
{
async fn fetch_pool(self, pool: &Pool) -> Result<Vec<T>, ExecError> {
let select = self.compile()?;
select_rows_pool_with_related(pool, &select).await
}
}
#[cfg(test)]
mod pool_dispatch_tests {
use super::*;
#[cfg(feature = "mysql")]
#[tokio::test]
async fn mysql_pool_dispatch_uses_mysql_dialect() {
let my = sqlx::mysql::MySqlPoolOptions::new()
.max_connections(1)
.connect_lazy("mysql://user:pass@localhost:1/none")
.unwrap();
let pool: Pool = my.into();
assert_eq!(pool.dialect().name(), "mysql");
assert_eq!(pool.dialect().quote_ident("col"), "`col`");
assert_eq!(pool.dialect().placeholder(1), "?");
}
#[cfg(feature = "postgres")]
#[tokio::test]
async fn postgres_pool_dispatch_uses_postgres_dialect() {
let pg = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.connect_lazy("postgres://localhost:1/none")
.unwrap();
let pool: Pool = pg.into();
assert_eq!(pool.dialect().name(), "postgres");
assert_eq!(pool.dialect().quote_ident("col"), "\"col\"");
assert_eq!(pool.dialect().placeholder(1), "$1");
}
#[test]
fn maybe_my_from_row_resolves_for_unit_type() {
fn check<T: super::MaybeMyFromRow>() {}
check::<()>();
}
}