1use std::fmt::Display;
2
3use crate::{
4 ast::{self, fmt::ToTokens},
5 to_sql_string::{ToSqlContext, ToSqlString},
6};
7
8impl ToSqlString for ast::Select {
9 fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
10 let mut ret = Vec::new();
11 if let Some(with) = &self.with {
12 ret.push(with.to_sql_string(context));
13 }
14
15 ret.push(self.body.to_sql_string(context));
16
17 if let Some(order_by) = &self.order_by {
18 let joined_cols = order_by
20 .iter()
21 .map(|col| col.to_sql_string(context))
22 .collect::<Vec<_>>()
23 .join(", ");
24 ret.push(format!("ORDER BY {}", joined_cols));
25 }
26 if let Some(limit) = &self.limit {
27 ret.push(limit.to_sql_string(context));
28 }
29 ret.join(" ")
30 }
31}
32
33impl ToSqlString for ast::SelectBody {
34 fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
35 let mut ret = self.select.to_sql_string(context);
36
37 if let Some(compounds) = &self.compounds {
38 ret.push(' ');
39 let compound_selects = compounds
40 .iter()
41 .map(|compound_select| {
42 let mut curr = compound_select.operator.to_string();
43 curr.push(' ');
44 curr.push_str(&compound_select.select.to_sql_string(context));
45 curr
46 })
47 .collect::<Vec<_>>()
48 .join(" ");
49 ret.push_str(&compound_selects);
50 }
51 ret
52 }
53}
54
55impl ToSqlString for ast::OneSelect {
56 fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
57 match self {
58 ast::OneSelect::Select(select) => select.to_sql_string(context),
59 ast::OneSelect::Values(values) => {
60 let joined_values = values
61 .iter()
62 .map(|value| {
63 let joined_value = value
64 .iter()
65 .map(|e| e.to_sql_string(context))
66 .collect::<Vec<_>>()
67 .join(", ");
68 format!("({})", joined_value)
69 })
70 .collect::<Vec<_>>()
71 .join(", ");
72 format!("VALUES {}", joined_values)
73 }
74 }
75 }
76}
77
78impl ToSqlString for ast::SelectInner {
79 fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
80 dbg!(&self);
81 let mut ret = Vec::with_capacity(2 + self.columns.len());
82 ret.push("SELECT".to_string());
83 if let Some(distinct) = self.distinctness {
84 ret.push(distinct.to_string());
85 }
86 let joined_cols = self
87 .columns
88 .iter()
89 .map(|col| col.to_sql_string(context))
90 .collect::<Vec<_>>()
91 .join(", ");
92 ret.push(joined_cols);
93
94 if let Some(from) = &self.from {
95 ret.push(from.to_sql_string(context));
96 }
97 if let Some(where_expr) = &self.where_clause {
98 ret.push("WHERE".to_string());
99 ret.push(where_expr.to_sql_string(context));
100 }
101 if let Some(group_by) = &self.group_by {
102 ret.push(group_by.to_sql_string(context));
103 }
104 if let Some(window_clause) = &self.window_clause {
105 ret.push("WINDOW".to_string());
106 let joined_window = window_clause
107 .iter()
108 .map(|window_def| window_def.to_sql_string(context))
109 .collect::<Vec<_>>()
110 .join(",");
111 ret.push(joined_window);
112 }
113
114 ret.join(" ")
115 }
116}
117
118impl ToSqlString for ast::FromClause {
119 fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
120 let mut ret = String::from("FROM");
121 if let Some(select_table) = &self.select {
122 ret.push(' ');
123 ret.push_str(&select_table.to_sql_string(context));
124 }
125 if let Some(joins) = &self.joins {
126 ret.push(' ');
127 let joined_joins = joins
128 .iter()
129 .map(|join| {
130 let mut curr = join.operator.to_string();
131 curr.push(' ');
132 curr.push_str(&join.table.to_sql_string(context));
133 if let Some(join_constraint) = &join.constraint {
134 curr.push(' ');
135 curr.push_str(&join_constraint.to_sql_string(context));
136 }
137 curr
138 })
139 .collect::<Vec<_>>()
140 .join(" ");
141 ret.push_str(&joined_joins);
142 }
143 ret
144 }
145}
146
147impl ToSqlString for ast::SelectTable {
148 fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
149 let mut ret = String::new();
150 match self {
151 Self::Table(name, alias, indexed) => {
152 ret.push_str(&name.to_sql_string(context));
153 if let Some(alias) = alias {
154 ret.push(' ');
155 ret.push_str(&alias.to_string());
156 }
157 if let Some(indexed) = indexed {
158 ret.push(' ');
159 ret.push_str(&indexed.to_string());
160 }
161 }
162 Self::TableCall(table_func, args, alias) => {
163 ret.push_str(&table_func.to_sql_string(context));
164 if let Some(args) = args {
165 ret.push(' ');
166 let joined_args = args
167 .iter()
168 .map(|arg| arg.to_sql_string(context))
169 .collect::<Vec<_>>()
170 .join(", ");
171 ret.push_str(&joined_args);
172 }
173 if let Some(alias) = alias {
174 ret.push(' ');
175 ret.push_str(&alias.to_string());
176 }
177 }
178 Self::Select(select, alias) => {
179 ret.push('(');
180 ret.push_str(&select.to_sql_string(context));
181 ret.push(')');
182 if let Some(alias) = alias {
183 ret.push(' ');
184 ret.push_str(&alias.to_string());
185 }
186 }
187 Self::Sub(from_clause, alias) => {
188 ret.push('(');
189 ret.push_str(&from_clause.to_sql_string(context));
190 ret.push(')');
191 if let Some(alias) = alias {
192 ret.push(' ');
193 ret.push_str(&alias.to_string());
194 }
195 }
196 }
197 ret
198 }
199}
200
201impl ToSqlString for ast::With {
202 fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
203 format!(
204 "WITH{} {}",
205 if self.recursive { " RECURSIVE " } else { "" },
206 self.ctes
207 .iter()
208 .map(|cte| cte.to_sql_string(context))
209 .collect::<Vec<_>>()
210 .join(", ")
211 )
212 }
213}
214
215impl ToSqlString for ast::Limit {
216 fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
217 format!(
218 "LIMIT {}{}",
219 self.expr.to_sql_string(context),
220 self.offset
221 .as_ref()
222 .map_or("".to_string(), |offset| format!(
223 " OFFSET {}",
224 offset.to_sql_string(context)
225 ))
226 )
227 }
229}
230
231impl ToSqlString for ast::CommonTableExpr {
232 fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
233 let mut ret = Vec::with_capacity(self.columns.as_ref().map_or(2, |cols| cols.len()));
234 ret.push(self.tbl_name.0.clone());
235 if let Some(cols) = &self.columns {
236 let joined_cols = cols
237 .iter()
238 .map(|col| col.to_string())
239 .collect::<Vec<_>>()
240 .join(", ");
241
242 ret.push(format!("({})", joined_cols));
243 }
244 ret.push(format!(
245 "AS {}({})",
246 {
247 let mut materialized = self.materialized.to_string();
248 if !materialized.is_empty() {
249 materialized.push(' ');
250 }
251 materialized
252 },
253 self.select.to_sql_string(context)
254 ));
255 ret.join(" ")
256 }
257}
258
259impl Display for ast::IndexedColumn {
260 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
261 write!(f, "{}", self.col_name.0)
262 }
263}
264
265impl ToSqlString for ast::SortedColumn {
266 fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
267 let mut curr = self.expr.to_sql_string(context);
268 if let Some(sort_order) = self.order {
269 curr.push(' ');
270 curr.push_str(&sort_order.to_string());
271 }
272 if let Some(nulls_order) = self.nulls {
273 curr.push(' ');
274 curr.push_str(&nulls_order.to_string());
275 }
276 curr
277 }
278}
279
280impl Display for ast::SortOrder {
281 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
282 self.to_fmt(f)
283 }
284}
285
286impl Display for ast::NullsOrder {
287 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288 self.to_fmt(f)
289 }
290}
291
292impl Display for ast::Materialized {
293 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
294 let value = match self {
295 Self::Any => "",
296 Self::No => "NOT MATERIALIZED",
297 Self::Yes => "MATERIALIZED",
298 };
299 write!(f, "{}", value)
300 }
301}
302
303impl ToSqlString for ast::ResultColumn {
304 fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
305 let mut ret = String::new();
306 match self {
307 Self::Expr(expr, alias) => {
308 ret.push_str(&expr.to_sql_string(context));
309 if let Some(alias) = alias {
310 ret.push(' ');
311 ret.push_str(&alias.to_string());
312 }
313 }
314 Self::Star => {
315 ret.push('*');
316 }
317 Self::TableStar(name) => {
318 ret.push_str(&format!("{}.*", name.0));
319 }
320 }
321 ret
322 }
323}
324
325impl Display for ast::As {
326 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
327 write!(
328 f,
329 "{}",
330 match self {
331 Self::As(alias) => {
332 format!("AS {}", alias.0)
333 }
334 Self::Elided(alias) => alias.0.clone(),
335 }
336 )
337 }
338}
339
340impl Display for ast::Indexed {
341 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
342 write!(
343 f,
344 "{}",
345 match self {
346 Self::NotIndexed => "NOT INDEXED".to_string(),
347 Self::IndexedBy(name) => format!("INDEXED BY {}", name.0),
348 }
349 )
350 }
351}
352
353impl Display for ast::JoinOperator {
354 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355 write!(
356 f,
357 "{}",
358 match self {
359 Self::Comma => ",".to_string(),
360 Self::TypedJoin(join) => {
361 let join_keyword = "JOIN";
362 if let Some(join) = join {
363 format!("{} {}", join, join_keyword)
364 } else {
365 join_keyword.to_string()
366 }
367 }
368 }
369 )
370 }
371}
372
373impl Display for ast::JoinType {
374 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
375 let value = {
376 let mut modifiers = Vec::new();
377 if self.contains(Self::NATURAL) {
378 modifiers.push("NATURAL");
379 }
380 if self.contains(Self::LEFT) || self.contains(Self::RIGHT) {
381 if self.contains(Self::LEFT | Self::RIGHT) {
383 modifiers.push("FULL");
384 } else if self.contains(Self::LEFT) {
385 modifiers.push("LEFT");
386 } else if self.contains(Self::RIGHT) {
387 modifiers.push("RIGHT");
388 }
389 }
394
395 if self.contains(Self::INNER) {
396 modifiers.push("INNER");
397 }
398 if self.contains(Self::CROSS) {
399 modifiers.push("CROSS");
400 }
401 modifiers.join(" ")
402 };
403 write!(f, "{}", value)
404 }
405}
406
407impl ToSqlString for ast::JoinConstraint {
408 fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
409 match self {
410 Self::On(expr) => {
411 format!("ON {}", expr.to_sql_string(context))
412 }
413 Self::Using(col_names) => {
414 let joined_names = col_names
415 .iter()
416 .map(|col| col.0.clone())
417 .collect::<Vec<_>>()
418 .join(",");
419 format!("USING ({})", joined_names)
420 }
421 }
422 }
423}
424
425impl ToSqlString for ast::GroupBy {
426 fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
427 let mut ret = String::from("GROUP BY ");
428 let curr = self
429 .exprs
430 .iter()
431 .map(|expr| expr.to_sql_string(context))
432 .collect::<Vec<_>>()
433 .join(",");
434 ret.push_str(&curr);
435 if let Some(having) = &self.having {
436 ret.push_str(&format!(" HAVING {}", having.to_sql_string(context)));
437 }
438 ret
439 }
440}
441
442impl ToSqlString for ast::WindowDef {
443 fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
444 format!("{} AS {}", self.name.0, self.window.to_sql_string(context))
445 }
446}
447
448impl ToSqlString for ast::Window {
449 fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
450 let mut ret = Vec::new();
451 if let Some(name) = &self.base {
452 ret.push(name.0.clone());
453 }
454 if let Some(partition) = &self.partition_by {
455 let joined_exprs = partition
456 .iter()
457 .map(|e| e.to_sql_string(context))
458 .collect::<Vec<_>>()
459 .join(",");
460 ret.push(format!("PARTITION BY {}", joined_exprs));
461 }
462 if let Some(order_by) = &self.order_by {
463 let joined_cols = order_by
464 .iter()
465 .map(|col| col.to_sql_string(context))
466 .collect::<Vec<_>>()
467 .join(", ");
468 ret.push(format!("ORDER BY {}", joined_cols));
469 }
470 if let Some(frame_claue) = &self.frame_clause {
471 ret.push(frame_claue.to_sql_string(context));
472 }
473 format!("({})", ret.join(" "))
474 }
475}
476
477impl ToSqlString for ast::FrameClause {
478 fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
479 let mut ret = Vec::new();
480 ret.push(self.mode.to_string());
481 let start_sql = self.start.to_sql_string(context);
482 if let Some(end) = &self.end {
483 ret.push(format!(
484 "BETWEEN {} AND {}",
485 start_sql,
486 end.to_sql_string(context)
487 ));
488 } else {
489 ret.push(start_sql);
490 }
491 if let Some(exclude) = &self.exclude {
492 ret.push(exclude.to_string());
493 }
494
495 ret.join(" ")
496 }
497}
498
499impl Display for ast::FrameMode {
500 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
501 self.to_fmt(f)
502 }
503}
504
505impl ToSqlString for ast::FrameBound {
506 fn to_sql_string<C: ToSqlContext>(&self, context: &C) -> String {
507 match self {
508 Self::CurrentRow => "CURRENT ROW".to_string(),
509 Self::Following(expr) => format!("{} FOLLOWING", expr.to_sql_string(context)),
510 Self::Preceding(expr) => format!("{} PRECEDING", expr.to_sql_string(context)),
511 Self::UnboundedFollowing => "UNBOUNDED FOLLOWING".to_string(),
512 Self::UnboundedPreceding => "UNBOUNDED PRECEDING".to_string(),
513 }
514 }
515}
516
517impl Display for ast::FrameExclude {
518 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
519 write!(f, "{}", {
520 let clause = match self {
521 Self::CurrentRow => "CURRENT ROW",
522 Self::Group => "GROUP",
523 Self::NoOthers => "NO OTHERS",
524 Self::Ties => "TIES",
525 };
526 format!("EXCLUDE {}", clause)
527 })
528 }
529}
530
531#[cfg(test)]
532mod tests {
533 use crate::to_sql_string_test;
534
535 to_sql_string_test!(test_select_basic, "SELECT 1;");
536
537 to_sql_string_test!(test_select_table, "SELECT * FROM t;");
538
539 to_sql_string_test!(test_select_table_2, "SELECT a FROM t;");
540
541 to_sql_string_test!(test_select_multiple_columns, "SELECT a, b, c FROM t;");
542
543 to_sql_string_test!(test_select_with_alias, "SELECT a AS col1 FROM t;");
544
545 to_sql_string_test!(test_select_with_table_alias, "SELECT t1.a FROM t AS t1;");
546
547 to_sql_string_test!(test_select_with_where, "SELECT a FROM t WHERE b = 1;");
548
549 to_sql_string_test!(
550 test_select_with_multiple_conditions,
551 "SELECT a FROM t WHERE b = 1 AND c > 2;"
552 );
553
554 to_sql_string_test!(
555 test_select_with_order_by,
556 "SELECT a FROM t ORDER BY a DESC;"
557 );
558
559 to_sql_string_test!(test_select_with_limit, "SELECT a FROM t LIMIT 10;");
560
561 to_sql_string_test!(
562 test_select_with_offset,
563 "SELECT a FROM t LIMIT 10 OFFSET 5;"
564 );
565
566 to_sql_string_test!(
567 test_select_with_join,
568 "SELECT a FROM t JOIN t2 ON t.b = t2.b;"
569 );
570
571 to_sql_string_test!(
572 test_select_with_group_by,
573 "SELECT a, COUNT(*) FROM t GROUP BY a;"
574 );
575
576 to_sql_string_test!(
577 test_select_with_having,
578 "SELECT a, COUNT(*) FROM t GROUP BY a HAVING COUNT(*) > 1;"
579 );
580
581 to_sql_string_test!(test_select_with_distinct, "SELECT DISTINCT a FROM t;");
582
583 to_sql_string_test!(test_select_with_function, "SELECT COUNT(a) FROM t;");
584
585 to_sql_string_test!(
586 test_select_with_subquery,
587 "SELECT a FROM (SELECT b FROM t) AS sub;"
588 );
589
590 to_sql_string_test!(
591 test_select_nested_subquery,
592 "SELECT a FROM (SELECT b FROM (SELECT c FROM t WHERE c > 10) AS sub1 WHERE b < 20) AS sub2;"
593 );
594
595 to_sql_string_test!(
596 test_select_multiple_joins,
597 "SELECT t1.a, t2.b, t3.c FROM t1 JOIN t2 ON t1.id = t2.id LEFT JOIN t3 ON t2.id = t3.id;"
598 );
599
600 to_sql_string_test!(
601 test_select_with_cte,
602 "WITH cte AS (SELECT a FROM t WHERE b = 1) SELECT a FROM cte WHERE a > 10;"
603 );
604
605 to_sql_string_test!(
606 test_select_with_window_function,
607 "SELECT a, ROW_NUMBER() OVER (PARTITION BY b ORDER BY c DESC) AS rn FROM t;"
608 );
609
610 to_sql_string_test!(
611 test_select_with_complex_where,
612 "SELECT a FROM t WHERE b IN (1, 2, 3) AND c BETWEEN 10 AND 20 OR d IS NULL;"
613 );
614
615 to_sql_string_test!(
616 test_select_with_case,
617 "SELECT CASE WHEN a > 0 THEN 'positive' ELSE 'non-positive' END AS result FROM t;"
618 );
619
620 to_sql_string_test!(test_select_with_aggregate_and_join, "SELECT t1.a, COUNT(t2.b) FROM t1 LEFT JOIN t2 ON t1.id = t2.id GROUP BY t1.a HAVING COUNT(t2.b) > 5;");
621
622 to_sql_string_test!(test_select_with_multiple_ctes, "WITH cte1 AS (SELECT a FROM t WHERE b = 1), cte2 AS (SELECT c FROM t2 WHERE d = 2) SELECT cte1.a, cte2.c FROM cte1 JOIN cte2 ON cte1.a = cte2.c;");
623
624 to_sql_string_test!(
625 test_select_with_union,
626 "SELECT a FROM t1 UNION SELECT b FROM t2;"
627 );
628
629 to_sql_string_test!(
630 test_select_with_union_all,
631 "SELECT a FROM t1 UNION ALL SELECT b FROM t2;"
632 );
633
634 to_sql_string_test!(
635 test_select_with_exists,
636 "SELECT a FROM t WHERE EXISTS (SELECT 1 FROM t2 WHERE t2.b = t.a);"
637 );
638
639 to_sql_string_test!(
640 test_select_with_correlated_subquery,
641 "SELECT a, (SELECT COUNT(*) FROM t2 WHERE t2.b = t.a) AS count_b FROM t;"
642 );
643
644 to_sql_string_test!(
645 test_select_with_complex_order_by,
646 "SELECT a, b FROM t ORDER BY CASE WHEN a IS NULL THEN 1 ELSE 0 END, b ASC, c DESC;"
647 );
648
649 to_sql_string_test!(
650 test_select_with_full_outer_join,
651 "SELECT t1.a, t2.b FROM t1 FULL OUTER JOIN t2 ON t1.id = t2.id;",
652 ignore = "OUTER JOIN is incorrectly parsed in parser"
653 );
654
655 to_sql_string_test!(test_select_with_aggregate_window, "SELECT a, SUM(b) OVER (PARTITION BY c ORDER BY d ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS running_sum FROM t;");
656
657 to_sql_string_test!(
658 test_select_with_exclude,
659 "SELECT
660 c.name,
661 o.order_id,
662 o.order_amount,
663 SUM(o.order_amount) OVER (PARTITION BY c.id
664 ORDER BY o.order_date
665 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
666 EXCLUDE CURRENT ROW) AS running_total_excluding_current
667FROM customers c
668JOIN orders o ON c.id = o.customer_id
669WHERE EXISTS (SELECT 1
670 FROM orders o2
671 WHERE o2.customer_id = c.id
672 AND o2.order_amount > 1000);"
673 );
674}