use smallvec::SmallVec;
use std::collections::HashSet;
use std::marker::PhantomData;
use sqlx::{
postgres::PgRow,
Execute,
Executor,
FromRow,
Postgres,
Row as _,
};
use sqlxo_traits::{
AliasedColumn,
FullTextSearchConfig,
FullTextSearchable,
GetDeleteMarker,
JoinIdentifiable,
JoinKind,
JoinNavigationModel,
JoinPath,
PrimaryKey,
QueryContext,
SqlWrite,
};
use crate::{
and,
blocks::{
BuildableFilter,
BuildableJoin,
BuildablePage,
BuildableSort,
Expression,
Page,
Pagination,
QualifiedColumn,
ReadHead,
SelectProjection,
SelectType,
SortOrder,
SqlWriter,
},
order_by,
select::{
AggregateFunction,
AggregateSelection,
GroupByList,
HavingList,
HavingPredicate,
SelectionColumn,
SelectionEntry,
SelectionList,
},
Buildable,
ExecutablePlan,
FetchablePlan,
Planable,
PrimaryKeyExpression,
QueryBuilder,
Result,
};
#[allow(dead_code)]
pub trait BuildableReadQuery<C, Row = <C as QueryContext>::Model>:
Buildable<C, Row = Row, Plan: Planable<C, Row>>
+ BuildableFilter<C>
+ BuildableJoin<C>
+ BuildableSort<C>
+ BuildablePage<C>
where
C: QueryContext,
Row: Send + Sync + Unpin + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>,
{
}
pub(crate) trait DynFullTextSearchPlan: Send + Sync {
fn write_condition(
&self,
w: &mut SqlWriter,
base_alias: &str,
joins: Option<&[JoinPath]>,
);
fn write_rank_expr(
&self,
w: &mut SqlWriter,
base_alias: &str,
joins: Option<&[JoinPath]>,
);
fn include_rank(&self) -> bool;
}
struct ModelFullTextSearchPlan<M>
where
M: FullTextSearchable,
{
config: M::FullTextSearchConfig,
_marker: PhantomData<M>,
}
impl<M> ModelFullTextSearchPlan<M>
where
M: FullTextSearchable,
{
fn new(config: M::FullTextSearchConfig) -> Self {
Self {
config,
_marker: PhantomData,
}
}
}
impl<M> DynFullTextSearchPlan for ModelFullTextSearchPlan<M>
where
M: FullTextSearchable + Send + Sync + 'static,
M::FullTextSearchConfig: Send + Sync,
{
fn write_condition(
&self,
w: &mut SqlWriter,
base_alias: &str,
joins: Option<&[JoinPath]>,
) {
M::write_search_predicate(w, base_alias, joins, &self.config);
}
fn write_rank_expr(
&self,
w: &mut SqlWriter,
base_alias: &str,
joins: Option<&[JoinPath]>,
) {
M::write_search_score(w, base_alias, joins, &self.config);
}
fn include_rank(&self) -> bool {
self.config.include_rank()
}
}
pub struct ReadQueryPlan<'a, C: QueryContext, Row = <C as QueryContext>::Model>
{
pub(crate) joins: Option<Vec<JoinPath>>,
pub(crate) where_expr: Option<Expression<C::Query>>,
pub(crate) sort_expr: Option<SortOrder<C::Sort>>,
pub(crate) pagination: Option<Pagination>,
pub(crate) table: &'a str,
pub(crate) include_deleted: bool,
pub(crate) delete_marker_field: Option<&'a str>,
pub(crate) selection: Option<SelectionList<Row, SelectionEntry>>,
pub(crate) group_by: Option<Vec<SelectionColumn>>,
pub(crate) having: Option<Vec<HavingPredicate>>,
pub(crate) full_text_search: Option<Box<dyn DynFullTextSearchPlan>>,
pub(crate) aggregate_filter: Option<AggregateFilter>,
row: PhantomData<Row>,
}
#[derive(Clone)]
pub struct AggregateFilter {
pub columns: SmallVec<[&'static str; 2]>,
pub predicates: Vec<HavingPredicate>,
}
fn build_alias_lookup(
joins: Option<&[JoinPath]>,
) -> Vec<(&'static str, String)> {
let mut aliases = Vec::new();
if let Some(paths) = joins {
for path in paths {
let mut alias_prefix = String::new();
for segment in path.segments() {
if let Some(through) = segment.descriptor.through {
let mut through_alias = alias_prefix.clone();
through_alias.push_str(through.alias_segment);
aliases.push((through.table, through_alias));
}
alias_prefix.push_str(segment.descriptor.alias_segment);
aliases.push((
segment.descriptor.right_table,
alias_prefix.clone(),
));
}
}
}
aliases
}
fn path_is_prefix(prefix: &JoinPath, candidate: &JoinPath) -> bool {
let prefix_segments = prefix.segments();
let candidate_segments = candidate.segments();
if prefix_segments.len() > candidate_segments.len() {
return false;
}
prefix_segments
.iter()
.zip(candidate_segments.iter())
.all(|(left, right)| left == right)
}
fn merge_unique_join_paths(mut paths: Vec<JoinPath>) -> Vec<JoinPath> {
let mut unique: Vec<JoinPath> = Vec::new();
for path in paths.drain(..) {
if unique.iter().any(|existing| existing == &path) {
continue;
}
unique.push(path);
}
let mut merged: Vec<JoinPath> = Vec::new();
for idx in 0..unique.len() {
let current = &unique[idx];
let shadowed = unique.iter().enumerate().any(|(other_idx, other)| {
other_idx != idx &&
path_is_prefix(current, other) &&
current.len() < other.len()
});
if !shadowed {
merged.push(current.clone());
}
}
merged
}
fn resolve_alias_for_table(
table: &'static str,
column: &'static str,
base_table: &str,
aliases: &[(&'static str, String)],
) -> String {
if table == base_table {
return base_table.to_string();
}
let mut matches =
aliases.iter().filter(|(tbl, _)| *tbl == table).peekable();
let Some((_, alias)) = matches.next() else {
panic!(
"`take!` requested column `{table}.{column}` but `{table}` is not \
part of the join set"
);
};
if matches.peek().is_some() {
panic!(
"`take!` requested column `{table}.{column}` but `{table}` is \
joined multiple times; disambiguation is not implemented yet"
);
}
alias.clone()
}
fn resolve_selection_columns(
selection: &[SelectionColumn],
base_table: &str,
joins: Option<&[JoinPath]>,
) -> SmallVec<[QualifiedColumn; 4]> {
let aliases = build_alias_lookup(joins);
let mut resolved = SmallVec::new();
for col in selection {
resolved.push(resolve_selection_column(col, base_table, &aliases));
}
resolved
}
fn resolve_selection_column(
column: &SelectionColumn,
base_table: &str,
aliases: &[(&'static str, String)],
) -> QualifiedColumn {
let table_alias = resolve_alias_for_table(
column.table,
column.column,
base_table,
aliases,
);
QualifiedColumn {
table_alias,
column: column.column,
}
}
fn format_aggregate_expression(
selection: &AggregateSelection,
base_table: &str,
aliases: &[(&'static str, String)],
) -> String {
match selection.column {
Some(col) => {
let qualified = resolve_selection_column(&col, base_table, aliases);
match selection.function {
AggregateFunction::CountDistinct => format!(
r#"COUNT(DISTINCT "{}"."{}")"#,
qualified.table_alias, qualified.column
),
_ => format!(
r#"{}("{}"."{}")"#,
selection.function.sql_name(),
qualified.table_alias,
qualified.column
),
}
}
None => format!("{}(*)", selection.function.sql_name()),
}
}
fn primary_key_text_alias(idx: usize) -> String {
format!("__sqlxo_pk_{}", idx)
}
fn write_having_predicate(
predicate: &HavingPredicate,
writer: &mut SqlWriter,
base_table: &str,
aliases: &[(&'static str, String)],
) {
let expr =
format_aggregate_expression(&predicate.selection, base_table, aliases);
writer.push(&expr);
writer.push(" ");
writer.push(predicate.comparator.as_str());
writer.push(" ");
predicate.bind_value(writer);
}
impl<'a, C, Row> ReadQueryPlan<'a, C, Row>
where
C: QueryContext,
C::Model: JoinNavigationModel,
{
fn compute_aggregate_filter(&mut self) {
if self.having.is_none() ||
self.selection.is_some() ||
self.group_by.is_some()
{
return;
}
let pk_columns = <C::Model as PrimaryKey>::PRIMARY_KEY;
if pk_columns.is_empty() {
return;
}
let Some(predicates) = self.having.take() else {
return;
};
if predicates.is_empty() {
return;
}
let columns = SmallVec::<[&'static str; 2]>::from_slice(pk_columns);
self.aggregate_filter = Some(AggregateFilter {
columns,
predicates,
});
}
fn push_aggregate_filter_clause(
&self,
writer: &mut SqlWriter,
filter: &AggregateFilter,
) {
writer.push_where_raw(|w| {
if filter.columns.len() == 1 {
let col = filter.columns[0];
w.push(&format!(r#""{}"."{}""#, self.table, col));
} else {
w.push("(");
for (idx, col) in filter.columns.iter().enumerate() {
if idx > 0 {
w.push(", ");
}
w.push(&format!(r#""{}"."{}""#, self.table, col));
}
w.push(")");
}
w.push(" IN (");
self.write_aggregate_subquery(w, filter);
w.push(")");
});
}
fn write_aggregate_subquery(
&self,
writer: &mut SqlWriter,
filter: &AggregateFilter,
) {
writer.push("SELECT ");
for (idx, col) in filter.columns.iter().enumerate() {
if idx > 0 {
writer.push(", ");
}
writer.push(&format!(r#""{}"."{}""#, self.table, col));
}
writer.push(" FROM ");
writer.push(self.table);
if let Some(js) = &self.joins {
for path in js {
push_join_path_inline(
writer.query_builder_mut(),
path,
self.table,
);
}
}
self.write_subquery_filters(writer);
self.write_subquery_group_by(writer, filter);
self.write_subquery_having(writer, filter);
}
fn write_subquery_filters(&self, writer: &mut SqlWriter) {
let mut has_clause = false;
if !self.include_deleted {
if let Some(delete_field) = self.delete_marker_field {
writer.push(" WHERE ");
writer.push(&format!(
r#""{}"."{}" IS NULL"#,
self.table, delete_field
));
has_clause = true;
}
}
if let Some(expr) = &self.where_expr {
if has_clause {
writer.push(" AND (");
} else {
writer.push(" WHERE (");
}
expr.write(writer);
writer.push(")");
has_clause = true;
}
if let Some(fts) = &self.full_text_search {
if has_clause {
writer.push(" AND (");
} else {
writer.push(" WHERE (");
}
fts.write_condition(writer, self.table, self.joins.as_deref());
writer.push(")");
}
}
fn write_subquery_group_by(
&self,
writer: &mut SqlWriter,
filter: &AggregateFilter,
) {
writer.push(" GROUP BY ");
for (idx, col) in filter.columns.iter().enumerate() {
if idx > 0 {
writer.push(", ");
}
writer.push(&format!(r#""{}"."{}""#, self.table, col));
}
}
fn write_subquery_having(
&self,
writer: &mut SqlWriter,
filter: &AggregateFilter,
) {
if filter.predicates.is_empty() {
return;
}
let aliases = build_alias_lookup(self.joins.as_deref());
writer.push(" HAVING ");
for (idx, predicate) in filter.predicates.iter().enumerate() {
if idx > 0 {
writer.push(" AND ");
}
write_having_predicate(predicate, writer, self.table, &aliases);
}
}
fn to_query_builder(
&self,
select_type: SelectType,
) -> sqlx::QueryBuilder<'static, Postgres> {
self.to_query_builder_with_options(select_type, self.pagination, true)
}
fn to_query_builder_with_options(
&self,
select_type: SelectType,
pagination: Option<Pagination>,
include_sort: bool,
) -> sqlx::QueryBuilder<'static, Postgres> {
let mut w = SqlWriter::new(ReadHead::new(
self.table,
self.select_type_for(select_type.clone()),
));
self.write_query_with_options(
&mut w,
select_type,
pagination,
include_sort,
);
w.into_builder()
}
fn write_query_with_options(
&self,
w: &mut SqlWriter,
select_type: SelectType,
pagination: Option<Pagination>,
include_sort: bool,
) {
if let Some(js) = &self.joins {
w.push_joins(js, self.table);
}
self.push_where_clause(w);
if let Some(filter) = &self.aggregate_filter {
self.push_aggregate_filter_clause(w, filter);
} else {
self.push_group_by_clause(w);
self.push_having_clause(w);
}
if include_sort {
if let Some(s) = &self.sort_expr {
w.push_sort(s);
} else if !matches!(select_type, SelectType::Exists) {
if let Some(fts) = &self.full_text_search {
if fts.include_rank() {
w.push_order_by_raw(|writer| {
fts.write_rank_expr(
writer,
self.table,
self.joins.as_deref(),
);
writer.push(" DESC");
});
}
}
}
}
if let SelectType::Exists = select_type {
w.push_pagination(&Pagination {
page: 0,
page_size: 1,
});
} else if let Some(p) = pagination.as_ref() {
w.push_pagination(p);
}
if let SelectType::Exists = select_type {
w.push(")");
}
}
fn base_query_builder_without_joins(
&self,
pagination: Option<Pagination>,
) -> ReadQueryBuilder<'a, C>
where
C::Model: crate::GetDeleteMarker + JoinNavigationModel,
{
let mut builder = QueryBuilder::<C>::read().without_auto_joins();
if self.include_deleted {
builder = builder.include_deleted();
}
if let Some(expr) = &self.where_expr {
builder = builder.r#where(expr.clone());
}
if let Some(sort) = &self.sort_expr {
builder = builder.order_by(sort.clone());
}
if let Some(page) = pagination {
builder = builder.paginate(page);
}
builder
}
async fn distinct_total_with_primary_key_projection<'e, E>(
&self,
exec: E,
) -> Result<i64>
where
E: Executor<'e, Database = Postgres>,
{
let pk_columns = C::Model::PRIMARY_KEY;
if pk_columns.is_empty() {
return Ok(0);
}
let has_grouped_filter = self.aggregate_filter.is_some() ||
self.group_by
.as_ref()
.is_some_and(|columns| !columns.is_empty()) ||
self.having
.as_ref()
.is_some_and(|predicates| !predicates.is_empty());
if has_grouped_filter {
let pk_aliases: Vec<String> = pk_columns
.iter()
.enumerate()
.map(|(idx, _)| primary_key_text_alias(idx))
.collect();
let projections: Vec<SelectProjection> = pk_columns
.iter()
.enumerate()
.map(|(idx, column)| SelectProjection {
expression: format!(
r#""{}"."{}"::text"#,
self.table, column
),
alias: Some(primary_key_text_alias(idx)),
})
.collect();
let rows: Vec<PgRow> = self
.to_query_builder_with_options(
SelectType::Projection(projections),
None,
false,
)
.build()
.fetch_all(exec)
.await?;
let mut unique: HashSet<Vec<String>> =
HashSet::with_capacity(rows.len());
for row in rows {
let mut key: Vec<String> = Vec::with_capacity(pk_aliases.len());
for alias in &pk_aliases {
key.push(row.try_get(alias.as_str())?);
}
unique.insert(key);
}
return Ok(unique.len() as i64);
}
let count_expression = if pk_columns.len() == 1 {
format!(r#"COUNT(DISTINCT "{}"."{}")"#, self.table, pk_columns[0])
} else {
let tuple = pk_columns
.iter()
.map(|column| format!(r#""{}"."{}""#, self.table, column))
.collect::<Vec<_>>()
.join(", ");
format!("COUNT(DISTINCT ({}))", tuple)
};
let row: PgRow = self
.to_query_builder_with_options(
SelectType::Projection(vec![SelectProjection {
expression: count_expression,
alias: Some("__sqlxo_total".to_string()),
}]),
None,
false,
)
.build()
.fetch_one(exec)
.await?;
Ok(row.try_get("__sqlxo_total")?)
}
fn select_type_for(&self, base: SelectType) -> SelectType {
let resolved = match base {
SelectType::Star => self
.selection
.as_ref()
.map(|s| self.selection_select_type(s))
.unwrap_or(SelectType::Star),
other => other,
};
self.apply_join_extras(resolved)
}
fn selection_select_type(
&self,
selection: &SelectionList<Row, SelectionEntry>,
) -> SelectType {
let mut has_columns = false;
let mut has_aggregates = false;
for entry in selection.entries() {
match entry {
SelectionEntry::Column(_) => has_columns = true,
SelectionEntry::Aggregate(_) => has_aggregates = true,
}
}
if has_columns && has_aggregates && self.group_by.is_none() {
panic!(
"`group_by!` must be provided when selecting columns \
alongside aggregates"
);
}
if has_columns && !has_aggregates {
let mut cols: SmallVec<[SelectionColumn; 4]> =
SmallVec::with_capacity(selection.entries().len());
for entry in selection.entries() {
if let SelectionEntry::Column(col) = entry {
cols.push(*col);
}
}
return SelectType::Columns(resolve_selection_columns(
&cols,
self.table,
self.joins.as_deref(),
));
}
let projections = self.build_projections(selection);
SelectType::Projection(projections)
}
fn build_projections(
&self,
selection: &SelectionList<Row, SelectionEntry>,
) -> Vec<SelectProjection> {
let aliases = build_alias_lookup(self.joins.as_deref());
selection
.entries()
.iter()
.enumerate()
.map(|(idx, entry)| match entry {
SelectionEntry::Column(col) => {
let qualified =
resolve_selection_column(col, self.table, &aliases);
SelectProjection {
expression: format!(
r#""{}"."{}""#,
qualified.table_alias, qualified.column
),
alias: None,
}
}
SelectionEntry::Aggregate(agg) => {
let expr =
format_aggregate_expression(agg, self.table, &aliases);
let alias = format!(r#"__sqlxo_sel_{}"#, idx);
SelectProjection {
expression: expr,
alias: Some(alias),
}
}
})
.collect()
}
fn apply_join_extras(&self, select: SelectType) -> SelectType {
let extras = self.join_projection_columns();
if extras.is_empty() {
return select;
}
match select {
SelectType::Star => SelectType::StarWithExtras(extras),
SelectType::StarAndCount => SelectType::StarAndCountExtras(extras),
other => other,
}
}
fn join_projection_columns(&self) -> SmallVec<[AliasedColumn; 4]> {
if self.selection.is_some() {
return SmallVec::new();
}
C::Model::collect_join_columns(self.joins.as_deref(), "")
}
fn push_where_clause(&self, w: &mut SqlWriter) {
let mut has_clause = false;
if !self.include_deleted {
if let Some(delete_field) = self.delete_marker_field {
let qualified =
format!(r#""{}"."{}""#, self.table, delete_field);
w.push_where_raw(|writer| {
writer.push(&qualified);
writer.push(" IS NULL");
});
has_clause = true;
}
}
if let Some(e) = &self.where_expr {
let wrap = has_clause;
w.push_where_raw(|writer| {
if wrap {
writer.push("(");
e.write(writer);
writer.push(")");
} else {
e.write(writer);
}
});
has_clause = true;
}
if let Some(fts) = &self.full_text_search {
let wrap = has_clause;
w.push_where_raw(|writer| {
if wrap {
writer.push("(");
}
fts.write_condition(writer, self.table, self.joins.as_deref());
if wrap {
writer.push(")");
}
});
}
}
fn push_group_by_clause(&self, w: &mut SqlWriter) {
if let Some(columns) = &self.group_by {
if columns.is_empty() {
return;
}
let resolved = resolve_selection_columns(
columns,
self.table,
self.joins.as_deref(),
);
w.push_group_by_columns(&resolved);
}
}
fn push_having_clause(&self, w: &mut SqlWriter) {
let Some(predicates) = &self.having else {
return;
};
if predicates.is_empty() {
return;
}
let aliases = build_alias_lookup(self.joins.as_deref());
let table = self.table;
w.push_having(|writer| {
for (idx, predicate) in predicates.iter().enumerate() {
if idx > 0 {
writer.push(" AND ");
}
write_having_predicate(predicate, writer, table, &aliases);
}
});
}
pub async fn fetch_page<'e, E>(&self, exec: E) -> Result<Page<Row>>
where
E: Executor<'e, Database = Postgres> + Clone,
Row: HydrateRow<C> + PageOrderBridge<C>,
{
let pagination = self.pagination.unwrap_or_default();
let has_projected_page = self.selection.is_some() ||
self.aggregate_filter.is_some() ||
self.group_by
.as_ref()
.is_some_and(|columns| !columns.is_empty()) ||
self.having
.as_ref()
.is_some_and(|predicates| !predicates.is_empty());
let can_batch_collections =
<Row as HydrateRow<C>>::requires_collection_merge(self) &&
self.full_text_search.is_none() &&
self.aggregate_filter.is_none() &&
self.group_by.is_none() &&
self.having
.as_ref()
.is_none_or(|predicates| predicates.is_empty());
let total = if has_projected_page {
self.projected_total(exec.clone()).await?
} else {
self.distinct_total_with_primary_key_projection(exec.clone())
.await?
};
if total == 0 {
return Ok(Page::new(vec![], pagination, 0));
}
if can_batch_collections {
let page_models = self
.base_query_builder_without_joins(Some(pagination))
.build()
.fetch_all(exec.clone())
.await?;
if page_models.is_empty() {
return Ok(Page::new(vec![], pagination, total));
}
let Some(key_filter) =
<C::Model as PageKeyFilterSupport<C>>::page_key_filter(
&page_models,
)
else {
let rows: Vec<PgRow> = self
.to_query_builder(SelectType::Star)
.build()
.fetch_all(exec)
.await?;
let mut items = Vec::with_capacity(rows.len());
for row in rows {
items.push(self.map_pg_row(row)?);
}
items =
<Row as HydrateRow<C>>::merge_collection_rows(items, self);
return Ok(Page::new(items, pagination, total));
};
let combined_where = match &self.where_expr {
Some(where_expr) => Some(and![where_expr.clone(), key_filter]),
None => Some(key_filter),
};
let detail_plan = ReadQueryPlan::<C, Row> {
joins: self.joins.clone(),
where_expr: combined_where,
sort_expr: None,
pagination: None,
table: self.table,
include_deleted: self.include_deleted,
delete_marker_field: self.delete_marker_field,
selection: None,
group_by: None,
having: None,
full_text_search: None,
aggregate_filter: None,
row: PhantomData,
};
let rows: Vec<PgRow> = detail_plan
.to_query_builder(SelectType::Star)
.build()
.fetch_all(exec)
.await?;
let mut items = Vec::with_capacity(rows.len());
for row in rows {
items.push(detail_plan.map_pg_row(row)?);
}
items = <Row as HydrateRow<C>>::merge_collection_rows(
items,
&detail_plan,
);
let ordered =
<Row as PageOrderBridge<C>>::order_batched(items, &page_models);
return Ok(Page::new(ordered, pagination, total));
}
let rows: Vec<PgRow> = self
.to_query_builder(SelectType::Star)
.build()
.fetch_all(exec)
.await?;
let mut items = Vec::with_capacity(rows.len());
for row in rows {
items.push(self.map_pg_row(row)?);
}
items = <Row as HydrateRow<C>>::merge_collection_rows(items, self);
Ok(Page::new(items, pagination, total))
}
async fn projected_total<'e, E>(&self, exec: E) -> Result<i64>
where
E: Executor<'e, Database = Postgres>,
{
let mut inner_builder =
self.to_query_builder_with_options(SelectType::Star, None, false);
let mut inner = inner_builder.build();
let count_sql = format!(
r#"SELECT COUNT(*) AS "__sqlxo_total" FROM ({}) AS "__sqlxo_count""#,
inner.sql()
);
match inner
.take_arguments()
.map_err(|err| anyhow::anyhow!(err.to_string()))?
{
Some(args) => Ok(sqlx::query_scalar_with::<Postgres, i64, _>(
&count_sql, args,
)
.fetch_one(exec)
.await?),
None => Ok(sqlx::query_scalar::<Postgres, i64>(&count_sql)
.fetch_one(exec)
.await?),
}
}
pub async fn exists<'e, E>(&self, exec: E) -> Result<bool>
where
E: Executor<'e, Database = Postgres>,
{
#[derive(sqlx::FromRow)]
struct ExistsRow {
exists: bool,
}
let row: ExistsRow = self
.to_query_builder(SelectType::Exists)
.build_query_as::<ExistsRow>()
.fetch_one(exec)
.await?;
Ok(row.exists)
}
#[cfg(any(test, feature = "test-utils"))]
pub fn sql(&self, build: SelectType) -> String {
use sqlx::Execute;
self.to_query_builder(build).build().sql().to_string()
}
fn map_pg_row(&self, row: PgRow) -> Result<Row>
where
Row: HydrateRow<C>,
{
<Row as HydrateRow<C>>::from_pg_row(self, row)
}
}
#[async_trait::async_trait]
impl<'a, C, Row> FetchablePlan<C, Row> for ReadQueryPlan<'a, C, Row>
where
C: QueryContext,
Row: Send + Sync + Unpin + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>,
{
async fn fetch_one<'e, E>(&self, exec: E) -> Result<Row>
where
E: Executor<'e, Database = Postgres>,
{
if <Row as HydrateRow<C>>::requires_collection_merge(self) {
let rows = self
.to_query_builder(SelectType::Star)
.build()
.fetch_all(exec)
.await?;
let mapped = rows
.into_iter()
.map(|row| self.map_pg_row(row))
.collect::<Result<Vec<Row>, _>>()?;
let merged =
<Row as HydrateRow<C>>::merge_collection_rows(mapped, self);
return merged
.into_iter()
.next()
.ok_or(sqlx::Error::RowNotFound.into());
}
let row = self
.to_query_builder(SelectType::Star)
.build()
.fetch_one(exec)
.await?;
Ok(self.map_pg_row(row)?)
}
async fn fetch_all<'e, E>(&self, exec: E) -> Result<Vec<Row>>
where
E: Executor<'e, Database = Postgres>,
{
let rows = self
.to_query_builder(SelectType::Star)
.build()
.fetch_all(exec)
.await?;
let mapped = rows
.into_iter()
.map(|row| self.map_pg_row(row))
.collect::<Result<Vec<Row>, _>>()?;
Ok(<Row as HydrateRow<C>>::merge_collection_rows(mapped, self))
}
async fn fetch_optional<'e, E>(&self, exec: E) -> Result<Option<Row>>
where
E: Executor<'e, Database = Postgres>,
{
if <Row as HydrateRow<C>>::requires_collection_merge(self) {
let rows = self
.to_query_builder(SelectType::Star)
.build()
.fetch_all(exec)
.await?;
let mapped = rows
.into_iter()
.map(|row| self.map_pg_row(row))
.collect::<Result<Vec<Row>, _>>()?;
let merged =
<Row as HydrateRow<C>>::merge_collection_rows(mapped, self);
return Ok(merged.into_iter().next());
}
let row = self
.to_query_builder(SelectType::Star)
.build()
.fetch_optional(exec)
.await?;
match row {
Some(row) => Ok(Some(self.map_pg_row(row)?)),
None => Ok(None),
}
}
}
fn push_join_path_inline(
qb: &mut sqlx::QueryBuilder<'static, Postgres>,
path: &JoinPath,
base_table: &str,
) {
if path.is_empty() {
return;
}
let mut left_alias = base_table.to_string();
let mut alias_prefix = String::new();
for segment in path.segments() {
let join_word = match segment.kind {
JoinKind::Inner => " INNER JOIN ",
JoinKind::Left => " LEFT JOIN ",
};
if let Some(through) = segment.descriptor.through {
let mut through_alias = alias_prefix.clone();
through_alias.push_str(through.alias_segment);
let clause = format!(
r#"{join}{table} AS "{alias}" ON "{left}"."{left_field}" = "{alias}"."{right_field}""#,
join = join_word,
table = through.table,
alias = &through_alias,
left = &left_alias,
left_field = through.left_field,
right_field = through.right_field,
);
qb.push(clause);
left_alias = through_alias;
}
alias_prefix.push_str(segment.descriptor.alias_segment);
let right_alias = alias_prefix.clone();
let clause = format!(
r#"{join}{table} AS "{alias}" ON "{left}"."{left_field}" = "{alias}"."{right_field}""#,
join = join_word,
table = segment.descriptor.right_table,
alias = &right_alias,
left = &left_alias,
left_field = segment.descriptor.left_field,
right_field = segment.descriptor.right_field,
);
qb.push(clause);
left_alias = right_alias;
}
}
#[async_trait::async_trait]
impl<'a, C, Row> ExecutablePlan<C> for ReadQueryPlan<'a, C, Row>
where
C: QueryContext,
Row: Send + Sync + Unpin + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>,
{
async fn execute<'e, E>(&self, exec: E) -> Result<u64>
where
E: Executor<'e, Database = Postgres>,
{
let rows = self
.to_query_builder(SelectType::Star)
.build()
.execute(exec)
.await?
.rows_affected();
Ok(rows)
}
}
impl<'a, C, Row> Planable<C, Row> for ReadQueryPlan<'a, C, Row>
where
C: QueryContext,
Row: Send + Sync + Unpin + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>,
{
}
trait PageKeyFilterSupport<C: QueryContext>: JoinNavigationModel {
fn page_key_filter(_models: &[Self]) -> Option<Expression<C::Query>>
where
Self: Sized,
{
None
}
}
impl<C, M> PageKeyFilterSupport<C> for M
where
C: QueryContext,
M: JoinNavigationModel,
{
default fn page_key_filter(_models: &[Self]) -> Option<Expression<C::Query>>
where
Self: Sized,
{
None
}
}
impl<C, M> PageKeyFilterSupport<C> for M
where
C: QueryContext,
M: JoinNavigationModel + PrimaryKeyExpression<C>,
{
fn page_key_filter(models: &[Self]) -> Option<Expression<C::Query>>
where
Self: Sized,
{
if models.is_empty() {
return None;
}
if models.len() == 1 {
return Some(models[0].primary_key_expression());
}
Some(Expression::Or(
models
.iter()
.map(|model| model.primary_key_expression())
.collect(),
))
}
}
trait PageOrderSupport: JoinNavigationModel {
fn order_batched(items: Vec<Self>, _page_models: &[Self]) -> Vec<Self>
where
Self: Sized,
{
items
}
}
impl<M> PageOrderSupport for M
where
M: JoinNavigationModel,
{
default fn order_batched(
items: Vec<Self>,
_page_models: &[Self],
) -> Vec<Self>
where
Self: Sized,
{
items
}
}
impl<M> PageOrderSupport for M
where
M: JoinNavigationModel + JoinIdentifiable,
M::Key: PartialEq,
{
fn order_batched(mut items: Vec<Self>, page_models: &[Self]) -> Vec<Self>
where
Self: Sized,
{
let mut ordered = Vec::with_capacity(page_models.len());
for page_model in page_models {
let page_key = page_model.join_key();
if let Some(idx) = items
.iter()
.position(|candidate| candidate.join_key() == page_key)
{
ordered.push(items.swap_remove(idx));
}
}
ordered
}
}
#[doc(hidden)]
pub trait PageOrderBridge<C: QueryContext>: Sized {
fn order_batched(items: Vec<Self>, _page_models: &[C::Model]) -> Vec<Self> {
items
}
}
impl<C, Row> PageOrderBridge<C> for Row
where
C: QueryContext,
Row: Send + Sync + Unpin + for<'r> sqlx::FromRow<'r, PgRow>,
{
default fn order_batched(
items: Vec<Self>,
_page_models: &[C::Model],
) -> Vec<Self> {
items
}
}
impl<C> PageOrderBridge<C> for C::Model
where
C: QueryContext,
C::Model: JoinNavigationModel + PageOrderSupport,
{
fn order_batched(items: Vec<Self>, page_models: &[C::Model]) -> Vec<Self> {
<C::Model as PageOrderSupport>::order_batched(items, page_models)
}
}
#[doc(hidden)]
pub trait HydrateRow<C: QueryContext>: Sized {
fn from_pg_row(plan: &ReadQueryPlan<C, Self>, row: PgRow) -> Result<Self>;
fn requires_collection_merge(_plan: &ReadQueryPlan<C, Self>) -> bool {
false
}
fn merge_collection_rows(
rows: Vec<Self>,
_plan: &ReadQueryPlan<C, Self>,
) -> Vec<Self> {
rows
}
}
impl<C, Row> HydrateRow<C> for Row
where
C: QueryContext,
Row: Send + Sync + Unpin + for<'r> sqlx::FromRow<'r, PgRow>,
{
default fn from_pg_row(
_plan: &ReadQueryPlan<C, Row>,
row: PgRow,
) -> Result<Self> {
Ok(Row::from_row(&row)?)
}
default fn requires_collection_merge(
_plan: &ReadQueryPlan<C, Self>,
) -> bool {
false
}
default fn merge_collection_rows(
rows: Vec<Self>,
_plan: &ReadQueryPlan<C, Self>,
) -> Vec<Self> {
rows
}
}
impl<C> HydrateRow<C> for C::Model
where
C: QueryContext,
C::Model: JoinNavigationModel,
{
fn from_pg_row(plan: &ReadQueryPlan<C, Self>, row: PgRow) -> Result<Self> {
let mut model = Self::from_row(&row)?;
if plan.selection.is_none() {
model.hydrate_navigations(plan.joins.as_deref(), &row, "")?;
}
Ok(model)
}
fn requires_collection_merge(plan: &ReadQueryPlan<C, Self>) -> bool {
plan.selection.is_none() &&
Self::has_collection_joins(plan.joins.as_deref())
}
fn merge_collection_rows(
rows: Vec<Self>,
plan: &ReadQueryPlan<C, Self>,
) -> Vec<Self> {
if !Self::has_collection_joins(plan.joins.as_deref()) {
return rows;
}
<Self as JoinNavigationModel>::merge_collection_rows(
rows,
plan.joins.as_deref(),
)
}
}
enum RelationReloadTarget<'m, M> {
One(&'m mut M),
Many(&'m mut [M]),
}
pub struct RelationReloadBuilder<'m, C: QueryContext> {
target: RelationReloadTarget<'m, C::Model>,
include_lazy_relations: bool,
_phantom: PhantomData<C>,
}
impl<'m, C> RelationReloadBuilder<'m, C>
where
C: QueryContext,
{
pub(crate) fn one(model: &'m mut C::Model) -> Self {
Self {
target: RelationReloadTarget::One(model),
include_lazy_relations: false,
_phantom: PhantomData,
}
}
pub(crate) fn many(models: &'m mut [C::Model]) -> Self {
Self {
target: RelationReloadTarget::Many(models),
include_lazy_relations: false,
_phantom: PhantomData,
}
}
pub fn include_lazy_relations(mut self) -> Self {
self.include_lazy_relations = true;
self
}
pub async fn execute<'e, E>(self, exec: E) -> Result<u64>
where
E: Executor<'e, Database = Postgres>,
C::Model:
PrimaryKeyExpression<C> + JoinIdentifiable + JoinNavigationModel,
{
match self.target {
RelationReloadTarget::One(model) => {
let mut builder = QueryBuilder::<C>::read();
if self.include_lazy_relations {
builder = builder.include_lazy_relations();
}
let loaded = builder
.r#where(model.primary_key_expression())
.build()
.fetch_one(exec)
.await?;
*model = loaded;
Ok(1)
}
RelationReloadTarget::Many(models) => {
if models.is_empty() {
return Ok(0);
}
let mut filters: Vec<Expression<C::Query>> =
Vec::with_capacity(models.len());
for model in models.iter() {
filters.push(model.primary_key_expression());
}
let filter_expr = if filters.len() == 1 {
filters.remove(0)
} else {
Expression::Or(filters)
};
let mut builder = QueryBuilder::<C>::read();
if self.include_lazy_relations {
builder = builder.include_lazy_relations();
}
let loaded = builder
.r#where(filter_expr)
.build()
.fetch_all(exec)
.await?;
let mut refreshed = 0u64;
for model in models.iter_mut() {
if let Some(found) = loaded.iter().find(|candidate| {
candidate.join_key() == model.join_key()
}) {
*model = found.clone();
refreshed += 1;
}
}
Ok(refreshed)
}
}
}
}
pub struct ReadQueryBuilder<
'a,
C: QueryContext,
Row = <C as QueryContext>::Model,
> {
pub(crate) table: &'a str,
pub(crate) joins: Option<Vec<JoinPath>>,
pub(crate) where_expr: Option<Expression<C::Query>>,
pub(crate) sort_expr: Option<SortOrder<C::Sort>>,
pub(crate) pagination: Option<Pagination>,
pub(crate) include_deleted: bool,
pub(crate) auto_joins: bool,
pub(crate) include_lazy_relations: bool,
pub(crate) delete_marker_field: Option<&'a str>,
pub(crate) selection: Option<SelectionList<Row, SelectionEntry>>,
pub(crate) group_by: Option<Vec<SelectionColumn>>,
pub(crate) having: Option<Vec<HavingPredicate>>,
pub(crate) full_text_search: Option<Box<dyn DynFullTextSearchPlan>>,
row: PhantomData<Row>,
}
impl<'a, C, Row> ReadQueryBuilder<'a, C, Row>
where
C: QueryContext,
{
pub fn include_deleted(mut self) -> Self {
self.include_deleted = true;
self
}
pub fn without_auto_joins(mut self) -> Self {
self.auto_joins = false;
self
}
pub fn include_lazy_relations(mut self) -> Self {
self.include_lazy_relations = true;
self
}
pub fn search(
mut self,
config: <C::Model as FullTextSearchable>::FullTextSearchConfig,
) -> Self
where
C::Model: FullTextSearchable + 'static,
<C::Model as FullTextSearchable>::FullTextSearchConfig:
Send + Sync + 'static,
{
self.full_text_search =
Some(Box::new(ModelFullTextSearchPlan::<C::Model>::new(config)));
self
}
}
impl<'a, C, Row> Buildable<C> for ReadQueryBuilder<'a, C, Row>
where
C: QueryContext,
C::Model: crate::GetDeleteMarker + JoinNavigationModel,
Row: Send + Sync + Unpin + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>,
{
type Row = Row;
type Plan = ReadQueryPlan<'a, C, Row>;
fn from_ctx() -> Self {
Self {
table: C::TABLE,
joins: None,
where_expr: None,
sort_expr: None,
pagination: None,
include_deleted: false,
auto_joins: true,
include_lazy_relations: false,
delete_marker_field: C::Model::delete_marker_field(),
selection: None,
group_by: None,
having: None,
full_text_search: None,
row: PhantomData,
}
}
fn build(self) -> Self::Plan {
let auto_join_paths: Vec<JoinPath> = if self.auto_joins {
C::Model::default_join_paths(self.include_lazy_relations).into_vec()
} else {
Vec::new()
};
let mut combined_paths = auto_join_paths;
if let Some(paths) = self.joins {
combined_paths.extend(paths);
}
let resolved_joins = if combined_paths.is_empty() {
None
} else {
Some(merge_unique_join_paths(combined_paths))
};
let mut plan = ReadQueryPlan {
joins: resolved_joins,
where_expr: self.where_expr,
sort_expr: self.sort_expr,
pagination: self.pagination,
table: self.table,
include_deleted: self.include_deleted,
delete_marker_field: self.delete_marker_field,
selection: self.selection,
group_by: self.group_by,
having: self.having,
full_text_search: self.full_text_search,
aggregate_filter: None,
row: PhantomData,
};
plan.compute_aggregate_filter();
plan
}
}
impl<'a, C, Row> ReadQueryBuilder<'a, C, Row>
where
C: QueryContext,
C::Model: crate::GetDeleteMarker + JoinNavigationModel,
Row: Send + Sync + Unpin + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>,
{
pub fn take<NewRow>(
self,
selection: SelectionList<NewRow, SelectionEntry>,
) -> ReadQueryBuilder<'a, C, NewRow>
where
NewRow: Send
+ Sync
+ Unpin
+ for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>,
{
ReadQueryBuilder {
table: self.table,
joins: self.joins,
where_expr: self.where_expr,
sort_expr: self.sort_expr,
pagination: self.pagination,
include_deleted: self.include_deleted,
auto_joins: self.auto_joins,
include_lazy_relations: self.include_lazy_relations,
delete_marker_field: self.delete_marker_field,
selection: Some(selection),
group_by: self.group_by,
having: self.having,
full_text_search: self.full_text_search,
row: PhantomData,
}
}
pub fn group_by(mut self, group_by: GroupByList) -> Self {
let cols = group_by.into_columns().into_vec();
self.group_by = Some(cols);
self
}
pub fn having(mut self, having: HavingList) -> Self {
self.having = Some(having.into_predicates());
self
}
}
impl<'a, C, Row> BuildableFilter<C> for ReadQueryBuilder<'a, C, Row>
where
C: QueryContext,
C::Model: crate::GetDeleteMarker + JoinNavigationModel,
{
fn r#where(mut self, e: Expression<<C as QueryContext>::Query>) -> Self {
match self.where_expr {
Some(existing) => self.where_expr = Some(and![existing, e]),
None => self.where_expr = Some(e),
};
self
}
}
impl<'a, C, Row> BuildableJoin<C> for ReadQueryBuilder<'a, C, Row>
where
C: QueryContext,
C::Model: crate::GetDeleteMarker + JoinNavigationModel,
{
fn join(self, join: <C as QueryContext>::Join, kind: JoinKind) -> Self {
self.join_path(JoinPath::from_join(join, kind))
}
fn join_path(mut self, path: JoinPath) -> Self {
if let Some(expected) = path.first_table() {
assert_eq!(
expected, self.table,
"join path must start at base table `{}` but started at `{}`",
self.table, expected,
);
}
match &mut self.joins {
Some(existing) => existing.push(path),
None => self.joins = Some(vec![path]),
};
self
}
}
impl<'a, C, Row> BuildableSort<C> for ReadQueryBuilder<'a, C, Row>
where
C: QueryContext,
C::Model: crate::GetDeleteMarker + JoinNavigationModel,
{
fn order_by(mut self, s: SortOrder<<C as QueryContext>::Sort>) -> Self {
match self.sort_expr {
Some(existing) => self.sort_expr = Some(order_by![existing, s]),
None => self.sort_expr = Some(s),
}
self
}
}
impl<'a, C, Row> BuildablePage<C> for ReadQueryBuilder<'a, C, Row>
where
C: QueryContext,
C::Model: crate::GetDeleteMarker + JoinNavigationModel,
{
fn paginate(mut self, p: Pagination) -> Self {
self.pagination = Some(p);
self
}
}
impl<'a, C, Row> BuildableReadQuery<C, Row> for ReadQueryBuilder<'a, C, Row>
where
C: QueryContext,
C::Model: crate::GetDeleteMarker + JoinNavigationModel,
Row: Send + Sync + Unpin + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>,
{
}