use crate::paginated_query_as::internal::quote_identifier;
use crate::paginated_query_as::models::QuerySortDirection;
use crate::{FlatQueryParams, PaginatedResponse, QueryParams};
use serde::Serialize;
use sqlx::{query::QueryAs, Database, Execute, Executor, FromRow, IntoArguments, Pool};
type QueryBuilderFn<T, DB> = Box<
dyn for<'p> Fn(&'p QueryParams<T>) -> (Vec<String>, <DB as Database>::Arguments<'p>)
+ Send
+ Sync,
>;
pub struct PaginatedQueryBuilder<'q, T, DB, A>
where
DB: Database,
T: for<'r> FromRow<'r, <DB as Database>::Row> + Send + Unpin,
{
query: QueryAs<'q, DB, T, A>,
params: QueryParams<'q, T>,
totals_count_enabled: bool,
build_query_fn: QueryBuilderFn<T, DB>,
}
impl<'q, T, DB, A> PaginatedQueryBuilder<'q, T, DB, A>
where
DB: Database,
T: for<'r> FromRow<'r, <DB as Database>::Row> + Send + Unpin + Serialize + Default,
A: 'q + IntoArguments<'q, DB> + Send,
DB::Arguments<'q>: IntoArguments<'q, DB>,
for<'c> &'c Pool<DB>: Executor<'c, Database = DB>,
usize: sqlx::ColumnIndex<<DB as Database>::Row>,
i64: sqlx::Type<DB> + for<'r> sqlx::Decode<'r, DB> + Send + Unpin,
{
pub fn new<F>(query: QueryAs<'q, DB, T, A>, build_query_fn: F) -> Self
where
F: for<'p> Fn(&'p QueryParams<T>) -> (Vec<String>, DB::Arguments<'p>)
+ Send
+ Sync
+ 'static,
{
Self {
query,
params: FlatQueryParams::default().into(),
totals_count_enabled: true,
build_query_fn: Box::new(build_query_fn),
}
}
pub fn with_query_builder<F>(mut self, build_query_fn: F) -> Self
where
F: for<'p> Fn(&'p QueryParams<T>) -> (Vec<String>, DB::Arguments<'p>)
+ Send
+ Sync
+ 'static,
{
self.build_query_fn = Box::new(build_query_fn);
self
}
pub fn with_params(mut self, params: impl Into<QueryParams<'q, T>>) -> Self {
self.params = params.into();
self
}
pub fn disable_totals_count(mut self) -> Self {
self.totals_count_enabled = false;
self
}
fn build_base_query(&self) -> String {
format!("WITH base_query AS ({})", self.query.sql())
}
fn build_where_clause(&self, conditions: &[String]) -> String {
if conditions.is_empty() {
String::new()
} else {
format!(" WHERE {}", conditions.join(" AND "))
}
}
fn build_order_clause(&self) -> String {
let order = match self.params.sort.sort_direction {
QuerySortDirection::Ascending => "ASC",
QuerySortDirection::Descending => "DESC",
};
let column_name = quote_identifier(&self.params.sort.sort_column);
format!(" ORDER BY {} {}", column_name, order)
}
fn build_limit_offset_clause(&self) -> String {
let pagination = &self.params.pagination;
let offset = (pagination.page - 1) * pagination.page_size;
format!(" LIMIT {} OFFSET {}", pagination.page_size, offset)
}
}
#[cfg(feature = "postgres")]
impl<'q, T, A> PaginatedQueryBuilder<'q, T, sqlx::Postgres, A>
where
T: for<'r> FromRow<'r, <sqlx::Postgres as sqlx::Database>::Row>
+ Send
+ Unpin
+ Serialize
+ Default,
A: 'q + IntoArguments<'q, sqlx::Postgres> + Send,
{
pub fn new_with_defaults(query: sqlx::query::QueryAs<'q, sqlx::Postgres, T, A>) -> Self {
use crate::paginated_query_as::examples::postgres_examples::build_query_with_safe_defaults;
Self::new(query, |params| {
build_query_with_safe_defaults::<T, sqlx::Postgres>(params)
})
}
pub async fn fetch_paginated(
self,
pool: &sqlx::PgPool,
) -> Result<PaginatedResponse<T>, sqlx::Error> {
let base_sql = self.build_base_query();
let params_ref = &self.params;
let (conditions, main_arguments) = (self.build_query_fn)(params_ref);
let where_clause = self.build_where_clause(&conditions);
let count_sql = if self.totals_count_enabled {
Some(format!(
"{} SELECT COUNT(*) FROM base_query{}",
base_sql, where_clause
))
} else {
None
};
let mut main_sql = format!("{} SELECT * FROM base_query{}", base_sql, where_clause);
main_sql.push_str(&self.build_order_clause());
main_sql.push_str(&self.build_limit_offset_clause());
let (total, total_pages, pagination) = if self.totals_count_enabled {
let (_, count_arguments) = (self.build_query_fn)(params_ref);
let pagination_arguments = self.params.pagination.clone();
let count_sql_str = count_sql.as_ref().unwrap();
let count: i64 = sqlx::query_scalar_with(count_sql_str, count_arguments)
.fetch_one(pool)
.await?;
let available_pages = match count {
0 => 0,
_ => (count + pagination_arguments.page_size - 1) / pagination_arguments.page_size,
};
(
Some(count),
Some(available_pages),
Some(pagination_arguments),
)
} else {
(None, None, None)
};
let records = sqlx::query_as_with::<sqlx::Postgres, T, _>(&main_sql, main_arguments)
.fetch_all(pool)
.await?;
Ok(PaginatedResponse {
records,
pagination,
total,
total_pages,
})
}
}
#[cfg(feature = "sqlite")]
impl<'q, T, A> PaginatedQueryBuilder<'q, T, sqlx::Sqlite, A>
where
T: for<'r> FromRow<'r, <sqlx::Sqlite as sqlx::Database>::Row>
+ Send
+ Unpin
+ Serialize
+ Default,
A: 'q + IntoArguments<'q, sqlx::Sqlite> + Send,
{
pub fn new_with_defaults(query: sqlx::query::QueryAs<'q, sqlx::Sqlite, T, A>) -> Self {
use crate::QueryBuilder;
Self::new(query, |params| {
QueryBuilder::<T, sqlx::Sqlite>::new()
.with_search(params)
.with_filters(params)
.with_date_range(params)
.build()
})
}
pub async fn fetch_paginated(
self,
pool: &sqlx::SqlitePool,
) -> Result<PaginatedResponse<T>, sqlx::Error> {
let base_sql = self.build_base_query();
let params_ref = &self.params;
let (conditions, main_arguments) = (self.build_query_fn)(params_ref);
let where_clause = self.build_where_clause(&conditions);
let count_sql = if self.totals_count_enabled {
Some(format!(
"{} SELECT COUNT(*) FROM base_query{}",
base_sql, where_clause
))
} else {
None
};
let mut main_sql = format!("{} SELECT * FROM base_query{}", base_sql, where_clause);
main_sql.push_str(&self.build_order_clause());
main_sql.push_str(&self.build_limit_offset_clause());
let (total, total_pages, pagination) = if self.totals_count_enabled {
let (_, count_arguments) = (self.build_query_fn)(params_ref);
let pagination_arguments = self.params.pagination.clone();
let count_sql_str = count_sql.as_ref().unwrap();
let count: i64 = sqlx::query_scalar_with(count_sql_str, count_arguments)
.fetch_one(pool)
.await?;
let available_pages = match count {
0 => 0,
_ => (count + pagination_arguments.page_size - 1) / pagination_arguments.page_size,
};
(
Some(count),
Some(available_pages),
Some(pagination_arguments),
)
} else {
(None, None, None)
};
let records = sqlx::query_as_with::<sqlx::Sqlite, T, _>(&main_sql, main_arguments)
.fetch_all(pool)
.await?;
Ok(PaginatedResponse {
records,
pagination,
total,
total_pages,
})
}
}