1use super::{placeholder, DbPool, DbRow};
2use sqlx::Row;
3
4fn validate_identifier(s: &str) -> bool {
8 if s.is_empty() {
9 return false;
10 }
11 s.chars().all(|c| c.is_alphanumeric() || c == '_' || c == '.' || c == '*' || c == ' ')
12}
13
14fn validate_column_list(s: &str) -> bool {
17 if s.is_empty() {
18 return false;
19 }
20 s.chars().all(|c| c.is_alphanumeric() || c == '_' || c == '.' || c == '*' || c == ' ' || c == ',')
21}
22
23fn validate_join(s: &str) -> bool {
27 if s.is_empty() {
28 return false;
29 }
30 !s.contains(';') && !s.contains('\'') && !s.contains('"') && !s.contains("--") && !s.contains("/*")
31}
32
33#[derive(Debug, Clone, Copy)]
35pub enum Order {
36 Asc,
37 Desc,
38}
39
40impl std::fmt::Display for Order {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 match self {
43 Order::Asc => write!(f, "ASC"),
44 Order::Desc => write!(f, "DESC"),
45 }
46 }
47}
48
49pub struct SelectBuilder {
78 table: String,
79 columns: String,
80 joins: Vec<String>,
81 conditions: Vec<Condition>,
82 order: Vec<(String, Order)>,
83 limit_val: Option<u32>,
84 offset_val: Option<u32>,
85 group_by_val: Option<String>,
86}
87
88enum Condition {
89 Eq(String, BindValue),
90 Ne(String, BindValue),
91 Gt(String, BindValue),
92 Gte(String, BindValue),
93 Lt(String, BindValue),
94 Lte(String, BindValue),
95 Like(String, BindValue),
96 IsNull(String),
97 IsNotNull(String),
98 In(String, Vec<BindValue>),
99 Raw(String),
100}
101
102#[derive(Clone)]
103#[doc(hidden)]
104pub enum BindValue {
105 Int(i64),
106 Float(f64),
107 String(String),
108 Bool(bool),
109}
110
111#[doc(hidden)]
112pub trait IntoBindValue {
113 fn into_bind_value(self) -> BindValue;
114}
115
116impl IntoBindValue for i32 {
117 fn into_bind_value(self) -> BindValue { BindValue::Int(self as i64) }
118}
119
120impl IntoBindValue for i64 {
121 fn into_bind_value(self) -> BindValue { BindValue::Int(self) }
122}
123
124impl IntoBindValue for u32 {
125 fn into_bind_value(self) -> BindValue { BindValue::Int(self as i64) }
126}
127
128impl IntoBindValue for f64 {
129 fn into_bind_value(self) -> BindValue { BindValue::Float(self) }
130}
131
132impl IntoBindValue for bool {
133 fn into_bind_value(self) -> BindValue { BindValue::Bool(self) }
134}
135
136impl IntoBindValue for &str {
137 fn into_bind_value(self) -> BindValue { BindValue::String(self.to_string()) }
138}
139
140impl IntoBindValue for String {
141 fn into_bind_value(self) -> BindValue { BindValue::String(self) }
142}
143
144impl IntoBindValue for chrono::DateTime<chrono::Utc> {
145 fn into_bind_value(self) -> BindValue {
146 BindValue::String(self.format("%Y-%m-%d %H:%M:%S%.6f").to_string())
147 }
148}
149
150impl SelectBuilder {
151 pub fn table(table: &str) -> Self {
152 assert!(validate_identifier(table), "Invalid table name: {table}");
153 Self {
154 table: table.to_string(),
155 columns: "*".to_string(),
156 joins: Vec::new(),
157 conditions: Vec::new(),
158 order: Vec::new(),
159 limit_val: None,
160 offset_val: None,
161 group_by_val: None,
162 }
163 }
164
165 pub fn columns(mut self, cols: &str) -> Self {
166 assert!(validate_column_list(cols), "Invalid column list: {cols}");
167 self.columns = cols.to_string();
168 self
169 }
170
171 pub fn join(mut self, join_clause: &str) -> Self {
173 assert!(validate_join(join_clause), "Invalid JOIN clause: contains forbidden characters");
174 self.joins.push(join_clause.to_string());
175 self
176 }
177
178 pub fn where_eq<V: IntoBindValue>(mut self, column: &str, value: V) -> Self {
179 assert!(validate_identifier(column), "Invalid column name: {column}");
180 self.conditions.push(Condition::Eq(column.to_string(), value.into_bind_value()));
181 self
182 }
183
184 pub fn where_ne<V: IntoBindValue>(mut self, column: &str, value: V) -> Self {
185 assert!(validate_identifier(column), "Invalid column name: {column}");
186 self.conditions.push(Condition::Ne(column.to_string(), value.into_bind_value()));
187 self
188 }
189
190 pub fn where_gt<V: IntoBindValue>(mut self, column: &str, value: V) -> Self {
191 assert!(validate_identifier(column), "Invalid column name: {column}");
192 self.conditions.push(Condition::Gt(column.to_string(), value.into_bind_value()));
193 self
194 }
195
196 pub fn where_gte<V: IntoBindValue>(mut self, column: &str, value: V) -> Self {
197 assert!(validate_identifier(column), "Invalid column name: {column}");
198 self.conditions.push(Condition::Gte(column.to_string(), value.into_bind_value()));
199 self
200 }
201
202 pub fn where_lt<V: IntoBindValue>(mut self, column: &str, value: V) -> Self {
203 assert!(validate_identifier(column), "Invalid column name: {column}");
204 self.conditions.push(Condition::Lt(column.to_string(), value.into_bind_value()));
205 self
206 }
207
208 pub fn where_lte<V: IntoBindValue>(mut self, column: &str, value: V) -> Self {
209 assert!(validate_identifier(column), "Invalid column name: {column}");
210 self.conditions.push(Condition::Lte(column.to_string(), value.into_bind_value()));
211 self
212 }
213
214 pub fn where_like<V: IntoBindValue>(mut self, column: &str, pattern: V) -> Self {
215 assert!(validate_identifier(column), "Invalid column name: {column}");
216 self.conditions.push(Condition::Like(column.to_string(), pattern.into_bind_value()));
217 self
218 }
219
220 pub fn where_null(mut self, column: &str) -> Self {
221 assert!(validate_identifier(column), "Invalid column name: {column}");
222 self.conditions.push(Condition::IsNull(column.to_string()));
223 self
224 }
225
226 pub fn where_not_null(mut self, column: &str) -> Self {
227 assert!(validate_identifier(column), "Invalid column name: {column}");
228 self.conditions.push(Condition::IsNotNull(column.to_string()));
229 self
230 }
231
232 pub fn where_in<V: IntoBindValue>(mut self, column: &str, values: &[V]) -> Self
233 where V: Clone {
234 assert!(validate_identifier(column), "Invalid column name: {column}");
235 let bind_values: Vec<BindValue> = values.iter().map(|v| v.clone().into_bind_value()).collect();
236 self.conditions.push(Condition::In(column.to_string(), bind_values));
237 self
238 }
239
240 pub fn where_raw(mut self, raw: &str) -> Self {
243 assert!(validate_join(raw), "where_raw contains forbidden characters");
244 self.conditions.push(Condition::Raw(raw.to_string()));
245 self
246 }
247
248 pub fn order_by(mut self, column: &str, direction: Order) -> Self {
249 assert!(validate_identifier(column), "Invalid ORDER BY column: {column}");
250 self.order.push((column.to_string(), direction));
251 self
252 }
253
254 pub fn limit(mut self, limit: u32) -> Self {
255 self.limit_val = Some(limit);
256 self
257 }
258
259 pub fn offset(mut self, offset: u32) -> Self {
260 self.offset_val = Some(offset);
261 self
262 }
263
264 pub fn group_by(mut self, clause: &str) -> Self {
265 assert!(validate_column_list(clause), "Invalid GROUP BY clause: {clause}");
266 self.group_by_val = Some(clause.to_string());
267 self
268 }
269
270 pub async fn fetch_all<T>(self, pool: &DbPool) -> Result<Vec<T>, sqlx::Error>
272 where
273 T: for<'r> sqlx::FromRow<'r, DbRow> + Send + Unpin,
274 {
275 let (sql, binds) = self.build_select();
276 let mut query = sqlx::query_as::<_, T>(&sql);
277 for bind in &binds {
278 query = bind_value(query, bind);
279 }
280 query.fetch_all(pool).await
281 }
282
283 pub async fn fetch_one<T>(self, pool: &DbPool) -> Result<Option<T>, sqlx::Error>
285 where
286 T: for<'r> sqlx::FromRow<'r, DbRow> + Send + Unpin,
287 {
288 let (sql, binds) = self.limit(1).build_select();
289 let mut query = sqlx::query_as::<_, T>(&sql);
290 for bind in &binds {
291 query = bind_value(query, bind);
292 }
293 query.fetch_optional(pool).await
294 }
295
296 pub async fn count(self, pool: &DbPool) -> Result<i64, sqlx::Error> {
298 let (sql, binds) = self.build_count();
299 let mut query = sqlx::query(&sql);
300 for bind in &binds {
301 query = bind_value_raw(query, bind);
302 }
303 let row = query.fetch_one(pool).await?;
304 Ok(row.try_get::<i64, _>(0).unwrap_or(0))
305 }
306
307 fn build_select(self) -> (String, Vec<BindValue>) {
309 let mut binds = Vec::new();
310 let mut idx = 1usize;
311
312 let joins = self.joins.join(" ");
313 let where_clause = self.build_where(&mut binds, &mut idx);
314
315 let order_clause = if self.order.is_empty() {
316 String::new()
317 } else {
318 let parts: Vec<String> = self.order.iter()
319 .map(|(col, dir)| format!("{} {}", col, dir))
320 .collect();
321 format!(" ORDER BY {}", parts.join(", "))
322 };
323
324 let group = self.group_by_val
325 .as_ref()
326 .map(|g| format!(" GROUP BY {}", g))
327 .unwrap_or_default();
328
329 let limit = self.limit_val
330 .map(|l| format!(" LIMIT {}", l))
331 .unwrap_or_default();
332
333 let offset = self.offset_val
334 .map(|o| format!(" OFFSET {}", o))
335 .unwrap_or_default();
336
337 let sql = format!(
338 "SELECT {} FROM {} {}{}{}{}{}{}",
339 self.columns, self.table, joins, where_clause, group, order_clause, limit, offset
340 );
341
342 (sql.trim().to_string(), binds)
343 }
344
345 fn build_count(self) -> (String, Vec<BindValue>) {
346 let mut binds = Vec::new();
347 let mut idx = 1usize;
348
349 let joins = self.joins.join(" ");
350 let where_clause = self.build_where(&mut binds, &mut idx);
351
352 let group = self.group_by_val
353 .as_ref()
354 .map(|g| format!(" GROUP BY {}", g))
355 .unwrap_or_default();
356
357 let sql = format!(
358 "SELECT COUNT(*) FROM {} {}{}{}",
359 self.table, joins, where_clause, group
360 );
361
362 (sql.trim().to_string(), binds)
363 }
364
365 fn build_where(&self, binds: &mut Vec<BindValue>, idx: &mut usize) -> String {
366 if self.conditions.is_empty() {
367 return String::new();
368 }
369
370 let parts: Vec<String> = self.conditions.iter().map(|c| {
371 match c {
372 Condition::Eq(col, val) => {
373 let ph = placeholder(*idx);
374 *idx += 1;
375 binds.push(val.clone());
376 format!("{} = {}", col, ph)
377 }
378 Condition::Ne(col, val) => {
379 let ph = placeholder(*idx);
380 *idx += 1;
381 binds.push(val.clone());
382 format!("{} != {}", col, ph)
383 }
384 Condition::Gt(col, val) => {
385 let ph = placeholder(*idx);
386 *idx += 1;
387 binds.push(val.clone());
388 format!("{} > {}", col, ph)
389 }
390 Condition::Gte(col, val) => {
391 let ph = placeholder(*idx);
392 *idx += 1;
393 binds.push(val.clone());
394 format!("{} >= {}", col, ph)
395 }
396 Condition::Lt(col, val) => {
397 let ph = placeholder(*idx);
398 *idx += 1;
399 binds.push(val.clone());
400 format!("{} < {}", col, ph)
401 }
402 Condition::Lte(col, val) => {
403 let ph = placeholder(*idx);
404 *idx += 1;
405 binds.push(val.clone());
406 format!("{} <= {}", col, ph)
407 }
408 Condition::Like(col, val) => {
409 let ph = placeholder(*idx);
410 *idx += 1;
411 binds.push(val.clone());
412 format!("{} LIKE {}", col, ph)
413 }
414 Condition::IsNull(col) => format!("{} IS NULL", col),
415 Condition::IsNotNull(col) => format!("{} IS NOT NULL", col),
416 Condition::In(col, vals) => {
417 let placeholders: Vec<String> = vals.iter().map(|v| {
418 let ph = placeholder(*idx);
419 *idx += 1;
420 binds.push(v.clone());
421 ph
422 }).collect();
423 format!("{} IN ({})", col, placeholders.join(", "))
424 }
425 Condition::Raw(raw) => raw.clone(),
426 }
427 }).collect();
428
429 format!(" WHERE {}", parts.join(" AND "))
430 }
431}
432
433fn bind_value<'q, T>(
435 query: sqlx::query::QueryAs<'q, super::Db, T, super::DbArguments>,
436 value: &'q BindValue,
437) -> sqlx::query::QueryAs<'q, super::Db, T, super::DbArguments>
438where
439 T: for<'r> sqlx::FromRow<'r, DbRow>,
440{
441 match value {
442 BindValue::Int(v) => query.bind(*v),
443 BindValue::Float(v) => query.bind(*v),
444 BindValue::String(v) => query.bind(v.as_str()),
445 BindValue::Bool(v) => query.bind(*v),
446 }
447}
448
449fn bind_value_raw<'q>(
451 query: sqlx::query::Query<'q, super::Db, super::DbArguments>,
452 value: &'q BindValue,
453) -> sqlx::query::Query<'q, super::Db, super::DbArguments> {
454 match value {
455 BindValue::Int(v) => query.bind(*v),
456 BindValue::Float(v) => query.bind(*v),
457 BindValue::String(v) => query.bind(v.as_str()),
458 BindValue::Bool(v) => query.bind(*v),
459 }
460}