use crate::entity::{EntityType, FromRow};
use crate::error::LrefResult;
use crate::provider::{DatabaseProvider, DbValue};
use std::marker::PhantomData;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct FilterCondition {
column: String,
operator: String,
value: Option<DbValue>,
}
impl FilterCondition {
pub fn new(column: impl Into<String>, operator: impl Into<String>, value: Option<DbValue>) -> Self {
Self {
column: column.into(),
operator: operator.into(),
value,
}
}
pub fn to_sql(&self, placeholder: &str) -> String {
match &self.value {
Some(_) => format!("{} {} {}", self.column, self.operator, placeholder),
None => format!("{} {}", self.column, self.operator),
}
}
pub fn db_value(&self) -> Option<&DbValue> {
self.value.as_ref()
}
}
#[derive(Debug, Clone)]
pub struct OrderBy {
column: String,
direction: OrderDirection,
}
#[derive(Debug, Clone, Copy)]
pub enum OrderDirection {
Ascending,
Descending,
}
impl OrderBy {
pub fn new(column: impl Into<String>, direction: OrderDirection) -> Self {
Self {
column: column.into(),
direction,
}
}
pub fn to_sql(&self) -> String {
let dir = match self.direction {
OrderDirection::Ascending => "ASC",
OrderDirection::Descending => "DESC",
};
format!("{} {}", self.column, dir)
}
}
#[derive(Debug, Clone)]
pub struct IncludePath {
pub navigation: String,
pub nested: Vec<String>,
pub related_table: Option<String>,
pub foreign_key_column: Option<String>,
pub referenced_key_column: Option<String>,
}
#[derive(Debug, Clone)]
pub struct JoinSpec {
pub join_type: String,
pub table: String,
pub on_clause: String,
}
impl JoinSpec {
pub fn to_sql(&self) -> String {
format!("{} JOIN {} ON {}", self.join_type, self.table, self.on_clause)
}
}
#[derive(Debug, Clone)]
pub struct GroupBy {
pub columns: Vec<String>,
}
impl GroupBy {
pub fn to_sql(&self) -> String {
if self.columns.is_empty() {
String::new()
} else {
format!("GROUP BY {}", self.columns.join(", "))
}
}
}
#[derive(Debug, Clone)]
pub struct HavingCondition {
pub expression: String,
}
impl HavingCondition {
pub fn to_sql(&self) -> String {
format!("HAVING {}", self.expression)
}
}
#[derive(Debug, Clone)]
pub struct QueryState {
pub from: String,
pub filters: Vec<FilterCondition>,
pub joins: Vec<JoinSpec>,
pub group_bys: Vec<String>,
pub havings: Vec<String>,
pub orderings: Vec<OrderBy>,
pub offset: Option<usize>,
pub limit: Option<usize>,
pub includes: Vec<IncludePath>,
pub projected_columns: Option<Vec<String>>,
pub is_count: bool,
pub is_exists: bool,
pub aggregate: Option<String>,
pub aggregate_column: Option<String>,
pub parameters: Vec<DbValue>,
}
impl QueryState {
pub fn new(from: impl Into<String>) -> Self {
Self {
from: from.into(),
filters: Vec::new(),
joins: Vec::new(),
group_bys: Vec::new(),
havings: Vec::new(),
orderings: Vec::new(),
offset: None,
limit: None,
includes: Vec::new(),
projected_columns: None,
is_count: false,
is_exists: false,
aggregate: None,
aggregate_column: None,
parameters: Vec::new(),
}
}
pub fn to_sql(&self) -> String {
let select = if self.is_count {
"SELECT COUNT(*)".to_string()
} else if self.is_exists {
"SELECT 1".to_string()
} else if let Some(ref agg) = self.aggregate {
let col = self.aggregate_column.as_deref().unwrap_or("*");
format!("SELECT {}({})", agg, col)
} else if let Some(ref cols) = self.projected_columns {
format!("SELECT {}", cols.join(", "))
} else {
"SELECT *".to_string()
};
let mut sql = format!("{} FROM {}", select, self.from);
for join in &self.joins {
sql.push_str(&format!(" {}", join.to_sql()));
}
if !self.filters.is_empty() {
let clauses: Vec<String> = self
.filters
.iter()
.enumerate()
.map(|(i, f)| f.to_sql(&format!("${}", i + 1)))
.collect();
sql.push_str(&format!(" WHERE {}", clauses.join(" AND ")));
}
if !self.group_bys.is_empty() {
sql.push_str(&format!(" GROUP BY {}", self.group_bys.join(", ")));
}
if !self.havings.is_empty() {
sql.push_str(&format!(" HAVING {}", self.havings.join(" AND ")));
}
if !self.orderings.is_empty() {
let ords: Vec<String> = self.orderings.iter().map(|o| o.to_sql()).collect();
sql.push_str(&format!(" ORDER BY {}", ords.join(", ")));
}
match (self.limit, self.offset) {
(Some(limit), Some(offset)) => {
sql.push_str(&format!(" LIMIT {} OFFSET {}", limit, offset));
}
(Some(limit), None) => {
sql.push_str(&format!(" LIMIT {}", limit));
}
(None, Some(offset)) => {
sql.push_str(&format!(" OFFSET {}", offset));
}
(None, None) => {}
}
sql
}
pub fn params(&self) -> &[DbValue] {
&self.parameters
}
}
pub struct QueryBuilder<T: EntityType> {
state: QueryState,
provider: Option<Arc<dyn DatabaseProvider>>,
_phantom: PhantomData<T>,
}
impl<T: EntityType> QueryBuilder<T> {
pub fn new(table_name: impl Into<String>) -> Self {
Self {
state: QueryState::new(table_name),
provider: None,
_phantom: PhantomData,
}
}
pub fn with_provider(table_name: impl Into<String>, provider: Arc<dyn DatabaseProvider>) -> Self {
Self {
state: QueryState::new(table_name),
provider: Some(provider),
_phantom: PhantomData,
}
}
pub fn state(&self) -> &QueryState {
&self.state
}
pub fn filter<F>(mut self, _predicate: F) -> Self
where
F: Fn(&T) -> bool,
{
self.state.filters.push(FilterCondition::new(
"__filter__",
"=",
Some(DbValue::String("?".to_string())),
));
self
}
pub fn filter_column(mut self, column: &str, operator: &str, value: impl Into<DbValue>) -> Self {
let db_val = value.into();
self.state.parameters.push(db_val.clone());
self.state
.filters
.push(FilterCondition::new(column, operator, Some(db_val)));
self
}
pub fn find_by_id(mut self, id: i32) -> Self {
let val = DbValue::I32(id);
self.state.parameters.push(val.clone());
self.state.filters.push(FilterCondition::new(
"id",
"=",
Some(val),
));
self
}
pub fn find_by_key(mut self, key_values: &std::collections::HashMap<String, DbValue>) -> Self {
for (col, val) in key_values {
self.state.parameters.push(val.clone());
self.state.filters.push(FilterCondition::new(col.as_str(), "=", Some(val.clone())));
}
self
}
pub fn filter_in(mut self, column: &str, values: Vec<impl Into<DbValue>>) -> Self {
let db_vals: Vec<DbValue> = values.into_iter().map(|v| v.into()).collect();
let start = self.state.parameters.len() + 1;
let placeholders: Vec<String> = (0..db_vals.len()).map(|i| format!("${}", start + i)).collect();
for v in db_vals {
self.state.parameters.push(v);
}
self.state.filters.push(FilterCondition::new(
column, &format!("IN ({})", placeholders.join(", ")), None,
));
self
}
pub fn filter_is_null(mut self, column: &str) -> Self {
self.state.filters.push(FilterCondition::new(column, "IS NULL", None));
self
}
pub fn filter_is_not_null(mut self, column: &str) -> Self {
self.state.filters.push(FilterCondition::new(column, "IS NOT NULL", None));
self
}
pub fn filter_between(mut self, column: &str, low: impl Into<DbValue>, high: impl Into<DbValue>) -> Self {
let lo: DbValue = low.into();
let hi: DbValue = high.into();
let start = self.state.parameters.len() + 1;
self.state.parameters.push(lo);
self.state.parameters.push(hi);
self.state.filters.push(FilterCondition::new(
column, &format!("BETWEEN ${} AND ${}", start, start + 1), None,
));
self
}
pub fn order_by<V>(mut self, _accessor: fn(&T) -> &V) -> Self {
self.state
.orderings
.push(OrderBy::new("__order__", OrderDirection::Ascending));
self
}
pub fn order_by_column(mut self, column: &str) -> Self {
self.state
.orderings
.push(OrderBy::new(column, OrderDirection::Ascending));
self
}
pub fn order_by_desc<V>(mut self, _accessor: fn(&T) -> &V) -> Self {
self.state
.orderings
.push(OrderBy::new("__order__", OrderDirection::Descending));
self
}
pub fn order_by_desc_column(mut self, column: &str) -> Self {
self.state
.orderings
.push(OrderBy::new(column, OrderDirection::Descending));
self
}
pub fn skip(mut self, count: usize) -> Self {
self.state.offset = Some(count);
self
}
pub fn take(mut self, count: usize) -> Self {
self.state.limit = Some(count);
self
}
pub fn include<Nav>(mut self, _navigation: fn(&T) -> &Nav) -> Self {
self.state.includes.push(IncludePath {
navigation: "__include__".to_string(),
nested: Vec::new(),
related_table: None,
foreign_key_column: None,
referenced_key_column: None,
});
self
}
pub fn include_named(mut self, navigation: &str) -> Self {
self.state.includes.push(IncludePath {
navigation: navigation.to_string(),
nested: Vec::new(),
related_table: None,
foreign_key_column: None,
referenced_key_column: None,
});
self
}
pub fn include_with_join(
mut self,
navigation: &str,
related_table: &str,
foreign_key: &str,
referenced_key: &str,
join_type: &str,
) -> Self {
self.state.includes.push(IncludePath {
navigation: navigation.to_string(),
nested: Vec::new(),
related_table: Some(related_table.to_string()),
foreign_key_column: Some(foreign_key.to_string()),
referenced_key_column: Some(referenced_key.to_string()),
});
let on_clause = format!("{}.{} = {}.{}",
self.state.from, foreign_key, related_table, referenced_key);
self.state.joins.push(JoinSpec {
join_type: join_type.to_string(),
table: related_table.to_string(),
on_clause,
});
self
}
pub fn inner_join(
mut self,
table: &str,
left_column: &str,
right_column: &str,
) -> Self {
let on_clause = format!("{}.{} = {}.{}", self.state.from, left_column, table, right_column);
self.state.joins.push(JoinSpec {
join_type: "INNER".to_string(),
table: table.to_string(),
on_clause,
});
self
}
pub fn left_join(
mut self,
table: &str,
left_column: &str,
right_column: &str,
) -> Self {
let on_clause = format!("{}.{} = {}.{}", self.state.from, left_column, table, right_column);
self.state.joins.push(JoinSpec {
join_type: "LEFT".to_string(),
table: table.to_string(),
on_clause,
});
self
}
pub fn group_by(mut self, columns: &[&str]) -> Self {
self.state.group_bys = columns.iter().map(|s| s.to_string()).collect();
self
}
pub fn having(mut self, expression: &str) -> Self {
self.state.havings.push(expression.to_string());
self
}
pub fn then_include<Nav, SubNav>(mut self, _navigation: fn(&Nav) -> &SubNav) -> Self {
if let Some(last) = self.state.includes.last_mut() {
last.nested.push("__then__".to_string());
}
self
}
pub async fn sum(self, column: &str) -> LrefResult<f64> {
let mut state = self.state.clone();
state.aggregate = Some("SUM".to_string());
state.aggregate_column = Some(column.to_string());
let sql = state.to_sql();
let params = state.params().to_vec();
let provider = self.provider.as_ref()
.ok_or_else(|| crate::error::LrefError::Configuration(
"No provider attached to QueryBuilder.".to_string()
))?;
let mut conn = provider.get_connection().await?;
let rows = conn.query(&sql, ¶ms).await?;
if let Some(first) = rows.first().and_then(|r| r.first()) {
first.trim().parse::<f64>().map_err(|_| {
crate::error::LrefError::TypeConversion("SUM result is not f64".to_string())
})
} else {
Ok(0.0)
}
}
pub async fn avg(self, column: &str) -> LrefResult<f64> {
let mut state = self.state.clone();
state.aggregate = Some("AVG".to_string());
state.aggregate_column = Some(column.to_string());
let sql = state.to_sql();
let params = state.params().to_vec();
let provider = self.provider.as_ref()
.ok_or_else(|| crate::error::LrefError::Configuration(
"No provider attached to QueryBuilder.".to_string()
))?;
let mut conn = provider.get_connection().await?;
let rows = conn.query(&sql, ¶ms).await?;
if let Some(first) = rows.first().and_then(|r| r.first()) {
first.trim().parse::<f64>().map_err(|_| {
crate::error::LrefError::TypeConversion("AVG result is not f64".to_string())
})
} else {
Ok(0.0)
}
}
pub async fn min<V>(self, column: &str) -> LrefResult<Option<String>> {
let mut state = self.state.clone();
state.aggregate = Some("MIN".to_string());
state.aggregate_column = Some(column.to_string());
let sql = state.to_sql();
let params = state.params().to_vec();
let provider = self.provider.as_ref()
.ok_or_else(|| crate::error::LrefError::Configuration(
"No provider attached to QueryBuilder.".to_string()
))?;
let mut conn = provider.get_connection().await?;
let rows = conn.query(&sql, ¶ms).await?;
Ok(rows.first().and_then(|r| r.first().cloned()))
}
pub async fn max<V>(self, column: &str) -> LrefResult<Option<String>> {
let mut state = self.state.clone();
state.aggregate = Some("MAX".to_string());
state.aggregate_column = Some(column.to_string());
let sql = state.to_sql();
let params = state.params().to_vec();
let provider = self.provider.as_ref()
.ok_or_else(|| crate::error::LrefError::Configuration(
"No provider attached to QueryBuilder.".to_string()
))?;
let mut conn = provider.get_connection().await?;
let rows = conn.query(&sql, ¶ms).await?;
Ok(rows.first().and_then(|r| r.first().cloned()))
}
pub fn select<R, F>(self, _selector: F) -> SelectQueryBuilder<T, R>
where
F: Fn(&T) -> R,
{
let mut state = self.state.clone();
state.projected_columns = Some(vec!["__projected__".to_string()]);
SelectQueryBuilder::<T, R> {
state,
_phantom_t: PhantomData,
_phantom_r: PhantomData,
}
}
pub fn select_columns(self, columns: &[&str]) -> SelectQueryBuilder<T, ()> {
let mut state = self.state.clone();
state.projected_columns = Some(columns.iter().map(|s| s.to_string()).collect());
SelectQueryBuilder::<T, ()> {
state,
_phantom_t: PhantomData,
_phantom_r: PhantomData,
}
}
pub fn to_sql(&self) -> String {
self.state.to_sql()
}
pub async fn to_list(self) -> LrefResult<Vec<T>>
where
T: FromRow,
{
let sql = self.to_sql();
let params = self.state.params().to_vec();
let provider = self.provider.as_ref()
.ok_or_else(|| crate::error::LrefError::Configuration(
"No provider attached to QueryBuilder. Use DbSet::query() or attach a provider.".to_string()
))?;
let mut conn = provider.get_connection().await?;
let rows = conn.query(&sql, ¶ms).await?;
materialize_entities::<T>(&rows)
}
pub async fn first(self) -> LrefResult<T>
where
T: FromRow,
{
let mut results = self.take(1).to_list().await?;
results.pop().ok_or_else(|| {
crate::error::LrefError::NotFound("Entity not found".to_string())
})
}
pub async fn first_or_default(self) -> LrefResult<Option<T>>
where
T: FromRow,
{
let mut results = self.take(1).to_list().await?;
Ok(results.pop())
}
pub async fn count(self) -> LrefResult<i64> {
let mut state = self.state.clone();
state.is_count = true;
let sql = state.to_sql();
let params = state.params().to_vec();
let provider = self.provider.as_ref()
.ok_or_else(|| crate::error::LrefError::Configuration(
"No provider attached to QueryBuilder.".to_string()
))?;
let mut conn = provider.get_connection().await?;
let rows = conn.query(&sql, ¶ms).await?;
if let Some(first_row) = rows.first() {
if let Some(first_val) = first_row.first() {
return first_val.trim().parse::<i64>().map_err(|e| {
crate::error::LrefError::TypeConversion(
format!("COUNT result '{}' is not i64: {}", first_val, e)
)
});
}
}
Ok(0)
}
pub async fn any(self) -> LrefResult<bool> {
let mut state = self.state.clone();
state.is_exists = true;
state.limit = Some(1);
let sql = state.to_sql();
let params = state.params().to_vec();
let provider = self.provider.as_ref()
.ok_or_else(|| crate::error::LrefError::Configuration(
"No provider attached to QueryBuilder.".to_string()
))?;
let mut conn = provider.get_connection().await?;
let rows = conn.query(&sql, ¶ms).await?;
Ok(!rows.is_empty())
}
pub fn execute_update(self) -> ExecuteUpdateBuilder<T> {
ExecuteUpdateBuilder {
state: self.state.clone(),
updates: Vec::new(),
provider: self.provider.clone(),
_phantom: PhantomData,
}
}
pub async fn execute_delete(self) -> LrefResult<u64> {
let sql = format!("DELETE FROM {} {}", self.state.from, build_where(&self.state.filters));
let params = self.state.params().to_vec();
let provider = self.provider.as_ref()
.ok_or_else(|| crate::error::LrefError::Configuration(
"No provider attached to QueryBuilder.".to_string()
))?;
let mut conn = provider.get_connection().await?;
conn.execute(&sql, ¶ms).await
}
}
pub struct ExecuteUpdateBuilder<T: EntityType> {
state: QueryState,
updates: Vec<(String, DbValue)>,
provider: Option<Arc<dyn DatabaseProvider>>,
_phantom: PhantomData<T>,
}
impl<T: EntityType> ExecuteUpdateBuilder<T> {
pub fn set_property(mut self, _accessor: fn(&T) -> &str, value: impl Into<DbValue>) -> Self {
self.updates.push(("__column__".to_string(), value.into()));
self
}
pub fn set_column(mut self, column: &str, value: impl Into<DbValue>) -> Self {
self.updates.push((column.to_string(), value.into()));
self
}
pub fn to_sql(&self) -> String {
let sets: Vec<String> = self
.updates
.iter()
.enumerate()
.map(|(i, (col, _))| format!("{} = ${}", col, i + 1))
.collect();
let where_clause = build_where(&self.state.filters);
format!(
"UPDATE {} SET {} {}",
self.state.from,
sets.join(", "),
where_clause
)
}
pub fn params(&self) -> Vec<DbValue> {
let mut params: Vec<DbValue> = self.updates.iter().map(|(_, v)| v.clone()).collect();
for filter in &self.state.filters {
if let Some(v) = filter.db_value() {
params.push(v.clone());
}
}
params
}
pub async fn execute(self) -> LrefResult<u64> {
let sql = self.to_sql();
let params = self.params();
let provider = self.provider.as_ref()
.ok_or_else(|| crate::error::LrefError::Configuration(
"No provider attached to ExecuteUpdateBuilder.".to_string()
))?;
let mut conn = provider.get_connection().await?;
conn.execute(&sql, ¶ms).await
}
}
pub struct SelectQueryBuilder<T: EntityType, R> {
state: QueryState,
_phantom_t: PhantomData<T>,
_phantom_r: PhantomData<R>,
}
impl<T: EntityType, R> SelectQueryBuilder<T, R> {
pub fn to_sql(&self) -> String {
self.state.to_sql()
}
pub async fn to_list(self) -> LrefResult<Vec<R>> {
let _sql = self.to_sql();
Ok(Vec::new())
}
}
pub fn materialize_entities<T: EntityType + FromRow>(rows: &[Vec<String>]) -> LrefResult<Vec<T>> {
let mut entities = Vec::with_capacity(rows.len());
for row in rows {
let entity = T::from_row(row)?;
entities.push(entity);
}
Ok(entities)
}
fn build_where(filters: &[FilterCondition]) -> String {
if filters.is_empty() {
String::new()
} else {
let clauses: Vec<String> = filters.iter().enumerate()
.map(|(i, f)| f.to_sql(&format!("${}", i + 1)))
.collect();
format!("WHERE {}", clauses.join(" AND "))
}
}
pub trait IQueryable<T: EntityType> {
fn query(&self) -> QueryBuilder<T>;
}