postrust_sql/
select.rs

1//! SELECT statement builder.
2
3use crate::{
4    builder::SqlFragment,
5    expr::{Expr, OrderExpr},
6    identifier::{escape_ident, from_qi, QualifiedIdentifier},
7};
8
9/// Builder for SELECT statements.
10#[derive(Clone, Debug, Default)]
11pub struct SelectBuilder {
12    columns: Vec<SqlFragment>,
13    from: Option<SqlFragment>,
14    joins: Vec<SqlFragment>,
15    where_clauses: Vec<SqlFragment>,
16    group_by: Vec<SqlFragment>,
17    having: Vec<SqlFragment>,
18    order_by: Vec<SqlFragment>,
19    limit: Option<i64>,
20    offset: Option<i64>,
21    distinct: bool,
22    cte: Vec<(String, SqlFragment)>,
23}
24
25impl SelectBuilder {
26    /// Create a new SELECT builder.
27    pub fn new() -> Self {
28        Self::default()
29    }
30
31    /// Add a CTE (WITH clause).
32    pub fn with_cte(mut self, name: &str, query: SqlFragment) -> Self {
33        self.cte.push((name.to_string(), query));
34        self
35    }
36
37    /// Set DISTINCT.
38    pub fn distinct(mut self) -> Self {
39        self.distinct = true;
40        self
41    }
42
43    /// Add a column to select.
44    pub fn column(mut self, name: &str) -> Self {
45        self.columns.push(SqlFragment::raw(escape_ident(name)));
46        self
47    }
48
49    /// Add a column with alias.
50    pub fn column_as(mut self, name: &str, alias: &str) -> Self {
51        self.columns.push(SqlFragment::raw(format!(
52            "{} AS {}",
53            escape_ident(name),
54            escape_ident(alias)
55        )));
56        self
57    }
58
59    /// Add a qualified column (table.column).
60    pub fn qualified_column(mut self, table: &str, column: &str) -> Self {
61        self.columns.push(SqlFragment::raw(format!(
62            "{}.{}",
63            escape_ident(table),
64            escape_ident(column)
65        )));
66        self
67    }
68
69    /// Add a raw SQL column expression.
70    pub fn column_raw(mut self, sql: SqlFragment) -> Self {
71        self.columns.push(sql);
72        self
73    }
74
75    /// Add all columns (*).
76    pub fn all_columns(mut self) -> Self {
77        self.columns.push(SqlFragment::raw("*"));
78        self
79    }
80
81    /// Add all columns from a table (table.*).
82    pub fn all_columns_from(mut self, table: &str) -> Self {
83        self.columns
84            .push(SqlFragment::raw(format!("{}.*", escape_ident(table))));
85        self
86    }
87
88    /// Set the FROM table.
89    pub fn from_table(mut self, qi: &QualifiedIdentifier) -> Self {
90        self.from = Some(SqlFragment::raw(from_qi(qi)));
91        self
92    }
93
94    /// Set FROM with alias.
95    pub fn from_table_as(mut self, qi: &QualifiedIdentifier, alias: &str) -> Self {
96        self.from = Some(SqlFragment::raw(format!(
97            "{} AS {}",
98            from_qi(qi),
99            escape_ident(alias)
100        )));
101        self
102    }
103
104    /// Set FROM from raw SQL.
105    pub fn from_raw(mut self, sql: SqlFragment) -> Self {
106        self.from = Some(sql);
107        self
108    }
109
110    /// Add an INNER JOIN.
111    pub fn inner_join(mut self, table: &str, condition: &str) -> Self {
112        self.joins.push(SqlFragment::raw(format!(
113            " INNER JOIN {} ON {}",
114            escape_ident(table),
115            condition
116        )));
117        self
118    }
119
120    /// Add a LEFT JOIN.
121    pub fn left_join(mut self, table: &str, condition: &str) -> Self {
122        self.joins.push(SqlFragment::raw(format!(
123            " LEFT JOIN {} ON {}",
124            escape_ident(table),
125            condition
126        )));
127        self
128    }
129
130    /// Add a LEFT JOIN LATERAL with subquery.
131    pub fn left_join_lateral(mut self, subquery: SqlFragment, alias: &str, on: &str) -> Self {
132        let mut join = SqlFragment::raw(" LEFT JOIN LATERAL (");
133        join.append(subquery);
134        join.push(") AS ");
135        join.push(&escape_ident(alias));
136        join.push(" ON ");
137        join.push(on);
138        self.joins.push(join);
139        self
140    }
141
142    /// Add a WHERE clause.
143    pub fn where_expr(mut self, expr: Expr) -> Self {
144        self.where_clauses.push(expr.into_fragment());
145        self
146    }
147
148    /// Add a raw WHERE clause.
149    pub fn where_raw(mut self, sql: SqlFragment) -> Self {
150        self.where_clauses.push(sql);
151        self
152    }
153
154    /// Add a GROUP BY column.
155    pub fn group_by(mut self, column: &str) -> Self {
156        self.group_by.push(SqlFragment::raw(escape_ident(column)));
157        self
158    }
159
160    /// Add a HAVING clause.
161    pub fn having(mut self, expr: Expr) -> Self {
162        self.having.push(expr.into_fragment());
163        self
164    }
165
166    /// Add an ORDER BY clause.
167    pub fn order_by(mut self, expr: OrderExpr) -> Self {
168        self.order_by.push(expr.into_fragment());
169        self
170    }
171
172    /// Add ORDER BY from raw SQL.
173    pub fn order_by_raw(mut self, sql: SqlFragment) -> Self {
174        self.order_by.push(sql);
175        self
176    }
177
178    /// Set LIMIT.
179    pub fn limit(mut self, n: i64) -> Self {
180        self.limit = Some(n);
181        self
182    }
183
184    /// Set OFFSET.
185    pub fn offset(mut self, n: i64) -> Self {
186        self.offset = Some(n);
187        self
188    }
189
190    /// Build the SELECT statement.
191    pub fn build(self) -> SqlFragment {
192        let mut result = SqlFragment::new();
193
194        // CTEs
195        if !self.cte.is_empty() {
196            result.push("WITH ");
197            for (i, (name, query)) in self.cte.into_iter().enumerate() {
198                if i > 0 {
199                    result.push(", ");
200                }
201                result.push(&escape_ident(&name));
202                result.push(" AS (");
203                result.append(query);
204                result.push(")");
205            }
206            result.push(" ");
207        }
208
209        // SELECT
210        result.push("SELECT ");
211        if self.distinct {
212            result.push("DISTINCT ");
213        }
214
215        // Columns
216        if self.columns.is_empty() {
217            result.push("*");
218        } else {
219            for (i, col) in self.columns.into_iter().enumerate() {
220                if i > 0 {
221                    result.push(", ");
222                }
223                result.append(col);
224            }
225        }
226
227        // FROM
228        if let Some(from) = self.from {
229            result.push(" FROM ");
230            result.append(from);
231        }
232
233        // JOINs
234        for join in self.joins {
235            result.append(join);
236        }
237
238        // WHERE
239        if !self.where_clauses.is_empty() {
240            result.push(" WHERE ");
241            for (i, clause) in self.where_clauses.into_iter().enumerate() {
242                if i > 0 {
243                    result.push(" AND ");
244                }
245                result.append(clause);
246            }
247        }
248
249        // GROUP BY
250        if !self.group_by.is_empty() {
251            result.push(" GROUP BY ");
252            for (i, col) in self.group_by.into_iter().enumerate() {
253                if i > 0 {
254                    result.push(", ");
255                }
256                result.append(col);
257            }
258        }
259
260        // HAVING
261        if !self.having.is_empty() {
262            result.push(" HAVING ");
263            for (i, clause) in self.having.into_iter().enumerate() {
264                if i > 0 {
265                    result.push(" AND ");
266                }
267                result.append(clause);
268            }
269        }
270
271        // ORDER BY
272        if !self.order_by.is_empty() {
273            result.push(" ORDER BY ");
274            for (i, order) in self.order_by.into_iter().enumerate() {
275                if i > 0 {
276                    result.push(", ");
277                }
278                result.append(order);
279            }
280        }
281
282        // LIMIT
283        if let Some(limit) = self.limit {
284            result.push(" LIMIT ");
285            result.push(&limit.to_string());
286        }
287
288        // OFFSET
289        if let Some(offset) = self.offset {
290            result.push(" OFFSET ");
291            result.push(&offset.to_string());
292        }
293
294        result
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[test]
303    fn test_simple_select() {
304        let qi = QualifiedIdentifier::new("public", "users");
305        let sql = SelectBuilder::new()
306            .column("id")
307            .column("name")
308            .from_table(&qi)
309            .build();
310
311        assert_eq!(
312            sql.sql(),
313            "SELECT \"id\", \"name\" FROM \"public\".\"users\""
314        );
315    }
316
317    #[test]
318    fn test_select_with_where() {
319        let qi = QualifiedIdentifier::new("public", "users");
320        let sql = SelectBuilder::new()
321            .all_columns()
322            .from_table(&qi)
323            .where_expr(Expr::eq("id", 1i64))
324            .build();
325
326        assert!(sql.sql().contains("WHERE"));
327        assert!(sql.sql().contains("$1"));
328    }
329
330    #[test]
331    fn test_select_with_order_limit() {
332        let qi = QualifiedIdentifier::new("public", "users");
333        let sql = SelectBuilder::new()
334            .all_columns()
335            .from_table(&qi)
336            .order_by(OrderExpr::new("created_at").desc())
337            .limit(10)
338            .offset(20)
339            .build();
340
341        assert!(sql.sql().contains("ORDER BY"));
342        assert!(sql.sql().contains("LIMIT 10"));
343        assert!(sql.sql().contains("OFFSET 20"));
344    }
345
346    #[test]
347    fn test_select_distinct() {
348        let qi = QualifiedIdentifier::unqualified("users");
349        let sql = SelectBuilder::new()
350            .distinct()
351            .column("status")
352            .from_table(&qi)
353            .build();
354
355        assert!(sql.sql().contains("SELECT DISTINCT"));
356    }
357}