degen_sql/
sql_builder.rs

1
2use std::sync::Arc;
3use std::collections::BTreeMap;
4
5use tokio_postgres::types::ToSql;
6use crate::pagination::PaginationData;
7
8pub struct SqlBuilder {
9	pub statement_base: SqlStatementBase,
10	pub table_name : String, 
11	pub where_params: BTreeMap<String, Arc<dyn ToSql + Sync> > , 
12
13	pub order: Option<(String,OrderingDirection)> , 
14
15	pub limit: Option< u32 >, 
16	
17	// Optional pagination that overrides order, limit and offset when provided
18	pub pagination: Option<PaginationData>,
19}
20
21impl SqlBuilder {
22    pub fn build(&self) -> (String , Vec<Arc<dyn ToSql + Sync>>  ) {
23        let mut query = format!("{} FROM {}", self.statement_base.build(), self.table_name);
24        let mut conditions = Vec::new();
25        let mut params: Vec<Arc<dyn ToSql + Sync>> = Vec::new();
26
27        // WHERE conditions
28        for (key, param) in &self.where_params {
29            params.push(Arc::clone(param)); // Clone Arc reference
30            conditions.push(format!("{} = ${}", key, params.len()));
31        }
32
33        if !conditions.is_empty() {
34            query.push_str(" WHERE ");
35            query.push_str(&conditions.join(" AND "));
36        }
37
38        // Use pagination if provided, otherwise fall back to manual order and limit
39        if let Some(pagination) = &self.pagination {
40            // Append the pagination query part (includes ORDER BY, LIMIT, and OFFSET)
41            query.push_str(&format!(" {}", pagination.build_query_part()));
42        } else {
43            // ORDER BY clause
44            if let Some((column, direction)) = &self.order {
45                query.push_str(&format!(" ORDER BY {} {}", column, direction.build()));
46            }
47
48            // LIMIT clause
49            if let Some(limit) = self.limit {
50                query.push_str(&format!(" LIMIT {}", limit));
51            }
52        }
53
54        ( query , params) 
55    }
56    
57    // Helper method to set pagination
58    pub fn with_pagination(mut self, pagination: PaginationData) -> Self {
59        self.pagination = Some(pagination);
60        self
61    }
62}
63
64
65
66pub enum SqlStatementBase {
67	SelectAll,
68    SelectCountAll,
69    Delete
70}
71
72impl SqlStatementBase {
73
74	pub fn build(&self) -> String {
75
76		match self {
77
78			Self::SelectAll => "SELECT *" ,
79            Self::SelectCountAll => "SELECT COUNT(*)" ,
80            Self::Delete => "DELETE"
81
82		}.to_string() 
83	}
84
85}
86
87pub enum OrderingDirection {
88
89	DESC,
90	ASC 
91}
92
93
94impl OrderingDirection {
95
96	pub fn build(&self) -> String {
97
98		match self {
99
100			Self::DESC => "DESC" ,
101			Self::ASC => "ASC" 
102
103		}.to_string() 
104	}
105
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use std::collections::BTreeMap;
112    use std::sync::Arc;
113    use crate::pagination::{PaginationData, ColumnSortDir};
114    use crate::tiny_safe_string::TinySafeString;
115
116    #[test]
117    fn test_sql_builder() {
118        let mut where_params: BTreeMap<String, Arc<dyn ToSql + Sync>> = BTreeMap::new();
119        where_params.insert("chain_id".to_string(), Arc::new(1_i64));
120        where_params.insert("status".to_string(), Arc::new("active".to_string()));
121
122        let sql_builder = SqlBuilder {
123            statement_base: SqlStatementBase::SelectAll,
124            table_name: "teller_bids".to_string(),
125            where_params,
126            order: Some(("created_at".to_string(), OrderingDirection::DESC)),
127            limit: Some(10),
128            pagination: None,
129        };
130
131        let (query, params) = sql_builder.build();
132
133        assert_eq!(
134            query,
135            "SELECT * FROM teller_bids WHERE chain_id = $1 AND status = $2 ORDER BY created_at DESC LIMIT 10"
136        );
137
138        assert_eq!(
139            params.len(),
140            2
141        );
142    }
143    
144    #[test]
145    fn test_sql_builder_with_pagination() {
146        let mut where_params: BTreeMap<String, Arc<dyn ToSql + Sync>> = BTreeMap::new();
147        where_params.insert("chain_id".to_string(), Arc::new(1_i64));
148        
149        let mut pagination = PaginationData::default();
150        pagination.page = Some(2);
151        pagination.page_size = Some(20);
152        pagination.sort_by = Some(TinySafeString::new("updated_at").unwrap());
153        pagination.sort_dir = Some(ColumnSortDir::Asc);
154        
155        let sql_builder = SqlBuilder {
156            statement_base: SqlStatementBase::SelectAll,
157            table_name: "teller_bids".to_string(),
158            where_params,
159            order: Some(("created_at".to_string(), OrderingDirection::DESC)), // Should be ignored
160            limit: Some(10), // Should be ignored
161            pagination: Some(pagination),
162        };
163
164        let (query, params) = sql_builder.build();
165
166        assert_eq!(
167            query,
168            "SELECT * FROM teller_bids WHERE chain_id = $1 ORDER BY updated_at ASC LIMIT 20 OFFSET 20"
169        );
170
171        assert_eq!(
172            params.len(),
173            1
174        );
175    }
176    
177    #[test]
178    fn test_sql_builder_with_pagination_method() {
179        let mut where_params: BTreeMap<String, Arc<dyn ToSql + Sync>> = BTreeMap::new();
180        where_params.insert("status".to_string(), Arc::new("pending".to_string()));
181        
182        let mut pagination = PaginationData::default();
183        pagination.page = Some(3);
184        pagination.page_size = Some(15);
185        
186        let sql_builder = SqlBuilder {
187            statement_base: SqlStatementBase::SelectAll,
188            table_name: "orders".to_string(),
189            where_params,
190            order: None,
191            limit: None,
192            pagination: None,
193        }.with_pagination(pagination);
194
195        let (query, params) = sql_builder.build();
196
197        assert_eq!(
198            query,
199            "SELECT * FROM orders WHERE status = $1 ORDER BY created_at DESC LIMIT 15 OFFSET 30"
200        );
201
202        assert_eq!(
203            params.len(),
204            1
205        );
206    }
207    
208    // Tests for the example queries in delete_by_apikey function
209    #[test]
210    fn test_sql_builder_count_query() {
211        let mut where_params: BTreeMap<String, Arc<dyn ToSql + Sync>> = BTreeMap::new();
212        where_params.insert("apikey".to_string(), Arc::new("test-api-key".to_string()));
213        
214        let sql_builder = SqlBuilder {
215            statement_base: SqlStatementBase::SelectCountAll,
216            table_name: "api_keys".to_string(),
217            where_params,
218            order: None,
219            limit: None,
220            pagination: None,
221        };
222        
223        let (query, params) = sql_builder.build();
224        
225        assert_eq!(
226            query,
227            "SELECT COUNT(*) FROM api_keys WHERE apikey = $1"
228        );
229        
230        assert_eq!(
231            params.len(),
232            1
233        );
234    }
235    
236    #[test]
237    fn test_sql_builder_delete_query() {
238        let mut where_params: BTreeMap<String, Arc<dyn ToSql + Sync>> = BTreeMap::new();
239        where_params.insert("apikey".to_string(), Arc::new("test-api-key".to_string()));
240        
241        let sql_builder = SqlBuilder {
242            statement_base: SqlStatementBase::Delete,
243            table_name: "api_keys".to_string(),
244            where_params,
245            order: None,
246            limit: None,
247            pagination: None,
248        };
249        
250        let (query, params) = sql_builder.build();
251        
252        assert_eq!(
253            query,
254            "DELETE FROM api_keys WHERE apikey = $1"
255        );
256        
257        assert_eq!(
258            params.len(),
259            1
260        );
261    }
262    
263    #[test]
264    fn test_delete_by_apikey_example() {
265        // This test shows how to build both queries from the delete_by_apikey example
266        
267        // First query: "SELECT COUNT(*) FROM api_keys WHERE apikey = $1;"
268        let apikey = "example-api-key";
269        let mut where_params: BTreeMap<String, Arc<dyn ToSql + Sync>> = BTreeMap::new();
270        where_params.insert("apikey".to_string(), Arc::new(apikey.to_string()));
271        
272        let count_builder = SqlBuilder {
273            statement_base: SqlStatementBase::SelectCountAll,
274            table_name: "api_keys".to_string(),
275            where_params: where_params.clone(),
276            order: None,
277            limit: None,
278            pagination: None,
279        };
280        
281        let (count_query, _count_params) = count_builder.build();
282        
283        assert_eq!(
284            count_query,
285            "SELECT COUNT(*) FROM api_keys WHERE apikey = $1"
286        );
287        
288        // Second query: "DELETE FROM api_keys WHERE apikey = $1;"
289        let delete_builder = SqlBuilder {
290            statement_base: SqlStatementBase::Delete,
291            table_name: "api_keys".to_string(),
292            where_params,
293            order: None,
294            limit: None,
295            pagination: None,
296        };
297        
298        let (delete_query, _delete_params) = delete_builder.build();
299        
300        assert_eq!(
301            delete_query,
302            "DELETE FROM api_keys WHERE apikey = $1"
303        );
304        
305        // Example of how these might be used (this doesn't execute, just shows the pattern)
306        /*
307        async fn delete_by_apikey_example(
308            apikey: &str,
309            psql_db: &Database,
310        ) -> Result<bool, PostgresModelError> {
311            // First verify the API key exists
312            let count_builder = SqlBuilder {
313                statement_base: SqlStatementBase::SelectCountAll,
314                table_name: "api_keys".to_string(),
315                where_params: {
316                    let mut map = BTreeMap::new();
317                    map.insert("apikey".to_string(), Arc::new(apikey.to_string()));
318                    map
319                },
320                order: None,
321                limit: None,
322                pagination: None,
323            };
324            
325            let (count_query, count_params) = count_builder.build();
326            let check_result = psql_db.query_one(&count_query, &count_params).await?;
327            let count: i64 = check_result.get(0);
328            
329            if count == 0 {
330                return Ok(false);
331            }
332            
333            // Now delete the API key
334            let delete_builder = SqlBuilder {
335                statement_base: SqlStatementBase::Delete,
336                table_name: "api_keys".to_string(),
337                where_params: {
338                    let mut map = BTreeMap::new();
339                    map.insert("apikey".to_string(), Arc::new(apikey.to_string()));
340                    map
341                },
342                order: None,
343                limit: None,
344                pagination: None,
345            };
346            
347            let (delete_query, delete_params) = delete_builder.build();
348            let result = psql_db.execute(&delete_query, &delete_params).await;
349            
350            match result {
351                Ok(rows_affected) => Ok(rows_affected > 0),
352                Err(e) => Err(e.into()),
353            }
354        }
355        */
356    }
357}