use serde::Serialize;
use sqlx::postgres::PgRow;
use crate::core::condition::{Condition, JoinOp, SqlValue};
use crate::core::model::Model;
use crate::core::query::QueryBuilder;
use crate::core::sqlx::pg as sqlx_pg;
use crate::orm::pagination;
use crate::orm::postgres::{executor, pool};
pub struct ModelQuery<M> {
builder: QueryBuilder<M>,
include_trashed: bool,
skip_scopes: bool,
named_db: Option<String>,
}
impl<M> ModelQuery<M>
where
M: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Sync + Unpin + 'static,
{
pub(crate) fn new(builder: QueryBuilder<M>) -> Self {
Self {
builder,
include_trashed: false,
skip_scopes: false,
named_db: None,
}
}
pub(crate) fn into_final_builder(self) -> QueryBuilder<M> {
let mut b = self.builder;
if !self.skip_scopes {
b = crate::orm::scopes::apply_scopes::<M>(b);
}
if !self.include_trashed {
if let Some(col) = M::soft_delete_column() {
b = b.where_null(col);
}
}
b
}
fn pool_for_query(&self) -> Result<sqlx::PgPool, sqlx::Error> {
if let Some(ref name) = self.named_db {
return pool::get_named_pool(name).ok_or_else(|| {
sqlx::Error::Configuration(
format!("no named pool '{name}' registered — call pool::register_named_pool()")
.into(),
)
});
}
pool::try_current_pool().ok_or_else(|| {
sqlx::Error::Configuration(
"no database pool in scope — add OrmLayer to your router or \
call pool::with_pool() in tests"
.to_string()
.into(),
)
})
}
fn named_pool_override(&self) -> Option<sqlx::PgPool> {
self.named_db.as_deref().and_then(pool::get_named_pool)
}
#[must_use]
pub fn and_where(self, col: &str, val: impl Into<SqlValue>) -> Self {
Self {
builder: self.builder.where_eq(col, val),
..self
}
}
#[must_use]
pub fn or_where(self, col: &str, val: impl Into<SqlValue>) -> Self {
Self {
builder: self.builder.or_where_eq(col, val),
..self
}
}
#[must_use]
pub fn and_where_null(self, col: &str) -> Self {
Self {
builder: self.builder.where_null(col),
..self
}
}
#[must_use]
pub fn and_where_not_null(self, col: &str) -> Self {
Self {
builder: self.builder.where_not_null(col),
..self
}
}
#[must_use]
pub fn and_where_like(self, col: &str, pattern: &str) -> Self {
Self {
builder: self.builder.where_like(col, pattern),
..self
}
}
#[must_use]
pub fn and_where_in(self, col: &str, vals: Vec<impl Into<SqlValue>>) -> Self {
Self {
builder: self.builder.where_in(col, vals),
..self
}
}
#[must_use]
pub fn and_where_not_in(self, col: &str, vals: Vec<impl Into<SqlValue>>) -> Self {
Self {
builder: self.builder.where_not_in(col, vals),
..self
}
}
#[must_use]
pub fn and_where_ilike(self, col: &str, pattern: &str) -> Self {
Self {
builder: self.builder.where_ilike(col, pattern),
..self
}
}
#[must_use]
pub fn and_where_between(
self,
col: &str,
lo: impl Into<SqlValue>,
hi: impl Into<SqlValue>,
) -> Self {
Self {
builder: self.builder.where_between(col, lo, hi),
..self
}
}
#[must_use]
pub fn and_where_op(self, col: &str, op: &str, val: impl Into<SqlValue>) -> Self {
Self {
builder: self.builder.where_op(col, op, val),
..self
}
}
#[must_use]
pub fn and_where_group<F>(self, f: F) -> Self
where
F: FnOnce(QueryBuilder<M>) -> QueryBuilder<M>,
{
Self {
builder: self.builder.where_group(f),
..self
}
}
#[must_use]
pub fn or_where_group<F>(self, f: F) -> Self
where
F: FnOnce(QueryBuilder<M>) -> QueryBuilder<M>,
{
Self {
builder: self.builder.or_where_group(f),
..self
}
}
#[must_use]
pub fn and_where_json(self, col: &str, key: &str, val: impl Into<SqlValue>) -> Self {
Self {
builder: self.builder.where_json(col, key, val),
..self
}
}
#[must_use]
pub fn and_where_json_contains(self, col: &str, json_val: &str) -> Self {
Self {
builder: self.builder.where_json_contains(col, json_val),
..self
}
}
#[must_use]
pub fn and_where_raw(self, sql: &str) -> Self {
Self {
builder: self.builder.where_raw(sql),
..self
}
}
#[must_use]
pub fn and_where_exists_sql(self, subquery_sql: &str) -> Self {
Self {
builder: self.builder.where_raw(&format!("EXISTS ({subquery_sql})")),
..self
}
}
#[must_use]
pub fn and_where_not_exists_sql(self, subquery_sql: &str) -> Self {
Self {
builder: self
.builder
.where_raw(&format!("NOT EXISTS ({subquery_sql})")),
..self
}
}
#[must_use]
pub fn order_by(self, col: &str) -> Self {
Self {
builder: self.builder.order_by(col),
..self
}
}
#[must_use]
pub fn order_by_desc(self, col: &str) -> Self {
Self {
builder: self.builder.order_by_desc(col),
..self
}
}
#[must_use]
pub fn limit(self, n: usize) -> Self {
Self {
builder: self.builder.limit(n),
..self
}
}
#[must_use]
pub fn offset(self, n: usize) -> Self {
Self {
builder: self.builder.offset(n),
..self
}
}
#[must_use]
pub fn with_trashed(mut self) -> Self {
self.include_trashed = true;
self
}
#[must_use]
pub fn without_global_scopes(mut self) -> Self {
self.skip_scopes = true;
self
}
#[must_use]
pub fn only_trashed(mut self) -> Self {
self.include_trashed = true;
if let Some(col) = M::soft_delete_column() {
self.builder = self.builder.where_not_null(col);
}
self
}
#[must_use]
pub fn select(self, cols: &[&str]) -> Self {
Self {
builder: self.builder.select(cols),
..self
}
}
#[must_use]
pub fn select_raw(self, expr: &str) -> Self {
Self {
builder: self.builder.select_raw(expr),
..self
}
}
#[must_use]
pub fn join_raw(self, raw: &str) -> Self {
Self {
builder: self.builder.join_raw(raw),
..self
}
}
#[cfg(all(feature = "active", feature = "query"))]
#[must_use]
pub fn and_expr(self, expr: crate::dsl::expr::Expr) -> Self {
Self {
builder: self.builder.push_condition(
crate::core::condition::JoinOp::And,
crate::orm::bridge::expr_to_condition(expr),
),
..self
}
}
#[cfg(all(feature = "active", feature = "query"))]
#[must_use]
pub fn or_expr(self, expr: crate::dsl::expr::Expr) -> Self {
Self {
builder: self.builder.push_condition(
crate::core::condition::JoinOp::Or,
crate::orm::bridge::expr_to_condition(expr),
),
..self
}
}
#[cfg(all(feature = "active", feature = "query"))]
pub fn into_dsl(self) -> crate::dsl::SelectBuilder {
crate::orm::bridge::model_query_into_select(self.into_final_builder())
}
#[must_use]
pub fn lock_for_update(self) -> Self {
Self {
builder: self.builder.for_update(),
..self
}
}
#[must_use]
pub fn lock_for_update_nowait(self) -> Self {
Self {
builder: self.builder.for_update().nowait(),
..self
}
}
#[must_use]
pub fn lock_for_share(self) -> Self {
Self {
builder: self.builder.for_share(),
..self
}
}
#[must_use]
pub fn lock_for_share_skip_locked(self) -> Self {
Self {
builder: self.builder.for_share().skip_locked(),
..self
}
}
#[must_use]
pub fn on(mut self, db_name: impl Into<String>) -> Self {
self.named_db = Some(db_name.into());
self
}
#[must_use]
pub fn on_replica(self) -> Self {
let mut builder = self.builder;
builder.use_replica = true;
Self { builder, ..self }
}
#[must_use]
pub fn on_write_db(self) -> Self {
Self {
builder: self.builder.on_write_db(),
..self
}
}
#[must_use]
pub fn where_in_subquery(self, col: &str, subquery_sql: &str) -> Self {
Self {
builder: self
.builder
.where_raw(&format!("{col} IN ({subquery_sql})")),
..self
}
}
#[must_use]
pub fn where_not_in_subquery(self, col: &str, subquery_sql: &str) -> Self {
Self {
builder: self
.builder
.where_raw(&format!("{col} NOT IN ({subquery_sql})")),
..self
}
}
#[must_use]
pub fn where_exists_sql(self, subquery_sql: &str) -> Self {
Self {
builder: self.builder.where_raw(&format!("EXISTS ({subquery_sql})")),
..self
}
}
#[must_use]
pub fn group_by_raw(self, expr: &str) -> Self {
Self {
builder: self.builder.group_by(&[expr]),
..self
}
}
#[must_use]
pub fn having_raw(self, expr: &str) -> Self {
Self {
builder: self.builder.having(expr),
..self
}
}
#[must_use]
pub fn where_column(self, col: &str, other_col: &str) -> Self {
Self {
builder: self.builder.where_raw(&format!("{col} = {other_col}")),
..self
}
}
#[must_use]
pub fn with_count(self, relation: &str) -> Self {
let parent = M::table_name();
let fk = format!("{}_id", naive_singular(parent));
let expr = format!(
"(SELECT COUNT(*) FROM {relation} WHERE {relation}.{fk} = {parent}.id) AS {relation}_count"
);
Self {
builder: self.builder.add_select_expr(expr),
..self
}
}
#[must_use]
pub fn with_sum(self, relation: &str, col: &str) -> Self {
self.rel_agg(relation, "SUM", col)
}
#[must_use]
pub fn with_avg(self, relation: &str, col: &str) -> Self {
self.rel_agg(relation, "AVG", col)
}
#[must_use]
pub fn with_min(self, relation: &str, col: &str) -> Self {
self.rel_agg(relation, "MIN", col)
}
#[must_use]
pub fn with_max(self, relation: &str, col: &str) -> Self {
self.rel_agg(relation, "MAX", col)
}
fn rel_agg(self, relation: &str, agg: &str, col: &str) -> Self {
let parent = M::table_name();
let fk = format!("{}_id", naive_singular(parent));
let alias = format!("{relation}_{}_{col}", agg.to_lowercase());
let expr = format!(
"(SELECT {agg}({relation}.{col}) FROM {relation} WHERE {relation}.{fk} = {parent}.id) AS {alias}"
);
Self {
builder: self.builder.add_select_expr(expr),
..self
}
}
#[cfg(feature = "postgres")]
#[must_use]
pub fn with(self, relation: impl Into<String>) -> crate::orm::eager::EagerModelQuery<M> {
crate::orm::eager::EagerModelQuery {
query: self,
relations: vec![relation.into()],
}
}
#[must_use]
pub fn has(self, relation: &str) -> Self {
let cond = self.subquery_cond(relation, true, vec![]);
Self {
builder: self.builder.push_condition(JoinOp::And, cond),
..self
}
}
#[must_use]
pub fn doesnt_have(self, relation: &str) -> Self {
let cond = self.subquery_cond(relation, false, vec![]);
Self {
builder: self.builder.push_condition(JoinOp::And, cond),
..self
}
}
#[must_use]
pub fn where_has<F>(self, relation: &str, f: F) -> Self
where
F: FnOnce(QueryBuilder<M>) -> QueryBuilder<M>,
{
let inner_builder = f(QueryBuilder::new(relation));
let inner = inner_builder.conditions().to_vec();
let cond = self.subquery_cond(relation, true, inner);
Self {
builder: self.builder.push_condition(JoinOp::And, cond),
..self
}
}
#[must_use]
pub fn where_doesnt_have<F>(self, relation: &str, f: F) -> Self
where
F: FnOnce(QueryBuilder<M>) -> QueryBuilder<M>,
{
let inner_builder = f(QueryBuilder::new(relation));
let inner = inner_builder.conditions().to_vec();
let cond = self.subquery_cond(relation, false, inner);
Self {
builder: self.builder.push_condition(JoinOp::And, cond),
..self
}
}
fn subquery_cond(
&self,
relation: &str,
exists: bool,
inner: Vec<(JoinOp, Condition)>,
) -> Condition {
let parent = M::table_name();
let fk = format!("{}_id", naive_singular(parent));
Condition::Subquery {
exists,
table: relation.to_string(),
fk_expr: format!("{relation}.{fk} = {parent}.id"),
inner,
}
}
pub async fn get(self) -> Result<Vec<M>, sqlx::Error> {
let table = M::table_name();
let pool_override = self.named_pool_override();
let builder = self.into_final_builder();
let rows = match pool_override {
Some(p) => pool::with_pool(p, pool::fetch_all(builder)).await?,
None => pool::fetch_all(builder).await?,
};
crate::orm::n1::record(table, rows.len());
Ok(rows)
}
pub async fn first(self) -> Result<Option<M>, sqlx::Error> {
let pool_override = self.named_pool_override();
let builder = self.into_final_builder().limit(1);
match pool_override {
Some(p) => pool::with_pool(p, pool::fetch_optional(builder)).await,
None => pool::fetch_optional(builder).await,
}
}
pub async fn first_or_404(self) -> Result<M, sqlx::Error> {
self.first().await?.ok_or(sqlx::Error::RowNotFound)
}
pub async fn first_or_default(self) -> Result<M, sqlx::Error>
where
M: Default,
{
Ok(self.first().await?.unwrap_or_default())
}
pub async fn first_or_else(self, f: impl FnOnce() -> M) -> Result<M, sqlx::Error> {
Ok(self.first().await?.unwrap_or_else(f))
}
pub async fn count(self) -> Result<i64, sqlx::Error> {
let pool_override = self.named_pool_override();
let builder = self.into_final_builder();
match pool_override {
Some(p) => pool::with_pool(p, pool::count(builder)).await,
None => pool::count(builder).await,
}
}
pub async fn force_delete(self) -> Result<u64, sqlx::Error> {
let pool = self.pool_for_query()?;
executor::delete(&pool, self.builder).await
}
pub async fn soft_delete(self) -> Result<u64, sqlx::Error> {
let col = M::soft_delete_column().unwrap_or("deleted_at");
let pool = self.pool_for_query()?;
executor::soft_delete::<M>(&pool, self.builder, col).await
}
pub async fn restore(self) -> Result<u64, sqlx::Error> {
let col = M::soft_delete_column().unwrap_or("deleted_at");
let pool = self.pool_for_query()?;
executor::restore::<M>(&pool, self.builder, col).await
}
pub fn stream(
self,
) -> impl futures_core::Stream<Item = Result<M, sqlx::Error>> + Send + 'static {
let pool_result = self.pool_for_query();
let (sql, params) = self.into_final_builder().to_sql();
let (tx, rx) = tokio::sync::mpsc::channel::<Result<M, sqlx::Error>>(64);
tokio::spawn(async move {
let pool = match pool_result {
Ok(p) => p,
Err(e) => {
let _ = tx.send(Err(e)).await;
return;
}
};
let mut s = sqlx_pg::build_query_as::<M>(&sql, params).fetch(&pool);
use futures::TryStreamExt as _;
while let Some(result) = s.try_next().await.transpose() {
if tx.send(result).await.is_err() {
break;
}
}
});
MpscStream(rx)
}
pub async fn max(self, col: &str) -> Result<Option<f64>, sqlx::Error> {
let pool_override = self.named_pool_override();
let builder = self.into_final_builder();
let agg = format!("MAX({col})");
match pool_override {
Some(p) => pool::with_pool(p, pool::aggregate(builder, &agg)).await,
None => pool::aggregate(builder, &agg).await,
}
}
pub async fn min(self, col: &str) -> Result<Option<f64>, sqlx::Error> {
let pool_override = self.named_pool_override();
let builder = self.into_final_builder();
let agg = format!("MIN({col})");
match pool_override {
Some(p) => pool::with_pool(p, pool::aggregate(builder, &agg)).await,
None => pool::aggregate(builder, &agg).await,
}
}
pub async fn sum(self, col: &str) -> Result<Option<f64>, sqlx::Error> {
let pool_override = self.named_pool_override();
let builder = self.into_final_builder();
let agg = format!("SUM({col})");
match pool_override {
Some(p) => pool::with_pool(p, pool::aggregate(builder, &agg)).await,
None => pool::aggregate(builder, &agg).await,
}
}
pub async fn avg(self, col: &str) -> Result<Option<f64>, sqlx::Error> {
let pool_override = self.named_pool_override();
let builder = self.into_final_builder();
let agg = format!("AVG({col})");
match pool_override {
Some(p) => pool::with_pool(p, pool::aggregate(builder, &agg)).await,
None => pool::aggregate(builder, &agg).await,
}
}
pub async fn exists(self) -> Result<bool, sqlx::Error> {
Ok(self.count().await? > 0)
}
pub async fn doesnt_exist(self) -> Result<bool, sqlx::Error> {
Ok(self.count().await? == 0)
}
pub async fn last(self) -> Result<Option<M>, sqlx::Error> {
let pool_override = self.named_pool_override();
let builder = self
.into_final_builder()
.reorder_desc(M::primary_key())
.limit(1);
match pool_override {
Some(p) => pool::with_pool(p, pool::fetch_optional(builder)).await,
None => pool::fetch_optional(builder).await,
}
}
pub async fn find_or_fail(self) -> Result<M, sqlx::Error> {
self.first_or_404().await
}
pub async fn value(self, col: &str) -> Result<Option<String>, sqlx::Error> {
let pool = self.pool_for_query()?;
let builder = self.into_final_builder().select(&[col]).limit(1);
let (sql, params) = builder.to_sql();
let row = sqlx_pg::build_query(&sql, params)
.fetch_optional(&pool)
.await?;
use sqlx::Row;
Ok(row.and_then(|r| r.try_get::<Option<String>, _>(0).ok().flatten()))
}
pub async fn pluck(self, col: &str) -> Result<Vec<String>, sqlx::Error> {
let pool = self.pool_for_query()?;
let builder = self.into_final_builder().select(&[col]);
let (sql, params) = builder.to_sql();
let rows = sqlx_pg::build_query(&sql, params).fetch_all(&pool).await?;
use sqlx::Row;
let vals = rows
.into_iter()
.filter_map(|r| r.try_get::<String, _>(0).ok())
.collect();
Ok(vals)
}
pub async fn pairs(self, col1: &str, col2: &str) -> Result<Vec<(String, String)>, sqlx::Error> {
let pool = self.pool_for_query()?;
let builder = self.into_final_builder().select(&[col1, col2]);
let (sql, params) = builder.to_sql();
let rows = sqlx_pg::build_query(&sql, params).fetch_all(&pool).await?;
use sqlx::Row;
let vals = rows
.into_iter()
.filter_map(|r| {
let a = r.try_get::<String, _>(0).ok()?;
let b = r.try_get::<String, _>(1).ok()?;
Some((a, b))
})
.collect();
Ok(vals)
}
pub async fn chunk<F, Fut, E>(self, size: usize, mut f: F) -> Result<(), E>
where
M: Clone,
F: FnMut(Vec<M>) -> Fut,
Fut: std::future::Future<Output = Result<(), E>>,
E: From<sqlx::Error>,
{
let pool_override = self.named_pool_override();
let base = self.into_final_builder().order_by(M::primary_key());
let mut offset = 0usize;
loop {
let batch = match &pool_override {
Some(p) => pool::with_pool(
p.clone(),
pool::fetch_all(base.clone().limit(size).offset(offset)),
)
.await
.map_err(E::from)?,
None => pool::fetch_all(base.clone().limit(size).offset(offset))
.await
.map_err(E::from)?,
};
let is_last = batch.len() < size;
f(batch).await?;
if is_last {
break;
}
offset += size;
}
Ok(())
}
pub async fn chunk_by_id<F, Fut, E>(self, size: usize, f: F) -> Result<(), E>
where
M: Clone,
F: FnMut(Vec<M>) -> Fut,
Fut: std::future::Future<Output = Result<(), E>>,
E: From<sqlx::Error>,
{
self.chunk(size, f).await
}
pub async fn paginate(
self,
per_page: u32,
current_page: u32,
) -> Result<pagination::Page<M>, sqlx::Error>
where
M: Serialize + Clone,
{
let pool_override = self.named_pool_override();
let final_builder = self.into_final_builder();
let offset = ((current_page.saturating_sub(1)) as usize) * (per_page as usize);
let (total, data) = match pool_override {
Some(p) => {
let total = pool::with_pool(p.clone(), pool::count(final_builder.clone())).await?;
let data = pool::with_pool(
p,
pool::fetch_all(final_builder.limit(per_page as usize).offset(offset)),
)
.await?;
(total, data)
}
None => {
let total = pool::count(final_builder.clone()).await?;
let data =
pool::fetch_all(final_builder.limit(per_page as usize).offset(offset)).await?;
(total, data)
}
};
Ok(pagination::Page::new(data, total, per_page, current_page))
}
pub async fn simple_paginate(
self,
per_page: u32,
current_page: u32,
) -> Result<pagination::SimplePage<M>, sqlx::Error>
where
M: Serialize,
{
let pool_override = self.named_pool_override();
let offset = ((current_page.saturating_sub(1)) as usize) * (per_page as usize);
let builder = self
.into_final_builder()
.limit(per_page as usize + 1)
.offset(offset);
let data = match pool_override {
Some(p) => pool::with_pool(p, pool::fetch_all(builder)).await?,
None => pool::fetch_all(builder).await?,
};
Ok(pagination::SimplePage::new(data, per_page, current_page))
}
pub async fn cursor_paginate(
self,
per_page: u32,
cursor_col: &str,
cursor: Option<&str>,
) -> Result<pagination::CursorPage<M>, sqlx::Error>
where
M: Serialize,
{
let pool_override = self.named_pool_override();
let mut b = self.into_final_builder();
let prev_cursor: Option<String>;
if let Some(token) = cursor {
if let Some(id) = pagination::decode_cursor(token) {
b = b.where_gt(cursor_col, id);
prev_cursor = Some(pagination::encode_cursor(id));
} else {
prev_cursor = None;
}
} else {
prev_cursor = None;
}
let fetch_builder = b.order_by(cursor_col).limit(per_page as usize + 1);
let data = match pool_override {
Some(p) => pool::with_pool(p, pool::fetch_all(fetch_builder)).await?,
None => pool::fetch_all(fetch_builder).await?,
};
let next_id_cursor: Option<String> = if data.len() > per_page as usize {
match data[per_page as usize - 1].pk_value() {
SqlValue::Integer(pk) => Some(pagination::encode_cursor(pk)),
_ => None,
}
} else {
None
};
Ok(pagination::CursorPage::new(
data,
per_page,
next_id_cursor,
prev_cursor,
))
}
}
struct MpscStream<T>(tokio::sync::mpsc::Receiver<T>);
impl<T: Unpin> futures_core::Stream for MpscStream<T> {
type Item = T;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<T>> {
self.0.poll_recv(cx)
}
}
fn naive_singular(table: &str) -> String {
if let Some(s) = table.strip_suffix("ies") {
format!("{s}y")
} else if let Some(s) = table.strip_suffix('s') {
s.to_string()
} else {
table.to_string()
}
}