1
2use std::sync::Arc;
3use std::collections::BTreeMap;
4
5use tokio_postgres::types::ToSql;
6use crate::pagination::PaginationData;
7use crate::tiny_safe_string::TinySafeString;
8
9
10pub struct SqlBuilder {
62 pub statement_base: SqlStatementBase,
63 pub table_name : String,
64 pub where_params: BTreeMap<TinySafeString, Arc<dyn ToSql + Sync> > ,
65
66 pub order: Option<(TinySafeString,OrderingDirection)> ,
67
68 pub limit: Option< u32 >,
69
70 pub pagination: Option<PaginationData>,
72}
73
74impl SqlBuilder {
75 pub fn build(&self) -> (String , Vec<Arc<dyn ToSql + Sync>> ) {
76 let mut query = format!("{} FROM {}", self.statement_base.build(), self.table_name);
77 let mut conditions = Vec::new();
78 let mut params: Vec<Arc<dyn ToSql + Sync>> = Vec::new();
79
80 for (key, param) in &self.where_params {
82 params.push(Arc::clone(param)); conditions.push(format!("{} = ${}", key, params.len()));
84 }
85
86 if !conditions.is_empty() {
87 query.push_str(" WHERE ");
88 query.push_str(&conditions.join(" AND "));
89 }
90
91 if let Some(pagination) = &self.pagination {
93 query.push_str(&format!(" {}", pagination.build_query_part()));
95 } else {
96 if let Some((column, direction)) = &self.order {
98 query.push_str(&format!(" ORDER BY {} {}", column, direction.build()));
99 }
100
101 if let Some(limit) = self.limit {
103 query.push_str(&format!(" LIMIT {}", limit));
104 }
105 }
106
107 ( query , params)
108 }
109
110 pub fn with_pagination(mut self, pagination: PaginationData) -> Self {
112 self.pagination = Some(pagination);
113 self
114 }
115}
116
117
118
119pub enum SqlStatementBase {
120 SelectAll,
121 SelectCountAll,
122 Delete
123}
124
125impl SqlStatementBase {
126
127 pub fn build(&self) -> String {
128
129 match self {
130
131 Self::SelectAll => "SELECT *" ,
132 Self::SelectCountAll => "SELECT COUNT(*)" ,
133 Self::Delete => "DELETE"
134
135 }.to_string()
136 }
137
138}
139
140pub enum OrderingDirection {
141
142 DESC,
143 ASC
144}
145
146
147impl OrderingDirection {
148
149 pub fn build(&self) -> String {
150
151 match self {
152
153 Self::DESC => "DESC" ,
154 Self::ASC => "ASC"
155
156 }.to_string()
157 }
158
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164 use std::collections::BTreeMap;
165 use std::sync::Arc;
166 use crate::pagination::{PaginationData, ColumnSortDir};
167 use crate::tiny_safe_string::TinySafeString;
168
169 #[test]
170 fn test_sql_builder() {
171 let mut where_params: BTreeMap<TinySafeString, Arc<dyn ToSql + Sync>> = BTreeMap::new();
172 where_params.insert(TinySafeString::new("chain_id").unwrap(), Arc::new(1_i64));
173 where_params.insert(TinySafeString::new("status").unwrap(), Arc::new("active".to_string()));
174
175 let sql_builder = SqlBuilder {
176 statement_base: SqlStatementBase::SelectAll,
177 table_name: "teller_bids".to_string(),
178 where_params,
179 order: Some((TinySafeString::new("created_at").unwrap(), OrderingDirection::DESC)),
180 limit: Some(10),
181 pagination: None,
182 };
183
184 let (query, params) = sql_builder.build();
185
186 assert_eq!(
187 query,
188 "SELECT * FROM teller_bids WHERE chain_id = $1 AND status = $2 ORDER BY created_at DESC LIMIT 10"
189 );
190
191 assert_eq!(
192 params.len(),
193 2
194 );
195 }
196
197 #[test]
198 fn test_sql_builder_with_pagination() {
199 let mut where_params: BTreeMap<TinySafeString, Arc<dyn ToSql + Sync>> = BTreeMap::new();
200 where_params.insert(TinySafeString::new("chain_id").unwrap(), Arc::new(1_i64));
201
202 let mut pagination = PaginationData::default();
203 pagination.page = Some(2);
204 pagination.page_size = Some(20);
205 pagination.sort_by = Some(TinySafeString::new("updated_at").unwrap());
206 pagination.sort_dir = Some(ColumnSortDir::Asc);
207
208 let sql_builder = SqlBuilder {
209 statement_base: SqlStatementBase::SelectAll,
210 table_name: "teller_bids".to_string(),
211 where_params,
212 order: Some((TinySafeString::new("created_at").unwrap(), OrderingDirection::DESC)), limit: Some(10), pagination: Some(pagination),
215 };
216
217 let (query, params) = sql_builder.build();
218
219 assert_eq!(
220 query,
221 "SELECT * FROM teller_bids WHERE chain_id = $1 ORDER BY updated_at ASC LIMIT 20 OFFSET 20"
222 );
223
224 assert_eq!(
225 params.len(),
226 1
227 );
228 }
229
230 #[test]
231 fn test_sql_builder_with_pagination_method() {
232 let mut where_params: BTreeMap<TinySafeString, Arc<dyn ToSql + Sync>> = BTreeMap::new();
233 where_params.insert(TinySafeString::new("status").unwrap(), Arc::new("pending".to_string()));
234
235 let mut pagination = PaginationData::default();
236 pagination.page = Some(3);
237 pagination.page_size = Some(15);
238
239 let sql_builder = SqlBuilder {
240 statement_base: SqlStatementBase::SelectAll,
241 table_name: "orders".to_string(),
242 where_params,
243 order: None,
244 limit: None,
245 pagination: None,
246 }.with_pagination(pagination);
247
248 let (query, params) = sql_builder.build();
249
250 assert_eq!(
251 query,
252 "SELECT * FROM orders WHERE status = $1 ORDER BY created_at DESC LIMIT 15 OFFSET 30"
253 );
254
255 assert_eq!(
256 params.len(),
257 1
258 );
259 }
260
261 #[test]
263 fn test_sql_builder_count_query() {
264 let mut where_params: BTreeMap<TinySafeString, Arc<dyn ToSql + Sync>> = BTreeMap::new();
265 where_params.insert(TinySafeString::new("apikey").unwrap(), Arc::new("test-api-key".to_string()));
266
267 let sql_builder = SqlBuilder {
268 statement_base: SqlStatementBase::SelectCountAll,
269 table_name: "api_keys".to_string(),
270 where_params,
271 order: None,
272 limit: None,
273 pagination: None,
274 };
275
276 let (query, params) = sql_builder.build();
277
278 assert_eq!(
279 query,
280 "SELECT COUNT(*) FROM api_keys WHERE apikey = $1"
281 );
282
283 assert_eq!(
284 params.len(),
285 1
286 );
287 }
288
289 #[test]
290 fn test_sql_builder_delete_query() {
291 let mut where_params: BTreeMap<TinySafeString, Arc<dyn ToSql + Sync>> = BTreeMap::new();
292 where_params.insert(TinySafeString::new("apikey").unwrap(), Arc::new("test-api-key".to_string()));
293
294 let sql_builder = SqlBuilder {
295 statement_base: SqlStatementBase::Delete,
296 table_name: "api_keys".to_string(),
297 where_params,
298 order: None,
299 limit: None,
300 pagination: None,
301 };
302
303 let (query, params) = sql_builder.build();
304
305 assert_eq!(
306 query,
307 "DELETE FROM api_keys WHERE apikey = $1"
308 );
309
310 assert_eq!(
311 params.len(),
312 1
313 );
314 }
315
316 #[test]
317 fn test_delete_by_apikey_example() {
318 let apikey = "example-api-key";
322 let mut where_params: BTreeMap<TinySafeString, Arc<dyn ToSql + Sync>> = BTreeMap::new();
323 where_params.insert(TinySafeString::new("apikey").unwrap(), Arc::new(apikey.to_string()));
324
325 let count_builder = SqlBuilder {
326 statement_base: SqlStatementBase::SelectCountAll,
327 table_name: "api_keys".to_string(),
328 where_params: where_params.clone(),
329 order: None,
330 limit: None,
331 pagination: None,
332 };
333
334 let (count_query, _count_params) = count_builder.build();
335
336 assert_eq!(
337 count_query,
338 "SELECT COUNT(*) FROM api_keys WHERE apikey = $1"
339 );
340
341 let delete_builder = SqlBuilder {
343 statement_base: SqlStatementBase::Delete,
344 table_name: "api_keys".to_string(),
345 where_params,
346 order: None,
347 limit: None,
348 pagination: None,
349 };
350
351 let (delete_query, _delete_params) = delete_builder.build();
352
353 assert_eq!(
354 delete_query,
355 "DELETE FROM api_keys WHERE apikey = $1"
356 );
357
358 }
410}