#![warn(missing_docs)]
use std::{fmt::Debug, marker::PhantomData, sync::Arc};
#[cfg(feature = "mysql")]
use sqlx::MySqlPool;
#[cfg(feature = "postgres")]
use sqlx::PgPool;
#[cfg(feature = "sqlite")]
use sqlx::SqlitePool;
use crate::dialects::get_dialect;
use crate::filter::{Filter, Filtered};
use crate::helpers::{StartingSql, bind_value, build_filter_expr, get_starting_sql};
use crate::schema::{ColumnInfo, Select, Value};
use crate::{database::error::DatabaseError, row::Row, schema::Schema};
#[derive(Debug)]
pub struct Query<T, S> {
pub(crate) table: PhantomData<T>,
pub(crate) filters: Vec<Box<dyn Filtered>>,
#[cfg(feature = "mysql")]
pub(crate) conn: Arc<MySqlPool>,
#[cfg(feature = "postgres")]
pub(crate) conn: Arc<PgPool>,
#[cfg(feature = "sqlite")]
pub(crate) conn: Arc<SqlitePool>,
pub(crate) select: Option<S>,
pub(crate) distinct: bool,
pub(crate) joins: Vec<JoinInfo>,
pub(crate) limit: Option<u64>,
pub(crate) offset: Option<u64>,
}
#[derive(Debug)]
pub(crate) struct JoinInfo {
pub(crate) table_name: String,
pub(crate) condition: Filter,
pub(crate) join_type: JoinType,
pub(crate) columns: Vec<ColumnInfo<'static>>,
pub(crate) selected_columns: Vec<&'static str>,
}
#[derive(Debug, PartialEq)]
pub(crate) enum JoinType {
Left,
Inner,
#[cfg(not(feature = "sqlite"))]
Right,
#[cfg(feature = "postgres")]
Full,
Cross,
}
impl<T: Schema + Debug, S: Select + Debug> Query<T, S> {
#[cfg(feature = "mysql")]
pub(crate) fn new(conn: Arc<MySqlPool>) -> Self {
Self {
table: PhantomData,
filters: Vec::new(),
select: None,
distinct: false,
limit: None,
offset: None,
joins: Vec::new(),
conn,
}
}
#[cfg(feature = "postgres")]
pub(crate) fn new(conn: Arc<PgPool>) -> Self {
Self {
table: PhantomData,
filters: Vec::new(),
select: None,
distinct: false,
limit: None,
offset: None,
joins: Vec::new(),
conn,
}
}
#[cfg(feature = "sqlite")]
pub(crate) fn new(conn: Arc<SqlitePool>) -> Self {
Self {
table: PhantomData,
filters: Vec::new(),
select: None,
distinct: false,
limit: None,
offset: None,
joins: Vec::new(),
conn,
}
}
pub fn filter<F>(mut self, filter: F) -> Self
where
F: Filtered + 'static,
{
self.filters.push(Box::new(filter));
self
}
pub fn limit(mut self, limit: u64) -> Self {
self.limit = Some(limit);
self
}
pub fn offset(mut self, offset: u64) -> Self {
self.offset = Some(offset);
self
}
pub fn select(mut self, select_schema: S) -> Self {
self.select = Some(select_schema);
self
}
pub fn select_distinct(mut self, select_schema: S) -> Self {
self.select = Some(select_schema);
self.distinct = true;
self
}
pub fn left_join<LeftJoinSchema: Schema + Debug, LeftJoinSchemaSelect: Select + Debug>(
mut self,
filter: Filter,
select_schema: LeftJoinSchemaSelect,
) -> Self {
self.joins.push(JoinInfo {
table_name: LeftJoinSchema::table_name().to_string(),
condition: filter,
join_type: JoinType::Left,
columns: LeftJoinSchema::get_all_columns(),
selected_columns: select_schema.get_selected(),
});
self
}
pub fn inner_join<InnerJoinSchema: Schema + Debug, InnerJoinSchemaSelect: Select + Debug>(
mut self,
filter: Filter,
select_schema: InnerJoinSchemaSelect,
) -> Self {
self.joins.push(JoinInfo {
table_name: InnerJoinSchema::table_name().to_string(),
condition: filter,
join_type: JoinType::Inner,
columns: InnerJoinSchema::get_all_columns(),
selected_columns: select_schema.get_selected(),
});
self
}
#[cfg(not(feature = "sqlite"))]
pub fn right_join<RightJoinSchema: Schema + Debug, RightJoinSchemaSelect: Select + Debug>(
mut self,
filter: Filter,
select_schema: RightJoinSchemaSelect,
) -> Self {
self.joins.push(JoinInfo {
table_name: RightJoinSchema::table_name().to_string(),
condition: filter,
join_type: JoinType::Right,
columns: RightJoinSchema::get_all_columns(),
selected_columns: select_schema.get_selected(),
});
self
}
#[cfg(feature = "postgres")]
pub fn full_join<FullJoinSchema: Schema + Debug, FullJoinSchemaSelect: Select + Debug>(
mut self,
filter: Filter,
select_schema: FullJoinSchemaSelect,
) -> Self {
self.joins.push(JoinInfo {
table_name: FullJoinSchema::table_name().to_string(),
condition: filter,
join_type: JoinType::Full,
columns: FullJoinSchema::get_all_columns(),
selected_columns: select_schema.get_selected(),
});
self
}
pub fn cross_join<CrossJoinSchema: Schema + Debug, CrossJoinSchemaSelect: Select + Debug>(
mut self,
select_schema: CrossJoinSchemaSelect,
) -> Self {
self.joins.push(JoinInfo {
table_name: CrossJoinSchema::table_name().to_string(),
condition: Filter::default(),
join_type: JoinType::Cross,
columns: CrossJoinSchema::get_all_columns(),
selected_columns: select_schema.get_selected(),
});
self
}
pub async fn execute(self) -> Result<Vec<Row<T>>, DatabaseError> {
let mut sql = get_starting_sql(StartingSql::Select, T::table_name());
if self.distinct {
sql.push_str(" DISTINCT ");
}
let sql = Self::select_sql(sql, self.select, T::table_name(), &self.joins);
let sql = Self::joins_sql(sql, &self.joins);
let mut params: Vec<Value> = Vec::new();
let mut sql = Self::filter_sql(sql, self.filters, &mut params);
if let Some(limit) = self.limit {
sql.push_str(&format!(" LIMIT {}", limit));
}
if let Some(offset) = self.offset {
if self.limit.is_none() {
sql.push_str(" LIMIT 18446744073709551615");
}
sql.push_str(&format!(" OFFSET {}", offset));
}
let mut conn = self
.conn
.acquire()
.await
.map_err(DatabaseError::ConnectionError)?;
let mut query = sqlx::query(&sql);
for v in params {
query = bind_value(query, v);
}
let data = query
.fetch_all(&mut *conn)
.await
.map_err(|e| DatabaseError::QueryError(e.to_string()))?;
#[cfg(feature = "mysql")]
let rows = Row::from_mysql_row(data, Some(&self.joins));
#[cfg(feature = "postgres")]
let rows = Row::from_postgres_row(data, Some(&self.joins));
#[cfg(feature = "sqlite")]
let rows = Row::from_sqlite_row(data, Some(&self.joins));
Ok(rows)
}
pub(crate) fn select_sql(
mut sql: String,
select: Option<S>,
table_name: &str,
joins: &Vec<JoinInfo>,
) -> String {
if let Some(selection) = select {
sql.push_str(&selection.get_selected().join(", "));
} else {
let dialect = get_dialect();
sql.push_str(&format!("{}.*", dialect.quote_identifier(table_name)));
}
if !joins.is_empty() {
for join in joins {
for column in &join.selected_columns {
sql.push_str(&format!(", {}", column));
}
}
}
sql.push_str(format!(" FROM {}", get_dialect().quote_identifier(table_name)).as_str());
sql
}
pub(crate) fn joins_sql(mut sql: String, joins: &Vec<JoinInfo>) -> String {
if joins.is_empty() {
return sql;
}
for join in joins {
let join_type = match join.join_type {
JoinType::Left => "LEFT JOIN",
JoinType::Inner => "INNER JOIN",
#[cfg(not(feature = "sqlite"))]
JoinType::Right => "RIGHT JOIN",
#[cfg(feature = "postgres")]
JoinType::Full => "FULL JOIN",
JoinType::Cross => "CROSS JOIN",
};
let join_table = &join.table_name;
if join_type == "CROSS JOIN" {
sql.push_str(&format!(" {} {}", join_type, join_table,));
} else {
sql.push_str(&format!(
" {} {} ON {}.{} = {}.{}",
join_type,
join_table,
join.condition.column_one.0,
join.condition.column_one.1,
join.condition.column_two.as_ref().unwrap().0,
join.condition.column_two.as_ref().unwrap().1
));
}
}
sql
}
pub(crate) fn filter_sql(
mut sql: String,
filters: Vec<Box<dyn Filtered>>,
params: &mut Vec<Value>,
) -> String {
if filters.is_empty() {
return sql;
}
sql.push_str(" WHERE ");
let mut parts: Vec<String> = Vec::with_capacity(filters.len());
for filter in &filters {
parts.push(build_filter_expr(filter.as_ref(), params));
}
sql.push_str(&parts.join(" AND "));
sql
}
}