Skip to main content

karbon_framework/db/
select_builder.rs

1use super::{placeholder, DbPool, DbRow};
2use sqlx::Row;
3
4/// Validate a SQL identifier (table name, column name, alias).
5/// Allows: alphanumeric, underscore, dot (for table.column), star (*).
6/// Rejects anything that could be SQL injection.
7fn validate_identifier(s: &str) -> bool {
8    if s.is_empty() {
9        return false;
10    }
11    s.chars().all(|c| c.is_alphanumeric() || c == '_' || c == '.' || c == '*' || c == ' ')
12}
13
14/// Validate a column list (e.g., "id, name, email, u.created_at").
15/// Allows commas and spaces in addition to identifier chars.
16fn validate_column_list(s: &str) -> bool {
17    if s.is_empty() {
18        return false;
19    }
20    s.chars().all(|c| c.is_alphanumeric() || c == '_' || c == '.' || c == '*' || c == ' ' || c == ',')
21}
22
23/// Validate a JOIN clause.
24/// Allows: alphanumeric, underscore, dot, space, equals, and common JOIN keywords.
25/// Rejects semicolons, quotes, comments, and other injection vectors.
26fn validate_join(s: &str) -> bool {
27    if s.is_empty() {
28        return false;
29    }
30    !s.contains(';') && !s.contains('\'') && !s.contains('"') && !s.contains("--") && !s.contains("/*")
31}
32
33/// Sort direction
34#[derive(Debug, Clone, Copy)]
35pub enum Order {
36    Asc,
37    Desc,
38}
39
40impl std::fmt::Display for Order {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        match self {
43            Order::Asc => write!(f, "ASC"),
44            Order::Desc => write!(f, "DESC"),
45        }
46    }
47}
48
49/// A typed, fluent SELECT query builder with parameterized conditions.
50///
51/// ```ignore
52/// use framework::db::{SelectBuilder, Order};
53///
54/// let users: Vec<User> = SelectBuilder::table("users")
55///     .columns("id, name, email, created_at")
56///     .where_eq("active", true)
57///     .where_like("name", "%alice%")
58///     .where_gt("age", 18)
59///     .where_in("role", &["admin", "editor"])
60///     .where_null("deleted_at")
61///     .order_by("created_at", Order::Desc)
62///     .limit(20)
63///     .offset(0)
64///     .fetch_all(&pool)
65///     .await?;
66///
67/// let count = SelectBuilder::table("users")
68///     .where_eq("active", true)
69///     .count(&pool)
70///     .await?;
71///
72/// let user: Option<User> = SelectBuilder::table("users")
73///     .where_eq("email", "alice@example.com")
74///     .fetch_one(&pool)
75///     .await?;
76/// ```
77pub struct SelectBuilder {
78    table: String,
79    columns: String,
80    joins: Vec<String>,
81    conditions: Vec<Condition>,
82    order: Vec<(String, Order)>,
83    limit_val: Option<u32>,
84    offset_val: Option<u32>,
85    group_by_val: Option<String>,
86}
87
88enum Condition {
89    Eq(String, BindValue),
90    Ne(String, BindValue),
91    Gt(String, BindValue),
92    Gte(String, BindValue),
93    Lt(String, BindValue),
94    Lte(String, BindValue),
95    Like(String, BindValue),
96    IsNull(String),
97    IsNotNull(String),
98    In(String, Vec<BindValue>),
99    Raw(String),
100}
101
102#[derive(Clone)]
103#[doc(hidden)]
104pub enum BindValue {
105    Int(i64),
106    Float(f64),
107    String(String),
108    Bool(bool),
109}
110
111#[doc(hidden)]
112pub trait IntoBindValue {
113    fn into_bind_value(self) -> BindValue;
114}
115
116impl IntoBindValue for i32 {
117    fn into_bind_value(self) -> BindValue { BindValue::Int(self as i64) }
118}
119
120impl IntoBindValue for i64 {
121    fn into_bind_value(self) -> BindValue { BindValue::Int(self) }
122}
123
124impl IntoBindValue for u32 {
125    fn into_bind_value(self) -> BindValue { BindValue::Int(self as i64) }
126}
127
128impl IntoBindValue for f64 {
129    fn into_bind_value(self) -> BindValue { BindValue::Float(self) }
130}
131
132impl IntoBindValue for bool {
133    fn into_bind_value(self) -> BindValue { BindValue::Bool(self) }
134}
135
136impl IntoBindValue for &str {
137    fn into_bind_value(self) -> BindValue { BindValue::String(self.to_string()) }
138}
139
140impl IntoBindValue for String {
141    fn into_bind_value(self) -> BindValue { BindValue::String(self) }
142}
143
144impl IntoBindValue for chrono::DateTime<chrono::Utc> {
145    fn into_bind_value(self) -> BindValue {
146        BindValue::String(self.format("%Y-%m-%d %H:%M:%S%.6f").to_string())
147    }
148}
149
150impl SelectBuilder {
151    pub fn table(table: &str) -> Self {
152        assert!(validate_identifier(table), "Invalid table name: {table}");
153        Self {
154            table: table.to_string(),
155            columns: "*".to_string(),
156            joins: Vec::new(),
157            conditions: Vec::new(),
158            order: Vec::new(),
159            limit_val: None,
160            offset_val: None,
161            group_by_val: None,
162        }
163    }
164
165    pub fn columns(mut self, cols: &str) -> Self {
166        assert!(validate_column_list(cols), "Invalid column list: {cols}");
167        self.columns = cols.to_string();
168        self
169    }
170
171    /// Add a JOIN clause. Must not contain SQL injection vectors (;, quotes, comments).
172    pub fn join(mut self, join_clause: &str) -> Self {
173        assert!(validate_join(join_clause), "Invalid JOIN clause: contains forbidden characters");
174        self.joins.push(join_clause.to_string());
175        self
176    }
177
178    pub fn where_eq<V: IntoBindValue>(mut self, column: &str, value: V) -> Self {
179        assert!(validate_identifier(column), "Invalid column name: {column}");
180        self.conditions.push(Condition::Eq(column.to_string(), value.into_bind_value()));
181        self
182    }
183
184    pub fn where_ne<V: IntoBindValue>(mut self, column: &str, value: V) -> Self {
185        assert!(validate_identifier(column), "Invalid column name: {column}");
186        self.conditions.push(Condition::Ne(column.to_string(), value.into_bind_value()));
187        self
188    }
189
190    pub fn where_gt<V: IntoBindValue>(mut self, column: &str, value: V) -> Self {
191        assert!(validate_identifier(column), "Invalid column name: {column}");
192        self.conditions.push(Condition::Gt(column.to_string(), value.into_bind_value()));
193        self
194    }
195
196    pub fn where_gte<V: IntoBindValue>(mut self, column: &str, value: V) -> Self {
197        assert!(validate_identifier(column), "Invalid column name: {column}");
198        self.conditions.push(Condition::Gte(column.to_string(), value.into_bind_value()));
199        self
200    }
201
202    pub fn where_lt<V: IntoBindValue>(mut self, column: &str, value: V) -> Self {
203        assert!(validate_identifier(column), "Invalid column name: {column}");
204        self.conditions.push(Condition::Lt(column.to_string(), value.into_bind_value()));
205        self
206    }
207
208    pub fn where_lte<V: IntoBindValue>(mut self, column: &str, value: V) -> Self {
209        assert!(validate_identifier(column), "Invalid column name: {column}");
210        self.conditions.push(Condition::Lte(column.to_string(), value.into_bind_value()));
211        self
212    }
213
214    pub fn where_like<V: IntoBindValue>(mut self, column: &str, pattern: V) -> Self {
215        assert!(validate_identifier(column), "Invalid column name: {column}");
216        self.conditions.push(Condition::Like(column.to_string(), pattern.into_bind_value()));
217        self
218    }
219
220    pub fn where_null(mut self, column: &str) -> Self {
221        assert!(validate_identifier(column), "Invalid column name: {column}");
222        self.conditions.push(Condition::IsNull(column.to_string()));
223        self
224    }
225
226    pub fn where_not_null(mut self, column: &str) -> Self {
227        assert!(validate_identifier(column), "Invalid column name: {column}");
228        self.conditions.push(Condition::IsNotNull(column.to_string()));
229        self
230    }
231
232    pub fn where_in<V: IntoBindValue>(mut self, column: &str, values: &[V]) -> Self
233    where V: Clone {
234        assert!(validate_identifier(column), "Invalid column name: {column}");
235        let bind_values: Vec<BindValue> = values.iter().map(|v| v.clone().into_bind_value()).collect();
236        self.conditions.push(Condition::In(column.to_string(), bind_values));
237        self
238    }
239
240    /// Add a raw WHERE clause. **WARNING**: This is NOT parameterized.
241    /// Only use with hardcoded strings, NEVER with user input.
242    pub fn where_raw(mut self, raw: &str) -> Self {
243        assert!(validate_join(raw), "where_raw contains forbidden characters");
244        self.conditions.push(Condition::Raw(raw.to_string()));
245        self
246    }
247
248    pub fn order_by(mut self, column: &str, direction: Order) -> Self {
249        assert!(validate_identifier(column), "Invalid ORDER BY column: {column}");
250        self.order.push((column.to_string(), direction));
251        self
252    }
253
254    pub fn limit(mut self, limit: u32) -> Self {
255        self.limit_val = Some(limit);
256        self
257    }
258
259    pub fn offset(mut self, offset: u32) -> Self {
260        self.offset_val = Some(offset);
261        self
262    }
263
264    pub fn group_by(mut self, clause: &str) -> Self {
265        assert!(validate_column_list(clause), "Invalid GROUP BY clause: {clause}");
266        self.group_by_val = Some(clause.to_string());
267        self
268    }
269
270    /// Execute and return all matching rows
271    pub async fn fetch_all<T>(self, pool: &DbPool) -> Result<Vec<T>, sqlx::Error>
272    where
273        T: for<'r> sqlx::FromRow<'r, DbRow> + Send + Unpin,
274    {
275        let (sql, binds) = self.build_select();
276        let mut query = sqlx::query_as::<_, T>(&sql);
277        for bind in &binds {
278            query = bind_value(query, bind);
279        }
280        query.fetch_all(pool).await
281    }
282
283    /// Execute and return the first matching row
284    pub async fn fetch_one<T>(self, pool: &DbPool) -> Result<Option<T>, sqlx::Error>
285    where
286        T: for<'r> sqlx::FromRow<'r, DbRow> + Send + Unpin,
287    {
288        let (sql, binds) = self.limit(1).build_select();
289        let mut query = sqlx::query_as::<_, T>(&sql);
290        for bind in &binds {
291            query = bind_value(query, bind);
292        }
293        query.fetch_optional(pool).await
294    }
295
296    /// Execute a COUNT(*) query with the same conditions
297    pub async fn count(self, pool: &DbPool) -> Result<i64, sqlx::Error> {
298        let (sql, binds) = self.build_count();
299        let mut query = sqlx::query(&sql);
300        for bind in &binds {
301            query = bind_value_raw(query, bind);
302        }
303        let row = query.fetch_one(pool).await?;
304        Ok(row.try_get::<i64, _>(0).unwrap_or(0))
305    }
306
307    /// Build the SELECT SQL and collect bind values
308    fn build_select(self) -> (String, Vec<BindValue>) {
309        let mut binds = Vec::new();
310        let mut idx = 1usize;
311
312        let joins = self.joins.join(" ");
313        let where_clause = self.build_where(&mut binds, &mut idx);
314
315        let order_clause = if self.order.is_empty() {
316            String::new()
317        } else {
318            let parts: Vec<String> = self.order.iter()
319                .map(|(col, dir)| format!("{} {}", col, dir))
320                .collect();
321            format!(" ORDER BY {}", parts.join(", "))
322        };
323
324        let group = self.group_by_val
325            .as_ref()
326            .map(|g| format!(" GROUP BY {}", g))
327            .unwrap_or_default();
328
329        let limit = self.limit_val
330            .map(|l| format!(" LIMIT {}", l))
331            .unwrap_or_default();
332
333        let offset = self.offset_val
334            .map(|o| format!(" OFFSET {}", o))
335            .unwrap_or_default();
336
337        let sql = format!(
338            "SELECT {} FROM {} {}{}{}{}{}{}",
339            self.columns, self.table, joins, where_clause, group, order_clause, limit, offset
340        );
341
342        (sql.trim().to_string(), binds)
343    }
344
345    fn build_count(self) -> (String, Vec<BindValue>) {
346        let mut binds = Vec::new();
347        let mut idx = 1usize;
348
349        let joins = self.joins.join(" ");
350        let where_clause = self.build_where(&mut binds, &mut idx);
351
352        let group = self.group_by_val
353            .as_ref()
354            .map(|g| format!(" GROUP BY {}", g))
355            .unwrap_or_default();
356
357        let sql = format!(
358            "SELECT COUNT(*) FROM {} {}{}{}",
359            self.table, joins, where_clause, group
360        );
361
362        (sql.trim().to_string(), binds)
363    }
364
365    fn build_where(&self, binds: &mut Vec<BindValue>, idx: &mut usize) -> String {
366        if self.conditions.is_empty() {
367            return String::new();
368        }
369
370        let parts: Vec<String> = self.conditions.iter().map(|c| {
371            match c {
372                Condition::Eq(col, val) => {
373                    let ph = placeholder(*idx);
374                    *idx += 1;
375                    binds.push(val.clone());
376                    format!("{} = {}", col, ph)
377                }
378                Condition::Ne(col, val) => {
379                    let ph = placeholder(*idx);
380                    *idx += 1;
381                    binds.push(val.clone());
382                    format!("{} != {}", col, ph)
383                }
384                Condition::Gt(col, val) => {
385                    let ph = placeholder(*idx);
386                    *idx += 1;
387                    binds.push(val.clone());
388                    format!("{} > {}", col, ph)
389                }
390                Condition::Gte(col, val) => {
391                    let ph = placeholder(*idx);
392                    *idx += 1;
393                    binds.push(val.clone());
394                    format!("{} >= {}", col, ph)
395                }
396                Condition::Lt(col, val) => {
397                    let ph = placeholder(*idx);
398                    *idx += 1;
399                    binds.push(val.clone());
400                    format!("{} < {}", col, ph)
401                }
402                Condition::Lte(col, val) => {
403                    let ph = placeholder(*idx);
404                    *idx += 1;
405                    binds.push(val.clone());
406                    format!("{} <= {}", col, ph)
407                }
408                Condition::Like(col, val) => {
409                    let ph = placeholder(*idx);
410                    *idx += 1;
411                    binds.push(val.clone());
412                    format!("{} LIKE {}", col, ph)
413                }
414                Condition::IsNull(col) => format!("{} IS NULL", col),
415                Condition::IsNotNull(col) => format!("{} IS NOT NULL", col),
416                Condition::In(col, vals) => {
417                    let placeholders: Vec<String> = vals.iter().map(|v| {
418                        let ph = placeholder(*idx);
419                        *idx += 1;
420                        binds.push(v.clone());
421                        ph
422                    }).collect();
423                    format!("{} IN ({})", col, placeholders.join(", "))
424                }
425                Condition::Raw(raw) => raw.clone(),
426            }
427        }).collect();
428
429        format!(" WHERE {}", parts.join(" AND "))
430    }
431}
432
433// Helper: bind a value to a query_as
434fn bind_value<'q, T>(
435    query: sqlx::query::QueryAs<'q, super::Db, T, super::DbArguments>,
436    value: &'q BindValue,
437) -> sqlx::query::QueryAs<'q, super::Db, T, super::DbArguments>
438where
439    T: for<'r> sqlx::FromRow<'r, DbRow>,
440{
441    match value {
442        BindValue::Int(v) => query.bind(*v),
443        BindValue::Float(v) => query.bind(*v),
444        BindValue::String(v) => query.bind(v.as_str()),
445        BindValue::Bool(v) => query.bind(*v),
446    }
447}
448
449// Helper: bind a value to a raw query
450fn bind_value_raw<'q>(
451    query: sqlx::query::Query<'q, super::Db, super::DbArguments>,
452    value: &'q BindValue,
453) -> sqlx::query::Query<'q, super::Db, super::DbArguments> {
454    match value {
455        BindValue::Int(v) => query.bind(*v),
456        BindValue::Float(v) => query.bind(*v),
457        BindValue::String(v) => query.bind(v.as_str()),
458        BindValue::Bool(v) => query.bind(*v),
459    }
460}