use crate::column::{reindex_params, FilterExpr, OrderExpr, SqlValue};
use crate::error::{OrmError, OrmResult};
use crate::pagination::{Page};
use crate::scope::Scope;
use sqlx::postgres::{PgArguments, PgRow};
use sqlx::{PgPool, Postgres};
use std::marker::PhantomData;
#[derive(Debug, Clone)]
enum JoinType {
Inner,
Left,
Right,
}
#[derive(Debug, Clone)]
struct JoinClause {
join_type: JoinType,
table: String,
alias: Option<String>,
on: String,
}
#[derive(Debug, Clone)]
pub struct UpdateSet {
pub col: String,
pub val: SqlValue,
}
#[derive(Debug)]
pub struct QueryBuilder<T> {
pub(crate) table: String,
pub(crate) pk: String,
select_cols: Vec<String>,
distinct: bool,
filters: Vec<FilterExpr>,
with_deleted: bool,
only_deleted: bool,
soft_delete_col: Option<String>,
joins: Vec<JoinClause>,
group_by: Vec<String>,
having: Option<FilterExpr>,
order_by: Vec<String>,
order_random: bool,
limit: Option<i64>,
offset: Option<i64>,
op: QueryOp,
update_sets: Vec<UpdateSet>,
_marker: PhantomData<T>,
}
impl<T> Clone for QueryBuilder<T> {
fn clone(&self) -> Self {
Self {
table: self.table.clone(),
pk: self.pk.clone(),
select_cols: self.select_cols.clone(),
distinct: self.distinct,
filters: self.filters.clone(),
with_deleted: self.with_deleted,
only_deleted: self.only_deleted,
soft_delete_col: self.soft_delete_col.clone(),
joins: self.joins.clone(),
group_by: self.group_by.clone(),
having: self.having.clone(),
order_by: self.order_by.clone(),
order_random: self.order_random,
limit: self.limit,
offset: self.offset,
op: self.op.clone(),
update_sets: self.update_sets.clone(),
_marker: PhantomData,
}
}
}
#[derive(Debug, Clone, PartialEq)]
enum QueryOp {
Select,
Update,
Delete,
Count,
}
impl<T> QueryBuilder<T>
where
T: Send + Sync + Unpin + 'static,
{
pub fn new(table: impl Into<String>, pk: impl Into<String>) -> Self {
Self {
table: table.into(),
pk: pk.into(),
select_cols: vec!["*".into()],
distinct: false,
filters: vec![],
with_deleted: false,
only_deleted: false,
soft_delete_col: None,
joins: vec![],
group_by: vec![],
having: None,
order_by: vec![],
order_random: false,
limit: None,
offset: None,
op: QueryOp::Select,
update_sets: vec![],
_marker: PhantomData,
}
}
pub fn with_soft_delete_col(mut self, col: impl Into<String>) -> Self {
self.soft_delete_col = Some(col.into());
self
}
pub fn filter<F>(mut self, f: F) -> Self
where
F: FnOnce(&T::Columns) -> FilterExpr,
T: HasColumns,
{
let cols = T::columns_proxy();
let expr = f(&cols);
self.filters.push(expr);
self
}
pub fn filter_raw(mut self, sql: impl Into<String>) -> Self {
self.filters.push(FilterExpr::raw(sql));
self
}
pub fn filter_if<F>(self, condition: bool, f: F) -> Self
where
F: FnOnce(&T::Columns) -> FilterExpr,
T: HasColumns,
{
if condition {
self.filter(f)
} else {
self
}
}
pub fn apply(self, scope: Scope<T>) -> Self {
(scope.apply_fn)(self)
}
pub fn with_deleted(mut self) -> Self {
self.with_deleted = true;
self
}
pub fn only_deleted(mut self) -> Self {
self.only_deleted = true;
self
}
pub fn select_cols(mut self, cols: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.select_cols = cols.into_iter().map(|c| c.into()).collect();
self
}
pub fn select_distinct_col(mut self, col: impl Into<String>) -> Self {
self.distinct = true;
self.select_cols = vec![col.into()];
self
}
pub fn inner_join(mut self, table: impl Into<String>, on: impl Into<String>) -> Self {
self.joins.push(JoinClause {
join_type: JoinType::Inner,
table: table.into(),
alias: None,
on: on.into(),
});
self
}
pub fn left_join(mut self, table: impl Into<String>, on: impl Into<String>) -> Self {
self.joins.push(JoinClause {
join_type: JoinType::Left,
table: table.into(),
alias: None,
on: on.into(),
});
self
}
pub fn group_by_col(mut self, col: impl Into<String>) -> Self {
self.group_by.push(col.into());
self
}
pub fn having_raw(mut self, sql: impl Into<String>) -> Self {
self.having = Some(FilterExpr::raw(sql));
self
}
pub fn order_by<F>(mut self, f: F) -> Self
where
F: FnOnce(&T::Columns) -> OrderExpr,
T: HasColumns,
{
let cols = T::columns_proxy();
let expr = f(&cols);
self.order_by.push(expr.sql);
self
}
pub fn order_by_raw(mut self, sql: impl Into<String>) -> Self {
self.order_by.push(sql.into());
self
}
pub fn order_by_random(mut self) -> Self {
self.order_random = true;
self
}
pub fn limit(mut self, n: i64) -> Self {
self.limit = Some(n);
self
}
pub fn offset(mut self, n: i64) -> Self {
self.offset = Some(n);
self
}
pub fn paginate(mut self, page: i64, per_page: i64) -> Self {
let page = page.max(1);
self.limit = Some(per_page);
self.offset = Some((page - 1) * per_page);
self
}
pub fn build_select(&self) -> (String, Vec<SqlValue>) {
let mut bindings: Vec<SqlValue> = vec![];
let mut all_filters = self.filters.clone();
if let Some(ref col) = self.soft_delete_col {
if !self.with_deleted && !self.only_deleted {
all_filters.push(FilterExpr::raw(format!("\"{}\" IS NULL", col)));
} else if self.only_deleted {
all_filters.push(FilterExpr::raw(format!("\"{}\" IS NOT NULL", col)));
}
}
let distinct_kw = if self.distinct { "DISTINCT " } else { "" };
let cols = self.select_cols.join(", ");
let mut sql = format!("SELECT {}{} FROM \"{}\"", distinct_kw, cols, self.table);
for j in &self.joins {
let kw = match j.join_type {
JoinType::Inner => "INNER JOIN",
JoinType::Left => "LEFT JOIN",
JoinType::Right => "RIGHT JOIN",
};
let alias_part = j
.alias
.as_deref()
.map(|a| format!(" AS \"{}\"", a))
.unwrap_or_default();
sql.push_str(&format!(
" {} \"{}\"{} ON {}",
kw, j.table, alias_part, j.on
));
}
if !all_filters.is_empty() {
let mut parts: Vec<String> = vec![];
for expr in &all_filters {
let offset = bindings.len();
let reindexed = reindex_params(&expr.sql, offset);
parts.push(reindexed);
bindings.extend(expr.bindings.clone());
}
sql.push_str(" WHERE ");
sql.push_str(&parts.join(" AND "));
}
if !self.group_by.is_empty() {
sql.push_str(" GROUP BY ");
sql.push_str(&self.group_by.join(", "));
}
if let Some(ref hav) = self.having {
let offset = bindings.len();
let reindexed = reindex_params(&hav.sql, offset);
sql.push_str(&format!(" HAVING {}", reindexed));
bindings.extend(hav.bindings.clone());
}
if self.order_random {
sql.push_str(" ORDER BY RANDOM()");
} else if !self.order_by.is_empty() {
sql.push_str(" ORDER BY ");
sql.push_str(&self.order_by.join(", "));
}
if let Some(l) = self.limit {
sql.push_str(&format!(" LIMIT {}", l));
}
if let Some(o) = self.offset {
sql.push_str(&format!(" OFFSET {}", o));
}
(sql, bindings)
}
fn build_count(self) -> (String, Vec<SqlValue>) {
let count_builder = QueryBuilder::<T> {
select_cols: vec!["COUNT(*)".into()],
order_by: vec![],
order_random: false,
limit: None,
offset: None,
..self.clone()
};
count_builder.build_select()
}
pub async fn fetch_all(self, pool: &PgPool) -> OrmResult<Vec<T>>
where
T: for<'r> sqlx::FromRow<'r, PgRow>,
{
let (sql, bindings) = self.build_select();
let mut q = sqlx::query_as::<Postgres, T>(&sql);
for b in bindings {
q = bind_value(q, b);
}
q.fetch_all(pool).await.map_err(OrmError::from_sqlx)
}
pub async fn first(mut self, pool: &PgPool) -> OrmResult<Option<T>>
where
T: for<'r> sqlx::FromRow<'r, PgRow>,
{
self.limit = Some(1);
let (sql, bindings) = self.build_select();
let mut q = sqlx::query_as::<Postgres, T>(&sql);
for b in bindings {
q = bind_value(q, b);
}
q.fetch_optional(pool).await.map_err(OrmError::from_sqlx)
}
pub async fn first_or_fail(self, pool: &PgPool) -> OrmResult<T>
where
T: for<'r> sqlx::FromRow<'r, PgRow>,
{
self.first(pool).await?.ok_or(OrmError::NotFound)
}
pub async fn last(mut self, pool: &PgPool) -> OrmResult<Option<T>>
where
T: for<'r> sqlx::FromRow<'r, PgRow>,
{
if self.order_by.is_empty() && !self.order_random {
self.order_by.push(format!("\"{}\" DESC", self.pk));
}
self.limit = Some(1);
let (sql, bindings) = self.build_select();
let mut q = sqlx::query_as::<Postgres, T>(&sql);
for b in bindings {
q = bind_value(q, b);
}
q.fetch_optional(pool).await.map_err(OrmError::from_sqlx)
}
pub async fn count(self, pool: &PgPool) -> OrmResult<i64> {
let (sql, bindings) = self.build_count();
let mut q = sqlx::query_as::<Postgres, (i64,)>(&sql);
for b in bindings {
q = bind_i64_value(q, b);
}
let row = q.fetch_one(pool).await.map_err(OrmError::from_sqlx)?;
Ok(row.0)
}
pub async fn exists(self, pool: &PgPool) -> OrmResult<bool> {
Ok(self.count(pool).await? > 0)
}
pub async fn fetch_page(self, page: i64, per_page: i64, pool: &PgPool) -> OrmResult<Page<T>>
where
T: for<'r> sqlx::FromRow<'r, PgRow>,
{
let page = page.max(1);
let total = self.clone().count(pool).await?;
let items = self.paginate(page, per_page).fetch_all(pool).await?;
Ok(Page::new(items, total, page, per_page))
}
pub async fn update_all<F>(self, f: F, pool: &PgPool) -> OrmResult<u64>
where
F: FnOnce(&mut UpdateBuilder<T>),
T: HasColumns,
{
let mut ub = UpdateBuilder::new();
f(&mut ub);
let mut set_parts: Vec<String> = vec![];
let mut bindings: Vec<SqlValue> = vec![];
for us in ub.sets {
let idx = bindings.len() + 1;
set_parts.push(format!("\"{}\" = ${}", us.col, idx));
bindings.push(us.val);
}
let mut where_parts: Vec<String> = vec![];
for expr in &self.filters {
let offset = bindings.len();
let reindexed = reindex_params(&expr.sql, offset);
where_parts.push(reindexed);
bindings.extend(expr.bindings.clone());
}
let mut sql = format!("UPDATE \"{}\" SET {}", self.table, set_parts.join(", "));
if !where_parts.is_empty() {
sql.push_str(" WHERE ");
sql.push_str(&where_parts.join(" AND "));
}
let mut q = sqlx::query(&sql);
for b in bindings {
q = bind_query_value(q, b);
}
let result = q.execute(pool).await.map_err(OrmError::from_sqlx)?;
Ok(result.rows_affected())
}
pub async fn delete_all(self, pool: &PgPool) -> OrmResult<u64> {
let mut bindings: Vec<SqlValue> = vec![];
let mut where_parts: Vec<String> = vec![];
for expr in &self.filters {
let offset = bindings.len();
let reindexed = reindex_params(&expr.sql, offset);
where_parts.push(reindexed);
bindings.extend(expr.bindings.clone());
}
let mut sql = format!("DELETE FROM \"{}\"", self.table);
if !where_parts.is_empty() {
sql.push_str(" WHERE ");
sql.push_str(&where_parts.join(" AND "));
}
let mut q = sqlx::query(&sql);
for b in bindings {
q = bind_query_value(q, b);
}
let result = q.execute(pool).await.map_err(OrmError::from_sqlx)?;
Ok(result.rows_affected())
}
}
pub struct UpdateBuilder<T> {
pub sets: Vec<UpdateSet>,
_m: PhantomData<T>,
}
impl<T> UpdateBuilder<T> {
fn new() -> Self {
Self {
sets: vec![],
_m: PhantomData,
}
}
}
pub trait HasColumns {
type Columns;
fn columns_proxy() -> Self::Columns;
}
fn bind_value<'q, T>(
q: sqlx::query::QueryAs<'q, Postgres, T, PgArguments>,
val: SqlValue,
) -> sqlx::query::QueryAs<'q, Postgres, T, PgArguments>
where
T: Send + Unpin,
{
match val {
SqlValue::Int(v) => q.bind(v),
SqlValue::Float(v) => q.bind(v),
SqlValue::Text(v) => q.bind(v),
SqlValue::Bool(v) => q.bind(v),
SqlValue::Null => q.bind(Option::<String>::None),
SqlValue::Json(v) => q.bind(sqlx::types::Json(v)),
SqlValue::Bytes(v) => q.bind(v),
}
}
fn bind_i64_value<'q>(
q: sqlx::query::QueryAs<'q, Postgres, (i64,), PgArguments>,
val: SqlValue,
) -> sqlx::query::QueryAs<'q, Postgres, (i64,), PgArguments> {
match val {
SqlValue::Int(v) => q.bind(v),
SqlValue::Float(v) => q.bind(v),
SqlValue::Text(v) => q.bind(v),
SqlValue::Bool(v) => q.bind(v),
SqlValue::Null => q.bind(Option::<String>::None),
SqlValue::Json(v) => q.bind(sqlx::types::Json(v)),
SqlValue::Bytes(v) => q.bind(v),
}
}
fn bind_query_value<'q>(
q: sqlx::query::Query<'q, Postgres, PgArguments>,
val: SqlValue,
) -> sqlx::query::Query<'q, Postgres, PgArguments> {
match val {
SqlValue::Int(v) => q.bind(v),
SqlValue::Float(v) => q.bind(v),
SqlValue::Text(v) => q.bind(v),
SqlValue::Bool(v) => q.bind(v),
SqlValue::Null => q.bind(Option::<String>::None),
SqlValue::Json(v) => q.bind(sqlx::types::Json(v)),
SqlValue::Bytes(v) => q.bind(v),
}
}