1use crate::dialect::Dialect;
4use crate::pagination::{Cursor, IntoCursor};
5use crate::validate::{assert_valid_sql_expression, assert_valid_sql_identifier};
6
7use super::filter::{build_condition_impl, build_filter_expr_impl};
8use super::types::{
9 Aggregate, CompoundFilter, ComputedField, CursorDirection, Filter, FilterExpr, Operator,
10 QueryResult, SortDir, SortField, Value,
11};
12
13#[derive(Debug)]
15#[must_use = "builder does nothing until .build() is called"]
16pub struct QueryBuilder<D: Dialect> {
17 dialect: D,
18 table: String,
19 fields: Vec<String>,
20 computed: Vec<ComputedField>,
21 aggregates: Vec<Aggregate>,
22 filters: Vec<Filter>,
23 filter_expr: Option<FilterExpr>,
24 group_by: Vec<String>,
25 having: Option<FilterExpr>,
26 sorts: Vec<SortField>,
27 limit: Option<u32>,
28 offset: Option<u32>,
29 cursor: Option<Cursor>,
30 cursor_direction: Option<CursorDirection>,
31}
32
33impl<D: Dialect> QueryBuilder<D> {
34 pub fn new(dialect: D, table: impl Into<String>) -> Self {
40 let table = table.into();
41 assert_valid_sql_identifier(&table, "table");
42 Self {
43 dialect,
44 table,
45 fields: Vec::new(),
46 computed: Vec::new(),
47 aggregates: Vec::new(),
48 filters: Vec::new(),
49 filter_expr: None,
50 group_by: Vec::new(),
51 having: None,
52 sorts: Vec::new(),
53 limit: None,
54 offset: None,
55 cursor: None,
56 cursor_direction: None,
57 }
58 }
59
60 pub fn fields(mut self, fields: &[&str]) -> Self {
66 for field in fields {
67 assert_valid_sql_identifier(field, "field");
68 }
69 self.fields = fields.iter().map(|s| (*s).to_string()).collect();
70 self
71 }
72
73 pub fn computed(mut self, alias: impl Into<String>, expression: impl Into<String>) -> Self {
94 let alias = alias.into();
95 let expression = expression.into();
96 assert_valid_sql_identifier(&alias, "computed field alias");
97 assert_valid_sql_expression(&expression, "computed field");
98 self.computed.push(ComputedField::new(alias, expression));
99 self
100 }
101
102 pub fn aggregate(mut self, agg: Aggregate) -> Self {
104 self.aggregates.push(agg);
105 self
106 }
107
108 pub fn count(mut self) -> Self {
110 self.aggregates.push(Aggregate::count());
111 self
112 }
113
114 pub fn sum(mut self, field: impl Into<String>) -> Self {
116 self.aggregates.push(Aggregate::sum(field));
117 self
118 }
119
120 pub fn avg(mut self, field: impl Into<String>) -> Self {
122 self.aggregates.push(Aggregate::avg(field));
123 self
124 }
125
126 pub fn min(mut self, field: impl Into<String>) -> Self {
128 self.aggregates.push(Aggregate::min(field));
129 self
130 }
131
132 pub fn max(mut self, field: impl Into<String>) -> Self {
134 self.aggregates.push(Aggregate::max(field));
135 self
136 }
137
138 pub fn filter(mut self, field: impl Into<String>, op: Operator, value: Value) -> Self {
144 let field = field.into();
145 assert_valid_sql_identifier(&field, "filter field");
146 self.filters.push(Filter { field, op, value });
147 self
148 }
149
150 pub fn filter_expr(mut self, expr: FilterExpr) -> Self {
152 self.filter_expr = Some(expr);
153 self
154 }
155
156 pub fn and(mut self, filters: Vec<FilterExpr>) -> Self {
158 self.filter_expr = Some(FilterExpr::Compound(CompoundFilter::and(filters)));
159 self
160 }
161
162 pub fn or(mut self, filters: Vec<FilterExpr>) -> Self {
164 self.filter_expr = Some(FilterExpr::Compound(CompoundFilter::or(filters)));
165 self
166 }
167
168 pub fn group_by(mut self, fields: &[&str]) -> Self {
174 for field in fields {
175 assert_valid_sql_identifier(field, "group by field");
176 }
177 self.group_by = fields.iter().map(|s| (*s).to_string()).collect();
178 self
179 }
180
181 pub fn having(mut self, expr: FilterExpr) -> Self {
183 self.having = Some(expr);
184 self
185 }
186
187 pub fn sort(mut self, field: impl Into<String>, dir: SortDir) -> Self {
193 let field = field.into();
194 assert_valid_sql_identifier(&field, "sort field");
195 self.sorts.push(SortField::new(field, dir));
196 self
197 }
198
199 pub fn sorts(mut self, sorts: &[SortField]) -> Self {
201 self.sorts.extend(sorts.iter().cloned());
202 self
203 }
204
205 pub const fn page(mut self, page: u32, limit: u32) -> Self {
207 self.limit = Some(limit);
208 self.offset = Some(page.saturating_sub(1).saturating_mul(limit));
209 self
210 }
211
212 pub const fn limit_offset(mut self, limit: u32, offset: u32) -> Self {
214 self.limit = Some(limit);
215 self.offset = Some(offset);
216 self
217 }
218
219 pub const fn limit(mut self, limit: u32) -> Self {
221 self.limit = Some(limit);
222 self
223 }
224
225 pub fn after_cursor(mut self, cursor: impl IntoCursor) -> Self {
235 if let Some(c) = cursor.into_cursor() {
236 self.cursor = Some(c);
237 self.cursor_direction = Some(CursorDirection::After);
238 }
239 self
240 }
241
242 pub fn before_cursor(mut self, cursor: impl IntoCursor) -> Self {
251 if let Some(c) = cursor.into_cursor() {
252 self.cursor = Some(c);
253 self.cursor_direction = Some(CursorDirection::Before);
254 }
255 self
256 }
257
258 pub fn build(self) -> QueryResult {
260 let mut sql = String::new();
261 let mut params = Vec::new();
262 let mut param_idx = 1usize;
263
264 let mut select_parts = Vec::new();
266
267 if !self.fields.is_empty() {
269 select_parts.extend(self.fields.clone());
270 }
271
272 for comp in &self.computed {
274 select_parts.push(comp.to_sql());
275 }
276
277 for agg in &self.aggregates {
279 select_parts.push(agg.to_sql());
280 }
281
282 let select_str = if select_parts.is_empty() {
283 "*".to_string()
284 } else {
285 select_parts.join(", ")
286 };
287
288 sql.push_str(&format!("SELECT {} FROM {}", select_str, self.table));
289
290 let has_filter_expr = self.filter_expr.is_some();
292 let has_simple_filters = !self.filters.is_empty();
293 let has_cursor = self.cursor.is_some() && self.cursor_direction.is_some();
294
295 if has_filter_expr || has_simple_filters || has_cursor {
296 sql.push_str(" WHERE ");
297 let mut all_conditions = Vec::new();
298
299 if let Some(ref expr) = self.filter_expr {
301 let (condition, new_params, new_idx) =
302 build_filter_expr_impl(&self.dialect, expr, param_idx);
303 all_conditions.push(condition);
304 params.extend(new_params);
305 param_idx = new_idx;
306 }
307
308 for filter in &self.filters {
310 let (condition, new_params, new_idx) =
311 build_condition_impl(&self.dialect, filter, param_idx);
312 all_conditions.push(condition);
313 params.extend(new_params);
314 param_idx = new_idx;
315 }
316
317 if let (Some(cursor), Some(direction)) = (&self.cursor, self.cursor_direction) {
319 let (condition, new_params, new_idx) =
320 self.build_cursor_condition(cursor, direction, param_idx);
321 if !condition.is_empty() {
322 all_conditions.push(condition);
323 params.extend(new_params);
324 param_idx = new_idx;
325 }
326 }
327
328 sql.push_str(&all_conditions.join(" AND "));
329 }
330
331 if !self.group_by.is_empty() {
333 sql.push_str(&format!(" GROUP BY {}", self.group_by.join(", ")));
334 }
335
336 if let Some(ref expr) = self.having {
339 let (condition, new_params, _new_idx) =
340 build_filter_expr_impl(&self.dialect, expr, param_idx);
341 sql.push_str(&format!(" HAVING {condition}"));
342 params.extend(new_params);
343 }
344
345 if !self.sorts.is_empty() {
347 sql.push_str(" ORDER BY ");
348 let sort_parts: Vec<String> = self
349 .sorts
350 .iter()
351 .map(|s| {
352 let dir = match s.dir {
353 SortDir::Asc => "ASC",
354 SortDir::Desc => "DESC",
355 };
356 format!("{} {}", s.field, dir)
357 })
358 .collect();
359 sql.push_str(&sort_parts.join(", "));
360 }
361
362 if let Some(limit) = self.limit {
364 sql.push_str(&format!(" LIMIT {limit}"));
365 }
366 if let Some(offset) = self.offset {
367 sql.push_str(&format!(" OFFSET {offset}"));
368 }
369
370 QueryResult { sql, params }
371 }
372
373 fn build_cursor_condition(
379 &self,
380 cursor: &Cursor,
381 direction: CursorDirection,
382 start_idx: usize,
383 ) -> (String, Vec<Value>, usize) {
384 let sort_fields: Vec<SortField> = if self.sorts.is_empty() {
386 cursor
387 .fields
388 .iter()
389 .map(|(name, _)| SortField::new(name.clone(), SortDir::Asc))
390 .collect()
391 } else {
392 self.sorts.clone()
393 };
394
395 if sort_fields.is_empty() {
396 return (String::new(), vec![], start_idx);
397 }
398
399 let mut cursor_values: Vec<(&str, &Value)> = Vec::new();
401 for sort in &sort_fields {
402 if let Some((_, value)) = cursor.fields.iter().find(|(name, _)| name == &sort.field) {
403 cursor_values.push((&sort.field, value));
404 }
405 }
406
407 if cursor_values.is_empty() {
408 return (String::new(), vec![], start_idx);
409 }
410
411 let mut idx = start_idx;
412 let mut params = Vec::new();
413
414 if cursor_values.len() == 1 {
415 let (field, value) = cursor_values[0];
417 let sort = &sort_fields[0];
418 let op = match (direction, sort.dir) {
419 (CursorDirection::After, SortDir::Asc) => ">",
420 (CursorDirection::After, SortDir::Desc) => "<",
421 (CursorDirection::Before, SortDir::Asc) => "<",
422 (CursorDirection::Before, SortDir::Desc) => ">",
423 };
424
425 let sql = format!("{} {} {}", field, op, self.dialect.param(idx));
426 params.push(value.clone());
427 idx += 1;
428
429 (sql, params, idx)
430 } else {
431 let fields: Vec<&str> = cursor_values.iter().map(|(f, _)| *f).collect();
434 let placeholders: Vec<String> = cursor_values
435 .iter()
436 .enumerate()
437 .map(|(i, (_, value))| {
438 params.push((*value).clone());
439 self.dialect.param(idx + i)
440 })
441 .collect();
442 idx += cursor_values.len();
443
444 let primary_dir = sort_fields[0].dir;
446 let op = match (direction, primary_dir) {
447 (CursorDirection::After, SortDir::Asc) => ">",
448 (CursorDirection::After, SortDir::Desc) => "<",
449 (CursorDirection::Before, SortDir::Asc) => "<",
450 (CursorDirection::Before, SortDir::Desc) => ">",
451 };
452
453 let sql = format!(
454 "({}) {} ({})",
455 fields.join(", "),
456 op,
457 placeholders.join(", ")
458 );
459
460 (sql, params, idx)
461 }
462 }
463}