Skip to main content

dbx_core/engine/
query_builder.rs

1// Query Builder - Fluent API for building SQL queries
2//
3// This module provides a type-safe, fluent API for building SQL queries
4// without writing raw SQL strings.
5
6use crate::{Database, DbxResult};
7use arrow::record_batch::RecordBatch;
8
9/// Connector type for WHERE clauses
10#[derive(Debug, Clone, PartialEq)]
11enum Connector {
12    And,
13    Or,
14}
15
16impl Connector {
17    fn to_sql(&self) -> &str {
18        match self {
19            Connector::And => "AND",
20            Connector::Or => "OR",
21        }
22    }
23}
24
25/// JOIN type
26#[derive(Debug, Clone, PartialEq)]
27enum JoinType {
28    Inner,
29    Left,
30    Right,
31}
32
33impl JoinType {
34    fn to_sql(&self) -> &str {
35        match self {
36            JoinType::Inner => "INNER JOIN",
37            JoinType::Left => "LEFT JOIN",
38            JoinType::Right => "RIGHT JOIN",
39        }
40    }
41}
42
43/// JOIN clause representation
44#[derive(Debug, Clone)]
45struct JoinClause {
46    join_type: JoinType,
47    table: String,
48    on_conditions: Vec<(String, String)>, // (left_column, right_column)
49}
50
51/// WHERE clause representation
52#[derive(Debug, Clone)]
53struct WhereClause {
54    column: String,
55    operator: String,
56    value: String,
57    connector: Connector,
58}
59
60/// ORDER BY clause representation
61#[derive(Debug, Clone)]
62struct OrderByClause {
63    column: String,
64    direction: String,
65}
66
67/// Aggregate function types
68#[derive(Debug, Clone)]
69enum AggregateFunction {
70    Count(String),
71    Sum(String),
72    Avg(String),
73    Min(String),
74    Max(String),
75}
76
77impl AggregateFunction {
78    fn to_sql(&self) -> String {
79        match self {
80            AggregateFunction::Count(col) => format!("COUNT({})", col),
81            AggregateFunction::Sum(col) => format!("SUM({})", col),
82            AggregateFunction::Avg(col) => format!("AVG({})", col),
83            AggregateFunction::Min(col) => format!("MIN({})", col),
84            AggregateFunction::Max(col) => format!("MAX({})", col),
85        }
86    }
87}
88
89/// Query Builder for constructing SQL queries using a fluent API
90pub struct QueryBuilder<'a> {
91    db: &'a Database,
92    select_columns: Vec<String>,
93    from_table: Option<String>,
94    join_clauses: Vec<JoinClause>,
95    where_clauses: Vec<WhereClause>,
96    order_by_clauses: Vec<OrderByClause>,
97    limit_value: Option<usize>,
98    offset_value: Option<usize>,
99    aggregate: Option<AggregateFunction>,
100}
101
102impl<'a> QueryBuilder<'a> {
103    /// Create a new QueryBuilder
104    pub(crate) fn new(db: &'a Database) -> Self {
105        Self {
106            db,
107            select_columns: Vec::new(),
108            from_table: None,
109            join_clauses: Vec::new(),
110            where_clauses: Vec::new(),
111            order_by_clauses: Vec::new(),
112            limit_value: None,
113            offset_value: None,
114            aggregate: None,
115        }
116    }
117
118    /// Select specific columns
119    ///
120    pub fn select(mut self, columns: &[&str]) -> Self {
121        self.select_columns = columns.iter().map(|s| s.to_string()).collect();
122        self
123    }
124
125    /// Specify the table to query from
126    ///
127    pub fn from(mut self, table: &str) -> Self {
128        self.from_table = Some(table.to_string());
129        self
130    }
131
132    /// Add a WHERE clause
133    ///
134    pub fn where_(mut self, column: &str, operator: &str, value: &str) -> Self {
135        self.where_clauses.push(WhereClause {
136            column: column.to_string(),
137            operator: operator.to_string(),
138            value: value.to_string(),
139            connector: Connector::And, // First clause uses AND (will be ignored)
140        });
141        self
142    }
143
144    /// Add an AND condition to the WHERE clause
145    ///
146    pub fn and(mut self, column: &str, operator: &str, value: &str) -> Self {
147        self.where_clauses.push(WhereClause {
148            column: column.to_string(),
149            operator: operator.to_string(),
150            value: value.to_string(),
151            connector: Connector::And,
152        });
153        self
154    }
155
156    /// Add an OR condition to the WHERE clause
157    ///
158    pub fn or(mut self, column: &str, operator: &str, value: &str) -> Self {
159        self.where_clauses.push(WhereClause {
160            column: column.to_string(),
161            operator: operator.to_string(),
162            value: value.to_string(),
163            connector: Connector::Or,
164        });
165        self
166    }
167
168    /// Add an ORDER BY clause
169    ///
170    pub fn order_by(mut self, column: &str, direction: &str) -> Self {
171        self.order_by_clauses.push(OrderByClause {
172            column: column.to_string(),
173            direction: direction.to_uppercase(),
174        });
175        self
176    }
177
178    /// Set the LIMIT clause
179    ///
180    pub fn limit(mut self, limit: usize) -> Self {
181        self.limit_value = Some(limit);
182        self
183    }
184
185    /// Set the OFFSET clause
186    ///
187    pub fn offset(mut self, offset: usize) -> Self {
188        self.offset_value = Some(offset);
189        self
190    }
191
192    /// Add an INNER JOIN clause
193    ///
194    pub fn inner_join(mut self, table: &str, left_col: &str, right_col: &str) -> Self {
195        self.join_clauses.push(JoinClause {
196            join_type: JoinType::Inner,
197            table: table.to_string(),
198            on_conditions: vec![(left_col.to_string(), right_col.to_string())],
199        });
200        self
201    }
202
203    /// Add a LEFT JOIN clause
204    ///
205    pub fn left_join(mut self, table: &str, left_col: &str, right_col: &str) -> Self {
206        self.join_clauses.push(JoinClause {
207            join_type: JoinType::Left,
208            table: table.to_string(),
209            on_conditions: vec![(left_col.to_string(), right_col.to_string())],
210        });
211        self
212    }
213
214    /// Add a RIGHT JOIN clause
215    ///
216    pub fn right_join(mut self, table: &str, left_col: &str, right_col: &str) -> Self {
217        self.join_clauses.push(JoinClause {
218            join_type: JoinType::Right,
219            table: table.to_string(),
220            on_conditions: vec![(left_col.to_string(), right_col.to_string())],
221        });
222        self
223    }
224
225    /// Count rows
226    ///
227    pub fn count(mut self, column: &str) -> Self {
228        self.aggregate = Some(AggregateFunction::Count(column.to_string()));
229        self
230    }
231
232    /// Sum values
233    pub fn sum(mut self, column: &str) -> Self {
234        self.aggregate = Some(AggregateFunction::Sum(column.to_string()));
235        self
236    }
237
238    /// Average values
239    pub fn avg(mut self, column: &str) -> Self {
240        self.aggregate = Some(AggregateFunction::Avg(column.to_string()));
241        self
242    }
243
244    /// Minimum value
245    pub fn min(mut self, column: &str) -> Self {
246        self.aggregate = Some(AggregateFunction::Min(column.to_string()));
247        self
248    }
249
250    /// Maximum value
251    pub fn max(mut self, column: &str) -> Self {
252        self.aggregate = Some(AggregateFunction::Max(column.to_string()));
253        self
254    }
255
256    /// Build the SQL query string
257    fn build_sql(&self) -> String {
258        let mut sql = String::new();
259
260        // SELECT clause
261        if let Some(agg) = &self.aggregate {
262            sql.push_str(&format!("SELECT {}", agg.to_sql()));
263        } else {
264            let columns = if self.select_columns.is_empty() {
265                "*".to_string()
266            } else {
267                self.select_columns.join(", ")
268            };
269            sql.push_str(&format!("SELECT {}", columns));
270        }
271
272        // FROM clause
273        if let Some(table) = &self.from_table {
274            sql.push_str(&format!(" FROM {}", table));
275        }
276
277        // JOIN clauses
278        for join in &self.join_clauses {
279            sql.push_str(&format!(" {} {}", join.join_type.to_sql(), join.table));
280
281            if !join.on_conditions.is_empty() {
282                sql.push_str(" ON ");
283                let conditions: Vec<String> = join
284                    .on_conditions
285                    .iter()
286                    .map(|(left, right)| format!("{} = {}", left, right))
287                    .collect();
288                sql.push_str(&conditions.join(" AND "));
289            }
290        }
291
292        // WHERE clause
293        if !self.where_clauses.is_empty() {
294            sql.push_str(" WHERE ");
295            for (i, clause) in self.where_clauses.iter().enumerate() {
296                if i > 0 {
297                    sql.push_str(&format!(" {} ", clause.connector.to_sql()));
298                }
299                sql.push_str(&format!(
300                    "{} {} {}",
301                    clause.column, clause.operator, clause.value
302                ));
303            }
304        }
305
306        // ORDER BY clause
307        if !self.order_by_clauses.is_empty() {
308            sql.push_str(" ORDER BY ");
309            let orders: Vec<String> = self
310                .order_by_clauses
311                .iter()
312                .map(|o| format!("{} {}", o.column, o.direction))
313                .collect();
314            sql.push_str(&orders.join(", "));
315        }
316
317        // LIMIT clause
318        if let Some(limit) = self.limit_value {
319            sql.push_str(&format!(" LIMIT {}", limit));
320        }
321
322        // OFFSET clause
323        if let Some(offset) = self.offset_value {
324            sql.push_str(&format!(" OFFSET {}", offset));
325        }
326
327        sql
328    }
329
330    /// Execute the query and return results
331    ///
332    pub fn execute(self) -> DbxResult<Vec<RecordBatch>> {
333        let sql = self.build_sql();
334        self.db.execute_sql(&sql)
335    }
336}