Skip to main content

karbon_framework/db/
select_builder.rs

1use super::{placeholder, DbPool, DbRow};
2use sqlx::Row;
3
4/// Validate a SQL identifier (table name, column name, alias).
5/// Allows: alphanumeric, underscore, dot (for table.column), star (*).
6/// Rejects anything that could be SQL injection.
7fn validate_identifier(s: &str) -> bool {
8    if s.is_empty() {
9        return false;
10    }
11    s.chars().all(|c| c.is_alphanumeric() || c == '_' || c == '.' || c == '*' || c == ' ')
12}
13
14/// Validate a column list (e.g., "id, name, email, u.created_at").
15/// Allows commas and spaces in addition to identifier chars.
16fn validate_column_list(s: &str) -> bool {
17    if s.is_empty() {
18        return false;
19    }
20    s.chars().all(|c| c.is_alphanumeric() || c == '_' || c == '.' || c == '*' || c == ' ' || c == ',')
21}
22
23/// Validate a JOIN clause.
24/// Allows: alphanumeric, underscore, dot, space, equals, and common JOIN keywords.
25/// Rejects semicolons, quotes, comments, and other injection vectors.
26fn validate_join(s: &str) -> bool {
27    if s.is_empty() {
28        return false;
29    }
30    !s.contains(';') && !s.contains('\'') && !s.contains('"') && !s.contains("--") && !s.contains("/*")
31}
32
33/// Sort direction
34#[derive(Debug, Clone, Copy)]
35pub enum Order {
36    Asc,
37    Desc,
38}
39
40impl std::fmt::Display for Order {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        match self {
43            Order::Asc => write!(f, "ASC"),
44            Order::Desc => write!(f, "DESC"),
45        }
46    }
47}
48
49/// A typed, fluent SELECT query builder with parameterized conditions.
50///
51/// ```ignore
52/// use framework::db::{SelectBuilder, Order};
53///
54/// let users: Vec<User> = SelectBuilder::table("users")
55///     .columns("id, name, email, created_at")
56///     .where_eq("active", true)
57///     .where_like("name", "%alice%")
58///     .where_gt("age", 18)
59///     .where_in("role", &["admin", "editor"])
60///     .where_null("deleted_at")
61///     .order_by("created_at", Order::Desc)
62///     .limit(20)
63///     .offset(0)
64///     .fetch_all(&pool)
65///     .await?;
66///
67/// let count = SelectBuilder::table("users")
68///     .where_eq("active", true)
69///     .count(&pool)
70///     .await?;
71///
72/// let user: Option<User> = SelectBuilder::table("users")
73///     .where_eq("email", "alice@example.com")
74///     .fetch_one(&pool)
75///     .await?;
76/// ```
77pub struct SelectBuilder {
78    table: String,
79    columns: String,
80    joins: Vec<String>,
81    conditions: Vec<Condition>,
82    order: Vec<(String, Order)>,
83    limit_val: Option<u32>,
84    offset_val: Option<u32>,
85    group_by_val: Option<String>,
86}
87
88enum Condition {
89    Eq(String, BindValue),
90    Ne(String, BindValue),
91    Gt(String, BindValue),
92    Gte(String, BindValue),
93    Lt(String, BindValue),
94    Lte(String, BindValue),
95    Like(String, BindValue),
96    IsNull(String),
97    IsNotNull(String),
98    In(String, Vec<BindValue>),
99    Raw(String),
100}
101
102#[derive(Clone)]
103#[doc(hidden)]
104pub enum BindValue {
105    Int(i64),
106    Float(f64),
107    String(String),
108    Bool(bool),
109}
110
111#[doc(hidden)]
112pub trait IntoBindValue {
113    fn into_bind_value(self) -> BindValue;
114}
115
116impl IntoBindValue for i32 {
117    fn into_bind_value(self) -> BindValue { BindValue::Int(self as i64) }
118}
119
120impl IntoBindValue for i64 {
121    fn into_bind_value(self) -> BindValue { BindValue::Int(self) }
122}
123
124impl IntoBindValue for u32 {
125    fn into_bind_value(self) -> BindValue { BindValue::Int(self as i64) }
126}
127
128impl IntoBindValue for f64 {
129    fn into_bind_value(self) -> BindValue { BindValue::Float(self) }
130}
131
132impl IntoBindValue for bool {
133    fn into_bind_value(self) -> BindValue { BindValue::Bool(self) }
134}
135
136impl IntoBindValue for &str {
137    fn into_bind_value(self) -> BindValue { BindValue::String(self.to_string()) }
138}
139
140impl IntoBindValue for String {
141    fn into_bind_value(self) -> BindValue { BindValue::String(self) }
142}
143
144impl SelectBuilder {
145    pub fn table(table: &str) -> Self {
146        assert!(validate_identifier(table), "Invalid table name: {table}");
147        Self {
148            table: table.to_string(),
149            columns: "*".to_string(),
150            joins: Vec::new(),
151            conditions: Vec::new(),
152            order: Vec::new(),
153            limit_val: None,
154            offset_val: None,
155            group_by_val: None,
156        }
157    }
158
159    pub fn columns(mut self, cols: &str) -> Self {
160        assert!(validate_column_list(cols), "Invalid column list: {cols}");
161        self.columns = cols.to_string();
162        self
163    }
164
165    /// Add a JOIN clause. Must not contain SQL injection vectors (;, quotes, comments).
166    pub fn join(mut self, join_clause: &str) -> Self {
167        assert!(validate_join(join_clause), "Invalid JOIN clause: contains forbidden characters");
168        self.joins.push(join_clause.to_string());
169        self
170    }
171
172    pub fn where_eq<V: IntoBindValue>(mut self, column: &str, value: V) -> Self {
173        assert!(validate_identifier(column), "Invalid column name: {column}");
174        self.conditions.push(Condition::Eq(column.to_string(), value.into_bind_value()));
175        self
176    }
177
178    pub fn where_ne<V: IntoBindValue>(mut self, column: &str, value: V) -> Self {
179        assert!(validate_identifier(column), "Invalid column name: {column}");
180        self.conditions.push(Condition::Ne(column.to_string(), value.into_bind_value()));
181        self
182    }
183
184    pub fn where_gt<V: IntoBindValue>(mut self, column: &str, value: V) -> Self {
185        assert!(validate_identifier(column), "Invalid column name: {column}");
186        self.conditions.push(Condition::Gt(column.to_string(), value.into_bind_value()));
187        self
188    }
189
190    pub fn where_gte<V: IntoBindValue>(mut self, column: &str, value: V) -> Self {
191        assert!(validate_identifier(column), "Invalid column name: {column}");
192        self.conditions.push(Condition::Gte(column.to_string(), value.into_bind_value()));
193        self
194    }
195
196    pub fn where_lt<V: IntoBindValue>(mut self, column: &str, value: V) -> Self {
197        assert!(validate_identifier(column), "Invalid column name: {column}");
198        self.conditions.push(Condition::Lt(column.to_string(), value.into_bind_value()));
199        self
200    }
201
202    pub fn where_lte<V: IntoBindValue>(mut self, column: &str, value: V) -> Self {
203        assert!(validate_identifier(column), "Invalid column name: {column}");
204        self.conditions.push(Condition::Lte(column.to_string(), value.into_bind_value()));
205        self
206    }
207
208    pub fn where_like<V: IntoBindValue>(mut self, column: &str, pattern: V) -> Self {
209        assert!(validate_identifier(column), "Invalid column name: {column}");
210        self.conditions.push(Condition::Like(column.to_string(), pattern.into_bind_value()));
211        self
212    }
213
214    pub fn where_null(mut self, column: &str) -> Self {
215        assert!(validate_identifier(column), "Invalid column name: {column}");
216        self.conditions.push(Condition::IsNull(column.to_string()));
217        self
218    }
219
220    pub fn where_not_null(mut self, column: &str) -> Self {
221        assert!(validate_identifier(column), "Invalid column name: {column}");
222        self.conditions.push(Condition::IsNotNull(column.to_string()));
223        self
224    }
225
226    pub fn where_in<V: IntoBindValue>(mut self, column: &str, values: &[V]) -> Self
227    where V: Clone {
228        assert!(validate_identifier(column), "Invalid column name: {column}");
229        let bind_values: Vec<BindValue> = values.iter().map(|v| v.clone().into_bind_value()).collect();
230        self.conditions.push(Condition::In(column.to_string(), bind_values));
231        self
232    }
233
234    /// Add a raw WHERE clause. **WARNING**: This is NOT parameterized.
235    /// Only use with hardcoded strings, NEVER with user input.
236    pub fn where_raw(mut self, raw: &str) -> Self {
237        assert!(validate_join(raw), "where_raw contains forbidden characters");
238        self.conditions.push(Condition::Raw(raw.to_string()));
239        self
240    }
241
242    pub fn order_by(mut self, column: &str, direction: Order) -> Self {
243        assert!(validate_identifier(column), "Invalid ORDER BY column: {column}");
244        self.order.push((column.to_string(), direction));
245        self
246    }
247
248    pub fn limit(mut self, limit: u32) -> Self {
249        self.limit_val = Some(limit);
250        self
251    }
252
253    pub fn offset(mut self, offset: u32) -> Self {
254        self.offset_val = Some(offset);
255        self
256    }
257
258    pub fn group_by(mut self, clause: &str) -> Self {
259        assert!(validate_column_list(clause), "Invalid GROUP BY clause: {clause}");
260        self.group_by_val = Some(clause.to_string());
261        self
262    }
263
264    /// Execute and return all matching rows
265    pub async fn fetch_all<T>(self, pool: &DbPool) -> Result<Vec<T>, sqlx::Error>
266    where
267        T: for<'r> sqlx::FromRow<'r, DbRow> + Send + Unpin,
268    {
269        let (sql, binds) = self.build_select();
270        let mut query = sqlx::query_as::<_, T>(&sql);
271        for bind in &binds {
272            query = bind_value(query, bind);
273        }
274        query.fetch_all(pool).await
275    }
276
277    /// Execute and return the first matching row
278    pub async fn fetch_one<T>(self, pool: &DbPool) -> Result<Option<T>, sqlx::Error>
279    where
280        T: for<'r> sqlx::FromRow<'r, DbRow> + Send + Unpin,
281    {
282        let (sql, binds) = self.limit(1).build_select();
283        let mut query = sqlx::query_as::<_, T>(&sql);
284        for bind in &binds {
285            query = bind_value(query, bind);
286        }
287        query.fetch_optional(pool).await
288    }
289
290    /// Execute a COUNT(*) query with the same conditions
291    pub async fn count(self, pool: &DbPool) -> Result<i64, sqlx::Error> {
292        let (sql, binds) = self.build_count();
293        let mut query = sqlx::query(&sql);
294        for bind in &binds {
295            query = bind_value_raw(query, bind);
296        }
297        let row = query.fetch_one(pool).await?;
298        Ok(row.try_get::<i64, _>(0).unwrap_or(0))
299    }
300
301    /// Build the SELECT SQL and collect bind values
302    fn build_select(self) -> (String, Vec<BindValue>) {
303        let mut binds = Vec::new();
304        let mut idx = 1usize;
305
306        let joins = self.joins.join(" ");
307        let where_clause = self.build_where(&mut binds, &mut idx);
308
309        let order_clause = if self.order.is_empty() {
310            String::new()
311        } else {
312            let parts: Vec<String> = self.order.iter()
313                .map(|(col, dir)| format!("{} {}", col, dir))
314                .collect();
315            format!(" ORDER BY {}", parts.join(", "))
316        };
317
318        let group = self.group_by_val
319            .as_ref()
320            .map(|g| format!(" GROUP BY {}", g))
321            .unwrap_or_default();
322
323        let limit = self.limit_val
324            .map(|l| format!(" LIMIT {}", l))
325            .unwrap_or_default();
326
327        let offset = self.offset_val
328            .map(|o| format!(" OFFSET {}", o))
329            .unwrap_or_default();
330
331        let sql = format!(
332            "SELECT {} FROM {} {}{}{}{}{}{}",
333            self.columns, self.table, joins, where_clause, group, order_clause, limit, offset
334        );
335
336        (sql.trim().to_string(), binds)
337    }
338
339    fn build_count(self) -> (String, Vec<BindValue>) {
340        let mut binds = Vec::new();
341        let mut idx = 1usize;
342
343        let joins = self.joins.join(" ");
344        let where_clause = self.build_where(&mut binds, &mut idx);
345
346        let group = self.group_by_val
347            .as_ref()
348            .map(|g| format!(" GROUP BY {}", g))
349            .unwrap_or_default();
350
351        let sql = format!(
352            "SELECT COUNT(*) FROM {} {}{}{}",
353            self.table, joins, where_clause, group
354        );
355
356        (sql.trim().to_string(), binds)
357    }
358
359    fn build_where(&self, binds: &mut Vec<BindValue>, idx: &mut usize) -> String {
360        if self.conditions.is_empty() {
361            return String::new();
362        }
363
364        let parts: Vec<String> = self.conditions.iter().map(|c| {
365            match c {
366                Condition::Eq(col, val) => {
367                    let ph = placeholder(*idx);
368                    *idx += 1;
369                    binds.push(val.clone());
370                    format!("{} = {}", col, ph)
371                }
372                Condition::Ne(col, val) => {
373                    let ph = placeholder(*idx);
374                    *idx += 1;
375                    binds.push(val.clone());
376                    format!("{} != {}", col, ph)
377                }
378                Condition::Gt(col, val) => {
379                    let ph = placeholder(*idx);
380                    *idx += 1;
381                    binds.push(val.clone());
382                    format!("{} > {}", col, ph)
383                }
384                Condition::Gte(col, val) => {
385                    let ph = placeholder(*idx);
386                    *idx += 1;
387                    binds.push(val.clone());
388                    format!("{} >= {}", col, ph)
389                }
390                Condition::Lt(col, val) => {
391                    let ph = placeholder(*idx);
392                    *idx += 1;
393                    binds.push(val.clone());
394                    format!("{} < {}", col, ph)
395                }
396                Condition::Lte(col, val) => {
397                    let ph = placeholder(*idx);
398                    *idx += 1;
399                    binds.push(val.clone());
400                    format!("{} <= {}", col, ph)
401                }
402                Condition::Like(col, val) => {
403                    let ph = placeholder(*idx);
404                    *idx += 1;
405                    binds.push(val.clone());
406                    format!("{} LIKE {}", col, ph)
407                }
408                Condition::IsNull(col) => format!("{} IS NULL", col),
409                Condition::IsNotNull(col) => format!("{} IS NOT NULL", col),
410                Condition::In(col, vals) => {
411                    let placeholders: Vec<String> = vals.iter().map(|v| {
412                        let ph = placeholder(*idx);
413                        *idx += 1;
414                        binds.push(v.clone());
415                        ph
416                    }).collect();
417                    format!("{} IN ({})", col, placeholders.join(", "))
418                }
419                Condition::Raw(raw) => raw.clone(),
420            }
421        }).collect();
422
423        format!(" WHERE {}", parts.join(" AND "))
424    }
425}
426
427// Helper: bind a value to a query_as
428fn bind_value<'q, T>(
429    query: sqlx::query::QueryAs<'q, super::Db, T, super::DbArguments>,
430    value: &'q BindValue,
431) -> sqlx::query::QueryAs<'q, super::Db, T, super::DbArguments>
432where
433    T: for<'r> sqlx::FromRow<'r, DbRow>,
434{
435    match value {
436        BindValue::Int(v) => query.bind(*v),
437        BindValue::Float(v) => query.bind(*v),
438        BindValue::String(v) => query.bind(v.as_str()),
439        BindValue::Bool(v) => query.bind(*v),
440    }
441}
442
443// Helper: bind a value to a raw query
444fn bind_value_raw<'q>(
445    query: sqlx::query::Query<'q, super::Db, super::DbArguments>,
446    value: &'q BindValue,
447) -> sqlx::query::Query<'q, super::Db, super::DbArguments> {
448    match value {
449        BindValue::Int(v) => query.bind(*v),
450        BindValue::Float(v) => query.bind(*v),
451        BindValue::String(v) => query.bind(v.as_str()),
452        BindValue::Bool(v) => query.bind(*v),
453    }
454}