1use super::{placeholder, DbPool, DbRow};
2use sqlx::Row;
3
4fn validate_identifier(s: &str) -> bool {
8 if s.is_empty() {
9 return false;
10 }
11 s.chars().all(|c| c.is_alphanumeric() || c == '_' || c == '.' || c == '*' || c == ' ')
12}
13
14fn validate_column_list(s: &str) -> bool {
17 if s.is_empty() {
18 return false;
19 }
20 s.chars().all(|c| c.is_alphanumeric() || c == '_' || c == '.' || c == '*' || c == ' ' || c == ',')
21}
22
23fn validate_join(s: &str) -> bool {
27 if s.is_empty() {
28 return false;
29 }
30 !s.contains(';') && !s.contains('\'') && !s.contains('"') && !s.contains("--") && !s.contains("/*")
31}
32
33#[derive(Debug, Clone, Copy)]
35pub enum Order {
36 Asc,
37 Desc,
38}
39
40impl std::fmt::Display for Order {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 match self {
43 Order::Asc => write!(f, "ASC"),
44 Order::Desc => write!(f, "DESC"),
45 }
46 }
47}
48
49pub struct SelectBuilder {
78 table: String,
79 columns: String,
80 joins: Vec<String>,
81 conditions: Vec<Condition>,
82 order: Vec<(String, Order)>,
83 limit_val: Option<u32>,
84 offset_val: Option<u32>,
85 group_by_val: Option<String>,
86}
87
88enum Condition {
89 Eq(String, BindValue),
90 Ne(String, BindValue),
91 Gt(String, BindValue),
92 Gte(String, BindValue),
93 Lt(String, BindValue),
94 Lte(String, BindValue),
95 Like(String, BindValue),
96 IsNull(String),
97 IsNotNull(String),
98 In(String, Vec<BindValue>),
99 Raw(String),
100}
101
102#[derive(Clone)]
103#[doc(hidden)]
104pub enum BindValue {
105 Int(i64),
106 Float(f64),
107 String(String),
108 Bool(bool),
109}
110
111#[doc(hidden)]
112pub trait IntoBindValue {
113 fn into_bind_value(self) -> BindValue;
114}
115
116impl IntoBindValue for i32 {
117 fn into_bind_value(self) -> BindValue { BindValue::Int(self as i64) }
118}
119
120impl IntoBindValue for i64 {
121 fn into_bind_value(self) -> BindValue { BindValue::Int(self) }
122}
123
124impl IntoBindValue for u32 {
125 fn into_bind_value(self) -> BindValue { BindValue::Int(self as i64) }
126}
127
128impl IntoBindValue for f64 {
129 fn into_bind_value(self) -> BindValue { BindValue::Float(self) }
130}
131
132impl IntoBindValue for bool {
133 fn into_bind_value(self) -> BindValue { BindValue::Bool(self) }
134}
135
136impl IntoBindValue for &str {
137 fn into_bind_value(self) -> BindValue { BindValue::String(self.to_string()) }
138}
139
140impl IntoBindValue for String {
141 fn into_bind_value(self) -> BindValue { BindValue::String(self) }
142}
143
144impl SelectBuilder {
145 pub fn table(table: &str) -> Self {
146 assert!(validate_identifier(table), "Invalid table name: {table}");
147 Self {
148 table: table.to_string(),
149 columns: "*".to_string(),
150 joins: Vec::new(),
151 conditions: Vec::new(),
152 order: Vec::new(),
153 limit_val: None,
154 offset_val: None,
155 group_by_val: None,
156 }
157 }
158
159 pub fn columns(mut self, cols: &str) -> Self {
160 assert!(validate_column_list(cols), "Invalid column list: {cols}");
161 self.columns = cols.to_string();
162 self
163 }
164
165 pub fn join(mut self, join_clause: &str) -> Self {
167 assert!(validate_join(join_clause), "Invalid JOIN clause: contains forbidden characters");
168 self.joins.push(join_clause.to_string());
169 self
170 }
171
172 pub fn where_eq<V: IntoBindValue>(mut self, column: &str, value: V) -> Self {
173 assert!(validate_identifier(column), "Invalid column name: {column}");
174 self.conditions.push(Condition::Eq(column.to_string(), value.into_bind_value()));
175 self
176 }
177
178 pub fn where_ne<V: IntoBindValue>(mut self, column: &str, value: V) -> Self {
179 assert!(validate_identifier(column), "Invalid column name: {column}");
180 self.conditions.push(Condition::Ne(column.to_string(), value.into_bind_value()));
181 self
182 }
183
184 pub fn where_gt<V: IntoBindValue>(mut self, column: &str, value: V) -> Self {
185 assert!(validate_identifier(column), "Invalid column name: {column}");
186 self.conditions.push(Condition::Gt(column.to_string(), value.into_bind_value()));
187 self
188 }
189
190 pub fn where_gte<V: IntoBindValue>(mut self, column: &str, value: V) -> Self {
191 assert!(validate_identifier(column), "Invalid column name: {column}");
192 self.conditions.push(Condition::Gte(column.to_string(), value.into_bind_value()));
193 self
194 }
195
196 pub fn where_lt<V: IntoBindValue>(mut self, column: &str, value: V) -> Self {
197 assert!(validate_identifier(column), "Invalid column name: {column}");
198 self.conditions.push(Condition::Lt(column.to_string(), value.into_bind_value()));
199 self
200 }
201
202 pub fn where_lte<V: IntoBindValue>(mut self, column: &str, value: V) -> Self {
203 assert!(validate_identifier(column), "Invalid column name: {column}");
204 self.conditions.push(Condition::Lte(column.to_string(), value.into_bind_value()));
205 self
206 }
207
208 pub fn where_like<V: IntoBindValue>(mut self, column: &str, pattern: V) -> Self {
209 assert!(validate_identifier(column), "Invalid column name: {column}");
210 self.conditions.push(Condition::Like(column.to_string(), pattern.into_bind_value()));
211 self
212 }
213
214 pub fn where_null(mut self, column: &str) -> Self {
215 assert!(validate_identifier(column), "Invalid column name: {column}");
216 self.conditions.push(Condition::IsNull(column.to_string()));
217 self
218 }
219
220 pub fn where_not_null(mut self, column: &str) -> Self {
221 assert!(validate_identifier(column), "Invalid column name: {column}");
222 self.conditions.push(Condition::IsNotNull(column.to_string()));
223 self
224 }
225
226 pub fn where_in<V: IntoBindValue>(mut self, column: &str, values: &[V]) -> Self
227 where V: Clone {
228 assert!(validate_identifier(column), "Invalid column name: {column}");
229 let bind_values: Vec<BindValue> = values.iter().map(|v| v.clone().into_bind_value()).collect();
230 self.conditions.push(Condition::In(column.to_string(), bind_values));
231 self
232 }
233
234 pub fn where_raw(mut self, raw: &str) -> Self {
237 assert!(validate_join(raw), "where_raw contains forbidden characters");
238 self.conditions.push(Condition::Raw(raw.to_string()));
239 self
240 }
241
242 pub fn order_by(mut self, column: &str, direction: Order) -> Self {
243 assert!(validate_identifier(column), "Invalid ORDER BY column: {column}");
244 self.order.push((column.to_string(), direction));
245 self
246 }
247
248 pub fn limit(mut self, limit: u32) -> Self {
249 self.limit_val = Some(limit);
250 self
251 }
252
253 pub fn offset(mut self, offset: u32) -> Self {
254 self.offset_val = Some(offset);
255 self
256 }
257
258 pub fn group_by(mut self, clause: &str) -> Self {
259 assert!(validate_column_list(clause), "Invalid GROUP BY clause: {clause}");
260 self.group_by_val = Some(clause.to_string());
261 self
262 }
263
264 pub async fn fetch_all<T>(self, pool: &DbPool) -> Result<Vec<T>, sqlx::Error>
266 where
267 T: for<'r> sqlx::FromRow<'r, DbRow> + Send + Unpin,
268 {
269 let (sql, binds) = self.build_select();
270 let mut query = sqlx::query_as::<_, T>(&sql);
271 for bind in &binds {
272 query = bind_value(query, bind);
273 }
274 query.fetch_all(pool).await
275 }
276
277 pub async fn fetch_one<T>(self, pool: &DbPool) -> Result<Option<T>, sqlx::Error>
279 where
280 T: for<'r> sqlx::FromRow<'r, DbRow> + Send + Unpin,
281 {
282 let (sql, binds) = self.limit(1).build_select();
283 let mut query = sqlx::query_as::<_, T>(&sql);
284 for bind in &binds {
285 query = bind_value(query, bind);
286 }
287 query.fetch_optional(pool).await
288 }
289
290 pub async fn count(self, pool: &DbPool) -> Result<i64, sqlx::Error> {
292 let (sql, binds) = self.build_count();
293 let mut query = sqlx::query(&sql);
294 for bind in &binds {
295 query = bind_value_raw(query, bind);
296 }
297 let row = query.fetch_one(pool).await?;
298 Ok(row.try_get::<i64, _>(0).unwrap_or(0))
299 }
300
301 fn build_select(self) -> (String, Vec<BindValue>) {
303 let mut binds = Vec::new();
304 let mut idx = 1usize;
305
306 let joins = self.joins.join(" ");
307 let where_clause = self.build_where(&mut binds, &mut idx);
308
309 let order_clause = if self.order.is_empty() {
310 String::new()
311 } else {
312 let parts: Vec<String> = self.order.iter()
313 .map(|(col, dir)| format!("{} {}", col, dir))
314 .collect();
315 format!(" ORDER BY {}", parts.join(", "))
316 };
317
318 let group = self.group_by_val
319 .as_ref()
320 .map(|g| format!(" GROUP BY {}", g))
321 .unwrap_or_default();
322
323 let limit = self.limit_val
324 .map(|l| format!(" LIMIT {}", l))
325 .unwrap_or_default();
326
327 let offset = self.offset_val
328 .map(|o| format!(" OFFSET {}", o))
329 .unwrap_or_default();
330
331 let sql = format!(
332 "SELECT {} FROM {} {}{}{}{}{}{}",
333 self.columns, self.table, joins, where_clause, group, order_clause, limit, offset
334 );
335
336 (sql.trim().to_string(), binds)
337 }
338
339 fn build_count(self) -> (String, Vec<BindValue>) {
340 let mut binds = Vec::new();
341 let mut idx = 1usize;
342
343 let joins = self.joins.join(" ");
344 let where_clause = self.build_where(&mut binds, &mut idx);
345
346 let group = self.group_by_val
347 .as_ref()
348 .map(|g| format!(" GROUP BY {}", g))
349 .unwrap_or_default();
350
351 let sql = format!(
352 "SELECT COUNT(*) FROM {} {}{}{}",
353 self.table, joins, where_clause, group
354 );
355
356 (sql.trim().to_string(), binds)
357 }
358
359 fn build_where(&self, binds: &mut Vec<BindValue>, idx: &mut usize) -> String {
360 if self.conditions.is_empty() {
361 return String::new();
362 }
363
364 let parts: Vec<String> = self.conditions.iter().map(|c| {
365 match c {
366 Condition::Eq(col, val) => {
367 let ph = placeholder(*idx);
368 *idx += 1;
369 binds.push(val.clone());
370 format!("{} = {}", col, ph)
371 }
372 Condition::Ne(col, val) => {
373 let ph = placeholder(*idx);
374 *idx += 1;
375 binds.push(val.clone());
376 format!("{} != {}", col, ph)
377 }
378 Condition::Gt(col, val) => {
379 let ph = placeholder(*idx);
380 *idx += 1;
381 binds.push(val.clone());
382 format!("{} > {}", col, ph)
383 }
384 Condition::Gte(col, val) => {
385 let ph = placeholder(*idx);
386 *idx += 1;
387 binds.push(val.clone());
388 format!("{} >= {}", col, ph)
389 }
390 Condition::Lt(col, val) => {
391 let ph = placeholder(*idx);
392 *idx += 1;
393 binds.push(val.clone());
394 format!("{} < {}", col, ph)
395 }
396 Condition::Lte(col, val) => {
397 let ph = placeholder(*idx);
398 *idx += 1;
399 binds.push(val.clone());
400 format!("{} <= {}", col, ph)
401 }
402 Condition::Like(col, val) => {
403 let ph = placeholder(*idx);
404 *idx += 1;
405 binds.push(val.clone());
406 format!("{} LIKE {}", col, ph)
407 }
408 Condition::IsNull(col) => format!("{} IS NULL", col),
409 Condition::IsNotNull(col) => format!("{} IS NOT NULL", col),
410 Condition::In(col, vals) => {
411 let placeholders: Vec<String> = vals.iter().map(|v| {
412 let ph = placeholder(*idx);
413 *idx += 1;
414 binds.push(v.clone());
415 ph
416 }).collect();
417 format!("{} IN ({})", col, placeholders.join(", "))
418 }
419 Condition::Raw(raw) => raw.clone(),
420 }
421 }).collect();
422
423 format!(" WHERE {}", parts.join(" AND "))
424 }
425}
426
427fn bind_value<'q, T>(
429 query: sqlx::query::QueryAs<'q, super::Db, T, super::DbArguments>,
430 value: &'q BindValue,
431) -> sqlx::query::QueryAs<'q, super::Db, T, super::DbArguments>
432where
433 T: for<'r> sqlx::FromRow<'r, DbRow>,
434{
435 match value {
436 BindValue::Int(v) => query.bind(*v),
437 BindValue::Float(v) => query.bind(*v),
438 BindValue::String(v) => query.bind(v.as_str()),
439 BindValue::Bool(v) => query.bind(*v),
440 }
441}
442
443fn bind_value_raw<'q>(
445 query: sqlx::query::Query<'q, super::Db, super::DbArguments>,
446 value: &'q BindValue,
447) -> sqlx::query::Query<'q, super::Db, super::DbArguments> {
448 match value {
449 BindValue::Int(v) => query.bind(*v),
450 BindValue::Float(v) => query.bind(*v),
451 BindValue::String(v) => query.bind(v.as_str()),
452 BindValue::Bool(v) => query.bind(*v),
453 }
454}