use crate::core::{
AggregateQuery, BulkInsertQuery, BulkUpdateQuery, CountQuery, DeleteQuery, InsertQuery, Model,
SelectQuery, SqlValue, UpdateQuery,
};
use crate::query::{QuerySet, UpdateBuilder};
#[cfg(feature = "postgres")]
use sqlx::postgres::{PgArguments, PgPool, PgRow};
#[cfg(feature = "postgres")]
use sqlx::query::{Query, QueryAs};
use super::Dialect;
use super::ExecError;
#[cfg(feature = "postgres")]
use super::Postgres;
#[doc(hidden)]
pub trait FkPkAccess {
fn __rustango_fk_pk(&self, field_name: &str) -> Option<i64>;
fn __rustango_fk_pk_value(&self, field_name: &str) -> Option<crate::core::SqlValue>;
}
#[doc(hidden)]
#[cfg(feature = "postgres")]
pub trait LoadRelated {
fn __rustango_load_related(
&mut self,
row: &PgRow,
field_name: &str,
alias: &str,
) -> Result<bool, sqlx::Error>;
}
#[cfg(not(feature = "postgres"))]
#[doc(hidden)]
pub trait LoadRelated {}
#[cfg(not(feature = "postgres"))]
impl<T> LoadRelated for T {}
pub(crate) fn select_related_leaves(aliases: &[&'static str]) -> Vec<(&'static str, &'static str)> {
aliases
.iter()
.copied()
.filter(|a| {
!aliases.iter().any(|b| {
*b != *a
&& b.len() > a.len()
&& b.starts_with(a)
&& b.as_bytes()[a.len()..].starts_with(b"__")
})
})
.map(|a| {
let first_hop = a.split_once("__").map(|(h, _)| h).unwrap_or(a);
(a, first_hop)
})
.collect()
}
#[cfg(feature = "postgres")]
impl<T> QuerySet<T>
where
T: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
{
pub async fn fetch_on<'c, E>(self, executor: E) -> Result<Vec<T>, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
T: LoadRelated,
{
let select = self.compile()?;
let select_related_aliases: Vec<&'static str> =
select.joins.iter().map(|j| j.alias).collect();
let stmt = Postgres.compile_select(&select)?;
if select_related_aliases.is_empty() {
let mut q: QueryAs<'_, sqlx::Postgres, T, PgArguments> =
sqlx::query_as::<_, T>(&stmt.sql);
for value in stmt.params {
q = bind_query_as(q, value);
}
let rows = q.fetch_all(executor).await?;
return Ok(rows);
}
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
let raw_rows = q.fetch_all(executor).await?;
let leaves = select_related_leaves(&select_related_aliases);
let mut out = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
let mut t = T::from_row(row)?;
for (leaf, first_hop) in &leaves {
let _ = t.__rustango_load_related(row, leaf, first_hop)?;
}
out.push(t);
}
Ok(out)
}
pub async fn fetch_paginated_on<'c, E>(self, executor: E) -> Result<Page<T>, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let select = self.compile()?;
let stmt = Postgres.compile_select(&select)?;
let sql = inject_total_count(&stmt.sql);
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&sql);
for value in stmt.params {
q = bind_query(q, value);
}
let raw_rows: Vec<PgRow> = q.fetch_all(executor).await?;
let total: i64 = raw_rows
.first()
.map(|row| sqlx::Row::try_get::<i64, _>(row, "__rustango_total"))
.transpose()?
.unwrap_or(0);
let mut rows = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
rows.push(T::from_row(row)?);
}
Ok(Page { rows, total })
}
pub async fn fetch_paginated(self, pool: &PgPool) -> Result<Page<T>, ExecError> {
self.fetch_paginated_on(pool).await
}
pub async fn in_bulk_on<'c, E, C, K, I, F>(
self,
column: C,
ids: I,
extract: F,
executor: E,
) -> Result<std::collections::HashMap<K, T>, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
T: LoadRelated,
C: crate::core::Column<Model = T>,
K: Eq + std::hash::Hash + Into<crate::core::SqlValue>,
I: IntoIterator<Item = K>,
F: Fn(&T) -> K,
{
let _ = column;
let id_values: Vec<crate::core::SqlValue> = ids.into_iter().map(|v| v.into()).collect();
if id_values.is_empty() {
return Ok(std::collections::HashMap::new());
}
let rows = self
.filter_op(
C::COLUMN,
crate::core::Op::In,
crate::core::SqlValue::List(id_values),
)
.fetch_on(executor)
.await?;
let mut out = std::collections::HashMap::with_capacity(rows.len());
for row in rows {
let key = extract(&row);
out.insert(key, row);
}
Ok(out)
}
}
mod page;
use page::inject_total_count;
pub use page::Page;
#[cfg(feature = "postgres")]
mod pg_on;
#[cfg(feature = "postgres")]
pub use pg_on::{
bulk_insert_on, delete_on, insert_on, insert_returning_on, select_one_row_on, select_rows_on,
update_on,
};
mod row_to_json;
#[cfg(feature = "postgres")]
pub use row_to_json::row_to_json;
#[cfg(feature = "mysql")]
pub use row_to_json::row_to_json_my;
#[cfg(feature = "sqlite")]
pub use row_to_json::row_to_json_sqlite;
pub use row_to_json::{select_one_row_as_json, select_rows_as_json};
#[cfg(feature = "postgres")]
pub async fn annotate_count_children<P>(
parent_qs: crate::query::QuerySet<P>,
child_table: &'static str,
child_fk_column: &'static str,
pool: &PgPool,
) -> Result<Vec<(P, i64)>, ExecError>
where
P: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
{
annotate_count_children_on(parent_qs, child_table, child_fk_column, pool).await
}
#[cfg(feature = "postgres")]
pub async fn annotate_count_children_on<'c, P, E>(
parent_qs: crate::query::QuerySet<P>,
child_table: &'static str,
child_fk_column: &'static str,
executor: E,
) -> Result<Vec<(P, i64)>, ExecError>
where
P: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
use std::fmt::Write as _;
let select = parent_qs.compile()?;
let parent = select.model;
let pk_field = parent.primary_key().ok_or(ExecError::MissingPrimaryKey {
table: parent.table,
})?;
let cols: Vec<&'static str> = parent.scalar_fields().map(|f| f.column).collect();
let mut sql = String::from("SELECT ");
for (i, col) in cols.iter().enumerate() {
if i > 0 {
sql.push_str(", ");
}
let _ = write!(sql, "\"{}\".\"{col}\"", parent.table);
}
let _ = write!(
sql,
", COUNT(\"{child_table}\".\"{child_pk}\") AS \"__annotated_count\" FROM \"{parent_table}\" LEFT JOIN \"{child_table}\" ON \"{child_table}\".\"{child_fk_column}\" = \"{parent_table}\".\"{parent_pk}\"",
parent_table = parent.table,
parent_pk = pk_field.column,
child_pk = "id",
);
let tail = crate::sql::postgres::compile_where_order_tail(
&select.where_clause,
select.search.as_ref(),
&select.order_by,
select.limit,
select.offset,
Some(parent.table),
Some(parent),
)?;
sql.push_str(" GROUP BY ");
for (i, col) in cols.iter().enumerate() {
if i > 0 {
sql.push_str(", ");
}
let _ = write!(sql, "\"{}\".\"{col}\"", parent.table);
}
sql.push_str(&tail.sql);
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&sql);
for param in tail.params {
q = bind_query(q, param);
}
let raw_rows = q.fetch_all(executor).await?;
let mut out = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
let parent_obj = P::from_row(row)?;
let count: i64 = sqlx::Row::try_get(row, "__annotated_count")?;
out.push((parent_obj, count));
}
Ok(out)
}
#[cfg(feature = "postgres")]
pub async fn fetch_with_prefetch<P, C>(
parent_qs: crate::query::QuerySet<P>,
child_fk_column: &'static str,
pool: &PgPool,
) -> Result<Vec<(P, Vec<C>)>, ExecError>
where
P: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin + LoadRelated + HasPkValue,
C: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin + LoadRelated + FkPkAccess,
{
let parents: Vec<P> = parent_qs.fetch_on(pool).await?;
if parents.is_empty() {
return Ok(Vec::new());
}
let pk_field = P::SCHEMA
.primary_key()
.ok_or(ExecError::MissingPrimaryKey {
table: P::SCHEMA.table,
})?;
let mut parent_pks: Vec<crate::core::SqlValue> = Vec::with_capacity(parents.len());
for parent in &parents {
let pk = extract_pk_value(parent);
if !matches!(pk, crate::core::SqlValue::Null) {
parent_pks.push(pk);
}
}
{
let mut seen = std::collections::HashSet::new();
parent_pks.retain(|v| seen.insert(v.to_display_string()));
}
if parent_pks.is_empty() {
return Ok(parents.into_iter().map(|p| (p, Vec::new())).collect());
}
let children: Vec<C> = crate::query::QuerySet::<C>::new()
.filter_op(
child_fk_column,
crate::core::Op::In,
crate::core::SqlValue::List(parent_pks),
)
.fetch_on(pool)
.await?;
let mut grouped: std::collections::HashMap<String, Vec<C>> = std::collections::HashMap::new();
for child in children {
let Some(fk_pk) = child.__rustango_fk_pk_value(child_fk_column) else {
continue;
};
grouped
.entry(fk_pk.to_display_string())
.or_default()
.push(child);
}
let mut out = Vec::with_capacity(parents.len());
for parent in parents {
let pk = extract_pk_value(&parent).to_display_string();
let kids = grouped.remove(&pk).unwrap_or_default();
out.push((parent, kids));
}
let _ = pk_field; Ok(out)
}
pub(super) fn extract_pk_value<P: HasPkValue>(parent: &P) -> crate::core::SqlValue {
parent.__rustango_pk_value_impl()
}
#[doc(hidden)]
pub trait HasPkValue {
fn __rustango_pk_value_impl(&self) -> crate::core::SqlValue;
}
#[cfg(feature = "postgres")]
impl<T: Model + Send> QuerySet<T> {
pub async fn count_on<'c, E>(self, executor: E) -> Result<i64, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let select = self.compile()?;
let stmt = Postgres.compile_count(&CountQuery {
model: select.model,
where_clause: select.where_clause,
search: select.search,
})?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for value in stmt.params {
q = bind_query(q, value);
}
let row = q.fetch_one(executor).await?;
let count: i64 = sqlx::Row::try_get(&row, 0)?;
Ok(count)
}
pub async fn explain(self, pool: &PgPool) -> Result<Vec<String>, ExecError> {
self.explain_on(pool, ExplainOptions::default()).await
}
pub async fn explain_on<'c, E>(
self,
executor: E,
options: ExplainOptions,
) -> Result<Vec<String>, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let select = self.compile()?;
let stmt = Postgres.compile_select(&select)?;
let mut sql = String::with_capacity(stmt.sql.len() + 32);
sql.push_str("EXPLAIN ");
let prefix = options.to_clause();
if !prefix.is_empty() {
sql.push_str(&prefix);
sql.push(' ');
}
sql.push_str(&stmt.sql);
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&sql);
for value in stmt.params {
q = bind_query(q, value);
}
let rows = q.fetch_all(executor).await?;
let mut out = Vec::with_capacity(rows.len());
for row in &rows {
let line: String = match options.format {
ExplainFormat::Json => {
let v: serde_json::Value = sqlx::Row::try_get(row, 0)?;
v.to_string()
}
ExplainFormat::Text | ExplainFormat::Yaml | ExplainFormat::Xml => {
sqlx::Row::try_get(row, 0)?
}
};
out.push(line);
}
Ok(out)
}
}
mod explain;
pub use explain::{explain_pool, ExplainFormat, ExplainOptions};
#[cfg(feature = "postgres")]
#[cfg(feature = "postgres")]
impl<T: Model + Send> QuerySet<T> {
pub async fn delete_on<'c, E>(self, executor: E) -> Result<u64, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let query = self.compile_delete()?;
delete_on(executor, &query).await
}
}
#[cfg(feature = "postgres")]
impl<T: Model + Send> UpdateBuilder<T> {
pub async fn execute_on<'c, E>(self, executor: E) -> Result<u64, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let query = self.compile()?;
update_on(executor, &query).await
}
}
pub trait UpdaterPool<T: Model + Send> {
fn execute_pool(
self,
pool: &Pool,
) -> impl std::future::Future<Output = Result<u64, ExecError>> + Send;
}
impl<T: Model + Send> UpdaterPool<T> for UpdateBuilder<T> {
async fn execute_pool(self, pool: &Pool) -> Result<u64, ExecError> {
let query = self.compile()?;
update_pool(pool, &query).await
}
}
#[cfg(feature = "postgres")]
macro_rules! bind_match {
($q:expr, $value:expr) => {
match $value {
SqlValue::Null => $q.bind(None::<String>),
SqlValue::I16(v) => $q.bind(v),
SqlValue::I32(v) => $q.bind(v),
SqlValue::I64(v) => $q.bind(v),
SqlValue::F32(v) => $q.bind(v),
SqlValue::F64(v) => $q.bind(v),
SqlValue::Bool(v) => $q.bind(v),
SqlValue::String(v) => $q.bind(v),
SqlValue::DateTime(v) => $q.bind(v),
SqlValue::Date(v) => $q.bind(v),
SqlValue::Time(v) => $q.bind(v),
SqlValue::Uuid(v) => $q.bind(v),
SqlValue::Json(v) => $q.bind(sqlx::types::Json(v)),
SqlValue::Decimal(v) => $q.bind(v),
SqlValue::Binary(v) => $q.bind(v),
SqlValue::List(_) => {
unreachable!("`SqlValue::List` is expanded to scalars by the SQL writer")
}
SqlValue::RangeLiteral(s) => $q.bind(s),
SqlValue::HStore(pairs) => {
$q.bind(sqlx::postgres::types::PgHstore(pairs.into_iter().collect()))
}
SqlValue::Vector(v) => $q.bind(crate::sql::Vector(v)),
SqlValue::Geometry { x, y, srid } => {
$q.bind(crate::sql::Point { x, y, srid })
}
SqlValue::Array(elems) => match elems.first() {
None => $q.bind(Vec::<i32>::new()),
Some(SqlValue::I64(_)) => {
let v: Vec<i64> = elems
.into_iter()
.filter_map(|e| if let SqlValue::I64(n) = e { Some(n) } else { None })
.collect();
$q.bind(v)
}
Some(SqlValue::I32(_)) => {
let v: Vec<i32> = elems
.into_iter()
.filter_map(|e| if let SqlValue::I32(n) = e { Some(n) } else { None })
.collect();
$q.bind(v)
}
Some(SqlValue::String(_)) => {
let v: Vec<String> = elems
.into_iter()
.filter_map(|e| {
if let SqlValue::String(s) = e {
Some(s)
} else {
None
}
})
.collect();
$q.bind(v)
}
Some(SqlValue::Bool(_)) => {
let v: Vec<bool> = elems
.into_iter()
.filter_map(|e| if let SqlValue::Bool(b) = e { Some(b) } else { None })
.collect();
$q.bind(v)
}
Some(_) => unreachable!(
"SqlValue::Array elements other than I32/I64/String/Bool are not yet supported (v1, issue #30)"
),
},
}
};
}
#[cfg(feature = "mysql")]
macro_rules! bind_match_mysql {
($q:expr, $value:expr) => {
match $value {
SqlValue::Null => $q.bind(None::<String>),
SqlValue::I16(v) => $q.bind(v),
SqlValue::I32(v) => $q.bind(v),
SqlValue::I64(v) => $q.bind(v),
SqlValue::F32(v) => $q.bind(v),
SqlValue::F64(v) => $q.bind(v),
SqlValue::Bool(v) => $q.bind(v),
SqlValue::String(v) => $q.bind(v),
SqlValue::DateTime(v) => $q.bind(v),
SqlValue::Date(v) => $q.bind(v),
SqlValue::Time(v) => $q.bind(v),
SqlValue::Uuid(v) => $q.bind(v),
SqlValue::Json(v) => $q.bind(sqlx::types::Json(v)),
SqlValue::Decimal(v) => $q.bind(v),
SqlValue::Binary(v) => $q.bind(v),
SqlValue::List(_) => {
unreachable!("`SqlValue::List` is expanded to scalars by the SQL writer")
}
SqlValue::Array(_) => unreachable!(
"MySQL has no array type; `write_array_op` rejects before bind. Issue #30."
),
SqlValue::RangeLiteral(_) => unreachable!(
"MySQL has no range type; `write_range_op` rejects before bind. Issue #31."
),
SqlValue::HStore(_) => {
unreachable!("MySQL has no hstore type; `HStore` columns are PG-only. Issue #342.")
}
SqlValue::Vector(_) => unreachable!(
"MySQL has no vector type; pgvector distance operators are rejected before bind. Issue #824."
),
SqlValue::Geometry { .. } => unreachable!(
"MySQL has no geometry type; PostGIS `Point` columns are PG-only. Issue #443."
),
}
};
}
#[cfg(feature = "sqlite")]
macro_rules! bind_match_sqlite {
($q:expr, $value:expr) => {
match $value {
SqlValue::Null => $q.bind(None::<String>),
SqlValue::I16(v) => $q.bind(v),
SqlValue::I32(v) => $q.bind(v),
SqlValue::I64(v) => $q.bind(v),
SqlValue::F32(v) => $q.bind(v),
SqlValue::F64(v) => $q.bind(v),
SqlValue::Bool(v) => $q.bind(v),
SqlValue::String(v) => $q.bind(v),
SqlValue::DateTime(v) => $q.bind(v),
SqlValue::Date(v) => $q.bind(v),
SqlValue::Time(v) => $q.bind(v),
SqlValue::Uuid(v) => $q.bind(v),
SqlValue::Json(v) => $q.bind(sqlx::types::Json(v)),
SqlValue::Decimal(v) => $q.bind(v.to_string()),
SqlValue::Binary(v) => $q.bind(v),
SqlValue::List(_) => {
unreachable!("`SqlValue::List` is expanded to scalars by the SQL writer")
}
SqlValue::Array(_) => unreachable!(
"SQLite has no array type; `write_array_op` rejects before bind. Issue #30."
),
SqlValue::RangeLiteral(_) => unreachable!(
"SQLite has no range type; `write_range_op` rejects before bind. Issue #31."
),
SqlValue::HStore(_) => {
unreachable!("SQLite has no hstore type; `HStore` columns are PG-only. Issue #342.")
}
SqlValue::Vector(_) => unreachable!(
"SQLite has no vector type; pgvector distance operators are rejected before bind. Issue #824."
),
SqlValue::Geometry { .. } => unreachable!(
"SQLite has no geometry type; PostGIS `Point` columns are PG-only. Issue #443."
),
}
};
}
#[cfg(feature = "postgres")]
pub(super) fn bind_query_as<T>(
q: QueryAs<'_, sqlx::Postgres, T, PgArguments>,
value: SqlValue,
) -> QueryAs<'_, sqlx::Postgres, T, PgArguments> {
bind_match!(q, value)
}
#[cfg(feature = "postgres")]
pub(super) fn bind_query(
q: Query<'_, sqlx::Postgres, PgArguments>,
value: SqlValue,
) -> Query<'_, sqlx::Postgres, PgArguments> {
bind_match!(q, value)
}
#[cfg(feature = "postgres")]
pub async fn raw_query_on<'c, T, E>(
sql: &str,
binds: Vec<SqlValue>,
executor: E,
) -> Result<Vec<T>, ExecError>
where
T: for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let mut q: QueryAs<'_, sqlx::Postgres, T, PgArguments> = sqlx::query_as(sql);
for b in binds {
q = bind_query_as(q, b);
}
Ok(q.fetch_all(executor).await?)
}
#[cfg(feature = "postgres")]
pub async fn fetch_aggregate_on<'c, E>(
query: &AggregateQuery,
executor: E,
) -> Result<Vec<std::collections::HashMap<String, SqlValue>>, ExecError>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
let stmt = Postgres.compile_aggregate(query)?;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for p in stmt.params {
q = bind_query(q, p);
}
let raw_rows = q.fetch_all(executor).await?;
let mut out = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
use sqlx::{Column as _, Row as _};
let mut map = std::collections::HashMap::new();
for (i, col) in row.columns().iter().enumerate() {
let name = col.name().to_owned();
let val: SqlValue = if let Ok(v) = row.try_get::<i64, _>(i) {
SqlValue::I64(v)
} else if let Ok(v) = row.try_get::<i32, _>(i) {
SqlValue::I32(v)
} else if let Ok(v) = row.try_get::<f64, _>(i) {
SqlValue::F64(v)
} else if let Ok(v) = row.try_get::<bool, _>(i) {
SqlValue::Bool(v)
} else if let Ok(v) = row.try_get::<String, _>(i) {
SqlValue::String(v)
} else if let Ok(v) = row.try_get::<serde_json::Value, _>(i) {
SqlValue::Json(v)
} else if let Ok(v) = row.try_get::<Vec<String>, _>(i) {
SqlValue::Json(serde_json::Value::Array(
v.into_iter().map(serde_json::Value::String).collect(),
))
} else if let Ok(v) = row.try_get::<Vec<i64>, _>(i) {
SqlValue::Json(serde_json::Value::Array(
v.into_iter()
.map(|n| serde_json::Value::Number(n.into()))
.collect(),
))
} else {
SqlValue::Null
};
map.insert(name, val);
}
out.push(map);
}
Ok(out)
}
mod tx;
pub use tx::{transaction_pool, PoolTx};
mod atomic;
pub use atomic::{atomic, on_commit, on_commit_pending};
use super::Pool;
#[cfg(feature = "mysql")]
pub(super) fn bind_query_my(
q: sqlx::query::Query<'_, sqlx::MySql, sqlx::mysql::MySqlArguments>,
value: SqlValue,
) -> sqlx::query::Query<'_, sqlx::MySql, sqlx::mysql::MySqlArguments> {
bind_match_mysql!(q, value)
}
#[cfg(feature = "sqlite")]
pub(super) fn bind_query_sqlite<'a>(
q: sqlx::query::Query<'a, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'a>>,
value: SqlValue,
) -> sqlx::query::Query<'a, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'a>> {
bind_match_sqlite!(q, value)
}
pub async fn insert_pool(pool: &Pool, query: &InsertQuery) -> Result<(), ExecError> {
query.validate()?;
let stmt = pool.dialect().compile_insert(query)?;
execute_pool(pool, &stmt.sql, stmt.params).await?;
Ok(())
}
pub async fn insert_returning_pool(
pool: &Pool,
query: &InsertQuery,
) -> Result<InsertReturningPool, ExecError> {
crate::test_assertions::query_counter::bump();
query.validate()?;
if query.returning.is_empty() {
return Err(ExecError::EmptyReturning);
}
match pool {
#[cfg(feature = "postgres")]
Pool::Postgres(pg) => {
let row = insert_returning_on(pg, query).await?;
Ok(InsertReturningPool::PgRow(row))
}
#[cfg(feature = "mysql")]
Pool::Mysql(my) => {
let plain = InsertQuery {
model: query.model,
columns: query.columns.clone(),
values: query.values.clone(),
returning: ::std::vec::Vec::new(),
on_conflict: query.on_conflict.clone(),
};
let stmt = pool.dialect().compile_insert(&plain)?;
let mut conn = my.acquire().await?;
let mut q: sqlx::query::Query<'_, sqlx::MySql, sqlx::mysql::MySqlArguments> =
sqlx::query(&stmt.sql);
for v in stmt.params {
q = bind_query_my(q, v);
}
q.execute(&mut *conn).await?;
use sqlx::Row as _;
let row = sqlx::query("SELECT LAST_INSERT_ID()")
.fetch_one(&mut *conn)
.await?;
let id_u64: u64 = row.try_get::<u64, _>(0)?;
let id = i64::try_from(id_u64).unwrap_or(i64::MAX);
Ok(InsertReturningPool::MySqlAutoId(id))
}
#[cfg(feature = "sqlite")]
Pool::Sqlite(sq) => {
let stmt = pool.dialect().compile_insert(query)?;
let mut q: sqlx::query::Query<'_, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'_>> =
sqlx::query(&stmt.sql);
for v in stmt.params {
q = bind_query_sqlite(q, v);
}
let row = q.fetch_one(sq).await?;
Ok(InsertReturningPool::SqliteRow(row))
}
}
}
pub enum InsertReturningPool {
#[cfg(feature = "postgres")]
PgRow(PgRow),
#[cfg(feature = "mysql")]
MySqlAutoId(i64),
#[cfg(feature = "sqlite")]
SqliteRow(sqlx::sqlite::SqliteRow),
}
impl ::core::fmt::Debug for InsertReturningPool {
fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
match self {
#[cfg(feature = "postgres")]
Self::PgRow(_) => f.debug_tuple("PgRow").field(&"<PgRow>").finish(),
#[cfg(feature = "mysql")]
Self::MySqlAutoId(id) => f.debug_tuple("MySqlAutoId").field(id).finish(),
#[cfg(feature = "sqlite")]
Self::SqliteRow(_) => f.debug_tuple("SqliteRow").field(&"<SqliteRow>").finish(),
}
}
}
pub async fn update_pool(pool: &Pool, query: &UpdateQuery) -> Result<u64, ExecError> {
let stmt = pool.dialect().compile_update(query)?;
execute_pool(pool, &stmt.sql, stmt.params).await
}
pub async fn delete_pool(pool: &Pool, query: &DeleteQuery) -> Result<u64, ExecError> {
let stmt = pool.dialect().compile_delete(query)?;
execute_pool(pool, &stmt.sql, stmt.params).await
}
pub async fn count_rows_pool(pool: &Pool, query: &CountQuery) -> Result<i64, ExecError> {
let stmt = pool.dialect().compile_count(query)?;
fetch_scalar_pool(pool, &stmt.sql, stmt.params).await
}
pub async fn bulk_insert_pool(pool: &Pool, query: &BulkInsertQuery) -> Result<(), ExecError> {
if query.rows.is_empty() {
return Ok(());
}
let stmt = pool.dialect().compile_bulk_insert(query)?;
execute_pool(pool, &stmt.sql, stmt.params).await?;
Ok(())
}
pub async fn bulk_update_pool(pool: &Pool, query: &BulkUpdateQuery) -> Result<u64, ExecError> {
if query.rows.is_empty() {
return Ok(0);
}
let stmt = pool.dialect().compile_bulk_update(query)?;
execute_pool(pool, &stmt.sql, stmt.params).await
}
pub async fn raw_execute_pool(
pool: &Pool,
sql: &str,
binds: Vec<SqlValue>,
) -> Result<u64, ExecError> {
execute_pool(pool, sql, binds).await
}
pub async fn raw_execute_tx(
tx: &mut tx::PoolTx<'_>,
sql: &str,
binds: Vec<SqlValue>,
) -> Result<u64, ExecError> {
crate::test_assertions::query_counter::bump();
match tx {
#[cfg(feature = "postgres")]
tx::PoolTx::Postgres(t) => {
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(sql);
for v in binds {
q = bind_query(q, v);
}
Ok(q.execute(&mut **t).await?.rows_affected())
}
#[cfg(feature = "mysql")]
tx::PoolTx::Mysql(t) => {
let mut q: sqlx::query::Query<'_, sqlx::MySql, sqlx::mysql::MySqlArguments> =
sqlx::query(sql);
for v in binds {
q = bind_query_my(q, v);
}
Ok(q.execute(&mut **t).await?.rows_affected())
}
#[cfg(feature = "sqlite")]
tx::PoolTx::Sqlite(t) => {
let mut q: sqlx::query::Query<'_, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'_>> =
sqlx::query(sql);
for v in binds {
q = bind_query_sqlite(q, v);
}
Ok(q.execute(&mut **t).await?.rows_affected())
}
}
}
pub async fn run_ddl_idempotent(pool: &Pool, ddl: &str) -> Result<(), sqlx::Error> {
for stmt in ddl.split(';').map(str::trim).filter(|s| !s.is_empty()) {
match execute_pool(pool, stmt, Vec::new()).await {
Ok(_) => {}
Err(crate::sql::ExecError::Driver(err)) => {
if !crate::sql::is_mysql_dup_index_error(&err) {
return Err(err);
}
}
Err(other) => return Err(sqlx::Error::Protocol(format!("{other}"))),
}
}
Ok(())
}
async fn execute_pool(pool: &Pool, sql: &str, binds: Vec<SqlValue>) -> Result<u64, ExecError> {
crate::test_assertions::query_counter::bump();
match pool {
#[cfg(feature = "postgres")]
Pool::Postgres(pg) => {
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(sql);
for v in binds {
q = bind_query(q, v);
}
Ok(q.execute(pg).await?.rows_affected())
}
#[cfg(feature = "mysql")]
Pool::Mysql(my) => {
let mut q: sqlx::query::Query<'_, sqlx::MySql, sqlx::mysql::MySqlArguments> =
sqlx::query(sql);
for v in binds {
q = bind_query_my(q, v);
}
Ok(q.execute(my).await?.rows_affected())
}
#[cfg(feature = "sqlite")]
Pool::Sqlite(sq) => {
let mut q: sqlx::query::Query<'_, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'_>> =
sqlx::query(sql);
for v in binds {
q = bind_query_sqlite(q, v);
}
Ok(q.execute(sq).await?.rows_affected())
}
}
}
async fn execute_tx(
tx: &mut PoolTx<'_>,
sql: &str,
binds: Vec<SqlValue>,
) -> Result<u64, ExecError> {
match tx {
#[cfg(feature = "postgres")]
PoolTx::Postgres(t) => {
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(sql);
for v in binds {
q = bind_query(q, v);
}
Ok(q.execute(&mut **t).await?.rows_affected())
}
#[cfg(feature = "mysql")]
PoolTx::Mysql(t) => {
let mut q: sqlx::query::Query<'_, sqlx::MySql, sqlx::mysql::MySqlArguments> =
sqlx::query(sql);
for v in binds {
q = bind_query_my(q, v);
}
Ok(q.execute(&mut **t).await?.rows_affected())
}
#[cfg(feature = "sqlite")]
PoolTx::Sqlite(t) => {
let mut q: sqlx::query::Query<'_, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'_>> =
sqlx::query(sql);
for v in binds {
q = bind_query_sqlite(q, v);
}
Ok(q.execute(&mut **t).await?.rows_affected())
}
}
}
pub async fn insert_tx(tx: &mut PoolTx<'_>, query: &InsertQuery) -> Result<(), ExecError> {
query.validate()?;
let stmt = tx.dialect().compile_insert(query)?;
execute_tx(tx, &stmt.sql, stmt.params).await?;
Ok(())
}
pub async fn insert_returning_tx(
tx: &mut PoolTx<'_>,
query: &InsertQuery,
) -> Result<InsertReturningPool, ExecError> {
query.validate()?;
if query.returning.is_empty() {
return Err(ExecError::EmptyReturning);
}
match tx {
#[cfg(feature = "postgres")]
PoolTx::Postgres(t) => {
let row = insert_returning_on(&mut **t, query).await?;
Ok(InsertReturningPool::PgRow(row))
}
#[cfg(feature = "mysql")]
PoolTx::Mysql(t) => {
let plain = InsertQuery {
model: query.model,
columns: query.columns.clone(),
values: query.values.clone(),
returning: ::std::vec::Vec::new(),
on_conflict: query.on_conflict.clone(),
};
let stmt = super::mysql::DIALECT.compile_insert(&plain)?;
let mut q: sqlx::query::Query<'_, sqlx::MySql, sqlx::mysql::MySqlArguments> =
sqlx::query(&stmt.sql);
for v in stmt.params {
q = bind_query_my(q, v);
}
q.execute(&mut **t).await?;
use sqlx::Row as _;
let row = sqlx::query("SELECT LAST_INSERT_ID()")
.fetch_one(&mut **t)
.await?;
let id_u64: u64 = row.try_get::<u64, _>(0)?;
let id = i64::try_from(id_u64).unwrap_or(i64::MAX);
Ok(InsertReturningPool::MySqlAutoId(id))
}
#[cfg(feature = "sqlite")]
PoolTx::Sqlite(t) => {
let stmt = super::sqlite::DIALECT.compile_insert(query)?;
let mut q: sqlx::query::Query<'_, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'_>> =
sqlx::query(&stmt.sql);
for v in stmt.params {
q = bind_query_sqlite(q, v);
}
let row = q.fetch_one(&mut **t).await?;
Ok(InsertReturningPool::SqliteRow(row))
}
}
}
pub async fn update_tx(tx: &mut PoolTx<'_>, query: &UpdateQuery) -> Result<u64, ExecError> {
let stmt = tx.dialect().compile_update(query)?;
execute_tx(tx, &stmt.sql, stmt.params).await
}
pub async fn delete_tx(tx: &mut PoolTx<'_>, query: &DeleteQuery) -> Result<u64, ExecError> {
let stmt = tx.dialect().compile_delete(query)?;
execute_tx(tx, &stmt.sql, stmt.params).await
}
pub async fn select_rows_tx_with_related<T>(
tx: &mut PoolTx<'_>,
query: &SelectQuery,
) -> Result<Vec<T>, ExecError>
where
T: MaybePgFromRow
+ MaybeMyFromRow
+ MaybeSqliteFromRow
+ LoadRelated
+ MaybeMyLoadRelated
+ MaybeSqliteLoadRelated
+ Send
+ Unpin,
{
let stmt = tx.dialect().compile_select(query)?;
let aliases: Vec<&'static str> = query.joins.iter().map(|j| j.alias).collect();
let leaves = select_related_leaves(&aliases);
match tx {
#[cfg(feature = "postgres")]
PoolTx::Postgres(t) => {
if aliases.is_empty() {
let mut q: QueryAs<'_, sqlx::Postgres, T, PgArguments> =
sqlx::query_as::<_, T>(&stmt.sql);
for v in stmt.params {
q = bind_query_as(q, v);
}
return Ok(q.fetch_all(&mut **t).await?);
}
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for v in stmt.params {
q = bind_query(q, v);
}
let raw_rows = q.fetch_all(&mut **t).await?;
let mut out = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
let mut item = T::from_row(row)?;
for &(alias, first_hop) in &leaves {
let _ = item.__rustango_load_related(row, alias, first_hop)?;
}
out.push(item);
}
Ok(out)
}
#[cfg(feature = "mysql")]
PoolTx::Mysql(t) => {
if aliases.is_empty() {
let mut q: sqlx::query::QueryAs<'_, sqlx::MySql, T, sqlx::mysql::MySqlArguments> =
sqlx::query_as::<_, T>(&stmt.sql);
for v in stmt.params {
q = bind_query_as_my(q, v);
}
return Ok(q.fetch_all(&mut **t).await?);
}
let mut q: sqlx::query::Query<'_, sqlx::MySql, sqlx::mysql::MySqlArguments> =
sqlx::query(&stmt.sql);
for v in stmt.params {
q = bind_query_my(q, v);
}
let raw_rows = q.fetch_all(&mut **t).await?;
let mut out = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
let mut item = <T as sqlx::FromRow<sqlx::mysql::MySqlRow>>::from_row(row)?;
for &(alias, first_hop) in &leaves {
let _ = item.__rustango_load_related_my(row, alias, first_hop)?;
}
out.push(item);
}
Ok(out)
}
#[cfg(feature = "sqlite")]
PoolTx::Sqlite(t) => {
if aliases.is_empty() {
let mut q: sqlx::query::QueryAs<
'_,
sqlx::Sqlite,
T,
sqlx::sqlite::SqliteArguments<'_>,
> = sqlx::query_as::<_, T>(&stmt.sql);
for v in stmt.params {
q = bind_query_as_sqlite(q, v);
}
return Ok(q.fetch_all(&mut **t).await?);
}
let mut q: sqlx::query::Query<'_, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'_>> =
sqlx::query(&stmt.sql);
for v in stmt.params {
q = bind_query_sqlite(q, v);
}
let raw_rows = q.fetch_all(&mut **t).await?;
let mut out = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
let mut item = <T as sqlx::FromRow<sqlx::sqlite::SqliteRow>>::from_row(row)?;
for &(alias, first_hop) in &leaves {
let _ = item.__rustango_load_related_sqlite(row, alias, first_hop)?;
}
out.push(item);
}
Ok(out)
}
}
}
async fn fetch_scalar_pool(pool: &Pool, sql: &str, binds: Vec<SqlValue>) -> Result<i64, ExecError> {
match pool {
#[cfg(feature = "postgres")]
Pool::Postgres(pg) => {
use sqlx::Row as _;
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(sql);
for v in binds {
q = bind_query(q, v);
}
let row = q.fetch_one(pg).await?;
Ok(row.try_get::<i64, _>(0)?)
}
#[cfg(feature = "mysql")]
Pool::Mysql(my) => {
use sqlx::Row as _;
let mut q: sqlx::query::Query<'_, sqlx::MySql, sqlx::mysql::MySqlArguments> =
sqlx::query(sql);
for v in binds {
q = bind_query_my(q, v);
}
let row = q.fetch_one(my).await?;
Ok(row.try_get::<i64, _>(0)?)
}
#[cfg(feature = "sqlite")]
Pool::Sqlite(sq) => {
use sqlx::Row as _;
let mut q: sqlx::query::Query<'_, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'_>> =
sqlx::query(sql);
for v in binds {
q = bind_query_sqlite(q, v);
}
let row = q.fetch_one(sq).await?;
Ok(row.try_get::<i64, _>(0)?)
}
}
}
mod traits;
#[cfg(feature = "mysql")]
pub use traits::LoadRelatedMy;
#[cfg(feature = "sqlite")]
pub use traits::LoadRelatedSqlite;
pub use traits::{
MaybeMyFromRow, MaybeMyLoadRelated, MaybePgFromRow, MaybeSqliteFromRow, MaybeSqliteLoadRelated,
};
pub async fn select_rows_pool<T>(pool: &Pool, query: &SelectQuery) -> Result<Vec<T>, ExecError>
where
T: MaybePgFromRow + MaybeMyFromRow + MaybeSqliteFromRow + Send + Unpin,
{
let stmt = pool.dialect().compile_select(query)?;
match pool {
#[cfg(feature = "postgres")]
Pool::Postgres(pg) => {
let mut q: QueryAs<'_, sqlx::Postgres, T, PgArguments> =
sqlx::query_as::<_, T>(&stmt.sql);
for v in stmt.params {
q = bind_query_as(q, v);
}
Ok(q.fetch_all(pg).await?)
}
#[cfg(feature = "mysql")]
Pool::Mysql(my) => {
let mut q: sqlx::query::QueryAs<'_, sqlx::MySql, T, sqlx::mysql::MySqlArguments> =
sqlx::query_as::<_, T>(&stmt.sql);
for v in stmt.params {
q = bind_query_as_my(q, v);
}
Ok(q.fetch_all(my).await?)
}
#[cfg(feature = "sqlite")]
Pool::Sqlite(sq) => {
let mut q: sqlx::query::QueryAs<
'_,
sqlx::Sqlite,
T,
sqlx::sqlite::SqliteArguments<'_>,
> = sqlx::query_as::<_, T>(&stmt.sql);
for v in stmt.params {
q = bind_query_as_sqlite(q, v);
}
Ok(q.fetch_all(sq).await?)
}
}
}
pub async fn select_one_row_pool<T>(
pool: &Pool,
query: &SelectQuery,
) -> Result<Option<T>, ExecError>
where
T: MaybePgFromRow + MaybeMyFromRow + MaybeSqliteFromRow + Send + Unpin,
{
crate::test_assertions::query_counter::bump();
let stmt = pool.dialect().compile_select(query)?;
match pool {
#[cfg(feature = "postgres")]
Pool::Postgres(pg) => {
let mut q: QueryAs<'_, sqlx::Postgres, T, PgArguments> =
sqlx::query_as::<_, T>(&stmt.sql);
for v in stmt.params {
q = bind_query_as(q, v);
}
Ok(q.fetch_optional(pg).await?)
}
#[cfg(feature = "mysql")]
Pool::Mysql(my) => {
let mut q: sqlx::query::QueryAs<'_, sqlx::MySql, T, sqlx::mysql::MySqlArguments> =
sqlx::query_as::<_, T>(&stmt.sql);
for v in stmt.params {
q = bind_query_as_my(q, v);
}
Ok(q.fetch_optional(my).await?)
}
#[cfg(feature = "sqlite")]
Pool::Sqlite(sq) => {
let mut q: sqlx::query::QueryAs<
'_,
sqlx::Sqlite,
T,
sqlx::sqlite::SqliteArguments<'_>,
> = sqlx::query_as::<_, T>(&stmt.sql);
for v in stmt.params {
q = bind_query_as_sqlite(q, v);
}
Ok(q.fetch_optional(sq).await?)
}
}
}
#[cfg(feature = "mysql")]
pub(super) fn bind_query_as_my<T>(
q: sqlx::query::QueryAs<'_, sqlx::MySql, T, sqlx::mysql::MySqlArguments>,
value: SqlValue,
) -> sqlx::query::QueryAs<'_, sqlx::MySql, T, sqlx::mysql::MySqlArguments> {
bind_match_mysql!(q, value)
}
#[cfg(feature = "sqlite")]
pub(super) fn bind_query_as_sqlite<'a, T>(
q: sqlx::query::QueryAs<'a, sqlx::Sqlite, T, sqlx::sqlite::SqliteArguments<'a>>,
value: SqlValue,
) -> sqlx::query::QueryAs<'a, sqlx::Sqlite, T, sqlx::sqlite::SqliteArguments<'a>> {
bind_match_sqlite!(q, value)
}
pub async fn fetch_aggregate_pool<T>(
pool: &Pool,
query: &AggregateQuery,
) -> Result<Vec<T>, ExecError>
where
T: MaybePgFromRow + MaybeMyFromRow + MaybeSqliteFromRow + Send + Unpin,
{
let stmt = pool.dialect().compile_aggregate(query)?;
match pool {
#[cfg(feature = "postgres")]
Pool::Postgres(pg) => {
let mut q: QueryAs<'_, sqlx::Postgres, T, PgArguments> =
sqlx::query_as::<_, T>(&stmt.sql);
for v in stmt.params {
q = bind_query_as(q, v);
}
Ok(q.fetch_all(pg).await?)
}
#[cfg(feature = "mysql")]
Pool::Mysql(my) => {
let mut q: sqlx::query::QueryAs<'_, sqlx::MySql, T, sqlx::mysql::MySqlArguments> =
sqlx::query_as::<_, T>(&stmt.sql);
for v in stmt.params {
q = bind_query_as_my(q, v);
}
Ok(q.fetch_all(my).await?)
}
#[cfg(feature = "sqlite")]
Pool::Sqlite(sq) => {
let mut q: sqlx::query::QueryAs<
'_,
sqlx::Sqlite,
T,
sqlx::sqlite::SqliteArguments<'_>,
> = sqlx::query_as::<_, T>(&stmt.sql);
for v in stmt.params {
q = bind_query_as_sqlite(q, v);
}
Ok(q.fetch_all(sq).await?)
}
}
}
mod values;
#[allow(unused_imports)]
pub use values::{
fetch_aggregate_dict, fetch_values_dict, fetch_values_flat, fetch_values_list, MaybeMyScalar,
MaybePgScalar, MaybeSqliteScalar,
};
pub async fn raw_query_pool<T>(
sql: &str,
binds: Vec<SqlValue>,
pool: &Pool,
) -> Result<Vec<T>, ExecError>
where
T: MaybePgFromRow + MaybeMyFromRow + MaybeSqliteFromRow + Send + Unpin,
{
match pool {
#[cfg(feature = "postgres")]
Pool::Postgres(pg) => {
let mut q: QueryAs<'_, sqlx::Postgres, T, PgArguments> = sqlx::query_as::<_, T>(sql);
for v in binds {
q = bind_query_as(q, v);
}
Ok(q.fetch_all(pg).await?)
}
#[cfg(feature = "mysql")]
Pool::Mysql(my) => {
let mut q: sqlx::query::QueryAs<'_, sqlx::MySql, T, sqlx::mysql::MySqlArguments> =
sqlx::query_as::<_, T>(sql);
for v in binds {
q = bind_query_as_my(q, v);
}
Ok(q.fetch_all(my).await?)
}
#[cfg(feature = "sqlite")]
Pool::Sqlite(sq) => {
let mut q: sqlx::query::QueryAs<
'_,
sqlx::Sqlite,
T,
sqlx::sqlite::SqliteArguments<'_>,
> = sqlx::query_as::<_, T>(sql);
for v in binds {
q = bind_query_as_sqlite(q, v);
}
Ok(q.fetch_all(sq).await?)
}
}
}
pub async fn raw_query_tx<T>(
tx: &mut tx::PoolTx<'_>,
sql: &str,
binds: Vec<SqlValue>,
) -> Result<Vec<T>, ExecError>
where
T: MaybePgFromRow + MaybeMyFromRow + MaybeSqliteFromRow + Send + Unpin,
{
crate::test_assertions::query_counter::bump();
match tx {
#[cfg(feature = "postgres")]
tx::PoolTx::Postgres(t) => {
let mut q: QueryAs<'_, sqlx::Postgres, T, PgArguments> = sqlx::query_as::<_, T>(sql);
for v in binds {
q = bind_query_as(q, v);
}
Ok(q.fetch_all(&mut **t).await?)
}
#[cfg(feature = "mysql")]
tx::PoolTx::Mysql(t) => {
let mut q: sqlx::query::QueryAs<'_, sqlx::MySql, T, sqlx::mysql::MySqlArguments> =
sqlx::query_as::<_, T>(sql);
for v in binds {
q = bind_query_as_my(q, v);
}
Ok(q.fetch_all(&mut **t).await?)
}
#[cfg(feature = "sqlite")]
tx::PoolTx::Sqlite(t) => {
let mut q: sqlx::query::QueryAs<
'_,
sqlx::Sqlite,
T,
sqlx::sqlite::SqliteArguments<'_>,
> = sqlx::query_as::<_, T>(sql);
for v in binds {
q = bind_query_as_sqlite(q, v);
}
Ok(q.fetch_all(&mut **t).await?)
}
}
}
pub async fn fetch_dates_pool<T: crate::core::Model + Send>(
pool: &Pool,
qs: crate::query::DatesQuerySet<T>,
) -> Result<Vec<chrono::NaiveDate>, ExecError> {
let descending = qs.descending;
let kind = qs.kind;
let column = qs.resolve_column()?;
let select_query = qs.qs.compile()?;
let dialect = pool.dialect();
let inner = dialect.compile_select(&select_query)?;
let col_quoted = dialect.quote_ident(column);
let trunc_sql = kind.trunc_sql(dialect.name(), &col_quoted);
let order_dir = if descending { "DESC" } else { "ASC" };
let sql = format!(
"SELECT DISTINCT {trunc_sql} AS rs_dates_bucket FROM ({inner_sql}) AS rs_dates_sub ORDER BY rs_dates_bucket {order_dir}",
inner_sql = inner.sql,
);
let rows: Vec<(chrono::NaiveDate,)> = raw_query_pool(&sql, inner.params, pool).await?;
Ok(rows.into_iter().map(|r| r.0).collect())
}
pub async fn fetch_datetimes_pool<T: crate::core::Model + Send>(
pool: &Pool,
qs: crate::query::DateTimesQuerySet<T>,
) -> Result<Vec<chrono::DateTime<chrono::Utc>>, ExecError> {
let descending = qs.descending;
let kind = qs.kind;
let column = qs.resolve_column()?;
let select_query = qs.qs.compile()?;
let dialect = pool.dialect();
let inner = dialect.compile_select(&select_query)?;
let col_quoted = dialect.quote_ident(column);
let trunc_sql = kind.trunc_sql(dialect.name(), &col_quoted);
let order_dir = if descending { "DESC" } else { "ASC" };
let sql = format!(
"SELECT DISTINCT {trunc_sql} AS rs_datetimes_bucket FROM ({inner_sql}) AS rs_datetimes_sub ORDER BY rs_datetimes_bucket {order_dir}",
inner_sql = inner.sql,
);
let rows: Vec<(chrono::DateTime<chrono::Utc>,)> =
raw_query_pool(&sql, inner.params, pool).await?;
Ok(rows.into_iter().map(|r| r.0).collect())
}
pub trait CounterPool<T: Model + Send> {
fn count(self, pool: &Pool)
-> impl std::future::Future<Output = Result<i64, ExecError>> + Send;
}
impl<T: Model + Send> CounterPool<T> for QuerySet<T> {
async fn count(self, pool: &Pool) -> Result<i64, ExecError> {
let select = self.compile()?;
count_rows_pool(
pool,
&CountQuery {
model: select.model,
where_clause: select.where_clause,
search: select.search,
},
)
.await
}
}
pub trait ExistsPool<T: Model + Send> {
fn exists(
self,
pool: &Pool,
) -> impl std::future::Future<Output = Result<bool, ExecError>> + Send;
fn is_empty(
self,
pool: &Pool,
) -> impl std::future::Future<Output = Result<bool, ExecError>> + Send;
fn doesnt_exist(
self,
pool: &Pool,
) -> impl std::future::Future<Output = Result<bool, ExecError>> + Send;
fn contains_pk(
self,
pool: &Pool,
pk_value: impl Into<crate::core::SqlValue> + Send,
) -> impl std::future::Future<Output = Result<bool, ExecError>> + Send;
}
impl<T: Model + Send> ExistsPool<T> for QuerySet<T> {
async fn exists(self, pool: &Pool) -> Result<bool, ExecError> {
let count = self.count(pool).await?;
Ok(count > 0)
}
async fn is_empty(self, pool: &Pool) -> Result<bool, ExecError> {
let count = self.count(pool).await?;
Ok(count == 0)
}
async fn doesnt_exist(self, pool: &Pool) -> Result<bool, ExecError> {
self.is_empty(pool).await
}
async fn contains_pk(
self,
pool: &Pool,
pk_value: impl Into<crate::core::SqlValue> + Send,
) -> Result<bool, ExecError> {
let Some(pk_field) = T::SCHEMA.primary_key() else {
return Err(ExecError::Query(crate::core::QueryError::UnknownField {
model: T::SCHEMA.name,
field: "primary_key".into(),
}));
};
self.filter_op(pk_field.column, crate::core::Op::Eq, pk_value.into())
.exists(pool)
.await
}
}
pub async fn fetch_paginated_pool<T>(
qs: crate::query::QuerySet<T>,
pool: &Pool,
) -> Result<Page<T>, ExecError>
where
T: Model + MaybePgFromRow + MaybeMyFromRow + MaybeSqliteFromRow + Send + Unpin,
{
let select = qs.compile()?;
let stmt = pool.dialect().compile_select(&select)?;
let sql = inject_total_count(&stmt.sql);
match pool {
#[cfg(feature = "postgres")]
Pool::Postgres(pg) => {
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&sql);
for v in stmt.params {
q = bind_query(q, v);
}
use sqlx::Row as _;
let raw_rows: Vec<PgRow> = q.fetch_all(pg).await?;
let total: i64 = raw_rows
.first()
.map(|row| row.try_get::<i64, _>("__rustango_total"))
.transpose()?
.unwrap_or(0);
let mut rows = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
rows.push(T::from_row(row)?);
}
Ok(Page { rows, total })
}
#[cfg(feature = "mysql")]
Pool::Mysql(my) => {
let mut q: sqlx::query::Query<'_, sqlx::MySql, sqlx::mysql::MySqlArguments> =
sqlx::query(&sql);
for v in stmt.params {
q = bind_query_my(q, v);
}
use sqlx::Row as _;
let raw_rows: Vec<sqlx::mysql::MySqlRow> = q.fetch_all(my).await?;
let total: i64 = raw_rows
.first()
.map(|row| row.try_get::<i64, _>("__rustango_total"))
.transpose()?
.unwrap_or(0);
let mut rows = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
rows.push(<T as sqlx::FromRow<sqlx::mysql::MySqlRow>>::from_row(row)?);
}
Ok(Page { rows, total })
}
#[cfg(feature = "sqlite")]
Pool::Sqlite(sq) => {
let mut q: sqlx::query::Query<'_, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'_>> =
sqlx::query(&sql);
for v in stmt.params {
q = bind_query_sqlite(q, v);
}
use sqlx::Row as _;
let raw_rows: Vec<sqlx::sqlite::SqliteRow> = q.fetch_all(sq).await?;
let total: i64 = raw_rows
.first()
.map(|row| row.try_get::<i64, _>("__rustango_total"))
.transpose()?
.unwrap_or(0);
let mut rows = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
rows.push(<T as sqlx::FromRow<sqlx::sqlite::SqliteRow>>::from_row(
row,
)?);
}
Ok(Page { rows, total })
}
}
}
mod prefetch;
pub use prefetch::{fetch_with_prefetch_filtered, fetch_with_prefetch_pool};
pub async fn select_rows_pool_with_related<T>(
pool: &Pool,
query: &SelectQuery,
) -> Result<Vec<T>, ExecError>
where
T: MaybePgFromRow
+ MaybeMyFromRow
+ MaybeSqliteFromRow
+ LoadRelated
+ MaybeMyLoadRelated
+ MaybeSqliteLoadRelated
+ Send
+ Unpin,
{
crate::test_assertions::query_counter::bump();
let stmt = pool.dialect().compile_select(query)?;
let aliases: Vec<&'static str> = query.joins.iter().map(|j| j.alias).collect();
let leaves = select_related_leaves(&aliases);
match pool {
#[cfg(feature = "postgres")]
Pool::Postgres(pg) => {
if aliases.is_empty() {
let mut q: QueryAs<'_, sqlx::Postgres, T, PgArguments> =
sqlx::query_as::<_, T>(&stmt.sql);
for v in stmt.params {
q = bind_query_as(q, v);
}
return Ok(q.fetch_all(pg).await?);
}
let mut q: Query<'_, sqlx::Postgres, PgArguments> = sqlx::query(&stmt.sql);
for v in stmt.params {
q = bind_query(q, v);
}
let raw_rows = q.fetch_all(pg).await?;
let mut out = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
let mut t = T::from_row(row)?;
for &(alias, first_hop) in &leaves {
let _ = t.__rustango_load_related(row, alias, first_hop)?;
}
out.push(t);
}
Ok(out)
}
#[cfg(feature = "mysql")]
Pool::Mysql(my) => {
if aliases.is_empty() {
let mut q: sqlx::query::QueryAs<'_, sqlx::MySql, T, sqlx::mysql::MySqlArguments> =
sqlx::query_as::<_, T>(&stmt.sql);
for v in stmt.params {
q = bind_query_as_my(q, v);
}
return Ok(q.fetch_all(my).await?);
}
let mut q: sqlx::query::Query<'_, sqlx::MySql, sqlx::mysql::MySqlArguments> =
sqlx::query(&stmt.sql);
for v in stmt.params {
q = bind_query_my(q, v);
}
let raw_rows = q.fetch_all(my).await?;
let mut out = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
let mut t = <T as sqlx::FromRow<sqlx::mysql::MySqlRow>>::from_row(row)?;
for &(alias, first_hop) in &leaves {
let _ = t.__rustango_load_related_my(row, alias, first_hop)?;
}
out.push(t);
}
Ok(out)
}
#[cfg(feature = "sqlite")]
Pool::Sqlite(sq) => {
if aliases.is_empty() {
let mut q: sqlx::query::QueryAs<
'_,
sqlx::Sqlite,
T,
sqlx::sqlite::SqliteArguments<'_>,
> = sqlx::query_as::<_, T>(&stmt.sql);
for v in stmt.params {
q = bind_query_as_sqlite(q, v);
}
return Ok(q.fetch_all(sq).await?);
}
let mut q: sqlx::query::Query<'_, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'_>> =
sqlx::query(&stmt.sql);
for v in stmt.params {
q = bind_query_sqlite(q, v);
}
let raw_rows = q.fetch_all(sq).await?;
let mut out = Vec::with_capacity(raw_rows.len());
for row in &raw_rows {
let mut t = <T as sqlx::FromRow<sqlx::sqlite::SqliteRow>>::from_row(row)?;
for &(alias, first_hop) in &leaves {
let _ = t.__rustango_load_related_sqlite(row, alias, first_hop)?;
}
out.push(t);
}
Ok(out)
}
}
}
pub trait FetcherPool<T>
where
T: Model
+ MaybePgFromRow
+ MaybeMyFromRow
+ MaybeSqliteFromRow
+ LoadRelated
+ MaybeMyLoadRelated
+ MaybeSqliteLoadRelated
+ Send
+ Unpin,
{
fn fetch(
self,
pool: &Pool,
) -> impl std::future::Future<Output = Result<Vec<T>, ExecError>> + Send;
}
impl<T> FetcherPool<T> for QuerySet<T>
where
T: Model
+ MaybePgFromRow
+ MaybeMyFromRow
+ MaybeSqliteFromRow
+ LoadRelated
+ MaybeMyLoadRelated
+ MaybeSqliteLoadRelated
+ Send
+ Unpin,
{
async fn fetch(self, pool: &Pool) -> Result<Vec<T>, ExecError> {
let select = self.compile()?;
select_rows_pool_with_related(pool, &select).await
}
}
impl<T> crate::query::QuerySet<T>
where
T: Model
+ MaybePgFromRow
+ MaybeMyFromRow
+ MaybeSqliteFromRow
+ LoadRelated
+ MaybeMyLoadRelated
+ MaybeSqliteLoadRelated
+ Send
+ Unpin,
{
pub async fn first(self, pool: &Pool) -> Result<Option<T>, ExecError> {
let qs = ensure_pk_ordering(self, false);
let rows = qs.limit(1).fetch(pool).await?;
Ok(rows.into_iter().next())
}
pub async fn last(self, pool: &Pool) -> Result<Option<T>, ExecError> {
let qs = ensure_pk_ordering(self, true);
let rows = qs.limit(1).fetch(pool).await?;
Ok(rows.into_iter().next())
}
pub async fn earliest(mut self, field: &str, pool: &Pool) -> Result<Option<T>, ExecError> {
self = self.replace_order_by(&[(field, false)]);
let rows = self.limit(1).fetch(pool).await?;
Ok(rows.into_iter().next())
}
pub async fn latest(mut self, field: &str, pool: &Pool) -> Result<Option<T>, ExecError> {
self = self.replace_order_by(&[(field, true)]);
let rows = self.limit(1).fetch(pool).await?;
Ok(rows.into_iter().next())
}
pub async fn find(self, pk: impl Into<SqlValue>, pool: &Pool) -> Result<Option<T>, ExecError> {
self.where_key(pk).first(pool).await
}
pub async fn first_or_fail(self, pool: &Pool) -> Result<T, ExecError> {
match self.first(pool).await? {
Some(row) => Ok(row),
None => Err(ExecError::Driver(sqlx::Error::RowNotFound)),
}
}
pub async fn find_or_fail(self, pk: impl Into<SqlValue>, pool: &Pool) -> Result<T, ExecError> {
self.where_key(pk).first_or_fail(pool).await
}
pub async fn sole(self, pool: &Pool) -> Result<T, ExecError> {
let mut rows = self.limit(2).fetch(pool).await?;
match rows.len() {
0 => Err(ExecError::Driver(sqlx::Error::RowNotFound)),
1 => Ok(rows.remove(0)),
n => Err(ExecError::MultipleRowsReturned {
op: "sole",
table: T::SCHEMA.name,
count: n,
}),
}
}
pub async fn latest_default(self, pool: &Pool) -> Result<Option<T>, ExecError> {
let Some((field, attr_desc)) = T::SCHEMA.get_latest_by else {
return Err(ExecError::Driver(sqlx::Error::Configuration(
::std::format!(
"`{model}::latest_default()` requires `#[rustango(get_latest_by = \"<col>\")]`",
model = T::SCHEMA.name
)
.into(),
)));
};
let _ = attr_desc;
self.latest(field, pool).await
}
pub async fn earliest_default(self, pool: &Pool) -> Result<Option<T>, ExecError> {
let Some((field, _attr_desc)) = T::SCHEMA.get_latest_by else {
return Err(ExecError::Driver(sqlx::Error::Configuration(
::std::format!(
"`{model}::earliest_default()` requires `#[rustango(get_latest_by = \"<col>\")]`",
model = T::SCHEMA.name
)
.into(),
)));
};
self.earliest(field, pool).await
}
pub fn iterator(self, chunk_size: i64) -> Result<ChunkedIter<T>, crate::core::QueryError> {
assert!(
chunk_size > 0,
"QuerySet::iterator: chunk_size must be > 0; got {chunk_size}"
);
let query = self.compile()?;
Ok(ChunkedIter {
query,
chunk_size,
offset: 0,
exhausted: false,
buffer: std::collections::VecDeque::new(),
seen: 0,
_model: std::marker::PhantomData,
})
}
pub async fn in_bulk<C, K, I, F>(
self,
column: C,
ids: I,
extract: F,
pool: &Pool,
) -> Result<std::collections::HashMap<K, T>, ExecError>
where
C: crate::core::Column<Model = T>,
K: Eq + std::hash::Hash + Into<crate::core::SqlValue>,
I: IntoIterator<Item = K>,
F: Fn(&T) -> K,
{
let _ = column;
let id_values: Vec<crate::core::SqlValue> = ids.into_iter().map(|v| v.into()).collect();
if id_values.is_empty() {
return Ok(std::collections::HashMap::new());
}
let rows = self
.filter_op(
C::COLUMN,
crate::core::Op::In,
crate::core::SqlValue::List(id_values),
)
.fetch(pool)
.await?;
let mut out = std::collections::HashMap::with_capacity(rows.len());
for row in rows {
let key = extract(&row);
out.insert(key, row);
}
Ok(out)
}
}
mod get_or_create;
pub use get_or_create::{get_or_create, update_or_create};
fn ensure_pk_ordering<T: Model>(
qs: crate::query::QuerySet<T>,
reverse: bool,
) -> crate::query::QuerySet<T> {
if !qs.has_order_by() {
let pk = T::SCHEMA.primary_key().map(|f| f.column);
if let Some(pk_col) = pk {
return qs.replace_order_by(&[(pk_col, reverse)]);
}
qs
} else if reverse {
qs.flip_order_by()
} else {
qs
}
}
pub trait FetcherTx<T>
where
T: Model
+ MaybePgFromRow
+ MaybeMyFromRow
+ MaybeSqliteFromRow
+ LoadRelated
+ MaybeMyLoadRelated
+ MaybeSqliteLoadRelated
+ Send
+ Unpin,
{
fn fetch_tx(
self,
tx: &mut PoolTx<'_>,
) -> impl std::future::Future<Output = Result<Vec<T>, ExecError>> + Send;
}
impl<T> FetcherTx<T> for QuerySet<T>
where
T: Model
+ MaybePgFromRow
+ MaybeMyFromRow
+ MaybeSqliteFromRow
+ LoadRelated
+ MaybeMyLoadRelated
+ MaybeSqliteLoadRelated
+ Send
+ Unpin,
{
async fn fetch_tx(self, tx: &mut PoolTx<'_>) -> Result<Vec<T>, ExecError> {
let select = self.compile()?;
select_rows_tx_with_related(tx, &select).await
}
}
mod iter;
pub use iter::ChunkedIter;
#[cfg(test)]
mod pool_dispatch_tests {
#[allow(unused_imports)]
use super::*;
#[cfg(feature = "mysql")]
#[tokio::test]
async fn mysql_pool_dispatch_uses_mysql_dialect() {
let my = sqlx::mysql::MySqlPoolOptions::new()
.max_connections(1)
.connect_lazy("mysql://user:pass@localhost:1/none")
.unwrap();
let pool: Pool = my.into();
assert_eq!(pool.dialect().name(), "mysql");
assert_eq!(pool.dialect().quote_ident("col"), "`col`");
assert_eq!(pool.dialect().placeholder(1), "?");
}
#[cfg(feature = "postgres")]
#[tokio::test]
async fn postgres_pool_dispatch_uses_postgres_dialect() {
let pg = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.connect_lazy("postgres://localhost:1/none")
.unwrap();
let pool: Pool = pg.into();
assert_eq!(pool.dialect().name(), "postgres");
assert_eq!(pool.dialect().quote_ident("col"), "\"col\"");
assert_eq!(pool.dialect().placeholder(1), "$1");
}
#[test]
fn maybe_my_from_row_resolves_for_unit_type() {
fn check<T: super::MaybeMyFromRow>() {}
check::<()>();
}
#[test]
fn select_related_leaves_keeps_deepest_chain_aliases() {
assert_eq!(
super::select_related_leaves(&["author", "editor"]),
vec![("author", "author"), ("editor", "editor")],
);
assert_eq!(
super::select_related_leaves(&[
"author",
"author__profile",
"author__profile__country"
]),
vec![("author__profile__country", "author")],
);
assert_eq!(
super::select_related_leaves(&["editor", "author", "author__profile"]),
vec![("editor", "editor"), ("author__profile", "author")],
);
assert_eq!(
super::select_related_leaves(&["author", "authorship"]),
vec![("author", "author"), ("authorship", "authorship")],
);
}
}