1
2use std::sync::Arc;
3use std::collections::BTreeMap;
4
5use tokio_postgres::types::ToSql;
6use crate::pagination::PaginationData;
7use crate::tiny_safe_string::TinySafeString;
8
9
10pub struct SqlBuilder {
62 pub statement_base: SqlStatementBase,
63 pub table_name : String,
64
65 pub where_params: BTreeMap<TinySafeString, (ComparisonType, Arc<dyn ToSql + Sync>)>,
66
67 pub order: Option<(TinySafeString,OrderingDirection)> ,
68
69
70
71 pub limit: Option< u32 >,
72
73 pub pagination: Option<PaginationData>,
75}
76
77impl SqlBuilder {
78 pub fn build(&self) -> (String , Vec<Arc<dyn ToSql + Sync>> ) {
79 let mut query = format!("{} FROM {}", self.statement_base.build(), self.table_name);
80 let mut conditions = Vec::new();
81
82
83 let mut params: Vec<Arc<dyn ToSql + Sync>> = Vec::new();
84
85 for (key, (comparison_type, param)) in &self.where_params {
87 params.push(Arc::clone(param)); let operator = comparison_type.to_operator();
90 if *comparison_type == ComparisonType::IN {
91 conditions.push(format!("{} {} (${}}})", key, operator, params.len()));
92 } else {
93 conditions.push(format!("{} {} ${}", key, operator, params.len()));
94 }
95 }
96
97 if !conditions.is_empty() {
98 query.push_str(" WHERE ");
99 query.push_str(&conditions.join(" AND "));
100 }
101
102 if let Some(pagination) = &self.pagination {
104 query.push_str(&format!(" {}", pagination.build_query_part()));
106 } else {
107 if let Some((column, direction)) = &self.order {
109 query.push_str(&format!(" ORDER BY {} {}", column, direction.build()));
110 }
111
112 if let Some(limit) = self.limit {
114 query.push_str(&format!(" LIMIT {}", limit));
115 }
116 }
117
118 ( query , params)
119 }
120
121 pub fn with_pagination(mut self, pagination: PaginationData) -> Self {
123 self.pagination = Some(pagination);
124 self
125 }
126}
127
128
129
130
131#[derive(PartialEq,Default)]
132pub enum ComparisonType {
133 #[default]
134 EQ,
135 LT,
136 GT,
137 LTE,
138 GTE,
139 LIKE,
140 IN,
141}
142
143impl ComparisonType {
144 pub fn to_operator(&self) -> &str {
145 match self {
146 Self::EQ => "=",
147 Self::LT => "<",
148 Self::GT => ">",
149 Self::LTE => "<=",
150 Self::GTE => ">=",
151 Self::LIKE => "LIKE",
152 Self::IN => "IN",
153 }
154 }
155}
156
157
158pub enum SqlStatementBase {
159 SelectAll,
160 SelectCountAll,
161 Delete
162}
163
164impl SqlStatementBase {
165
166 pub fn build(&self) -> String {
167
168 match self {
169
170 Self::SelectAll => "SELECT *" ,
171 Self::SelectCountAll => "SELECT COUNT(*)" ,
172 Self::Delete => "DELETE"
173
174 }.to_string()
175 }
176
177}
178
179pub enum OrderingDirection {
180
181 DESC,
182 ASC
183}
184
185
186impl OrderingDirection {
187
188 pub fn build(&self) -> String {
189
190 match self {
191
192 Self::DESC => "DESC" ,
193 Self::ASC => "ASC"
194
195 }.to_string()
196 }
197
198}
199
200
201
202
203
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use std::collections::BTreeMap;
209 use std::sync::Arc;
210
211 #[test]
212 fn test_sql_builder() {
213 let mut where_params: BTreeMap<TinySafeString, (ComparisonType, Arc<dyn ToSql + Sync>)> = BTreeMap::new();
214 where_params.insert("chain_id".into(), (ComparisonType::EQ, Arc::new(1_i64) as Arc<dyn ToSql + Sync>));
215 where_params.insert("status".into(), (ComparisonType::EQ, Arc::new("active".to_string()) as Arc<dyn ToSql + Sync>));
216
217 let sql_builder = SqlBuilder {
218 statement_base: SqlStatementBase::SelectAll,
219 table_name: "teller_bids".into(),
220 where_params,
221 order: Some(("created_at".into(), OrderingDirection::DESC)),
222 limit: Some(10),
223 pagination: None,
224 };
225
226 let (query, params) = sql_builder.build();
227 assert_eq!(
228 query,
229 "SELECT * FROM teller_bids WHERE chain_id = $1 AND status = $2 ORDER BY created_at DESC LIMIT 10"
230 );
231 assert_eq!(params.len(), 2);
232 }
233
234 #[test]
235 fn test_sql_builder_with_different_comparison_types() {
236 let mut where_params: BTreeMap<TinySafeString, (ComparisonType, Arc<dyn ToSql + Sync>)> = BTreeMap::new();
237 where_params.insert("amount".into(), (ComparisonType::GT, Arc::new(1000_i64) as Arc<dyn ToSql + Sync>));
238 where_params.insert("created_at".into(), (ComparisonType::LTE, Arc::new("2023-01-01".to_string()) as Arc<dyn ToSql + Sync>));
239 where_params.insert("name".into(), (ComparisonType::LIKE, Arc::new("%test%".to_string()) as Arc<dyn ToSql + Sync>));
240
241 let sql_builder = SqlBuilder {
242 statement_base: SqlStatementBase::SelectAll,
243 table_name: "transactions".into(),
244 where_params,
245 order: None,
246 limit: None,
247 pagination: None,
248 };
249
250 let (query, params) = sql_builder.build();
251 assert_eq!(
252 query,
253 "SELECT * FROM transactions WHERE amount > $1 AND created_at <= $2 AND name LIKE $3"
254 );
255 assert_eq!(params.len(), 3);
256 }
257}