1use crate::{Database, DbxResult};
7use arrow::record_batch::RecordBatch;
8
9#[derive(Debug, Clone, PartialEq)]
11enum Connector {
12 And,
13 Or,
14}
15
16impl Connector {
17 fn to_sql(&self) -> &str {
18 match self {
19 Connector::And => "AND",
20 Connector::Or => "OR",
21 }
22 }
23}
24
25#[derive(Debug, Clone, PartialEq)]
27enum JoinType {
28 Inner,
29 Left,
30 Right,
31}
32
33impl JoinType {
34 fn to_sql(&self) -> &str {
35 match self {
36 JoinType::Inner => "INNER JOIN",
37 JoinType::Left => "LEFT JOIN",
38 JoinType::Right => "RIGHT JOIN",
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
45struct JoinClause {
46 join_type: JoinType,
47 table: String,
48 on_conditions: Vec<(String, String)>, }
50
51#[derive(Debug, Clone)]
53struct WhereClause {
54 column: String,
55 operator: String,
56 value: String,
57 connector: Connector,
58}
59
60#[derive(Debug, Clone)]
62struct OrderByClause {
63 column: String,
64 direction: String,
65}
66
67#[derive(Debug, Clone)]
69enum AggregateFunction {
70 Count(String),
71 Sum(String),
72 Avg(String),
73 Min(String),
74 Max(String),
75}
76
77impl AggregateFunction {
78 fn to_sql(&self) -> String {
79 match self {
80 AggregateFunction::Count(col) => format!("COUNT({})", col),
81 AggregateFunction::Sum(col) => format!("SUM({})", col),
82 AggregateFunction::Avg(col) => format!("AVG({})", col),
83 AggregateFunction::Min(col) => format!("MIN({})", col),
84 AggregateFunction::Max(col) => format!("MAX({})", col),
85 }
86 }
87}
88
89pub struct QueryBuilder<'a> {
91 db: &'a Database,
92 select_columns: Vec<String>,
93 from_table: Option<String>,
94 join_clauses: Vec<JoinClause>,
95 where_clauses: Vec<WhereClause>,
96 order_by_clauses: Vec<OrderByClause>,
97 limit_value: Option<usize>,
98 offset_value: Option<usize>,
99 aggregate: Option<AggregateFunction>,
100}
101
102impl<'a> QueryBuilder<'a> {
103 pub(crate) fn new(db: &'a Database) -> Self {
105 Self {
106 db,
107 select_columns: Vec::new(),
108 from_table: None,
109 join_clauses: Vec::new(),
110 where_clauses: Vec::new(),
111 order_by_clauses: Vec::new(),
112 limit_value: None,
113 offset_value: None,
114 aggregate: None,
115 }
116 }
117
118 pub fn select(mut self, columns: &[&str]) -> Self {
121 self.select_columns = columns.iter().map(|s| s.to_string()).collect();
122 self
123 }
124
125 pub fn from(mut self, table: &str) -> Self {
128 self.from_table = Some(table.to_string());
129 self
130 }
131
132 pub fn where_(mut self, column: &str, operator: &str, value: &str) -> Self {
135 self.where_clauses.push(WhereClause {
136 column: column.to_string(),
137 operator: operator.to_string(),
138 value: value.to_string(),
139 connector: Connector::And, });
141 self
142 }
143
144 pub fn and(mut self, column: &str, operator: &str, value: &str) -> Self {
147 self.where_clauses.push(WhereClause {
148 column: column.to_string(),
149 operator: operator.to_string(),
150 value: value.to_string(),
151 connector: Connector::And,
152 });
153 self
154 }
155
156 pub fn or(mut self, column: &str, operator: &str, value: &str) -> Self {
159 self.where_clauses.push(WhereClause {
160 column: column.to_string(),
161 operator: operator.to_string(),
162 value: value.to_string(),
163 connector: Connector::Or,
164 });
165 self
166 }
167
168 pub fn order_by(mut self, column: &str, direction: &str) -> Self {
171 self.order_by_clauses.push(OrderByClause {
172 column: column.to_string(),
173 direction: direction.to_uppercase(),
174 });
175 self
176 }
177
178 pub fn limit(mut self, limit: usize) -> Self {
181 self.limit_value = Some(limit);
182 self
183 }
184
185 pub fn offset(mut self, offset: usize) -> Self {
188 self.offset_value = Some(offset);
189 self
190 }
191
192 pub fn inner_join(mut self, table: &str, left_col: &str, right_col: &str) -> Self {
195 self.join_clauses.push(JoinClause {
196 join_type: JoinType::Inner,
197 table: table.to_string(),
198 on_conditions: vec![(left_col.to_string(), right_col.to_string())],
199 });
200 self
201 }
202
203 pub fn left_join(mut self, table: &str, left_col: &str, right_col: &str) -> Self {
206 self.join_clauses.push(JoinClause {
207 join_type: JoinType::Left,
208 table: table.to_string(),
209 on_conditions: vec![(left_col.to_string(), right_col.to_string())],
210 });
211 self
212 }
213
214 pub fn right_join(mut self, table: &str, left_col: &str, right_col: &str) -> Self {
217 self.join_clauses.push(JoinClause {
218 join_type: JoinType::Right,
219 table: table.to_string(),
220 on_conditions: vec![(left_col.to_string(), right_col.to_string())],
221 });
222 self
223 }
224
225 pub fn count(mut self, column: &str) -> Self {
228 self.aggregate = Some(AggregateFunction::Count(column.to_string()));
229 self
230 }
231
232 pub fn sum(mut self, column: &str) -> Self {
234 self.aggregate = Some(AggregateFunction::Sum(column.to_string()));
235 self
236 }
237
238 pub fn avg(mut self, column: &str) -> Self {
240 self.aggregate = Some(AggregateFunction::Avg(column.to_string()));
241 self
242 }
243
244 pub fn min(mut self, column: &str) -> Self {
246 self.aggregate = Some(AggregateFunction::Min(column.to_string()));
247 self
248 }
249
250 pub fn max(mut self, column: &str) -> Self {
252 self.aggregate = Some(AggregateFunction::Max(column.to_string()));
253 self
254 }
255
256 fn build_sql(&self) -> String {
258 let mut sql = String::new();
259
260 if let Some(agg) = &self.aggregate {
262 sql.push_str(&format!("SELECT {}", agg.to_sql()));
263 } else {
264 let columns = if self.select_columns.is_empty() {
265 "*".to_string()
266 } else {
267 self.select_columns.join(", ")
268 };
269 sql.push_str(&format!("SELECT {}", columns));
270 }
271
272 if let Some(table) = &self.from_table {
274 sql.push_str(&format!(" FROM {}", table));
275 }
276
277 for join in &self.join_clauses {
279 sql.push_str(&format!(" {} {}", join.join_type.to_sql(), join.table));
280
281 if !join.on_conditions.is_empty() {
282 sql.push_str(" ON ");
283 let conditions: Vec<String> = join
284 .on_conditions
285 .iter()
286 .map(|(left, right)| format!("{} = {}", left, right))
287 .collect();
288 sql.push_str(&conditions.join(" AND "));
289 }
290 }
291
292 if !self.where_clauses.is_empty() {
294 sql.push_str(" WHERE ");
295 for (i, clause) in self.where_clauses.iter().enumerate() {
296 if i > 0 {
297 sql.push_str(&format!(" {} ", clause.connector.to_sql()));
298 }
299 sql.push_str(&format!(
300 "{} {} {}",
301 clause.column, clause.operator, clause.value
302 ));
303 }
304 }
305
306 if !self.order_by_clauses.is_empty() {
308 sql.push_str(" ORDER BY ");
309 let orders: Vec<String> = self
310 .order_by_clauses
311 .iter()
312 .map(|o| format!("{} {}", o.column, o.direction))
313 .collect();
314 sql.push_str(&orders.join(", "));
315 }
316
317 if let Some(limit) = self.limit_value {
319 sql.push_str(&format!(" LIMIT {}", limit));
320 }
321
322 if let Some(offset) = self.offset_value {
324 sql.push_str(&format!(" OFFSET {}", offset));
325 }
326
327 sql
328 }
329
330 pub fn execute(self) -> DbxResult<Vec<RecordBatch>> {
333 let sql = self.build_sql();
334 self.db.execute_sql(&sql)
335 }
336}