1
2use std::sync::Arc;
3use std::collections::BTreeMap;
4
5use tokio_postgres::types::ToSql;
6use crate::pagination::PaginationData;
7use crate::tiny_safe_string::TinySafeString;
8
9
10pub struct SqlBuilder {
62 pub statement_base: SqlStatementBase,
63 pub table_name : String,
64
65 pub where_params: BTreeMap<TinySafeString, (ComparisonType, Arc<dyn ToSql + Sync>)>,
66
67 pub order: Option<(TinySafeString,OrderingDirection)> ,
68
69
70
71 pub limit: Option< u32 >,
72
73 pub pagination: Option<PaginationData>,
75}
76
77impl SqlBuilder {
78
79 pub fn new(statement_base: SqlStatementBase, table_name: impl Into<String>) -> Self {
81 SqlBuilder {
82 statement_base,
83 table_name: table_name.into(),
84 where_params: BTreeMap::new(),
85 order: None,
86 limit: None,
87 pagination: None,
88 }
89 }
90
91 pub fn where_eq(mut self, key: impl Into<TinySafeString>, value: impl ToSql + Sync + 'static
93 ) -> Self {
94 self.where_params.insert(key.into(), (ComparisonType::EQ, Arc::new(value) as Arc<dyn ToSql + Sync>));
95 self
96 }
97
98 pub fn where_lt(mut self, key: impl Into<TinySafeString>, value: impl ToSql + Sync + 'static
100 ) -> Self {
101 self.where_params.insert(key.into(), (ComparisonType::LT, Arc::new(value) as Arc<dyn ToSql + Sync>));
102 self
103 }
104
105 pub fn where_gt(mut self, key: impl Into<TinySafeString>, value: impl ToSql + Sync + 'static
107 ) -> Self {
108 self.where_params.insert(key.into(), (ComparisonType::GT, Arc::new(value) as Arc<dyn ToSql + Sync>));
109 self
110 }
111
112 pub fn where_lte(mut self, key: impl Into<TinySafeString>, value: impl ToSql + Sync + 'static) -> Self {
114 self.where_params.insert(key.into(), (ComparisonType::LTE, Arc::new(value) as Arc<dyn ToSql + Sync>));
115 self
116 }
117
118 pub fn where_gte(mut self, key: impl Into<TinySafeString>, value: impl ToSql + Sync + 'static) -> Self {
120 self.where_params.insert(key.into(), (ComparisonType::GTE, Arc::new(value) as Arc<dyn ToSql + Sync>));
121 self
122 }
123
124 pub fn where_like(mut self, key: impl Into<TinySafeString>, value: impl ToSql + Sync + 'static) -> Self {
126 self.where_params.insert(key.into(), (ComparisonType::LIKE, Arc::new(value) as Arc<dyn ToSql + Sync>));
127 self
128 }
129
130 pub fn where_in(mut self, key: impl Into<TinySafeString>, value: impl ToSql + Sync + 'static
132 ) -> Self {
133 self.where_params.insert(key.into(), (ComparisonType::IN, Arc::new(value) as Arc<dyn ToSql + Sync>));
134 self
135 }
136
137 pub fn where_null(mut self, key: impl Into<TinySafeString>) -> Self {
139 self.where_params.insert(key.into(), (ComparisonType::NULL, Arc::new(0_i32) as Arc<dyn ToSql + Sync>));
141 self
142 }
143
144 pub fn where_custom(mut self, key: impl Into<TinySafeString>, comparison_type: ComparisonType, value: impl ToSql + Sync + 'static) -> Self {
146 self.where_params.insert(key.into(), (comparison_type, Arc::new(value) as Arc<dyn ToSql + Sync>));
147 self
148 }
149
150 pub fn order_by(mut self, column: impl Into<TinySafeString>, direction: OrderingDirection) -> Self {
152 self.order = Some((column.into(), direction));
153 self
154 }
155
156 pub fn limit(mut self, limit: u32) -> Self {
158 self.limit = Some(limit);
159 self
160 }
161
162 pub fn with_pagination(mut self, pagination: PaginationData) -> Self {
164 self.pagination = Some(pagination);
165 self
166 }
167
168
169
170
171 pub fn build(&self) -> (String , Vec<Arc<dyn ToSql + Sync>> ) {
172 let mut query = format!("{} FROM {}", self.statement_base.build(), self.table_name);
173 let mut conditions = Vec::new();
174
175
176 let mut params: Vec<Arc<dyn ToSql + Sync>> = Vec::new();
177
178 for (key, (comparison_type, param)) in &self.where_params {
180 params.push(Arc::clone(param)); let operator = comparison_type.to_operator();
183 if *comparison_type == ComparisonType::IN {
184 conditions.push(format!("{} {} (${})", key, operator, params.len()));
185 } else if *comparison_type == ComparisonType::NULL {
186 conditions.push(format!("{} {}", key, operator));
187 params.pop();
189 } else {
190 conditions.push(format!("{} {} ${}", key, operator, params.len()));
191 }
192 }
193
194 if !conditions.is_empty() {
195 query.push_str(" WHERE ");
196 query.push_str(&conditions.join(" AND "));
197 }
198
199 if let Some(pagination) = &self.pagination {
201 query.push_str(&format!(" {}", pagination.build_query_part()));
203 } else {
204 if let Some((column, direction)) = &self.order {
206 query.push_str(&format!(" ORDER BY {} {}", column, direction.build()));
207 }
208
209 if let Some(limit) = self.limit {
211 query.push_str(&format!(" LIMIT {}", limit));
212 }
213 }
214
215 ( query , params)
216 }
217}
218
219
220
221
222#[derive(PartialEq,Default)]
223pub enum ComparisonType {
224 #[default]
225 EQ,
226 LT,
227 GT,
228 LTE,
229 GTE,
230 LIKE,
231 IN,
232 NULL
233}
234
235impl ComparisonType {
236 pub fn to_operator(&self) -> &str {
237 match self {
238 Self::EQ => "=",
239 Self::LT => "<",
240 Self::GT => ">",
241 Self::LTE => "<=",
242 Self::GTE => ">=",
243 Self::LIKE => "LIKE",
244 Self::IN => "IN",
245 Self::NULL => "IS NULL",
246 }
247 }
248}
249
250
251pub enum SqlStatementBase {
252 SelectAll,
253 SelectCountAll,
254 Delete
255}
256
257impl SqlStatementBase {
258
259 pub fn build(&self) -> String {
260
261 match self {
262
263 Self::SelectAll => "SELECT *" ,
264 Self::SelectCountAll => "SELECT COUNT(*)" ,
265 Self::Delete => "DELETE"
266
267 }.to_string()
268 }
269
270}
271
272pub enum OrderingDirection {
273
274 DESC,
275 ASC
276}
277
278
279impl OrderingDirection {
280
281 pub fn build(&self) -> String {
282
283 match self {
284
285 Self::DESC => "DESC" ,
286 Self::ASC => "ASC"
287
288 }.to_string()
289 }
290
291}
292
293
294
295
296
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301 use std::collections::BTreeMap;
302 use std::sync::Arc;
303
304 #[test]
305 fn test_sql_builder() {
306 let mut where_params: BTreeMap<TinySafeString, (ComparisonType, Arc<dyn ToSql + Sync>)> = BTreeMap::new();
307 where_params.insert("chain_id".into(), (ComparisonType::EQ, Arc::new(1_i64) as Arc<dyn ToSql + Sync>));
308 where_params.insert("status".into(), (ComparisonType::EQ, Arc::new("active".to_string()) as Arc<dyn ToSql + Sync>));
309
310 let sql_builder = SqlBuilder {
311 statement_base: SqlStatementBase::SelectAll,
312 table_name: "teller_bids".into(),
313 where_params,
314 order: Some(("created_at".into(), OrderingDirection::DESC)),
315 limit: Some(10),
316 pagination: None,
317 };
318
319 let (query, params) = sql_builder.build();
320 assert_eq!(
321 query,
322 "SELECT * FROM teller_bids WHERE chain_id = $1 AND status = $2 ORDER BY created_at DESC LIMIT 10"
323 );
324 assert_eq!(params.len(), 2);
325 }
326
327 #[test]
328 fn test_sql_builder_with_different_comparison_types() {
329 let mut where_params: BTreeMap<TinySafeString, (ComparisonType, Arc<dyn ToSql + Sync>)> = BTreeMap::new();
330 where_params.insert("amount".into(), (ComparisonType::GT, Arc::new(1000_i64) as Arc<dyn ToSql + Sync>));
331 where_params.insert("created_at".into(), (ComparisonType::LTE, Arc::new("2023-01-01".to_string()) as Arc<dyn ToSql + Sync>));
332 where_params.insert("name".into(), (ComparisonType::LIKE, Arc::new("%test%".to_string()) as Arc<dyn ToSql + Sync>));
333
334 let sql_builder = SqlBuilder {
335 statement_base: SqlStatementBase::SelectAll,
336 table_name: "transactions".into(),
337 where_params,
338 order: None,
339 limit: None,
340 pagination: None,
341 };
342
343 let (query, params) = sql_builder.build();
344 assert_eq!(
345 query,
346 "SELECT * FROM transactions WHERE amount > $1 AND created_at <= $2 AND name LIKE $3"
347 );
348 assert_eq!(params.len(), 3);
349 }
350
351 #[test]
352 fn test_sql_builder_with_null_comparison() {
353 let mut where_params: BTreeMap<TinySafeString, (ComparisonType, Arc<dyn ToSql + Sync>)> = BTreeMap::new();
354 where_params.insert("deleted_at".into(), (ComparisonType::NULL, Arc::new(0_i32) as Arc<dyn ToSql + Sync>));
356 where_params.insert("status".into(), (ComparisonType::EQ, Arc::new("active".to_string()) as Arc<dyn ToSql + Sync>));
357
358 let sql_builder = SqlBuilder {
359 statement_base: SqlStatementBase::SelectAll,
360 table_name: "users".into(),
361 where_params,
362 order: None,
363 limit: None,
364 pagination: None,
365 };
366
367 let (query, params) = sql_builder.build();
368 assert_eq!(
369 query,
370 "SELECT * FROM users WHERE deleted_at IS NULL AND status = $1"
371 );
372 assert_eq!(params.len(), 1);
374 }
375
376 #[test]
377 fn test_sql_builder_with_in_operator() {
378 let mut where_params: BTreeMap<TinySafeString, (ComparisonType, Arc<dyn ToSql + Sync>)> = BTreeMap::new();
379 where_params.insert("status".into(), (ComparisonType::IN, Arc::new("(1, 2, 3)".to_string()) as Arc<dyn ToSql + Sync>));
381
382 let sql_builder = SqlBuilder {
383 statement_base: SqlStatementBase::SelectCountAll,
384 table_name: "orders".into(),
385 where_params,
386 order: None,
387 limit: None,
388 pagination: None,
389 };
390
391 let (query, params) = sql_builder.build();
392 assert_eq!(
393 query,
394 "SELECT COUNT(*) FROM orders WHERE status IN ($1)"
395 );
396 assert_eq!(params.len(), 1);
397 }
398
399 #[test]
400 fn test_sql_builder_with_pagination() {
401 let pagination = PaginationData {
402 page: Some(2),
403 page_size: Some(20),
404 sort_by: Some("created_at".into()),
405 sort_dir: Some(crate::pagination::ColumnSortDir::Desc),
406 };
407
408 let mut where_params: BTreeMap<TinySafeString, (ComparisonType, Arc<dyn ToSql + Sync>)> = BTreeMap::new();
409 where_params.insert("active".into(), (ComparisonType::EQ, Arc::new(true) as Arc<dyn ToSql + Sync>));
410
411 let sql_builder = SqlBuilder {
412 statement_base: SqlStatementBase::SelectAll,
413 table_name: "products".into(),
414 where_params,
415 order: Some(("id".into(), OrderingDirection::ASC)), limit: Some(50), pagination: Some(pagination),
418 };
419
420 let (query, params) = sql_builder.build();
421 assert!(query.contains("FROM products WHERE active = $1"));
423 assert_eq!(params.len(), 1);
424 }
425
426 #[test]
427 fn test_delete_statement() {
428 let mut where_params: BTreeMap<TinySafeString, (ComparisonType, Arc<dyn ToSql + Sync>)> = BTreeMap::new();
429 where_params.insert("id".into(), (ComparisonType::EQ, Arc::new(42_i64) as Arc<dyn ToSql + Sync>));
430
431 let sql_builder = SqlBuilder {
432 statement_base: SqlStatementBase::Delete,
433 table_name: "logs".into(),
434 where_params,
435 order: None,
436 limit: None,
437 pagination: None,
438 };
439
440 let (query, params) = sql_builder.build();
441 assert_eq!(
442 query,
443 "DELETE FROM logs WHERE id = $1"
444 );
445 assert_eq!(params.len(), 1);
446 }
447}