use crate::database::search::pagination::Pagination;
use crate::database::sql::clauses::{
having_clause::HavingClause,
join_clause::JoinClause,
order_by_clause::OrderByClause,
where_clause::WhereClause
};
use inflector::cases::snakecase::to_snake_case;
use crate::database::sql::enums::join_type::JoinType;
use crate::database::sql::enums::operator::Operator;
use crate::database::sql::enums::where_joiner::WhereJoiner;
use crate::database::sql::enums::order::Order;
use crate::database::sql::enums::sql_symbol::SqlSymbol;
use crate::database::sql::pagination_fragment::PaginationFragment;
use diesel::{MysqlConnection, sql_query, RunQueryDsl};
use crate::RpaError;
use crate::database::sql::count_query::CountQuery;
use regex::Regex;
use crate::database::sql::query::Query;
#[derive(Clone, Debug)]
pub struct QueryBuilder {
query: String,
select_fields: Vec<String>,
table_name: String,
join_clauses: Vec<JoinClause>,
where_clauses: Vec<WhereClause>,
group_by_clauses: Vec<String>,
having_clauses: Vec<HavingClause>,
order_by_clauses: Vec<OrderByClause>,
pagination: Option<Pagination>
}
impl QueryBuilder {
pub fn new(table_name: String) -> QueryBuilder {
return QueryBuilder {
query: String::new(),
select_fields: Vec::new(),
table_name,
join_clauses: Vec::new(),
where_clauses: Vec::new(),
group_by_clauses: Vec::new(),
having_clauses: Vec::new(),
order_by_clauses: Vec::new(),
pagination: None
};
}
pub fn select(mut self, field_name: String) -> Self {
let field_name = to_snake_case(field_name.as_str());
self.select_fields.push(field_name);
self
}
pub fn select_fields(mut self, field_names: Vec<String>) -> Self {
self.select_fields = field_names.iter()
.map(|field_name| to_snake_case(field_name.as_str()))
.collect();
self
}
pub fn from_table(mut self, table_name: String) -> Self {
let table_name = to_snake_case(table_name.as_str());
self.table_name = table_name;
self
}
pub fn join_table(mut self, join_table: String, from_field: String, to_field: String, join_type: JoinType) -> Self {
let from_field = to_snake_case(from_field.as_str());
let to_field = to_snake_case(to_field.as_str());
self.join_clauses.push(JoinClause {
join_table,
from_field,
to_field,
join_type
});
self
}
pub fn where_values(mut self, field_name: String, operator: Operator, value: String, joiner: WhereJoiner) -> Self {
let field_name = to_snake_case(field_name.as_str());
self.where_clauses.push(WhereClause {
field_name,
operator,
value,
joiner
});
self
}
pub fn order_by(mut self, condition: String, order: Order) -> Self {
let condition = to_snake_case(condition.as_str());
self.order_by_clauses.push(OrderByClause {
condition,
order
});
self
}
pub fn group_by(mut self, field_name: String) -> Self {
let field_name = to_snake_case(field_name.as_str());
self.group_by_clauses.push(field_name);
self
}
pub fn having(mut self, condition: String, operator: Operator, value: String) -> Self {
let condition = to_snake_case(condition.as_str());
self.having_clauses.push(HavingClause {
condition,
operator,
value
});
self
}
pub fn with_pagination(mut self, pagination: Option<Pagination>) -> Self {
self.pagination = pagination;
self
}
fn build_select_fragment(select_fields: Vec<String>) -> String {
let select_fragment;
if select_fields.len() > 0 {
select_fragment = format!("SELECT {} ", select_fields.join(", "));
} else {
select_fragment = String::from("SELECT * ");
}
select_fragment
}
fn build_from_fragment(table_name: String) -> String {
let from_fragment = format!(" FROM {} ", table_name);
from_fragment
}
fn build_join_fragment(join_clauses: Vec<JoinClause>) -> String {
let mut join_fragment = String::new();
if join_clauses.len() > 0 {
for join_clause in join_clauses {
let join = format!(" {} {} ON {} = {} ", join_clause.join_type.symbol(), join_clause.join_table, join_clause.from_field, join_clause.to_field);
join_fragment = format!(" {} {} ", join_fragment, join);
}
}
join_fragment
}
fn build_where_fragment(where_clauses: Vec<WhereClause>) -> String {
let mut where_fragment = String::new();
if where_clauses.len() > 0 {
if where_clauses.len() == 1 {
let where_clause = where_clauses.get(0).unwrap();
match where_clause.operator {
Operator::LIKE => {
where_fragment = format!(" WHERE {} {} '%{}%' ", where_clause.field_name, where_clause.operator.clone().symbol(), where_clause.value);
},
_ => {
where_fragment = format!(" WHERE {} {} '{}' ", where_clause.field_name, where_clause.operator.clone().symbol(), where_clause.value);
}
}
} else {
let mut index = 0;
let last = where_clauses.len();
where_fragment = where_clauses.iter().map(|where_clause| {
index += 1;
return if index == 0 || index == last {
match where_clause.operator {
Operator::LIKE => {
format!(" {} {} '%{}%' ", where_clause.field_name, where_clause.operator.clone().symbol(), where_clause.value)
},
_ => {
format!(" {} {} '{}' ", where_clause.field_name, where_clause.operator.clone().symbol(), where_clause.value)
}
}
} else {
match where_clause.operator {
Operator::LIKE => {
format!(" {} {} '%{}%' {} ", where_clause.field_name, where_clause.operator.clone().symbol(), where_clause.value, where_clause.joiner.clone().symbol())
},
_ => {
format!(" {} {} '{}' {} ", where_clause.field_name, where_clause.operator.clone().symbol(), where_clause.value, where_clause.joiner.clone().symbol())
}
}
}
}).collect();
where_fragment = format!(" WHERE {} ", where_fragment);
}
}
where_fragment
}
fn build_group_by_fragment(group_by_clauses: Vec<String>) -> String {
let mut group_by_fragment = String::new();
if group_by_clauses.len() > 0 {
group_by_fragment = format!(" GROUP BY {} ", group_by_clauses.join(", "));
}
group_by_fragment
}
fn build_having_fragment(having_clauses: Vec<HavingClause>) -> String {
let mut having_fragment = String::new();
if having_clauses.len() > 0 {
let having: Vec<String> = having_clauses.iter().map(|having_clause|
format!(" {} {} '{}' ", having_clause.condition, having_clause.operator.clone().symbol(), having_clause.value)
).collect();
having_fragment = format!(" HAVING {} ", having.join(", "));
}
having_fragment
}
fn build_order_by_fragment(order_by_clauses: Vec<OrderByClause>) -> String {
let mut order_by_fragment = String::new();
if order_by_clauses.len() > 0 {
let order_by: Vec<String> = order_by_clauses.iter().map(|order_by_clause|
format!(" {} {} ", order_by_clause.condition, order_by_clause.order.clone().symbol())
).collect();
order_by_fragment = format!(" ORDER BY {} ", order_by.join(", "));
}
order_by_fragment
}
fn build_fragments(
select_fragment: String,
from_fragment: String,
join_fragment: String,
where_fragment: String,
group_by_fragment: String,
having_fragment: String,
order_by_fragment: String,
pagination_fragment: Option<PaginationFragment>) -> String {
if pagination_fragment.is_some() {
let pagination_fragment = pagination_fragment.unwrap();
return pagination_fragment.query;
}
format!("{} {} {} {} {} {} {}", select_fragment, from_fragment, join_fragment, where_fragment, group_by_fragment, having_fragment, order_by_fragment)
}
fn build_pagination_fragment(pagination: Option<Pagination>,
select_fragment: String,
query_without_select_fragment: String,
connection: &MysqlConnection) -> Result<Option<PaginationFragment>, RpaError> {
if pagination.is_some() {
let pagination = pagination.unwrap();
let page_size = pagination.page_size;
let page = pagination.page;
let count_query = format!("SELECT COUNT(1) as count_result {}", query_without_select_fragment);
let count_query_result = sql_query(count_query.as_str()).get_result::<CountQuery>(connection).unwrap();
let record_count = count_query_result.count_result;
let total_pages = record_count / page_size;
if page > total_pages {
return Err(RpaError::builder().with_description("Page is out of range").build());
}
let start_record = page_size * (page - 1);
let pagination_query = format!("{} {} LIMIT {},{}", select_fragment, query_without_select_fragment, start_record, page_size);
return Ok(Some(PaginationFragment {
query: pagination_query,
page_size,
page,
total_pages
}));
}
Ok(None)
}
fn clean_query(dirty_query: String) -> String {
let re = Regex::new(r" +").unwrap();
let multiple_spaces_cleaned = re.replace_all(dirty_query.as_str(), " ").to_string();
let trimmed_query = multiple_spaces_cleaned.trim().to_string();
format!("{};", trimmed_query)
}
pub fn build(mut self, connection: &MysqlConnection) -> Result<Query, RpaError> {
let select_fragment = QueryBuilder::build_select_fragment(self.select_fields);
let from_fragment = QueryBuilder::build_from_fragment(self.table_name);
let join_fragment = QueryBuilder::build_join_fragment(self.join_clauses);
let where_fragment = QueryBuilder::build_where_fragment(self.where_clauses);
let group_by_fragment = QueryBuilder::build_group_by_fragment(self.group_by_clauses);
let having_fragment = QueryBuilder::build_having_fragment(self.having_clauses);
let order_by_fragment = QueryBuilder::build_order_by_fragment(self.order_by_clauses);
let query_without_select_fragment = QueryBuilder::build_fragments(
"".to_string(),
from_fragment,
join_fragment,
where_fragment,
group_by_fragment,
having_fragment,
order_by_fragment,
None);
let pagination_fragment = QueryBuilder::build_pagination_fragment(
self.pagination,
select_fragment.clone(),
query_without_select_fragment.clone(),
connection)?;
let mut page: Option<i64> = None;
let mut page_size: Option<i64> = None;
let mut total_pages: Option<i64> = None;
let dirty_query ;
if pagination_fragment.is_some() {
let pagination_fragment = pagination_fragment.unwrap();
dirty_query = pagination_fragment.query;
page = Some(pagination_fragment.page);
page_size = Some(pagination_fragment.page_size);
total_pages = Some(pagination_fragment.total_pages);
} else {
dirty_query = format!("{} {}", select_fragment, query_without_select_fragment);
}
self.query = QueryBuilder::clean_query(dirty_query);
Ok(Query{
sql: self.query,
page,
page_size,
total_pages
})
}
}