use std::marker::PhantomData;
use super::pool::DbPool;
use super::repository::{extract_count, placeholder};
use super::{DbError, Model, ToColumn, Value};
use crate::pagination::{CursorPage, Page};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Order {
Asc,
Desc,
}
impl Order {
fn as_sql(self) -> &'static str {
match self {
Order::Asc => "ASC",
Order::Desc => "DESC",
}
}
}
pub struct QueryBuilder<'a, T: Model> {
pool: &'a DbPool,
filters: Vec<(String, Vec<Value>)>,
order: Option<(String, Order)>,
limit: Option<u64>,
offset: Option<u64>,
_phantom: PhantomData<T>,
}
impl<'a, T: Model> Clone for QueryBuilder<'a, T> {
fn clone(&self) -> Self {
QueryBuilder {
pool: self.pool,
filters: self.filters.clone(),
order: self.order.clone(),
limit: self.limit,
offset: self.offset,
_phantom: PhantomData,
}
}
}
impl<'a, T: Model> QueryBuilder<'a, T> {
pub fn new(pool: &'a DbPool) -> Self {
QueryBuilder {
pool,
filters: Vec::new(),
order: None,
limit: None,
offset: None,
_phantom: PhantomData,
}
}
pub fn where_eq(mut self, col: &str, val: impl ToColumn) -> Self {
self.filters.push((
format!("{} = __placeholder__", col),
vec![val.to_column()],
));
self
}
pub fn filter(mut self, expr: &str, params: Vec<Value>) -> Self {
self.filters.push((expr.to_owned(), params));
self
}
pub fn order_by(mut self, col: &str, order: Order) -> Self {
self.order = Some((col.to_owned(), order));
self
}
pub fn limit(mut self, n: u64) -> Self {
self.limit = Some(n);
self
}
pub fn offset(mut self, n: u64) -> Self {
self.offset = Some(n);
self
}
pub async fn fetch_all(self) -> Result<Vec<T>, DbError> {
let (sql, params) = build_select::<T>(
self.filters, self.order, self.limit, self.offset, "*",
);
let rows = self.pool.query_rows(&sql, ¶ms).await?;
rows.iter().map(|r| T::from_row(r)).collect()
}
pub async fn fetch_one(self) -> Result<Option<T>, DbError> {
let (sql, params) = build_select::<T>(
self.filters, self.order, Some(1), self.offset, "*",
);
let rows = self.pool.query_rows(&sql, ¶ms).await?;
match rows.into_iter().next() {
Some(row) => Ok(Some(T::from_row(&row)?)),
None => Ok(None),
}
}
pub async fn count(self) -> Result<i64, DbError> {
let (sql, params) = build_select::<T>(
self.filters, self.order, self.limit, self.offset, "COUNT(*)",
);
let rows = self.pool.query_rows(&sql, ¶ms).await?;
extract_count(rows)
}
pub async fn delete(self) -> Result<(), DbError> {
let (where_clause, params, _) = build_where(self.filters, 0);
let sql = format!("DELETE FROM {}{}", T::table_name(), where_clause);
self.pool.execute(&sql, ¶ms).await?;
Ok(())
}
pub async fn update(self, col: &str, val: impl ToColumn) -> Result<(), DbError> {
let set_ph = placeholder(1);
let set_val = val.to_column();
let (where_clause, mut where_params, _) = build_where(self.filters, 1);
let mut params = vec![set_val];
params.append(&mut where_params);
let sql = format!(
"UPDATE {} SET {} = {}{}",
T::table_name(), col, set_ph, where_clause,
);
self.pool.execute(&sql, ¶ms).await?;
Ok(())
}
pub async fn paginate(self, page: u64, per_page: u64) -> Result<Page<T>, DbError> {
let page = page.max(1);
let per_page = per_page.max(1);
let offset = (page - 1) * per_page;
let total_items = self.clone().count().await? as u64;
let items = self.limit(per_page).offset(offset).fetch_all().await?;
Ok(Page::new(items, page, per_page, total_items))
}
pub async fn paginate_after(self, cursor: Option<&str>, per_page: u64) -> Result<CursorPage<T>, DbError> {
let per_page = per_page.max(1);
let pk_col = T::primary_key_name();
let mut builder = self;
if let Some(cursor_str) = cursor {
let cursor_val: i64 = cursor_str
.parse()
.map_err(|_| DbError::new(format!("invalid cursor '{}': expected an integer primary key", cursor_str)))?;
builder = builder.filter(&format!("{} > __placeholder__", pk_col), vec![Value::Int(cursor_val)]);
}
let mut rows = builder.order_by(pk_col, Order::Asc).limit(per_page + 1).fetch_all().await?;
let has_more = rows.len() as u64 > per_page;
if has_more {
rows.truncate(per_page as usize);
}
let next_cursor = if has_more {
rows.last().map(|item| match item.primary_key_value() {
Value::Int(n) => n.to_string(),
Value::Text(s) => s,
other => format!("{:?}", other),
})
} else {
None
};
Ok(CursorPage { items: rows, next_cursor })
}
}
fn build_select<T: Model>(
filters: Vec<(String, Vec<Value>)>,
order: Option<(String, Order)>,
limit: Option<u64>,
offset: Option<u64>,
projection: &str,
) -> (String, Vec<Value>) {
let (where_clause, params, _) = build_where(filters, 0);
let mut sql = format!("SELECT {} FROM {}{}", projection, T::table_name(), where_clause);
if let Some((col, ord)) = order {
sql.push_str(&format!(" ORDER BY {} {}", col, ord.as_sql()));
}
if let Some(n) = limit {
sql.push_str(&format!(" LIMIT {}", n));
}
if let Some(n) = offset {
sql.push_str(&format!(" OFFSET {}", n));
}
(sql, params)
}
pub(crate) fn build_where(
filters: Vec<(String, Vec<Value>)>,
start_idx: usize,
) -> (String, Vec<Value>, usize) {
let mut all_params: Vec<Value> = Vec::new();
let mut conditions: Vec<String> = Vec::new();
let mut idx = start_idx;
for (mut fragment, params) in filters {
for param in params {
idx += 1;
if fragment.contains("__placeholder__") {
fragment = fragment.replacen("__placeholder__", &placeholder(idx), 1);
} else {
#[cfg(all(feature = "model-postgres", not(feature = "model-sqlite")))]
{
fragment = fragment.replacen("?", &placeholder(idx), 1);
}
}
all_params.push(param);
}
conditions.push(fragment);
}
let where_clause = if conditions.is_empty() {
String::new()
} else {
format!(" WHERE {}", conditions.join(" AND "))
};
(where_clause, all_params, idx)
}