degen_sql/
sql_builder.rs

1
2use std::sync::Arc;
3use std::collections::BTreeMap;
4
5use tokio_postgres::types::ToSql;
6use crate::pagination::PaginationData;
7use crate::tiny_safe_string::TinySafeString;
8
9
10/*
11
12
13
14
15
16
17
18
19   let mut where_params: BTreeMap<String, Arc<dyn ToSql + Sync>> = BTreeMap::new();
20        where_params.insert("owner_address".to_string(), Arc::new(domain_address));
21        where_params.insert("chain_id".to_string(), Arc::new(chain_id));
22        
23        let sql_builder = SqlBuilder {
24            statement_base: SqlStatementBase::SelectAll,
25            table_name: "invoices".to_string(),
26            where_params,
27            order: Some(("created_at".to_string(), OrderingDirection::DESC)),
28            limit: None,
29            pagination: pagination.cloned(),
30        };
31        
32        // Build the SQL query and parameters
33        let (query, params) = sql_builder.build();
34        
35
36         let built_params = &params.iter().map(|x| &**x).collect::<Vec<_>>();
37
38        // Execute the query
39        let rows = psql_db.query(&query, &built_params).await?;
40
41        let mut invoices = Vec::new();
42        for row in rows {
43            match Invoice::from_row(&row) {
44                Ok(invoice) => invoices.push(invoice),
45                Err(e) => {
46                    eprintln!("Error parsing invoice row: {}", e);
47                    // Continue to next row instead of failing entirely
48                }
49            }
50        }
51
52        Ok(invoices)
53
54
55
56
57*/
58
59
60
61pub struct SqlBuilder {
62	pub statement_base: SqlStatementBase,
63	pub table_name : String, 
64 
65	pub where_params: BTreeMap<TinySafeString, (ComparisonType, Arc<dyn ToSql + Sync>)>, 
66    
67	pub order: Option<(TinySafeString,OrderingDirection)> , 
68 
69 
70 
71	pub limit: Option< u32 >, 
72	
73	// Optional pagination that overrides order, limit and offset when provided
74	pub pagination: Option<PaginationData>,
75}
76
77impl SqlBuilder {
78
79        // Create a new instance with default values
80            pub fn new(statement_base: SqlStatementBase, table_name: impl Into<String>) -> Self {
81                SqlBuilder {
82                    statement_base,
83                    table_name: table_name.into(),
84                    where_params: BTreeMap::new(),
85                    order: None,
86                    limit: None,
87                    pagination: None,
88                }
89            }
90            
91            // Add a where condition with equality comparison
92            pub fn where_eq(mut self, key: impl Into<TinySafeString>, value: impl ToSql + Sync + 'static
93         ) -> Self {
94                self.where_params.insert(key.into(), (ComparisonType::EQ, Arc::new(value) as Arc<dyn ToSql + Sync>));
95                self
96            }
97            
98           // Add a where condition with less than comparison
99            pub fn where_lt(mut self, key: impl Into<TinySafeString>, value: impl ToSql + Sync + 'static
100         ) -> Self {
101                self.where_params.insert(key.into(), (ComparisonType::LT, Arc::new(value) as Arc<dyn ToSql + Sync>));
102                self
103           }
104           
105           // Add a where condition with greater than comparison
106           pub fn where_gt(mut self, key: impl Into<TinySafeString>, value: impl ToSql + Sync + 'static
107         ) -> Self {
108               self.where_params.insert(key.into(), (ComparisonType::GT, Arc::new(value) as Arc<dyn ToSql + Sync>));
109               self
110           }
111           
112           // Add a where condition with less than or equal comparison
113           pub fn where_lte(mut self, key: impl Into<TinySafeString>, value: impl ToSql + Sync + 'static) -> Self {
114               self.where_params.insert(key.into(), (ComparisonType::LTE, Arc::new(value) as Arc<dyn ToSql + Sync>));
115               self
116           }
117           
118           // Add a where condition with greater than or equal comparison
119           pub fn where_gte(mut self, key: impl Into<TinySafeString>, value: impl ToSql + Sync + 'static) -> Self {
120               self.where_params.insert(key.into(), (ComparisonType::GTE, Arc::new(value) as Arc<dyn ToSql + Sync>));
121               self
122           }
123           
124           // Add a where condition with LIKE comparison
125           pub fn where_like(mut self, key: impl Into<TinySafeString>, value: impl ToSql + Sync + 'static) -> Self {
126               self.where_params.insert(key.into(), (ComparisonType::LIKE, Arc::new(value) as Arc<dyn ToSql + Sync>));
127               self
128           }
129           
130           // Add a where condition with IN comparison
131           pub fn where_in(mut self, key: impl Into<TinySafeString>, value: impl ToSql + Sync + 'static
132         ) -> Self {
133               self.where_params.insert(key.into(), (ComparisonType::IN, Arc::new(value) as Arc<dyn ToSql + Sync>));
134               self
135           }
136           
137           // Add a where condition with IS NULL comparison
138           pub fn where_null(mut self, key: impl Into<TinySafeString>) -> Self {
139               // The value doesn't matter for NULL comparison, just using a dummy value
140               self.where_params.insert(key.into(), (ComparisonType::NULL, Arc::new(0_i32) as Arc<dyn ToSql + Sync>));
141               self
142           }
143           
144           // Add a generic where condition with custom comparison
145           pub fn where_custom(mut self, key: impl Into<TinySafeString>, comparison_type: ComparisonType, value: impl ToSql + Sync + 'static) -> Self {
146               self.where_params.insert(key.into(), (comparison_type, Arc::new(value) as Arc<dyn ToSql + Sync>));
147               self
148           }
149           
150           // Set the ORDER BY clause
151           pub fn order_by(mut self, column: impl Into<TinySafeString>, direction: OrderingDirection) -> Self {
152               self.order = Some((column.into(), direction));
153               self
154           }
155           
156           // Set the LIMIT clause
157           pub fn limit(mut self, limit: u32) -> Self {
158               self.limit = Some(limit);
159               self
160           }
161           
162           // Helper method to set pagination
163           pub fn with_pagination(mut self, pagination: PaginationData) -> Self {
164               self.pagination = Some(pagination);
165               self
166           }
167
168
169
170
171    pub fn build(&self) -> (String , Vec<Arc<dyn ToSql + Sync>>  ) {
172        let mut query = format!("{} FROM {}", self.statement_base.build(), self.table_name);
173        let mut conditions = Vec::new();
174       
175
176           let mut params: Vec<Arc<dyn ToSql + Sync>> = Vec::new();
177        
178        // WHERE conditions
179        for (key, (comparison_type, param)) in &self.where_params {
180            params.push(Arc::clone(param)); // Clone Arc reference
181            
182            let operator = comparison_type.to_operator();
183            if *comparison_type == ComparisonType::IN {
184                conditions.push(format!("{} {} (${})", key, operator, params.len()));
185            } else if *comparison_type == ComparisonType::NULL {
186                conditions.push(format!("{} {}", key, operator));
187                // Pop the last parameter as NULL doesn't need a parameter
188                params.pop();
189            } else {
190                conditions.push(format!("{} {} ${}", key, operator, params.len()));
191            }
192        }
193
194        if !conditions.is_empty() {
195            query.push_str(" WHERE ");
196            query.push_str(&conditions.join(" AND "));
197        }
198
199        // Use pagination if provided, otherwise fall back to manual order and limit
200        if let Some(pagination) = &self.pagination {
201            // Append the pagination query part (includes ORDER BY, LIMIT, and OFFSET)
202            query.push_str(&format!(" {}", pagination.build_query_part()));
203        } else {
204            // ORDER BY clause
205            if let Some((column, direction)) = &self.order {
206                query.push_str(&format!(" ORDER BY {} {}", column, direction.build()));
207            }
208
209            // LIMIT clause
210            if let Some(limit) = self.limit {
211                query.push_str(&format!(" LIMIT {}", limit));
212            }
213        }
214
215        ( query , params) 
216    }
217}
218
219
220
221
222#[derive(PartialEq,Default)]
223pub enum ComparisonType {
224    #[default]
225    EQ,
226    LT,
227    GT,
228    LTE,
229    GTE,
230    LIKE,
231    IN,
232    NULL
233}
234
235impl ComparisonType {
236    pub fn to_operator(&self) -> &str {
237        match self {
238            Self::EQ => "=",
239            Self::LT => "<",
240            Self::GT => ">",
241            Self::LTE => "<=",
242            Self::GTE => ">=",
243            Self::LIKE => "LIKE",
244            Self::IN => "IN",
245            Self::NULL => "IS NULL",
246        }
247    }
248}
249
250
251pub enum SqlStatementBase {
252	SelectAll,
253    SelectCountAll,
254    Delete
255}
256
257impl SqlStatementBase {
258
259	pub fn build(&self) -> String {
260
261		match self {
262
263			Self::SelectAll => "SELECT *" ,
264            Self::SelectCountAll => "SELECT COUNT(*)" ,
265            Self::Delete => "DELETE"
266
267		}.to_string() 
268	}
269
270}
271
272pub enum OrderingDirection {
273
274	DESC,
275	ASC 
276}
277
278
279impl OrderingDirection {
280
281	pub fn build(&self) -> String {
282
283		match self {
284
285			Self::DESC => "DESC" ,
286			Self::ASC => "ASC" 
287
288		}.to_string() 
289	}
290
291}
292
293
294
295
296
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301    use std::collections::BTreeMap;
302    use std::sync::Arc;
303   
304    #[test]
305    fn test_sql_builder() {
306        let mut where_params: BTreeMap<TinySafeString, (ComparisonType, Arc<dyn ToSql + Sync>)> = BTreeMap::new();
307        where_params.insert("chain_id".into(), (ComparisonType::EQ, Arc::new(1_i64) as Arc<dyn ToSql + Sync>));
308        where_params.insert("status".into(), (ComparisonType::EQ, Arc::new("active".to_string()) as Arc<dyn ToSql + Sync>));
309        
310        let sql_builder = SqlBuilder {
311            statement_base: SqlStatementBase::SelectAll,
312            table_name: "teller_bids".into(),
313            where_params,
314            order: Some(("created_at".into(), OrderingDirection::DESC)),
315            limit: Some(10),
316            pagination: None,
317        };
318        
319        let (query, params) = sql_builder.build();
320        assert_eq!(
321            query,
322            "SELECT * FROM teller_bids WHERE chain_id = $1 AND status = $2 ORDER BY created_at DESC LIMIT 10"
323        );
324        assert_eq!(params.len(), 2);
325    }
326    
327    #[test]
328    fn test_sql_builder_with_different_comparison_types() {
329        let mut where_params: BTreeMap<TinySafeString, (ComparisonType, Arc<dyn ToSql + Sync>)> = BTreeMap::new();
330        where_params.insert("amount".into(), (ComparisonType::GT, Arc::new(1000_i64) as Arc<dyn ToSql + Sync>));
331        where_params.insert("created_at".into(), (ComparisonType::LTE, Arc::new("2023-01-01".to_string()) as Arc<dyn ToSql + Sync>));
332        where_params.insert("name".into(), (ComparisonType::LIKE, Arc::new("%test%".to_string()) as Arc<dyn ToSql + Sync>));
333        
334        let sql_builder = SqlBuilder {
335            statement_base: SqlStatementBase::SelectAll,
336            table_name: "transactions".into(),
337            where_params,
338            order: None,
339            limit: None,
340            pagination: None,
341        };
342        
343        let (query, params) = sql_builder.build();
344        assert_eq!(
345            query,
346            "SELECT * FROM transactions WHERE amount > $1 AND created_at <= $2 AND name LIKE $3"
347        );
348        assert_eq!(params.len(), 3);
349    }
350    
351    #[test]
352    fn test_sql_builder_with_null_comparison() {
353        let mut where_params: BTreeMap<TinySafeString, (ComparisonType, Arc<dyn ToSql + Sync>)> = BTreeMap::new();
354        // The parameter value doesn't matter for NULL comparison, but we need to provide something
355        where_params.insert("deleted_at".into(), (ComparisonType::NULL, Arc::new(0_i32) as Arc<dyn ToSql + Sync>));
356        where_params.insert("status".into(), (ComparisonType::EQ, Arc::new("active".to_string()) as Arc<dyn ToSql + Sync>));
357        
358        let sql_builder = SqlBuilder {
359            statement_base: SqlStatementBase::SelectAll,
360            table_name: "users".into(),
361            where_params,
362            order: None,
363            limit: None,
364            pagination: None,
365        };
366        
367        let (query, params) = sql_builder.build();
368        assert_eq!(
369            query,
370            "SELECT * FROM users WHERE deleted_at IS NULL AND status = $1"
371        );
372        // Only one parameter because NULL doesn't need a parameter
373        assert_eq!(params.len(), 1);
374    }
375    
376    #[test]
377    fn test_sql_builder_with_in_operator() {
378        let mut where_params: BTreeMap<TinySafeString, (ComparisonType, Arc<dyn ToSql + Sync>)> = BTreeMap::new();
379        // For an IN condition, you'd typically pass an array value
380        where_params.insert("status".into(), (ComparisonType::IN, Arc::new("(1, 2, 3)".to_string()) as Arc<dyn ToSql + Sync>));
381        
382        let sql_builder = SqlBuilder {
383            statement_base: SqlStatementBase::SelectCountAll,
384            table_name: "orders".into(),
385            where_params,
386            order: None,
387            limit: None,
388            pagination: None,
389        };
390        
391        let (query, params) = sql_builder.build();
392        assert_eq!(
393            query,
394            "SELECT COUNT(*) FROM orders WHERE status IN ($1)"
395        );
396        assert_eq!(params.len(), 1);
397    }
398    
399    #[test]
400    fn test_sql_builder_with_pagination() {
401        let pagination = PaginationData {
402            page: Some(2),
403            page_size: Some(20),
404            sort_by: Some("created_at".into()),
405            sort_dir: Some(crate::pagination::ColumnSortDir::Desc),
406        };
407        
408        let mut where_params: BTreeMap<TinySafeString, (ComparisonType, Arc<dyn ToSql + Sync>)> = BTreeMap::new();
409        where_params.insert("active".into(), (ComparisonType::EQ, Arc::new(true) as Arc<dyn ToSql + Sync>));
410        
411        let sql_builder = SqlBuilder {
412            statement_base: SqlStatementBase::SelectAll,
413            table_name: "products".into(),
414            where_params,
415            order: Some(("id".into(), OrderingDirection::ASC)), // This should be overridden by pagination
416            limit: Some(50), // This should be overridden by pagination
417            pagination: Some(pagination),
418        };
419        
420        let (query, params) = sql_builder.build();
421        // The exact query depends on how the PaginationData.build_query_part() method is implemented
422        assert!(query.contains("FROM products WHERE active = $1"));
423        assert_eq!(params.len(), 1);
424    }
425    
426    #[test]
427    fn test_delete_statement() {
428        let mut where_params: BTreeMap<TinySafeString, (ComparisonType, Arc<dyn ToSql + Sync>)> = BTreeMap::new();
429        where_params.insert("id".into(), (ComparisonType::EQ, Arc::new(42_i64) as Arc<dyn ToSql + Sync>));
430        
431        let sql_builder = SqlBuilder {
432            statement_base: SqlStatementBase::Delete,
433            table_name: "logs".into(),
434            where_params,
435            order: None,
436            limit: None,
437            pagination: None,
438        };
439        
440        let (query, params) = sql_builder.build();
441        assert_eq!(
442            query,
443            "DELETE FROM logs WHERE id = $1"
444        );
445        assert_eq!(params.len(), 1);
446    }
447}