use std::sync::Arc;
use std::collections::BTreeMap;
use tokio_postgres::types::ToSql;
use crate::pagination::PaginationData;
pub struct SqlBuilder {
pub statement_base: SqlStatementBase,
pub table_name : String,
pub where_params: BTreeMap<String, Arc<dyn ToSql + Sync> > ,
pub order: Option<(String,OrderingDirection)> ,
pub limit: Option< u32 >,
pub pagination: Option<PaginationData>,
}
impl SqlBuilder {
pub fn build(&self) -> (String , Vec<Arc<dyn ToSql + Sync>> ) {
let mut query = format!("{} FROM {}", self.statement_base.build(), self.table_name);
let mut conditions = Vec::new();
let mut params: Vec<Arc<dyn ToSql + Sync>> = Vec::new();
for (key, param) in &self.where_params {
params.push(Arc::clone(param)); conditions.push(format!("{} = ${}", key, params.len()));
}
if !conditions.is_empty() {
query.push_str(" WHERE ");
query.push_str(&conditions.join(" AND "));
}
if let Some(pagination) = &self.pagination {
query.push_str(&format!(" {}", pagination.build_query_part()));
} else {
if let Some((column, direction)) = &self.order {
query.push_str(&format!(" ORDER BY {} {}", column, direction.build()));
}
if let Some(limit) = self.limit {
query.push_str(&format!(" LIMIT {}", limit));
}
}
( query , params)
}
pub fn with_pagination(mut self, pagination: PaginationData) -> Self {
self.pagination = Some(pagination);
self
}
}
pub enum SqlStatementBase {
SelectAll,
SelectCountAll,
Delete
}
impl SqlStatementBase {
pub fn build(&self) -> String {
match self {
Self::SelectAll => "SELECT *" ,
Self::SelectCountAll => "SELECT COUNT(*)" ,
Self::Delete => "DELETE"
}.to_string()
}
}
pub enum OrderingDirection {
DESC,
ASC
}
impl OrderingDirection {
pub fn build(&self) -> String {
match self {
Self::DESC => "DESC" ,
Self::ASC => "ASC"
}.to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::BTreeMap;
use std::sync::Arc;
use crate::pagination::{PaginationData, ColumnSortDir};
use crate::tiny_safe_string::TinySafeString;
#[test]
fn test_sql_builder() {
let mut where_params: BTreeMap<String, Arc<dyn ToSql + Sync>> = BTreeMap::new();
where_params.insert("chain_id".to_string(), Arc::new(1_i64));
where_params.insert("status".to_string(), Arc::new("active".to_string()));
let sql_builder = SqlBuilder {
statement_base: SqlStatementBase::SelectAll,
table_name: "teller_bids".to_string(),
where_params,
order: Some(("created_at".to_string(), OrderingDirection::DESC)),
limit: Some(10),
pagination: None,
};
let (query, params) = sql_builder.build();
assert_eq!(
query,
"SELECT * FROM teller_bids WHERE chain_id = $1 AND status = $2 ORDER BY created_at DESC LIMIT 10"
);
assert_eq!(
params.len(),
2
);
}
#[test]
fn test_sql_builder_with_pagination() {
let mut where_params: BTreeMap<String, Arc<dyn ToSql + Sync>> = BTreeMap::new();
where_params.insert("chain_id".to_string(), Arc::new(1_i64));
let mut pagination = PaginationData::default();
pagination.page = Some(2);
pagination.page_size = Some(20);
pagination.sort_by = Some(TinySafeString::new("updated_at").unwrap());
pagination.sort_dir = Some(ColumnSortDir::Asc);
let sql_builder = SqlBuilder {
statement_base: SqlStatementBase::SelectAll,
table_name: "teller_bids".to_string(),
where_params,
order: Some(("created_at".to_string(), OrderingDirection::DESC)), limit: Some(10), pagination: Some(pagination),
};
let (query, params) = sql_builder.build();
assert_eq!(
query,
"SELECT * FROM teller_bids WHERE chain_id = $1 ORDER BY updated_at ASC LIMIT 20 OFFSET 20"
);
assert_eq!(
params.len(),
1
);
}
#[test]
fn test_sql_builder_with_pagination_method() {
let mut where_params: BTreeMap<String, Arc<dyn ToSql + Sync>> = BTreeMap::new();
where_params.insert("status".to_string(), Arc::new("pending".to_string()));
let mut pagination = PaginationData::default();
pagination.page = Some(3);
pagination.page_size = Some(15);
let sql_builder = SqlBuilder {
statement_base: SqlStatementBase::SelectAll,
table_name: "orders".to_string(),
where_params,
order: None,
limit: None,
pagination: None,
}.with_pagination(pagination);
let (query, params) = sql_builder.build();
assert_eq!(
query,
"SELECT * FROM orders WHERE status = $1 ORDER BY created_at DESC LIMIT 15 OFFSET 30"
);
assert_eq!(
params.len(),
1
);
}
#[test]
fn test_sql_builder_count_query() {
let mut where_params: BTreeMap<String, Arc<dyn ToSql + Sync>> = BTreeMap::new();
where_params.insert("apikey".to_string(), Arc::new("test-api-key".to_string()));
let sql_builder = SqlBuilder {
statement_base: SqlStatementBase::SelectCountAll,
table_name: "api_keys".to_string(),
where_params,
order: None,
limit: None,
pagination: None,
};
let (query, params) = sql_builder.build();
assert_eq!(
query,
"SELECT COUNT(*) FROM api_keys WHERE apikey = $1"
);
assert_eq!(
params.len(),
1
);
}
#[test]
fn test_sql_builder_delete_query() {
let mut where_params: BTreeMap<String, Arc<dyn ToSql + Sync>> = BTreeMap::new();
where_params.insert("apikey".to_string(), Arc::new("test-api-key".to_string()));
let sql_builder = SqlBuilder {
statement_base: SqlStatementBase::Delete,
table_name: "api_keys".to_string(),
where_params,
order: None,
limit: None,
pagination: None,
};
let (query, params) = sql_builder.build();
assert_eq!(
query,
"DELETE FROM api_keys WHERE apikey = $1"
);
assert_eq!(
params.len(),
1
);
}
#[test]
fn test_delete_by_apikey_example() {
let apikey = "example-api-key";
let mut where_params: BTreeMap<String, Arc<dyn ToSql + Sync>> = BTreeMap::new();
where_params.insert("apikey".to_string(), Arc::new(apikey.to_string()));
let count_builder = SqlBuilder {
statement_base: SqlStatementBase::SelectCountAll,
table_name: "api_keys".to_string(),
where_params: where_params.clone(),
order: None,
limit: None,
pagination: None,
};
let (count_query, _count_params) = count_builder.build();
assert_eq!(
count_query,
"SELECT COUNT(*) FROM api_keys WHERE apikey = $1"
);
let delete_builder = SqlBuilder {
statement_base: SqlStatementBase::Delete,
table_name: "api_keys".to_string(),
where_params,
order: None,
limit: None,
pagination: None,
};
let (delete_query, _delete_params) = delete_builder.build();
assert_eq!(
delete_query,
"DELETE FROM api_keys WHERE apikey = $1"
);
}
}