use std::marker::PhantomData;
use crate::dialect::{
render_count, render_delete, render_exists, render_select, render_update, Dialect,
};
use crate::error::OrmError;
use crate::executor::Executor;
use crate::model::{FromRow, Model};
use crate::query::ast::{
Cte, CteQuery, JoinKind, LockClause, LockStrength, LockWait, OrderTerm, SelectItem,
SelectStatement, UnionStatement, WithClause,
};
use crate::query::column::Column;
use crate::query::expr::{BinaryOp, Expr};
use crate::query::projection::ExprTuple;
use crate::query::write::{Assignment, DeleteStatement, UpdateStatement};
use crate::value::Value;
pub struct QuerySet<M: Model> {
statement: SelectStatement,
scope_active: bool,
_marker: PhantomData<fn() -> M>,
}
#[derive(Debug, Clone)]
pub struct Page<T> {
pub items: Vec<T>,
pub total: i64,
pub page: u64,
pub page_size: u64,
pub pages: u64,
}
impl<M: Model> QuerySet<M> {
pub fn new() -> Self {
let projection = M::COLUMNS
.iter()
.map(|column| SelectItem::Column {
table: M::TABLE,
column: column.name,
})
.collect();
let mut statement = SelectStatement::new(M::TABLE, projection);
let scope_active = if let Some(column) = M::DELETED_AT {
statement.filters.push(scope_filter(M::TABLE, column, false));
true
} else {
false
};
Self {
statement,
scope_active,
_marker: PhantomData,
}
}
pub fn with_deleted(mut self) -> Self {
if self.scope_active {
self.statement.filters.remove(0);
self.scope_active = false;
}
self
}
pub fn only_deleted(mut self) -> Self {
if let Some(column) = M::DELETED_AT {
if self.scope_active {
self.statement.filters.remove(0);
self.scope_active = false;
}
self.statement.filters.push(scope_filter(M::TABLE, column, true));
}
self
}
pub fn filter(mut self, predicate: Expr) -> Self {
self.statement.filters.push(predicate);
self
}
pub fn filter_raw<V, I>(mut self, sql: impl Into<String>, params: I) -> Self
where
V: crate::value::BindValue,
I: IntoIterator<Item = V>,
{
let raw_params = params.into_iter().map(|v| v.to_value()).collect();
self.statement.filters.push(Expr::Raw { sql: sql.into(), params: raw_params });
self
}
pub fn filter_any(mut self, predicates: impl IntoIterator<Item = Expr>) -> Self {
self.statement.filters.push(Expr::any(predicates));
self
}
pub fn filter_all(mut self, predicates: impl IntoIterator<Item = Expr>) -> Self {
self.statement.filters.push(Expr::all(predicates));
self
}
pub fn filter_not(mut self, predicate: Expr) -> Self {
self.statement.filters.push(Expr::not(predicate));
self
}
pub fn join<C>(mut self, relation: crate::relation::Relation<M, C>) -> Self {
self.statement.joins.push(relation.join_node());
self
}
pub fn left_join<C>(mut self, relation: crate::relation::Relation<M, C>) -> Self {
self.statement
.joins
.push(relation.join_node_with_kind(JoinKind::Left));
self
}
pub fn right_join<C>(mut self, relation: crate::relation::Relation<M, C>) -> Self {
self.statement
.joins
.push(relation.join_node_with_kind(JoinKind::Right));
self
}
pub fn full_join<C>(mut self, relation: crate::relation::Relation<M, C>) -> Self {
self.statement
.joins
.push(relation.join_node_with_kind(JoinKind::Full));
self
}
pub fn cross_join<C: crate::model::Model>(mut self) -> Self {
self.statement.joins.push(crate::query::ast::Join {
kind: JoinKind::Cross,
table: C::TABLE,
alias: None,
left_table: "",
left_column: "",
right_table: "",
right_column: "",
});
self
}
pub fn self_join(
self,
alias: &'static str,
base_column: &'static str,
alias_column: &'static str,
) -> Self {
self.self_join_with_kind(JoinKind::Inner, alias, base_column, alias_column)
}
pub fn self_left_join(
self,
alias: &'static str,
base_column: &'static str,
alias_column: &'static str,
) -> Self {
self.self_join_with_kind(JoinKind::Left, alias, base_column, alias_column)
}
fn self_join_with_kind(
mut self,
kind: JoinKind,
alias: &'static str,
base_column: &'static str,
alias_column: &'static str,
) -> Self {
self.statement.joins.push(crate::query::ast::Join {
kind,
table: M::TABLE,
alias: Some(alias),
left_table: M::TABLE,
left_column: base_column,
right_table: alias,
right_column: alias_column,
});
self
}
pub fn preload<C: Model>(
self,
relation: crate::relation::Relation<M, C>,
) -> crate::preload::Preloader<M> {
crate::preload::Preloader::new(self).preload(relation)
}
pub fn order_by(mut self, term: OrderTerm) -> Self {
self.statement.order_by.push(term);
self
}
pub fn limit(mut self, limit: u64) -> Self {
self.statement.limit = Some(limit);
self
}
pub fn offset(mut self, offset: u64) -> Self {
self.statement.offset = Some(offset);
self
}
pub fn distinct(mut self) -> Self {
self.statement.distinct = true;
self
}
pub fn distinct_on<G: ExprTuple>(mut self, group: G) -> Self {
self.statement.distinct_on = group.into_exprs();
self
}
pub fn none(mut self) -> Self {
self.statement.filters.push(Expr::binary(
Expr::value(crate::value::Value::Int(0)),
BinaryOp::Eq,
Expr::value(crate::value::Value::Int(1)),
));
self
}
pub fn for_update(mut self) -> Self {
self.statement.lock = Some(LockClause::new(LockStrength::Update));
self
}
pub fn for_share(mut self) -> Self {
self.statement.lock = Some(LockClause::new(LockStrength::Share));
self
}
pub fn skip_locked(mut self) -> Self {
self.lock_mut().wait = LockWait::SkipLocked;
self
}
pub fn nowait(mut self) -> Self {
self.lock_mut().wait = LockWait::NoWait;
self
}
pub fn lock_of(mut self, tables: &[&'static str]) -> Self {
self.lock_mut().of = tables.to_vec();
self
}
fn lock_mut(&mut self) -> &mut LockClause {
self.statement
.lock
.get_or_insert_with(|| LockClause::new(LockStrength::Update))
}
pub fn keyset_after(mut self, cursor: Vec<Value>) -> Self {
let predicate = keyset_predicate(&self.statement.order_by, &cursor, true);
self.statement.filters.push(predicate);
self
}
pub fn keyset_before(mut self, cursor: Vec<Value>) -> Self {
let predicate = keyset_predicate(&self.statement.order_by, &cursor, false);
self.statement.filters.push(predicate);
self
}
pub fn with(mut self, ctes: impl IntoIterator<Item = (&'static str, CteQuery)>) -> Self {
self.statement.with = Some(WithClause {
recursive: false,
ctes: ctes
.into_iter()
.map(|(name, query)| Cte { name, columns: None, query })
.collect(),
});
self
}
pub fn with_recursive(
mut self,
ctes: impl IntoIterator<Item = (&'static str, crate::query::union::UnionQuery<M>)>,
) -> Self {
self.statement.with = Some(WithClause {
recursive: true,
ctes: ctes
.into_iter()
.map(|(name, query)| Cte {
name,
columns: None,
query: CteQuery::Union(Box::new(query.into_statement())),
})
.collect(),
});
self
}
pub fn select<P: crate::query::projection::Projection>(mut self, projection: P) -> Self {
self.statement.projection = projection.into_select_items();
self
}
pub fn group_by<G: crate::query::projection::ExprTuple>(mut self, group: G) -> Self {
self.statement.group_by = group.into_exprs();
self
}
pub fn having(mut self, predicate: Expr) -> Self {
self.statement.having = Some(predicate);
self
}
pub fn statement(&self) -> &SelectStatement {
&self.statement
}
pub fn into_statement(self) -> SelectStatement {
self.statement
}
pub fn to_subquery(self) -> crate::query::expr::Expr {
crate::query::expr::Expr::subquery(self.statement)
}
pub fn union(self, other: QuerySet<M>) -> crate::query::union::UnionQuery<M> {
crate::query::union::UnionQuery::new(self, other, false)
}
pub fn union_all(self, other: QuerySet<M>) -> crate::query::union::UnionQuery<M> {
crate::query::union::UnionQuery::new(self, other, true)
}
pub async fn all(self, executor: impl Executor) -> crate::Result<Vec<M>> {
self.all_as::<M>(executor).await
}
pub async fn all_as<T: FromRow>(self, executor: impl Executor) -> crate::Result<Vec<T>> {
validate_for_dialect(executor.dialect(), &self.statement)?;
let (sql, params) = render_select(executor.dialect(), &self.statement);
let rows = executor.fetch_all(sql, params).await?;
rows.iter().map(T::from_row).collect()
}
pub async fn first<E: Executor>(mut self, executor: E) -> crate::Result<Option<M>> {
self.statement.limit = Some(1);
validate_for_dialect(executor.dialect(), &self.statement)?;
let (sql, params) = render_select(executor.dialect(), &self.statement);
let rows = executor.fetch_all(sql, params).await?;
match rows.first() {
Some(row) => M::from_row(row).map(Some),
None => Ok(None),
}
}
pub async fn one<E: Executor>(mut self, executor: E) -> crate::Result<M> {
self.statement.limit = Some(2);
validate_for_dialect(executor.dialect(), &self.statement)?;
let (sql, params) = render_select(executor.dialect(), &self.statement);
let rows = executor.fetch_all(sql, params).await?;
match rows.len() {
0 => Err(OrmError::not_found(format!(
"no row in `{}` matched the query",
M::TABLE
))),
1 => M::from_row(&rows[0]),
_ => Err(OrmError::multiple_found(format!(
"more than one row in `{}` matched the query",
M::TABLE
))),
}
}
pub async fn one_or_none<E: Executor>(
mut self,
executor: E,
) -> crate::Result<Option<M>> {
self.statement.limit = Some(2);
validate_for_dialect(executor.dialect(), &self.statement)?;
let (sql, params) = render_select(executor.dialect(), &self.statement);
let rows = executor.fetch_all(sql, params).await?;
match rows.len() {
0 => Ok(None),
1 => M::from_row(&rows[0]).map(Some),
_ => Err(OrmError::multiple_found(format!(
"more than one row in `{}` matched the query",
M::TABLE
))),
}
}
pub async fn count<E: Executor>(self, executor: E) -> crate::Result<i64> {
let (sql, params) = render_count(executor.dialect(), &self.statement);
let rows = executor.fetch_all(sql, params).await?;
match rows.first() {
Some(row) => row.get_index::<i64>(0),
None => Ok(0),
}
}
pub async fn paginate<E: Executor>(
self,
executor: E,
page: u64,
page_size: u64,
) -> crate::Result<Page<M>> {
self.paginate_as::<M, E>(executor, page, page_size).await
}
pub async fn paginate_as<T: FromRow, E: Executor>(
self,
executor: E,
page: u64,
page_size: u64,
) -> crate::Result<Page<T>> {
let page = page.max(1);
let page_size = page_size.max(1);
let (count_sql, count_params) = render_count(executor.dialect(), &self.statement);
let rows = executor.fetch_all(count_sql, count_params).await?;
let total: i64 = match rows.first() {
Some(row) => row.get_index::<i64>(0)?,
None => 0,
};
let pages = if total == 0 {
1
} else {
(total as u64).div_ceil(page_size)
};
let page = page.min(pages);
let offset = (page - 1) * page_size;
let mut statement = self.statement;
statement.limit = Some(page_size);
statement.offset = Some(offset);
let (sql, params) = render_select(executor.dialect(), &statement);
let rows = executor.fetch_all(sql, params).await?;
let items: Vec<T> = rows.iter().map(T::from_row).collect::<crate::Result<_>>()?;
Ok(Page {
items,
total,
page,
page_size,
pages,
})
}
pub async fn exists<E: Executor>(self, executor: E) -> crate::Result<bool> {
let (sql, params) = render_exists(executor.dialect(), &self.statement);
let rows = executor.fetch_all(sql, params).await?;
match rows.first() {
Some(row) => row.get_index::<bool>(0),
None => Ok(false),
}
}
pub async fn chunk<E: Executor>(
self,
executor: E,
size: u64,
) -> crate::Result<Vec<Vec<M>>> {
let size = size.max(1);
let mut batches = Vec::new();
let mut offset = 0u64;
loop {
let mut batch_stmt = self.statement.clone();
batch_stmt.limit = Some(size);
batch_stmt.offset = Some(offset);
let (sql, params) = render_select(executor.dialect(), &batch_stmt);
let rows = executor.fetch_all(sql, params).await?;
if rows.is_empty() {
break;
}
let batch: Vec<M> = rows.iter().map(M::from_row).collect::<crate::Result<_>>()?;
let batch_len = batch.len() as u64;
batches.push(batch);
if batch_len < size {
break;
}
offset += size;
}
Ok(batches)
}
pub async fn pluck<T: crate::value::FromValue, E: Executor>(
mut self,
executor: E,
column: Column<M, T>,
) -> crate::Result<Vec<T>> {
self.statement.projection = vec![SelectItem::Column {
table: column.table(),
column: column.name(),
}];
validate_for_dialect(executor.dialect(), &self.statement)?;
let (sql, params) = render_select(executor.dialect(), &self.statement);
let rows = executor.fetch_all(sql, params).await?;
rows.iter().map(|row| row.get::<T>(column.name())).collect()
}
pub async fn update<E: Executor>(
self,
executor: E,
assignments: impl IntoIterator<Item = Assignment>,
) -> crate::Result<u64> {
let statement = UpdateStatement {
table: self.statement.table,
assignments: assignments.into_iter().collect(),
filters: self.statement.filters,
returning: Vec::new(),
};
let (sql, params) = render_update(executor.dialect(), &statement);
Ok(executor.execute(sql, params).await?.rows_affected)
}
pub async fn update_returning<E: Executor>(
self,
executor: E,
assignments: impl IntoIterator<Item = Assignment>,
) -> crate::Result<Vec<M>> {
let returning = M::COLUMNS.iter().map(|c| c.name).collect();
let statement = UpdateStatement {
table: self.statement.table,
assignments: assignments.into_iter().collect(),
filters: self.statement.filters,
returning,
};
let (sql, params) = render_update(executor.dialect(), &statement);
let rows = executor.fetch_all(sql, params).await?;
rows.iter().map(M::from_row).collect()
}
pub async fn delete<E: Executor>(self, executor: E) -> crate::Result<u64> {
if let Some(column) = M::DELETED_AT {
let statement = UpdateStatement {
table: self.statement.table,
assignments: vec![Assignment::new(column, Expr::raw("CURRENT_TIMESTAMP"))],
filters: self.statement.filters,
returning: Vec::new(),
};
let (sql, params) = render_update(executor.dialect(), &statement);
return Ok(executor.execute(sql, params).await?.rows_affected);
}
let statement = DeleteStatement {
table: self.statement.table,
filters: self.statement.filters,
returning: Vec::new(),
};
let (sql, params) = render_delete(executor.dialect(), &statement);
Ok(executor.execute(sql, params).await?.rows_affected)
}
pub async fn hard_delete<E: Executor>(self, executor: E) -> crate::Result<u64> {
let statement = DeleteStatement {
table: self.statement.table,
filters: self.statement.filters,
returning: Vec::new(),
};
let (sql, params) = render_delete(executor.dialect(), &statement);
Ok(executor.execute(sql, params).await?.rows_affected)
}
pub async fn restore<E: Executor>(self, executor: E) -> crate::Result<u64> {
let Some(column) = M::DELETED_AT else {
return Ok(0);
};
let statement = UpdateStatement {
table: self.statement.table,
assignments: vec![Assignment::new(column, Expr::value(Value::Null))],
filters: self.statement.filters,
returning: Vec::new(),
};
let (sql, params) = render_update(executor.dialect(), &statement);
Ok(executor.execute(sql, params).await?.rows_affected)
}
pub async fn delete_returning<E: Executor>(self, executor: E) -> crate::Result<Vec<M>> {
let returning = M::COLUMNS.iter().map(|c| c.name).collect();
let statement = DeleteStatement {
table: self.statement.table,
filters: self.statement.filters,
returning,
};
let (sql, params) = render_delete(executor.dialect(), &statement);
let rows = executor.fetch_all(sql, params).await?;
rows.iter().map(M::from_row).collect()
}
}
impl<M: Model> Default for QuerySet<M> {
fn default() -> Self {
Self::new()
}
}
pub(crate) fn validate_for_dialect(
dialect: &dyn Dialect,
statement: &SelectStatement,
) -> crate::Result<()> {
if !dialect.supports_full_join()
&& statement.joins.iter().any(|join| join.kind == JoinKind::Full)
{
return Err(OrmError::query(format!(
"FULL OUTER JOIN is not supported by the `{}` dialect",
dialect.name()
)));
}
if !dialect.supports_distinct_on() && !statement.distinct_on.is_empty() {
return Err(OrmError::query(format!(
"DISTINCT ON is not supported by the `{}` dialect",
dialect.name()
)));
}
if !dialect.supports_lock_modifiers()
&& statement.lock.as_ref().is_some_and(LockClause::uses_modifiers)
{
return Err(OrmError::query(format!(
"FOR SHARE / SKIP LOCKED / NOWAIT / OF is not supported by the `{}` dialect",
dialect.name()
)));
}
if let Some(with) = &statement.with {
for cte in &with.ctes {
match &cte.query {
CteQuery::Select(select) => validate_for_dialect(dialect, select)?,
CteQuery::Union(union) => validate_union_for_dialect(dialect, union)?,
}
}
}
Ok(())
}
fn scope_filter(table: &'static str, column: &'static str, deleted: bool) -> Expr {
Expr::is_null(Expr::column(table, column), deleted)
}
fn keyset_predicate(order: &[OrderTerm], cursor: &[Value], after: bool) -> Expr {
assert!(
!order.is_empty(),
"keyset pagination requires at least one `order_by` term"
);
assert_eq!(
order.len(),
cursor.len(),
"keyset cursor length ({}) must match the number of `order_by` terms ({})",
cursor.len(),
order.len(),
);
let mut disjuncts = Vec::with_capacity(order.len());
for boundary in 0..order.len() {
let mut conjuncts = Vec::with_capacity(boundary + 1);
for prior in 0..boundary {
conjuncts.push(Expr::binary(
order[prior].expr.clone(),
BinaryOp::Eq,
Expr::value(cursor[prior].clone()),
));
}
let ascending = !order[boundary].descending;
let op = if ascending == after { BinaryOp::Gt } else { BinaryOp::Lt };
conjuncts.push(Expr::binary(
order[boundary].expr.clone(),
op,
Expr::value(cursor[boundary].clone()),
));
disjuncts.push(Expr::all(conjuncts));
}
Expr::any(disjuncts)
}
pub(crate) fn validate_union_for_dialect(
dialect: &dyn Dialect,
union: &UnionStatement,
) -> crate::Result<()> {
validate_for_dialect(dialect, &union.first)?;
for (_, branch) in &union.rest {
validate_for_dialect(dialect, branch)?;
}
Ok(())
}