use sqlx::{
Postgres,
Type,
};
use sqlxo_traits::{
Filterable,
JoinKind,
JoinPath,
Sortable,
};
use sqlxo_traits::{
QueryContext,
SqlWrite,
};
use std::collections::HashSet;
mod expression;
mod head;
mod pagination;
mod sort;
pub use expression::Expression;
pub use head::{
DeleteHead,
InsertHead,
QualifiedColumn,
ReadHead,
SelectProjection,
SelectType,
UpdateHead,
};
pub use pagination::{
Page,
Pagination,
};
pub use sort::SortOrder;
use crate::blocks::head::ToHead;
pub trait BuildableFilter<C: QueryContext> {
fn r#where(self, e: Expression<C::Query>) -> Self;
}
pub trait BuildableJoin<C: QueryContext> {
fn join(self, join: C::Join, kind: JoinKind) -> Self;
fn join_path(self, path: JoinPath) -> Self;
}
pub trait BuildableSort<C: QueryContext> {
fn order_by(self, s: SortOrder<C::Sort>) -> Self;
}
pub trait BuildablePage<C: QueryContext> {
fn paginate(self, p: Pagination) -> Self;
}
pub struct SqlWriter {
qb: sqlx::QueryBuilder<'static, Postgres>,
has_join: bool,
has_where: bool,
has_sort: bool,
has_group_by: bool,
has_having: bool,
has_pagination: bool,
}
impl SqlWriter {
pub fn new(head: impl ToHead) -> Self {
Self::from_sql(head.to_head().to_string())
}
pub fn from_sql(sql: impl Into<String>) -> Self {
let qb = sqlx::QueryBuilder::<Postgres>::new(sql.into());
Self {
qb,
has_join: false,
has_where: false,
has_sort: false,
has_group_by: false,
has_having: false,
has_pagination: false,
}
}
pub fn into_builder(self) -> sqlx::QueryBuilder<'static, Postgres> {
self.qb
}
pub fn query_builder_mut(
&mut self,
) -> &mut sqlx::QueryBuilder<'static, Postgres> {
&mut self.qb
}
pub fn push_joins(&mut self, joins: &[JoinPath], base_table: &str) {
if self.has_join || joins.is_empty() {
return;
}
let mut emitted: HashSet<String> = HashSet::new();
for path in joins {
self.push_join_path(path, base_table, &mut emitted);
}
self.has_join = true;
}
fn push_join_path(
&mut self,
path: &JoinPath,
base_table: &str,
emitted: &mut HashSet<String>,
) {
if path.is_empty() {
return;
}
let mut left_alias = base_table.to_string();
let mut alias_prefix = String::new();
for segment in path.segments() {
let join_word = match segment.kind {
JoinKind::Inner => " INNER JOIN ",
JoinKind::Left => " LEFT JOIN ",
};
if let Some(through) = segment.descriptor.through {
let mut through_alias = alias_prefix.clone();
through_alias.push_str(through.alias_segment);
let join_key = format!(
"{join}|{table}|{alias}|{left}|{left_field}|{right_field}",
join = join_word.trim(),
table = through.table,
alias = &through_alias,
left = &left_alias,
left_field = through.left_field,
right_field = through.right_field,
);
if emitted.insert(join_key) {
let clause = format!(
r#"{join}{table} AS "{alias}" ON "{left}"."{left_field}" = "{alias}"."{right_field}""#,
join = join_word,
table = through.table,
alias = &through_alias,
left = &left_alias,
left_field = through.left_field,
right_field = through.right_field,
);
self.qb.push(clause);
}
left_alias = through_alias;
}
alias_prefix.push_str(segment.descriptor.alias_segment);
let right_alias = alias_prefix.clone();
let join_key = format!(
"{join}|{table}|{alias}|{left}|{left_field}|{right_field}",
join = join_word.trim(),
table = segment.descriptor.right_table,
alias = &right_alias,
left = &left_alias,
left_field = segment.descriptor.left_field,
right_field = segment.descriptor.right_field,
);
if emitted.insert(join_key) {
let clause = format!(
r#"{join}{table} AS "{alias}" ON "{left}"."{left_field}" = "{alias}"."{right_field}""#,
join = join_word,
table = segment.descriptor.right_table,
alias = &right_alias,
left = &left_alias,
left_field = segment.descriptor.left_field,
right_field = segment.descriptor.right_field,
);
self.qb.push(clause);
}
left_alias = right_alias;
}
}
pub fn push_where<F: Filterable>(&mut self, expr: &Expression<F>) {
self.push_where_raw(|writer| expr.write(writer));
}
pub fn push_where_raw(&mut self, mut build: impl FnMut(&mut SqlWriter)) {
if self.has_where {
self.qb.push(" AND ");
} else {
self.qb.push(" WHERE ");
self.has_where = true;
}
build(self);
}
pub fn push_sort<S: Sortable>(&mut self, sort: &SortOrder<S>) {
if self.has_sort {
return;
}
self.qb.push(" ORDER BY ");
self.has_sort = true;
self.qb.push(sort.to_sql());
}
pub fn push_order_by_raw(&mut self, mut build: impl FnMut(&mut SqlWriter)) {
if self.has_sort {
return;
}
self.qb.push(" ORDER BY ");
self.has_sort = true;
build(self);
}
pub fn push_group_by_columns(&mut self, columns: &[QualifiedColumn]) {
if columns.is_empty() || self.has_group_by {
return;
}
self.qb.push(" GROUP BY ");
for (idx, col) in columns.iter().enumerate() {
if idx > 0 {
self.qb.push(", ");
}
self.qb
.push(format!(r#""{}"."{}""#, col.table_alias, col.column));
}
self.has_group_by = true;
}
pub fn push_having(&mut self, mut build: impl FnMut(&mut SqlWriter)) {
if self.has_having {
self.qb.push(" AND ");
} else {
self.qb.push(" HAVING ");
self.has_having = true;
}
build(self);
}
pub fn push_pagination(&mut self, p: &Pagination) {
if self.has_pagination {
return;
}
self.qb.push(" LIMIT ");
self.bind(p.page_size);
self.qb.push(" OFFSET ");
self.bind(p.page * p.page_size);
}
}
impl SqlWrite for SqlWriter {
fn push(&mut self, s: &str) {
self.qb.push(s);
}
fn bind<T>(&mut self, value: T)
where
T: sqlx::Encode<'static, Postgres> + Send + 'static,
T: Type<Postgres>,
{
self.qb.push_bind(value);
}
}