Skip to main content

oxide_sql_core/builder/
expr.rs

1//! Expression builder for dynamic (string-based) queries.
2//!
3//! For compile-time validated column expressions, use `col` from `builder::typed`.
4
5use super::value::{SqlValue, ToSqlValue};
6
7/// Creates a column reference for dynamic (string-based) queries.
8///
9/// For compile-time validated queries, use `col` from `builder::typed`.
10#[must_use]
11pub fn dyn_col(name: &str) -> ColumnRef {
12    ColumnRef {
13        table: None,
14        name: String::from(name),
15    }
16}
17
18/// A column reference for dynamic (string-based) queries.
19#[derive(Debug, Clone)]
20pub struct ColumnRef {
21    /// Optional table qualifier.
22    pub table: Option<String>,
23    /// Column name.
24    pub name: String,
25}
26
27impl ColumnRef {
28    /// Creates a qualified column reference.
29    #[must_use]
30    pub fn qualified(table: &str, name: &str) -> Self {
31        Self {
32            table: Some(String::from(table)),
33            name: String::from(name),
34        }
35    }
36
37    /// Returns the SQL representation.
38    #[must_use]
39    pub fn to_sql(&self) -> String {
40        match &self.table {
41            Some(t) => format!("{t}.{}", self.name),
42            None => self.name.clone(),
43        }
44    }
45
46    /// Creates an equality expression.
47    #[must_use]
48    pub fn eq<T: ToSqlValue>(self, value: T) -> ExprBuilder {
49        ExprBuilder::binary(self.into(), "=", value.to_sql_value().into())
50    }
51
52    /// Creates an inequality expression.
53    #[must_use]
54    pub fn not_eq<T: ToSqlValue>(self, value: T) -> ExprBuilder {
55        ExprBuilder::binary(self.into(), "!=", value.to_sql_value().into())
56    }
57
58    /// Creates a less-than expression.
59    #[must_use]
60    pub fn lt<T: ToSqlValue>(self, value: T) -> ExprBuilder {
61        ExprBuilder::binary(self.into(), "<", value.to_sql_value().into())
62    }
63
64    /// Creates a less-than-or-equal expression.
65    #[must_use]
66    pub fn lt_eq<T: ToSqlValue>(self, value: T) -> ExprBuilder {
67        ExprBuilder::binary(self.into(), "<=", value.to_sql_value().into())
68    }
69
70    /// Creates a greater-than expression.
71    #[must_use]
72    pub fn gt<T: ToSqlValue>(self, value: T) -> ExprBuilder {
73        ExprBuilder::binary(self.into(), ">", value.to_sql_value().into())
74    }
75
76    /// Creates a greater-than-or-equal expression.
77    #[must_use]
78    pub fn gt_eq<T: ToSqlValue>(self, value: T) -> ExprBuilder {
79        ExprBuilder::binary(self.into(), ">=", value.to_sql_value().into())
80    }
81
82    /// Creates an IS NULL expression.
83    #[must_use]
84    pub fn is_null(self) -> ExprBuilder {
85        ExprBuilder::postfix(self.into(), "IS NULL")
86    }
87
88    /// Creates an IS NOT NULL expression.
89    #[must_use]
90    pub fn is_not_null(self) -> ExprBuilder {
91        ExprBuilder::postfix(self.into(), "IS NOT NULL")
92    }
93
94    /// Creates a LIKE expression.
95    #[must_use]
96    pub fn like<T: ToSqlValue>(self, pattern: T) -> ExprBuilder {
97        ExprBuilder::binary(self.into(), "LIKE", pattern.to_sql_value().into())
98    }
99
100    /// Creates a NOT LIKE expression.
101    #[must_use]
102    pub fn not_like<T: ToSqlValue>(self, pattern: T) -> ExprBuilder {
103        ExprBuilder::binary(self.into(), "NOT LIKE", pattern.to_sql_value().into())
104    }
105
106    /// Creates a BETWEEN expression.
107    #[must_use]
108    pub fn between<T: ToSqlValue, U: ToSqlValue>(self, low: T, high: U) -> ExprBuilder {
109        ExprBuilder::between(self.into(), low.to_sql_value(), high.to_sql_value(), false)
110    }
111
112    /// Creates a NOT BETWEEN expression.
113    #[must_use]
114    pub fn not_between<T: ToSqlValue, U: ToSqlValue>(self, low: T, high: U) -> ExprBuilder {
115        ExprBuilder::between(self.into(), low.to_sql_value(), high.to_sql_value(), true)
116    }
117
118    /// Creates an IN expression.
119    #[must_use]
120    pub fn in_list<T: ToSqlValue>(self, values: Vec<T>) -> ExprBuilder {
121        let sql_values: Vec<SqlValue> = values.into_iter().map(ToSqlValue::to_sql_value).collect();
122        ExprBuilder::in_list_impl(self.into(), sql_values, false)
123    }
124
125    /// Creates a NOT IN expression.
126    #[must_use]
127    pub fn not_in_list<T: ToSqlValue>(self, values: Vec<T>) -> ExprBuilder {
128        let sql_values: Vec<SqlValue> = values.into_iter().map(ToSqlValue::to_sql_value).collect();
129        ExprBuilder::in_list_impl(self.into(), sql_values, true)
130    }
131}
132
133/// A type-safe expression builder.
134#[derive(Debug, Clone)]
135pub struct ExprBuilder {
136    sql: String,
137    params: Vec<SqlValue>,
138}
139
140impl ExprBuilder {
141    /// Creates a new expression from raw SQL.
142    ///
143    /// **Warning**: Only use this for SQL fragments that don't contain user input.
144    #[must_use]
145    pub fn raw(sql: impl Into<String>) -> Self {
146        Self {
147            sql: sql.into(),
148            params: vec![],
149        }
150    }
151
152    /// Creates a column reference expression.
153    ///
154    /// This is used internally by typed column accessors.
155    #[must_use]
156    pub fn column(name: &str) -> Self {
157        Self {
158            sql: String::from(name),
159            params: vec![],
160        }
161    }
162
163    /// Creates an expression from a value (parameterized).
164    #[must_use]
165    pub fn value<T: ToSqlValue>(value: T) -> Self {
166        Self {
167            sql: String::from("?"),
168            params: vec![value.to_sql_value()],
169        }
170    }
171
172    /// Creates a binary expression.
173    fn binary(left: Self, op: &str, right: Self) -> Self {
174        let mut params = left.params;
175        params.extend(right.params);
176        Self {
177            sql: format!("{} {op} {}", left.sql, right.sql),
178            params,
179        }
180    }
181
182    /// Creates a postfix expression.
183    fn postfix(operand: Self, op: &str) -> Self {
184        Self {
185            sql: format!("{} {op}", operand.sql),
186            params: operand.params,
187        }
188    }
189
190    /// Creates a BETWEEN expression.
191    fn between(expr: Self, low: SqlValue, high: SqlValue, negated: bool) -> Self {
192        let keyword = if negated { "NOT BETWEEN" } else { "BETWEEN" };
193        let mut params = expr.params;
194        params.push(low);
195        params.push(high);
196        Self {
197            sql: format!("{} {keyword} ? AND ?", expr.sql),
198            params,
199        }
200    }
201
202    /// Creates an IN expression (internal).
203    fn in_list_impl(expr: Self, values: Vec<SqlValue>, negated: bool) -> Self {
204        let keyword = if negated { "NOT IN" } else { "IN" };
205        let placeholders: Vec<&str> = values.iter().map(|_| "?").collect();
206        let mut params = expr.params;
207        params.extend(values);
208        Self {
209            sql: format!("{} {keyword} ({})", expr.sql, placeholders.join(", ")),
210            params,
211        }
212    }
213
214    /// Creates an AND expression.
215    #[must_use]
216    pub fn and(self, other: Self) -> Self {
217        Self::binary(self, "AND", other)
218    }
219
220    /// Creates an OR expression.
221    #[must_use]
222    pub fn or(self, other: Self) -> Self {
223        Self::binary(self, "OR", other)
224    }
225
226    /// Wraps the expression in parentheses.
227    #[must_use]
228    pub fn paren(self) -> Self {
229        Self {
230            sql: format!("({})", self.sql),
231            params: self.params,
232        }
233    }
234
235    /// Negates the expression with NOT.
236    #[must_use]
237    #[allow(clippy::should_implement_trait)]
238    pub fn not(self) -> Self {
239        Self {
240            sql: format!("NOT {}", self.sql),
241            params: self.params,
242        }
243    }
244
245    /// Creates an equality expression.
246    #[must_use]
247    pub fn eq<T: ToSqlValue>(self, value: T) -> Self {
248        Self::binary(self, "=", value.to_sql_value().into())
249    }
250
251    /// Creates an inequality expression.
252    #[must_use]
253    pub fn not_eq<T: ToSqlValue>(self, value: T) -> Self {
254        Self::binary(self, "!=", value.to_sql_value().into())
255    }
256
257    /// Creates a less-than expression.
258    #[must_use]
259    pub fn lt<T: ToSqlValue>(self, value: T) -> Self {
260        Self::binary(self, "<", value.to_sql_value().into())
261    }
262
263    /// Creates a less-than-or-equal expression.
264    #[must_use]
265    pub fn lt_eq<T: ToSqlValue>(self, value: T) -> Self {
266        Self::binary(self, "<=", value.to_sql_value().into())
267    }
268
269    /// Creates a greater-than expression.
270    #[must_use]
271    pub fn gt<T: ToSqlValue>(self, value: T) -> Self {
272        Self::binary(self, ">", value.to_sql_value().into())
273    }
274
275    /// Creates a greater-than-or-equal expression.
276    #[must_use]
277    pub fn gt_eq<T: ToSqlValue>(self, value: T) -> Self {
278        Self::binary(self, ">=", value.to_sql_value().into())
279    }
280
281    /// Creates an IS NULL expression.
282    #[must_use]
283    pub fn is_null(self) -> Self {
284        Self::postfix(self, "IS NULL")
285    }
286
287    /// Creates an IS NOT NULL expression.
288    #[must_use]
289    pub fn is_not_null(self) -> Self {
290        Self::postfix(self, "IS NOT NULL")
291    }
292
293    /// Creates a LIKE expression.
294    #[must_use]
295    pub fn like<T: ToSqlValue>(self, pattern: T) -> Self {
296        Self::binary(self, "LIKE", pattern.to_sql_value().into())
297    }
298
299    /// Creates an IN expression.
300    #[must_use]
301    pub fn in_list<T: ToSqlValue>(self, values: Vec<T>) -> Self {
302        let sql_values: Vec<SqlValue> = values.into_iter().map(ToSqlValue::to_sql_value).collect();
303        Self::in_list_impl(self, sql_values, false)
304    }
305
306    /// Creates a NOT IN expression.
307    #[must_use]
308    pub fn not_in_list<T: ToSqlValue>(self, values: Vec<T>) -> Self {
309        let sql_values: Vec<SqlValue> = values.into_iter().map(ToSqlValue::to_sql_value).collect();
310        Self::in_list_impl(self, sql_values, true)
311    }
312
313    /// Returns the SQL string.
314    #[must_use]
315    pub fn sql(&self) -> &str {
316        &self.sql
317    }
318
319    /// Returns the parameters.
320    #[must_use]
321    pub fn params(&self) -> &[SqlValue] {
322        &self.params
323    }
324
325    /// Consumes the builder and returns the SQL and parameters.
326    #[must_use]
327    pub fn build(self) -> (String, Vec<SqlValue>) {
328        (self.sql, self.params)
329    }
330}
331
332impl From<ColumnRef> for ExprBuilder {
333    fn from(col: ColumnRef) -> Self {
334        Self {
335            sql: col.to_sql(),
336            params: vec![],
337        }
338    }
339}
340
341impl From<SqlValue> for ExprBuilder {
342    fn from(value: SqlValue) -> Self {
343        Self {
344            sql: String::from("?"),
345            params: vec![value],
346        }
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[test]
355    fn test_column_eq() {
356        let expr = dyn_col("name").eq("Alice");
357        assert_eq!(expr.sql(), "name = ?");
358        assert_eq!(expr.params().len(), 1);
359    }
360
361    #[test]
362    fn test_column_comparison() {
363        assert_eq!(dyn_col("age").gt(18).sql(), "age > ?");
364        assert_eq!(dyn_col("age").lt_eq(65).sql(), "age <= ?");
365    }
366
367    #[test]
368    fn test_is_null() {
369        let expr = dyn_col("deleted_at").is_null();
370        assert_eq!(expr.sql(), "deleted_at IS NULL");
371        assert!(expr.params().is_empty());
372    }
373
374    #[test]
375    fn test_like() {
376        let expr = dyn_col("email").like("%@example.com");
377        assert_eq!(expr.sql(), "email LIKE ?");
378    }
379
380    #[test]
381    fn test_between() {
382        let expr = dyn_col("price").between(10, 100);
383        assert_eq!(expr.sql(), "price BETWEEN ? AND ?");
384        assert_eq!(expr.params().len(), 2);
385    }
386
387    #[test]
388    fn test_in_list() {
389        let expr = dyn_col("status").in_list(vec!["active", "pending"]);
390        assert_eq!(expr.sql(), "status IN (?, ?)");
391        assert_eq!(expr.params().len(), 2);
392    }
393
394    #[test]
395    fn test_and_or() {
396        let expr = dyn_col("active").eq(true).and(
397            dyn_col("age")
398                .gt(18)
399                .or(dyn_col("verified").eq(true))
400                .paren(),
401        );
402        assert_eq!(expr.sql(), "active = ? AND (age > ? OR verified = ?)");
403        assert_eq!(expr.params().len(), 3);
404    }
405
406    #[test]
407    fn test_qualified_column() {
408        let expr = ColumnRef::qualified("users", "name").eq("Bob");
409        assert_eq!(expr.sql(), "users.name = ?");
410    }
411
412    #[test]
413    fn test_sql_injection_prevention() {
414        let malicious = "'; DROP TABLE users; --";
415        let expr = dyn_col("name").eq(malicious);
416        // The value is parameterized, not interpolated
417        assert_eq!(expr.sql(), "name = ?");
418        // The malicious input is stored safely as a parameter
419        assert!(matches!(&expr.params()[0], SqlValue::Text(s) if s == malicious));
420    }
421}