1
2use std::sync::Arc;
3use std::collections::BTreeMap;
4
5use tokio_postgres::types::ToSql;
6use crate::pagination::PaginationData;
7
8pub struct SqlBuilder {
9 pub statement_base: SqlStatementBase,
10 pub table_name : String,
11 pub where_params: BTreeMap<String, Arc<dyn ToSql + Sync> > ,
12
13 pub order: Option<(String,OrderingDirection)> ,
14
15 pub limit: Option< u32 >,
16
17 pub pagination: Option<PaginationData>,
19}
20
21impl SqlBuilder {
22 pub fn build(&self) -> (String , Vec<Arc<dyn ToSql + Sync>> ) {
23 let mut query = format!("{} FROM {}", self.statement_base.build(), self.table_name);
24 let mut conditions = Vec::new();
25 let mut params: Vec<Arc<dyn ToSql + Sync>> = Vec::new();
26
27 for (key, param) in &self.where_params {
29 params.push(Arc::clone(param)); conditions.push(format!("{} = ${}", key, params.len()));
31 }
32
33 if !conditions.is_empty() {
34 query.push_str(" WHERE ");
35 query.push_str(&conditions.join(" AND "));
36 }
37
38 if let Some(pagination) = &self.pagination {
40 query.push_str(&format!(" {}", pagination.build_query_part()));
42 } else {
43 if let Some((column, direction)) = &self.order {
45 query.push_str(&format!(" ORDER BY {} {}", column, direction.build()));
46 }
47
48 if let Some(limit) = self.limit {
50 query.push_str(&format!(" LIMIT {}", limit));
51 }
52 }
53
54 ( query , params)
55 }
56
57 pub fn with_pagination(mut self, pagination: PaginationData) -> Self {
59 self.pagination = Some(pagination);
60 self
61 }
62}
63
64
65
66pub enum SqlStatementBase {
67 SelectAll,
68 SelectCountAll,
69 Delete
70}
71
72impl SqlStatementBase {
73
74 pub fn build(&self) -> String {
75
76 match self {
77
78 Self::SelectAll => "SELECT *" ,
79 Self::SelectCountAll => "SELECT COUNT(*)" ,
80 Self::Delete => "DELETE"
81
82 }.to_string()
83 }
84
85}
86
87pub enum OrderingDirection {
88
89 DESC,
90 ASC
91}
92
93
94impl OrderingDirection {
95
96 pub fn build(&self) -> String {
97
98 match self {
99
100 Self::DESC => "DESC" ,
101 Self::ASC => "ASC"
102
103 }.to_string()
104 }
105
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use std::collections::BTreeMap;
112 use std::sync::Arc;
113 use crate::pagination::{PaginationData, ColumnSortDir};
114 use crate::tiny_safe_string::TinySafeString;
115
116 #[test]
117 fn test_sql_builder() {
118 let mut where_params: BTreeMap<String, Arc<dyn ToSql + Sync>> = BTreeMap::new();
119 where_params.insert("chain_id".to_string(), Arc::new(1_i64));
120 where_params.insert("status".to_string(), Arc::new("active".to_string()));
121
122 let sql_builder = SqlBuilder {
123 statement_base: SqlStatementBase::SelectAll,
124 table_name: "teller_bids".to_string(),
125 where_params,
126 order: Some(("created_at".to_string(), OrderingDirection::DESC)),
127 limit: Some(10),
128 pagination: None,
129 };
130
131 let (query, params) = sql_builder.build();
132
133 assert_eq!(
134 query,
135 "SELECT * FROM teller_bids WHERE chain_id = $1 AND status = $2 ORDER BY created_at DESC LIMIT 10"
136 );
137
138 assert_eq!(
139 params.len(),
140 2
141 );
142 }
143
144 #[test]
145 fn test_sql_builder_with_pagination() {
146 let mut where_params: BTreeMap<String, Arc<dyn ToSql + Sync>> = BTreeMap::new();
147 where_params.insert("chain_id".to_string(), Arc::new(1_i64));
148
149 let mut pagination = PaginationData::default();
150 pagination.page = Some(2);
151 pagination.page_size = Some(20);
152 pagination.sort_by = Some(TinySafeString::new("updated_at").unwrap());
153 pagination.sort_dir = Some(ColumnSortDir::Asc);
154
155 let sql_builder = SqlBuilder {
156 statement_base: SqlStatementBase::SelectAll,
157 table_name: "teller_bids".to_string(),
158 where_params,
159 order: Some(("created_at".to_string(), OrderingDirection::DESC)), limit: Some(10), pagination: Some(pagination),
162 };
163
164 let (query, params) = sql_builder.build();
165
166 assert_eq!(
167 query,
168 "SELECT * FROM teller_bids WHERE chain_id = $1 ORDER BY updated_at ASC LIMIT 20 OFFSET 20"
169 );
170
171 assert_eq!(
172 params.len(),
173 1
174 );
175 }
176
177 #[test]
178 fn test_sql_builder_with_pagination_method() {
179 let mut where_params: BTreeMap<String, Arc<dyn ToSql + Sync>> = BTreeMap::new();
180 where_params.insert("status".to_string(), Arc::new("pending".to_string()));
181
182 let mut pagination = PaginationData::default();
183 pagination.page = Some(3);
184 pagination.page_size = Some(15);
185
186 let sql_builder = SqlBuilder {
187 statement_base: SqlStatementBase::SelectAll,
188 table_name: "orders".to_string(),
189 where_params,
190 order: None,
191 limit: None,
192 pagination: None,
193 }.with_pagination(pagination);
194
195 let (query, params) = sql_builder.build();
196
197 assert_eq!(
198 query,
199 "SELECT * FROM orders WHERE status = $1 ORDER BY created_at DESC LIMIT 15 OFFSET 30"
200 );
201
202 assert_eq!(
203 params.len(),
204 1
205 );
206 }
207
208 #[test]
210 fn test_sql_builder_count_query() {
211 let mut where_params: BTreeMap<String, Arc<dyn ToSql + Sync>> = BTreeMap::new();
212 where_params.insert("apikey".to_string(), Arc::new("test-api-key".to_string()));
213
214 let sql_builder = SqlBuilder {
215 statement_base: SqlStatementBase::SelectCountAll,
216 table_name: "api_keys".to_string(),
217 where_params,
218 order: None,
219 limit: None,
220 pagination: None,
221 };
222
223 let (query, params) = sql_builder.build();
224
225 assert_eq!(
226 query,
227 "SELECT COUNT(*) FROM api_keys WHERE apikey = $1"
228 );
229
230 assert_eq!(
231 params.len(),
232 1
233 );
234 }
235
236 #[test]
237 fn test_sql_builder_delete_query() {
238 let mut where_params: BTreeMap<String, Arc<dyn ToSql + Sync>> = BTreeMap::new();
239 where_params.insert("apikey".to_string(), Arc::new("test-api-key".to_string()));
240
241 let sql_builder = SqlBuilder {
242 statement_base: SqlStatementBase::Delete,
243 table_name: "api_keys".to_string(),
244 where_params,
245 order: None,
246 limit: None,
247 pagination: None,
248 };
249
250 let (query, params) = sql_builder.build();
251
252 assert_eq!(
253 query,
254 "DELETE FROM api_keys WHERE apikey = $1"
255 );
256
257 assert_eq!(
258 params.len(),
259 1
260 );
261 }
262
263 #[test]
264 fn test_delete_by_apikey_example() {
265 let apikey = "example-api-key";
269 let mut where_params: BTreeMap<String, Arc<dyn ToSql + Sync>> = BTreeMap::new();
270 where_params.insert("apikey".to_string(), Arc::new(apikey.to_string()));
271
272 let count_builder = SqlBuilder {
273 statement_base: SqlStatementBase::SelectCountAll,
274 table_name: "api_keys".to_string(),
275 where_params: where_params.clone(),
276 order: None,
277 limit: None,
278 pagination: None,
279 };
280
281 let (count_query, _count_params) = count_builder.build();
282
283 assert_eq!(
284 count_query,
285 "SELECT COUNT(*) FROM api_keys WHERE apikey = $1"
286 );
287
288 let delete_builder = SqlBuilder {
290 statement_base: SqlStatementBase::Delete,
291 table_name: "api_keys".to_string(),
292 where_params,
293 order: None,
294 limit: None,
295 pagination: None,
296 };
297
298 let (delete_query, _delete_params) = delete_builder.build();
299
300 assert_eq!(
301 delete_query,
302 "DELETE FROM api_keys WHERE apikey = $1"
303 );
304
305 }
357}